/*
 * transport.cc: socket transport class
 *
 * Copyright (c) 2000 Mount Linux Inc.
 * Licensed under the terms of the GPL
 */

#include <fcntl.h>
#include "config.h"
#include "defs.h"
#include "cipher.h"
#include "dh.h"
#include "bn.h"
#include "transport.h"

transport::transport(void)
{
    int one(1), flags;

    if ((sd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0)
    {
        throw socketCreationError();
    }

    kex = NULL;
    encipher = decipher = NULL;
    buffer = bptr = NULL;
    packet_length = padding_length = packet_read = 0;

    local_addr.sin_family = AF_INET;
    remote_addr.sin_family = AF_INET;

    local_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    remote_addr.sin_addr.s_addr = htonl(INADDR_ANY);

    flags = fcntl(sd, F_GETFL, 0);

#if defined(O_NONBLOCK)
    if (fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
    {
        throw socketSettingError("O_NONBLOCK");
    }
#else
    if (fcntl(sd, F_SETFL, flags | O_NDELAY) == -1)
    {
        throw socketSettingError("O_NDELAY");
    }
#endif

    if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, (void*) &one, sizeof(int)) < 0)
    {
        throw socketSettingError("SO_REUSEADDR");
    }

    if (setsockopt(sd, SOL_SOCKET, SO_KEEPALIVE, (void*) &one, sizeof(int)) < 0)
    {
        throw socketSettingError("SO_KEEPALIVE");
    }
}

transport::transport(int nsd)
{
    kex = NULL;
    encipher = decipher = NULL;
    buffer = bptr = NULL;
    packet_length = padding_length = packet_read = 0;
    sd = nsd;
}

transport::~transport(void)
{
    if (encipher != NULL)
        delete encipher;

    if (decipher != NULL)
        delete decipher;

    if (kex != NULL)
        delete kex;

    if (buffer != NULL)
        delete buffer;

    close(sd);
}

int transport::getsd(void)
{
    return sd;
}

int transport::bind(char* hostname)
{
    struct hostent* hp;

    hp = gethostbyname(hostname);
    memcpy(&local_addr.sin_addr, hp->h_addr_list[0], hp->h_length);

    return (::bind(sd, (struct sockaddr*) &local_addr, sizeof(local_addr)));
}

int transport::bind(char* hostname, int port)
{
    struct hostent* hp;

    if (hostname != NULL)
    {
        hp = gethostbyname(hostname);
        memcpy(&local_addr.sin_addr, hp->h_addr_list[0], hp->h_length);
    }
    local_addr.sin_port = htons(port);

    return (::bind(sd, (struct sockaddr*) &local_addr, sizeof(local_addr)));
}

int transport::connect(char* hostname, int port)
{
    struct hostent* hp;

    if (hostname != NULL)
    {
        hp = gethostbyname(hostname);

        memcpy(&remote_addr.sin_addr, hp->h_addr_list[0], hp->h_length);
    }
    remote_addr.sin_port = htons(port);

    addr_len = sizeof(remote_addr);

    return (::connect(sd, (struct sockaddr*) &remote_addr, addr_len));
}

int transport::listen(int backlog)
{
    return (::listen(sd, backlog));
}

int transport::accept(void)
{
    int nsd;

    if ((nsd = ::accept(sd, (struct sockaddr*) &remote_addr, &addr_len)) < 0)
        throw socketError(errno);

    return nsd;
}

void transport::kexInit(struct BigNum* p, struct BigNum* g)
{
    kex = new DH(RNG, p, g);
}

void transport::kexGenerate(struct BigNum* f)
{
    if (kex != NULL)
    {
        kex->generate(f);
    }
}

struct BigNum* transport::kexPublic(void)
{
    if (kex != NULL)
    {
        return kex->publickey();
    }
    return NULL;
}

void transport::setCipher(int type, bool client)
{
    unsigned char* key, *ivect;
    struct BigNum* privkey;
    int mlen, klen, ivlen;

    if (kex != NULL && (privkey = kex->privatekey()) != NULL)
    {
        klen = encipher->keylength(type);
        ivlen = encipher->blocklength(type);

        if (bnBytes(privkey) >= (klen + ivlen) * 2)
        {
            mlen = klen + ivlen;
        }
        else
        {
            mlen = 0;
        }

        key = new unsigned char[bnBytes(privkey)];
        bnExtractLittleBytes(privkey, key, 0, bnBytes(privkey));
        ivect = key + klen;

        if (client == true)
        {
            encipher = new cipher(type, key, klen, ivect, ivlen);
            decipher = new cipher(type, key+mlen, klen, ivect+mlen, ivlen);
        }
        else
        {
            encipher = new cipher(type, key+mlen, klen, ivect+mlen, ivlen);
            decipher = new cipher(type, key, klen, ivect, ivlen);
        }

        delete[] key;
    }
}

