// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "modeltasks.h"
#include "httpserver.h"
#include "embeddingproxy.h"
#include "llmproxy.h"

#include <QJsonParseError>
#include <QJsonObject>
#include <QJsonArray>
#include <QDebug>


#include <inja/inja.hpp>
#include <iostream>

GLOBAL_USE_NAMESPACE

HttpTask::HttpTask(ModelRunner *r, const QString &req, ModelProxy *m, HttpContext *c)
    : ModelTask(r)
    , reqBody(req)
    , model(m)
    , ctx(c)
{

}

void EmbTask::doTask()
{
    EmbeddingProxy *emb = dynamic_cast<EmbeddingProxy *>(model);
    Q_ASSERT(emb);

    QJsonParseError er;
    auto doc = QJsonDocument::fromJson(reqBody.toUtf8(), &er);
    if (er.error == QJsonParseError::NoError) {
        QStringList prompts = doc.object().toVariantHash().value("input").toStringList();

        // check input size
        QString error;
        std::list<std::string> stdPrompts;
        for (int i = 0; i < prompts.size(); ++i) {
            int pmptSize = prompts.at(i).size();
            if (pmptSize > 5120) {
                error = QString("the input prompt %0 is too large: %1").arg(i).arg(pmptSize);
                std::cerr << error.toStdString() << std::endl;
                break;
            }
            stdPrompts.push_back(prompts.at(i).toStdString());
        }
        if (error.isEmpty()) {
            // return json as openai embeddings api.
            QJsonObject root;
            QJsonArray arry;
            if (!prompts.isEmpty()) {
                auto tokens = emb->tokenize(stdPrompts);
                auto out = emb->embedding(tokens);

                int i = 0;
                for (auto it = out.begin(); it != out.end(); ++it) {
                    QJsonObject embObj;
                    embObj.insert("object", "embedding");
                    embObj.insert("index", i++);
                    QJsonArray embValue;
                    for (const float &v : *it)
                        embValue.append(v);
                    embObj.insert("embedding", embValue);
                    arry.append(embObj);
                }
            }

            root.insert("data", arry);
            root.insert("model", QString::fromStdString(emb->name()));
            root.insert("object", "list");
            HttpServer::setContent(ctx, QString::fromUtf8(QJsonDocument(root).toJson(QJsonDocument::Compact)));
        } else {
            HttpServer::setStatus(ctx, 403);
            QJsonObject obj;
            obj.insert("invalid_request_error", error);
            HttpServer::setContent(ctx, QString::fromUtf8(QJsonDocument(obj).toJson(QJsonDocument::Compact)));
        }
    } else {
        HttpServer::setStatus(ctx, 403);
        QJsonObject obj;
        obj.insert("invalid_request_error", "Invalid input content");
        HttpServer::setContent(ctx, QString::fromUtf8(QJsonDocument(obj).toJson(QJsonDocument::Compact)));
    }
}

ChatCompletionsTask::ChatCompletionsTask(ModelRunner *r, const QVariantHash &json, ModelProxy *m, HttpContext *c)
    : ModelTask(r)
    , root(json)
    , model(m)
    , ctx(c)
{

}

void ChatCompletionsTask::doTask()
{
    LLMProxy *llm = dynamic_cast<LLMProxy *>(model);
    Q_ASSERT(llm);

    QString prompt = formatPrompt();
    auto tokens = llm->tokenize(prompt.toStdString());
    auto token = llm->generate(tokens, {}, nullptr, nullptr);
    QString content = QString::fromStdString(llm->detokenize({token}));
    QString templateStr = QString(R"({"choices":[{"index":0, "message":{"role":"assistant","content":"%0"}}]})")
            .arg(content);
    HttpServer::setContent(ctx, templateStr);
}

QString ChatCompletionsTask::formatPrompt()
{
    auto messages = root.value("messages").value<QVariantList>();
    QString prompt = "";
    auto temp = runner->chatTmpl.toStdString();

    if (temp.empty()) {
        for (const QVariant &line: messages) {
            auto map = line.value<QVariantMap>();
            QString content = map.value("content").toString();
            prompt.append(QString("%0\n").arg(content));
        }
        return prompt;
    }

    struct Prompt{
        std::string system;
        std::string prompt;
        std::string response;
    };

    QList<Prompt> prompts;
    Prompt tmp;
    for (const QVariant &line: messages) {
        auto map = line.value<QVariantMap>();
        QString role = map.value("role").toString();
        QString content = map.value("content").toString();

        if (role.toLower() == "user") {
            if (!tmp.prompt.empty() || !tmp.response.empty()) {
                prompts.append(tmp);
                tmp = Prompt();
            }

            tmp.prompt = content.toStdString();
        } else if (role.toLower() == "assistant") {
            if (!tmp.response.empty()) {
                prompts.append(tmp);
                tmp = Prompt();
            }

            tmp.response = content.toStdString();
        } else if (role.toLower() == "system") {
            if (!tmp.system.empty() || !tmp.prompt.empty() || !tmp.response.empty()) {
                prompts.append(tmp);
                tmp = Prompt();
            }

            tmp.system = content.toStdString();
        }
    }

    // the final
    if (!tmp.system.empty() || !tmp.prompt.empty() || !tmp.response.empty()) {
        prompts.append(tmp);
        tmp = Prompt();
    }

    try {
        inja::Environment env;
        // add a space for line statement "##" to enable "###"
        env.set_line_statement(inja::LexerConfig().line_statement + " ");
        for (const Prompt &tmp: prompts) {
            inja::json json;
            json["System"] = tmp.system;
            json["Prompt"] = tmp.prompt;
            json["Response"] = tmp.response;

            prompt.append(QString::fromStdString(env.render(temp, json)));
        }
    } catch (const std::exception &error) {
        std::cerr << "fail to format prompt, please check the template file of model " << error.what() << std::endl;
        return "";
    }

    return prompt;
}

ChatStreamTask::~ChatStreamTask()
{

}

void ChatStreamTask::doTask()
{
    LLMProxy *llm = dynamic_cast<LLMProxy *>(model);
    Q_ASSERT(llm);

    QString prompt = formatPrompt();
    if (prompt.isEmpty()) {
        {
            QMutexLocker lk(&genMtx);
            stop = true;
        }
        con.notify_all();
        return;
    }

    auto tokens = llm->tokenize(prompt.toStdString());
    auto append = [](const std::string &text, void *user) {
        ChatStreamTask *self = static_cast<ChatStreamTask *>(user);
        QMutexLocker lk(&self->genMtx);
        self->text.append(QString::fromStdString(text));
        lk.unlock();

        self->con.notify_all();
        return !self->stop;
    };

    auto alltoken = llm->generate(tokens, {}, *append, this);

    {
        QMutexLocker lk(&genMtx);
        stop = true;
    }

    con.notify_all();
}
