#include <sys/stat.h>
#include <err.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "arg.h"
#include "sha256.h"
#include "tree.h"

#define INDEXF	".index"
#define STOREF	".store"
#define CACHEF	".cache"

#define BLKSIZ	4096

struct enthdr {
	uint64_t flags;
	uint64_t nents;
} __attribute__((packed));

struct ent {
	uint64_t size;
	uint8_t reserved[7];
	uint8_t md[32];
	uint64_t nblks;
	uint64_t blks[];
} __attribute__((packed));

struct blk {
	uint8_t md[32];
	uint64_t size;
	uint8_t data[BLKSIZ];
} __attribute__((packed));

struct cache_data {
	uint8_t md[32];
	uint64_t blkidx;
} __attribute__((packed));

struct cache_ent {
	struct cache_data data;
	int dirty;
	RB_ENTRY(cache_ent) e;
};

RB_HEAD(cache, cache_ent) cache_head;
struct enthdr enthdr;
int ifd;
int sfd;
int cfd;
int verbose;
char *argv0;

void
dump_md(const uint8_t *md, size_t len)
{
	size_t i;

	for (i = 0; i < len; i++)
		fprintf(stderr, "%02x", md[i]);
}

void
dump_enthdr(struct enthdr *hdr)
{
	fprintf(stderr, "hdr->flags = %llx\n",
	       (unsigned long long)hdr->flags);
	fprintf(stderr, "hdr->nents = %llx\n",
	       (unsigned long long)hdr->nents);
}

void
dump_ent(struct ent *ent)
{
	uint64_t i;

	fprintf(stderr, "ent->size: %llu\n", (unsigned long long)ent->size);
	fprintf(stderr, "ent->md: ");
	dump_md(ent->md, sizeof(ent->md));
	fputc('\n', stderr);
	if (verbose) {
		fprintf(stderr, "ent->nblks: %llu\n",
		        (unsigned long long)ent->nblks);
		for (i = 0; i < ent->nblks; i++)
			fprintf(stderr, "ent->blks[%llu]: %llu\n",
			       (unsigned long long)i,
			       (unsigned long long)ent->blks[i]);
	}
}

void
dump_blk(struct blk *blk)
{
	fprintf(stderr, "blk->md: ");
	dump_md(blk->md, sizeof(blk->md));
	putchar('\n');
	fprintf(stderr, "blk->size: %llu\n", (unsigned long long)blk->size);
}

void
str2bin(char *s, uint8_t *d)
{
	size_t i, len = strlen(s) / 2;

	for (i = 0; i < len; i++, s += 2)
		sscanf(s, "%2hhx", &d[i]);
}

ssize_t
xread(int fd, void *buf, size_t nbytes)
{
	uint8_t *bp = buf;
	ssize_t total = 0;

	while (nbytes > 0) {
		ssize_t n;

		n = read(fd, &bp[total], nbytes);
		if (n < 0)
			err(1, "read");
		else if (n == 0)
			return total;
		total += n;
		nbytes -= n;
	}
	return total;
}

ssize_t
xwrite(int fd, const void *buf, size_t nbytes)
{
	const uint8_t *bp = buf;
	ssize_t total = 0;

	while (nbytes > 0) {
		ssize_t n;

		n = write(fd, &bp[total], nbytes);
		if (n < 0)
			err(1, "write");
		else if (n == 0)
			return total;
		total += n;
		nbytes -= n;
	}
	return total;
}

int
cache_ent_cmp(struct cache_ent *e1, struct cache_ent *e2)
{
	int r;

	r = memcmp(e1->data.md, e2->data.md, sizeof(e1->data.md));
	if (r > 0)
		return 1;
	else if (r < 0)
		return -1;
	return 0;
}
RB_PROTOTYPE(cache, cache_ent, e, cache_ent_cmp);
RB_GENERATE(cache, cache_ent, e, cache_ent_cmp);

