/*
 *  glue.c
 *
 *  Written by Jari Ruusu, August 14 2004
 *
 *  Copyright 2001,2002,2003,2004 by Jari Ruusu.
 *  Redistribution of this file is permitted under the GNU Public License.
 */

#include <linux/version.h>
#include <linux/sched.h>
#include <linux/fs.h>
#include <linux/string.h>
#include <linux/types.h>
#include <linux/errno.h>
#if LINUX_VERSION_CODE >= 0x20600
# include <linux/bio.h>
# include <linux/blkdev.h>
#endif
#if LINUX_VERSION_CODE >= 0x20200
# include <linux/slab.h>
# include <linux/loop.h>
# include <asm/uaccess.h>
#else
# include <linux/malloc.h>
# include <asm/segment.h>
# include "patched-loop.h"
#endif
#if LINUX_VERSION_CODE >= 0x20400
# include <linux/spinlock.h>
#endif
#include <asm/byteorder.h>
#include "aes.h"
#include "md5.h"

#if LINUX_VERSION_CODE >= 0x20600
typedef sector_t TransferSector_t;
# define LoopInfo_t struct loop_info64
#else
typedef int TransferSector_t;
# define LoopInfo_t struct loop_info
#endif

#if !defined(cpu_to_le32)
# if defined(__BIG_ENDIAN)
#  define cpu_to_le32(x) ({u_int32_t __x=(x);((u_int32_t)((((u_int32_t)(__x)&(u_int32_t)0x000000ffUL)<<24)|(((u_int32_t)(__x)&(u_int32_t)0x0000ff00UL)<<8)|(((u_int32_t)(__x)&(u_int32_t)0x00ff0000UL)>>8)|(((u_int32_t)(__x)&(u_int32_t)0xff000000UL)>>24)));})
# else
#  define cpu_to_le32(x) ((u_int32_t)(x))
# endif
#endif

#if LINUX_VERSION_CODE < 0x20200
# define copy_from_user(t,f,s) (verify_area(VERIFY_READ,f,s)?(s):(memcpy_fromfs(t,f,s),0))
#endif

#if !defined(LOOP_MULTI_KEY_SETUP)
# define LOOP_MULTI_KEY_SETUP 0x4C4D
#endif

#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
# define KEY_ALLOC_COUNT  128
#else
# define KEY_ALLOC_COUNT  64
#endif

typedef struct {
    aes_context *keyPtr[KEY_ALLOC_COUNT];
    unsigned    keyMask;
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    rwlock_t    rwlock;
    unsigned    reversed;
    unsigned    blocked;
    struct timer_list timer;
#endif
} AESmultiKey;

#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
static void keyScrubWork(AESmultiKey *m)
{
    aes_context *a0, *a1;
    u_int32_t *p;
    int x, y, z;

    z = m->keyMask + 1;
    for(x = 0; x < z; x++) {
        a0 = m->keyPtr[x];
        a1 = m->keyPtr[x + z];
        memcpy(a1, a0, sizeof(aes_context));
        m->keyPtr[x] = a1;
        m->keyPtr[x + z] = a0;
        p = (u_int32_t *) a0;
        y = sizeof(aes_context) / sizeof(u_int32_t);
        while(y > 0) {
            *p ^= 0xFFFFFFFF;
            p++;
            y--;
        }
    }
    m->reversed ^= 1;

    /* try to flush dirty cache data to RAM */
#if defined(CONFIG_X86_64) || (defined(CONFIG_X86) && !defined(CONFIG_M386) && !defined(CONFIG_CPU_386))
    __asm__ __volatile__ ("wbinvd": : :"memory");
#else
    mb();
#endif
}

/* called only from loop thread process context */
static void keyScrubThreadFn(AESmultiKey *m)
{
    write_lock(&m->rwlock);
    if(!m->blocked) keyScrubWork(m);
    write_unlock(&m->rwlock);
}

