#include "common.h"
#include "shmem.h"
#include "sigctl.h"

#include <errno.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/ipc.h>
#include <sys/sem.h>
#include <sys/shm.h>
#include <string.h>
#include <stdio.h>
#include <unistd.h>

// SEM_A and SEM_R don't seem to be defined on Linux
#ifndef SEM_A
  #define SEM_A SHM_W
  #define SEM_R SHM_R
#endif

// Create a dummy value for the semctl function as described in the 
// semctl man-page:
#if !defined (__GNU_LIBRARY__) || defined(_SEM_SEMUN_UNDEFINED)
  union semun {
    int val;
    struct semid_ds* buf;
    unsigned short int array;
    struct seminfo *__buf;
  };
#endif

static semun semun_foo;

// How this module is working:
//   The server process creates two IPC objects.
//     1) The shared memory area
//     2) A semaphore set with two semaphores.
//        The first semaphore is used to ensure exclusive access to the area
//        The second semaphore is used to find out if the server has died.
//        More precisly it is incremented by one by the server process with
//        the SEM_UNDO flag set. So when it becomes zero the server has died.

//-----------------------------------------------------------------------------

void ShMemServer::releaseipc() {  
  blockSignals();
  if(memory) { shmdt((char*)memory); memory=0; }
  if(shmid!=-1) { shmctl(shmid,IPC_RMID,0); shmid=-1; }
  if(semid!=-1) { semctl(semid,0,IPC_RMID,semun_foo); semid=-1; }
  unblockSignals();
}

char* ShMemServer::init(const char* file, int size, ShMemT*& Lock) {
  blockSignals(); // block because releaseipc is called on termination by signals
  char* c=int_init(file,size,Lock);
  if(c) releaseipc();
  unblockSignals();
  return c;
}
  
char* ShMemServer::int_init(const char* file, int size, ShMemT*& Lock) {
  // Try to get ipc identifier for file:
  key_t ipcid=ftok(file,0);
  if(ipcid==-1) { snprintf(errmsg,sizeof(errmsg),"%s: %s",file,strerror(errno)); return errmsg; }
  
  // Try to stat the file:
  struct stat sbuf;
  if(stat(file,&sbuf)==-1) { snprintf(errmsg,sizeof(errmsg),"%s: %s",file,strerror(errno)); return errmsg; }
  
  // Find permissions for shared memory and semaphore on basis of permissions of file:
  int grprd=sbuf.st_mode & S_IRGRP;
  int grpwr=sbuf.st_mode & S_IWGRP;
  int wrlrd=sbuf.st_mode & S_IROTH;
  int wrlwr=sbuf.st_mode & S_IWOTH;
    
  int shm_perm=SHM_R|SHM_W|
               (((grprd ? SHM_R : 0)|(grpwr ? SHM_W : 0))>>3)|
	       (((wrlrd ? SHM_R : 0)|(wrlwr ? SHM_W : 0))>>6);		 
  int sem_perm=SEM_R|SEM_A|
               (((grprd||grpwr) ? (SEM_R|SEM_A) : 0)>>3)|
	       (((wrlrd||wrlwr) ? (SEM_R|SEM_A) : 0)>>6);

  //---------  
  // Create a semaphore set.
  // BUGS: The construction below is not safe in the situation where two
  //       servers are started at the same time.
  
  semid=semget(ipcid,2,IPC_CREAT|IPC_EXCL|sem_perm);
  if(semid==-1) {
    if(errno!=EEXIST) {
      snprintf(errmsg,sizeof(errmsg),"Could not create semaphore: %s",strerror(errno)); return errmsg;
    }
    
    // The semaphore set already existed - try to get it's ID:
    int si=semget(ipcid,0,0);
    if(si==-1) {
      snprintf(errmsg,sizeof(errmsg),"Could not get id for semaphore set: %s",strerror(errno)); return errmsg;
    }
    
    // See if a server is running:
    if(semctl(si,1,GETVAL,semun_foo)==1) {
      snprintf(errmsg,sizeof(errmsg),"The daemon is already running"); return errmsg;
    }
      
    // An old server is probalby terminated without cleaning up.
    // Try to kill semaphore set:
    if(semctl(si,0,IPC_RMID,semun_foo)==-1) {
      snprintf(errmsg,sizeof(errmsg),"Semaphore existed and could not be deleted: %s",strerror(errno)); return errmsg;
    }      
          
    // And retry to create it:
    semid=semget(ipcid,2,IPC_CREAT|IPC_EXCL|sem_perm);
    if(semid==-1) {
      snprintf(errmsg,sizeof(errmsg),"Semaphore could not be recreated: %s",strerror(errno)); return errmsg;
    }
  }

  //---------------
  // Set semaphoreset to indicate that we are running and that the memory area
  // is locked. Since this is an atomic operation we can be sure that no
  // client will try to use the shared memory before it is again released.
  static sembuf s[2]={
    1,1,SEM_UNDO, // indicate we are running
    0,1,SEM_UNDO  // and lock the memory area
  };
  
  if(semop(semid,s,2)==-1) {
    snprintf(errmsg,sizeof(errmsg),"Could not manipulate semaphore: %s",strerror(errno)); return errmsg;
  }      
  
  //---------------
  // Create a shared memory area with the given id:
  shmid=shmget(ipcid,size,IPC_CREAT|IPC_EXCL|shm_perm);                               
  
  if(shmid==-1) {
    if(errno!=EEXIST) {
      snprintf(errmsg,sizeof(errmsg),"Could not create shared memory area: %s",strerror(errno)); return errmsg;
    }
    
    // The area existed - mark it for deletion:
    if(shmctl(shmget(ipcid,0,0),IPC_RMID,0)==-1) {
      snprintf(errmsg,sizeof(errmsg),"Could not mark shared memory area for deletion: %s",strerror(errno)); return errmsg;
    }
        
    // And retry to create it:
    shmid=shmget(ipcid,size,IPC_CREAT|IPC_EXCL|shm_perm);
    if(shmid==-1) {
      snprintf(errmsg,sizeof(errmsg),"Could not recreate shared memory area: %s",strerror(errno)); return errmsg;
    }                    
  }    
  
  // Attach the shared memory area to the memory:
  memory=(ShMemT*)shmat(shmid,0,0);
  if(memory==0) {
    snprintf(errmsg,sizeof(errmsg),"Could not attach shared memory area: %s",strerror(errno)); return errmsg;
  }
  
  // And lock it:
  locked=1;
  Lock=memory;    

  return 0;
};