struct cache_ent *
alloc_cache_ent(uint8_t *md, uint64_t blkidx)
{
	struct cache_ent *ent;

	ent = calloc(1, sizeof(*ent));
	if (ent == NULL)
		err(1, "malloc");
	memcpy(&ent->data.md, md, sizeof(ent->data.md));
	ent->data.blkidx = blkidx;
	return ent;
}

void
add_cache_ent(struct cache_ent *ent)
{
	RB_INSERT(cache, &cache_head, ent);
}

void
flush_cache(void)
{
	struct cache_ent *ent;

	RB_FOREACH(ent, cache, &cache_head) {
		if (!ent->dirty)
			continue;
		lseek(cfd, ent->data.blkidx * sizeof(ent->data), SEEK_SET);
		xwrite(cfd, &ent->data, sizeof(ent->data));
		ent->dirty = 0;
	}
}

void
free_cache(void)
{
	struct cache_ent *ent, *tmp;

	RB_FOREACH_SAFE(ent, cache, &cache_head, tmp) {
		RB_REMOVE(cache, &cache_head, ent);
		free(ent);
	}
}

void
append_ent(struct ent *ent)
{
	/* Update index header */
	enthdr.nents++;
	lseek(ifd, 0, SEEK_SET);
	xwrite(ifd, &enthdr, sizeof(enthdr));

	/* Append entry */
	lseek(ifd, 0, SEEK_END);
	ent->size = sizeof(*ent);
	ent->size += ent->nblks * sizeof(ent->blks[0]);
	xwrite(ifd, ent, ent->size);
}

struct ent *
alloc_ent(void)
{
	struct ent *ent;

	ent = calloc(1, sizeof(*ent));
	if (ent == NULL)
		err(1, "malloc");
	return ent;
}

struct ent *
grow_ent(struct ent *ent, uint64_t nblks)
{
	size_t size;

	size = sizeof(*ent);
	size += nblks * sizeof(ent->blks[0]);
	ent = realloc(ent, size);
	if (ent == NULL)
		err(1, "realloc");
	return ent;
}

uint64_t
storefile_nblks(void)
{
	struct stat sb;

	if (fstat(sfd, &sb) < 0)
		err(1, "fstat");
	return sb.st_size /  sizeof(struct blk);
}

uint64_t
cachefile_nblks(void)
{
	struct stat sb;

	if (fstat(cfd, &sb) < 0)
		err(1, "fstat");
	return sb.st_size /  sizeof(struct cache_data);
}

void
hash_blk(struct blk *blk)
{
	sha256_context ctx;

	sha256_starts(&ctx);
	sha256_update(&ctx, blk->data, blk->size);
	sha256_finish(&ctx, blk->md);
}

void
read_blk(struct blk *blk, off_t blkidx)
{
	lseek(sfd, blkidx * sizeof(*blk), SEEK_SET);
	if (xread(sfd, blk, sizeof(*blk)) == 0)
		errx(1, "unexpected EOF");
}

void
append_blk(struct blk *blk)
{
	lseek(sfd, 0, SEEK_END);
	xwrite(sfd, blk, sizeof(*blk));
}

int
lookup_blk(struct blk *blk, uint64_t *blkidx)
{
	struct cache_ent *ent, key;

	memcpy(key.data.md, blk->md, sizeof(key.data.md));
	ent = RB_FIND(cache, &cache_head, &key);
	if (ent != NULL) {
		*blkidx = ent->data.blkidx;
		return 0;
	}
	return -1;
}

