#include <sys/types.h>
#include <sys/stat.h>
#include <sys/file.h>

#include <err.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "arg.h"
#include "blake2.h"
#include "dedup.h"

#define SNAPSF ".snapshots"
#define STOREF ".store"

enum {
	WALK_CONTINUE,
	WALK_STOP
};

struct extract_args {
	uint8_t *md;
	int fd;
	int ret;
};

static struct snap_hdr snap_hdr;
static struct blk_hdr blk_hdr;
static struct icache *icache;
static int ifd;
static int sfd;
static int hash_algo = HASH_BLAKE2B;
static int compr_algo = COMPR_LZ4;

int verbose;
char *argv0;

static void
print_md(FILE *fp, uint8_t *md, size_t size)
{
	size_t i;

	for (i = 0; i < size; i++)
		fprintf(fp, "%02x", md[i]);
}

static void
print_stats(struct stats *st)
{
	unsigned long long hits, misses;
	double hitratio;

	if (st->nr_blks == 0)
		return;

	fprintf(stderr, "Original size: %llu bytes\n",
	        (unsigned long long)st->orig_size);
	fprintf(stderr, "Compressed size: %llu bytes\n",
	        (unsigned long long)st->compr_size);
	fprintf(stderr, "Deduplicated size: %llu bytes\n",
	        (unsigned long long)st->dedup_size);
	fprintf(stderr, "Deduplication ratio: %.2f\n",
	        (double)st->orig_size / st->dedup_size);
	fprintf(stderr, "Min/avg/max block size: %llu/%llu/%llu bytes\n",
	        (unsigned long long)st->min_blk_size,
	        (unsigned long long)st->dedup_size / st->nr_blks,
	        (unsigned long long)st->max_blk_size);
	fprintf(stderr, "Number of unique blocks: %llu\n",
	        (unsigned long long)st->nr_blks);

	icache_stats(icache, &hits, &misses);
	if (hits == 0 && misses == 0)
		hitratio = 0;
	else
		hitratio = (double)hits / (hits + misses);

	fprintf(stderr, "Index cache hit percentage: %.2f%%\n",
	        100 * hitratio);
}

static struct snap *
alloc_snap(void)
{
	struct snap *snap;

	snap = calloc(1, sizeof(*snap));
	if (snap == NULL)
		err(1, "%s", __func__);
	return snap;
}

static void
free_snap(struct snap *snap)
{
	free(snap);
}

/*
 * The snapshot hash is calculated over the
 * hash of its block descriptors.
 */
static void
hash_snap(struct snap *snap, uint8_t *md)
{
	struct hash_ctx ctx;
	uint64_t i;

	if (hash_init(&ctx, hash_algo, MD_SIZE) < 0)
		errx(1, "hash_init failed");
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct blk_desc *blk_desc;

		blk_desc = &snap->blk_desc[i];
		hash_update(&ctx, blk_desc->md, sizeof(blk_desc->md));
	}
	hash_final(&ctx, md, MD_SIZE);
}

static struct snap *
grow_snap(struct snap *snap, uint64_t nr_blk_descs)
{
	size_t size;

	if (nr_blk_descs > SIZE_MAX / sizeof(snap->blk_desc[0]))
		errx(1, "%s: overflow", __func__);
	size = nr_blk_descs * sizeof(snap->blk_desc[0]);

	if (size > SIZE_MAX - sizeof(*snap))
		errx(1, "%s: overflow", __func__);
	size += sizeof(*snap);

	snap = realloc(snap, size);
	if (snap == NULL)
		err(1, "%s", __func__);
	return snap;
}

static void
append_snap(struct snap *snap)
{
	if (snap->nr_blk_descs > UINT64_MAX / BLK_DESC_SIZE)
		errx(1, "%s: overflow", __func__);
	snap->size = snap->nr_blk_descs * BLK_DESC_SIZE;

	if (snap->size > UINT64_MAX - SNAPSHOT_SIZE)
		errx(1, "%s: overflow", __func__);
	snap->size += SNAPSHOT_SIZE;

	xlseek(ifd, snap_hdr.size, SEEK_SET);
	write_snap(ifd, snap);
	write_snap_blk_descs(ifd, snap);

	if (snap_hdr.size > UINT64_MAX - snap->size)
		errx(1, "%s: overflow", __func__);
	snap_hdr.size += snap->size;

	if (snap_hdr.nr_snaps > UINT64_MAX - 1)
		errx(1, "%s: overflow", __func__);
	snap_hdr.nr_snaps++;
}

static uint8_t *
alloc_buf(size_t size)
{
	void *p;

	p = calloc(1, size);
	if (p == NULL)
		err(1, "%s", __func__);
	return p;
}

static void
free_buf(uint8_t *buf)
{
	free(buf);
}

