#include "stegome.h"

char *argv0;

/* Global variables */
int debug = 0;  /* Debug flag */

void
usage(void)
{
    fprint(2, "usage: %s [-m method] [-i image] [-o output] [-h|-x] [-d]\n", argv0);
    fprint(2, "  -m method: steganography method (lsb or blend, default: lsb)\n");
    fprint(2, "  -i image: carrier image file (default: stdin)\n");
    fprint(2, "  -o output: output file (default: stdout)\n");
    fprint(2, "  -h: hide data (read from stdin)\n");
    fprint(2, "  -x: extract data (write to stdout)\n");
    fprint(2, "  -d: enable debug output\n");
    exits("usage");
}

int
hide_lsb(uchar *imgbuf, ulong imgsize, uchar *data, long datalen, uchar **outbuf, ulong *outsize)
{
    int i, j, bit;
    uchar *p;
    uchar *out;
    
    if(debug)
        fprint(2, "hide_lsb: imgsize=%lud, datalen=%ld\n", imgsize, datalen);
    
    /* Number of bits we can hide is limited by image size */
    long maxbits = (long)(imgsize);
    
    if (datalen * 8 + sizeof(long) * 8 > maxbits) {
        werrstr("data too large for carrier image");
        return -1;
    }
    
    /* Allocate buffer for output image data (same size as input) */
    *outbuf = malloc(imgsize);
    if(*outbuf == nil){
        werrstr("failed to allocate buffer");
        return -1;
    }
    *outsize = imgsize;
    
    /* Copy the original image data */
    memmove(*outbuf, imgbuf, imgsize);
    out = *outbuf;
    
    /* Hide data length at the beginning */
    p = (uchar*)&datalen;
    if(debug)
        fprint(2, "hide_lsb: encoding length=%ld (%02x %02x %02x %02x)\n", 
               datalen, p[0], p[1], p[2], p[3]);
           
    for(i = 0; i < sizeof(long); i++){
        for(j = 0; j < 8; j++){
            if(i*8 + j >= imgsize)
                break;
                
            bit = (p[i] >> j) & 1;
            out[i*8 + j] = (out[i*8 + j] & ~1) | bit;
        }
    }
    
    /* Hide the actual data */
    for(i = 0; i < datalen; i++){
        for(j = 0; j < 8; j++){
            long pos = sizeof(long)*8 + i*8 + j;
            if(pos >= imgsize)
                break;
                
            bit = (data[i] >> j) & 1;
            out[pos] = (out[pos] & ~1) | bit;
        }
    }
    
    return 0;
}

int
extract_lsb(uchar *imgbuf, ulong imgsize, uchar *outbuf, long *outlen)
{
    int i, j, bit;
    uchar *p;
    long datalen = 0;
    
    if(debug)
        fprint(2, "extract_lsb: imgsize=%lud\n", imgsize);
    
    /* Extract data length from the beginning */
    p = (uchar*)&datalen;
    for(i = 0; i < sizeof(long); i++){
        p[i] = 0;
        for(j = 0; j < 8; j++){
            if(i*8 + j >= imgsize)
                break;
                
            bit = imgbuf[i*8 + j] & 1;
            p[i] |= bit << j;
        }
    }
    
    if(debug)
        fprint(2, "extract_lsb: decoded length=%ld (%02x %02x %02x %02x)\n", 
               datalen, p[0], p[1], p[2], p[3]);
    
    if (datalen <= 0 || datalen > BUFSIZE || datalen*8 + sizeof(long)*8 > imgsize) {
        werrstr("invalid data length (datalen=%ld, BUFSIZE=%d, imgsize=%lud)", 
                datalen, BUFSIZE, imgsize);
        return -1;
    }
    
    *outlen = datalen;
    memset(outbuf, 0, datalen);
    
    /* Extract the actual data */
    for(i = 0; i < datalen; i++){
        outbuf[i] = 0;
        for(j = 0; j < 8; j++){
            long pos = sizeof(long)*8 + i*8 + j;
            if(pos >= imgsize)
                break;
                
            bit = imgbuf[pos] & 1;
            outbuf[i] |= bit << j;
        }
    }
    
    return 0;
}

