#include <sys/stat.h>
#include <sys/file.h>
#include <err.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <lz4.h>
#include <openssl/sha.h>

#include "arg.h"
#include "dedup.h"
#include "tree.h"

#define SNAPSF ".snapshots"
#define STOREF ".store"
#define CACHEF ".cache"

#define MSGSIZE 256
#define MDSIZE SHA256_DIGEST_LENGTH

/* file format version */
#define VER_MIN 1
#define VER_MAJ 0

enum {
	WALK_CONTINUE,
	WALK_STOP
};

struct stats {
	uint64_t orig_size;
	uint64_t comp_size;
	uint64_t dedup_size;
	uint64_t min_blk_size;
	uint64_t max_blk_size;
	uint64_t nr_blks;
	uint64_t reserved[6];
};

struct snapshot_hdr {
	uint64_t flags;
	uint64_t nr_snapshots;
	uint64_t store_size;
	uint64_t reserved[4];
	struct stats st;
};

struct blk_desc {
	uint8_t md[MDSIZE];
	uint64_t offset;
	uint64_t size;
};

struct snapshot {
	uint64_t size;
	uint8_t msg[MSGSIZE];
	uint8_t md[MDSIZE];	/* hash of file */
	uint64_t nr_blk_descs;
	struct blk_desc blk_desc[];
};

struct cache_entry {
	struct blk_desc blk_desc;
	RB_ENTRY(cache_entry) e;
};

struct extract_args {
	uint8_t *md;
	int fd;
};

static RB_HEAD(cache, cache_entry) cache_head;
static struct snapshot_hdr snaphdr;
static int ifd;
static int sfd;
static int cfd;
static int verbose;
static int cache_dirty;
static unsigned long long cache_hits;
static unsigned long long cache_misses;
char *argv0;

static size_t
comp_size(size_t size)
{
	return LZ4_compressBound(size);
}

static size_t
comp(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	int ret;

	ret = LZ4_compress_default((char *)in, (char *)out, insize, outsize);
	if (ret < 0)
		errx(1, "LZ4_compress_default failed");
	return ret;
}

static size_t
decomp(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	int ret;

	ret = LZ4_decompress_safe((char *)in, (char *)out, insize, outsize);
	if (ret < 0)
		errx(1, "LZ4_decompress_safe failed");
	return ret;
}

static void
print_md(FILE *fp, uint8_t *md, size_t size)
{
	size_t i;

	for (i = 0; i < size; i++)
		fprintf(fp, "%02x", md[i]);
}

static void
print_stats(struct stats *st)
{
	if (st->nr_blks == 0)
		return;

	fprintf(stderr, "original size: %llu bytes\n",
	        (unsigned long long)st->orig_size);
	fprintf(stderr, "compressed size: %llu bytes\n",
	        (unsigned long long)st->comp_size);
	fprintf(stderr, "deduplicated size: %llu bytes\n",
	        (unsigned long long)st->dedup_size);
	fprintf(stderr, "min/avg/max block size: %llu/%llu/%llu\n",
	        (unsigned long long)st->min_blk_size,
	        (unsigned long long)st->dedup_size / st->nr_blks,
	        (unsigned long long)st->max_blk_size);
	fprintf(stderr, "number of blocks: %llu\n",
	        (unsigned long long)st->nr_blks);
	fprintf(stderr, "cache hits: %llu\n", cache_hits);
	fprintf(stderr, "cache misses: %llu\n", cache_misses);
}

static int
cache_entry_cmp(struct cache_entry *e1, struct cache_entry *e2)
{
	int r;

	r = memcmp(e1->blk_desc.md, e2->blk_desc.md, sizeof(e1->blk_desc.md));
	if (r > 0)
		return 1;
	else if (r < 0)
		return -1;
	return 0;
}
static RB_PROTOTYPE(cache, cache_entry, e, cache_entry_cmp);
static RB_GENERATE(cache, cache_entry, e, cache_entry_cmp);

static struct cache_entry *
alloc_cache_entry(void)
{
	struct cache_entry *ent;

	ent = calloc(1, sizeof(*ent));
	if (ent == NULL)
		err(1, "calloc");
	return ent;
}

static void
free_cache_entry(struct cache_entry *ent)
{
	free(ent);
}

static void
add_cache_entry(struct cache_entry *ent)
{
	RB_INSERT(cache, &cache_head, ent);
}

