#include <sys/stat.h>
#include <err.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <lz4.h>
#include <openssl/sha.h>

#include "arg.h"
#include "tree.h"

#define INDEXF ".index"
#define STOREF ".store"
#define CACHEF ".cache"

#define BLKSIZE 4096
#define WINSIZE 511
#define HASHMSK ((1ul << 10) - 1)
#define MSGSIZE 256
#define MDSIZE SHA256_DIGEST_LENGTH

#define ROTL(x, y) (((x) << (y)) | ((x) >> (32 - (y))))

enum {
	WALK_CONTINUE,
	WALK_STOP
};

struct stats {
	uint64_t orig_size;
	uint64_t comp_size;
	uint64_t dedup_size;
	uint64_t cache_hits;
	uint64_t cache_misses;
};

/* index file header */
struct enthdr {
	uint64_t flags;
	uint64_t nents;
	uint64_t store_size;
	struct stats st;
};

/* block descriptor */
struct bdescr {
	uint8_t md[MDSIZE];
	uint64_t offset;
	uint64_t size;
};

/* index file entry */
struct ent {
	uint64_t size;
	uint8_t msg[MSGSIZE];
	uint8_t md[MDSIZE];	/* hash of file */
	uint64_t nblks;
	struct bdescr bdescr[];
};

/* cache entry */
struct cent {
	struct bdescr bdescr;
	RB_ENTRY(cent) e;
};

struct extract_args {
	uint8_t *md;
	int fd;
};

RB_HEAD(cache, cent) cache_head;
struct enthdr enthdr;
int ifd;
int sfd;
int cfd;
int verbose;
int cache_dirty;
char *argv0;

/*
 * Static table for use in buzhash algorithm.
 * 256 * 32 bits randomly generated unique integers
 *
 * To get better pseudo-random results, there is exactly the same number
 * of 0 and 1 spread amongst these integers. It means that there is
 * exactly 50% chance that a XOR operation would flip all the bits in
 * the hash.
 */
