/*  *********************************************************************
    File: hdskcert.c

    SSLRef 3.0 Final -- 11/19/96

    Copyright (c)1996 by Netscape Communications Corp.

    By retrieving this software you are bound by the licensing terms
    disclosed in the file "LICENSE.txt". Please read it, and if you don't
    accept the terms, delete this software.

    SSLRef 3.0 was developed by Netscape Communications Corp. of Mountain
    View, California <http://home.netscape.com/> and Consensus Development
    Corporation of Berkeley, California <http://www.consensus.com/>.

    *********************************************************************

    File: hdskcert.c   Contains support for certificate-related messages

    Support for encoding and decoding the certificate, certificate
    request, and certificate verify messages.

    ****************************************************************** */

#ifndef _SSLCTX_H_
#include "sslctx.h"
#endif

#ifndef _SSLHDSHK_H_
#include "sslhdshk.h"
#endif

#ifndef _SSLALLOC_H_
#include "sslalloc.h"
#endif

#ifndef _SSLALERT_H_
#include "sslalert.h"
#endif

#ifndef _X509_H_
#include "x509.h"
#endif

#include <string.h>

SSLErr
SSLEncodeCertificate(SSLRecord *certificate, SSLContext *ctx)
{   SSLErr          err;
    uint32          totalLength;
    int             i, j, certCount;
    uint8           *progress;
    SSLCertificate  *cert;
    
    /* Match DER-encoded root certs here */

    cert = ctx->localCert;
    ASSERT(cert != 0);
    totalLength = 0;
    certCount = 0;
    while (cert)
    {   totalLength += 3 + cert->derCert.length;    /* 3 for encoded length field */
        ++certCount;
        cert = cert->next;
    }
    
    certificate->contentType = SSL_handshake;
    certificate->protocolVersion = SSL_Version_3_0;
    if ((err = SSLAllocBuffer(&certificate->contents, totalLength + 7, &ctx->sysCtx)) != 0)
        return err;
    
    progress = certificate->contents.data;
    *progress++ = SSL_certificate;
    progress = SSLEncodeInt(progress, totalLength+3, 3);    /* Handshake message length */
    progress = SSLEncodeInt(progress, totalLength, 3);      /* Vector length */
    
    /* Root cert is first in the linked list, but has to go last, so walk list backwards */
    for (i = 0; i < certCount; ++i)
    {   cert = ctx->localCert;
        for (j = i+1; j < certCount; ++j)
            cert = cert->next;
        progress = SSLEncodeInt(progress, cert->derCert.length, 3);
        memcpy(progress, cert->derCert.data, cert->derCert.length);
        progress += cert->derCert.length;
    }
    
    ASSERT(progress == certificate->contents.data + certificate->contents.length);
    
    if (ctx->protocolSide == SSL_ClientSide)
        ctx->certSent = 1;

    return SSLNoErr;
}