static void keyScrubTimerInit(struct loop_device *lo)
{
    AESmultiKey     *m;
    unsigned long   expire;
    static void keyScrubTimerFn(unsigned long);

    m = (AESmultiKey *)lo->key_data;
    expire = jiffies + HZ;
    init_timer(&m->timer);
    m->timer.expires = expire;
    m->timer.data = (unsigned long)lo;
    m->timer.function = keyScrubTimerFn;
    add_timer(&m->timer);
}

/* called only from timer handler context */
static void keyScrubTimerFn(unsigned long d)
{
    struct loop_device *lo = (struct loop_device *)d;
    extern void loop_add_keyscrub_fn(struct loop_device *, void (*)(void *), void *);

    /* rw lock needs process context, so make loop thread do scrubbing */
    loop_add_keyscrub_fn(lo, (void (*)(void*))keyScrubThreadFn, lo->key_data);
    /* start timer again */
    keyScrubTimerInit(lo);
}
#endif

static AESmultiKey *allocMultiKey(void)
{
    AESmultiKey *m;
    aes_context *a;
    int x = 0, n;

    m = (AESmultiKey *) kmalloc(sizeof(AESmultiKey), GFP_KERNEL);
    if(!m) return 0;
    memset(m, 0, sizeof(AESmultiKey));
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    rwlock_init(&m->rwlock);
    init_timer(&m->timer);
    again:
#endif

    n = PAGE_SIZE / sizeof(aes_context);
    if(!n) n = 1;

    a = (aes_context *) kmalloc(sizeof(aes_context) * n, GFP_KERNEL);
    if(!a) {
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
        if(x) kfree(m->keyPtr[0]);
#endif
        kfree(m);
        return 0;
    }

    while((x < KEY_ALLOC_COUNT) && n) {
        m->keyPtr[x] = a;
        a++;
        x++;
        n--;
    }
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    if(x < 2) goto again;
#endif
    return m;
}

static void clearAndFreeMultiKey(AESmultiKey *m)
{
    aes_context *a;
    int x, n;

#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    /* stop scrub timer. loop thread was killed earlier */
    del_timer_sync(&m->timer);
    /* make sure allocated keys are in original order */
    if(m->reversed) keyScrubWork(m);
#endif
    n = PAGE_SIZE / sizeof(aes_context);
    if(!n) n = 1;

    x = 0;
    while(x < KEY_ALLOC_COUNT) {
        a = m->keyPtr[x];
        if(!a) break;
        memset(a, 0, sizeof(aes_context) * n);
        kfree(a);
        x += n;
    }

    kfree(m);
}

static int multiKeySetup(struct loop_device *lo, unsigned char *k)
{
    AESmultiKey *m;
    aes_context *a;
    int x, y, n, err = 0;
    union {
        u_int32_t     w[8]; /* needed for 4 byte alignment for b[] */
        unsigned char b[32];
    } un;

#if LINUX_VERSION_CODE >= 0x20200
    if(lo->lo_key_owner != current->uid && !capable(CAP_SYS_ADMIN))
        return -EPERM;
#endif

    m = (AESmultiKey *)lo->key_data;
    if(!m) return -ENXIO;

#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    /* temporarily prevent loop thread from messing with keys */
    write_lock(&m->rwlock);
    m->blocked = 1;
    /* make sure allocated keys are in original order */
    if(m->reversed) keyScrubWork(m);
    write_unlock(&m->rwlock);
#endif
    n = PAGE_SIZE / sizeof(aes_context);
    if(!n) n = 1;

    x = 0;
    while(x < KEY_ALLOC_COUNT) {
        if(!m->keyPtr[x]) {
            a = (aes_context *) kmalloc(sizeof(aes_context) * n, GFP_KERNEL);
            if(!a) {
                err = -ENOMEM;
                goto error_out;
            }
            y = x;
            while((y < (x + n)) && (y < KEY_ALLOC_COUNT)) {
                m->keyPtr[y] = a;
                a++;
                y++;
            }
        }
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
        if(x >= 64) {
            x++;
            continue;
        }
#endif
        if(copy_from_user(&un.b[0], k, 32)) {
            err = -EFAULT;
            goto error_out;
        }
        aes_set_key(m->keyPtr[x], &un.b[0], lo->lo_encrypt_key_size, 0);
        k += 32;
        x++;
    }
    m->keyMask = 0x3F;          /* range 0...63 */
    lo->lo_flags |= 0x100000;   /* multi-key (info exported to user space) */
    memset(&un.b[0], 0, 32);
error_out:
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    /* re-enable loop thread key scrubbing */
    write_lock(&m->rwlock);
    m->blocked = 0;
    write_unlock(&m->rwlock);
#endif
    return err;
}

