/* $Id: op_acid_db.c,v 1.16 2004/04/03 19:57:32 andrewbaker Exp $ */
/*
** Copyright (C) 2001-2002 Andrew R. Baker <andrewb@snort.org>
**
** This program is distributed under the terms of version 1.0 of the 
** Q Public License.  See LICENSE.QPL for further details.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
**
*/

/*  I N C L U D E S  *****************************************************/


#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <time.h>
#include <errno.h>
#include <unistd.h>

#include "ConfigFile.h"
#include "mstring.h"
#include "sid.h"
#include "classification.h"
#include "util.h"
#include "input-plugins/dp_alert.h"
#include "input-plugins/dp_log.h"
#include "op_plugbase.h"
#include "op_decode.h"
#include "event.h"

#include <ctype.h>

#ifdef ENABLE_MYSQL
#include <mysql.h>
#include <errmsg.h>
#endif /* ENABLE_MYSQL */

#ifdef ENABLE_POSTGRES
#include <libpq-fe.h>
#endif /* ENABLE_POSTGRES */

/*  D A T A   S T R U C T U R E S  **************************************/
typedef struct _OpAcidDb_Data 
{
    u_int8_t flavor;  /* what flavor of db?  MySQL, postgres, ... */
    u_int8_t detail;
    u_int16_t unused;
    char *server;
    char *database;
    char *user;
    char *password;
    int sensor_id;
    u_int32_t event_id;
    int linktype;
    /* db handles go here */
#ifdef ENABLE_MYSQL
    MYSQL *mysql;
#endif /* ENABLE_MYSQL */
#ifdef ENABLE_POSTGRES
    PGconn *pq;
#endif /* ENABLE_POSTGRES */
} OpAcidDb_Data;


#define MAX_QUERY_SIZE 8192

/* database flavor defines */
#define FLAVOR_MYSQL    1
#define FLAVOR_POSTGRES 2

char *db_flavours[] = {NULL, "mysql", "postgres"};

/* Output plugin API functions */
int OpAcidDb_Setup(OutputPlugin *, char *args);
int OpAcidDb_Exit(OutputPlugin *);
int OpAcidDb_Start(OutputPlugin *, void *);
int OpAcidDb_Stop(OutputPlugin *);
int OpAcidDb_Log(void *, void *);
int OpAcidDb_Alert(void *, void *);
int OpAcidDb_LogConfig(OutputPlugin *outputPlugin);

/* Internal functions */
OpAcidDb_Data *OpAcidDb_ParseArgs(char *);
int DbClose(OpAcidDb_Data *data);
int DbConnect(OpAcidDb_Data *data);
u_int32_t AcidDbGetNextCid(OpAcidDb_Data *data);
u_int32_t AcidDbGetSensorId(OpAcidDb_Data *data);
int AcidDbCheckSchemaVersion(OpAcidDb_Data *data);
u_int32_t AcidDbGetSigId(OpAcidDb_Data *, Sid *, ClassType *, 
        unsigned int priority);
unsigned int GetAcidDbClassId(OpAcidDb_Data *data, ClassType *class_type);
int InsertSigReferences(OpAcidDb_Data *, ReferenceData *, unsigned int sig_id);
int InsertIPData(OpAcidDb_Data *data, Packet *p);
int InsertICMPData(OpAcidDb_Data *data, Packet *p);
int InsertUDPData(OpAcidDb_Data *data, Packet *p);
int InsertTCPData(OpAcidDb_Data *data, Packet *p);
int InsertPayloadData(OpAcidDb_Data *data, Packet *p);

int SelectAsUInt(OpAcidDb_Data *data, char *sql, unsigned int *result);
int Insert(OpAcidDb_Data *data, char *sql, unsigned int *row_id);
char *EscapeString(OpAcidDb_Data *data, char *string);

#ifdef ENABLE_MYSQL
int MysqlConnect(OpAcidDb_Data *);
int MysqlClose(MYSQL *mysql);
int MysqlSelectAsUInt(MYSQL *mysql, char *sql, unsigned int *result);
int MysqlInsert(MYSQL *mysql, char *sql, unsigned int *row_id);
char *MysqlEscapeString(MYSQL *mysql, char *string);
#endif

#ifdef ENABLE_POSTGRES
int PostgresConnect(OpAcidDb_Data *);
int PostgresClose(PGconn *);
int PostgresSelectAsUInt(PGconn *, char *sql, unsigned int *result);
int PostgresInsert(PGconn *, char *sql);
char *PostgresEscapeString(PGconn *, char *string);
#endif /* ENABLE_POSTGRES */

/* Global variables */
static char sql_buffer[MAX_QUERY_SIZE];

/* init routine makes this processor available for dataprocessor directives */
void OpAcidDb_Init()
{
    OutputPlugin *outputPlugin;

    outputPlugin = RegisterOutputPlugin("alert_acid_db", "alert");
    outputPlugin->setupFunc = OpAcidDb_Setup;
    outputPlugin->exitFunc = OpAcidDb_Exit;
    outputPlugin->startFunc = OpAcidDb_Start;
    outputPlugin->stopFunc = OpAcidDb_Stop;
    outputPlugin->outputFunc = OpAcidDb_Alert;
    outputPlugin->logConfigFunc = OpAcidDb_LogConfig;
    
    outputPlugin = RegisterOutputPlugin("log_acid_db", "log");
    outputPlugin->setupFunc = OpAcidDb_Setup;
    outputPlugin->exitFunc = OpAcidDb_Exit;
    outputPlugin->startFunc = OpAcidDb_Start;
    outputPlugin->stopFunc = OpAcidDb_Stop;
    outputPlugin->outputFunc = OpAcidDb_Log;
    outputPlugin->logConfigFunc = OpAcidDb_LogConfig;
    
}