void
extract(char *id, int fd)
{
	uint8_t md[32];
	uint64_t nblks, i;

	str2bin(id, md);
	nblks = storefile_nblks();
	lseek(ifd, sizeof(enthdr), SEEK_SET);
	for (i = 0; i < enthdr.nents; i++) {
		uint64_t j;
		struct ent *ent;

		/* Load index entry */
		ent = alloc_ent();
		if (xread(ifd, ent, sizeof(*ent)) == 0)
			errx(1, "unexpected EOF");

		/* Check if we've located the right file */
		if (memcmp(ent->md, md, sizeof(ent->md)) != 0) {
			free(ent);
			/* Skip over index entry block table */
			lseek(ifd, ent->nblks * sizeof(ent->blks[0]), SEEK_CUR);
			continue;
		}

		/* Load index entry block table */
		ent = grow_ent(ent, ent->nblks);
		if (xread(ifd, ent->blks,
		    ent->nblks * sizeof(ent->blks[0])) == 0)
			errx(1, "unexpected EOF");

		/* Blast file blocks to file descriptor */
		for (j = 0; j < ent->nblks; j++) {
			struct blk blk;

			if (ent->blks[j] > nblks)
				errx(1, "index is corrupted");
			read_blk(&blk, ent->blks[j]);
			xwrite(fd, blk.data, blk.size);
		}
		free(ent);
		break;
	}
	if (i == enthdr.nents)
		errx(1, "%s: unknown hash %s", __func__, id);
}

void
dedup(int fd)
{
	sha256_context ctx;
	struct blk blk;
	struct ent *ent;
	ssize_t n;

	ent = alloc_ent();
	sha256_starts(&ctx);
	while ((n = xread(fd, blk.data, BLKSIZ)) > 0) {
		uint64_t blkidx;

		blk.size = n;
		hash_blk(&blk);

		/* Rolling hash of input stream */
		sha256_update(&ctx, blk.data, blk.size);
		/* Prepare for adding a new block index for this entry */
		ent = grow_ent(ent, ent->nblks + 1);

		if (lookup_blk(&blk, &blkidx) < 0) {
			struct cache_ent *cache_ent;

			blkidx = storefile_nblks();

			/* Create a cache entry for this block */
			cache_ent = alloc_cache_ent(blk.md, blkidx);
			add_cache_ent(cache_ent);
			cache_ent->dirty = 1;

			ent->blks[ent->nblks++] = blkidx;
			append_blk(&blk);
		} else {
			ent->blks[ent->nblks++] = blkidx;
		}
	}

	if (ent->nblks > 0) {
		/* Calculate hash and add this entry to the index */
		sha256_finish(&ctx, ent->md);
		append_ent(ent);
	}
	free(ent);

	flush_cache();
}

void
check(void)
{
	uint64_t nblks, i, j;

	nblks = storefile_nblks();
	lseek(ifd, sizeof(enthdr), SEEK_SET);
	for (i = 0; i < enthdr.nents; i++) {
		uint8_t md[32];
		sha256_context ctx;
		struct ent *ent;

		ent = alloc_ent();
		if (xread(ifd, ent, sizeof(*ent)) == 0)
			errx(1, "unexpected EOF");
		ent = grow_ent(ent, ent->nblks);
		if (xread(ifd, ent->blks,
		    ent->nblks * sizeof(ent->blks[0])) == 0)
			errx(1, "unexpected EOF");

		sha256_starts(&ctx);
		for (j = 0; j < ent->nblks; j++) {
			struct blk blk;

			if (ent->blks[j] > nblks)
				errx(1, "index is corrupted");
			read_blk(&blk, ent->blks[j]);
			sha256_update(&ctx, blk.data, blk.size);
		}
		sha256_finish(&ctx, md);

		if (memcmp(ent->md, md, sizeof(ent->md)) != 0)
			errx(1, "hash mismatch");

		free(ent);
	}
}

void
list(void)
{
	uint64_t i;

	lseek(ifd, sizeof(enthdr), SEEK_SET);
	for (i = 0; i < enthdr.nents; i++) {
		struct ent ent;
		size_t i;

		if (xread(ifd, &ent, sizeof(ent)) == 0)
			errx(1, "unexpected EOF");

		for (i = 0; i < sizeof(ent.md); i++)
			printf("%02x", ent.md[i]);
		if (verbose)
			printf(" %llu", (unsigned long long)ent.nblks * BLKSIZ);
		putchar('\n');
		lseek(ifd, ent.nblks * sizeof(ent.blks[0]), SEEK_CUR);
	}
}