void loop_compute_sector_iv(TransferSector_t devSect, u_int32_t *ivout)
{
    if(sizeof(TransferSector_t) == 8) {
        ivout[0] = cpu_to_le32(devSect);
        ivout[1] = cpu_to_le32((u_int64_t)devSect>>32);
        ivout[3] = ivout[2] = 0;
    } else {
        ivout[0] = cpu_to_le32(devSect);
        ivout[3] = ivout[2] = ivout[1] = 0;
    }
}

void loop_compute_md5_iv(TransferSector_t devSect, u_int32_t *ivout, u_int32_t *data)
{
    int         x;
#if defined(__BIG_ENDIAN)
    int         y, e;
#endif
    u_int32_t   buf[16];

    ivout[0] = 0x67452301;
    ivout[1] = 0xefcdab89;
    ivout[2] = 0x98badcfe;
    ivout[3] = 0x10325476;

#if defined(__BIG_ENDIAN)
    y = 7;
    e = 16;
    do {
        if (!y) {
            e = 12;
            /* md5_transform_CPUbyteorder wants data in CPU byte order */
            /* devSect is already in CPU byte order -- no need to convert */
            if(sizeof(TransferSector_t) == 8) {
                /* use only 56 bits of sector number */
                buf[12] = devSect;
                buf[13] = (((u_int64_t)devSect >> 32) & 0xFFFFFF) | 0x80000000;
            } else {
                /* 32 bits of sector number + 24 zero bits */
                buf[12] = devSect;
                buf[13] = 0x80000000;
            }
            /* 4024 bits == 31 * 128 bit plaintext blocks + 56 bits of sector number */
            buf[14] = 4024;
            buf[15] = 0;
        }
        x = 0;
        do {
            buf[x    ] = cpu_to_le32(data[0]);
            buf[x + 1] = cpu_to_le32(data[1]);
            buf[x + 2] = cpu_to_le32(data[2]);
            buf[x + 3] = cpu_to_le32(data[3]);
            x += 4;
            data += 4;
        } while (x < e);
        md5_transform_CPUbyteorder(&ivout[0], &buf[0]);
    } while (--y >= 0);
    ivout[0] = cpu_to_le32(ivout[0]);
    ivout[1] = cpu_to_le32(ivout[1]);
    ivout[2] = cpu_to_le32(ivout[2]);
    ivout[3] = cpu_to_le32(ivout[3]);
#else
    x = 6;
    do {
        md5_transform_CPUbyteorder(&ivout[0], data);
        data += 16;
    } while (--x >= 0);
    memcpy(buf, data, 48);
    /* md5_transform_CPUbyteorder wants data in CPU byte order */
    /* devSect is already in CPU byte order -- no need to convert */
    if(sizeof(TransferSector_t) == 8) {
        /* use only 56 bits of sector number */
        buf[12] = devSect;
        buf[13] = (((u_int64_t)devSect >> 32) & 0xFFFFFF) | 0x80000000;
    } else {
        /* 32 bits of sector number + 24 zero bits */
        buf[12] = devSect;
        buf[13] = 0x80000000;
    }
    /* 4024 bits == 31 * 128 bit plaintext blocks + 56 bits of sector number */
    buf[14] = 4024;
    buf[15] = 0;
    md5_transform_CPUbyteorder(&ivout[0], &buf[0]);
#endif
}