/* Setup the output plugin, process any arguments, link the functions to
 * the output functional node
 */
int OpAcidDb_Setup(OutputPlugin *outputPlugin, char *args)
{
    /* setup the run time context for this output plugin */
    outputPlugin->data = OpAcidDb_ParseArgs(args);

    return 0;
}

/* Inverse of the setup function, free memory allocated in Setup 
 * can't free the outputPlugin since it is also the list node itself
 */
int OpAcidDb_Exit(OutputPlugin *outputPlugin)
{
    return 0;
}

int OpAcidDb_LogConfig(OutputPlugin *outputPlugin)
{
    OpAcidDb_Data *data = NULL;

    if(!outputPlugin || !outputPlugin->data)
        return -1;

    data = (OpAcidDb_Data *)outputPlugin->data;

    LogMessage("OpAcidDB configured\n");
    LogMessage("  Database Flavour: %s\n", db_flavours[data->flavor]);
    LogMessage("  Detail Level: %s\n", data->detail == 1 ? "Full" : "Fast");
    LogMessage("  Database Server: %s\n", data->server);
    LogMessage("  Database User: %s\n", data->user);
    return 0;
}

/* 
 * this function gets called at start time, you should open any output files
 * or establish DB connections, etc, here
 */
int OpAcidDb_Start(OutputPlugin *outputPlugin, void *spool_header)
{
    OpAcidDb_Data *data = (OpAcidDb_Data *)outputPlugin->data;

    if(data == NULL)
        FatalError("ERROR: Unable to find context for AcidDb startup!\n");
    
    if(pv.verbose)
    {
        OpAcidDb_LogConfig(outputPlugin);
    }
    
    /* Connect to the database */
    if(DbConnect(data))
        FatalError("OpAcidDb_: Failed to connect to database: %s:%s@%s/%s\n",
                data->user, data->password, data->server, data->database);


    /* check the db schema */
    if(AcidDbCheckSchemaVersion(data))
        FatalError("OpAcidDb_: database schema mismatch\n");
 
    /* if sensor id == 0, then we attempt attempt to determine it dynamically */
    if(data->sensor_id == 0)
    {
        data->sensor_id = AcidDbGetSensorId(data);
        /* XXX: Error checking */
    }
    /* Get the next cid from the database */
    data->event_id = AcidDbGetNextCid(data);

    if(pv.verbose)
    {
        LogMessage("SensorID: %i\n", data->sensor_id);
        LogMessage("Next CID: %i\n", data->event_id);
    }
    return 0;
}

int OpAcidDb_Stop(OutputPlugin *outputPlugin)
{
    OpAcidDb_Data *data = (OpAcidDb_Data *)outputPlugin->data;

    if(data == NULL)
        FatalError("ERROR: Unable to find context for AcidDb startup!\n");

    /* close database connection */
    DbClose(data);
    
    return 0;
}

int OpAcidDb_Alert(void *context, void *data)
{
    char timestamp[TIMEBUF_SIZE];
    Sid *sid = NULL;
    ClassType *class_type = NULL;
    UnifiedAlertRecord *record = (UnifiedAlertRecord *)data; 
    OpAcidDb_Data *op_data = (OpAcidDb_Data *)context;
    u_int32_t acid_sig_id;

    RenderTimestamp(record->ts.tv_sec, timestamp, TIMEBUF_SIZE);
    sid = GetSid(record->event.sig_generator, record->event.sig_id);
    if(sid == NULL)
        sid = FakeSid(record->event.sig_generator, record->event.sig_id);
    
    
    if(!(class_type = GetClassType(record->event.classification)) 
            && record->event.classification != 0)
    {
        LogMessage("WARNING: No ClassType found for classification '%i'\n",
                record->event.classification);
    }
    
    if((acid_sig_id = AcidDbGetSigId(op_data, sid, class_type, 
            record->event.priority)) == 0)
    {
        FatalError("op_acid_db:  Failed to retrieve ACID DB sig id\n");
    }
    
    /* Insert data into the event table */
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO event(sid, cid, signature, timestamp) "
                "VALUES('%u', '%u', '%u', '%s')",
                op_data->sensor_id, op_data->event_id, acid_sig_id,
                timestamp) < MAX_QUERY_SIZE)
    {
        //LogMessage("SQL: %s\n", sql_buffer);
        Insert(op_data, sql_buffer, NULL);     /* XXX: Error checking */
    }
    /* insert data into the ip header table */
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "INSERT INTO iphdr(sid, cid, ip_src, ip_dst, ip_proto) "
            "VALUES('%u', '%u', '%u', '%u', '%u')",
            op_data->sensor_id, op_data->event_id, record->sip,
            record->dip, record->protocol) < MAX_QUERY_SIZE)
    {
        Insert(op_data, sql_buffer, NULL); /* XXX: Error checking */
    }
    /* build the protocol specific header information */
    switch(record->protocol)
    {
        case IPPROTO_TCP:
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO tcphdr (sid, cid, tcp_sport, tcp_dport, "
                    "tcp_flags) VALUES('%u', '%u', '%u', '%u', 0)", 
                    op_data->sensor_id, op_data->event_id, record->sp,
                    record->dp) < MAX_QUERY_SIZE)
            {
                Insert(op_data, sql_buffer, NULL); /* XXX: Error checking */
            }
            break;
        case IPPROTO_UDP:
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO udphdr (sid, cid, udp_sport, udp_dport) "
                    "VALUES('%u', '%u', '%u', '%u')", 
                    op_data->sensor_id, op_data->event_id, record->sp,
                    record->dp) < MAX_QUERY_SIZE)
            {
                Insert(op_data, sql_buffer, NULL);  /* XXX: Error checking */
            }
            break;
        case IPPROTO_ICMP:
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO icmphdr (sid, cid, icmp_type, icmp_code) "
                    "VALUES('%u', '%u', '%u', '%u')", 
                    op_data->sensor_id, op_data->event_id, record->sp,
                    record->dp) < MAX_QUERY_SIZE)
            {
                Insert(op_data, sql_buffer, NULL); /* XXX: Error Checking */
            }
            break;
    }
    ++op_data->event_id;
    return 0;
}