int
hide_blend(uchar *imgbuf, ulong imgsize, uchar *data, long datalen, uchar **outbuf, ulong *outsize)
{
    int i, j, idx;
    uchar *p;
    long maxbits;
    uchar *out;
    
    if(debug)
        fprint(2, "hide_blend: imgsize=%lud, datalen=%ld\n", imgsize, datalen);
    
    /* For blend method, we need RGBA format (4 bytes per pixel) */
    if(imgsize % 4 != 0){
        werrstr("image must be in RGBA format");
        return -1;
    }
    
    /* Number of bits we can hide is limited by number of pixels */
    maxbits = (imgsize / 4) * 8;
    
    if(datalen * 8 + sizeof(long) * 8 > maxbits){
        werrstr("data too large for carrier image");
        return -1;
    }
    
    /* Allocate buffer for output image data */
    *outbuf = malloc(imgsize);
    if(*outbuf == nil){
        werrstr("failed to allocate buffer");
        return -1;
    }
    *outsize = imgsize;
    
    /* Copy the original image data */
    memmove(*outbuf, imgbuf, imgsize);
    out = *outbuf;
    
    /* Hide data length in the alpha channel */
    p = (uchar*)&datalen;
    if(debug)
        fprint(2, "hide_blend: encoding length=%ld (%02x %02x %02x %02x)\n", 
               datalen, p[0], p[1], p[2], p[3]);
           
    for(i = 0; i < sizeof(long); i++){
        for(j = 0; j < 8; j++){
            if(i*8 + j >= imgsize/4)
                break;
            
            /* Calculate the index for the alpha byte (4th byte of each pixel) */
            idx = (i*8 + j) * 4 + 3;
            if(idx >= imgsize)
                break;
                
            /* Modify alpha channel (4th byte of each pixel) */
            if(p[i] & (1 << j))
                out[idx] = 0xFF;  /* Fully opaque */
            else
                out[idx] = 0x00;  /* Fully transparent */
        }
    }
    
    /* Hide the actual data in the alpha channel */
    for(i = 0; i < datalen; i++){
        for(j = 0; j < 8; j++){
            int bitpos = i*8 + j + sizeof(long)*8;
            if(bitpos >= imgsize/4)
                break;
            
            /* Calculate the index for the alpha byte */
            idx = bitpos * 4 + 3;
            if(idx >= imgsize)
                break;
                
            /* Modify alpha channel (4th byte of each pixel) */
            if(data[i] & (1 << j))
                out[idx] = 0xFF;  /* Fully opaque */
            else
                out[idx] = 0x00;  /* Fully transparent */
        }
    }
    
    return 0;
}

int
extract_blend(uchar *imgbuf, ulong imgsize, uchar *outbuf, long *outlen)
{
    int i, j, idx;
    uchar *p;
    long datalen = 0;
    
    if(debug)
        fprint(2, "extract_blend: imgsize=%lud\n", imgsize);
    
    /* For blend method, we need RGBA format (4 bytes per pixel) */
    if(imgsize % 4 != 0){
        werrstr("image must be in RGBA format");
        return -1;
    }
    
    /* Extract data length from the alpha channel */
    p = (uchar*)&datalen;
    for(i = 0; i < sizeof(long); i++){
        p[i] = 0;
        for(j = 0; j < 8; j++){
            if(i*8 + j >= imgsize/4)
                break;
            
            /* Calculate the index for the alpha byte */
            idx = (i*8 + j) * 4 + 3;
            if(idx >= imgsize)
                break;
                
            /* Read from alpha channel (4th byte of each pixel) */
            if(imgbuf[idx] == 0xFF)  /* Fully opaque */
                p[i] |= 1 << j;
        }
    }
    
    if(debug)
        fprint(2, "extract_blend: decoded length=%ld (%02x %02x %02x %02x)\n", 
               datalen, p[0], p[1], p[2], p[3]);
    
    if(datalen <= 0 || datalen > BUFSIZE || datalen*8 + sizeof(long)*8 > imgsize/4){
        werrstr("invalid data length (datalen=%ld, BUFSIZE=%d, imgsize=%lud)", 
                datalen, BUFSIZE, imgsize);
        return -1;
    }
    
    *outlen = datalen;
    memset(outbuf, 0, datalen);
    
    /* Extract the actual data from the alpha channel */
    for(i = 0; i < datalen; i++){
        outbuf[i] = 0;
        for(j = 0; j < 8; j++){
            int bitpos = i*8 + j + sizeof(long)*8;
            if(bitpos >= imgsize/4)
                break;
            
            /* Calculate the index for the alpha byte */
            idx = bitpos * 4 + 3;
            if(idx >= imgsize)
                break;
                
            /* Read from alpha channel (4th byte of each pixel) */
            if(imgbuf[idx] == 0xFF)  /* Fully opaque */
                outbuf[i] |= 1 << j;
        }
    }
    
    return 0;
}