int transfer_aes(struct loop_device *lo, int cmd, char *raw_buf,
          char *loop_buf, int size, TransferSector_t devSect)
{
    aes_context     *a;
    AESmultiKey     *m;
    int             x;
    unsigned        y;
    u_int32_t       iv[8];

    if(!size || (size & 511)) {
        return -EINVAL;
    }
    m = (AESmultiKey *)lo->key_data;
    y = m->keyMask;
    if(cmd == READ) {
        while(size) {
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
            read_lock(&m->rwlock);
#endif
            a = m->keyPtr[((unsigned)devSect) & y];
            if(y) {
                memcpy(&iv[0], raw_buf, 16);
                raw_buf += 16;
                loop_buf += 16;
            } else {
                loop_compute_sector_iv(devSect, &iv[0]);
            }
            x = 15;
            do {
                memcpy(&iv[4], raw_buf, 16);
                aes_decrypt(a, raw_buf, loop_buf);
                *((u_int32_t *)(&loop_buf[ 0])) ^= iv[0];
                *((u_int32_t *)(&loop_buf[ 4])) ^= iv[1];
                *((u_int32_t *)(&loop_buf[ 8])) ^= iv[2];
                *((u_int32_t *)(&loop_buf[12])) ^= iv[3];
                if(y && !x) {
                    raw_buf -= 496;
                    loop_buf -= 496;
                    loop_compute_md5_iv(devSect, &iv[4], (u_int32_t *)(&loop_buf[16]));
                } else {
                    raw_buf += 16;
                    loop_buf += 16;
                    memcpy(&iv[0], raw_buf, 16);
                }
                aes_decrypt(a, raw_buf, loop_buf);
                *((u_int32_t *)(&loop_buf[ 0])) ^= iv[4];
                *((u_int32_t *)(&loop_buf[ 4])) ^= iv[5];
                *((u_int32_t *)(&loop_buf[ 8])) ^= iv[6];
                *((u_int32_t *)(&loop_buf[12])) ^= iv[7];
                if(y && !x) {
                    raw_buf += 512;
                    loop_buf += 512;
                } else {
                    raw_buf += 16;
                    loop_buf += 16;
                }
            } while(--x >= 0);
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
            read_unlock(&m->rwlock);
#endif
#if LINUX_VERSION_CODE >= 0x20600
            cond_resched();
#elif LINUX_VERSION_CODE >= 0x20400
            if(current->need_resched) {set_current_state(TASK_RUNNING);schedule();}
#elif LINUX_VERSION_CODE >= 0x20200
            if(current->need_resched) {current->state=TASK_RUNNING;schedule();}
#else
            if(need_resched) schedule();
#endif
            size -= 512;
            devSect++;
        }
    } else {
        while(size) {
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
            read_lock(&m->rwlock);
#endif
            a = m->keyPtr[((unsigned)devSect) & y];
            if(y) {
#if LINUX_VERSION_CODE < 0x20400
                /* on 2.2 and older kernels, real raw_buf may be doing */
                /* writes at any time, so this needs to be stack buffer */
                u_int32_t tmp_raw_buf[128];
                char *TMP_RAW_BUF = (char *)(&tmp_raw_buf[0]);
#else
                /* on 2.4 and later kernels, real raw_buf is not doing */
                /* any writes now so it can be used as temp buffer */
# define TMP_RAW_BUF raw_buf
#endif
                memcpy(TMP_RAW_BUF, loop_buf, 512);
                loop_compute_md5_iv(devSect, &iv[0], (u_int32_t *)(&TMP_RAW_BUF[16]));
                x = 15;
                do {
                    iv[0] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 0]));
                    iv[1] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 4]));
                    iv[2] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 8]));
                    iv[3] ^= *((u_int32_t *)(&TMP_RAW_BUF[12]));
                    aes_encrypt(a, (unsigned char *)(&iv[0]), raw_buf);
                    memcpy(&iv[0], raw_buf, 16);
                    raw_buf += 16;
#if LINUX_VERSION_CODE < 0x20400
                    TMP_RAW_BUF += 16;
#endif
                    iv[0] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 0]));
                    iv[1] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 4]));
                    iv[2] ^= *((u_int32_t *)(&TMP_RAW_BUF[ 8]));
                    iv[3] ^= *((u_int32_t *)(&TMP_RAW_BUF[12]));
                    aes_encrypt(a, (unsigned char *)(&iv[0]), raw_buf);
                    memcpy(&iv[0], raw_buf, 16);
                    raw_buf += 16;