int OpAcidDb_Log(void *context, void *data)
{
    char timestamp[TIMEBUF_SIZE];
    Sid *sid = NULL;
    ClassType *class_type;
    UnifiedLogRecord *record = (UnifiedLogRecord *)data; 
    OpAcidDb_Data *op_data = (OpAcidDb_Data *)context;
    u_int32_t acid_sig_id;
    Packet p;

#if 0 /* this is broken */
    /* skip tagged packets, since the db does not have a mechanism to 
     * deal with them properly
     */
    if(record->log.event.event_reference)
    {
        LogMessage("Skipping tagged packet %i\n", record->log.event.event_reference);
        return 0;
    }
#endif
    

    RenderTimestamp(record->log.pkth.ts.tv_sec, timestamp, TIMEBUF_SIZE);
    sid = GetSid(record->log.event.sig_generator, record->log.event.sig_id);
    if(sid == NULL)
        sid = FakeSid(record->log.event.sig_generator, record->log.event.sig_id);
    class_type = GetClassType(record->log.event.classification);

    if((acid_sig_id = AcidDbGetSigId(op_data, sid, class_type, 
            record->log.event.priority)) == 0)
    {
        FatalError("op_acid_db:  Failed to retrieve ACID DB sig id\n");
    }

    /* Insert data into the event table */
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "INSERT INTO event(sid, cid, signature, timestamp) "
            "VALUES('%u', '%u', '%u', '%s')", 
            op_data->sensor_id, op_data->event_id, acid_sig_id,
            timestamp) < MAX_QUERY_SIZE)
    {
        //LogMessage("SQL: %s\n", sql_buffer);
        Insert(op_data, sql_buffer, NULL);  /* XXX: Error checking */
    }
    /* decode the packet */
    if(DecodePacket(&p, &record->log.pkth, record->pkt + 2) == 0)
    {
        if(p.iph)
        {
            /* Insert ip header information */
            InsertIPData(op_data, &p);

            /* store layer 4 data for non fragmented packets */
            if(!(p.pkt_flags & PKT_FRAG_FLAG))
            {
                switch(p.iph->ip_proto)
                {
                    case IPPROTO_ICMP:
                        InsertICMPData(op_data, &p);
                        break;
                    case IPPROTO_TCP:
                        InsertTCPData(op_data, &p);
                        break;
                    case IPPROTO_UDP:
                        InsertUDPData(op_data, &p);
                        break;
                }
            }

            /* Insert payload data */
            if(op_data->detail)
                InsertPayloadData(op_data, &p);
        }
    }
    ++op_data->event_id;
    return 0;
}

int InsertIPData(OpAcidDb_Data *op_data, Packet *p)
{
    if(op_data->detail)
    {
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO iphdr(sid, cid, ip_src, ip_dst, ip_proto, "
                "ip_ver, ip_hlen, ip_tos, ip_len, ip_id, ip_flags, ip_off, "
                "ip_ttl, ip_csum) VALUES('%u', '%u', '%u', '%u', '%u', "
                "'%u', '%u', '%u', '%u', '%u', '%u', '%u', "
                "'%u', '%u')",
                op_data->sensor_id, op_data->event_id, 
                ntohl(p->iph->ip_src.s_addr), ntohl(p->iph->ip_dst.s_addr), 
                p->iph->ip_proto, IP_VER(p->iph), IP_HLEN(p->iph),
                p->iph->ip_tos, ntohs(p->iph->ip_len), ntohs(p->iph->ip_id),
#if defined(WORDS_BIGENDIAN)
                ((p->iph->ip_off & 0xE000) >> 13),
                htons(p->iph->ip_off & 0x1FFF),
#else
                ((p->iph->ip_off & 0x00E0) >> 5),
                htons(p->iph->ip_off & 0xFF1F), 
#endif
                p->iph->ip_ttl,
                htons(p->iph->ip_csum)) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
        /* XXX: IP Options not handled */
    }
    else
    {
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO iphdr(sid, cid, ip_src, ip_dst, ip_proto) "
                "VALUES('%u', '%u', '%u', '%u', '%u')",
                op_data->sensor_id, op_data->event_id, 
                ntohl(p->iph->ip_src.s_addr), ntohl(p->iph->ip_dst.s_addr), 
                p->iph->ip_proto) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
    }
    return 0;
}

int InsertUDPData(OpAcidDb_Data *op_data, Packet *p)
{
    if(!p->udph)
        return 0;
    if(op_data->detail)
    {
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO udphdr(sid, cid, udp_sport, udp_dport, udp_len, "
                "udp_csum) VALUES('%u', '%u', '%u', '%u', '%u', '%u')", 
                op_data->sensor_id, op_data->event_id, p->sp,
                p->dp, ntohs(p->udph->uh_len), 
                ntohs(p->udph->uh_chk)) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
    }
    else
    {
        /* insert data into the udp header table */
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO udphdr (sid, cid, udp_sport, udp_dport) "
                "VALUES('%u', '%u', '%u', '%u')", 
                op_data->sensor_id, op_data->event_id, p->sp,
                p->dp) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
    }
    return 0;
}