uint32_t buz[] = {
	0xbc9fa594,0x30a8f827,0xced627a7,0xdb46a745,0xcfa4a9e8,0x77cccb59,0xddb66276,0x3adc532f,
	0xfe8b67d3,0x8155b59e,0x0c893666,0x1d757009,0x17394ee4,0x85d94c07,0xcacd52da,0x076c6f79,
	0xead0a798,0x6c7ccb4a,0x2639a1b8,0x3aa5ae32,0x3e6218d2,0xb290d980,0xa5149521,0x4b426119,
	0xd3230fc7,0x677c1cc4,0x2b64603c,0x01fe92a8,0xbe358296,0xa7e7fac7,0xf509bf41,0x04b017ad,
	0xf900344c,0x8e14e202,0xb2a6e9b4,0x3db3c311,0x960286a8,0xf6bf0468,0xed54ec94,0xf358070c,
	0x6a4795dd,0x3f7b925c,0x5e13a060,0xfaecbafe,0x03c8bb55,0x8a56ba88,0x633e3b49,0xe036bbbe,
	0x1ed3dbb5,0x76e8ad74,0x79d346ab,0x44b4ccc4,0x71eb22d3,0xa1aa3f24,0x50e05b81,0xa3b450d3,
	0x7f5caffb,0xa1990650,0x54c44800,0xda134b65,0x72362eea,0xbd12b8e6,0xf7c99fdc,0x020d48c7,
	0x9d9c3d46,0x32b75615,0xe61923cf,0xadc09d8f,0xab11376b,0xd66fe4cd,0xb3b086b6,0xb8345b9f,
	0x59029667,0xae0e937c,0xcbd4d4ba,0x720bb3fb,0x5f7d2ca3,0xec24ba15,0x6b40109b,0xf0a54587,
	0x3acf9420,0x466e981d,0xc66dc124,0x150ef7b4,0xc3ce718e,0x136774f5,0x46684ab4,0xb4b490f0,
	0x26508a8b,0xf12febc8,0x4b99171b,0xfc373c84,0x339b5677,0x41703ff3,0x7cadbbd7,0x15ea24e2,
	0x7a2f9783,0xed6a383a,0x649eb072,0x79970941,0x2abd28ad,0x4375e00c,0x9df084f7,0x6fdeec6c,
	0x6619ac6d,0x7d256f4d,0x9b8e658a,0x3d7627e9,0xd5a98d45,0x15f84223,0x9b6acef5,0xf876be67,
	0xe3ae7089,0x84e2b64a,0x6818a969,0x86e9ba4e,0xa24a5b57,0x61570cf1,0xa5f8fc91,0x879d8383,
	0x91b13866,0x75e87961,0x16db8138,0x5a2ff6b8,0x8f664e9b,0x894e1496,0x88235c5b,0xcdb3b580,
	0xa2e80109,0xb0f88a82,0xd12cd340,0x93fbc37d,0xf4d1eb82,0xce42f309,0x16ffd2c2,0xb4dfef2b,
	0xb8b1a33e,0x4708a5e6,0xba66dd88,0xa9ec0da6,0x6f8ee2c9,0xad8b9993,0x1d6a25a8,0x1f3d08ce,
	0x149c04e7,0x5cd1fa51,0xb84c89c7,0xeced6f8c,0xe328b30f,0x084fa836,0x6d1bb1b7,0x94c78ea5,
	0x14973034,0xf1a1bcef,0x48b798d2,0xded9ca9e,0x5fd965d0,0x92544eb1,0x5e80f189,0xcbbf5e15,
	0x4d8121f0,0x5dd3b92f,0xd9ea98fb,0x2dbf5644,0x0fbcb9b7,0x20a1db53,0x7c3fcc98,0x36744fbd,
	0xced08954,0x8e7c5efe,0x3c5f6733,0x657477be,0x3630a02d,0x38bcbda0,0xb7702575,0x4a7f4bce,
	0x0e7660fe,0x4dcb91b5,0x4fd7ffd3,0x041821c1,0xa846a181,0xc8048e9e,0xd4b05072,0x986e0509,
	0xa00aaeeb,0x02e3526a,0x2fac4843,0xfa98e805,0x923ecd8d,0x395d9546,0x8674c3cd,0xae5a8a71,
	0x966dfe45,0x5c9ceba5,0x0830a1cf,0xa1750981,0x8f604480,0x28ea0c9a,0x0da12413,0x98b0b3c5,
	0xa21d473a,0x96ce4308,0xe9a1001b,0x8bbacb44,0x18bad3f4,0xe3121acb,0x46a9b45f,0x92cd9704,
	0xc1a7c619,0x3281e361,0x462e8c79,0x9e572f93,0x7239e5f0,0x67d8e6ba,0x13747ce3,0xf01ee64a,
	0xe7d0ae12,0xeea04088,0xe5b36767,0x17558eae,0x678ffbe6,0xe0bbc866,0x0c24adec,0xa9cbb869,
	0x3fd44ee1,0x9ca4ca06,0x04c0ef00,0x04589a21,0x9cf9c819,0x976f6ca1,0x8a30e66a,0x004d6f7e,
	0x384c8851,0x5bc97eb8,0xc6c49339,0x5aa386c7,0x74bdf8af,0x9b713750,0x4112f8c2,0x2895dae1,
	0xf576d905,0x9de98bce,0xb2b26bcd,0xd46707a0,0x147fbb46,0xa52c6e50,0xe43128fc,0x374ad964,
	0x8dfd4d53,0xc4d0c087,0x31dfb5ca,0xa44589b5,0x6b637e2e,0x663f6b45,0xd2d8baa0,0x1dac7e4c
};

/* Buzhash: https://en.wikipedia.org/wiki/Rolling_hash#Cyclic_polynomial */
uint32_t
buzh_init(uint8_t *buf, size_t size)
{
	size_t i;
	uint32_t fp;

	for (i = size - 1, fp = 0; i > 0; i--, buf++)
		fp ^= ROTL(buz[*buf], i % 32);

	return fp ^ buz[*buf];
}

uint32_t
buzh_update(uint32_t fp, uint8_t in, uint8_t out, size_t size)
{
	return ROTL(fp, 1) ^ ROTL(buz[out], size % 32) ^ buz[in];
}