SSLErr
SSLProcessCertificate(SSLBuffer message, SSLContext *ctx)
{   SSLErr          err;
    uint32          listLen, certLen;
    SSLBuffer       buf;
    uint8           *p;
    SSLCertificate  *cert;
    
    p = message.data;
    listLen = SSLDecodeInt(p,3);
    p += 3;
    if (listLen + 3 != message.length)
        return SSLProtocolErr;
    
    while (listLen > 0)
    {   certLen = SSLDecodeInt(p,3);
        p += 3;
        if (listLen < certLen + 3)
            return SSLProtocolErr;
        if ((err = SSLAllocBuffer(&buf, sizeof(SSLCertificate), &ctx->sysCtx)) != 0)
            return err;
        cert = (SSLCertificate*)buf.data;
        if ((err = SSLAllocBuffer(&cert->derCert, certLen, &ctx->sysCtx)) != 0)
        {   SSLFreeBuffer(&buf, &ctx->sysCtx);
            return err;
        }
        memcpy(cert->derCert.data, p, certLen);
        p += certLen;
        cert->next = ctx->peerCert;     /* Insert backwards; root cert will be first in linked list */
        ctx->peerCert = cert;
        if ((err = ASNParseX509Certificate(cert->derCert, &cert->cert, ctx)) != 0)
            return err;
        listLen -= 3+certLen;
    }
    ASSERT(p == message.data + message.length && listLen == 0);
    
    if (ctx->peerCert == 0)
        return X509CertChainInvalidErr;
    
    if ((err = X509VerifyCertChain(ctx->peerCert, ctx)) != 0)
        return err;

/* Server's certificate is the last one in the chain */
    cert = ctx->peerCert;
    while (cert->next != 0)
        cert = cert->next;
/* Convert its public key to RSAREF format */
    if ((err = X509ExtractPublicKey(&cert->cert.pubKey, &ctx->peerKey)) != 0)
        return err;
    
    if (ctx->certCtx.checkCertFunc != 0)
    {   SSLBuffer       certList, *certs;
        int             i,certCount;
        SSLCertificate  *c;
        
        if ((err = SSLGetPeerCertificateChainLength(ctx, &certCount)) != 0)
            return err;
        if ((err = SSLAllocBuffer(&certList, certCount * sizeof(SSLBuffer), &ctx->sysCtx)) != 0)
            return err;
        certs = (SSLBuffer *)certList.data;
        c = ctx->peerCert;
        for (i = 0; i < certCount; i++, c = c->next)
            certs[i] = c->derCert;
        
        if ((err = ctx->certCtx.checkCertFunc(certCount, certs, ctx->certCtx.checkCertRef)) != 0)
        {   SSLFreeBuffer(&certList, &ctx->sysCtx);
            return err;
        }
        SSLFreeBuffer(&certList, &ctx->sysCtx);
    }
    
    return SSLNoErr;
}

SSLErr
SSLEncodeCertificateRequest(SSLRecord *request, SSLContext *ctx)
{   SSLErr      err;
    uint32      dnListLen, msgLen;
    uint8       *progress;
    DNListElem  *dn;
    
    dnListLen = 0;
    dn = ctx->acceptableDNList;
    ASSERTPTR(dn);
    while (dn)
    {   dnListLen += 2 + dn->derDN.length;
        dn = dn->next;
    }
    msgLen = 1 + 1 + 2 + dnListLen;
    
    request->contentType = SSL_handshake;
    request->protocolVersion = SSL_Version_3_0;
    if ((err = SSLAllocBuffer(&request->contents, msgLen + 4, &ctx->sysCtx)) != 0)
        return err;
    
    progress = request->contents.data;
    *progress++ = SSL_certificate_request;
    progress = SSLEncodeInt(progress, msgLen, 3);
    
    *progress++ = 1;        /* one cert type */
    *progress++ = 1;        /* RSA-sign type */
    progress = SSLEncodeInt(progress, dnListLen, 2);
    dn = ctx->acceptableDNList;
    while (dn)
    {   progress = SSLEncodeInt(progress, dn->derDN.length, 2);
        memcpy(progress, dn->derDN.data, dn->derDN.length);
        progress += dn->derDN.length;
        dn = dn->next;
    }
    
    ASSERT(progress == request->contents.data + request->contents.length);
    
    return SSLNoErr;
}

SSLErr
SSLProcessCertificateRequest(SSLBuffer message, SSLContext *ctx)
{   SSLErr          err;
    int             i, dnListLen, dnLen;
    unsigned int    typeCount;
    uint8           *progress;
    SSLBuffer       dnBuf;
    DNListElem      *dn;
    
    if (message.length < 3)
        return ERR(SSLProtocolErr);
    
    progress = message.data;
    typeCount = *progress++;
    if (typeCount < 1 || message.length < 3 + typeCount)
        return ERR(SSLProtocolErr);
    
    for (i = 0; i < typeCount; i++)
    {   if (*progress++ == 1)
            ctx->x509Requested = 1;
    }
    
    dnListLen = SSLDecodeInt(progress, 2);
    progress += 2;
    if (message.length != 3 + typeCount + dnListLen)
        return ERR(SSLProtocolErr);
    
    while (dnListLen > 0)
    {   if (dnListLen < 2)
            return ERR(SSLProtocolErr);
        dnLen = SSLDecodeInt(progress, 2);
        progress += 2;
        if (dnListLen < 2 + dnLen)
            return ERR(SSLProtocolErr);
        
        if (ERR(err = SSLAllocBuffer(&dnBuf, sizeof(DNListElem), &ctx->sysCtx)) != 0)
            return err;
        dn = (DNListElem*)dnBuf.data;
        if (ERR(err = SSLAllocBuffer(&dn->derDN, dnLen, &ctx->sysCtx)) != 0)
        {   SSLFreeBuffer(&dnBuf, &ctx->sysCtx);
            return err;
        }
        memcpy(dn->derDN.data, progress, dnLen);
        progress += dnLen;
        dn->next = ctx->acceptableDNList;
        ctx->acceptableDNList = dn;
        dnListLen -= 2 + dnLen;
    }
    
    ASSERT(progress == message.data + message.length);
    
    return SSLNoErr;
}

