#include "log/log.h"
#include "tpm2_common.h"

TSS2_TCTI_CONTEXT *
tcti_device_init(char const *device_path)
{
    size_t size;
    TSS2_RC rc;
    TSS2_TCTI_CONTEXT *tcti_ctx;

    rc = Tss2_Tcti_Device_Init(NULL,  &size,  0);
    if (rc != TSS2_RC_SUCCESS) {
        log_error("Failed to get allocation size for device tcti context: "
                  "0x%x\n",  rc);
        return NULL;
    }
    tcti_ctx = (TSS2_TCTI_CONTEXT *) calloc(1,  size);
    if (tcti_ctx == NULL) {
        log_error("Allocation for device TCTI context failed: %s\n", 
                  strerror(errno));
        return NULL;
    }
    rc = Tss2_Tcti_Device_Init(tcti_ctx,  &size,  device_path);
    if (rc != TSS2_RC_SUCCESS) {
        log_error("Failed to initialize device TCTI context: 0x%x\n", rc);
        free(tcti_ctx);
        return NULL;
    }
    return tcti_ctx;
}

TSS2_SYS_CONTEXT *
sapi_init_from_tcti_ctx(TSS2_TCTI_CONTEXT *tcti_ctx)
{
    TSS2_SYS_CONTEXT *sapi_ctx;
    TSS2_RC rc;
    size_t size;
    TSS2_ABI_VERSION abi_version = {
        .tssCreator = 1, 
        .tssFamily = 2, 
        .tssLevel = 1, 
        .tssVersion = 108, 
    };

    size = Tss2_Sys_GetContextSize(0);
    sapi_ctx = (TSS2_SYS_CONTEXT *) calloc(1,  size);
    if (sapi_ctx == NULL) {
        log_error( "Failed to allocate 0x%zx bytes for the SAPI context\n",  size);
        return NULL;
    }
    rc = Tss2_Sys_Initialize(sapi_ctx,  size,  tcti_ctx,  &abi_version);
    if (rc != TSS2_RC_SUCCESS) {
        log_error("Failed to initialize SAPI context: 0x%x\n",  rc);
        free(sapi_ctx);
        return NULL;
    }
    return sapi_ctx;
}

static int
get_dynamic_library_path(char **path)
{
    struct utsname arch_info;

    if (uname(&arch_info) < 0)
    {
        return 0;
    }
    *path = (char *)malloc(strlen(arch_info.machine) + strlen(TSS2_TCTI_TABRMD) -1);
    int ret = sprintf(*path, TSS2_TCTI_TABRMD, arch_info.machine);
    return ret;
}

TSS2_TCTI_CONTEXT *
tcti_tabrmd_init(const char * name)
{
    TSS2_TCTI_CONTEXT *tcti_ctx = NULL;
    const TSS2_TCTI_INFO *info;
    void *handle;
    char *path = NULL;
    int ret;

    ret = get_dynamic_library_path(&path);
    if (ret == 0) {
        log_error("Failed to get dynamic librarchy path");
        return NULL;
    }

    handle = dlopen(path, RTLD_LAZY);
    if (!handle) {
        log_error("Could not dlopen library: \"%s\"", path);
        free(path);
        return NULL;
    }
    free(path);
    TSS2_TCTI_INFO_FUNC infofn = (TSS2_TCTI_INFO_FUNC)dlsym(handle, TSS2_TCTI_INFO_SYMBOL);
    if (!infofn) {
        log_error("Symbol \"%s\"not found in library: \"%s\"",
                TSS2_TCTI_INFO_SYMBOL, name);
        goto err;
    }

    info = infofn();

    TSS2_TCTI_INIT_FUNC init = info->init;

    size_t size;
    TSS2_RC rc = init(NULL, &size, NULL);
    if (rc != TPM2_RC_SUCCESS) {
        log_error("tcti init setup routine failed for library: \"%s\""
                , name);
        goto err;
    }

    tcti_ctx = (TSS2_TCTI_CONTEXT*) calloc(1, size);
    if (tcti_ctx == NULL) {
        log_error( "Failed to allocate 0x%zx bytes for the TCTI context\n",  size);
        goto err;
    }

    rc = init(tcti_ctx, &size, NULL);
    if (rc != TPM2_RC_SUCCESS) {
        log_error("tcti init allocation routine failed for library: \"%s\""
                , name);
        goto err;
    }

    return tcti_ctx;
err:
    // free(path);
    free(tcti_ctx);
    dlclose(handle);
    return NULL;
}

TSS2_SYS_CONTEXT *
sapi_init_from_opts(OPTIONS *opts)
{
    TSS2_TCTI_CONTEXT *tcti_ctx;
    TSS2_SYS_CONTEXT *sapi_ctx;

    if (strstr(opts->device_file, TC_TABRMD_MARK)) {
        tcti_ctx = tcti_tabrmd_init(opts->device_file);
    }else{
        tcti_ctx = tcti_device_init(opts->device_file);
    }
    if (tcti_ctx == NULL) 
        return NULL;
    sapi_ctx = sapi_init_from_tcti_ctx(tcti_ctx);
    return sapi_ctx;
}