uint64_t
chunk_blk(uint8_t *buf, size_t size)
{
	size_t i;
	uint32_t fp;

	/* buzhash should be at least WINSIZE */
	if (size < WINSIZE)
		return size;

	/*
	 * To achieve better deduplication, we chunk blocks based on a
	 * recurring pattern occuring on the data stream. A fixed window
	 * of WINSIZE bytes is slid over the data, and a rolling hash is
	 * computed for this window.
	 * When the rolling hash matches a given pattern (see HASHMSK),
	 * the block is chunked at the end of that window, thus making
	 * WINSIZE the smallest possible block size.
	 */
	fp = buzh_init(buf, WINSIZE);
	for (i = 0; i < size - WINSIZE; i++) {
		if (i > 0)
			fp = buzh_update(fp, buf[i - 1], buf[WINSIZE + i - 1], WINSIZE);
		if ((fp & HASHMSK) == 0)
			return i + WINSIZE;
	}
	return size;
}

size_t
comp_size(size_t size)
{
	return LZ4_compressBound(size);
}

size_t
comp(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	return LZ4_compress_default((char *)in, (char *)out, insize, outsize);
}

size_t
decomp(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	return LZ4_decompress_safe((char *)in, (char *)out, insize, outsize);
}

void
print_md(const uint8_t *md, size_t size)
{
	size_t i;

	for (i = 0; i < size; i++)
		fprintf(stderr, "%02x", md[i]);
}

void
str2bin(char *s, uint8_t *d)
{
	size_t i, size = strlen(s) / 2;

	for (i = 0; i < size; i++, s += 2)
		sscanf(s, "%2hhx", &d[i]);
}

ssize_t
xread(int fd, void *buf, size_t nbytes)
{
	uint8_t *bp = buf;
	ssize_t total = 0;

	while (nbytes > 0) {
		ssize_t n;

		n = read(fd, &bp[total], nbytes);
		if (n < 0)
			err(1, "read");
		else if (n == 0)
			return total;
		total += n;
		nbytes -= n;
	}
	return total;
}

ssize_t
xwrite(int fd, const void *buf, size_t nbytes)
{
	const uint8_t *bp = buf;
	ssize_t total = 0;

	while (nbytes > 0) {
		ssize_t n;

		n = write(fd, &bp[total], nbytes);
		if (n < 0)
			err(1, "write");
		else if (n == 0)
			return total;
		total += n;
		nbytes -= n;
	}
	return total;
}

int
cent_cmp(struct cent *e1, struct cent *e2)
{
	int r;

	r = memcmp(e1->bdescr.md, e2->bdescr.md, sizeof(e1->bdescr.md));
	if (r > 0)
		return 1;
	else if (r < 0)
		return -1;
	return 0;
}
RB_PROTOTYPE(cache, cent, e, cent_cmp);
RB_GENERATE(cache, cent, e, cent_cmp);

struct cent *
alloc_cent(void)
{
	struct cent *ent;

	ent = calloc(1, sizeof(*ent));
	if (ent == NULL)
		err(1, "calloc");
	return ent;
}

void
add_cent(struct cent *cent)
{
	RB_INSERT(cache, &cache_head, cent);
}

void
flush_cache(void)
{
	struct cent *cent;

	if (!cache_dirty)
		return;

	lseek(cfd, 0, SEEK_SET);
	RB_FOREACH(cent, cache, &cache_head)
		xwrite(cfd, &cent->bdescr, sizeof(cent->bdescr));
}

void
free_cache(void)
{
	struct cent *cent, *tmp;

	RB_FOREACH_SAFE(cent, cache, &cache_head, tmp) {
		RB_REMOVE(cache, &cache_head, cent);
		free(cent);
	}
}

uint64_t
cache_nents(void)
{
	struct stat sb;

	if (fstat(cfd, &sb) < 0)
		err(1, "fstat");
	return sb.st_size /  sizeof(struct bdescr);
}

