From d1b4aba07815fd84d5017b05a90127e0bc903268 Mon Sep 17 00:00:00 2001 From: ZHEQIUSHUI Date: Fri, 26 Apr 2024 16:27:13 +0800 Subject: [PATCH] remove useless code --- CMakeLists.txt | 3 +- src/runner/Tokenizer/Tokenizer.cpp | 128 +++--- src/runner/Tokenizer/chatglm.cpp | 633 ----------------------------- src/runner/Tokenizer/chatglm.h | 258 ------------ 4 files changed, 65 insertions(+), 957 deletions(-) delete mode 100644 src/runner/Tokenizer/chatglm.cpp delete mode 100644 src/runner/Tokenizer/chatglm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 2afc6f6..4eeb0bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,8 +59,7 @@ function(build_exec name main_source) src/runner/utils/memory_utils.cpp src/runner/utils/cqdm.cpp src/runner/Tokenizer/Tokenizer.cpp - src/runner/Tokenizer/QwenTokenizer.cpp - src/runner/Tokenizer/chatglm.cpp) + src/runner/Tokenizer/QwenTokenizer.cpp) target_link_libraries(${name} ax_engine ax_interpreter ax_sys ax_ivps) target_link_libraries(${name} sentencepiece re2::re2) diff --git a/src/runner/Tokenizer/Tokenizer.cpp b/src/runner/Tokenizer/Tokenizer.cpp index 5b7b61f..b5dbf4a 100644 --- a/src/runner/Tokenizer/Tokenizer.cpp +++ b/src/runner/Tokenizer/Tokenizer.cpp @@ -5,7 +5,7 @@ #include "QwenTokenizer.hpp" -#include "chatglm.h" +// #include "chatglm.h" #include "httplib.h" #include "json.hpp" @@ -231,69 +231,69 @@ class TokenizerQwen : public BaseTokenizer } }; -class TokenizerGLM3 : public BaseTokenizer -{ - std::shared_ptr sp; - bool _b_bos, _b_eos; - -private: - /* data */ -public: - bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override - { - if (!file_exist(model_path)) - { - ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); - return false; - } - // std::vector sp_model_data; - // read_file(model_path, sp_model_data); - // std::string_view serialized_model_proto(sp_model_data.data(), sp_model_data.size()); - - sp.reset(new chatglm::ChatGLM3Tokenizer(model_path)); - - this->_b_bos = b_bos; - this->_b_eos = b_eos; - return true; - } - - bool Encode(std::string input, std::vector &output) override - { - if (_b_bos) - { - // input += "<|im_start|>"; - } - if (_b_eos) - { - // input += "<|endoftext|>"; - } - output = sp->encode(input, 1024); - - return true; - } - - std::vector Encode(std::string input) override - { - std::vector output; - Encode(input, output); - return output; - } - - std::string Decode(const std::vector input) override - { - return sp->decode(input); - } - - int GetBosID() override - { - return sp->sp.bos_id(); - } - - int GetEosID() override - { - return sp->sp.eos_id(); - } -}; +// class TokenizerGLM3 : public BaseTokenizer +// { +// std::shared_ptr sp; +// bool _b_bos, _b_eos; + +// private: +// /* data */ +// public: +// bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override +// { +// if (!file_exist(model_path)) +// { +// ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); +// return false; +// } +// // std::vector sp_model_data; +// // read_file(model_path, sp_model_data); +// // std::string_view serialized_model_proto(sp_model_data.data(), sp_model_data.size()); + +// sp.reset(new chatglm::ChatGLM3Tokenizer(model_path)); + +// this->_b_bos = b_bos; +// this->_b_eos = b_eos; +// return true; +// } + +// bool Encode(std::string input, std::vector &output) override +// { +// if (_b_bos) +// { +// // input += "<|im_start|>"; +// } +// if (_b_eos) +// { +// // input += "<|endoftext|>"; +// } +// output = sp->encode(input, 1024); + +// return true; +// } + +// std::vector Encode(std::string input) override +// { +// std::vector output; +// Encode(input, output); +// return output; +// } + +// std::string Decode(const std::vector input) override +// { +// return sp->decode(input); +// } + +// int GetBosID() override +// { +// return sp->sp.bos_id(); +// } + +// int GetEosID() override +// { +// return sp->sp.eos_id(); +// } +// }; class Tokenizer_Http : public BaseTokenizer { diff --git a/src/runner/Tokenizer/chatglm.cpp b/src/runner/Tokenizer/chatglm.cpp deleted file mode 100644 index 40b25f9..0000000 --- a/src/runner/Tokenizer/chatglm.cpp +++ /dev/null @@ -1,633 +0,0 @@ -#include "chatglm.h" -#include "../utils/sample_log.h" - -#include -#include -namespace chatglm -{ - - class LogMessageFatal - { - public: - LogMessageFatal(const char *file, int line) { oss_ << file << ':' << line << ' '; } - [[noreturn]] ~LogMessageFatal() noexcept(false) { throw std::runtime_error(oss_.str()); } - std::ostringstream &stream() { return oss_; } - - private: - std::ostringstream oss_; - }; - -#define CHATGLM_THROW chatglm::LogMessageFatal(__FILE__, __LINE__).stream() -#define CHATGLM_CHECK(cond) \ - if (!(cond)) \ - CHATGLM_THROW << "check failed (" #cond ") " - - const std::string ToolCallMessage::TYPE_FUNCTION = "function"; - const std::string ToolCallMessage::TYPE_CODE = "code"; - - const std::string ChatMessage::ROLE_USER = "user"; - const std::string ChatMessage::ROLE_ASSISTANT = "assistant"; - const std::string ChatMessage::ROLE_SYSTEM = "system"; - const std::string ChatMessage::ROLE_OBSERVATION = "observation"; - - // trim from start (in place) - static inline void ltrim(std::string &s) - { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) - { return !std::isspace(ch); })); - } - - // trim from end (in place) - static inline void rtrim(std::string &s) - { - s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) - { return !std::isspace(ch); }) - .base(), - s.end()); - } - - // trim from both ends (in place) - static inline void trim(std::string &s) - { - rtrim(s); - ltrim(s); - } - - void BaseTokenizer::check_chat_messages(const std::vector &messages) - { - CHATGLM_CHECK(messages.size() % 2 == 1) << "invalid chat messages size " << messages.size(); - for (size_t i = 0; i < messages.size(); i++) - { - const std::string &target_role = (i % 2 == 0) ? ChatMessage::ROLE_USER : ChatMessage::ROLE_ASSISTANT; - CHATGLM_CHECK(messages[i].role == target_role) - << "expect messages[" << i << "].role to be " << target_role << ", but got " << messages[i].role; - } - } - - // ===== ChatGLM-6B ===== - - ChatGLMTokenizer::ChatGLMTokenizer(std::string_view serialized_model_proto) - { - const auto status = sp.LoadFromSerializedProto(serialized_model_proto); - CHATGLM_CHECK(status.ok()) << status.ToString(); - - bos_token_id = sp.PieceToId(""); - eos_token_id = sp.PieceToId(""); - mask_token_id = sp.PieceToId("[MASK]"); - gmask_token_id = sp.PieceToId("[gMASK]"); - pad_token_id = sp.PieceToId(""); - } - - std::vector ChatGLMTokenizer::encode(const std::string &text, int max_length) const - { - std::string input = preprocess(text); - std::vector ids; - sp.Encode(input, &ids); - ids.insert(ids.end(), {gmask_token_id, bos_token_id}); - if ((int)ids.size() > max_length) - { - // sliding window: always take the last max_length tokens - ids.erase(ids.begin(), ids.end() - max_length); - } - return ids; - } - - std::vector ChatGLMTokenizer::encode_messages(const std::vector &messages, int max_length) const - { - std::string prompt = build_prompt(messages); - std::vector input_ids = encode(prompt, max_length); - return input_ids; - } - - std::string ChatGLMTokenizer::build_prompt(const std::vector &messages) - { - check_chat_messages(messages); - - std::ostringstream oss_prompt; - if (messages.size() == 1) - { - oss_prompt << messages.front().content; - } - else - { - for (size_t i = 0; i < messages.size(); i += 2) - { - oss_prompt << "[Round " << i / 2 << "]\n问:" << messages[i].content << "\n答:"; - if (i + 1 < messages.size()) - { - oss_prompt << messages[i + 1].content << "\n"; - } - } - } - return oss_prompt.str(); - } - - std::string ChatGLMTokenizer::decode(const std::vector &ids) const - { - std::string text; - sp.Decode(ids, &text); - text = postprocess(text); - return text; - } - - static std::string regex_replace(const std::string &input, const std::regex ®ex, - std::function format) - { - std::ostringstream oss; - int last_index = 0; - for (auto it = std::sregex_iterator(input.begin(), input.end(), regex); it != std::sregex_iterator(); it++) - { - oss << it->prefix() << format(*it); - last_index = it->position() + it->length(); - } - oss << input.substr(last_index); - return oss.str(); - } - - std::string ChatGLMTokenizer::preprocess(const std::string &text) - { - std::string output; - - // newline token - { - static const std::regex newline_regex("\n"); - output = std::regex_replace(text, newline_regex, ""); - } - // tab token - { - static const std::regex tab_regex("\t"); - output = std::regex_replace(output, tab_regex, "<|tab|>"); - } - // blank tokens - { - static const std::regex pattern(R"([ ]{2,80})"); - output = regex_replace(output, pattern, [](const std::smatch &sm) - { - std::ostringstream oss; - oss << "<|blank_" << sm.str().size() << "|>"; - return oss.str(); }); - } - - return output; - } - - static inline std::string replace_punctuations(const std::string &text) - { - // reference: https://stackoverflow.com/questions/37989081/how-to-use-unicode-range-in-c-regex - static std::wstring_convert> converter; - static const std::vector> punct_map{ - {std::wregex(converter.from_bytes(R"(([\u4e00-\u9fff]),)")), converter.from_bytes("$1,")}, - {std::wregex(converter.from_bytes(R"(,([\u4e00-\u9fff]))")), converter.from_bytes(",$1")}, - {std::wregex(converter.from_bytes(R"(([\u4e00-\u9fff])!)")), converter.from_bytes("$1!")}, - {std::wregex(converter.from_bytes(R"(!([\u4e00-\u9fff]))")), converter.from_bytes("!$1")}, - {std::wregex(converter.from_bytes(R"(([\u4e00-\u9fff]):)")), converter.from_bytes("$1:")}, - {std::wregex(converter.from_bytes(R"(:([\u4e00-\u9fff]))")), converter.from_bytes(":$1")}, - {std::wregex(converter.from_bytes(R"(([\u4e00-\u9fff]);)")), converter.from_bytes("$1;")}, - {std::wregex(converter.from_bytes(R"(;([\u4e00-\u9fff]))")), converter.from_bytes(";$1")}, - {std::wregex(converter.from_bytes(R"(([\u4e00-\u9fff])\?)")), converter.from_bytes("$1?")}, - {std::wregex(converter.from_bytes(R"(\?([\u4e00-\u9fff]))")), converter.from_bytes("?$1")}, - }; - std::wstring w_output = converter.from_bytes(text); - for (const auto &punct_pair : punct_map) - { - w_output = std::regex_replace(w_output, punct_pair.first, punct_pair.second); - } - std::string output = converter.to_bytes(w_output); - return output; - } - - std::string ChatGLMTokenizer::postprocess(const std::string &text) - { - std::string output; - - // newline token - { - static const std::regex pattern(R"()"); - output = std::regex_replace(text, pattern, "\n"); - } - // tab token - { - static const std::regex pattern(R"(<\|tab\|>)"); - output = std::regex_replace(output, pattern, "\t"); - } - // blank tokens - { - static const std::regex pattern(R"(<\|blank_(\d+)\|>)"); - output = regex_replace(output, pattern, - [](const std::smatch &sm) - { return std::string(std::stoi(sm[1].str()), ' '); }); - } - // punctuations - output = replace_punctuations(output); - - return output; - } - - // ===== ChatGLM2-6B ===== - - ChatGLM2Tokenizer::ChatGLM2Tokenizer(std::string_view serialized_model_proto) - { - const auto status = sp.LoadFromSerializedProto(serialized_model_proto); - CHATGLM_CHECK(status.ok()) << status.ToString(); - - int special_id = sp.GetPieceSize(); - mask_token_id = special_id++; - gmask_token_id = special_id++; - smask_token_id = special_id++; - sop_token_id = special_id++; - eop_token_id = special_id++; - } - - std::vector ChatGLM2Tokenizer::encode(const std::string &text, int max_length) const - { - std::vector ids; - sp.Encode(text, &ids); - ids.insert(ids.begin(), {gmask_token_id, sop_token_id}); // special prefix - if ((int)ids.size() > max_length) - { - // sliding window: drop the least recent history while keeping the two special prefix tokens - int num_drop = (int)ids.size() - max_length; - ids.erase(ids.begin() + 2, ids.begin() + 2 + num_drop); - } - return ids; - } - - std::string ChatGLM2Tokenizer::decode(const std::vector &ids) const - { - // filter out special tokens - std::vector normal_ids(ids); - normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) - { return is_special_id(id); }), - normal_ids.end()); - - std::string text; - sp.Decode(normal_ids, &text); - text = replace_punctuations(text); - return text; - } - - std::vector ChatGLM2Tokenizer::encode_messages(const std::vector &messages, int max_length) const - { - std::string prompt = build_prompt(messages); - std::vector input_ids = encode(prompt, max_length); - return input_ids; - } - - std::string ChatGLM2Tokenizer::build_prompt(const std::vector &messages) - { - check_chat_messages(messages); - - std::ostringstream oss_prompt; - for (size_t i = 0; i < messages.size(); i += 2) - { - oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << messages[i].content << "\n\n答:"; - if (i < messages.size() - 1) - { - oss_prompt << messages[i + 1].content << "\n\n"; - } - } - return oss_prompt.str(); - } - - bool ChatGLM2Tokenizer::is_special_id(int id) const - { - return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id || - id == eop_token_id; - } - - // ===== ChatGLM3-6B ===== - - ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) - { - const auto status = sp.LoadFromSerializedProto(serialized_model_proto); - CHATGLM_CHECK(status.ok()) << status.ToString(); - - int special_id = sp.GetPieceSize(); - mask_token_id = special_id++; - gmask_token_id = special_id++; - smask_token_id = special_id++; - sop_token_id = special_id++; - eop_token_id = special_id++; - system_token_id = special_id++; - user_token_id = special_id++; - assistant_token_id = special_id++; - observation_token_id = special_id++; - - special_tokens = { - {"[MASK]", mask_token_id}, - {"[gMASK]", gmask_token_id}, - {"[sMASK]", smask_token_id}, - {"sop", sop_token_id}, - {"eop", eop_token_id}, - {"<|system|>", system_token_id}, - {"<|user|>", user_token_id}, - {"<|assistant|>", assistant_token_id}, - {"<|observation|>", observation_token_id}, - }; - - for (const auto &item : special_tokens) - { - index_special_tokens[item.second] = item.first; - } - } - - std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const - { - std::vector ids; - sp.Encode(text, &ids); - ids.insert(ids.begin(), {gmask_token_id, sop_token_id}); // special prefix - truncate(ids, max_length); - return ids; - } - - std::string ChatGLM3Tokenizer::decode(const std::vector &ids) const - { - std::string text = decode_with_special_tokens(ids); - text = remove_special_tokens(text); - return text; - } - - std::string ChatGLM3Tokenizer::decode_with_special_tokens(const std::vector &ids) const - { - std::vector pieces; - for (int id : ids) - { - auto pos = index_special_tokens.find(id); - if (pos != index_special_tokens.end()) - { - // special tokens - pieces.emplace_back(pos->second); - } - else - { - // normal tokens - pieces.emplace_back(sp.IdToPiece(id)); - } - } - - std::string text = sp.DecodePieces(pieces); - return text; - } - - std::string ChatGLM3Tokenizer::remove_special_tokens(const std::string &text) - { - std::string output = text; - static const std::vector special_token_regex{ - // std::regex(R"(<\|assistant\|> interpreter)"), - // std::regex(R"(<\|assistant\|> interpre)"), - std::regex(R"(<\|assistant\|>)"), - std::regex(R"(<\|user\|>)"), - std::regex(R"(<\|observation\|>)"), - }; - for (const auto &re : special_token_regex) - { - output = std::regex_replace(output, re, ""); - } - return output; - } - - std::vector ChatGLM3Tokenizer::encode_single_message(const std::string &role, const std::string &content) const - { - std::vector input_ids; - input_ids.emplace_back(get_command("<|" + role + "|>")); - // TODO: support metadata - std::vector newline_ids; - sp.Encode("\n", &newline_ids); - input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); - std::vector content_ids; - sp.Encode(content, &content_ids); - input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end()); - return input_ids; - } - - std::vector ChatGLM3Tokenizer::encode_messages(const std::vector &messages, int max_length) const - { - std::vector input_ids{gmask_token_id, sop_token_id}; - for (const auto &msg : messages) - { - auto msg_ids = encode_single_message(msg.role, msg.content); - input_ids.insert(input_ids.end(), msg_ids.begin(), msg_ids.end()); - - // encode code block into a separate message - if (!msg.tool_calls.empty() && msg.tool_calls.front().type == ToolCallMessage::TYPE_CODE) - { - auto code_ids = encode_single_message(msg.role, msg.tool_calls.front().code.input); - input_ids.insert(input_ids.end(), code_ids.begin(), code_ids.end()); - } - } - input_ids.emplace_back(assistant_token_id); - truncate(input_ids, max_length); - return input_ids; - } - - ChatMessage ChatGLM3Tokenizer::decode_message(const std::vector &ids) const - { - ChatMessage message; - if (!ids.empty() && ids.back() == observation_token_id) - { - // insert an <|assistant|> token before content to match possible interpreter delimiter - std::vector full_ids{assistant_token_id}; - full_ids.insert(full_ids.end(), ids.begin(), ids.end()); - - std::string output = decode_with_special_tokens(full_ids); - const std::string ci_delim = "<|assistant|> interpreter"; - size_t ci_pos = output.find(ci_delim); - if (ci_pos != std::string::npos) - { - // code interpreter - std::string chat_output = output.substr(0, ci_pos); - chat_output = remove_special_tokens(chat_output); - trim(chat_output); - std::string code_output = output.substr(ci_pos + ci_delim.size()); - code_output = remove_special_tokens(code_output); - trim(code_output); - message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(chat_output), - {ToolCallMessage(CodeMessage(std::move(code_output)))}); - } - else - { - // tool call - output = remove_special_tokens(output); - - // parse tool name - std::string tool_name = "PARSE_ERROR"; - size_t pos = output.find('\n'); - if (pos != std::string::npos) - { - // split tool name and args by 1st linebreak - tool_name = output.substr(0, pos); - trim(tool_name); - output.erase(0, pos + 1); - } - - // post process output - trim(output); - - // extract args - std::string tool_args = "PARSE_ERROR"; - static const std::regex args_regex(R"(```.*?\n(.*?)\n```)"); - std::smatch sm; - if (std::regex_search(output, sm, args_regex)) - { - CHATGLM_CHECK(sm.size() == 2) << "unexpected regex match results"; - tool_args = sm[1]; - } - - message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(output), - {ToolCallMessage(FunctionMessage(std::move(tool_name), std::move(tool_args)))}); - } - } - else - { - // conversation - message = BaseTokenizer::decode_message(ids); - trim(message.content); // strip leading linebreak in conversation mode - } - return message; - } - - int ChatGLM3Tokenizer::get_command(const std::string &token) const - { - auto pos = special_tokens.find(token); - CHATGLM_CHECK(pos != special_tokens.end()) << token << " is not a special token"; - return pos->second; - } - - bool ChatGLM3Tokenizer::is_special_id(int id) const { return index_special_tokens.count(id) > 0; } - - void ChatGLM3Tokenizer::truncate(std::vector &ids, int max_length) - { - if ((int)ids.size() > max_length) - { - // sliding window: drop the least recent history while keeping the two special prefix tokens - int num_drop = (int)ids.size() - max_length; - ids.erase(ids.begin() + 2, ids.begin() + 2 + num_drop); - } - } - - // ===== Baichuan ===== - - BaichuanTokenizer::BaichuanTokenizer(std::string_view serialized_model_proto) - { - const auto status = sp.LoadFromSerializedProto(serialized_model_proto); - CHATGLM_CHECK(status.ok()) << status.ToString(); - } - - std::vector BaichuanTokenizer::encode(const std::string &text, int max_length) const - { - std::vector ids; - sp.Encode(text, &ids); - truncate(ids, max_length); - return ids; - } - - std::string BaichuanTokenizer::decode(const std::vector &ids) const - { - std::vector normal_ids(ids); - normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) - { return is_special_id(id); }), - normal_ids.end()); - - std::string text; - sp.Decode(normal_ids, &text); - return text; - } - - std::vector BaichuanTokenizer::encode_messages(const std::vector &messages, int max_length) const - { - check_chat_messages(messages); - - std::vector ids; - ids.reserve(max_length); - for (const auto &msg : messages) - { - ids.push_back((msg.role == ChatMessage::ROLE_USER) ? USER_TOKEN_ID : ASSISTANT_TOKEN_ID); - std::vector content_ids = encode(msg.content, max_length); - ids.insert(ids.end(), content_ids.begin(), content_ids.end()); - } - ids.push_back(ASSISTANT_TOKEN_ID); - - truncate(ids, max_length); - return ids; - } - - bool BaichuanTokenizer::is_special_id(int id) const - { - return id == bos_token_id || id == eos_token_id || id == pad_token_id; - } - - void BaichuanTokenizer::truncate(std::vector &ids, int max_length) - { - if ((int)ids.size() > max_length) - { - ids.erase(ids.begin(), ids.end() - max_length); - } - } - - // ===== InternLM ===== - - InternLMTokenizer::InternLMTokenizer(std::string_view serialized_model_proto) - { - const auto status = sp.LoadFromSerializedProto(serialized_model_proto); - CHATGLM_CHECK(status.ok()) << status.ToString(); - } - - std::vector InternLMTokenizer::encode(const std::string &text, int max_length) const - { - std::vector ids; - sp.Encode(text, &ids); - ids.insert(ids.begin(), {bos_token_id}); // special prefix - if ((int)ids.size() > max_length) - { - // sliding window: drop the least recent history while keeping the special prefix - int num_drop = (int)ids.size() - max_length; - ids.erase(ids.begin() + 1, ids.begin() + 1 + num_drop); - } - return ids; - } - - std::string InternLMTokenizer::decode(const std::vector &ids) const - { - // filter out special tokens - std::vector normal_ids(ids); - normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) - { return is_special_id(id); }), - normal_ids.end()); - - std::string text; - sp.Decode(normal_ids, &text); - // remove and its following - size_t eoa_pos = text.find(""); - if (eoa_pos != std::string::npos) - { - text.erase(eoa_pos); - } - return text; - } - - std::vector InternLMTokenizer::encode_messages(const std::vector &messages, int max_length) const - { - std::string prompt = build_prompt(messages); - std::vector input_ids = encode(prompt, max_length); - return input_ids; - } - - std::string InternLMTokenizer::build_prompt(const std::vector &messages) - { - check_chat_messages(messages); - - std::ostringstream oss_prompt; - for (const auto &msg : messages) - { - if (msg.role == ChatMessage::ROLE_USER) - { - oss_prompt << "<|User|>:" << msg.content << "\n<|Bot|>:"; - } - else - { - oss_prompt << msg.content << "\n"; - } - } - return oss_prompt.str(); - } -} \ No newline at end of file diff --git a/src/runner/Tokenizer/chatglm.h b/src/runner/Tokenizer/chatglm.h deleted file mode 100644 index bcffc36..0000000 --- a/src/runner/Tokenizer/chatglm.h +++ /dev/null @@ -1,258 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -namespace chatglm -{ - - struct FunctionMessage - { - std::string name; - std::string arguments; - - FunctionMessage() = default; - FunctionMessage(std::string name, std::string arguments) : name(std::move(name)), arguments(std::move(arguments)) {} - - friend std::ostream &operator<<(std::ostream &os, const FunctionMessage &self) - { - return os << "FunctionMessage(name=" << std::quoted(self.name) << ", arguments=" << std::quoted(self.arguments) - << ")"; - } - }; - - struct CodeMessage - { - std::string input; - - CodeMessage() = default; - CodeMessage(std::string input) : input(std::move(input)) {} - - friend std::ostream &operator<<(std::ostream &os, const CodeMessage &self) - { - return os << "CodeMessage(input=" << std::quoted(self.input) << ")"; - } - }; - - struct ToolCallMessage - { - std::string type; - FunctionMessage function; - CodeMessage code; - - static const std::string TYPE_FUNCTION; - static const std::string TYPE_CODE; - - ToolCallMessage(FunctionMessage function) : type(TYPE_FUNCTION), function(std::move(function)) {} - - ToolCallMessage(CodeMessage code) : type(TYPE_CODE), code(std::move(code)) {} - - friend std::ostream &operator<<(std::ostream &os, const ToolCallMessage &self) - { - return os << "ToolCallMessage(type=" << std::quoted(self.type) << ", function=" << self.function - << ", code=" << self.code << ")"; - } - }; - - struct ChatMessage - { - std::string role; - std::string content; - std::vector tool_calls; - - static const std::string ROLE_USER; - static const std::string ROLE_ASSISTANT; - static const std::string ROLE_SYSTEM; - static const std::string ROLE_OBSERVATION; - - ChatMessage() = default; - ChatMessage(std::string role, std::string content, std::vector tool_calls = {}) - : role(std::move(role)), content(std::move(content)), tool_calls(std::move(tool_calls)) {} - - friend std::ostream &operator<<(std::ostream &os, const ChatMessage &self) - { - os << "ChatMessage(role=" << std::quoted(self.role) << ", content=" << std::quoted(self.content) - << ", tool_calls=["; - for (size_t i = 0; i < self.tool_calls.size(); i++) - { - os << (i > 0 ? ", " : "") << self.tool_calls[i]; - } - return os << "])"; - } - }; - - class BaseTokenizer - { - public: - virtual ~BaseTokenizer() = default; - - virtual std::vector encode(const std::string &text, int max_length) const = 0; - - virtual std::string decode(const std::vector &ids) const = 0; - - virtual std::vector encode_messages(const std::vector &messages, int max_length) const = 0; - - virtual ChatMessage decode_message(const std::vector &ids) const - { - return {ChatMessage::ROLE_ASSISTANT, decode(ids)}; - } - - protected: - static void check_chat_messages(const std::vector &messages); - }; - - // ===== ChatGLM-6B ===== - - class ChatGLMTokenizer : public BaseTokenizer - { - public: - ChatGLMTokenizer(std::string_view serialized_model_proto); - - std::vector encode(const std::string &text, int max_length) const override; - - std::string decode(const std::vector &ids) const override; - - std::vector encode_messages(const std::vector &messages, int max_length) const override; - - static std::string build_prompt(const std::vector &messages); - - private: - static std::string preprocess(const std::string &text); - - static std::string postprocess(const std::string &text); - - public: - sentencepiece::SentencePieceProcessor sp; - int bos_token_id; - int eos_token_id; - int mask_token_id; - int gmask_token_id; - int pad_token_id; - }; - - // ===== ChatGLM2-6B ===== - - class ChatGLM2Tokenizer : public BaseTokenizer - { - public: - ChatGLM2Tokenizer(std::string_view serialized_model_proto); - - std::vector encode(const std::string &text, int max_length) const override; - - std::string decode(const std::vector &ids) const override; - - std::vector encode_messages(const std::vector &messages, int max_length) const override; - - static std::string build_prompt(const std::vector &messages); - - private: - bool is_special_id(int id) const; - - public: - sentencepiece::SentencePieceProcessor sp; - int mask_token_id; - int gmask_token_id; - int smask_token_id; - int sop_token_id; - int eop_token_id; - }; - - // ===== ChatGLM3-6B ===== - - class ChatGLM3Tokenizer : public BaseTokenizer - { - public: - ChatGLM3Tokenizer(std::string_view serialized_model_proto); - - std::vector encode(const std::string &text, int max_length) const override; - - std::string decode(const std::vector &ids) const override; - - std::vector encode_messages(const std::vector &messages, int max_length) const override; - - ChatMessage decode_message(const std::vector &ids) const override; - - private: - std::vector encode_single_message(const std::string &role, const std::string &content) const; - - std::string decode_with_special_tokens(const std::vector &ids) const; - - static std::string remove_special_tokens(const std::string &text); - - int get_command(const std::string &token) const; - - bool is_special_id(int id) const; - - static void truncate(std::vector &ids, int max_length); - - public: - sentencepiece::SentencePieceProcessor sp; - int mask_token_id; - int gmask_token_id; - int smask_token_id; - int sop_token_id; - int eop_token_id; - int system_token_id; - int user_token_id; - int assistant_token_id; - int observation_token_id; - std::unordered_map special_tokens; - std::unordered_map index_special_tokens; - }; - - // ===== Baichuan ===== - - class BaichuanTokenizer : public BaseTokenizer - { - public: - BaichuanTokenizer(std::string_view serialized_model_proto); - - std::vector encode(const std::string &text, int max_length) const override; - - std::string decode(const std::vector &ids) const override; - - std::vector encode_messages(const std::vector &messages, int max_length) const override; - - private: - bool is_special_id(int id) const; - - static void truncate(std::vector &ids, int max_length); - - public: - static constexpr int USER_TOKEN_ID = 195; - static constexpr int ASSISTANT_TOKEN_ID = 196; - - sentencepiece::SentencePieceProcessor sp; - int bos_token_id; - int eos_token_id; - int pad_token_id; - }; - - // ===== InternLM ===== - - class InternLMTokenizer : public BaseTokenizer - { - public: - InternLMTokenizer(std::string_view serialized_model_proto); - - std::vector encode(const std::string &text, int max_length) const override; - - std::string decode(const std::vector &ids) const override; - - std::vector encode_messages(const std::vector &messages, int max_length) const override; - - static std::string build_prompt(const std::vector &messages); - - private: - bool is_special_id(int id) const { return id == unk_token_id || id == bos_token_id || id == eos_token_id; } - - public: - sentencepiece::SentencePieceProcessor sp; - static constexpr int unk_token_id = 0; - static constexpr int bos_token_id = 1; - static constexpr int eos_token_id = 2; - }; -} // namespace chatglm \ No newline at end of file