int
detect_format(char *filename)
{
    char *ext;
    
    /* Get file extension */
    ext = strrchr(filename, '.');
    if(ext == nil)
        return FMT_BIT;  /* Default to Plan 9 bit format */
    
    ext++;  /* Skip the dot */
    
    if(strcmp(ext, "bit") == 0)
        return FMT_BIT;
    else if(strcmp(ext, "jpg") == 0 || strcmp(ext, "jpeg") == 0)
        return FMT_JPG;
    else if(strcmp(ext, "png") == 0)
        return FMT_PNG;
    
    return FMT_BIT;  /* Default to Plan 9 bit format */
}

int
read_image(char *filename, uchar **buf, ulong *size)
{
    int fd, fmt;
    Dir *d;
    uchar *p;
    
    /* Open the file */
    fd = open(filename, OREAD);
    if(fd < 0){
        werrstr("failed to open file %s: %r", filename);
        return -1;
    }
    
    /* Get file size */
    d = dirfstat(fd);
    if(d == nil){
        close(fd);
        werrstr("failed to stat file %s: %r", filename);
        return -1;
    }
    
    *size = d->length;
    free(d);
    
    /* Allocate buffer */
    *buf = malloc(*size);
    if(*buf == nil){
        close(fd);
        werrstr("failed to allocate memory");
        return -1;
    }
    
    /* Read file content */
    p = *buf;
    if(read(fd, p, *size) != *size){
        free(*buf);
        close(fd);
        werrstr("failed to read file %s: %r", filename);
        return -1;
    }
    
    close(fd);
    
    /* Detect format and return it */
    fmt = detect_format(filename);
    
    /* Ensure fmt is used */
    USED(fmt);
    
    return fmt;
}

int
write_image(char *filename, uchar *buf, ulong size, int format)
{
    int fd;
    
    USED(format);  /* Format detection is done by the filename for now */
    
    /* Open the file */
    fd = create(filename, OWRITE, 0666);
    if(fd < 0){
        werrstr("failed to create file %s: %r", filename);
        return -1;
    }
    
    /* Write file content */
    if(write(fd, buf, size) != size){
        close(fd);
        werrstr("failed to write file %s: %r", filename);
        return -1;
    }
    
    close(fd);
    return 0;
}