void
append_ent(struct ent *ent)
{
	/* Update index header */
	enthdr.nents++;
	lseek(ifd, 0, SEEK_SET);
	xwrite(ifd, &enthdr, sizeof(enthdr));

	/* Append entry */
	lseek(ifd, 0, SEEK_END);
	ent->size = sizeof(*ent);
	ent->size += ent->nblks * sizeof(ent->bdescr[0]);
	xwrite(ifd, ent, ent->size);
}

struct ent *
alloc_ent(void)
{
	struct ent *ent;

	ent = calloc(1, sizeof(*ent));
	if (ent == NULL)
		err(1, "calloc");
	return ent;
}

struct ent *
grow_ent(struct ent *ent, uint64_t nblks)
{
	size_t size;

	size = sizeof(*ent);
	size += nblks * sizeof(ent->bdescr[0]);
	ent = realloc(ent, size);
	if (ent == NULL)
		err(1, "realloc");
	return ent;
}

uint8_t *
alloc_buf(size_t size)
{
	void *p;

	p = calloc(1, size);
	if (p == NULL)
		err(1, "calloc");
	return p;
}

void
hash_blk(uint8_t *buf, size_t size, uint8_t *md)
{
	SHA256_CTX ctx;

	SHA256_Init(&ctx);
	SHA256_Update(&ctx, buf, size);
	SHA256_Final(md, &ctx);
}

void
read_blk(uint8_t *buf, struct bdescr *bdescr)
{
	lseek(sfd, bdescr->offset, SEEK_SET);
	if (xread(sfd, buf, bdescr->size) == 0)
		errx(1, "read: unexpected EOF");
}

void
append_blk(uint8_t *buf, struct bdescr *bdescr)
{
	lseek(sfd, enthdr.store_size, SEEK_SET);
	xwrite(sfd, buf, bdescr->size);
	enthdr.store_size += bdescr->size;
}

int
lookup_blk(uint8_t *md)
{
	struct cent *ent, key;

	memcpy(key.bdescr.md, md, sizeof(key.bdescr.md));
	ent = RB_FIND(cache, &cache_head, &key);
	if (ent != NULL)
		return 0;
	return -1;
}

void
dedup(int fd, char *msg)
{
	uint8_t *buf[2];
	struct ent *ent;
	SHA256_CTX ctx;
	ssize_t n, bufsize;

	buf[0] = alloc_buf(BLKSIZE);
	buf[1] = alloc_buf(comp_size(BLKSIZE));
	ent = alloc_ent();

	bufsize = 0;
	SHA256_Init(&ctx);
	while ((n = xread(fd, buf[0] + bufsize, BLKSIZE - bufsize)) > 0 || bufsize > 0) {
		uint8_t md[MDSIZE];
		struct bdescr bdescr;
		size_t blksize, csize;
		uint8_t *inp = buf[0]; /* input buf */
		uint8_t *outp = buf[1]; /* compressed buf */

		if (n > 0) {
			bufsize += n;
			enthdr.st.orig_size += n;
		}

		blksize = chunk_blk(inp, bufsize);
		csize = comp(inp, outp, blksize, comp_size(BLKSIZE));

		memcpy(bdescr.md, md, sizeof(bdescr));
		bdescr.offset = enthdr.store_size;
		bdescr.size = csize;

		enthdr.st.comp_size += bdescr.size;

		hash_blk(outp, bdescr.size, bdescr.md);

		/* Calculate file hash one block at a time */
		SHA256_Update(&ctx, outp, bdescr.size);

		ent = grow_ent(ent, ent->nblks + 1);

		if (lookup_blk(bdescr.md) < 0) {
			struct cent *cent;

			/* Update index entry */
			ent->bdescr[ent->nblks++] = bdescr;

			/* Store block */
			append_blk(outp, &bdescr);

			/* Create a cache entry for this block */
			cent = alloc_cent();
			cent->bdescr = bdescr;
			add_cent(cent);
			cache_dirty = 1;

			enthdr.st.dedup_size += bdescr.size;
			enthdr.st.cache_misses++;
		} else {
			ent->bdescr[ent->nblks++] = bdescr;
			enthdr.st.cache_hits++;
		}

		memmove(inp, inp + blksize, bufsize - blksize);
		bufsize -= blksize;
	}

	if (ent->nblks > 0) {
		/* Calculate hash and add this entry to the index */
		SHA256_Final(ent->md, &ctx);

		if (msg != NULL) {
			size_t size;

			size = strlen(msg) + 1;
			if (size > sizeof(ent->msg))
				size = sizeof(ent->msg);
			memcpy(ent->msg, msg, size);
			ent->msg[size - 1] = '\0';
		}

		append_ent(ent);
	}

	free(ent);
	free(buf[1]);
	free(buf[0]);
}

