/*
 * conn.c
 *
 * Copyright (C) 2002 Thomas Graf <tgr@reeler.org>
 *
 * This file belongs to the nstats package, see COPYING for more information.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <netinet/in.h>
#include <netinet/ether.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/icmp6.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include <netinet/ip_icmp.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <sys/ioctl.h>

#include "conn.h"
#include "packet.h"
#include "stats.h"

extern void quit(const char *);

struct ip_conn_chain_s {
    struct ip_conn_s *conn;
    struct ip_conn_chain_s *next;
} *conn_lookup[65537];

void
clean_ports(void)
{
    struct ip_port_s *p1, *n, *prev=NULL;
   
    p1 = stats.ip.ports;
    while(p1) {
        n = p1->next;

        if (p1->cnt == 1) {

            if (!prev)
                stats.ip.ports = p1->next;
            else
                prev->next = p1->next;

            free(p1);
        } else
            prev = p1;
        p1 = n;
    }
}

void
clean_conns(void)
{
    struct ip_conn_s *p1, *n, *prev=NULL;
   
    p1 = stats.ip.conns;
    while(p1) {
        n = p1->next;

        if (p1->cnt == 1) {

            if (!prev)
                stats.ip.conns = p1->next;
            else
                prev->next = p1->next;

            if (p1 == stats.ip.conns)
                stats.ip.conns = p1->next;

            if (p1 == stats.ip.conn_tail)
                stats.ip.conn_tail = prev;

            if (p1->sname)
                free(p1->sname);

            if (p1->dname)
                free(p1->dname);

            free(p1);
        } else
            prev = p1;

        p1 = n;
    }
}

void fill_ports(void)
{
    struct ip_conn_s *p1;
    struct ip_port_s *port, *n;
    int f=0;

    port = stats.ip.ports;
    while (port) {
        n = port->next;
        free(port);
        port = n;
    }

    stats.ip.ports = NULL;

    p1 = stats.ip.conns;
    while(p1) {

        stats.ip.c_ports += p1->cnt;

        for (f=0,port = stats.ip.ports; port; port = port->next) {
            if (port->port == p1->sport || port->port == p1->dport) {
                port->bs += p1->bs;
                port->cnt += p1->cnt;
                f=1;
                break;
            }
        }

        if (!f) {
            if (!(port = (struct ip_port_s *) calloc(1,sizeof(struct ip_port_s))))
                quit("Out of memory!\n");

            port->cnt = p1->cnt;
            port->bs = p1->bs;

            port->next = stats.ip.ports;
            stats.ip.ports = port;

            if (p1->dport <= 1024)
                port->port = p1->dport;
            else if (p1->sport <= 1024)
                port->port = p1->sport;
            else {
                port->port = p1->sport;

                /*
                 * this is kinda stupid, we do store a port entry for
                 * both source and destination port, i didn't find a better
                 * way to decide which is the more important port number
                 * as to look the packets in the same connection but the
                 * other direction
                 */

                if (!(port = (struct ip_port_s *)
                        calloc(1,sizeof(struct ip_port_s))))
                    quit("Out of memory!");

                port->cnt = p1->cnt;
                port->bs = p1->bs;
                port->port = p1->dport;

                port->next = stats.ip.ports;
                stats.ip.ports = port;
            }
        }

        p1 = p1->next;
    }
}

void
sort_ports(void)
{
    struct ip_port_s *p, *q, *e, *tail, *list;
    unsigned int insize, nmerges, psize, qsize, i;

    list = stats.ip.ports;
    insize = 1;

    if (!list)
        return;

    while (1) {
        p = list;
        list = NULL;
        tail = NULL;

        nmerges = 0;

        while (p) {
            nmerges++;
            q = p;
            
            for (psize=0,i=0; i < insize; i++) {
                psize++;
                q = q->next;
                if (!q)
                    break;
            }

            for (qsize=insize; psize > 0 || (qsize > 0 && q);) {

                if (psize == 0) {
                    e = q;
                    q = q->next;
                    qsize--;
                } else if (qsize == 0 || !q) {
                    e = p;
                    p = p->next;
                    psize--;
                } else if (p->bs > q->bs) {
                    e = p;
                    p = p->next;
                    psize--;
                } else {
                    e = q;
                    q = q->next;
                    qsize--;
                }

                if (tail)
                    tail->next = e;
                else
                    list = e;

                tail = e;
            }

            p = q;
        }

	    tail->next = NULL;

        if (nmerges <= 1) {
            stats.ip.ports = list;
            return;
        }

        insize *= 2;
    }
}


void
sort_conns(void)
{
    struct ip_conn_s *p, *q, *e, *tail, *list;
    unsigned int insize, nmerges, psize, qsize, i;

    list = stats.ip.conns;
    insize = 1;

    if (!list)
        return;

    while (1) {
        p = list;
        list = NULL;
        tail = NULL;

        nmerges = 0;

        while (p) {
            nmerges++;
            q = p;
            
            for (psize=0,i=0; i < insize; i++) {
                psize++;
                q = q->next;
                if (!q)
                    break;
            }

            for (qsize=insize; psize > 0 || (qsize > 0 && q);) {

                if (psize == 0) {
                    e = q;
                    q = q->next;
                    qsize--;
                } else if (qsize == 0 || !q) {
                    e = p;
                    p = p->next;
                    psize--;
                } else if (p->bs > q->bs) {
                    e = p;
                    p = p->next;
                    psize--;
                } else {
                    e = q;
                    q = q->next;
                    qsize--;
                }

                if (tail)
                    tail->next = e;
                else
                    list = e;

                tail = e;
            }

            p = q;
        }

	    tail->next = NULL;

        if (nmerges <= 1) {
            stats.ip.conns = list;
            return;
        }

        insize *= 2;
    }
}


