/* Snapshot archive implementation */
#include <sys/types.h>
#include <sys/stat.h>

#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <unistd.h>

#include <sodium.h>

#include "config.h"
#include "misc.h"
#include "queue.h"
#include "snap.h"
#include "state.h"

/* snapshot encryption algorithms */
#define SNONETYPE	0x400
#define SCHACHATYPE	0x401

/* snapshot header constants */
#define SHDRMAGIC	"SNAPSNAPPYSNOOP"
#define NSHDRMAGIC	16

#define VMIN		0
#define VMAJ		1
#define VMINMASK	0xff
#define VMAJSHIFT	8
#define VMAJMASK	0xff

#define CRYPTOHDRSIZE	crypto_secretstream_xchacha20poly1305_HEADERBYTES
#define SHDRSIZE	(NSHDRMAGIC + CRYPTOHDRSIZE + 8 + 8)

extern struct param param;

/* misc helpers */
extern int pack(unsigned char *, char *, ...);
extern int unpack(unsigned char *, char *, ...);

/* Snapshot header structure */
struct shdr {
	char magic[NSHDRMAGIC];			/* magic number for file(1) */
	unsigned char header[CRYPTOHDRSIZE];	/* xchacha20-poly1305 crypto header */
	uint64_t flags;				/* version number */
	uint64_t nbd;				/* number of block hashes */
};

struct mdnode {
	unsigned char md[MDSIZE];		/* hash of block */
	TAILQ_ENTRY(mdnode) e;			/* mdhead link node */
};

struct sctx {
	TAILQ_HEAD(mdhead, mdnode) mdhead;	/* list of hashes contained in snapshot */
	struct mdnode *mdnext;			/* next hash to be returned via sget() */
	int type;				/* encryption algorithm */
	int fd;					/* underlying snapshot file descriptor */
	int rdonly;				/* when set, ssync() is a no-op */
	struct shdr shdr;			/* snapshot header */
};

/* Unpack snapshot header */
static int
unpackshdr(unsigned char *buf, struct shdr *shdr)
{
	char fmt[BUFSIZ];
	int n;

	snprintf(fmt, sizeof(fmt), "'%d'%dqq", NSHDRMAGIC, CRYPTOHDRSIZE);
	n = unpack(buf, fmt,
	           shdr->magic,
	           shdr->header,
	           &shdr->flags,
	           &shdr->nbd);

	assert(n == SHDRSIZE);
	return n;
}

/* Pack snapshot header */
static int
packshdr(unsigned char *buf, struct shdr *shdr)
{
	char fmt[BUFSIZ];
	int n;

	snprintf(fmt, sizeof(fmt), "'%d'%dqq", NSHDRMAGIC, CRYPTOHDRSIZE);
	n = pack(buf, fmt,
	         shdr->magic,
	         shdr->header,
	         shdr->flags,
	         shdr->nbd);

	assert(n == SHDRSIZE);
	return n;
}

static int
loadmdnone(struct sctx *sctx, int first)
{
	struct mdnode *mdnode;

	mdnode = calloc(1, sizeof(*mdnode));
	if (mdnode == NULL) {
		seterr("calloc: out of memory");
		return -1;
	}

	if (xread(sctx->fd, mdnode->md, MDSIZE) != MDSIZE) {
		seterr("failed to read block hash: %s", strerror(errno));
		return -1;
	}

	TAILQ_INSERT_TAIL(&sctx->mdhead, mdnode, e);
	return 0;
}

static int
loadmdchacha(struct sctx *sctx, int first)
{
	unsigned char buf[MDSIZE + crypto_secretstream_xchacha20poly1305_ABYTES];
	unsigned char hdr[SHDRSIZE];
	static crypto_secretstream_xchacha20poly1305_state state;
	struct mdnode *mdnode;
	struct shdr *shdr;

	shdr = &sctx->shdr;
	packshdr(hdr, shdr);
	if (first && crypto_secretstream_xchacha20poly1305_init_pull(&state,
	                                                             shdr->header,
	                                                             param.key) < 0) {
		seterr("invalid crypto header");
		return -1;
	}

	if (xread(sctx->fd, buf, sizeof(buf)) != sizeof(buf)) {
		seterr("failed to read block hash: %s", strerror(errno));
		return -1;
	}

	mdnode = calloc(1, sizeof(*mdnode));
	if (mdnode == NULL) {
		seterr("calloc: out of memory");
		return -1;
	}

	if (crypto_secretstream_xchacha20poly1305_pull(&state, mdnode->md, NULL,
	                                               NULL, buf, sizeof(buf),
	                                               hdr, sizeof(hdr)) < 0) {
		free(mdnode);
		seterr("authentication failed");
		return -1;
	}

	TAILQ_INSERT_TAIL(&sctx->mdhead, mdnode, e);
	return 0;
}