int
extract(struct ent *ent, void *arg)
{
	uint8_t *buf[2];
	struct extract_args *args = arg;
	uint64_t i;

	if (memcmp(ent->md, args->md, sizeof(ent->md)) != 0)
		return WALK_CONTINUE;

	buf[0] = alloc_buf(BLKSIZE);
	buf[1] = alloc_buf(comp_size(BLKSIZE));
	for (i = 0; i < ent->nblks; i++) {
		size_t blksize;

		read_blk(buf[1], &ent->bdescr[i]);
		blksize = decomp(buf[1], buf[0], ent->bdescr[i].size, BLKSIZE);
		xwrite(args->fd, buf[0], blksize);
	}
	free(buf[1]);
	free(buf[0]);
	return WALK_STOP;
}

int
check(struct ent *ent, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(comp_size(BLKSIZE));
	/*
	 * Calculate hash for each block and compare
	 * with index entry block descriptor
	 */
	for (i = 0; i < ent->nblks; i++) {
		read_blk(buf, &ent->bdescr[i]);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, ent->bdescr[i].size);
		SHA256_Final(md, &ctx);

		if (memcmp(ent->bdescr[i].md, md,
		           sizeof(ent->bdescr[i]).md) == 0)
			continue;

		fprintf(stderr, "Block hash mismatch\n");
		fprintf(stderr, "  Expected hash: ");
		print_md(ent->md, sizeof(ent->md));
		fputc('\n', stderr);
		fprintf(stderr, "  Actual hash: ");
		print_md(md, sizeof(md));
		fputc('\n', stderr);
		fprintf(stderr, "  Offset: %llu\n",
		        (unsigned long long)ent->bdescr[i].offset);
		fprintf(stderr, "  Size: %llu\n",
		        (unsigned long long)ent->bdescr[i].size);
	}
	free(buf);
	return WALK_CONTINUE;
}

int
list(struct ent *ent, void *arg)
{
	print_md(ent->md, sizeof(ent->md));
	if (ent->msg[0] != '\0')
		printf("\t%s\n", ent->msg);
	else
		putchar('\n');
	return WALK_CONTINUE;
}

int
rebuild_cache(struct ent *ent, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(comp_size(BLKSIZE));
	for (i = 0; i < ent->nblks; i++) {
		struct cent *cent;

		read_blk(buf, &ent->bdescr[i]);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, ent->bdescr[i].size);
		SHA256_Final(md, &ctx);

		cent = alloc_cent();
		memcpy(cent->bdescr.md, md, sizeof(cent->bdescr.md));
		cent->bdescr = ent->bdescr[i];
		add_cent(cent);
		cache_dirty = 1;
	}
	free(buf);
	return WALK_CONTINUE;
}

/* Walk through all index entries and call fn() on each one */
void
walk(int (*fn)(struct ent *, void *), void *arg)
{
	struct ent *ent;
	uint64_t i;

	ent = alloc_ent();
	lseek(ifd, sizeof(enthdr), SEEK_SET);
	for (i = 0; i < enthdr.nents; i++) {
		if (xread(ifd, ent, sizeof(*ent)) == 0)
			errx(1, "read: unexpected EOF");

		ent = grow_ent(ent, ent->nblks);
		if (xread(ifd, ent->bdescr,
		          ent->nblks * sizeof(ent->bdescr[0])) == 0)
			errx(1, "read: unexpected EOF");

		if ((*fn)(ent, arg) == WALK_STOP)
			break;
	}
	free(ent);
}