TSS2_SYS_CONTEXT *
tpm2_init(uint8_t *device_name)
{
    TSS2_SYS_CONTEXT *sapi_context;
    OPTIONS opts = {
        .tcti_type = DEVICE_TCTI, 
        .device_file = device_name, 
    };
    int ret;

    sapi_context = sapi_init_from_opts(&opts);
    if (sapi_context == NULL) {
        log_error("SAPI context not initialized\n");
        return NULL;
    }

    ret = Tss2_Sys_Startup(sapi_context, TPM2_SU_CLEAR);
    if (ret != TSS2_RC_SUCCESS && ret != TPM2_RC_INITIALIZE) {
        log_error("TPM Startup FAILED! Response Code : 0x%x", ret);
        return NULL;
    }

    return sapi_context;
}

TSS2_RC 
tpm2_free(TSS2_SYS_CONTEXT *sapi_context){
    TSS2_TCTI_CONTEXT *tcti_context = NULL;
    TSS2_RC rc;

    rc = Tss2_Sys_GetTctiContext(sapi_context, &tcti_context);
    if (rc != TSS2_RC_SUCCESS)
        return rc;

    Tss2_Sys_Finalize(sapi_context);
    free(sapi_context);
    free(tcti_context);

    return rc;
}

TSS2_RC
tpm2_flush_context(
    TSS2_SYS_CONTEXT        *sapi_context,
    TPMI_DH_CONTEXT          flush_handle)
{
    return Tss2_Sys_FlushContext(sapi_context, flush_handle);
}

int
setup_createprimary_alg(TPM2B_PUBLIC *public_info)
{
    switch (public_info->publicArea.type)
    {
    case TPM2_ALG_RSA:
        public_info->publicArea.parameters.rsaDetail.symmetric.algorithm = TPM2_ALG_AES;
        public_info->publicArea.parameters.rsaDetail.symmetric.keyBits.aes = 128;
        public_info->publicArea.parameters.rsaDetail.symmetric.mode.aes = TPM2_ALG_CFB;
        public_info->publicArea.parameters.rsaDetail.scheme.scheme = TPM2_ALG_NULL;
        public_info->publicArea.parameters.rsaDetail.keyBits = 2048;
        public_info->publicArea.parameters.rsaDetail.exponent = 0;
        public_info->publicArea.unique.rsa.size = 0;
        break;
    case TPM2_ALG_KEYEDHASH:
        public_info->publicArea.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG_XOR;
        public_info->publicArea.parameters.keyedHashDetail.scheme.details.exclusiveOr.hashAlg = TPM2_ALG_SHA256;
        public_info->publicArea.parameters.keyedHashDetail.scheme.details.exclusiveOr.kdf = TPM2_ALG_KDF1_SP800_108;
        public_info->publicArea.unique.keyedHash.size = 0;
        break;
    case TPM2_ALG_ECC:
        public_info->publicArea.parameters.eccDetail.symmetric.algorithm = TPM2_ALG_AES;
        public_info->publicArea.parameters.eccDetail.symmetric.keyBits.aes = 128;
        public_info->publicArea.parameters.eccDetail.symmetric.mode.sym = TPM2_ALG_CFB;
        public_info->publicArea.parameters.eccDetail.scheme.scheme = TPM2_ALG_NULL;
        public_info->publicArea.parameters.eccDetail.curveID = TPM2_ECC_NIST_P256;
        public_info->publicArea.parameters.eccDetail.kdf.scheme = TPM2_ALG_NULL;
        public_info->publicArea.unique.ecc.x.size = 0;
        public_info->publicArea.unique.ecc.y.size = 0;
        break;
    case TPM2_ALG_SYMCIPHER: 
        public_info->publicArea.parameters.symDetail.sym.algorithm = TPM2_ALG_AES;
        public_info->publicArea.parameters.symDetail.sym.keyBits.sym = 128;
        public_info->publicArea.parameters.symDetail.sym.mode.sym = TPM2_ALG_CFB;
        public_info->publicArea.unique.sym.size = 0;
        break;
    default:
        log_error("type alg: 0x%0x not support !\n", public_info->publicArea.type);
        return -1;
    }
    return 0;
}