static void
hash_blk(uint8_t *buf, size_t size, uint8_t *md)
{
	struct hash_ctx ctx;

	if (hash_init(&ctx, hash_algo, MD_SIZE) < 0)
		errx(1, "hash_init failed");
	hash_update(&ctx, buf, size);
	hash_final(&ctx, md, MD_SIZE);
}

static void
read_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	ssize_t n;

	xlseek(sfd, blk_desc->offset, SEEK_SET);
	n = xread(sfd, buf, blk_desc->size);
	if (n == 0)
		errx(1, "%s: unexpected EOF", __func__);
	if (n != blk_desc->size)
		errx(1, "%s: short read", __func__);
}

static void
append_blk(uint8_t *buf, struct blk_desc *blk_desc)
{
	xlseek(sfd, blk_hdr.size, SEEK_SET);
	xwrite(sfd, buf, blk_desc->size);

	if (blk_hdr.size > UINT64_MAX - blk_desc->size)
		errx(1, "%s: overflow", __func__);
	blk_hdr.size += blk_desc->size;
}

static void
dedup_chunk(struct snap *snap, uint8_t *chunkp, size_t chunk_size)
{
	uint8_t md[MD_SIZE];
	struct blk_desc blk_desc;
	struct compr_ctx ctx;
	uint8_t *compr_buf;
	size_t n, csize;

	if (compr_init(&ctx, compr_algo) < 0)
		errx(1, "compr_init failed");
	csize = compr_size(&ctx, BLKSIZE_MAX);
	compr_buf = alloc_buf(csize);

	n = compr(&ctx, chunkp, compr_buf, chunk_size, csize);
	hash_blk(compr_buf, n, md);

	snap_hdr.st.orig_size += chunk_size;
	snap_hdr.st.compr_size += n;

	memcpy(blk_desc.md, md, sizeof(blk_desc.md));
	if (lookup_icache(icache, &blk_desc) < 0) {
		blk_desc.offset = blk_hdr.size;
		blk_desc.size = n;

		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;
		append_blk(compr_buf, &blk_desc);

		insert_icache(icache, &blk_desc);

		snap_hdr.st.dedup_size += blk_desc.size;
		snap_hdr.st.nr_blks++;

		if (blk_desc.size > snap_hdr.st.max_blk_size)
			snap_hdr.st.max_blk_size = blk_desc.size;
		if (blk_desc.size < snap_hdr.st.min_blk_size)
			snap_hdr.st.min_blk_size = blk_desc.size;
	} else {
		snap->blk_desc[snap->nr_blk_descs++] = blk_desc;
	}

	free(compr_buf);
	compr_final(&ctx);
}

static void
dedup(int fd, char *msg)
{
	struct snap *snap;
	struct chunker *chunker;

	snap = alloc_snap();
	chunker = alloc_chunker(fd, BLKSIZE_MIN, BLKSIZE_MAX,
	                        HASHMASK_BITS, WINSIZE);

	while (fill_chunker(chunker) > 0) {
		uint8_t *chunkp;
		size_t chunk_size;

		chunkp = get_chunk(chunker, &chunk_size);
		snap = grow_snap(snap, snap->nr_blk_descs + 1);
		dedup_chunk(snap, chunkp, chunk_size);
		drain_chunker(chunker);
	}

	if (snap->nr_blk_descs > 0) {
		if (msg != NULL) {
			size_t size;

			size = strlen(msg) + 1;
			if (size > sizeof(snap->msg))
				size = sizeof(snap->msg);
			memcpy(snap->msg, msg, size);
			snap->msg[size - 1] = '\0';
		}
		hash_snap(snap, snap->md);
		append_snap(snap);
	}

	free_chunker(chunker);
	free_snap(snap);
}