int InsertTCPData(OpAcidDb_Data *op_data, Packet *p)
{
    if(!p->tcph)
        return 0;

    /* insert data into the tcp header table */
    if(op_data->detail)
    {
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO tcphdr(sid, cid, tcp_sport, tcp_dport, tcp_seq, "
                "tcp_ack, tcp_off, tcp_res, tcp_flags, tcp_win, tcp_csum, "
                "tcp_urp) VALUES('%u', '%u', '%u', '%u', '%u', "
                "'%u', '%u', '%u', '%u', '%u', '%u', '%u')",
                op_data->sensor_id, op_data->event_id, p->sp, 
                p->dp, ntohl(p->tcph->th_seq), ntohl(p->tcph->th_ack),
                TCP_OFFSET(p->tcph), TCP_X2(p->tcph), p->tcph->th_flags,
                ntohs(p->tcph->th_win), ntohs(p->tcph->th_sum),
                ntohs(p->tcph->th_urp)) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error checking */
        }
        /* XXX: TCP Options not handled */
    }
    else
    {
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO tcphdr (sid, cid, tcp_sport, tcp_dport) "
                "VALUES('%u', '%u', '%u', '%u')", 
                op_data->sensor_id, op_data->event_id, p->sp,
                p->dp) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error checking */
        }
    }
    return 0;
}

int InsertICMPData(OpAcidDb_Data *op_data, Packet *p)
{
    if(!p->icmph)
        return 0;
    if(op_data->detail) 
    {
        if(p->icmph->icmp_type == 0 || p->icmph->icmp_type == 8 ||
                p->icmph->icmp_type == 13 || p->icmph->icmp_type == 14 ||
                p->icmph->icmp_type == 15 || p->icmph->icmp_type == 16)
        {
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO icmphdr(sid, cid, icmp_type, icmp_code, "
                    "icmp_csum, icmp_id, icmp_seq) "
                    "VALUES('%u', '%u', '%u', '%u', '%u', '%u', '%u')", 
                    op_data->sensor_id, op_data->event_id, p->icmph->icmp_type,
                    p->icmph->icmp_code, ntohs(p->icmph->icmp_csum),
                    htons(p->icmph->icmp_hun.ih_idseq.icd_id),
                    htons(p->icmph->icmp_hun.ih_idseq.icd_seq)) 
                    < MAX_QUERY_SIZE)
            {
                Insert(op_data, sql_buffer, NULL);  /* XXX: Error checking */
            }
        }
        else
        {
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO icmphdr(sid, cid, icmp_type, icmp_code, "
                    "icmp_csum) VALUES('%u', '%u', '%u', '%u', '%u')", 
                    op_data->sensor_id, op_data->event_id, p->icmph->icmp_type,
                    p->icmph->icmp_code, ntohs(p->icmph->icmp_csum))
                    < MAX_QUERY_SIZE)
            {
                Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
            }
        }
    }
    else
    {
        /* insert data into the icmp header table */
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO icmphdr (sid, cid, icmp_type, icmp_code) "
                "VALUES('%u', '%u', '%u', '%u')", 
                op_data->sensor_id, op_data->event_id, p->icmph->icmp_type,
                p->icmph->icmp_code) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
    }
    return 0;
}

int InsertPayloadData(OpAcidDb_Data *op_data, Packet *p)
{
    char *hex_payload;
    if(p->dsize)
    {
        hex_payload = fasthex(p->data, p->dsize);
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "INSERT INTO data(sid, cid, data_payload) "
                "VALUES('%u', '%u', '%s')", op_data->sensor_id, 
                op_data->event_id, hex_payload) < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL);  /* XXX: Error Checking */
        }
        free(hex_payload);
    }
    return 0;
}


/* Attempts to retrieve the sensor id
 */

