#include "database.h"
#if defined(HAVE_CONFIG_H)
#include "config.h"
#endif
#include "linkopt.h"
#if defined(DBSUPP_MYSQL)
#include "mysql.h"
#endif
#include "omysql.h"
#include "messform.h"
#include <assert.h>
#include <map>

using namespace std;

DatabaseProtocol *NewDatabase(CLinkOptions &options);

class MemoryData : public DatabaseProtocol
{
public:
    MemoryData();

    virtual bool InsertFound(const std::string &url);
    virtual void InsertLink(const std::string &src_url,
	const std::string &dest_url);
    virtual std::set<std::string> GetDependencies(
	const std::string &url);

private:
    typedef map<string, set<string> > TChecked;
    TChecked Checked;
};

#if defined(DBSUPP_MYSQL)
class MysqlDatabase : public DatabaseProtocol
{
public:
    MysqlDatabase(CLinkOptions &options);
    virtual ~MysqlDatabase();

    virtual bool InsertFound(const std::string &url);
    virtual void InsertLink(const std::string &src_url,
	const std::string &dest_url);
    virtual std::set<std::string> GetDependencies(
	const std::string &url);

private:
    MYSQL *Sock;
    MYSQL MysqlState;
    string InsertLinkFormat;
    string InsertFoundFormat;
    string UpdateFoundFormat;
    string SelectDepsFormat;

    void Connect(CLinkOptions &options);
    void InitFormats(CLinkOptions &options);
    void ClearTable(const string &table);

    void Query(const string &q);

    static string GetQuoted(const string &url);
};

class MysqlResult
{
public:
    MYSQL_RES *Res;

    MysqlResult(MYSQL *Sock);
    ~MysqlResult();
};
#endif /* DBSUPP_MYSQL */

DatabaseProtocol::~DatabaseProtocol()
{
}

DatabaseProtocol *NewDatabase(CLinkOptions &options)
{
#if defined(DBSUPP_MYSQL)
    if (options.Get_bool(o_mysql_support))
	return new MysqlDatabase(options);
#endif
    return new MemoryData();
}

Database::Database(CLinkOptions &options):
    Impl(NewDatabase(options))
{
    try
    {   RefCount = new int(1);
    }
    catch (...)
    {   delete Impl;
        throw;
    }
}

void Database::DelRef()
{
    if (--(*RefCount))
        return;

    delete RefCount;
    delete Impl;
}

Database::Database(const Database &other):
    Impl(other.Impl),
    RefCount(other.RefCount)
{
    AddRef();
}

Database::~Database()
{
    DelRef();
}

Database &Database::operator=(const Database &other)
{
    other.AddRef();
    DelRef();
    Impl = other.Impl;
    RefCount = other.RefCount;
    return *this;
}

bool Database::InsertFound(const std::string &url)
{
    return Impl->InsertFound(url);
}

void Database::InsertLink(const std::string &src_url,
    const std::string &dest_url)
{
    Impl->InsertLink(src_url, dest_url);
}

std::set<std::string> Database::GetDependencies(
    const std::string &url)
{
    return Impl->GetDependencies(url);
}

MemoryData::MemoryData()
{
}

bool MemoryData::InsertFound(const std::string &url)
{
    if (Checked.find(url) == Checked.end())
    {   Checked.insert(TChecked::value_type(url,
	    set<string>()));
	return true;
    }
    else
	return false;
}

void MemoryData::InsertLink(const std::string &src_url,
    const std::string &dest_url)
{
    Checked[dest_url].insert(src_url);
}

std::set<std::string> MemoryData::GetDependencies(
    const std::string &url)
{
    TChecked::const_iterator i = Checked.find(url);
    if (i == Checked.end())
	return set<string>();
    else
	return i->second;
}

#if defined(DBSUPP_MYSQL)
MysqlDatabase::MysqlDatabase(CLinkOptions &options)
{
    Connect(options);
    InitFormats(options);
    // somebody should prevent multiple linkcheck instances
    // from running at the same time...
    ClearTable(options.Get(o_mysql_found_urls_table));
    ClearTable(options.Get(o_mysql_links_table));
}

void MysqlDatabase::Connect(CLinkOptions &options)
{
    Sock = mysql_connect(&MysqlState, 
	options.Get(o_mysql_host).c_str(),
	options.Get(o_mysql_user).c_str(),
	options.Get(o_mysql_password).c_str());
    if (!Sock)
	throw ConnectError(mysql_error(&MysqlState));

    std::string db_name = options.Get(o_mysql_database);
    if (mysql_select_db(Sock, db_name.c_str()))
    {   std::string msg(mysql_error(Sock));
        mysql_close(Sock);
	throw SelectError(db_name, msg);
    }
}