int
setup_create_alg(TPM2B_PUBLIC *public_info)
{
switch (public_info->publicArea.type)
    {
    case TPM2_ALG_RSA:
        public_info->publicArea.parameters.rsaDetail.symmetric.algorithm = TPM2_ALG_NULL;
        public_info->publicArea.parameters.rsaDetail.scheme.scheme = TPM2_ALG_NULL;
        public_info->publicArea.parameters.rsaDetail.keyBits = 2048;
        public_info->publicArea.parameters.rsaDetail.exponent = 0;
        public_info->publicArea.unique.rsa.size = 0;
        break;
    case TPM2_ALG_KEYEDHASH:
        public_info->publicArea.unique.keyedHash.size = 0;
        public_info->publicArea.objectAttributes &= ~TPMA_OBJECT_DECRYPT;
        public_info->publicArea.objectAttributes |= TPMA_OBJECT_SIGN_ENCRYPT;
        public_info->publicArea.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG_HMAC;
        public_info->publicArea.parameters.keyedHashDetail.scheme.details.hmac.hashAlg = public_info->publicArea.nameAlg;
        break;
    case TPM2_ALG_ECC:
        public_info->publicArea.parameters.eccDetail.symmetric.algorithm = TPM2_ALG_NULL;
        public_info->publicArea.parameters.eccDetail.scheme.scheme = TPM2_ALG_NULL;
        public_info->publicArea.parameters.eccDetail.curveID = TPM2_ECC_NIST_P256;
        public_info->publicArea.parameters.eccDetail.kdf.scheme = TPM2_ALG_NULL;
        public_info->publicArea.unique.ecc.x.size = 0;
        public_info->publicArea.unique.ecc.y.size = 0;
        break;
    case TPM2_ALG_SYMCIPHER: 
        public_info->publicArea.parameters.symDetail.sym.algorithm = TPM2_ALG_AES;
        public_info->publicArea.parameters.symDetail.sym.keyBits.sym = 128;
        public_info->publicArea.parameters.symDetail.sym.mode.sym = TPM2_ALG_CFB;
        public_info->publicArea.unique.sym.size = 0;
        break;
    default:
        printf("type alg: 0x%0x not support !\n", public_info->publicArea.type);
        return -1;
    }
    return 0;
}

UINT32 
tpm2_util_endian_swap_32(UINT32 data) {
    UINT32 converted;
    UINT8 *bytes = (UINT8 *)&data;
    UINT8 *tmp = (UINT8 *)&converted;
    size_t i;
    for(i=0; i < sizeof(UINT32); i ++) {
        tmp[i] = bytes[sizeof(UINT32) - i - 1];
    }
    return converted;
}

int
setup_signature_scheme(
    TPMI_ALG_PUBLIC         key_type,
    TPMI_ALG_HASH           halg,
    TPMT_SIG_SCHEME        *scheme)
{
    switch (key_type)
    {
    case TPM2_ALG_RSA:
        scheme->scheme = TPM2_ALG_RSASSA;
        scheme->details.rsassa.hashAlg = halg;
        break;
    case TPM2_ALG_ECC:
        scheme->scheme = TPM2_ALG_ECDSA;
        scheme->details.ecdsa.hashAlg = halg;
        break;
    case TPM2_ALG_KEYEDHASH:
        scheme->scheme = TPM2_ALG_HMAC;
        scheme->details.hmac.hashAlg = halg;
        break;    
    default:
        log_error("Unknow key type, got: 0x%x", key_type);
        return -1;
    }
    return 0;
}

TSS2_RC
tpm_hash(
    TSS2_SYS_CONTEXT *sapi_context,
    TPMI_ALG_HASH     hash_type,
    TPM2_HANDLE       hierarchy,
    TPM2B_DIGEST     *digest,
    uint8_t          *message,
    uint32_t          message_len)
{
    TPM2B_MAX_BUFFER digest_buffer = TPM2B_TYPE_INIT(TPM2B_MAX_BUFFER, buffer);
    TPMT_TK_HASHCHECK validation;

    digest_buffer.size = message_len;
    memcpy (digest_buffer.buffer, message, message_len);

    return Tss2_Sys_Hash(sapi_context,
                         NULL,
                         &digest_buffer,
                         hash_type,
                         hierarchy,
                         digest,
                         &validation,
                         NULL);    
}

TSS2_RC 
tpm2_util_nv_read_public(TSS2_SYS_CONTEXT *sapi_context,
    TPMI_RH_NV_INDEX nv_index, TPM2B_NV_PUBLIC *nv_public) {

    TPM2B_NAME nv_name = TPM2B_TYPE_INIT(TPM2B_NAME, name);

    return Tss2_Sys_NV_ReadPublic(sapi_context, nv_index, NULL, nv_public,
            &nv_name, NULL);
}

TSS2_RC
tpm2_util_nv_max_buffer_size(TSS2_SYS_CONTEXT *sapi_context,
        UINT32 *size) {

    /* Get the maximum read block size */
    TPMS_CAPABILITY_DATA cap_data;
    TPMI_YES_NO more_data;
    TSS2_RC rval = TSS2_RETRY_EXP(
               Tss2_Sys_GetCapability (sapi_context, NULL,
                   TPM2_CAP_TPM_PROPERTIES, TPM2_PT_NV_BUFFER_MAX, 1,
                   &more_data, &cap_data, NULL));
    if (rval != TSS2_RC_SUCCESS) {
        log_error("Failed to query max transmission size via"
                  "Tss2_Sys_GetCapability. Error:0x%x", rval);
    } else {
        *size = cap_data.data.tpmProperties.tpmProperty[0].value;
    }

    return rval;
}