static int OpAcidDb_GetSensorId(OpAcidDb_Data *op_data, char *hostname,
        char *interface, char *bpf_filter, u_int32_t detail, 
        u_int32_t *sensor_id)
{
    int rval;
    char *e_hostname = NULL;
    char *e_interface = NULL;
    char *e_filter = NULL;
    if(!hostname)
        hostname = "localhost";
    if(!interface)
        interface = "";
    if(!bpf_filter)
        bpf_filter = "";

    if(!(e_hostname = EscapeString(op_data, hostname)))
        FatalError("Failed to escape string");
    
    if(!(e_interface = EscapeString(op_data, interface)))
        FatalError("Failed to escape string");
    
    if(!(e_filter = EscapeString(op_data, bpf_filter)))
        FatalError("Failed to escape string");

    if(snprintf(sql_buffer, MAX_QUERY_SIZE, 
                "SELECT sid FROM sensor WHERE hostname='%s' AND interface='%s' "
                "AND filter='%s' AND detail='%u' AND encoding='0'", e_hostname,
                e_interface, e_filter, detail) < MAX_QUERY_SIZE)
    {
        rval = SelectAsUInt(op_data, sql_buffer, sensor_id);
        free(e_filter);
        free(e_interface);
        free(e_hostname);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}

static int OpAcidDb_InsertSensor(OpAcidDb_Data *op_data, char *hostname,
        char *interface, char *bpf_filter, u_int32_t detail, 
        unsigned int *sensor_id)
{
    int rval;
    char *e_hostname = NULL;
    char *e_interface = NULL;
    char *e_filter = NULL;
    if(!hostname)
        hostname = "localhost";
    if(!interface)
        interface = "";
    if(!bpf_filter)
        bpf_filter = "";

    if(!(e_hostname = EscapeString(op_data, hostname)))
        FatalError("Failed to escape string");
    
    if(!(e_interface = EscapeString(op_data, interface)))
        FatalError("Failed to escape string");
    
    if(!(e_filter = EscapeString(op_data, bpf_filter)))
        FatalError("Failed to escape string");

    if(snprintf(sql_buffer, MAX_QUERY_SIZE, "INSERT INTO sensor(hostname, "
                "interface, filter, detail, encoding, last_cid) "
                "VALUES('%s', '%s', '%s', '%u', '0', '0')", e_hostname,
                e_interface, e_filter, detail) < MAX_QUERY_SIZE)
    {
        rval = Insert(op_data, sql_buffer, sensor_id);
        free(e_filter);
        free(e_interface);
        free(e_hostname);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}


unsigned int AcidDbGetSensorId(OpAcidDb_Data *op_data)
{
    unsigned int sensor_id = 0;
    char *interface = "";
    /* we need three things to determine the sensor id: interface, hostname, 
     * filter
     * of these three, interface must be specified (ie we won't query it)
     */
    if(pv.interface)
        interface = pv.interface;

    OpAcidDb_GetSensorId(op_data, pv.hostname, interface, pv.bpf_filter,
            op_data->detail, &sensor_id);
    if(sensor_id == 0)
    {
        OpAcidDb_InsertSensor(op_data, pv.hostname, interface, pv.bpf_filter,
                op_data->detail, &sensor_id);
        if(sensor_id == -1)
        {
            OpAcidDb_GetSensorId(op_data, pv.hostname, interface, 
                    pv.bpf_filter, op_data->detail, &sensor_id);
        }
    }
    if(pv.verbose >= 2)
        LogMessage("sensor_id == %u\n", sensor_id);
    return sensor_id;
}

/* Retrieves the next acid_cid to use for inserting into the database for this
 * sensor
 */
unsigned int AcidDbGetNextCid(OpAcidDb_Data *data)
{
    unsigned int cid = 0;
    if(snprintf(sql_buffer, MAX_QUERY_SIZE, 
                "SELECT max(cid) FROM event WHERE sid='%u'", data->sensor_id) 
            < MAX_QUERY_SIZE)
    {
        if(SelectAsUInt(data, sql_buffer, &cid) == -1)
        {
            FatalError("Database Error\n");
        }
#ifdef DEBUG
        LogMessage("cid == %u\n", cid); fflush(stdout);
#endif
    }
    else
    {
        FatalError("Database Error\n");
    } 
    return ++cid;
}

int AcidDbCheckSchemaVersion(OpAcidDb_Data *data)
{
    return 0;
}


/* 
 * Returns 1 on success
 */
static int OpAcidDb_GetSigId(OpAcidDb_Data *op_data, char *msg, u_int32_t rev,
        u_int32_t sid, u_int32_t *sig_id)
{
    int rval;
    char *e_message = NULL;

    if(!msg)
        msg = "";

    if(!(e_message = EscapeString(op_data, msg)))
        FatalError("Failed to escape string");

    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "SELECT sig_id FROM signature WHERE sig_name='%s' AND sig_rev=%u "
            "AND sig_sid=%u", e_message, rev, sid) < MAX_QUERY_SIZE)
    {
        rval = SelectAsUInt(op_data, sql_buffer, sig_id);
        free(e_message);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}

static int OpAcidDb_GetClassId(OpAcidDb_Data *op_data, char *class_name,
        u_int32_t *class_id)
{
    int rval;
    char *e_class_name = NULL;

    if(!class_name)
        class_name = "";

    if(!(e_class_name = EscapeString(op_data, class_name)))
        FatalError("Failed to escape string");
    
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "SELECT sig_class_id FROM sig_class WHERE sig_class_name='%s'",
            e_class_name) < MAX_QUERY_SIZE)
    {
        rval = SelectAsUInt(op_data, sql_buffer, class_id);
        free(e_class_name);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}

static int OpAcidDb_GetRefSystemId(OpAcidDb_Data *op_data, char *system,
        u_int32_t *ref_system_id)
{
    int rval;
    char *e_system = NULL;

    if(!system)
        system = "";

    if(!(e_system = EscapeString(op_data, system)))
        FatalError("Failed to escape string");

    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "SELECT ref_system_id FROM reference_system WHERE "
                "ref_system_name='%s'", e_system) < MAX_QUERY_SIZE)
    {
        rval = SelectAsUInt(op_data, sql_buffer, ref_system_id);
        free(e_system);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}

static int OpAcidDb_GetReferenceId(OpAcidDb_Data *op_data, 
        u_int32_t ref_system_id, char *ref_tag, u_int32_t *ref_id)
{
    int rval;
    char *e_ref_tag = NULL;

    if(!ref_tag)
        ref_tag = "";

    if(!(e_ref_tag = EscapeString(op_data, ref_tag)))
        FatalError("Failed to escape string");
    
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                "SELECT ref_id FROM reference WHERE ref_system_id=%u AND "
                "ref_tag='%s'", ref_system_id, e_ref_tag) < MAX_QUERY_SIZE)
    {
        rval = SelectAsUInt(op_data, sql_buffer, ref_id);
        free(e_ref_tag);
        return rval;
    }
    FatalError("SQL query too big");
    return -1;
}


/* looks up the acid sig_id for a message and returns it.  If no sig_id exists,
 * one is created
 * XXX:  Unfortunately, the db does not use the same sig_ids that snort does
 */