static int
initmdhead(struct sctx *sctx)
{
	struct shdr *shdr;
	int (*loadmd)(struct sctx *, int);
	uint64_t i;

	if (sctx->type == SNONETYPE)
		loadmd = loadmdnone;
	else
		loadmd = loadmdchacha;

	shdr = &sctx->shdr;
	for (i = 0; i < shdr->nbd; i++) {
		if ((*loadmd)(sctx, i == 0) == 0)
			continue;

		while (!TAILQ_EMPTY(&sctx->mdhead)) {
			struct mdnode *mdnode;

			mdnode = TAILQ_FIRST(&sctx->mdhead);
			TAILQ_REMOVE(&sctx->mdhead, mdnode, e);
			free(mdnode);
		}
		return -1;
	}
	return 0;
}

int
screat(char *path, int mode, struct sctx **sctx)
{
	unsigned char buf[SHDRSIZE];
	struct shdr *shdr;
	int type;
	int fd;

	if (path == NULL || sctx == NULL) {
		seterr("invalid params");
		return -1;
	}

	/* Determine algorithm type */
	if (strcasecmp(param.ealgo, "none") == 0) {
		type = SNONETYPE;
	} else if (strcasecmp(param.ealgo, "XChaCha20-Poly1305") == 0) {
		type = SCHACHATYPE;
	} else {
		seterr("invalid encryption type: %s", param.ealgo);
		return -1;
	}

	/* Ensure a key has been provided if caller requested encryption */
	if (type != SNONETYPE && !param.keyloaded) {
		seterr("expected encryption key");
		return -1;
	}

	if (sodium_init() < 0) {
		seterr("sodium_init: failed");
		return -1;
	}

	fd = open(path, O_RDWR | O_CREAT | O_EXCL, mode);
	if (fd < 0) {
		seterr("open: %s", strerror(errno));
		return -1;
	}

	*sctx = calloc(1, sizeof(**sctx));
	if (*sctx == NULL) {
		close(fd);
		seterr("calloc: out of memory");
		return -1;
	}

	TAILQ_INIT(&(*sctx)->mdhead);
	(*sctx)->mdnext = NULL;
	(*sctx)->type = type;
	(*sctx)->fd = fd;

	shdr = &(*sctx)->shdr;
	memcpy(shdr->magic, SHDRMAGIC, NSHDRMAGIC);
	shdr->flags = (VMAJ << VMAJSHIFT) | VMIN;
	shdr->nbd = 0;

	packshdr(buf, shdr);
	if (xwrite(fd, buf, SHDRSIZE) != SHDRSIZE) {
		free(*sctx);
		close(fd);
		seterr("failed to write snapshot header: %s", strerror(errno));
		return -1;
	}
	return 0;
}

int
sopen(char *path, int flags, int mode, struct sctx **sctx)
{
	unsigned char buf[SHDRSIZE];
	struct shdr *shdr;
	int type;
	int fd;

	if (path == NULL || sctx == NULL) {
		seterr("invalid params");
		return -1;
	}

	/* Existing snapshots are immutable */
	if (flags != S_READ) {
		seterr("invalid params");
		return -1;
	}

	/* Determine algorithm type */
	if (strcasecmp(param.ealgo, "none") == 0) {
		type = SNONETYPE;
	} else if (strcasecmp(param.ealgo, "XChaCha20-Poly1305") == 0) {
		type = SCHACHATYPE;
	} else {
		seterr("invalid encryption type: %s", param.ealgo);
		return -1;
	}

	/* Ensure a key has been provided if caller requested encryption */
	if (type != SNONETYPE && !param.keyloaded) {
		seterr("expected encryption key");
		return -1;
	}

	if (sodium_init() < 0) {
		seterr("sodium_init: failed");
		return -1;
	}

	fd = open(path, O_RDONLY, mode);
	if (fd < 0) {
		seterr("open: %s", strerror(errno));
		return -1;
	}

	*sctx = calloc(1, sizeof(**sctx));
	if (*sctx == NULL) {
		close(fd);
		seterr("calloc: out of memory");
		return -1;
	}

	TAILQ_INIT(&(*sctx)->mdhead);
	(*sctx)->mdnext = NULL;
	(*sctx)->type = type;
	(*sctx)->fd = fd;
	(*sctx)->rdonly = 1;

	shdr = &(*sctx)->shdr;

	if (xread(fd, buf, SHDRSIZE) != SHDRSIZE) {
		free(sctx);
		close(fd);
		seterr("failed to read snapshot header: %s", strerror(errno));
		return -1;
	}
	unpackshdr(buf, shdr);

	if (memcmp(shdr->magic, SHDRMAGIC, NSHDRMAGIC) != 0) {
		free(sctx);
		close(fd);
		seterr("unknown snapshot header magic");
		return -1;
	}

	/* If the major version is different, the format is incompatible */
	if (((shdr->flags >> VMAJSHIFT) & VMAJMASK) != VMAJ) {
		free(sctx);
		close(fd);
		seterr("snapshot header version mismatch");
		return -1;
	}

	if (initmdhead(*sctx) < 0) {
		free(*sctx);
		close(fd);
		return -1;
	}
	return 0;
}

