// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "ovembproxy.h"

GLOBAL_USE_NAMESPACE

OVEmbProxy::OVEmbProxy(const std::string &name, ov::Core *core)
    : EmbeddingProxy()
    , modelName(name)
    , ovCore(core)
{
    Q_ASSERT(core);
}

OVEmbProxy::~OVEmbProxy()
{

}

std::string OVEmbProxy::name() const
{
    return modelName;
}

std::list<std::vector<int32_t>> OVEmbProxy::tokenize(const std::list<std::string> &prompt, const std::map<std::string, std::string> &params)
{
    auto tokenizer = tokenizerModel.create_infer_request();

    std::list<std::vector<int32_t>> out;
    // todo 并发多batch
    for (std::string pmpt : prompt) {
        tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {1}, &pmpt});
        tokenizer.infer();

        auto tensor = tokenizer.get_output_tensor(0);
        tokenizer.reset_state();
        {
            int batch = tensor.get_shape().at(0);
            int dim = tensor.get_shape().at(1);
            // 每个batch的字节数
            const int stride = tensor.get_strides().at(0);
            for (int i = 0; i < batch; ++i) {
                std::vector<int32_t> tmp;
                tmp.resize(dim);
                char *begin = (char *)tensor.data();
                memcpy(tmp.data(), begin + i * stride, stride);
                out.push_back(tmp);
            }
        }
    }

    return out;
}

std::list<std::vector<float>> OVEmbProxy::embedding(const std::list<std::vector<int32_t>> &tokens, const std::map<std::string, std::string> &params)
{
    std::list<std::vector<float>> out;
    auto model = embModel.create_infer_request();
    for (const std::vector<int32_t> &token : tokens) {
        int32_t *dataPtr = const_cast<int32_t *>(token.data());
        ov::Tensor input_ids = ov::Tensor(ov::element::i32, ov::Shape{1, token.size()}, dataPtr);
        model.set_input_tensor(0, input_ids);

        //构建mask
        QVector<int32_t> mask(token.size(), 1);
        model.set_input_tensor(1, ov::Tensor(ov::element::i32, ov::Shape{1, mask.size()}, mask.data()));

        model.infer();
        const ov::Tensor& output_tensor = model.get_output_tensor();

        {
            int batch = output_tensor.get_shape().at(0);
            int dim = output_tensor.get_shape().at(1);
            const int stride = output_tensor.get_strides().at(0);
            for (int i = 0; i < batch; ++i) {
                std::vector<float> tmp;
                tmp.resize(dim);
                char *begin = (char *)output_tensor.data();
                memcpy(tmp.data(), begin + i * stride, stride);
                out.push_back(tmp);
            }
        }
    }

    return out;
}

bool OVEmbProxy::initialize(const QString &model, const QString &tokenizer, const QVariantHash &params)
{
    if (model.isEmpty() || tokenizer.isEmpty())
        return false;

    ovCore->add_extension("libopenvino_tokenizers.so");

    embModel = ovCore->compile_model(ovCore->read_model(model.toStdString()), "AUTO");
    tokenizerModel = ovCore->compile_model(ovCore->read_model(tokenizer.toStdString()), "AUTO");
    return true;
}
