#include "dat.h"
#include "fns.h"

struct CertState {
	int		pub;
	DSApriv		*priv;
};

struct SigState {			/* XXX: botch */
	DSAsig		*sig;
};

static Cert*		dssalloc(int);
static Sig*		mkdsssig(DSAsig*);
static void		dssfreesig(Sig*);

static Cert*
dssalloc(int pub)
{
	Cert *p;
	CertState *s;

	s = emalloc(sizeof *s);
	s->pub = pub;
	p = emalloc(sizeof *p);
	*p = certsshdss;
	p->state = s;

	return p;
}

static Sig*
mkdsssig(DSAsig *sig)
{
	Sig *p;
	SigState *s;

	s = emalloc(sizeof *s);
	s->sig = sig;
	p = emalloc(sizeof *p);
	*p = sigsshdss;
	p->state = s;

	return p;
}

static void
dssfreesig(Sig *sig)
{
	if(sig == nil)
		return;
	if(sig->state){			/* XXX: botch */
		if(sig->state->sig)
			dsasigfree(sig->state->sig);
		free(sig->state);
	}
	free(sig);
}

static Cert*
dsscertgen(void *v)
{
	Cert *p;

	USED(v);
	p = dssalloc(0);
	p->state->priv = dsagen(nil);

	return p;
}

static void
dssfree(Cert *p)
{
	if(p == nil)
		return;
	if(p->state){
		if(p->state->priv != nil){
			mpclear(p->state->priv->secret);
			p->state->priv->secret = nil;
		}
		dsaprivfree(p->state->priv);
		free(p->state);
	}
	free(p);
	return;
}

static int
dsscertcmp(Cert *a, Cert *b)
{
	DSApriv *ap, *bp;
	CertState *as, *bs;

	if(a == nil || b == nil)
		return -1;
	as = a->state;
	bs = b->state;
	if(as == nil || bs == nil)	/* can't compare implementations */
		return -1;

	if(as->pub != bs->pub)		/* must either both be public or not */
		return -1;
	ap = as->priv;
	bp = bs->priv;
	if(mpcmp(ap->pub.p, bp->pub.p) ||
	   mpcmp(ap->pub.q, bp->pub.q) ||
	   mpcmp(ap->pub.alpha, bp->pub.alpha) ||
	   mpcmp(ap->pub.key, bp->pub.key))
		return -1;
	if(as->pub == 0 && mpcmp(ap->secret, bp->secret) != 0)
		return -1;

	return 0;
}

Sig*
dsssign(Cert *p, uchar *buf, int n)
{
	mpint *b;
	Sig *sig;
	DSAsig *dsa;
	CertState *s;

	s = p->state;
	if(s->pub != 0)
		panic("dsssign: can't sign with public key (%lux)",
		      getcallerpc(&s));
	b = sha1tomp(buf, n);
	dsa = dsasign(s->priv, b);
	mpfree(b);
	sig = mkdsssig(dsa);
	return sig;
}

int
dssverify(Cert *p, Sig *sig, uchar *buf, int n)
{
	int rv;
	mpint *b;
	DSAsig *dsa;
	SigState *st;
	CertState *s;

	s = p->state;
	st = (SigState*)sig->state;
	dsa = st->sig;
	b = sha1tomp(buf, n);
	rv = dsaverify(&s->priv->pub, dsa, b);
	mpfree(b);

	return rv;
}

static void
dssdigest(Cert *p, uchar *buf, int len, DigestFn fn, int sz)
{
	int n;
	Blob *b;

	if(sz > len)
		panic("dssdigest: digest buffer too small (%lux)",
		      getcallerpc(&p));
	n = sizecert(p);
	b = mkblob(n);
	putcert(b, p);
	fn(b->rp, b->wp - b->rp, buf, nil);
	freeblob(b);

	return;
}

static int
dsspubencode(Cert *p, Blob *b)
{
	CertState *s;

	s = p->state;
	putstring(b, "ssh-dss");
	putmpint(b, s->priv->pub.p);
	putmpint(b, s->priv->pub.q);
	putmpint(b, s->priv->pub.alpha);
	putmpint(b, s->priv->pub.key);

	return 0;
}

static int
dsspubsize(Cert *p)
{
	int n;
	CertState *s;

	s = p->state;
	n = sizestring("ssh-dss");
	n += sizempint(s->priv->pub.p);
	n += sizempint(s->priv->pub.q);
	n += sizempint(s->priv->pub.alpha);
	n += sizempint(s->priv->pub.key);

	return n;
}

static int
dsspubprint(Cert *p, Biobuf *b)
{
	int n;
	Blob *t;

	n = sizecert(p);
	t = mkblob(n);
	putcert(t, p);
	n = t->wp - t->bp;
	n = Bprint(b, "ssh-dss %.*[", n, t->bp);
	freeblob(t);

	return n;
}

