#include <stdarg.h>

#include "tc_private/tc_handle.h"
#include "tc_private/tc_private.h"

#include "log/log.h"

#include "tc_tpm2.h"
#include "tpm2_common.h"

#include "tc_type.h"
#include "tc_errcode.h"

struct tpm2_sign_ctx
{
    TC_HANDLE        handle;
    uint32_t         key_index;
    TC_BUFFER       *key_auth_msg;
    TC_ALG           alg_sign;
    TC_ALG           alg_hash;
    TC_BUFFER       *plain_text;
    TC_BUFFER       *sign_text;
};

TC_RC tpm2_sign_init(struct api_ctx_st *api_ctx, int num, ...)
{
    TC_RC rc = TC_SUCCESS;
    struct tpm2_sign_ctx* sctx = (struct tpm2_sign_ctx*)malloc(sizeof(struct tpm2_sign_ctx));

    va_list ap;
    va_start(ap, num);
    sctx->handle = va_arg(ap, TC_HANDLE);
    sctx->key_index = va_arg(ap, uint32_t);
    sctx->key_auth_msg = va_arg(ap, TC_BUFFER*);
    sctx->alg_sign = va_arg(ap, TC_ALG);
    sctx->alg_hash = va_arg(ap, TC_ALG);
    sctx->plain_text = va_arg(ap, TC_BUFFER*);
    sctx->sign_text = va_arg(ap, TC_BUFFER*);
    va_end(ap);

    api_ctx->data = (HANDLE_DATA*)sctx;
    return rc;
}

TC_RC tpm2_sign_free(struct api_ctx_st *api_ctx)
{
    TC_RC rc = TC_SUCCESS;  
    free(api_ctx->data); 
    api_ctx->data = NULL;
    api_ctx->cmd_code = API_NULL;
    return rc;
}

TC_RC tpm2_sign(API_CTX *ctx)
{
    TC_RC rc = TC_SUCCESS;

    struct tpm2_sign_ctx* sctx = (struct tpm2_sign_ctx*)ctx->data;
    TC_HANDLE_CTX* tc_handle_ctx = (TC_HANDLE_CTX*)(sctx->handle);

    TPM2B_DIGEST digest = TPM2B_TYPE_INIT(TPM2B_DIGEST, buffer);
    TPMT_SIGNATURE signature;
    TPMT_SIG_SCHEME in_scheme;
    TPMT_TK_HASHCHECK validation;
    TSS2L_SYS_AUTH_RESPONSE sessionsDataout;
    TSS2L_SYS_AUTH_COMMAND sessionsData = {
        .auths    = {{.sessionHandle = TPM2_RS_PW}},
        .count    = 1
    };
    validation.tag = TPM2_ST_HASHCHECK;
    validation.hierarchy = TPM2_RH_OWNER;
    if (sctx->key_auth_msg != NULL) {
        if (sctx->key_auth_msg->size > sizeof(TPMU_HA)) {
            log_error("The length of the key authorization authentication password exceeds the limit\n");
            return TC_AUTH_HMAC_OVERSIZE;
        }
        sessionsData.auths[0].hmac.size = sctx->key_auth_msg->size;
        memcpy(sessionsData.auths[0].hmac.buffer,
               sctx->key_auth_msg->buffer,
               sctx->key_auth_msg->size);
    }
    if (sctx->plain_text->size > TPM2_MAX_DIGEST_BUFFER) {
        log_error("The length of the data to be signed exceeds the limit\n");
        return TC_SIGN_BUFFER_OVERSIZE;
    }

    if (sctx->key_index > MAX_OBJECT_NODE_COUNT) {
        TPMI_ALG_HASH   nameAlg;
        TPMI_ALG_PUBLIC signAlg;

        switch (sctx->alg_hash)
        {
        case TC_SHA256:
            nameAlg = TPM2_ALG_SHA256;
            break;
        case TC_SM3:
            nameAlg = TPM2_ALG_SM3_256;
            break;
        case TC_SHA1:
            nameAlg = TPM2_ALG_SHA1;
            break;
        default:
            log_error("unrecogize the tpm2_hash algorithms, %d\n", sctx->alg_hash);
            return TC_UNDEFINE_ALGO;
        }

        rc = tpm_hash((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                       nameAlg,
                       TPM2_RH_OWNER,
                       &digest,
                       sctx->plain_text->buffer,
                       sctx->plain_text->size);
        if (rc != TSS2_RC_SUCCESS) {
            log_error("Failed to hash message\n");
            goto end;
        }

        switch (sctx->alg_sign)
        {
        case TC_RSA:
            signAlg = TPM2_ALG_RSA;
            break;
        case TC_SM2:
            signAlg = TPM2_ALG_SM2;
            break;
        case TC_SYMMETRIC:
            signAlg = TPM2_ALG_SYMCIPHER;
            break;        
        default:
            log_error("unrecogize the crypto algorithms, %d\n", sctx->alg_sign);
            return TC_UNDEFINE_ALGO;
        }

        if (setup_signature_scheme(signAlg, nameAlg, &in_scheme) ) {
            log_error("setup signature scheme failed\n");
            return TC_UNDEFINE_ALGO;
        }

        rc = Tss2_Sys_Sign((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                            sctx->key_index,
                            &sessionsData,
                            &digest,
                            &in_scheme,
                            &validation,
                            &signature,
                            &sessionsDataout);
    }else{
        if (sctx->key_index > tc_handle_ctx->handle.tc_object->count) {
            log_error("Invalid object index\n");
            return TC_OBJECT_INDEX;
        }

        if (setup_signature_scheme(
                tc_handle_ctx->handle.tc_object->node_info[sctx->key_index]->alg_object,
                tc_handle_ctx->handle.tc_object->node_info[sctx->key_index]->name_hash_alg,
                &in_scheme)) {
            log_error("setup signature scheme failed\n");
            return TC_UNDEFINE_ALGO;
        }

        rc = tpm_hash((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                      tc_handle_ctx->handle.tc_object->node_info[sctx->key_index]->name_hash_alg,
                      TPM2_RH_OWNER,
                      &digest,
                      sctx->plain_text->buffer,
                      sctx->plain_text->size);
        if (rc != TSS2_RC_SUCCESS) {
            log_error("Failed to hash message\n");
            goto end;
        }

        rc = Tss2_Sys_Sign((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                            tc_handle_ctx->handle.tc_object->node_info[sctx->key_index]->obj_handle,
                            &sessionsData,
                            &digest,
                            &in_scheme,
                            &validation,
                            &signature,
                            &sessionsDataout);
    }

    sctx->sign_text->buffer = 
        (uint8_t*)malloc(signature.signature.rsapss.sig.size);
    memcpy(sctx->sign_text->buffer,
           signature.signature.rsapss.sig.buffer,
           signature.signature.rsapss.sig.size);
    sctx->sign_text->size = signature.signature.rsapss.sig.size;
end:
    if (rc != TSS2_RC_SUCCESS) {
        log_error("Failed to run api_sign:0x%0x\n", rc);
        rc = TC_COMMAND_SIGN;
    }
    return rc;
}