/*  *********************************************************************
    File: sslctx.c

    SSLRef 3.0 Final -- 11/19/96

    Copyright (c)1996 by Netscape Communications Corp.

    By retrieving this software you are bound by the licensing terms
    disclosed in the file "LICENSE.txt". Please read it, and if you don't
    accept the terms, delete this software.

    SSLRef 3.0 was developed by Netscape Communications Corp. of Mountain
    View, California <http://home.netscape.com/> and Consensus Development
    Corporation of Berkeley, California <http://www.consensus.com/>.

    *********************************************************************

    File: sslctx.c     SSLContext accessors

    Functions called by the end user which configure an SSLContext
    structure or access data stored there.

    ****************************************************************** */

#ifndef _SSLCTX_H_
#include "sslctx.h"
#endif

#ifndef _X509_H_
#include "x509.h"
#endif

#ifndef _SSLALLOC_H_
#include "sslalloc.h"
#endif

#include <string.h>

static SSLErr SSLDuplicateCertificateChain(SSLCertificate *chain, SSLContext *destCtx);
static SSLErr SSLDeleteCertificateChain(SSLCertificate *chain, SSLContext *ctx);

uint32
SSLContextSize(void)
{   return sizeof(SSLContext);
}

SSLErr
SSLInitContext(SSLContext *ctx)
{
    memset(ctx, 0, sizeof(SSLContext));
    
    /* Initialize the cipher state to NULL_WITH_NULL_NULL */
    ctx->selectedCipherSpec = &SSL_NULL_WITH_NULL_NULL_CipherSpec;
    ctx->selectedCipher = ctx->selectedCipherSpec->cipherSpec;
    ctx->writeCipher.hash = ctx->selectedCipherSpec->macAlgorithm;
    ctx->readCipher.hash = ctx->selectedCipherSpec->macAlgorithm;
    ctx->readCipher.symCipher = ctx->selectedCipherSpec->cipher;
    ctx->writeCipher.symCipher = ctx->selectedCipherSpec->cipher;
    SSLInitMACPads();

    return SSLNoErr;
}

SSLErr
SSLDuplicateContext(SSLContext *src, SSLContext *dest, void *ioRef)
{   SSLErr      err;
    
    if ((err = SSLInitContext(dest)) != 0)
        return err;
    
    /* Copy all the connection-independent fields */
    dest->sysCtx = src->sysCtx;
    dest->ioCtx = src->ioCtx;
    dest->protocolSide = src->protocolSide;
    if ((err = SSLDuplicateCertificateChain(src->localCert, dest)) != 0)
        return err;
    
    dest->localKey = src->localKey;
    dest->exportKey = src->exportKey;

    /* Copy ioRef */
    dest->ioCtx.ioRef = ioRef;
    
    return SSLNoErr;
}

static SSLErr
SSLDuplicateCertificateChain(SSLCertificate *chain, SSLContext *destCtx)
{   SSLErr      err;
    while (chain != 0)
    {   if ((err = SSLAddCertificate(destCtx, chain->derCert, 0, 0)) != 0)
            return err;
        chain = chain->next;
    }
    return SSLNoErr;
}

SSLErr
SSLDeleteContext(SSLContext *ctx)
{   WaitingRecord   *wait, *next;
    DNListElem      *dn, *nextDN;
    SSLBuffer       buf;
    
    SSLDeleteCertificateChain(ctx->localCert, ctx);
    SSLDeleteCertificateChain(ctx->peerCert, ctx);
    SSLFreeBuffer(&ctx->partialReadBuffer, &ctx->sysCtx);
    
    wait = ctx->recordWriteQueue;
    while (wait)
    {   SSLFreeBuffer(&wait->data, &ctx->sysCtx);
        next = wait->next;
        buf.data = (uint8*)wait;
        buf.length = sizeof(WaitingRecord);
        SSLFreeBuffer(&buf, &ctx->sysCtx);
        wait = next;
    }
    
#if BSAFE
    if (ctx->localKey)
        B_DestroyKeyObject(&ctx->localKey);
    if (ctx->exportKey)
        B_DestroyKeyObject(&ctx->exportKey);
    if (ctx->dhAnonParams)
        B_DestroyAlgorithmObject(&ctx->dhAnonParams);
    if (ctx->peerKey)
        B_DestroyKeyObject(&ctx->peerKey);
    if (ctx->peerDHParams)
        B_DestroyAlgorithmObject(&ctx->peerDHParams);
#endif
    
    SSLFreeBuffer(&ctx->dhPeerPublic, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->dhExchangePublic, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->dhPrivate, &ctx->sysCtx);
    
    SSLFreeBuffer(&ctx->shaState, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->md5State, &ctx->sysCtx);
    
    SSLFreeBuffer(&ctx->sessionID, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->peerID, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->resumableSession, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->preMasterSecret, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->partialReadBuffer, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->fragmentedMessageCache, &ctx->sysCtx);
    SSLFreeBuffer(&ctx->receivedDataBuffer, &ctx->sysCtx);

    dn = ctx->acceptableDNList;
    while (dn)
    {   SSLFreeBuffer(&dn->derDN, &ctx->sysCtx);
        nextDN = dn->next;
        buf.data = (uint8*)dn;
        buf.length = sizeof(DNListElem);
        SSLFreeBuffer(&buf, &ctx->sysCtx);
        dn = nextDN;
    }
    
    SSLDisposeCipherSuite(&ctx->readCipher, ctx);
    SSLDisposeCipherSuite(&ctx->writeCipher, ctx);
    SSLDisposeCipherSuite(&ctx->readPending, ctx);
    SSLDisposeCipherSuite(&ctx->writePending, ctx);
    
    memset(ctx, 0, sizeof(SSLContext));
    
    return SSLNoErr;
}