static void
flush_cache(void)
{
	struct cache_entry *ent;

	if (!cache_dirty)
		return;

	xlseek(cfd, 0, SEEK_SET);
	RB_FOREACH(ent, cache, &cache_head)
		xwrite(cfd, &ent->blk_desc, sizeof(ent->blk_desc));
}

static void
free_cache(void)
{
	struct cache_entry *ent, *tmp;

	RB_FOREACH_SAFE(ent, cache, &cache_head, tmp) {
		RB_REMOVE(cache, &cache_head, ent);
		free_cache_entry(ent);
	}
}

static uint64_t
cache_nr_entries(void)
{
	struct stat sb;

	if (fstat(cfd, &sb) < 0)
		err(1, "fstat");
	return sb.st_size / sizeof(struct blk_desc);
}

static void
append_snap(struct snapshot *snap)
{
	/* Update snapshot header */
	snaphdr.nr_snapshots++;
	xlseek(ifd, 0, SEEK_SET);
	xwrite(ifd, &snaphdr, sizeof(snaphdr));

	/* Append snapshot */
	xlseek(ifd, 0, SEEK_END);
	snap->size = sizeof(*snap);
	snap->size += snap->nr_blk_descs * sizeof(snap->blk_desc[0]);
	xwrite(ifd, snap, snap->size);
}

static struct snapshot *
alloc_snap(void)
{
	struct snapshot *snap;

	snap = calloc(1, sizeof(*snap));
	if (snap == NULL)
		err(1, "calloc");
	return snap;
}

static void
free_snap(struct snapshot *snap)
{
	free(snap);
}

static struct snapshot *
grow_snap(struct snapshot *snap, uint64_t nr_blk_descs)
{
	size_t size;

	size = sizeof(*snap);
	size += nr_blk_descs * sizeof(snap->blk_desc[0]);
	snap = realloc(snap, size);
	if (snap == NULL)
		err(1, "realloc");
	return snap;
}

static uint8_t *
alloc_buf(size_t size)
{
	void *p;

	p = calloc(1, size);
	if (p == NULL)
		err(1, "calloc");
	return p;
}

static void
free_buf(uint8_t *buf)
{
	free(buf);
}

static void
hash_blk(uint8_t *buf, size_t size, uint8_t *md)
{
	SHA256_CTX ctx;

	SHA256_Init(&ctx);
	SHA256_Update(&ctx, buf, size);
	SHA256_Final(md, &ctx);
}

static void
read_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	xlseek(sfd, blk_desc->offset, SEEK_SET);
	if (xread(sfd, buf, blk_desc->size) == 0)
		errx(1, "read: unexpected EOF");
}

static void
append_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	xlseek(sfd, snaphdr.store_size, SEEK_SET);
	xwrite(sfd, buf, blk_desc->size);
	snaphdr.store_size += blk_desc->size;
}

static int
lookup_blk_desc(uint8_t *md, struct blk_desc *blk_desc)
{
	struct cache_entry *ent, key;

	memcpy(key.blk_desc.md, md, sizeof(key.blk_desc.md));
	ent = RB_FIND(cache, &cache_head, &key);
	if (ent != NULL) {
		*blk_desc = ent->blk_desc;
		return 0;
	}
	return -1;
}

static void
dedup_chunk(struct snapshot *snap, uint8_t *chunkp, size_t chunk_size)
{
	uint8_t md[MDSIZE];
	uint8_t *comp_buf;
	struct blk_desc blk_desc;
	size_t n;

	comp_buf = alloc_buf(comp_size(BLKSIZE_MAX));

	n = comp(chunkp, comp_buf, chunk_size, comp_size(BLKSIZE_MAX));
	hash_blk(comp_buf, n, md);

	snaphdr.st.orig_size += chunk_size;
	snaphdr.st.comp_size += n;

	if (lookup_blk_desc(md, &blk_desc) < 0) {
		struct cache_entry *ent;

		memcpy(blk_desc.md, md, sizeof(blk_desc.md));
		blk_desc.offset = snaphdr.store_size;
		blk_desc.size = n;

		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;

		append_blk(comp_buf, &blk_desc);

		ent = alloc_cache_entry();
		ent->blk_desc = blk_desc;
		add_cache_entry(ent);
		cache_dirty = 1;
		cache_misses++;

		snaphdr.st.dedup_size += blk_desc.size;
		snaphdr.st.nr_blks++;

		if (blk_desc.size > snaphdr.st.max_blk_size)
			snaphdr.st.max_blk_size = blk_desc.size;
		if (blk_desc.size < snaphdr.st.min_blk_size)
			snaphdr.st.min_blk_size = blk_desc.size;
	} else {
		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;
		cache_hits++;
	}

	free(comp_buf);
}

