#include <sys/types.h>
#include <sys/stat.h>
#include <sys/file.h>

#include <err.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <lz4.h>
#include <openssl/sha.h>

#include "arg.h"
#include "dedup.h"

#define SNAPSF ".snapshots"
#define STOREF ".store"
#define CACHEF ".cache"

enum {
	WALK_CONTINUE,
	WALK_STOP
};

struct extract_args {
	uint8_t *md;
	int fd;
	int ret;
};

static struct snapshot_hdr snap_hdr;
static struct blk_hdr blk_hdr;
static struct cache *cache;
static int ifd;
static int sfd;
static int cfd;
static int cache_dirty;
static unsigned long long cache_hits;
static unsigned long long cache_misses;

int verbose;
char *argv0;

static size_t
compr_size(size_t size)
{
	size_t ret;

	if (compr_enabled)
		ret = LZ4_compressBound(size);
	else
		ret = size;

	return ret;
}

static size_t
compr(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	int ret;

	if (compr_enabled) {
		ret = LZ4_compress_default((char *)in, (char *)out, insize,
		                           outsize);
		if (ret < 0)
			errx(1, "LZ4_compress_default failed");
	} else {
		ret = insize;
		memcpy(out, in, insize);
	}

	return ret;
}

static size_t
decompr(uint8_t *in, uint8_t *out, size_t insize, size_t outsize)
{
	int ret;

	if (compr_enabled) {
		ret = LZ4_decompress_safe((char *)in, (char *)out, insize,
		                          outsize);
		if (ret < 0)
			errx(1, "LZ4_decompress_safe failed");
	} else {
		ret = insize;
		memcpy(out, in, insize);
	}

	return ret;
}

static void
print_md(FILE *fp, uint8_t *md, size_t size)
{
	size_t i;

	for (i = 0; i < size; i++)
		fprintf(fp, "%02x", md[i]);
}

static void
print_stats(struct stats *st)
{
	if (st->nr_blks == 0)
		return;

	fprintf(stderr, "Original size: %llu bytes\n",
	        (unsigned long long)st->orig_size);
	fprintf(stderr, "Compressed size: %llu bytes\n",
	        (unsigned long long)st->compr_size);
	fprintf(stderr, "Deduplicated size: %llu bytes\n",
	        (unsigned long long)st->dedup_size);
	fprintf(stderr, "Min/avg/max block size: %llu/%llu/%llu bytes\n",
	        (unsigned long long)st->min_blk_size,
	        (unsigned long long)st->dedup_size / st->nr_blks,
	        (unsigned long long)st->max_blk_size);
	fprintf(stderr, "Number of unique blocks: %llu\n",
	        (unsigned long long)st->nr_blks);
	fprintf(stderr, "Cache hits: %llu\n", cache_hits);
	fprintf(stderr, "Cache misses: %llu\n", cache_misses);
}

static void
append_snap(struct snapshot *snap)
{
	if (mul_overflow(snap->nr_blk_descs, BLK_DESC_SIZE))
		errx(1, "%s: overflow", __func__);
	snap->size = snap->nr_blk_descs * BLK_DESC_SIZE;

	if (add_overflow(SNAPSHOT_SIZE, snap->size))
		errx(1, "%s: overflow", __func__);
	snap->size += SNAPSHOT_SIZE;

	xlseek(ifd, snap_hdr.size, SEEK_SET);
	write_snapshot(ifd, snap);
	write_snapshot_blk_descs(ifd, snap);

	if (add_overflow(snap_hdr.size, snap->size))
		errx(1, "%s: overflow", __func__);
	snap_hdr.size += snap->size;
	snap_hdr.nr_snapshots++;
}

static struct snapshot *
alloc_snap(void)
{
	struct snapshot *snap;

	snap = calloc(1, sizeof(*snap));
	if (snap == NULL)
		err(1, "calloc");
	return snap;
}

static void
free_snap(struct snapshot *snap)
{
	free(snap);
}

static struct snapshot *
grow_snap(struct snapshot *snap, uint64_t nr_blk_descs)
{
	size_t size;

	if (mul_overflow(nr_blk_descs, sizeof(snap->blk_desc[0])))
		errx(1, "%s: overflow", __func__);
	size = nr_blk_descs * sizeof(snap->blk_desc[0]);

	if (add_overflow(size, sizeof(*snap)))
		errx(1, "%s: overflow", __func__);
	size += sizeof(*snap);

	snap = realloc(snap, size);
	if (snap == NULL)
		err(1, "realloc");
	return snap;
}

static uint8_t *
alloc_buf(size_t size)
{
	void *p;

	p = calloc(1, size);
	if (p == NULL)
		err(1, "calloc");
	return p;
}

