From 51d681e1597e76d5b888d4ce9e59afe2e1715cbb Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 12 Apr 2026 23:44:43 +0800 Subject: [PATCH] add GemmaTokenizer --- examples/cli/CMakeLists.txt | 4 + examples/cli/main.cpp | 10 ++ src/llm.hpp | 323 ++++++++++++++++++++++++++++-------- src/vocab/gemma_merges.hpp | Bin 0 -> 70887168 bytes src/vocab/gemma_vocab.hpp | Bin 0 -> 60368288 bytes src/vocab/vocab.cpp | 12 ++ src/vocab/vocab.h | 2 + 7 files changed, 281 insertions(+), 70 deletions(-) create mode 100644 src/vocab/gemma_merges.hpp create mode 100644 src/vocab/gemma_vocab.hpp diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt index db1f4ca3..cbc485bb 100644 --- a/examples/cli/CMakeLists.txt +++ b/examples/cli/CMakeLists.txt @@ -7,6 +7,10 @@ add_executable(${TARGET} image_metadata.cpp main.cpp ) +target_include_directories(${TARGET} PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/.." + "${PROJECT_SOURCE_DIR}/src" +) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT}) if(SD_WEBP) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index a5b0037b..0a8063f4 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -19,6 +19,7 @@ #include "common/media_io.h" #include "common/resource_owners.hpp" #include "image_metadata.h" +#include "llm.hpp" namespace fs = std::filesystem; @@ -498,6 +499,15 @@ int main(int argc, const char* argv[]) { SDContextParams ctx_params; SDGenerationParams gen_params; + cli_params.verbose = true; + sd_set_log_callback(sd_log_cb, (void*)&cli_params); + LLM::GemmaTokenizer tokenizer; + auto tokens = tokenizer.tokenize(" 一只可爱的小猫"); + for (auto token : tokens) { + LOG_INFO("%d", token); + } + return 0; + parse_args(argc, argv, cli_params, ctx_params, gen_params); sd_set_log_callback(sd_log_cb, (void*)&cli_params); log_verbose = cli_params.verbose; diff --git a/src/llm.hpp b/src/llm.hpp index c6c29614..17743396 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -47,24 +47,12 @@ namespace LLM { std::vector special_tokens; - bool add_bos_token = false; + bool add_bos_token = false; + bool byte_level_bpe = true; + bool byte_fallback = false; protected: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); + virtual std::string preprocess(const std::string& text) const { return text; } @@ -92,6 +80,22 @@ namespace LLM { return false; } + static std::vector split_utf32(const std::u32string& s, char32_t delim) { + std::vector result; + size_t start = 0; + + while (true) { + size_t pos = s.find(delim, start); + if (pos == std::u32string::npos) { + result.emplace_back(s.substr(start)); + break; + } + result.emplace_back(s.substr(start, pos - start)); + start = pos + 1; + } + return result; + } + public: BPETokenizer() = default; @@ -208,7 +212,7 @@ namespace LLM { } } - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { + virtual std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { std::string original_text = text; std::vector bpe_tokens; std::vector token_strs; @@ -230,25 +234,43 @@ namespace LLM { } } - std::string token_str = token; + std::string token_str = preprocess(token); std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; + if (byte_level_bpe) { + for (int i = 0; i < token_str.length(); i++) { + unsigned char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + } else { + utf32_token = utf8_to_utf32(token_str); } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - start = pos + 1; + auto bpe_strs = bpe(utf32_token); + for (const auto& bpe_str : split_utf32(bpe_strs, U' ')) { + int token_id; + auto iter = encoder.find(bpe_str); + if (iter != encoder.end()) { + token_id = iter->second; + } else { + if (byte_fallback) { + auto utf8_token_str = utf32_to_utf8(bpe_str); + for (int i = 0; i < utf8_token_str.length(); i++) { + unsigned char b = utf8_token_str[i]; + char hex_buf[16]; + snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b); + iter = encoder.find(utf8_to_utf32(hex_buf)); + GGML_ASSERT(iter != encoder.end()); + bpe_tokens.push_back(token_id); + token_strs.push_back(hex_buf); + } + continue; + } else { + token_id = UNK_TOKEN_ID; + } + } + bpe_tokens.push_back(token_id); + token_strs.push_back(utf32_to_utf8(bpe_str)); } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); } } @@ -259,7 +281,6 @@ namespace LLM { } ss << "]"; LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); return bpe_tokens; } }; @@ -268,35 +289,18 @@ namespace LLM { protected: void load_from_merges(const std::string& merges_utf8_str) { auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); for (auto& pair : byte_unicode_pairs) { byte_decoder[pair.second] = pair.first; } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); std::vector> merge_pairs; - // int print_num = 10; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } } + LOG_DEBUG("merges size %zu", merge_pairs.size()); std::vector tokens; for (const auto& pair : byte_unicode_pairs) { @@ -396,27 +400,14 @@ namespace LLM { for (auto& pair : byte_unicode_pairs) { byte_decoder[pair.second] = pair.first; } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); std::vector> merge_pairs; - // int print_num = 10; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } } + LOG_DEBUG("merges size %zu", merge_pairs.size()); int rank = 0; for (const auto& merge : merge_pairs) { @@ -473,6 +464,198 @@ namespace LLM { } }; + class GemmaTokenizer : public BPETokenizer { + protected: + std::vector special_tokens_before_merge; + std::vector special_tokens_after_merge; + + std::string preprocess(const std::string& text) const override { + std::string normalized = text; + size_t pos = 0; + while ((pos = normalized.find(' ', pos)) != std::string::npos) { + normalized.replace(pos, 1, "\xE2\x96\x81"); + pos += 3; + } + return normalized; + } + + void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + nlohmann::json vocab; + try { + vocab = nlohmann::json::parse(vocab_utf8_str); + } catch (const nlohmann::json::parse_error&) { + GGML_ABORT("invalid vocab json str"); + } + for (const auto& [key, value] : vocab.items()) { + std::u32string token = utf8_to_utf32(key); + int i = value; + encoder[token] = i; + decoder[i] = token; + } + encoder_len = static_cast(vocab.size()); + LOG_DEBUG("vocab size: %d", encoder_len); + + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + LOG_DEBUG("merges size %zu", merge_pairs.size()); + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; + }; + + public: + explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_json_utf8_str = "") { + byte_level_bpe = false; + byte_fallback = true; + add_bos_token = true; + PAD_TOKEN = ""; + EOS_TOKEN = ""; + BOS_TOKEN = ""; + UNK_TOKEN = ""; + + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + BOS_TOKEN_ID = 2; + UNK_TOKEN_ID = 3; + + special_tokens_before_merge = { + PAD_TOKEN, + EOS_TOKEN, + BOS_TOKEN, + UNK_TOKEN, + "", + "[multimodal]", + }; + for (int i = 0; i <= 98; i++) { + special_tokens_before_merge.push_back(""); + } + special_tokens_before_merge.push_back(""); + special_tokens_before_merge.push_back(""); + for (int i = 1; i <= 31; i++) { + special_tokens_before_merge.push_back(std::string(i, '\n')); + } + for (int i = 2; i <= 31; i++) { + std::string whitespace_token; + for (int j = 0; j < i; j++) { + whitespace_token += "\xE2\x96\x81"; + } + special_tokens_before_merge.push_back(whitespace_token); + } + std::vector html_tokens = { + "", + "", + "", + "", + "", + "
", + "
", + "", + "
", + "", + "", + "", + "", + "", + "", + "", + "

", + "

", + "

", + "

", + "

", + "
", + "
", + "
", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "