#include "dat.h"
#include "fns.h"

struct CertState {
	int		pub;
	RSApriv		*priv;
};

struct SigState {			/* XXX: botch */
	mpint		*sig;
};

static Cert*		rsaalloc(int);
static Sig*		mkrsasig(mpint*);
static void		rsafree(Cert*);

static Cert*
rsaalloc(int pub)
{
	Cert *p;
	CertState *s;

	s = emalloc(sizeof *s);
	s->pub = pub;
	p = emalloc(sizeof *p);
	*p = certsshrsa;
	p->state = s;

	return p;
}

static Sig*
mkrsasig(mpint *sig)
{
	Sig *p;
	SigState *s;

	s = emalloc(sizeof *s);
	s->sig = sig;
	p = emalloc(sizeof *p);
	*p = sigsshrsa;
	p->state = s;

	return p;
}

static void
rsafreesig(Sig *sig)
{
	if(sig == nil)
		return;
	if(sig->state){			/* XXX: botch */
		if(sig->state->sig)
			mpfree(sig->state->sig);
		free(sig->state);
	}
	free(sig);
}

static Cert*
rsacertgen(void *v)				/* XXX: botch */
{
	Cert *p;
	int *n, nlen, elen, nrep;

	if(v == nil){
		nlen = 2048;
		elen = 6;
		nrep = 0;	/* i.e. use default */
	}else{
		n = (int*)v;
		nlen = n[0];
		elen = n[1];
		nrep = n[2];
	}

	p = rsaalloc(0);
	p->state->priv = rsagen(nlen, elen, nrep);

	return p;
}

static void
rsafree(Cert *p)
{
	if(p == nil)
		return;
	if(p->state){
		if(p->state->priv != nil){
			mpclear(p->state->priv->dk);
			p->state->priv->dk = nil;
		}
		rsaprivfree(p->state->priv);
		free(p->state);
	}
	free(p);
	return;
}

static int
rsacertcmp(Cert *a, Cert *b)
{
	RSApriv *ap, *bp;
	CertState *as, *bs;


	if(a == nil || b == nil)
		return -1;
	as = a->state;
	bs = b->state;
	if(as == nil || bs == nil)	/* can't compare implementations */
		return -1;

	if(as->pub != bs->pub)		/* must either both be public or not */
		return -1;
	ap = as->priv;
	bp = bs->priv;
	if(mpcmp(ap->pub.n, bp->pub.n) ||
	   mpcmp(ap->pub.ek, bp->pub.ek))
		return -1;
	/* we assume chinese remainder values are OK */
	if(as->pub == 0 && mpcmp(ap->dk, bp->dk) != 0)
		return -1;

	return 0;
}

Sig*
rsasign(Cert *p, uchar *buf, int n)
{
	mpint *b;
	Sig *sig;
	mpint *rsa;
	CertState *s;

	s = p->state;
	b = sha1tomp(buf, n);
	if(s->pub != 0)
		panic("rsasign: can't sign with public key (%lux)",
		      getcallerpc(&s));
	rsa = rsadecrypt(s->priv, b, nil);
	mpfree(b);
	sig = mkrsasig(rsa);
	return sig;
}

int
rsaverify(Cert *p, Sig *s, uchar *buf, int n)
{
	int rv;
	SigState *ss;
	CertState *cs;
	mpint *b, *sig;

	rv = 0;
	cs = p->state;
	ss = s->state;
	b = sha1tomp(buf, n);				/* XXX: datafellows? */
	sig = rsaencrypt(&cs->priv->pub, b, nil);
	if(mpcmp(ss->sig, sig) == 0)
		rv = 1;
	mpfree(b);
	mpfree(sig);

	return rv;
}

static void
rsadigest(Cert *p, uchar *buf, int len, DigestFn fn, int sz)
{
	int n;
	Blob *b;

	if(sz > len)
		panic("rsadigest: digest buffer too smal (%lux)",
		      getcallerpc(&p));
	n = sizecert(p);
	b = mkblob(n);
	putcert(b, p);
	fn(b->rp, b->wp - b->rp, buf, nil);
	freeblob(b);

	return;
}

static int
rsapubencode(Cert *p, Blob *b)
{
	CertState *s;

	s = p->state;
	putstring(b, "ssh-rsa");
	putmpint(b, s->priv->pub.ek);
	putmpint(b, s->priv->pub.n);

	return 0;
}

static int
rsapubsize(Cert *p)
{
	int n;
	CertState *s;

	s = p->state;
	n = sizestring("ssh-rsa");
	n += sizempint(s->priv->pub.ek);
	n += sizempint(s->priv->pub.n);

	return n;
}

static int
rsapubprint(Cert *p, Biobuf *b)
{
	int n;
	Blob *t;

	n = sizecert(p);
	t = mkblob(n);
	putcert(t, p);
	n = t->wp - t->bp;
	n = Bprint(b, "ssh-rsa %.*[", n, t->bp);
	freeblob(t);

	return n;
}