static int
dssprivprint(Cert *p, Biobuf *b)
{
	int n;
	CertState *s;
	DSApriv *dsa;

	s = p->state;
	if(s->pub != 0)
		panic("dssprivprint: certificate has no private parts (%lux)",
		      getcallerpc(&p));
	dsa = s->priv;
	n = Bprint(b, "ssh-dss %B %B %B %B %B", dsa->pub.p, dsa->pub.q,
		   dsa->pub.alpha, dsa->pub.key, dsa->secret);
	return n;
}

static Cert*
dsspubdecode(Blob *b)
{
	Cert *p;
	DSApriv *dsa;

	p = dssalloc(1);
	dsa = dsaprivalloc();

	/* string already snarfed by certwiresniff */
	dsa->pub.p = getmpint(b);
	dsa->pub.q = getmpint(b);
	dsa->pub.alpha = getmpint(b);
	dsa->pub.key = getmpint(b);
	p->state->priv = dsa;

	return p;
}

static Cert*
dsspubparse(char *s)
{
	int n;
	char *t;
	uchar *buf;

	Blob *b;
	Cert *p;

	s = skipwhite(s);
	n = skiptext(s) - s;
	if(strncmp("ssh-dss", s, n) != 0){
		werrstr("corrupt ssh-dss certificate public key");
		return nil;
	}

	s = skiptext(s);
	s = skipwhite(s);
	n = strlen(s)+1;
	buf = emalloc(n);
	if((n = dec64(buf, n, s, n)) < 0)
		panic("dsspubparse: invalid input (0x%lux)", getcallerpc(&s));
	b = copyblob(buf, n);			/* XXX: bogus copy */
	free(buf);

	/* dsspubdecode doesn't read string type field */
	t = getstring(b);
	if(strcmp("ssh-dss", t) != 0)
		panic("dsspubparse: type mismatch (0x%lux)", getcallerpc(&s));
	p = dsspubdecode(b);
	freeblob(b);

	return p;
}

static Cert*
dssprivparse(char *s)
{
	int n;
	char *f[16];

	Cert *p;
	DSApriv *dsa;

	memset(f, 0, sizeof(f));
	n = tokenize(s, f, nelem(f));
	if(n != 6 || f[0] == nil || strcmp(f[0], "ssh-dss") != 0)
		goto Invalid;
	if(f[0] == nil || f[1] == nil || f[2] == nil ||
	   f[3] == nil || f[4] == nil || f[5] == nil)
		goto Invalid;
	p = dssalloc(0);
	dsa = dsaprivalloc();
	dsa->pub.p = strtomp(f[1], nil, 16, nil);
	dsa->pub.q = strtomp(f[2], nil, 16, nil);
	dsa->pub.alpha = strtomp(f[3], nil, 16, nil);
	dsa->pub.key = strtomp(f[4], nil, 16, nil);
	dsa->secret = strtomp(f[5], nil, 16, nil);
	p->state->priv = dsa;
	if(dsa->pub.p == nil ||
	   dsa->pub.q == nil ||
	   dsa->pub.alpha == nil ||
	   dsa->pub.key == nil ||
	   dsa->secret == nil)
		panic("dssprivparse: private key corrupt");
	return p;

 Invalid:
	panic("dssprivparse: invalid input (%lux)", getcallerpc(&s));
	return nil;	/* shut up stupid compiler */
}

static int
dsssigsize(Sig *p)
{
	int n;

	USED(p);

	n = sizestring("ssh-dss")+4+20+20;
	return n;
}

static int
dsssigencode(Sig *p, Blob *b)
{
	DSAsig *dsa;
	uchar r[20], s[20];

	if(p == nil || p->state == nil || p->state->sig == nil)
		panic(Eencode);
	dsa = p->state->sig;
	if(mptobe(dsa->r, r, sizeof(r), nil) < 0)
		panic(Eencode);
	if(mptobe(dsa->s, s, sizeof(s), nil) < 0)
		panic(Eencode);
	putstring(b, "ssh-dss");
	putlong(b, sizeof(r)+sizeof(s));
	putbytes(b, r, sizeof(r));
	putbytes(b, s, sizeof(s));

	return 0;
}

static Sig*
dsssigdecode(Blob *b)
{
	long n;
	uchar *buf;

	Sig *p;
	DSAsig *dsa;

	n = getlong(b);
	if(n != 20+20)
		panic(Edecode);
	dsa = dsasigalloc();
	buf = getbytes(b, 20);
	dsa->r = betomp(buf, 20, nil);
	buf = getbytes(b, 20);
	dsa->s = betomp(buf, 20, nil);
	if(dsa->r == nil || dsa->s == nil)
		panic(Edecode);
	p = mkdsssig(dsa);
	return p;
}

CertImpl certsshdss = {
	"ssh-dss",
	Sign,
	nil,			/* state */
	dsscertgen,
	dssfree,
	dsscertcmp,
	nil,			/* crypt */
	nil,
	dsssign,
	dssverify,
	dssdigest,

	dsspubdecode,
	dsspubencode,

	dsspubparse,
	dsspubprint,

	dssprivparse,
	dssprivprint,

	dsspubsize,
};

SigImpl sigsshdss = {
	"ssh-dss",
	nil,			/* SigState */

	dssfreesig,
	dsssigencode,
	dsssigdecode,
	dsssigsize
};
