tftp.c - sbase - suckless unix tools
 (HTM) git clone git://git.suckless.org/sbase
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) README
 (DIR) LICENSE
       ---
       tftp.c (5791B)
       ---
            1 /* See LICENSE file for copyright and license details. */
            2 #include <sys/time.h>
            3 #include <sys/types.h>
            4 #include <sys/socket.h>
            5 
            6 #include <netdb.h>
            7 #include <netinet/in.h>
            8 
            9 #include <errno.h>
           10 #include <stdio.h>
           11 #include <stdlib.h>
           12 #include <string.h>
           13 #include <unistd.h>
           14 
           15 #include "util.h"
           16 
           17 #define BLKSIZE 512
           18 #define HDRSIZE 4
           19 #define PKTSIZE (BLKSIZE + HDRSIZE)
           20 
           21 #define TIMEOUT_SEC 5
           22 /* transfer will time out after NRETRIES * TIMEOUT_SEC */
           23 #define NRETRIES 5
           24 
           25 #define RRQ  1
           26 #define WWQ  2
           27 #define DATA 3
           28 #define ACK  4
           29 #define ERR  5
           30 
           31 static char *errtext[] = {
           32         "Undefined",
           33         "File not found",
           34         "Access violation",
           35         "Disk full or allocation exceeded",
           36         "Illegal TFTP operation",
           37         "Unknown transfer ID",
           38         "File already exists",
           39         "No such user"
           40 };
           41 
           42 static struct sockaddr_storage to;
           43 static socklen_t tolen;
           44 static int timeout;
           45 static int state;
           46 static int s;
           47 
           48 static int
           49 packreq(unsigned char *buf, int op, char *path, char *mode)
           50 {
           51         unsigned char *p = buf;
           52 
           53         *p++ = op >> 8;
           54         *p++ = op & 0xff;
           55         if (strlen(path) + 1 > 256)
           56                 eprintf("filename too long\n");
           57         memcpy(p, path, strlen(path) + 1);
           58         p += strlen(path) + 1;
           59         memcpy(p, mode, strlen(mode) + 1);
           60         p += strlen(mode) + 1;
           61         return p - buf;
           62 }
           63 
           64 static int
           65 packack(unsigned char *buf, int blkno)
           66 {
           67         buf[0] = ACK >> 8;
           68         buf[1] = ACK & 0xff;
           69         buf[2] = blkno >> 8;
           70         buf[3] = blkno & 0xff;
           71         return 4;
           72 }
           73 
           74 static int
           75 packdata(unsigned char *buf, int blkno)
           76 {
           77         buf[0] = DATA >> 8;
           78         buf[1] = DATA & 0xff;
           79         buf[2] = blkno >> 8;
           80         buf[3] = blkno & 0xff;
           81         return 4;
           82 }
           83 
           84 static int
           85 unpackop(unsigned char *buf)
           86 {
           87         return (buf[0] << 8) | (buf[1] & 0xff);
           88 }
           89 
           90 static int
           91 unpackblkno(unsigned char *buf)
           92 {
           93         return (buf[2] << 8) | (buf[3] & 0xff);
           94 }
           95 
           96 static int
           97 unpackerrc(unsigned char *buf)
           98 {
           99         int errc;
          100 
          101         errc = (buf[2] << 8) | (buf[3] & 0xff);
          102         if (errc < 0 || errc >= LEN(errtext))
          103                 eprintf("bad error code: %d\n", errc);
          104         return errc;
          105 }
          106 
          107 static int
          108 writepkt(unsigned char *buf, int len)
          109 {
          110         int n;
          111 
          112         n = sendto(s, buf, len, 0, (struct sockaddr *)&to,
          113                    tolen);
          114         if (n < 0)
          115                 if (errno != EINTR)
          116                         eprintf("sendto:");
          117         return n;
          118 }
          119 
          120 static int
          121 readpkt(unsigned char *buf, int len)
          122 {
          123         int n;
          124 
          125         n = recvfrom(s, buf, len, 0, (struct sockaddr *)&to,
          126                      &tolen);
          127         if (n < 0) {
          128                 if (errno != EINTR && errno != EWOULDBLOCK)
          129                         eprintf("recvfrom:");
          130                 timeout++;
          131                 if (timeout == NRETRIES)
          132                         eprintf("transfer timed out\n");
          133         } else {
          134                 timeout = 0;
          135         }
          136         return n;
          137 }
          138 
          139 static void
          140 getfile(char *file)
          141 {
          142         unsigned char buf[PKTSIZE];
          143         int n, op, blkno, nextblkno = 1, done = 0;
          144 
          145         state = RRQ;
          146         for (;;) {
          147                 switch (state) {
          148                 case RRQ:
          149                         n = packreq(buf, RRQ, file, "octet");
          150                         writepkt(buf, n);
          151                         n = readpkt(buf, sizeof(buf));
          152                         if (n > 0) {
          153                                 op = unpackop(buf);
          154                                 if (op != DATA && op != ERR)
          155                                         eprintf("bad opcode: %d\n", op);
          156                                 state = op;
          157                         }
          158                         break;
          159                 case DATA:
          160                         n -= HDRSIZE;
          161                         if (n < 0)
          162                                 eprintf("truncated packet\n");
          163                         blkno = unpackblkno(buf);
          164                         if (blkno == nextblkno) {
          165                                 nextblkno++;
          166                                 write(1, &buf[HDRSIZE], n);
          167                         }
          168                         if (n < BLKSIZE)
          169                                 done = 1;
          170                         state = ACK;
          171                         break;
          172                 case ACK:
          173                         n = packack(buf, blkno);
          174                         writepkt(buf, n);
          175                         if (done)
          176                                 return;
          177                         n = readpkt(buf, sizeof(buf));
          178                         if (n > 0) {
          179                                 op = unpackop(buf);
          180                                 if (op != DATA && op != ERR)
          181                                         eprintf("bad opcode: %d\n", op);
          182                                 state = op;
          183                         }
          184                         break;
          185                 case ERR:
          186                         eprintf("error: %s\n", errtext[unpackerrc(buf)]);
          187                 }
          188         }
          189 }
          190 
          191 static void
          192 putfile(char *file)
          193 {
          194         unsigned char inbuf[PKTSIZE], outbuf[PKTSIZE];
          195         int inb, outb, op, blkno, nextblkno = 0, done = 0;
          196 
          197         state = WWQ;
          198         for (;;) {
          199                 switch (state) {
          200                 case WWQ:
          201                         outb = packreq(outbuf, WWQ, file, "octet");
          202                         writepkt(outbuf, outb);
          203                         inb = readpkt(inbuf, sizeof(inbuf));
          204                         if (inb > 0) {
          205                                 op = unpackop(inbuf);
          206                                 if (op != ACK && op != ERR)
          207                                         eprintf("bad opcode: %d\n", op);
          208                                 state = op;
          209                         }
          210                         break;
          211                 case DATA:
          212                         if (blkno == nextblkno) {
          213                                 nextblkno++;
          214                                 packdata(outbuf, nextblkno);
          215                                 outb = read(0, &outbuf[HDRSIZE], BLKSIZE);
          216                                 if (outb < BLKSIZE)
          217                                         done = 1;
          218                         }
          219                         writepkt(outbuf, outb + HDRSIZE);
          220                         inb = readpkt(inbuf, sizeof(inbuf));
          221                         if (inb > 0) {
          222                                 op = unpackop(inbuf);
          223                                 if (op != ACK && op != ERR)
          224                                         eprintf("bad opcode: %d\n", op);
          225                                 state = op;
          226                         }
          227                         break;
          228                 case ACK:
          229                         if (inb < HDRSIZE)
          230                                 eprintf("truncated packet\n");
          231                         blkno = unpackblkno(inbuf);
          232                         if (blkno == nextblkno)
          233                                 if (done)
          234                                         return;
          235                         state = DATA;
          236                         break;
          237                 case ERR:
          238                         eprintf("error: %s\n", errtext[unpackerrc(inbuf)]);
          239                 }
          240         }
          241 }
          242 
          243 static void
          244 usage(void)
          245 {
          246         eprintf("usage: %s -h host [-p port] [-x | -c] file\n", argv0);
          247 }
          248 
          249 int
          250 main(int argc, char *argv[])
          251 {
          252         struct addrinfo hints, *res, *r;
          253         struct timeval tv;
          254         char *host = NULL, *port = "tftp";
          255         void (*fn)(char *) = getfile;
          256         int ret;
          257 
          258         ARGBEGIN {
          259         case 'h':
          260                 host = EARGF(usage());
          261                 break;
          262         case 'p':
          263                 port = EARGF(usage());
          264                 break;
          265         case 'x':
          266                 fn = getfile;
          267                 break;
          268         case 'c':
          269                 fn = putfile;
          270                 break;
          271         default:
          272                 usage();
          273         } ARGEND
          274 
          275         if (!host || !argc)
          276                 usage();
          277 
          278         memset(&hints, 0, sizeof(hints));
          279         hints.ai_family = AF_UNSPEC;
          280         hints.ai_socktype = SOCK_DGRAM;
          281         hints.ai_protocol = IPPROTO_UDP;
          282         ret = getaddrinfo(host, port, &hints, &res);
          283         if (ret)
          284                 eprintf("getaddrinfo: %s\n", gai_strerror(ret));
          285 
          286         for (r = res; r; r = r->ai_next) {
          287                 if (r->ai_family != AF_INET &&
          288                     r->ai_family != AF_INET6)
          289                         continue;
          290                 s = socket(r->ai_family, r->ai_socktype,
          291                            r->ai_protocol);
          292                 if (s < 0)
          293                         continue;
          294                 break;
          295         }
          296         if (!r)
          297                 eprintf("cannot create socket\n");
          298         memcpy(&to, r->ai_addr, r->ai_addrlen);
          299         tolen = r->ai_addrlen;
          300         freeaddrinfo(res);
          301 
          302         tv.tv_sec = TIMEOUT_SEC;
          303         tv.tv_usec = 0;
          304         if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
          305                 eprintf("setsockopt:");
          306 
          307         fn(argv[0]);
          308         return 0;
          309 }