static void
free_buf(uint8_t *buf)
{
	free(buf);
}

static void
hash_blk(uint8_t *buf, size_t size, uint8_t *md)
{
	SHA256_CTX ctx;

	SHA256_Init(&ctx);
	SHA256_Update(&ctx, buf, size);
	SHA256_Final(md, &ctx);
}

static void
read_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	xlseek(sfd, blk_desc->offset, SEEK_SET);
	if (xread(sfd, buf, blk_desc->size) == 0)
		errx(1, "read: unexpected EOF");
}

static void
append_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	xlseek(sfd, blk_hdr.size, SEEK_SET);
	xwrite(sfd, buf, blk_desc->size);
	blk_hdr.size += blk_desc->size;
}

static void
dedup_chunk(struct snapshot *snap, uint8_t *chunkp, size_t chunk_size)
{
	uint8_t md[MDSIZE];
	struct cache_entry cache_entry;
	uint8_t *compr_buf;
	size_t n;

	compr_buf = alloc_buf(compr_size(BLKSIZE_MAX));

	n = compr(chunkp, compr_buf, chunk_size, compr_size(BLKSIZE_MAX));
	hash_blk(compr_buf, n, md);

	snap_hdr.st.orig_size += chunk_size;
	snap_hdr.st.compr_size += n;

	memcpy(cache_entry.md, md, sizeof(cache_entry.md));
	if (lookup_cache_entry(cache, &cache_entry) < 0) {
		struct blk_desc blk_desc;

		memcpy(&blk_desc.md, md, sizeof(blk_desc.md));
		blk_desc.offset = blk_hdr.size;
		blk_desc.size = n;

		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;
		append_blk(compr_buf, &blk_desc);

		cache_entry.offset = blk_desc.offset;
		cache_entry.size = blk_desc.size;
		cache_dirty = 1;
		add_cache_entry(cache, &cache_entry);
		cache_misses++;

		snap_hdr.st.dedup_size += blk_desc.size;
		snap_hdr.st.nr_blks++;

		if (blk_desc.size > snap_hdr.st.max_blk_size)
			snap_hdr.st.max_blk_size = blk_desc.size;
		if (blk_desc.size < snap_hdr.st.min_blk_size)
			snap_hdr.st.min_blk_size = blk_desc.size;
	} else {
		struct blk_desc blk_desc;

		memcpy(&blk_desc.md, cache_entry.md, sizeof(blk_desc.md));
		blk_desc.offset = cache_entry.offset;
		blk_desc.size = cache_entry.size;
		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;
		cache_hits++;
	}

	free(compr_buf);
}

static void
dedup(int fd, char *msg)
{
	struct snapshot *snap;
	struct chunker *chunker;
	SHA256_CTX ctx;
	ssize_t n;

	snap = alloc_snap();
	chunker = alloc_chunker(fd, BLKSIZE_MAX);

	SHA256_Init(&ctx);
	while ((n = fill_chunker(chunker)) > 0) {
		uint8_t *chunkp;
		size_t chunk_size;

		chunkp = get_chunk(chunker, &chunk_size);
		SHA256_Update(&ctx, chunkp, chunk_size);
		snap = grow_snap(snap, snap->nr_blk_descs + 1);
		dedup_chunk(snap, chunkp, chunk_size);
		drain_chunker(chunker);
	}
	SHA256_Final(snap->md, &ctx);

	if (snap->nr_blk_descs > 0) {
		if (msg != NULL) {
			size_t size;

			size = strlen(msg) + 1;
			if (size > sizeof(snap->msg))
				size = sizeof(snap->msg);
			memcpy(snap->msg, msg, size);
			snap->msg[size - 1] = '\0';
		}

		append_snap(snap);
	}

	free_chunker(chunker);
	free_snap(snap);
}