static int
extract(struct snap *snap, void *arg)
{
	uint8_t *buf[2];
	struct extract_args *args = arg;
	struct compr_ctx ctx;
	uint64_t i;

	if (memcmp(snap->md, args->md, sizeof(snap->md)) != 0)
		return WALK_CONTINUE;

	if (compr_init(&ctx, compr_algo) < 0)
		errx(1, "compr_init failed");
	buf[0] = alloc_buf(BLKSIZE_MAX);
	buf[1] = alloc_buf(compr_size(&ctx, BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct blk_desc *blk_desc;
		size_t blksize;

		blk_desc = &snap->blk_desc[i];
		read_blk(buf[1], blk_desc);
		blksize = decompr(&ctx, buf[1], buf[0], blk_desc->size, BLKSIZE_MAX);
		xwrite(args->fd, buf[0], blksize);
	}
	free_buf(buf[1]);
	free_buf(buf[0]);
	compr_final(&ctx);
	args->ret = 0;
	return WALK_STOP;
}

/*
 * Hash every block referenced by the given snapshot
 * and compare its hash with the one stored in the corresponding
 * block descriptor.
 */
static int
check_snap(struct snap *snap, void *arg)
{
	struct compr_ctx ctx;
	uint8_t *buf;
	int *ret = arg;
	uint64_t i;

	if (verbose > 0) {
		fprintf(stderr, "Checking snapshot: ");
		print_md(stderr, snap->md, sizeof(snap->md));
		fputc('\n', stderr);
	}

	if (compr_init(&ctx, compr_algo) < 0)
		errx(1, "compr_init failed");
	buf = alloc_buf(compr_size(&ctx, BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		uint8_t md[MD_SIZE];
		struct blk_desc *blk_desc;

		blk_desc = &snap->blk_desc[i];
		read_blk(buf, blk_desc);
		hash_blk(buf, blk_desc->size, md);

		if (memcmp(blk_desc->md, md, sizeof(blk_desc->md)) == 0)
			continue;

		fprintf(stderr, "Block hash mismatch\n");
		fprintf(stderr, "  Expected hash: ");
		print_md(stderr, blk_desc->md, sizeof(blk_desc->md));
		fputc('\n', stderr);
		fprintf(stderr, "  Actual hash: ");
		print_md(stderr, md, sizeof(md));
		fputc('\n', stderr);
		fprintf(stderr, "  Offset: %llu\n",
		        (unsigned long long)blk_desc->offset);
		fprintf(stderr, "  Size: %llu\n",
		        (unsigned long long)blk_desc->size);
		*ret = -1;
	}
	free_buf(buf);
	compr_final(&ctx);
	return WALK_CONTINUE;
}

static int
build_icache(struct snap *snap, void *arg)
{
	struct compr_ctx ctx;
	uint8_t *buf;
	uint64_t i;

	if (compr_init(&ctx, compr_algo) < 0)
		errx(1, "compr_init failed");
	buf = alloc_buf(compr_size(&ctx, BLKSIZE_MAX));
	for (i = 0; i < snap->nr_blk_descs; i++) {
		struct blk_desc *blk_desc;

		blk_desc = &snap->blk_desc[i];
		insert_icache(icache, blk_desc);
	}
	free(buf);
	compr_final(&ctx);
	return WALK_CONTINUE;
}

static int
list(struct snap *snap, void *arg)
{
	print_md(stdout, snap->md, sizeof(snap->md));
	if (snap->msg[0] != '\0')
		printf("\t%s\n", snap->msg);
	else
		putchar('\n');
	return WALK_CONTINUE;
}

/* Walk through all snapshots and call fn() on each one */
static void
walk_snap(int (*fn)(struct snap *, void *), void *arg)
{
	uint64_t i;

	xlseek(ifd, SNAP_HDR_SIZE, SEEK_SET);
	for (i = 0; i < snap_hdr.nr_snaps; i++) {
		struct snap *snap;
		int ret;

		snap = alloc_snap();
		read_snap(ifd, snap);
		snap = grow_snap(snap, snap->nr_blk_descs);
		read_snap_descs(ifd, snap);

		ret = (*fn)(snap, arg);
		free(snap);
		if (ret == WALK_STOP)
			break;
	}
}

static void
match_ver(uint64_t v)
{
	uint8_t maj, min;

	min = v & VER_MIN_MASK;
	maj = (v >> VER_MAJ_SHIFT) & VER_MAJ_MASK;
	if (maj == VER_MAJ && min == VER_MIN)
		return;
	errx(1, "format version mismatch: expected %u.%u but got %u.%u",
	     VER_MAJ, VER_MIN, maj, min);
}

static void
init_blk_hdr(void)
{
	blk_hdr.flags = (VER_MAJ << VER_MAJ_SHIFT) | VER_MIN;
	blk_hdr.flags |= compr_algo << COMPR_ALGO_SHIFT;
	blk_hdr.flags |= hash_algo << HASH_ALGO_SHIFT;
	blk_hdr.size = BLK_HDR_SIZE;
}

static void
load_blk_hdr(void)
{
	uint64_t v;

	xlseek(sfd, 0, SEEK_SET);
	read_blk_hdr(sfd, &blk_hdr);
	match_ver(blk_hdr.flags);

	v = blk_hdr.flags >> COMPR_ALGO_SHIFT;
	v &= COMPR_ALGO_MASK;
	compr_algo = v;

	if (compr_algo < 0 || compr_algo >= NR_COMPRS)
		errx(1, "unsupported compression algorithm: %d", compr_algo);

	if (verbose > 0)
		fprintf(stderr, "Compression algorithm: %s\n",
		        compr_type2name(compr_algo));

	v = blk_hdr.flags >> HASH_ALGO_SHIFT;
	v &= HASH_ALGO_MASK;
	hash_algo = v;

	if (hash_algo < 0 || hash_algo >= NR_HASHES)
		errx(1, "unsupported hash algorithm: %d", hash_algo);

	if (verbose > 0)
		fprintf(stderr, "Hash algorithm: %s\n",
		        hash_type2name(hash_algo));
}

static void
save_blk_hdr(void)
{
	xlseek(sfd, 0, SEEK_SET);
	write_blk_hdr(sfd, &blk_hdr);
}

static void
init_snap_hdr(void)
{
	struct compr_ctx ctx;

	if (compr_init(&ctx, compr_algo) < 0)
		errx(1, "compr_init failed");
	snap_hdr.flags = (VER_MAJ << VER_MAJ_SHIFT) | VER_MIN;
	snap_hdr.size = SNAP_HDR_SIZE;
	snap_hdr.st.min_blk_size = compr_size(&ctx, BLKSIZE_MAX);
	compr_final(&ctx);
}

static void
load_snap_hdr(void)
{
	xlseek(ifd, 0, SEEK_SET);
	read_snap_hdr(ifd, &snap_hdr);
	match_ver(snap_hdr.flags);
}

static void
save_snap_hdr(void)
{
	xlseek(ifd, 0, SEEK_SET);
	write_snap_hdr(ifd, &snap_hdr);
}

static void
init(int iflag)
{
	int flags;

	flags = O_RDWR;
	if (iflag)
		flags |= O_CREAT | O_EXCL;

	ifd = open(SNAPSF, flags, 0600);
	if (ifd < 0)
		err(1, "open %s", SNAPSF);

	sfd = open(STOREF, flags, 0600);
	if (sfd < 0)
		err(1, "open %s", STOREF);

	if (flock(ifd, LOCK_NB | LOCK_EX) < 0 ||
	    flock(sfd, LOCK_NB | LOCK_EX) < 0)
		err(1, "flock");

	if (iflag) {
		init_snap_hdr();
		init_blk_hdr();
	} else {
		load_snap_hdr();
		load_blk_hdr();
	}

	icache = alloc_icache();
	walk_snap(build_icache, NULL);
}

static void
term(void)
{
	if (verbose > 0)
		print_stats(&snap_hdr.st);

	free_icache(icache);

	save_blk_hdr();
	save_snap_hdr();

	fsync(sfd);
	fsync(ifd);

	close(sfd);
	close(ifd);
}

static void
usage(void)
{
	fprintf(stderr, "usage: %s [-cilv] [-H hash] [-Z compressor] [-e id] [-r root] [-m message] [file]\n", argv0);
	exit(1);
}

int
main(int argc, char *argv[])
{
	uint8_t md[MD_SIZE];
	char *id = NULL, *root = NULL, *msg = NULL, *hash_name = NULL;
	char *compr_name;
	int iflag = 0, lflag = 0, cflag = 0;
	int fd = -1;

	ARGBEGIN {
	case 'H':
		hash_name = EARGF(usage());
		hash_algo = hash_name2type(hash_name);
		if (hash_algo < 0)
			errx(1, "unknown hash: %s", hash_name);
		break;
	case 'Z':
		compr_name = EARGF(usage());
		compr_algo = compr_name2type(compr_name);
		if (compr_algo < 0)
			errx(1, "unknown compressor: %s", compr_name);
		break;
	case 'c':
		cflag = 1;
		break;
	case 'e':
		id = EARGF(usage());
		break;
	case 'i':
		iflag = 1;
		break;
	case 'l':
		lflag = 1;
		break;
	case 'r':
		root = EARGF(usage());
		break;
	case 'm':
		msg = EARGF(usage());
		break;
	case 'v':
		verbose++;
		break;
	default:
		usage();
	} ARGEND

	if (argc > 1) {
		usage();
	} else if (argc == 1) {
		if (id) {
			fd = open(argv[0], O_RDWR | O_CREAT, 0600);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		} else {
			fd = open(argv[0], O_RDONLY);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		}
	} else {
		if (id)
			fd = STDOUT_FILENO;
		else
			fd = STDIN_FILENO;
	}

	if (root != NULL) {
		mkdir(root, 0700);
		if (chdir(root) < 0)
			err(1, "chdir: %s", root);
	}

	init(iflag);

	if (iflag) {
		term();
		return 0;
	}

	if (cflag) {
		int ret;

		ret = 0;
		walk_snap(check_snap, &ret);
		if (ret != 0)
			errx(1, "%s or %s is corrupted", SNAPSF, STOREF);

		term();
		return 0;
	}

	if (lflag) {
		walk_snap(list, NULL);
		term();
		return 0;
	}

	if (id) {
		struct extract_args args;

		str2bin(id, md);
		args.md = md;
		args.fd = fd;
		args.ret = -1;
		walk_snap(extract, &args);
		if (args.ret != 0)
			errx(1, "unknown snapshot: %s", id);
	} else {
		dedup(fd, msg);
	}

	term();
	return 0;
}