void
rebuild_cache(void)
{
	uint64_t nblks, i;

	if (verbose)
		fprintf(stderr, "rebuilding cache...");
	nblks = storefile_nblks();
	lseek(cfd, 0, SEEK_SET);
	for (i = 0; i < nblks; i++) {
		struct cache_ent *ent;
		struct blk blk;

		read_blk(&blk, i);
		ent = alloc_cache_ent(blk.md, i);
		add_cache_ent(ent);
		ent->dirty = 1;
	}
	flush_cache();
	if (verbose)
		fprintf(stderr, "done\n");
}

void
init_cache(void)
{
	uint64_t nblks, i;

	if (verbose)
		fprintf(stderr, "initializing cache...");
	nblks = cachefile_nblks();
	lseek(cfd, 0, SEEK_SET);
	for (i = 0; i < nblks; i++) {
		struct blk blk;
		struct cache_ent *ent;

		ent = alloc_cache_ent(blk.md, i);
		if (xread(cfd, &ent->data, sizeof(ent->data)) == 0)
			errx(1, "unexpected EOF");
		add_cache_ent(ent);
	}
	if (verbose)
		fprintf(stderr, "done\n");
}

void
init(void)
{
	struct stat sb;

	ifd = open(INDEXF, O_RDWR | O_CREAT, 0600);
	if (ifd < 0)
		err(1, "open %s", INDEXF);

	sfd = open(STOREF, O_RDWR | O_CREAT, 0600);
	if (sfd < 0)
		err(1, "open %s", STOREF);

	cfd = open(CACHEF, O_RDWR | O_CREAT, 0600);
	if (cfd < 0)
		err(1, "open %s", CACHEF);

	if (fstat(ifd, &sb) < 0)
		err(1, "fstat %s", INDEXF);
	if (sb.st_size != 0)
		xread(ifd, &enthdr, sizeof(enthdr));

	if (cachefile_nblks() != storefile_nblks())
		rebuild_cache();
	else
		init_cache();
}

void
term(void)
{
	free_cache();

	fsync(ifd);
	fsync(sfd);
	fsync(cfd);

	close(ifd);
	close(sfd);
	close(cfd);
}

void
usage(void)
{
	fprintf(stderr, "usage: %s [-clv] [-e id] [-r root] [file]\n", argv0);
	exit(1);
}

int
main(int argc, char *argv[])
{
	char *id = NULL, *root = NULL;
	int fd = -1, lflag = 0, cflag = 0;

	ARGBEGIN {
	case 'c':
		cflag = 1;
		break;
	case 'e':
		id = EARGF(usage());
		break;
	case 'l':
		lflag = 1;
		break;
	case 'r':
		root = EARGF(usage());
		break;
	case 'v':
		verbose = 1;
		break;
	default:
		usage();
	} ARGEND

	if (argc > 1) {
		usage();
	} else if (argc == 1) {
		if (id) {
			fd = open(argv[0], O_RDWR | O_CREAT, 0600);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		} else {
			fd = open(argv[0], O_RDONLY);
			if (fd < 0)
				err(1, "open %s", argv[0]);
		}
	} else {
		if (id)
			fd = STDOUT_FILENO;
		else
			fd = STDIN_FILENO;
	}

	if (root != NULL) {
		mkdir(root, 0700);
		if (chdir(root) < 0)
			err(1, "chdir: %s", root);
	}

	init();

	if (cflag) {
		check();
		term();
		return 0;
	}

	if (lflag) {
		list();
		term();
		return 0;
	}

	if (id) {
		extract(id, fd);
	} else {
		dedup(fd);
	}

	term();
	return 0;
}