ShMemT* ShMemServer::Lock() {
  if(wait()==-1) return 0;
  else return memory;
}

void ShMemServer::UnLock(ShMemT*) {
  signal();
}

void ShMemServer::signalHandler(void* s) {
  ((ShMemServer*)s)->releaseipc();
}
  
ShMemServer::ShMemServer() { 
  memory=0; shmid=-1; semid=-1;
  
  // Install signal handler:
  registerSignalHandler(this,signalHandler,signalTerm);    
}

ShMemServer::~ShMemServer() {
  blockSignals();
  
  releaseipc();
  
  // Uninstall signal handler:
  unregisterSignalHandler(this,signalHandler);      
  
  unblockSignals();
}

//-----------------------------------------------------------------------------

char* ShMemClient::init(const char* file, int rdonly) {
  readonly=rdonly;
 
  // Try to get identifier for file:
  key_t ipcid=ftok(file,0);
  if(ipcid==-1) { snprintf(errmsg,sizeof(errmsg),"%s: %s",file,strerror(errno)); return errmsg; }

  // Then try to get the semaphore:
  semid=semget(ipcid,1,0);
  if(semid==-1) { snprintf(errmsg,sizeof(errmsg),"Could not get semaphore: %s. Is daemon running?",strerror(errno)); return errmsg; }
  
  // See if the server still runs:
  int i=semctl(semget(ipcid,0,0),1,GETVAL,semun_foo);
  if(i==0) {
    snprintf(errmsg,sizeof(errmsg),"It look likes the daemon has been killed"); return errmsg;
  } else if(i==-1) {
    snprintf(errmsg,sizeof(errmsg),"Could not read semaphore value: %s",strerror(errno)); return errmsg;
  }
    
  // Try to get the shared memory area:
  shmid=shmget(ipcid,0,0);
  if(shmid==-1) { snprintf(errmsg,sizeof(errmsg),"Could not get shared memory: %s. Is daemon running?",strerror(errno)); return errmsg; }  
    
  return 0;
}  

inline int ShMemClient::wait() {
  assert(!locked);
  
  // Wait on semaphore:
  static sembuf s[2]={
    0,0,SEM_UNDO, // wait
    0,1,SEM_UNDO  // then increment
  };
  int i;
  do {
    i=semop(semid,s,2);
  } while(i==-1 && errno==EINTR);
  
  if(i==-1) // Could not be locked, maybe the semaphore is delteded by a new server
    return i;
  else {
    locked=1;
    return i;
  }   
}  

inline int ShMemClient::signal() {  
  assert(locked);
  static sembuf s[1]={
    0,-1,SEM_UNDO
  };
  locked=0;
  return semop(semid,s,1);
}

ShMemT* ShMemClient::Lock() {
  // Return if server has died:
  if(semctl(semid,1,GETVAL,semun_foo)!=1) return 0;  

  if(wait()==-1) return 0;    
  mem=(ShMemT*)shmat(shmid,0,readonly ? SHM_RDONLY : 0);
  if(mem==0 || mem==(ShMemT*)-1) { // If it couldn't be locked
    mem=0;
    return 0;
  } else {
    // Block for signals so we leave memory in consistent state if killed:
    blockSignals();
  }
  
  return mem;
}
  
void ShMemClient::UnLock(ShMemT* a) {
  assert(a==mem);
  signal();
  shmdt((char*)mem);
  
  // Unblock signals:
  unblockSignals();
};