static SSLErr
SSLDeleteCertificateChain(SSLCertificate *chain, SSLContext *ctx)
{   SSLCertificate      *next;
    SSLBuffer           buf;
    
    while (chain)
    {   next = chain->next;
        ASNFreeX509Cert(&chain->cert, &ctx->sysCtx);
        SSLFreeBuffer(&chain->derCert, &ctx->sysCtx);
        buf.data = (uint8*)chain;
        buf.length = sizeof(SSLCertificate);
        SSLFreeBuffer(&buf, &ctx->sysCtx);
        chain = next;
    }
    
    return SSLNoErr;
}

SSLErr
SSLSetAllocFunc(SSLContext *ctx, SSLAllocFunc alloc)
{   ctx->sysCtx.alloc = alloc;
    return SSLNoErr;
}
SSLErr
SSLSetFreeFunc(SSLContext *ctx, SSLFreeFunc free)
{   ctx->sysCtx.free = free;
    return SSLNoErr;
}
SSLErr
SSLSetReallocFunc(SSLContext *ctx, SSLReallocFunc realloc)
{   ctx->sysCtx.realloc = realloc;
    return SSLNoErr;
}
SSLErr
SSLSetAllocRef(SSLContext *ctx, void* allocRef)
{   ctx->sysCtx.allocRef = allocRef;
    return SSLNoErr;
}
SSLErr
SSLSetTimeFunc(SSLContext *ctx, SSLTimeFunc time)
{   ctx->sysCtx.time = time;
    return SSLNoErr;
}
SSLErr
SSLSetConvertTimeFunc(SSLContext *ctx, SSLConvertTimeFunc convertTime)
{   ctx->sysCtx.convertTime = convertTime;
    return SSLNoErr;
}
SSLErr
SSLSetTimeRef(SSLContext *ctx, void* timeRef)
{   ctx->sysCtx.timeRef = timeRef;
    return SSLNoErr;
}
SSLErr
SSLSetRandomFunc(SSLContext *ctx, SSLRandomFunc random)
{   ctx->sysCtx.random = random;
    return SSLNoErr;
}
SSLErr
SSLSetRandomRef(SSLContext *ctx, void* randomRef)
{   ctx->sysCtx.randomRef = randomRef;
    return SSLNoErr;
}
SSLErr
SSLSetReadFunc(SSLContext *ctx, SSLIOFunc read)
{   ctx->ioCtx.read = read;
    return SSLNoErr;
}
SSLErr
SSLSetWriteFunc(SSLContext *ctx, SSLIOFunc write)
{   ctx->ioCtx.write = write;
    return SSLNoErr;
}
SSLErr
SSLSetIORef(SSLContext *ctx, void *ioRef)
{   ctx->ioCtx.ioRef = ioRef;
    return SSLNoErr;
}
SSLErr
SSLSetAddSessionFunc(SSLContext *ctx, SSLAddSessionFunc addSession)
{   ctx->sessionCtx.addSession = addSession;
    return SSLNoErr;
}
SSLErr
SSLSetGetSessionFunc(SSLContext *ctx, SSLGetSessionFunc getSession)
{   ctx->sessionCtx.getSession = getSession;
    return SSLNoErr;
}
SSLErr
SSLSetDeleteSessionFunc(SSLContext *ctx, SSLDeleteSessionFunc deleteSession)
{   ctx->sessionCtx.deleteSession = deleteSession;
    return SSLNoErr;
}
SSLErr
SSLSetSessionRef(SSLContext *ctx, void *sessionRef)
{   ctx->sessionCtx.sessionRef = sessionRef;
    return SSLNoErr;
}
SSLErr
SSLSetCheckCertificateFunc(SSLContext *ctx, SSLCheckCertificateFunc checkCertificate)
{   ctx->certCtx.checkCertFunc = checkCertificate;
    return SSLNoErr;
}
SSLErr
SSLSetCheckCertificateRef(SSLContext *ctx, void *checkCertificateRef)
{   ctx->certCtx.checkCertRef = checkCertificateRef;
    return SSLNoErr;
}
SSLErr
SSLSetProtocolSide(SSLContext *ctx, SSLProtocolSide side)
{   ctx->protocolSide = side;
    ctx->state = SSLUninitialized;
    return SSLNoErr;
}
SSLErr
SSLSetProtocolVersion(SSLContext *ctx, SSLProtocolVersion version)
{   ctx->protocolVersion = version;
    return SSLNoErr;
}
SSLErr
SSLSetRequestClientCert(SSLContext *ctx, int requestClientCert)
{   ctx->requestClientCert = requestClientCert;
    return SSLNoErr;
}
SSLErr
SSLSetPrivateKey(SSLContext *ctx, SSLRSAPrivateKey *privKey)
{   ctx->localKey = *privKey;
    return SSLNoErr;
}
SSLErr
SSLSetExportPrivateKey(SSLContext *ctx, SSLRSAPrivateKey *exportKey)
{   ctx->exportKey = *exportKey;
    return SSLNoErr;
}
SSLErr
SSLSetDHAnonParams(SSLContext *ctx, SSLDHParams *dhAnonParams)
{   ctx->dhAnonParams = *dhAnonParams;
    return SSLNoErr;
}
SSLErr
SSLSetPeerID(SSLContext *ctx, SSLBuffer peerID)
{   SSLErr  err;
    if ((err = SSLAllocBuffer(&ctx->peerID, peerID.length, &ctx->sysCtx)) != 0)
        return err;
    memcpy(ctx->peerID.data, peerID.data, peerID.length);
    return SSLNoErr;
}
SSLErr
SSLAddCertificate(SSLContext *ctx, SSLBuffer derCert, int parent, int complete)
{   SSLErr          err;
    SSLCertificate  *cert, *chain;
    SSLBuffer       certBuf;
    
    if (derCert.data != 0)
    {   if ((err = SSLAllocBuffer(&certBuf, sizeof(SSLCertificate), &ctx->sysCtx)) != 0)
            return err;
        cert = (SSLCertificate*)certBuf.data;
        cert->next = 0;
        cert->derCert = derCert;
        if ((err = ASNParseX509Certificate(derCert, &cert->cert, ctx)) != 0)
        {   SSLFreeBuffer(&certBuf, &ctx->sysCtx);
            return err;
        }
        
        /* Root cert is first in the chain, so put parents at front */
        if (parent)
        {   cert->next = ctx->localCert;
            ctx->localCert = cert;
        }
        else
        {   if (ctx->localCert == 0)
                ctx->localCert = cert;
            else
            {   chain = ctx->localCert;
                while (chain->next != 0)
                    chain = chain->next;
                chain->next = cert;
            }
        }
    }
    
    if (complete)
        if ((err = X509VerifyCertChain(ctx->localCert, ctx)) != 0)
            return err;
    
    return SSLNoErr;
}
SSLErr
SSLAddDistinguishedName(SSLContext *ctx, SSLBuffer derDN)
{   SSLBuffer       dnBuf;
    DNListElem      *dn;
    SSLErr          err;
    
    if ((err = SSLAllocBuffer(&dnBuf, sizeof(DNListElem), &ctx->sysCtx)) != 0)
        return err;
    dn = (DNListElem*)dnBuf.data;
    if ((err = SSLAllocBuffer(&dn->derDN, derDN.length, &ctx->sysCtx)) != 0)
    {   SSLFreeBuffer(&dnBuf, &ctx->sysCtx);
        return err;
    }
    memcpy(dn->derDN.data, derDN.data, derDN.length);
    dn->next = ctx->acceptableDNList;
    ctx->acceptableDNList = dn;
    return SSLNoErr;
}
SSLErr
SSLGetProtocolVersion(SSLContext *ctx, SSLProtocolVersion *version)
{   *version = ctx->protocolVersion;
    return SSLNoErr;
}
SSLErr
SSLGetPeerCertificateChainLength(SSLContext *ctx, int *chainLen)
{   SSLCertificate  *c;
    *chainLen = 0;
    c = ctx->peerCert;
    while (c)
    {   *chainLen += 1;
        c = c->next;
    }
    return SSLNoErr;
}   
SSLErr
SSLGetPeerCertificate(SSLContext *ctx, int index, SSLBuffer *derCert)
{   SSLErr          err;
    SSLCertificate  *c = ctx->peerCert;
    index -= 1;
    while (index > 0 && c != 0)
    {   --index;
        c = c->next;
    }
    if (c == 0)
        return SSLOverflowErr;
    if ((err = SSLAllocBuffer(derCert, c->derCert.length, &ctx->sysCtx)) != 0)
        return err;
    memcpy(derCert->data, c->derCert.data, c->derCert.length);
    return SSLNoErr;
}
SSLErr
SSLGetNegotiatedCipher(SSLContext *ctx, uint16 *cipherSuite)
{   *cipherSuite = ctx->selectedCipher;
    return SSLNoErr;
}
SSLErr
SSLGetWritePendingSize(SSLContext *ctx, uint32 *waitingBytes)
{   WaitingRecord *w;
    *waitingBytes = 0;
    w = ctx->recordWriteQueue;
    while (w != 0)
    {   *waitingBytes += w->data.length - w->sent;
        w = w->next;
    }
    return SSLNoErr;
}
SSLErr
SSLGetReadPendingSize(SSLContext *ctx, uint32 *waitingBytes)
{   *waitingBytes = ctx->partialReadBuffer.length - ctx->amountRead;
    return SSLNoErr;
}
