/*
 * primepool.cc: maintains a set of safe prime numbers
 *
 * Copyright (c) 2000 Mount Linux Inc.
 * Licensed under the terms of the GPL
 */

#include "config.h"
#include "bn.h"
#include "germain.h"
#include "defs.h"
#include "random.h"
#include "primepool.h"

primePool::primePool(void)
{
    primePoolSet = NULL;
}

primePool::~primePool(void)
{
    primeSet *tmp, *next;

    for (tmp = primePoolSet; tmp; tmp = next)
    {
        next = tmp->nextPrime;
        delete tmp;
    }
}

primePool::primeSet::primeSet(struct BigNum *p, primeSet *next)
{
    prime = p;
    nextPrime = next;
}

primePool::primeSet::~primeSet(void)
{
    if (prime != NULL)
    {
        bnEnd(prime);
        delete prime;
    }
    nextPrime = NULL;
}

void primePool::add(struct BigNum *p)
{
    primeSet *tmp(primePoolSet), *last(NULL);
    int nbits(bnBits(p));

    while (tmp)
    {
        if (bnBits(tmp->prime) > nbits)
        {
            break;
        }
        last = tmp;
        tmp = tmp->nextPrime;
    }

    if (last == NULL)
    {
        tmp = new primeSet(p, tmp);
        primePoolSet = tmp; 
    }
    else last->nextPrime = new primeSet(p, tmp);
}

void primePool::remove(struct BigNum *p)
{
    primeSet *tmp(primePoolSet), *last(NULL);

    while (tmp)
    {
        if (tmp->prime == p)
        {
            if (last)
            {
                last->nextPrime = tmp->nextPrime;
            }
            else
            {
                primePoolSet = tmp->nextPrime;
            }
            delete tmp;
            break;
        }
        last = tmp;
        tmp = tmp->nextPrime;
    }
}

void primePool::generate(int nbits)
{
    struct BigNum *p(new struct BigNum);
    unsigned char *buffer;
    int nbytes((nbits + 7) >> 3);

    buffer = RNG->getRandom(nbytes);

    bnBegin(p);
    bnInsertLittleBytes(p, buffer, 0, nbytes);
    germainPrimeGen(p, 0, NULL, NULL);
    add(p);

    delete buffer;
}

struct BigNum *primePool::getPrime(int nbits)
{
    primeSet *tmp(primePoolSet);

    while (tmp)
    {
        if (bnBits(tmp->prime) >= nbits)
        {
            break;
        }
        tmp = tmp->nextPrime;
    }

    if (tmp == NULL)
    {
        generate(nbits);
        return getPrime(nbits);
    }
    return tmp->prime;
}