static void
dedup(int fd, char *msg)
{
	struct snapshot *snap;
	struct chunker *chunker;
	SHA256_CTX ctx;
	ssize_t n;

	snap = alloc_snap();
	chunker = alloc_chunker(BLKSIZE_MAX, fd);

	SHA256_Init(&ctx);
	while ((n = fill_chunker(chunker)) > 0) {
		uint8_t *chunkp;
		size_t chunk_size;

		chunkp = get_chunk(chunker, &chunk_size);
		SHA256_Update(&ctx, chunkp, chunk_size);
		snap = grow_snap(snap, snap->nr_blk_descs + 1);
		dedup_chunk(snap, chunkp, chunk_size);
		drain_chunker(chunker);
	}
	SHA256_Final(snap->md, &ctx);

	if (snap->nr_blk_descs > 0) {
		if (msg != NULL) {
			size_t size;

			size = strlen(msg) + 1;
			if (size > sizeof(snap->msg))
				size = sizeof(snap->msg);
			memcpy(snap->msg, msg, size);
			snap->msg[size - 1] = '\0';
		}

		append_snap(snap);
	}

	free_chunker(chunker);
	free_snap(snap);
}

static int
extract(struct snapshot *snap, void *arg)
{
	uint8_t *buf[2];
	struct extract_args *args = arg;
	uint64_t i;

	if (memcmp(snap->md, args->md, sizeof(snap->md)) != 0)
		return WALK_CONTINUE;

	buf[0] = alloc_buf(BLKSIZE_MAX);
	buf[1] = alloc_buf(comp_size(BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		size_t blksize;

		read_blk(buf[1], &snap->blk_desc[i]);
		blksize = decomp(buf[1], buf[0], snap->blk_desc[i].size,
		                 BLKSIZE_MAX);
		xwrite(args->fd, buf[0], blksize);
	}
	free_buf(buf[1]);
	free_buf(buf[0]);
	return WALK_STOP;
}

static int
check(struct snapshot *snap, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(comp_size(BLKSIZE_MAX));
	/*
	 * Calculate hash for each block and compare
	 * against snapshot entry block descriptor
	 */
	for (i = 0; i < snap->nr_blk_descs; i++) {
		read_blk(buf, &snap->blk_desc[i]);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, snap->blk_desc[i].size);
		SHA256_Final(md, &ctx);

		if (memcmp(snap->blk_desc[i].md, md,
		           sizeof(snap->blk_desc[i]).md) == 0)
			continue;

		fprintf(stderr, "Block hash mismatch\n");
		fprintf(stderr, "  Expected hash: ");
		print_md(stderr, snap->md, sizeof(snap->md));
		fputc('\n', stderr);
		fprintf(stderr, "  Actual hash: ");
		print_md(stderr, md, sizeof(md));
		fputc('\n', stderr);
		fprintf(stderr, "  Offset: %llu\n",
		        (unsigned long long)snap->blk_desc[i].offset);
		fprintf(stderr, "  Size: %llu\n",
		        (unsigned long long)snap->blk_desc[i].size);
	}
	free_buf(buf);
	return WALK_CONTINUE;
}

static int
list(struct snapshot *snap, void *arg)
{
	print_md(stdout, snap->md, sizeof(snap->md));
	if (snap->msg[0] != '\0')
		printf("\t%s\n", snap->msg);
	else
		putchar('\n');
	return WALK_CONTINUE;
}

static int
rebuild_cache(struct snapshot *snap, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(comp_size(BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct cache_entry *ent;

		read_blk(buf, &snap->blk_desc[i]);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, snap->blk_desc[i].size);
		SHA256_Final(md, &ctx);

		ent = alloc_cache_entry();
		memcpy(ent->blk_desc.md, md, sizeof(ent->blk_desc.md));
		ent->blk_desc = snap->blk_desc[i];
		add_cache_entry(ent);
		cache_dirty = 1;
	}
	free(buf);
	return WALK_CONTINUE;
}