int
hide_data(char *carrierfile, char *secretfile, char *outputfile, int method)
{
    Biobuf *bf;
    uchar *imgbuf, *secretbuf, *outbuf;
    ulong imgsize, outsize;
    long n, total = 0;
    int fmt;
    
    /* Load carrier image */
    fmt = read_image(carrierfile, &imgbuf, &imgsize);
    if(fmt < 0){
        fprint(2, "%s: couldn't load carrier image %s: %r\n", argv0, carrierfile);
        return -1;
    }
    
    /* Read secret data */
    secretbuf = malloc(BUFSIZE);
    if(secretbuf == nil){
        free(imgbuf);
        fprint(2, "%s: failed to allocate memory\n", argv0);
        return -1;
    }
    
    bf = Bopen(secretfile, OREAD);
    if(bf == nil){
        free(imgbuf);
        free(secretbuf);
        fprint(2, "%s: couldn't open secret file %s: %r\n", argv0, secretfile);
        return -1;
    }
    
    while((n = Bread(bf, secretbuf+total, BUFSIZE-total)) > 0){
        total += n;
        if(total >= BUFSIZE){
            Bterm(bf);
            free(imgbuf);
            free(secretbuf);
            fprint(2, "%s: secret file too large\n", argv0);
            return -1;
        }
    }
    
    Bterm(bf);
    
    /* Hide the data using selected method */
    if(method == STEG_LSB){
        if(hide_lsb(imgbuf, imgsize, secretbuf, total, &outbuf, &outsize) < 0){
            free(imgbuf);
            free(secretbuf);
            fprint(2, "%s: failed to hide data: %r\n", argv0);
            return -1;
        }
    } else if(method == STEG_BLEND){
        if(hide_blend(imgbuf, imgsize, secretbuf, total, &outbuf, &outsize) < 0){
            free(imgbuf);
            free(secretbuf);
            fprint(2, "%s: failed to hide data: %r\n", argv0);
            return -1;
        }
    }
    
    /* Write output image */
    if(write_image(outputfile, outbuf, outsize, fmt) < 0){
        free(imgbuf);
        free(secretbuf);
        free(outbuf);
        fprint(2, "%s: failed to write output image: %r\n", argv0);
        return -1;
    }
    
    free(imgbuf);
    free(secretbuf);
    free(outbuf);
    return 0;
}

int
extract_data(char *carrierfile, char *outputfile, int method)
{
    Biobuf *bf;
    uchar *imgbuf, *outbuf;
    ulong imgsize;
    long outlen = 0;
    int fmt;
    
    /* Load carrier image */
    fmt = read_image(carrierfile, &imgbuf, &imgsize);
    if(fmt < 0){
        fprint(2, "%s: couldn't load carrier image %s: %r\n", argv0, carrierfile);
        return -1;
    }
    
    /* Allocate buffer for output data */
    outbuf = malloc(BUFSIZE);
    if(outbuf == nil){
        free(imgbuf);
        fprint(2, "%s: failed to allocate memory\n", argv0);
        return -1;
    }
    
    /* Extract data using selected method */
    if(method == STEG_LSB){
        if(extract_lsb(imgbuf, imgsize, outbuf, &outlen) < 0){
            free(imgbuf);
            free(outbuf);
            fprint(2, "%s: failed to extract data: %r\n", argv0);
            return -1;
        }
    } else if(method == STEG_BLEND){
        if(extract_blend(imgbuf, imgsize, outbuf, &outlen) < 0){
            free(imgbuf);
            free(outbuf);
            fprint(2, "%s: failed to extract data: %r\n", argv0);
            return -1;
        }
    }
    
    /* Write output file */
    bf = Bopen(outputfile, OWRITE);
    if(bf == nil){
        free(imgbuf);
        free(outbuf);
        fprint(2, "%s: couldn't open output file %s: %r\n", argv0, outputfile);
        return -1;
    }
    
    if(Bwrite(bf, outbuf, outlen) != outlen){
        Bterm(bf);
        free(imgbuf);
        free(outbuf);
        fprint(2, "%s: failed to write output data: %r\n", argv0);
        return -1;
    }
    
    Bterm(bf);
    free(imgbuf);
    free(outbuf);
    return 0;
}