u_int32_t AcidDbGetSigId(OpAcidDb_Data *op_data, Sid *sid, 
        ClassType *class_type, unsigned int priority)
{
    char *e_message;
    unsigned int sig_id = 0;
    unsigned int class_id = 0;

    if(!sid)
        return 0;
    
    if(OpAcidDb_GetSigId(op_data, sid->msg, sid->rev, sid->sid, &sig_id) == 1)
        return sig_id;
    
    /* Create a new signature entry */
    class_id = GetAcidDbClassId(op_data, class_type);

    if(!(e_message = EscapeString(op_data, sid->msg ? sid->msg : "")))
        FatalError("Failed to escape string");
    
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "INSERT INTO signature(sig_name, sig_class_id, sig_priority, "
            "sig_rev, sig_sid) VALUES('%s', '%u', '%u', '%u', '%u')",
            e_message, class_id, priority, sid->rev, sid->sid) < MAX_QUERY_SIZE)
    {
        Insert(op_data, sql_buffer, &sig_id); /* XXX: Error checking */
        free(e_message);
        if(sig_id == -1)
        {
            OpAcidDb_GetSigId(op_data, sid->msg, sid->rev, sid->sid, &sig_id);
        }
    }
    else
    {
        FatalError("SQL query too big");
    }
    
    InsertSigReferences(op_data, sid->ref, sig_id);
    
    return sig_id;
}

/* looks up the acid class_id since it does not use the standard snort ids */
unsigned int GetAcidDbClassId(OpAcidDb_Data *op_data, ClassType *class_type)
{
    char *e_class_name;
    unsigned int class_id = 0;

    if(!class_type || !class_type->type)
        return 0;

    if(OpAcidDb_GetClassId(op_data, class_type->type, &class_id) == 1)
        return class_id;
    
    if(!(e_class_name = EscapeString(op_data, 
                    class_type->type ? class_type->type : "")))
        FatalError("Failed to escape string");
    
    /* Insert a new sig_class record */
    if(snprintf(sql_buffer, MAX_QUERY_SIZE,
            "INSERT INTO sig_class(sig_class_name) VALUES('%s')", 
            e_class_name) < MAX_QUERY_SIZE)
    {
        Insert(op_data, sql_buffer, &class_id); /* XXX: Error checking */
        free(e_class_name);
        if(class_id == -1) 
        {
            OpAcidDb_GetClassId(op_data, class_type->type, &class_id);
        }
    }
    else
    {
        FatalError("SQL query too big");
    }
    return class_id;
}

int InsertSigReferences(OpAcidDb_Data *op_data, ReferenceData *ref, 
        unsigned int sig_id)
{
    unsigned int ref_system_id = 0;
    unsigned int ref_id = 0;
    unsigned int ref_seq = 1;
    char *e_ref_system_name;
    char *e_ref_tag;

    while(ref != NULL)
    {
        if(OpAcidDb_GetRefSystemId(op_data, ref->system, &ref_system_id) != 1)
        {
            /* if not found */
            if(!(e_ref_system_name = EscapeString(op_data, 
                            ref->system ? ref->system : "")))
                FatalError("Failed to escape string");
                
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                        "INSERT INTO reference_system(ref_system_name) "
                        "VALUES('%s')",
                        e_ref_system_name) < MAX_QUERY_SIZE)
            {
                free(e_ref_system_name);
                Insert(op_data, sql_buffer, &ref_system_id); 
                if(ref_system_id == -1)
                {
                    OpAcidDb_GetRefSystemId(op_data, ref->system, 
                            &ref_system_id);
                }
            }
            else
            {
                FatalError("SQL query too big");
            }
        }

        /* Get the reference id */
        if(OpAcidDb_GetReferenceId(op_data, ref_system_id, ref->id, &ref_id) 
                != 1)
        {
            /* if not found */
            if(!(e_ref_tag = EscapeString(op_data, ref->id ? ref->id : "")))
                FatalError("Failed to escape string");
            
            if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                        "INSERT INTO reference(ref_system_id, ref_tag) "
                        "VALUES('%u', '%s')", ref_system_id, e_ref_tag) 
                    < MAX_QUERY_SIZE)
            {
                free(e_ref_tag);
                Insert(op_data, sql_buffer, &ref_id); 
                if(ref_id == -1)
                {
                    OpAcidDb_GetReferenceId(op_data, ref_system_id, ref->id, 
                            &ref_id);
                }
            }
            else
            {
                FatalError("SQL query too big");
            }
        }

        /* INSERT record into the sig_reference join table */    
        if(snprintf(sql_buffer, MAX_QUERY_SIZE,
                    "INSERT INTO sig_reference(sig_id, ref_seq, ref_id) "
                    "VALUES('%u', '%u', '%u')", sig_id, ref_seq, ref_id) 
                < MAX_QUERY_SIZE)
        {
            Insert(op_data, sql_buffer, NULL); /* XXX: Error checking */
        }

        ++ref_seq;
        ref = ref->next;
    }
    return 0;
}