#if LINUX_VERSION_CODE < 0x20400
                    TMP_RAW_BUF += 16;
#endif
                } while(--x >= 0);
                loop_buf += 512;
            } else {
                loop_compute_sector_iv(devSect, &iv[0]);
                x = 15;
                do {
                    iv[0] ^= *((u_int32_t *)(&loop_buf[ 0]));
                    iv[1] ^= *((u_int32_t *)(&loop_buf[ 4]));
                    iv[2] ^= *((u_int32_t *)(&loop_buf[ 8]));
                    iv[3] ^= *((u_int32_t *)(&loop_buf[12]));
                    aes_encrypt(a, (unsigned char *)(&iv[0]), raw_buf);
                    memcpy(&iv[0], raw_buf, 16);
                    loop_buf += 16;
                    raw_buf += 16;
                    iv[0] ^= *((u_int32_t *)(&loop_buf[ 0]));
                    iv[1] ^= *((u_int32_t *)(&loop_buf[ 4]));
                    iv[2] ^= *((u_int32_t *)(&loop_buf[ 8]));
                    iv[3] ^= *((u_int32_t *)(&loop_buf[12]));
                    aes_encrypt(a, (unsigned char *)(&iv[0]), raw_buf);
                    memcpy(&iv[0], raw_buf, 16);
                    loop_buf += 16;
                    raw_buf += 16;
                } while(--x >= 0);
            }
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
            read_unlock(&m->rwlock);
#endif
#if LINUX_VERSION_CODE >= 0x20600
            cond_resched();
#elif LINUX_VERSION_CODE >= 0x20400
            if(current->need_resched) {set_current_state(TASK_RUNNING);schedule();}
#elif LINUX_VERSION_CODE >= 0x20200
            if(current->need_resched) {current->state=TASK_RUNNING;schedule();}
#else
            if(need_resched) schedule();
#endif
            size -= 512;
            devSect++;
        }
    }
    return(0);
}

int keySetup_aes(struct loop_device *lo, LoopInfo_t *info)
{
    AESmultiKey     *m;
    union {
        u_int32_t     w[8]; /* needed for 4 byte alignment for b[] */
        unsigned char b[32];
    } un;

    lo->key_data = m = allocMultiKey();
    if(!m) return(-ENOMEM);
    memcpy(&un.b[0], &info->lo_encrypt_key[0], 32);
    aes_set_key(m->keyPtr[0], &un.b[0], info->lo_encrypt_key_size, 0);
    memset(&info->lo_encrypt_key[0], 0, sizeof(info->lo_encrypt_key));
    memset(&un.b[0], 0, 32);
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    keyScrubTimerInit(lo);
#endif
    return(0);
}

int keyClean_aes(struct loop_device *lo)
{
    if(lo->key_data) {
        clearAndFreeMultiKey((AESmultiKey *)lo->key_data);
        lo->key_data = 0;
    }
    return(0);
}

int handleIoctl_aes(struct loop_device *lo, int cmd, unsigned long arg)
{
    int err;

    switch (cmd) {
    case LOOP_MULTI_KEY_SETUP:
        err = multiKeySetup(lo, (unsigned char *)arg);
        break;
    default:
        err = -EINVAL;
    }
    return err;
}

#if LINUX_VERSION_CODE >= 0x20200

static struct loop_func_table funcs_aes = {
    number:     16,     /* 16 == AES */
    transfer:   (void *) transfer_aes,
    init:       (void *) keySetup_aes,
    release:    keyClean_aes,
    ioctl:      (void *) handleIoctl_aes
};

int init_module_aes(void)
{
    if(loop_register_transfer(&funcs_aes)) {
        printk("loop: unable to register AES transfer\n");
        return -EIO;
    }
#if CONFIG_BLK_DEV_LOOP_KEYSCRUB
    printk("loop: AES key scrubbing enabled\n");
#endif
    return 0;
}

void cleanup_module_aes(void)
{
    if(loop_unregister_transfer(funcs_aes.number)) {
        printk("loop: unable to unregister AES transfer\n");
    }
}

#endif