void MysqlDatabase::InitFormats(CLinkOptions &options)
{   
    MessageFormat insert_link("insert into %0 (%1, %2) ");
    insert_link << options.Get(o_mysql_links_table) <<
	options.Get(o_mysql_links_src_column) <<
	options.Get(o_mysql_links_dest_column);
    InsertLinkFormat = insert_link.GetMessage();
    InsertLinkFormat += "values('%0', '%1')";

    MessageFormat insert_found("insert into %0 (%1) ");
    insert_found << options.Get(o_mysql_found_urls_table) <<
	options.Get(o_mysql_found_urls_url_column);
    InsertFoundFormat = insert_found.GetMessage();
    InsertFoundFormat += "values('%0')";

    MessageFormat update_found("update %0 set %1=null where %2=");
    update_found << options.Get(o_mysql_found_urls_table) <<
	options.Get(o_mysql_found_urls_timestamp) <<
	options.Get(o_mysql_found_urls_url_column);
    UpdateFoundFormat = update_found.GetMessage();
    UpdateFoundFormat += "'%0'";

    MessageFormat select_deps("select %0 from %1 where %2=");
    select_deps << options.Get(o_mysql_links_src_column) <<
	options.Get(o_mysql_links_table) <<
	options.Get(o_mysql_links_dest_column);
    SelectDepsFormat = select_deps.GetMessage();
    SelectDepsFormat += "'%0'";
}
   
MysqlDatabase::~MysqlDatabase()
{
    mysql_close(Sock);
}

string MysqlDatabase::GetQuoted(const string &url)
{
    string out;

    const char *p = url.c_str();
    const char *q;
    while ((q = strchr(p, '\'')) != 0)
    {   out += string(p, q);
        out += "''";
	p = q + 1;
    }

    out += p;
    return out;
}

bool MysqlDatabase::InsertFound(const std::string &url)
{
    MessageFormat ifrm(InsertFoundFormat);
    ifrm << GetQuoted(url);

    string iqr = ifrm.GetMessage();
    int r = mysql_query(Sock, iqr.c_str());
    if (!r)
	return true;

    int err = mysql_errno(Sock);        
    if (err != 1062)
	throw QueryError(err, iqr, mysql_error(Sock));

    // url already exists - update the record

    MessageFormat ufrm(UpdateFoundFormat);
    ufrm << GetQuoted(url);
    Query(ufrm.GetMessage());

    return false;
}

void MysqlDatabase::ClearTable(const string &table)
{
    MessageFormat f("delete from %0");
    f << table;
    Query(f.GetMessage());
}

void MysqlDatabase::Query(const string &q)
{
    if (mysql_query(Sock, q.c_str()))
	throw QueryError(mysql_errno(Sock),
	    q,
	    mysql_error(Sock));
}

void MysqlDatabase::InsertLink(const std::string &src_url,
    const std::string &dest_url)
{
    MessageFormat qf(InsertLinkFormat);
    qf << GetQuoted(src_url) << GetQuoted(dest_url);

    string qs = qf.GetMessage();
    int r = mysql_query(Sock, qs.c_str());
    if (r)
    {   // "Duplicate entry" error (1062) is harmless
	// (just a page pointing to another more than
	// once) - ignore it
	int err = mysql_errno(Sock);        
        if (err != 1062)
	    throw QueryError(err, qs, mysql_error(Sock));
    }
}

std::set<std::string> MysqlDatabase::GetDependencies(
    const std::string &url)
{
    MessageFormat f(SelectDepsFormat);
    f << GetQuoted(url);

    Query(f.GetMessage());
    MysqlResult r(Sock);

    std::set<std::string> deps;
    MYSQL_ROW row;
    while ((row = mysql_fetch_row(r.Res)))
    {   assert(row[0]);
	deps.insert(string(row[0]));
    }

    return deps;
}

MysqlResult::MysqlResult(MYSQL *Sock)
{
    Res = mysql_store_result(Sock);
    if (!Res)
	throw ResultError(mysql_error(Sock));
}

MysqlResult::~MysqlResult()
{
    mysql_free_result(Res);
}

#endif /* DBSUPP_MYSQL */

ConnectError::ConnectError(const char *msg):
    Message(msg)
{
}

SelectError::SelectError(const std::string &db_name,
    const std::string &msg):
    DatabaseName(db_name),
    Message(msg)
{
}

QueryError::QueryError(int error_number,
    const std::string &query,
    const char *msg):
    ErrorNumber(error_number),
    Query(query),
    Message(msg)
{
}

ResultError::ResultError(const char *msg):
    Message(msg)
{
}