OpAcidDb_Data *OpAcidDb_ParseArgs(char *args)
{
    OpAcidDb_Data *op_data;

    op_data = (OpAcidDb_Data *)SafeAlloc(sizeof(OpAcidDb_Data));

    if(args != NULL)
    {
        char **toks;
        int num_toks;
        int i;
        /* parse out your args */
        toks = mSplit(args, ",", 31, &num_toks, '\\');
        for(i = 0; i < num_toks; ++i)
        {
            char **stoks;
            int num_stoks;
            char *index = toks[i];
            while(isspace((int)*index))
                ++index;
            stoks = mSplit(index, " ", 2, &num_stoks, 0);
            if(strcasecmp("database", stoks[0]) == 0)
            {
                if(num_stoks > 1 && op_data->database == NULL)
                    op_data->database = strdup(stoks[1]);
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
            else if(strcasecmp("server", stoks[0]) == 0)
            {
                if(num_stoks > 1 && op_data->server == NULL)
                    op_data->server = strdup(stoks[1]);
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
            else if(strcasecmp("user", stoks[0]) == 0)
            {
                if(num_stoks > 1 && op_data->user == NULL)
                    op_data->user = strdup(stoks[1]);
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
            else if(strcasecmp("password", stoks[0]) == 0)
            {
                if(num_stoks > 1 && op_data->password == NULL)
                    op_data->password = strdup(stoks[1]);
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
            else if(strcasecmp("sensor_id", stoks[0]) == 0)
            {
                if(num_stoks > 1 && op_data->sensor_id == 0)
                    op_data->sensor_id = atoi(stoks[1]);
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
            else if(strcasecmp("detail", stoks[0]) == 0)
            {
                if(num_stoks > 1)
                {
                    if(strcasecmp("full", stoks[1]) == 0)
                        op_data->detail = 1;
                }
                else 
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
#ifdef ENABLE_MYSQL
            else if(strcasecmp("mysql", stoks[0]) == 0)
            {   
                if(op_data->flavor == 0)
                    op_data->flavor = FLAVOR_MYSQL;
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name, 
                            file_line, index);
            }
#endif /* ENABLE_MYSQL */
#ifdef ENABLE_POSTGRES
            else if(strcasecmp("postgres", stoks[0]) == 0)
            {
                if(op_data->flavor == 0)
                    op_data->flavor = FLAVOR_POSTGRES;
                else
                    LogMessage("Argument Error in %s(%i): %s\n", file_name,
                            file_line, index);
            }
#endif /* ENABLE_POSTGRES */
            else
            {
                fprintf(stderr, "WARNING %s (%d) => Unrecognized argument for "
                        "AcidDb plugin: %s\n", file_name, file_line, index);
            }
            FreeToks(stoks, num_stoks);
        }
        /* free your mSplit tokens */
        FreeToks(toks, num_toks);
    }
    if(op_data->flavor == 0)
        FatalError("You must specify a database flavor\n");
    return op_data;
}


int DbConnect(OpAcidDb_Data *op_data)
{
    switch(op_data->flavor)
    {
#ifdef ENABLE_MYSQL
        case FLAVOR_MYSQL:
            return MysqlConnect(op_data);
#endif
#ifdef ENABLE_POSTGRES
        case FLAVOR_POSTGRES:
            return PostgresConnect(op_data);
#endif /* ENABLE_POSTGRES */
        default:
            FatalError("Database flavor not supported\n");
            return 1;
    }
//    return 1;
}

int DbClose(OpAcidDb_Data *op_data)
{
    int rval = 0;
    switch(op_data->flavor)
    {
#ifdef ENABLE_MYSQL
        case FLAVOR_MYSQL:
            rval = MysqlClose(op_data->mysql);
            op_data->mysql = NULL;
            break;
#endif
#ifdef ENABLE_POSTGRES
        case FLAVOR_POSTGRES:
            rval = PostgresClose(op_data->pq);
            op_data->pq = NULL;
            break;
#endif /* ENABLE_POSTGRES */
        default:
            FatalError("Database flavor not supported\n");
            return 1;
    }

    return rval;
}


int SelectAsUInt(OpAcidDb_Data *op_data, char *sql, unsigned int *result)
{
    switch(op_data->flavor)
    {
#ifdef ENABLE_MYSQL
        case FLAVOR_MYSQL:
            return MysqlSelectAsUInt(op_data->mysql, sql, result);
#endif
#ifdef ENABLE_POSTGRES
        case FLAVOR_POSTGRES:
            return PostgresSelectAsUInt(op_data->pq, sql, result);
#endif /* ENABLE_POSTGRES */
        default:
            FatalError("Database flavor not supported\n");
            return 1;
    }
}

int Insert(OpAcidDb_Data *op_data, char *sql, unsigned int *row_id)
{
    switch(op_data->flavor)
    {
#ifdef ENABLE_MYSQL
        case FLAVOR_MYSQL:
            return MysqlInsert(op_data->mysql, sql, row_id);
#endif
#ifdef ENABLE_POSTGRES
        case FLAVOR_POSTGRES:
            if(row_id)
                *row_id = -1;
            return PostgresInsert(op_data->pq, sql);
#endif /* ENABLE_POSTGRES */
        default:
            FatalError("Database flavor not supported\n");
            return 1;
    }
}

char *EscapeString(OpAcidDb_Data *op_data, char *string)
{
    if(!op_data || !string)
        return NULL;

    switch(op_data->flavor)
    {
#ifdef ENABLE_MYSQL
        case FLAVOR_MYSQL:
            return MysqlEscapeString(op_data->mysql, string);
#endif
#ifdef ENABLE_POSTGRES
        case FLAVOR_POSTGRES:
            return PostgresEscapeString(op_data->pq, string);
#endif /* ENABLE_POSTGRES */
        default:
            FatalError("Database flavor not supported\n");
            return NULL;
    }
}
        

#ifdef ENABLE_MYSQL
int MysqlConnect(OpAcidDb_Data *op_data)
{
    op_data->mysql = mysql_init(NULL);
    if(!mysql_real_connect(op_data->mysql, op_data->server, op_data->user, 
                op_data->password, op_data->database, 0, NULL, 0))
    {
        FatalError("Failed to connect to database %s:XXXXXXXX@%s/%s: %s\n",
                op_data->user, op_data->server, 
                op_data->database, mysql_error(op_data->mysql));
    }
    return 0;
}

int MysqlClose(MYSQL *mysql)
{
    if(mysql_close)
        mysql_close(mysql);
    return 0;
}

int MysqlExecuteQuery(MYSQL *mysql, char *sql)
{
    int mysqlErrno;
    int result;
    while((result = mysql_query(mysql, sql) != 0))
    {
        mysqlErrno = mysql_errno(mysql);
        if(mysqlErrno < CR_MIN_ERROR)
        {
            if(pv.verbose)
                LogMessage("MySQL ERROR(%i): %s.  Aborting Query\n",
                        mysql_errno(mysql), mysql_error(mysql));
            return result;
        }
        if((mysqlErrno == CR_SERVER_LOST) 
                || (mysqlErrno == CR_SERVER_GONE_ERROR))
        {
            LogMessage("Lost connection to MySQL server.  Reconnecting\n");
            while(mysql_ping(mysql) != 0)
            {
                if(BarnyardSleep(15))
                    return result;
            }
            LogMessage("Reconnected to MySQL server.\n");
        }
        else
        {
            /* XXX we could spin here, but we do not */
            LogMessage("MySQL Error(%i): %s\n", mysqlErrno, mysql_error(mysql));
        }
    }
    return result;
}


int MysqlSelectAsUInt(MYSQL *mysql, char *sql, unsigned int *result)
{
    int rval = 0;
    MYSQL_RES *mysql_res;
    MYSQL_ROW tuple;
    
    if(MysqlExecuteQuery(mysql, sql) != 0)
    {
        /* XXX: should really just return up the chain */
        FatalError("Error (%s) executing query: %s\n", mysql_error(mysql), sql);
        return -1;
    }

    mysql_res = mysql_store_result(mysql);
    if((tuple = mysql_fetch_row(mysql_res)))
    {
        if(tuple[0] == NULL)
            *result = 0;
        else
            *result = atoi(tuple[0]);
        rval = 1;
    }
    mysql_free_result(mysql_res);
    return rval;
}

int MysqlInsert(MYSQL *mysql, char *sql, unsigned int *row_id)
{
    if(MysqlExecuteQuery(mysql, sql) != 0)
    {
        /* XXX: should really just return up the chain */
        FatalError("Error (%s) executing query: %s\n", mysql_error(mysql), sql);
        return -1;
    }

    if(row_id != NULL)
        *row_id = mysql_insert_id(mysql);
    return 0;
}

char *MysqlEscapeString(MYSQL *mysql, char *string)
{
    char *e_string = NULL;
    size_t len;
    
    if(!string || !mysql)
        return NULL;

    len = strlen(string);

    if(!(e_string = (char  *)calloc((len * 2) + 1, sizeof(char))))
    {
        FatalError("Out of memory (wanted %u bytes)",
                (len * 2) + 1);
        return NULL;
    }

    mysql_real_escape_string(mysql, e_string, string, len);

    return e_string;
}
#endif

#ifdef ENABLE_POSTGRES
int PostgresConnect(OpAcidDb_Data *op_data)
{
    if(!(op_data->pq = PQsetdbLogin(op_data->server, NULL, NULL, NULL, 
                    op_data->database, op_data->user, op_data->password)))
    {
        FatalError("Failed to create postgres connection object\n");
    }
    if(PQstatus(op_data->pq) != CONNECTION_OK)
    {
        FatalError("Failed to connect to database %s:XXXXXXXX@%s/%s: %s\n",
                op_data->user, op_data->server, 
                op_data->database, PQerrorMessage(op_data->pq));
    }

    return 0;
}

int PostgresClose(PGconn *pq)
{
    if(pq)
        PQfinish(pq);
    return 0;
}

int PostgresSelectAsUInt(PGconn *pq, char *sql, unsigned int *result)
{
    int rval = 0;
    PGresult *res = NULL;
    if(!pq || !sql || !result)
        return -1;

    if(!(res = PQexec(pq, sql)))
    {
        FatalError("PQexec failed: %s\n", PQerrorMessage(pq));
    }
    if(PQresultStatus(res) != PGRES_TUPLES_OK)
    {
        FatalError("SQL command '%s' failed: %s", sql, PQerrorMessage(pq));
    }
    if(PQntuples(res) > 0)
    {
        char *value = PQgetvalue(res, 0, 0);
        if(value == NULL)
            *result = 0;
        else
            *result = atoi(value);
        rval = 1;
    }

    PQclear(res);
    return rval;
}

int PostgresInsert(PGconn *pq, char *sql)
{
    PGresult *res = NULL;
    if(!pq || !sql)
        return -1;

    if(!(res = PQexec(pq, sql)))
    {
        FatalError("PQexec failed: %s\n", PQerrorMessage(pq));
    }
    if(PQresultStatus(res) != PGRES_COMMAND_OK)
    {
        FatalError("SQL command '%s' failed: %s", sql, PQerrorMessage(pq));
    }
    PQclear(res);
    return 0;
}

char *PostgresEscapeString(PGconn *pq, char *string)
{
    char *e_string = NULL;
    size_t len;
    
    if(!string || !pq)
        return NULL;

    len = strlen(string);

    if(!(e_string = (char  *)calloc((len * 2) + 1, sizeof(char))))
    {
        FatalError("Out of memory (wanted %u bytes)",
                (len * 2) + 1);
        return NULL;
    }

    PQescapeString(e_string, string, len);

    return e_string;
}
#endif /* ENABLE_POSTGRES */