static int
extract(struct snapshot *snap, void *arg)
{
	uint8_t *buf[2];
	struct extract_args *args = arg;
	uint64_t i;

	if (memcmp(snap->md, args->md, sizeof(snap->md)) != 0)
		return WALK_CONTINUE;

	buf[0] = alloc_buf(BLKSIZE_MAX);
	buf[1] = alloc_buf(compr_size(BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct blk_desc *blk_desc;
		size_t blksize;

		blk_desc = &snap->blk_desc[i];
		read_blk(buf[1], blk_desc);
		blksize = decompr(buf[1], buf[0], blk_desc->size, BLKSIZE_MAX);
		xwrite(args->fd, buf[0], blksize);
	}
	free_buf(buf[1]);
	free_buf(buf[0]);
	args->ret = 0;
	return WALK_STOP;
}

static int
check(struct snapshot *snap, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(compr_size(BLKSIZE_MAX));
	/*
	 * Calculate hash for each block and compare
	 * against snapshot entry block descriptor
	 */
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct blk_desc *blk_desc;

		blk_desc = &snap->blk_desc[i];
		read_blk(buf, blk_desc);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, blk_desc->size);
		SHA256_Final(md, &ctx);

		if (memcmp(blk_desc->md, md, sizeof(blk_desc->md)) == 0)
			continue;

		fprintf(stderr, "Block hash mismatch\n");
		fprintf(stderr, "  Expected hash: ");
		print_md(stderr, snap->md, sizeof(snap->md));
		fputc('\n', stderr);
		fprintf(stderr, "  Actual hash: ");
		print_md(stderr, md, sizeof(md));
		fputc('\n', stderr);
		fprintf(stderr, "  Offset: %llu\n",
		        (unsigned long long)blk_desc->offset);
		fprintf(stderr, "  Size: %llu\n",
		        (unsigned long long)blk_desc->size);
	}
	free_buf(buf);
	return WALK_CONTINUE;
}

static int
list(struct snapshot *snap, void *arg)
{
	print_md(stdout, snap->md, sizeof(snap->md));
	if (snap->msg[0] != '\0')
		printf("\t%s\n", snap->msg);
	else
		putchar('\n');
	return WALK_CONTINUE;
}