void
init_cache(void)
{
	uint64_t nents, i;
	uint64_t min, max, avg;

	min = comp_size(BLKSIZE);
	max = 0;
	avg = 0;

	nents = cache_nents();
	lseek(cfd, 0, SEEK_SET);
	for (i = 0; i < nents; i++) {
		struct cent *cent;

		cent = alloc_cent();
		if (xread(cfd, &cent->bdescr, sizeof(cent->bdescr)) == 0)
			errx(1, "read: unexpected EOF");
		add_cent(cent);

		if (cent->bdescr.size > max)
			max = cent->bdescr.size;
		if (cent->bdescr.size < min)
			min = cent->bdescr.size;
		avg += cent->bdescr.size;
	}
	avg /= nents;

	if (verbose) {
		fprintf(stderr, "min/avg/max block size: %llu/%llu/%llu\n",
		        (unsigned long long)min,
		        (unsigned long long)avg,
		        (unsigned long long)max);
	}
}

void
init(void)
{
	struct stat sb;

	ifd = open(INDEXF, O_RDWR | O_CREAT, 0600);
	if (ifd < 0)
		err(1, "open %s", INDEXF);

	sfd = open(STOREF, O_RDWR | O_CREAT, 0600);
	if (sfd < 0)
		err(1, "open %s", STOREF);

	cfd = open(CACHEF, O_RDWR | O_CREAT, 0600);
	if (cfd < 0)
		err(1, "open %s", CACHEF);

	if (fstat(ifd, &sb) < 0)
		err(1, "fstat %s", INDEXF);
	if (sb.st_size != 0)
		xread(ifd, &enthdr, sizeof(enthdr));
	if (verbose) {
		fprintf(stderr, "original size: %llu bytes\n",
		        (unsigned long long)enthdr.st.orig_size);
		fprintf(stderr, "compressed size: %llu bytes\n",
		        (unsigned long long)enthdr.st.comp_size);
		fprintf(stderr, "deduplicated size: %llu bytes\n",
		        (unsigned long long)enthdr.st.dedup_size);

		fprintf(stderr, "cache hits: %llu\n",
		        (unsigned long long)enthdr.st.cache_hits);
		fprintf(stderr, "cache misses: %llu\n",
		        (unsigned long long)enthdr.st.cache_misses);
	}

	if (cache_nents() != 0)
		init_cache();
	else
		walk(rebuild_cache, NULL);
}

void
term(void)
{
	flush_cache();
	free_cache();

	fsync(ifd);
	fsync(sfd);
	fsync(cfd);

	close(ifd);
	close(sfd);
	close(cfd);
}

void
usage(void)
{
	fprintf(stderr, "usage: %s [-clv] [-e id] [-r root] [-m message] [file]\n", argv0);
	exit(1);
}

int
main(int argc, char *argv[])
{
	uint8_t md[MDSIZE];
	char *id = NULL, *root = NULL, *msg = NULL;
	int fd = -1, lflag = 0, cflag = 0;

	ARGBEGIN {
	case 'c':
		cflag = 1;
		break;
	case 'e':
		id = EARGF(usage());
		break;
	case 'l':
		lflag = 1;
		break;
	case 'r':
		root = EARGF(usage());
		break;
	case 'm':
		msg = EARGF(usage());
		break;
	case 'v':
		verbose = 1;
		break;
	default:
		usage();
	} ARGEND

	if (argc > 1) {
		usage();
	} else if (argc == 1) {
		if (id) {
			fd = open(argv[0], O_RDWR | O_CREAT, 0600);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		} else {
			fd = open(argv[0], O_RDONLY);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		}
	} else {
		if (id)
			fd = STDOUT_FILENO;
		else
			fd = STDIN_FILENO;
	}

	if (root != NULL) {
		mkdir(root, 0700);
		if (chdir(root) < 0)
			err(1, "chdir: %s", root);
	}

	init();

	if (cflag) {
		walk(check, NULL);
		term();
		return 0;
	}

	if (lflag) {
		walk(list, NULL);
		term();
		return 0;
	}

	if (id) {
		str2bin(id, md);
		walk(extract, &(struct extract_args){ .md = md, .fd = fd });
	} else {
		dedup(fd, msg);
	}

	term();
	return 0;
}