SSLErr
SSLEncodeCertificateVerify(SSLRecord *certVerify, SSLContext *ctx)
{   SSLErr          err;
    uint8           signedHashData[36];
    SSLBuffer       hashData, shaMsgState, md5MsgState;
    uint32          len;
    unsigned int    outputLen;
    
    certVerify->contents.data = 0;
    hashData.data = signedHashData;
    hashData.length = 36;
    
    if (ERR(err = CloneHashState(&SSLHashSHA1, ctx->shaState, &shaMsgState, ctx)) != 0)
        goto fail;
    if (ERR(err = CloneHashState(&SSLHashMD5, ctx->md5State, &md5MsgState, ctx)) != 0)
        goto fail;
    if (ERR(err = SSLCalculateFinishedMessage(hashData, shaMsgState, md5MsgState, 0, ctx)) != 0)
        goto fail;
    
#if RSAREF
    len = (ctx->localKey.bits + 7)/8;
#elif BSAFE
    {   A_RSA_KEY   *keyInfo;
        int         rsaResult;
        
        if ((rsaResult = B_GetKeyInfo((POINTER*)&keyInfo, ctx->localKey, KI_RSAPublic)) != 0)
            return ERR(SSLUnknownErr);
        len = keyInfo->modulus.len;
    }
#endif /* RSAREF / BSAFE */
    
    certVerify->contentType = SSL_handshake;
    certVerify->protocolVersion = SSL_Version_3_0;
    if (ERR(err = SSLAllocBuffer(&certVerify->contents, len + 6, &ctx->sysCtx)) != 0)
        goto fail;
    
    certVerify->contents.data[0] = SSL_certificate_verify;
    SSLEncodeInt(certVerify->contents.data+1, len+2, 3);
    SSLEncodeInt(certVerify->contents.data+4, len, 2);
#if RSAREF
    if (RSAPrivateEncrypt(certVerify->contents.data+6, &outputLen,
                    signedHashData, 36, &ctx->localKey) != 0)   /* Sign the structure */
    {   err = ERR(SSLUnknownErr);
        goto fail;
    }
#elif BSAFE
    {   B_ALGORITHM_OBJ     rsa;
        B_ALGORITHM_METHOD  *chooser[] = { &AM_RSA_CRT_ENCRYPT, 0 };
        int                 rsaResult;
        
        if (ERR(rsaResult = B_CreateAlgorithmObject(&rsa)) != 0)
            return SSLUnknownErr;
        if (ERR(rsaResult = B_SetAlgorithmInfo(rsa, AI_PKCS_RSAPrivate, 0)) != 0)
            return SSLUnknownErr;
        if (ERR(rsaResult = B_EncryptInit(rsa, ctx->localKey, chooser, NO_SURR)) != 0)
            return SSLUnknownErr;
        if (ERR(rsaResult = B_EncryptUpdate(rsa, certVerify->contents.data+6,
                    &outputLen, len, signedHashData, 36, 0, NO_SURR)) != 0)
            return SSLUnknownErr;
        if (ERR(rsaResult = B_EncryptFinal(rsa, certVerify->contents.data+6+outputLen,
                    &outputLen, len-outputLen, 0, NO_SURR)) != 0)
            return SSLUnknownErr;
        B_DestroyAlgorithmObject(&rsa);
    }
#endif /* RSAREF / BSAFE */
    
    ASSERT(outputLen == len);
    
    err = SSLNoErr;
    
fail:
    ERR(SSLFreeBuffer(&shaMsgState, &ctx->sysCtx));
    ERR(SSLFreeBuffer(&md5MsgState, &ctx->sysCtx));

    return err;
}