static int
rebuild_cache(struct snapshot *snap, void *arg)
{
	uint8_t md[MDSIZE];
	uint8_t *buf;
	SHA256_CTX ctx;
	uint64_t i;

	buf = alloc_buf(compr_size(BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct cache_entry cache_entry;
		struct blk_desc *blk_desc;

		blk_desc = &snap->blk_desc[i];
		read_blk(buf, blk_desc);

		SHA256_Init(&ctx);
		SHA256_Update(&ctx, buf, blk_desc->size);
		SHA256_Final(md, &ctx);

		memcpy(cache_entry.md, blk_desc->md, sizeof(cache_entry.md));
		cache_entry.offset = blk_desc->offset;
		cache_entry.size = blk_desc->size;
		add_cache_entry(cache, &cache_entry);
	}
	free(buf);
	cache_dirty = 1;
	return WALK_CONTINUE;
}

/* Walk through all snapshots and call fn() on each one */
static void
walk_snap(int (*fn)(struct snapshot *, void *), void *arg)
{
	uint64_t i;

	for (i = 0; i < snap_hdr.nr_snapshots; i++) {
		struct snapshot *snap;
		int ret;

		snap = alloc_snap();
		read_snapshot(ifd, snap);
		snap = grow_snap(snap, snap->nr_blk_descs);
		read_snapshot_descs(ifd, snap);

		ret = (*fn)(snap, arg);
		free(snap);
		if (ret == WALK_STOP)
			break;
	}
}

static void
match_ver(uint64_t v)
{
	uint8_t maj, min;

	min = v & VER_MIN_MASK;
	maj = (v >> VER_MAJ_SHIFT) & VER_MAJ_MASK;
	if (maj == VER_MAJ && min == VER_MIN)
		return;
	errx(1, "format version mismatch: expected %u.%u but got %u.%u",
	     VER_MAJ, VER_MIN, maj, min);
}

static int
flush_cache(struct cache_entry *cache_entry)
{
	write_cache_entry(cfd, cache_entry);
	return 0;
}

static void
load_cache(void)
{
	struct stat sb;
	uint64_t nr_entries;
	uint64_t i;

	xlseek(cfd, 0, SEEK_SET);

	if (fstat(cfd, &sb) < 0)
		err(1, "fstat");

	nr_entries = sb.st_size / CACHE_ENTRY_SIZE;
	if (nr_entries == 0) {
		xlseek(ifd, SNAP_HDR_SIZE, SEEK_SET);
		walk_snap(rebuild_cache, NULL);
		return;
	}

	for (i = 0; i < nr_entries; i++) {
		struct cache_entry cache_entry;

		read_cache_entry(cfd, &cache_entry);
		add_cache_entry(cache, &cache_entry);
	}
}

static void
save_cache(void)
{
	if (cache_dirty) {
		xlseek(cfd, 0, SEEK_SET);
		walk_cache(cache, flush_cache);
	}
}

static void
init_blk_hdr(void)
{
	blk_hdr.flags = (VER_MAJ << VER_MAJ_SHIFT) | VER_MIN;
	blk_hdr.flags |= compr_enabled << COMPR_ENABLED_SHIFT;
	blk_hdr.size = BLK_HDR_SIZE;
}

static void
load_blk_hdr(void)
{
	uint64_t v;

	xlseek(sfd, 0, SEEK_SET);
	read_blk_hdr(sfd, &blk_hdr);
	match_ver(blk_hdr.flags);

	v = blk_hdr.flags >> COMPR_ENABLED_SHIFT;
	v &= COMPR_ENABLED_MASK;
	compr_enabled = v;
}

static void
save_blk_hdr(void)
{
	xlseek(sfd, 0, SEEK_SET);
	write_blk_hdr(sfd, &blk_hdr);
}

static void
init_snap_hdr(void)
{
	snap_hdr.flags = (VER_MAJ << VER_MAJ_SHIFT) | VER_MIN;
	snap_hdr.size = SNAP_HDR_SIZE;
	snap_hdr.st.min_blk_size = compr_size(BLKSIZE_MAX);
}

static void
load_snap_hdr(void)
{
	xlseek(ifd, 0, SEEK_SET);
	read_snap_hdr(ifd, &snap_hdr);
	match_ver(snap_hdr.flags);
}

static void
save_snap_hdr(void)
{
	xlseek(ifd, 0, SEEK_SET);
	write_snap_hdr(ifd, &snap_hdr);
}

static void
init(int iflag)
{
	int flags;

	flags = O_RDWR;
	if (iflag)
		flags |= O_CREAT | O_EXCL;

	ifd = open(SNAPSF, flags, 0600);
	if (ifd < 0)
		err(1, "open %s", SNAPSF);

	sfd = open(STOREF, flags, 0600);
	if (sfd < 0)
		err(1, "open %s", STOREF);

	/*
	 * The cache file does not have to exist
	 * and will be created again if deleted.
	 */
	cfd = open(CACHEF, O_RDWR | O_CREAT, 0600);
	if (cfd < 0)
		err(1, "open %s", CACHEF);

	if (flock(ifd, LOCK_NB | LOCK_EX) < 0 ||
	    flock(sfd, LOCK_NB | LOCK_EX) < 0 ||
	    flock(cfd, LOCK_NB | LOCK_EX) < 0)
		err(1, "flock");

	if (iflag) {
		init_snap_hdr();
		init_blk_hdr();
	} else {
		load_snap_hdr();
		load_blk_hdr();
	}

	cache = alloc_cache();
	load_cache();
}

static void
term(void)
{
	if (verbose > 0)
		print_stats(&snap_hdr.st);

	save_cache();
	free_cache(cache);

	save_blk_hdr();
	save_snap_hdr();

	fsync(cfd);
	fsync(sfd);
	fsync(ifd);

	close(cfd);
	close(sfd);
	close(ifd);
}

static void
usage(void)
{
	fprintf(stderr, "usage: %s [-Zcilv] [-e id] [-r root] [-m message] [file]\n", argv0);
	exit(1);
}

int
main(int argc, char *argv[])
{
	uint8_t md[MDSIZE];
	char *id = NULL, *root = NULL, *msg = NULL;
	int iflag = 0, lflag = 0, cflag = 0;
	int fd = -1;

	ARGBEGIN {
	case 'Z':
		compr_enabled = 0;
		break;
	case 'c':
		cflag = 1;
		break;
	case 'e':
		id = EARGF(usage());
		break;
	case 'i':
		iflag = 1;
		break;
	case 'l':
		lflag = 1;
		break;
	case 'r':
		root = EARGF(usage());
		break;
	case 'm':
		msg = EARGF(usage());
		break;
	case 'v':
		verbose++;
		break;
	default:
		usage();
	} ARGEND

	if (argc > 1) {
		usage();
	} else if (argc == 1) {
		if (id) {
			fd = open(argv[0], O_RDWR | O_CREAT, 0600);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		} else {
			fd = open(argv[0], O_RDONLY);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		}
	} else {
		if (id)
			fd = STDOUT_FILENO;
		else
			fd = STDIN_FILENO;
	}

	if (root != NULL) {
		mkdir(root, 0700);
		if (chdir(root) < 0)
			err(1, "chdir: %s", root);
	}

	init(iflag);

	if (iflag) {
		term();
		return 0;
	}

	if (cflag) {
		xlseek(ifd, SNAP_HDR_SIZE, SEEK_SET);
		walk_snap(check, NULL);
		term();
		return 0;
	}

	if (lflag) {
		xlseek(ifd, SNAP_HDR_SIZE, SEEK_SET);
		walk_snap(list, NULL);
		term();
		return 0;
	}

	if (id) {
		struct extract_args args;

		xlseek(ifd, SNAP_HDR_SIZE, SEEK_SET);
		str2bin(id, md);
		args.md = md;
		args.fd = fd;
		args.ret = -1;
		walk_snap(extract, &args);
		if (args.ret != 0)
			errx(1, "unknown snapshot: %s", id);
	} else {
		dedup(fd, msg);
	}

	term();
	return 0;
}