void
main(int argc, char *argv[])
{
    char *method = "lsb";
    char *image = nil;
    char *output = nil;
    int hide = 0;
    int extract = 0;
    int fd, n, i;
    uchar *imgbuf;
    ulong imgsize;
    uchar *outbuf;
    ulong outsize;
    uchar data[BUFSIZE];
    long datalen;
    int fmt;
    Biobuf *bf;
    
    argv0 = argv[0]; /* Set argv0 for usage message */
    
    ARGBEGIN {
    case 'm':
        method = EARGF(usage());
        break;
    case 'i':
        image = EARGF(usage());
        break;
    case 'o':
        output = EARGF(usage());
        break;
    case 'h':
        hide = 1;
        break;
    case 'x':
        extract = 1;
        break;
    case 'd':
        debug = 1;
        break;
    default:
        usage();
    } ARGEND;
    
    if(!hide && !extract)
        usage();
    if(hide && extract)
        usage();
    
    /* Read carrier image */
    if(image != nil){
        /* Read from file */
        fmt = detect_format(image);
        
        bf = Bopen(image, OREAD);
        if(bf == nil)
            sysfatal("failed to open input file %s: %r", image);
        
        /* Get file size and allocate buffer */
        Bseek(bf, 0, 2);
        imgsize = Boffset(bf);
        Bseek(bf, 0, 0);
        
        imgbuf = mallocz(imgsize, 1);
        if(imgbuf == nil){
            Bterm(bf);
            sysfatal("failed to allocate memory");
        }
        
        /* Read file */
        if(Bread(bf, imgbuf, imgsize) != imgsize){
            Bterm(bf);
            free(imgbuf);
            sysfatal("failed to read input file %s: %r", image);
        }
        
        Bterm(bf);
    } else {
        /* Read from stdin */
        Biobuf bin;
        
        Binit(&bin, 0, OREAD);
        
        imgbuf = mallocz(BUFSIZE, 1);
        if(imgbuf == nil)
            sysfatal("failed to allocate memory");
        
        imgsize = 0;
        while((n = Bread(&bin, imgbuf+imgsize, BUFSIZE-imgsize)) > 0){
            imgsize += n;
            if(imgsize >= BUFSIZE)
                break;
        }
        
        fmt = FMT_BIT;  /* Default to bit format for stdin */
        Bterm(&bin);
    }
    
    if(hide){
        /* Read data to hide */
        Biobuf bin;
        
        Binit(&bin, 0, OREAD);
        
        datalen = 0;
        while((n = Bread(&bin, data+datalen, BUFSIZE-datalen)) > 0){
            datalen += n;
            if(datalen >= BUFSIZE)
                break;
        }
        
        Bterm(&bin);
        
        if(datalen <= 0)
            sysfatal("no data to hide");
        
        /* Hide data */
        if(strcmp(method, "lsb") == 0){
            if(hide_lsb(imgbuf, imgsize, data, datalen, &outbuf, &outsize) < 0)
                sysfatal("hide_lsb: %r");
        } else if(strcmp(method, "blend") == 0){
            if(hide_blend(imgbuf, imgsize, data, datalen, &outbuf, &outsize) < 0)
                sysfatal("hide_blend: %r");
        } else
            sysfatal("unknown method: %s", method);
        
        /* Write output */
        if(output != nil){
            bf = Bopen(output, OWRITE);
            if(bf == nil)
                sysfatal("failed to create output file %s: %r", output);
                
            if(Bwrite(bf, outbuf, outsize) != outsize){
                Bterm(bf);
                sysfatal("failed to write output file %s: %r", output);
            }
            
            Bterm(bf);
        } else {
            /* Write to stdout */
            Biobuf bout;
            
            Binit(&bout, 1, OWRITE);
            if(Bwrite(&bout, outbuf, outsize) != outsize)
                sysfatal("failed to write output: %r");
                
            Bterm(&bout);
        }
        
        free(outbuf);
    } else {
        /* Extract data */
        if(strcmp(method, "lsb") == 0){
            if(extract_lsb(imgbuf, imgsize, data, &datalen) < 0)
                sysfatal("extract_lsb: %r");
        } else if(strcmp(method, "blend") == 0){
            if(extract_blend(imgbuf, imgsize, data, &datalen) < 0)
                sysfatal("extract_blend: %r");
        } else
            sysfatal("unknown method: %s", method);
        
        /* Write extracted data to stdout */
        Biobuf bout;
        
        Binit(&bout, 1, OWRITE);
        if(Bwrite(&bout, data, datalen) != datalen)
            sysfatal("failed to write output: %r");
            
        Bterm(&bout);
    }
    
    free(imgbuf);
    exits(nil);
} 