__inline__
int match(struct ip_conn_s *c, int sp, int dp, char *saddr, char *daddr, int type)
{
    if (c->dport != dp || c->sport != sp)
        return 0;

    switch(type) {
        case ETHERTYPE_IP: /* ip4 */
            if ( *((int *) saddr) == c->src_u.ip4.s_addr &&
                 *((int *) daddr) == c->dst_u.ip4.s_addr)
                return 1;
            break;

        case 0x86DD: /* ip6 */
            if (!memcmp(saddr, &c->src_u.ip6, 16) &&
                !memcmp(daddr, &c->dst_u.ip6, 16))
                return 1;
            break;
    }

    return 0;
}

void
handle_port(int sp, int dp, char *saddr, char *daddr, int type,
            const struct pcap_pkthdr *hdr)
{
    struct ip_conn_chain_s *cc;

    if ( sp < 0 || dp < 0 || sp > 65536 || dp > 65536 )
        return;

    if ( dp && sp ) {
        int f = 0;
        struct ip_conn_s *c = stats.ip.conns;

        if ( (cc = conn_lookup[dp]) ) {
            while (cc) {
                if (match(cc->conn, sp, dp, saddr, daddr, type)) {
                    cc->conn->cnt++;
                    cc->conn->bs += hdr->len;
                    cc->conn->last_activity = time(0);
                    f = 1;
                    break;
                }
                cc = cc->next;
            }
        }

        if (!f) {
            if ( (cc = conn_lookup[sp]) ) {
                while (cc) {
                    if (match(cc->conn, sp, dp, saddr, daddr, type)) {
                        cc->conn->cnt++;
                        cc->conn->bs += hdr->len;
                        cc->conn->last_activity = time(0);
                        f = 1;
                        break;
                    }
                    cc = cc->next;
                }
            }
        }

        if (!f) {

            if ( !(c = (struct ip_conn_s *) calloc(1,sizeof(struct ip_conn_s))) )
                quit("Out of memory!\n");

            if ( !stats.ip.conn_tail ) {
                stats.ip.conns = c;
                stats.ip.conn_tail = c;
            } else {
                stats.ip.conn_tail->next = c;
                stats.ip.conn_tail = c;
            }

            if ( !(cc = (struct ip_conn_chain_s *)
                    calloc(1, sizeof(struct ip_conn_chain_s))) )
                quit("Out of memory!\n");

            cc->conn = c;

            cc->next = conn_lookup[dp];
            conn_lookup[dp] = cc;

            if (dp != sp) {
                cc->next = conn_lookup[sp];
                conn_lookup[sp] = cc;
            }

            c->cnt = 1;
            c->bs += hdr->len;

            c->dport = dp;
            c->sport = sp;
            c->type = type;

            /* XXX: use the timestamp in packet header? */
            c->last_activity = time(0);

            switch(type) {
                case ETHERTYPE_IP: /* ip4 */
                    memcpy(&c->src_u.ip4, saddr, 4);
                    memcpy(&c->dst_u.ip4, daddr, 4);
                    break;

                case 0x86DD: /* ip6 */
                    memcpy(&c->src_u.ip6, saddr, 16);
                    memcpy(&c->dst_u.ip6, daddr, 16);
                    break;
            }

        }
    }
}

void
calc_conn_rate(void)
{
    struct ip_conn_s *p1;
   
    p1 = stats.ip.conns;
    while(p1) {

        if (p1->bs_old)
            p1->r_bs += (p1->bs - p1->bs_old);
        p1->bs_old = p1->bs;

        if (p1->cnt_old)
            p1->r_packets += (p1->cnt - p1->cnt_old);
        p1->cnt_old = p1->cnt;

        p1->ratec++;

        if (p1->ratec % 5 == 0) {

            p1->r_bs /= 6;
            p1->r_packets /= 6;

            p1->ratec = 0;
        }

        p1 = p1->next;
    }
}

void
expire_conns(int secs)
{
    struct ip_conn_s *p1, *n, *prev=NULL;
    time_t now = time(0);
   
    p1 = stats.ip.conns;
    while( p1 ) {
        n = p1->next;

        if ( (now - p1->last_activity) >= secs ) {

            if ( ! prev )
                stats.ip.conns = p1->next;
            else
                prev->next = p1->next;

            if ( p1 == stats.ip.conns )
                stats.ip.conns = p1->next;

            if ( p1 == stats.ip.conn_tail )
                stats.ip.conn_tail = prev;

            if ( p1->sname )
                free( p1->sname );

            if ( p1->dname )
                free( p1->dname );
            
            free( p1 );
        } else
            prev = p1;

        p1 = n;
    }
}