/* Walk through all snapshots and call fn() on each one */
static void
walk(int (*fn)(struct snapshot *, void *), void *arg)
{
	uint64_t i;

	xlseek(ifd, sizeof(snaphdr), SEEK_SET);
	for (i = 0; i < snaphdr.nr_snapshots; i++) {
		struct snapshot *snap;
		int ret;

		snap = alloc_snap();
		if (xread(ifd, snap, sizeof(*snap)) == 0)
			errx(1, "read: unexpected EOF");

		snap = grow_snap(snap, snap->nr_blk_descs);
		if (xread(ifd, snap->blk_desc,
		          snap->nr_blk_descs * sizeof(snap->blk_desc[0])) == 0)
			errx(1, "read: unexpected EOF");

		ret = (*fn)(snap, arg);
		free(snap);
		if (ret == WALK_STOP)
			break;
	}
}

static void
init_cache(void)
{
	uint64_t i;

	xlseek(cfd, 0, SEEK_SET);
	for (i = 0; i < cache_nr_entries(); i++) {
		struct cache_entry *ent;

		ent = alloc_cache_entry();
		if (xread(cfd, &ent->blk_desc, sizeof(ent->blk_desc)) == 0)
			errx(1, "read: unexpected EOF");
		add_cache_entry(ent);
	}
}

static void
init(void)
{
	struct stat sb;

	ifd = open(SNAPSF, O_RDWR | O_CREAT, 0600);
	if (ifd < 0)
		err(1, "open %s", SNAPSF);

	sfd = open(STOREF, O_RDWR | O_CREAT, 0600);
	if (sfd < 0)
		err(1, "open %s", STOREF);

	cfd = open(CACHEF, O_RDWR | O_CREAT, 0600);
	if (cfd < 0)
		err(1, "open %s", CACHEF);

	if (flock(ifd, LOCK_NB | LOCK_EX) < 0 ||
	    flock(sfd, LOCK_NB | LOCK_EX) < 0 ||
	    flock(cfd, LOCK_NB | LOCK_EX) < 0)
		errx(1, "busy lock");

	if (fstat(ifd, &sb) < 0)
		err(1, "fstat %s", SNAPSF);
	if (sb.st_size != 0) {
		uint8_t maj, min;

		xread(ifd, &snaphdr, sizeof(snaphdr));
		min = snaphdr.flags & 0xff;
		maj = (snaphdr.flags >> 8) & 0xff;

		if (maj != VER_MAJ || min != VER_MIN)
			errx(1, "expected snapshot format version %u.%u but got %u.%u",
			     VER_MAJ, VER_MIN, maj, min);
	} else {
		snaphdr.flags = (VER_MAJ << 8) | VER_MIN;
		snaphdr.st.min_blk_size = comp_size(BLKSIZE_MAX);
		xwrite(ifd, &snaphdr, sizeof(snaphdr));
	}

	if (cache_nr_entries() != 0)
		init_cache();
	else
		walk(rebuild_cache, NULL);
}

static void
term(void)
{
	if (verbose)
		print_stats(&snaphdr.st);
	flush_cache();
	free_cache();

	fsync(ifd);
	fsync(sfd);
	fsync(cfd);

	close(ifd);
	close(sfd);
	close(cfd);
}

static void
usage(void)
{
	fprintf(stderr, "usage: %s [-clv] [-e id] [-r root] [-m message] [file]\n", argv0);
	exit(1);
}

int
main(int argc, char *argv[])
{
	uint8_t md[MDSIZE];
	char *id = NULL, *root = NULL, *msg = NULL;
	int fd = -1, lflag = 0, cflag = 0;

	ARGBEGIN {
	case 'c':
		cflag = 1;
		break;
	case 'e':
		id = EARGF(usage());
		break;
	case 'l':
		lflag = 1;
		break;
	case 'r':
		root = EARGF(usage());
		break;
	case 'm':
		msg = EARGF(usage());
		break;
	case 'v':
		verbose = 1;
		break;
	default:
		usage();
	} ARGEND

	if (argc > 1) {
		usage();
	} else if (argc == 1) {
		if (id) {
			fd = open(argv[0], O_RDWR | O_CREAT, 0600);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		} else {
			fd = open(argv[0], O_RDONLY);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		}
	} else {
		if (id)
			fd = STDOUT_FILENO;
		else
			fd = STDIN_FILENO;
	}

	if (root != NULL) {
		mkdir(root, 0700);
		if (chdir(root) < 0)
			err(1, "chdir: %s", root);
	}

	init();

	if (cflag) {
		walk(check, NULL);
		term();
		return 0;
	}

	if (lflag) {
		walk(list, NULL);
		term();
		return 0;
	}

	if (id) {
		str2bin(id, md);
		walk(extract, &(struct extract_args){ .md = md, .fd = fd });
	} else {
		dedup(fd, msg);
	}

	term();
	return 0;
}
