tsafe-agent.c - safe - password protected secret keeper
 (HTM) git clone git://git.z3bra.org/safe.git
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) README
 (DIR) LICENSE
       ---
       tsafe-agent.c (5746B)
       ---
            1 #include <sys/resource.h>
            2 #include <sys/socket.h>
            3 #include <sys/stat.h>
            4 #include <sys/types.h>
            5 #include <sys/un.h>
            6 
            7 #include <err.h>
            8 #include <fcntl.h>
            9 #include <limits.h>
           10 #include <paths.h>
           11 #include <poll.h>
           12 #include <signal.h>
           13 #include <stdint.h>
           14 #include <stdio.h>
           15 #include <stdlib.h>
           16 #include <string.h>
           17 #include <unistd.h>
           18 
           19 #include <sodium.h>
           20 
           21 #include "arg.h"
           22 #include "config.h"
           23 
           24 #ifndef __OpenBSD__
           25         #include "strlcpy.h"
           26 #endif
           27 
           28 struct safe {
           29         int loaded;
           30         uint8_t saltkey[crypto_secretstream_xchacha20poly1305_KEYBYTES + crypto_pwhash_SALTBYTES];
           31 };
           32 
           33 char *argv0;
           34 struct safe s;
           35 char *sockp = NULL;
           36 int verbose = 0;
           37 
           38 void
           39 usage(void)
           40 {
           41         fprintf(stderr, "usage: %s [-hdv] [-t timeout] [-f socket]\n", argv0);
           42         exit(1);
           43 }
           44 
           45 char *
           46 dirname(char *path)
           47 {
           48         static char tmp[PATH_MAX];
           49         char *p = NULL;
           50         size_t len;
           51         snprintf(tmp, sizeof(tmp), "%s", path);
           52         len = strlen(tmp);
           53         for(p = tmp + len; p > tmp; p--)
           54                 if(*p == '/')
           55                         break;
           56 
           57         *p = 0;
           58         return tmp;
           59 }
           60 
           61 ssize_t
           62 xread(int fd, void *buf, size_t nbytes)
           63 {
           64         uint8_t *bp = buf;
           65         ssize_t total = 0;
           66 
           67         while (nbytes > 0) {
           68                 ssize_t n;
           69 
           70                 n = read(fd, &bp[total], nbytes);
           71                 if (n < 0)
           72                         err(1, "read");
           73                 else if (n == 0)
           74                         return total;
           75                 total += n;
           76                 nbytes -= n;
           77         }
           78         return total;
           79 }
           80 
           81 ssize_t
           82 xwrite(int fd, const void *buf, size_t nbytes)
           83 {
           84         const uint8_t *bp = buf;
           85         ssize_t total = 0;
           86 
           87         while (nbytes > 0) {
           88                 ssize_t n;
           89 
           90                 n = write(fd, &bp[total], nbytes);
           91                 if (n < 0)
           92                         err(1, "write");
           93                 else if (n == 0)
           94                         return total;
           95                 total += n;
           96                 nbytes -= n;
           97         }
           98         return total;
           99 }
          100 
          101 int
          102 creatsock(char *sockpath)
          103 {
          104         int sfd;
          105         struct sockaddr_un addr;
          106 
          107         sfd = socket(AF_UNIX, SOCK_STREAM, 0);
          108         if (sfd < 0)
          109                 return -1;
          110 
          111         umask(0177);
          112         memset(&addr, 0, sizeof(addr));
          113         addr.sun_family = AF_UNIX;
          114         strlcpy(addr.sun_path, sockpath, sizeof(addr.sun_path));
          115 
          116         if (bind(sfd, (struct sockaddr *) &addr, sizeof(addr)) < 0)
          117                 return -1;
          118 
          119         if (listen(sfd, 10) < 0)
          120                 return -1;
          121 
          122         return sfd;
          123 }
          124 
          125 void
          126 forgetkey()
          127 {
          128         sodium_memzero(s.saltkey, sizeof(s.saltkey));
          129         s.loaded = 0;
          130         alarm(0);
          131 }
          132 
          133 void
          134 sighandler(int signal)
          135 {
          136         switch (signal) {
          137         case SIGINT:
          138         case SIGTERM:
          139                 if (verbose)
          140                         fprintf(stderr, "unlocking key from memory\n");
          141                 sodium_munlock(s.saltkey, sizeof(s.saltkey));
          142 
          143                 if (verbose)
          144                         fprintf(stderr, "removing socket %s\n", sockp);
          145                 unlink(sockp);
          146                 rmdir(dirname(sockp));
          147                 exit(0);
          148                 /* NOTREACHED */
          149         case SIGALRM:
          150         case SIGUSR1:
          151                 if (verbose)
          152                         fprintf(stderr, "clearing key from memory\n");
          153                 forgetkey();
          154                 break;
          155         }
          156 }
          157 
          158 int
          159 servekey(int timeout)
          160 {
          161         int r, sfd;
          162         ssize_t n;
          163         struct pollfd pfd;
          164 
          165         if (verbose)
          166                 fprintf(stderr, "listening on %s\n", sockp);
          167         sfd = creatsock(sockp);
          168         if (sfd < 0)
          169                 err(1, "%s", sockp);
          170 
          171         s.loaded = 0;
          172 
          173         for (;;) {
          174                 pfd.fd = accept(sfd, NULL, NULL);
          175                 pfd.revents = 0;
          176                 pfd.events = POLLIN;
          177 
          178                 if (s.loaded)
          179                         pfd.events |= POLLOUT;
          180 
          181                 if (pfd.fd < 0)
          182                         err(1, "%s", sockp);
          183 
          184                 if ((r = poll(&pfd, 1, 100)) < 0)
          185                         return r;
          186 
          187                 if (pfd.revents & POLLIN) {
          188                         if (verbose)
          189                                 fprintf(stderr, "reading key from client fd %d\n", pfd.fd);
          190 
          191                         n = xread(pfd.fd, s.saltkey, sizeof(s.saltkey));
          192                         if (n == sizeof(s.saltkey)) {
          193                                 s.loaded = 1;
          194                                 if (verbose) {
          195                                         fprintf(stderr, "key loaded in memory\n");
          196                                         if (timeout > 0)
          197                                                 fprintf(stderr, "setting timeout to %d seconds\n", timeout);
          198                                 }
          199                                 alarm(timeout);
          200                         } else {
          201                                 forgetkey();
          202                                 if (verbose)
          203                                         fprintf(stderr, "failed to load key in memory\n");
          204                         }
          205                 } else if (pfd.revents & POLLOUT) {
          206                         if (verbose)
          207                                 fprintf(stderr, "sending key to client fd %d\n", pfd.fd);
          208 
          209                         xwrite(pfd.fd, s.saltkey, sizeof(s.saltkey));
          210                 }
          211 
          212                 close(pfd.fd);
          213         }
          214 
          215         /* NOTREACHED */
          216         close(sfd);
          217         return -1;
          218 }
          219 
          220 int
          221 main(int argc, char *argv[])
          222 {
          223         pid_t pid;
          224         int fd, timeout = 0, dflag = 0;
          225         size_t dirlen;
          226         char path[PATH_MAX];
          227 
          228         pid = getpid();
          229         strlcpy(path, agent_socktmp, sizeof(path));
          230 
          231         ARGBEGIN {
          232         case 'd':
          233                 dflag = 1;
          234                 break;
          235         case 'f':
          236                 sockp = EARGF(usage());
          237                 break;
          238         case 't':
          239                 timeout = atoi(EARGF(usage()));
          240                 break;
          241         case 'v':
          242                 verbose = 1;
          243                 break;
          244         default:
          245                 usage();
          246         } ARGEND
          247 
          248         sodium_mlock(&s, sizeof(s));
          249 
          250 #ifndef _DEBUG
          251         /* deny core dump as memory contains derivated key */
          252         struct rlimit rlim;
          253         rlim.rlim_cur = rlim.rlim_max = 0;
          254         if (setrlimit(RLIMIT_CORE, &rlim) < 0)
          255                 err(1, "setrlimit RLIMIT_CORE");
          256 #endif
          257 
          258         if (sockp) {
          259                 strlcpy(path, sockp, sizeof(path));
          260         } else {
          261                 if (!mkdtemp(path))
          262                         err(1, "mkdtemp: %s", path);
          263 
          264                 dirlen = strnlen(path, sizeof(path));
          265                 snprintf(path + dirlen, PATH_MAX - dirlen, agent_sockfmt, pid);
          266                 sockp = path;
          267         }
          268 
          269 #ifdef __OpenBSD__
          270         if (unveil(_PATH_DEVNULL, "rw") == -1)
          271                 err(1, "unveil %s", _PATH_DEVNULL);
          272         if (unveil(sockp, "c") == -1)
          273                 err(1, "unveil %s", sockp);
          274         if (pledge("stdio unix proc cpath", NULL) == -1)
          275                 err(1, "pledge");
          276 #endif
          277 
          278         if (dflag) {
          279                 printf("SAFE_PID=%d; export SAFE_PID\n", pid);
          280                 printf("SAFE_SOCK=%s; export SAFE_SOCK\n", sockp);
          281                 fflush(stdout);
          282                 goto skip;
          283         }
          284 
          285         if (verbose)
          286                 fprintf(stderr, "forking agent to the background\n");
          287 
          288         pid = fork();
          289         if (pid < 0)
          290                 err(1, "fork");
          291 
          292         if (pid) {
          293                 if (verbose)
          294                         fprintf(stderr, "agent pid is %d\n", pid);
          295 
          296                 printf("SAFE_PID=%d; export SAFE_PID\n", pid);
          297                 printf("SAFE_SOCK=%s; export SAFE_SOCK\n", sockp);
          298                 return 0;
          299         }
          300 
          301         if (setsid() < 0)
          302                 err(1, "setsid");
          303 
          304         if ((fd = open(_PATH_DEVNULL, O_RDWR, 0)) != -1) {
          305                 (void)dup2(fd, STDIN_FILENO);
          306                 (void)dup2(fd, STDOUT_FILENO);
          307                 (void)dup2(fd, STDERR_FILENO);
          308                 if (fd > 2)
          309                         close(fd);
          310         }
          311 
          312 skip:
          313         pid = getpid();
          314         signal(SIGINT, sighandler);
          315         signal(SIGTERM, sighandler);
          316         signal(SIGUSR1, sighandler);
          317         signal(SIGALRM, sighandler);
          318 
          319         if (sodium_init() < 0)
          320                 return -1;
          321 
          322         if (verbose)
          323                 fprintf(stderr, "locking key in memory\n");
          324 
          325         return servekey(timeout);
          326 }