SSLErr
SSLProcessCertificateVerify(SSLBuffer message, SSLContext *ctx)
{   SSLErr          err;
    uint8           signedHashData[36];
    uint16          signatureLen;
    SSLBuffer       hashData, shaMsgState, md5MsgState, outputData;
    unsigned int    outputLen, publicModulusLen;
    
    shaMsgState.data = 0;
    md5MsgState.data = 0;
    outputData.data = 0;
    
    if (message.length < 2)
        return ERR(SSLProtocolErr);     
    
    signatureLen = (uint16)SSLDecodeInt(message.data, 2);
    if (message.length != 2 + signatureLen)
        return ERR(SSLProtocolErr);
    
#if RSAREF
    publicModulusLen = (ctx->peerKey.bits + 7)/8;
#elif BSAFE
    {   A_RSA_KEY   *keyInfo;
        int         rsaResult;
        
        if ((rsaResult = B_GetKeyInfo((POINTER*)&keyInfo, ctx->peerKey, KI_RSAPublic)) != 0)
            return SSLUnknownErr;
        publicModulusLen = keyInfo->modulus.len;
    }
#endif /* RSAREF / BSAFE */
    
    if (signatureLen != publicModulusLen)
        return ERR(SSLProtocolErr);
    
    outputData.data = 0;
    hashData.data = signedHashData;
    hashData.length = 36;
    
    if (ERR(err = CloneHashState(&SSLHashSHA1, ctx->shaState, &shaMsgState, ctx)) != 0)
        goto fail;
    if (ERR(err = CloneHashState(&SSLHashMD5, ctx->md5State, &md5MsgState, ctx)) != 0)
        goto fail;
    if (ERR(err = SSLCalculateFinishedMessage(hashData, shaMsgState, md5MsgState, 0, ctx)) != 0)
        goto fail;
    
    if (ERR(err = SSLAllocBuffer(&outputData, publicModulusLen, &ctx->sysCtx)) != 0)
        goto fail;
    
#if RSAREF
    if (RSAPublicDecrypt(outputData.data, &outputLen,
        message.data + 2, signatureLen, &ctx->peerKey) != 0)
    {   ERR(err = SSLUnknownErr);
        goto fail;
    }
#elif BSAFE
    {   B_ALGORITHM_OBJ     rsa;
        B_ALGORITHM_METHOD  *chooser[] = { &AM_MD2, &AM_MD5, &AM_RSA_DECRYPT, 0 };
        int                 rsaResult;
        unsigned int        decryptLen;
        
        if ((rsaResult = B_CreateAlgorithmObject(&rsa)) != 0)
            return SSLUnknownErr;
        if ((rsaResult = B_SetAlgorithmInfo(rsa, AI_PKCS_RSAPublic, 0)) != 0)
            return SSLUnknownErr;
        if ((rsaResult = B_DecryptInit(rsa, ctx->peerKey, chooser, NO_SURR)) != 0)
            return SSLUnknownErr;
        if ((rsaResult = B_DecryptUpdate(rsa, outputData.data, &decryptLen, 36,
                    message.data + 2, signatureLen, 0, NO_SURR)) != 0)
            return SSLUnknownErr;
        outputLen = decryptLen;
        if ((rsaResult = B_DecryptFinal(rsa, outputData.data+outputLen,
                    &decryptLen, 36-outputLen, 0, NO_SURR)) != 0)
            return SSLUnknownErr;
        outputLen += decryptLen;
        B_DestroyAlgorithmObject(&rsa);
    }
#endif /* RSAREF / BSAFE */
    
    if (outputLen != 36)
    {   ERR(err = SSLProtocolErr);
        goto fail;
    }
    outputData.length = outputLen;
    
    DUMP_BUFFER_NAME("Finished got   ", outputData);
    DUMP_BUFFER_NAME("Finished wanted", hashData);
    
    if (memcmp(outputData.data, signedHashData, 36) != 0)
    {   ERR(err = SSLProtocolErr);
        goto fail;
    }
    
    err = SSLNoErr;
    
fail:
    ERR(SSLFreeBuffer(&shaMsgState, &ctx->sysCtx));
    ERR(SSLFreeBuffer(&md5MsgState, &ctx->sysCtx));
    ERR(SSLFreeBuffer(&outputData, &ctx->sysCtx));

    return err;
}