int
sput(struct sctx *sctx, unsigned char *md)
{
	struct shdr *shdr;
	struct mdnode *mdnode;

	if (sctx == NULL || md == NULL) {
		seterr("invalid params");
		return -1;
	}

	mdnode = calloc(1, sizeof(*mdnode));
	if (mdnode == NULL) {
		seterr("calloc: out of memory");
		return -1;
	}
	shdr = &sctx->shdr;
	shdr->nbd++;
	memcpy(mdnode->md, md, MDSIZE);
	TAILQ_INSERT_TAIL(&sctx->mdhead, mdnode, e);
	return 0;
}

int
sget(struct sctx *sctx, unsigned char *md)
{
	struct mdnode *mdnode;

	if (sctx == NULL || md == NULL) {
		seterr("invalid params");
		return -1;
	}

	mdnode = sctx->mdnext;
	if (mdnode == NULL)
		mdnode = TAILQ_FIRST(&sctx->mdhead);
	else
		mdnode = TAILQ_NEXT(mdnode, e);
	sctx->mdnext = mdnode;
	if (mdnode != NULL) {
		memcpy(md, mdnode->md, MDSIZE);
		return MDSIZE;
	}
	return 0;
}

static int
syncnone(struct sctx *sctx)
{
	unsigned char hdr[SHDRSIZE];
	struct mdnode *mdnode;
	struct shdr *shdr;

	shdr = &sctx->shdr;
	packshdr(hdr, shdr);
	if (xwrite(sctx->fd, hdr, SHDRSIZE) != SHDRSIZE) {
		seterr("failed to write snapshot header: %s", strerror(errno));
		return -1;
	}

	TAILQ_FOREACH(mdnode, &sctx->mdhead, e) {
		if (xwrite(sctx->fd, mdnode->md, MDSIZE) != MDSIZE) {
			seterr("failed to write block hash: %s",
			        strerror(errno));
			return -1;
		}
	}
	return 0;
}

static int
syncchacha(struct sctx *sctx)
{
	unsigned char hdr[SHDRSIZE];
	crypto_secretstream_xchacha20poly1305_state state;
	struct mdnode *mdnode;
	struct shdr *shdr;

	shdr = &sctx->shdr;
	crypto_secretstream_xchacha20poly1305_init_push(&state,
	                                                shdr->header,
	                                                param.key);

	packshdr(hdr, shdr);
	if (xwrite(sctx->fd, hdr, SHDRSIZE) != SHDRSIZE) {
		seterr("failed to write snapshot header: %s", strerror(errno));
		return -1;
	}

	TAILQ_FOREACH(mdnode, &sctx->mdhead, e) {
		unsigned char buf[MDSIZE + crypto_secretstream_xchacha20poly1305_ABYTES];
		unsigned char tag;

		if (TAILQ_LAST(&sctx->mdhead, mdhead) == mdnode)
			tag = crypto_secretstream_xchacha20poly1305_TAG_FINAL;
		else
			tag = 0;

		crypto_secretstream_xchacha20poly1305_push(&state,
		                                           buf, NULL,
		                                           mdnode->md, MDSIZE,
		                                           hdr, SHDRSIZE, tag);
		if (xwrite(sctx->fd, buf, sizeof(buf)) != sizeof(buf)) {
			seterr("failed to write block hash: %s",
			        strerror(errno));
			return -1;
		}
	}
	return 0;
}

int
ssync(struct sctx *sctx)
{
	if (sctx == NULL) {
		seterr("invalid params");
		return -1;
	}

	if (sctx->rdonly)
		return 0;

	if (lseek(sctx->fd, 0, SEEK_SET) < 0) {
		seterr("lseek: %s", strerror(errno));
		return -1;
	}

	if (sctx->type == SNONETYPE)
		syncnone(sctx);
	else
		syncchacha(sctx);

	fsync(sctx->fd);
	return 0;
}

int
sclose(struct sctx *sctx)
{
	int r;

	if (sctx == NULL)
		return -1;

	if (ssync(sctx) < 0)
		return -1;

	/* Free block hash list */
	while (!TAILQ_EMPTY(&sctx->mdhead)) {
		struct mdnode *mdnode;

		mdnode = TAILQ_FIRST(&sctx->mdhead);
		TAILQ_REMOVE(&sctx->mdhead, mdnode, e);
		free(mdnode);
	}

	r = close(sctx->fd);
	free(sctx);
	if (r < 0)
		seterr("close: %s", strerror(errno));
	return r;
}