static int
rsaprivprint(Cert *p, Biobuf *b)
{
	int n;
	CertState *s;
	RSApriv *rsa;

	s = p->state;
	if(s->pub != 0)
		panic("rsaprivprint: certificate has no private parts (%lux)",
		      getcallerpc(&p));
	rsa = s->priv;
	n = Bprint(b, "ssh-rsa %B %B %B %B %B %B %B %B",
		   rsa->pub.n, rsa->pub.ek, rsa->dk, rsa->p, rsa->q,
		   rsa->kp, rsa->kq, rsa->c2);
	return n;
}

static Cert*
rsapubdecode(Blob *b)
{
	Cert *p;
	RSApriv *rsa;

	p = rsaalloc(1);
	rsa = rsaprivalloc();

	/* string already snarfed by certwiresniff */
	rsa->pub.n = getmpint(b);
	rsa->pub.ek = getmpint(b);
	p->state->priv = rsa;

	return p;
}

static Cert*
rsapubparse(char *s)
{
	int n;
	char *t;
	uchar *buf;

	Blob *b;
	Cert *p;

	s = skipwhite(s);
	n = skiptext(s) - s;
	if(strncmp("ssh-rsa", s, n) != 0){
		werrstr("corrupt ssh-rsa certificate public key");
		return nil;
	}

	s = skiptext(s);
	s = skipwhite(s);
	n = strlen(s)+1;
	buf = emalloc(n);
	if((n = dec64(buf, n, s, n)) < 0)
		panic("rsapubparse: invalid input (%lux)", getcallerpc(&s));
	b = copyblob(buf, n);			/* XXX: bogus copy */
	free(buf);

	/* rsapubdecode doesn't read string type field */
	t = getstring(b);
	if(strcmp("ssh-rsa", t) != 0)
		panic("rsapubparse: type mismatch (0x%lux)", getcallerpc(&s));
	p = rsapubdecode(b);
	freeblob(b);

	return p;
}

static Cert*
rsaprivparse(char *s)
{
	int n;
	char *f[16];

	Cert *p;
	RSApriv *rsa;

	memset(f, 0, sizeof(f));
	if((n = tokenize(s, f, nelem(f))) != 9)
		goto Invalid;
	if(f[0] == nil || strcmp(f[0], "ssh-rsa") != 0)
		goto Invalid;
	if(f[1] == nil || f[2] == nil || f[3] == nil || f[4] == nil ||
	   f[5] == nil || f[6] == nil || f[7] == nil || f[8] == nil)
		goto Invalid;
	p = rsaalloc(0);
	rsa = rsaprivalloc();
	rsa->pub.n = strtomp(f[1], nil, 16, nil);
	rsa->pub.ek = strtomp(f[2], nil, 16, nil);
	rsa->dk = strtomp(f[3], nil, 16, nil);
	rsa->p = strtomp(f[4], nil, 16, nil);
	rsa->q = strtomp(f[5], nil, 16, nil);
	rsa->kp = strtomp(f[6], nil, 16, nil);
	rsa->kq = strtomp(f[7], nil, 16, nil);
	rsa->c2 = strtomp(f[8], nil, 16, nil);
	p->state->priv = rsa;
	if(rsa->pub.n == nil || rsa->pub.ek == nil)
		panic("rsaprivparse: private key corrupt");
	else if(rsa->dk == nil || rsa->p == nil || rsa->q == nil ||
		rsa->kp == nil || rsa->kq == nil || rsa->c2 == nil)
		panic("rsaprivparse: private key corrupt");
	return p;

 Invalid:
	panic("rsaprivparse: invalid input (%lux)", getcallerpc(&s));
	return nil;	/* shut up stupid compiler */
}

static int
rsasigencode(Sig *p, Blob *b)
{
	int n;
	mpint *rsa;
	uchar *buf = nil;

	if(p == nil || p->state == nil || p->state->sig == nil)
		panic(Eencode);

	rsa = p->state->sig;

	/* dump sign; rfc requires >0 */
	n = mptobe(rsa, nil, 0, &buf);
	putstring(b, "ssh-rsa");
	putbstring(b, buf, n);
	free(buf);

	return 0;
}

int
rsasigsize(Sig *p)
{
	int n;
	mpint *rsa;

	rsa = p->state->sig;
	n = sizestring("ssh-rsa");
	n += (rsa->top+1)*Dbytes + 4;		/* XXX: botch ("(b)string") */

	return n;
}

static Sig*
rsasigdecode(Blob *b)
{
	int n;
	Sig *p;
	mpint *rsa;
	uchar *buf;

	buf = getbstring(b, &n);
	rsa = betomp(buf, n, nil);	/* positive w/no sign */
	p = mkrsasig(rsa);

	return p;
}

CertImpl certsshrsa = {
	"ssh-rsa",
	Sign,			/* XXX: encrypt, decrypt */
	nil,			/* state */

	rsacertgen,
	rsafree,
	rsacertcmp,
	nil,			/* crypt */
	nil,
	rsasign,
	rsaverify,
	rsadigest,

	rsapubdecode,
	rsapubencode,

	rsapubparse,
	rsapubprint,

	rsaprivparse,
	rsaprivprint,

	rsapubsize
};

SigImpl sigsshrsa = {
	"ssh-rsa",
	nil,

	rsafreesig,
	rsasigencode,

	rsasigdecode,
	rsasigsize
};
