diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt index db1f4ca3..cbc485bb 100644 --- a/examples/cli/CMakeLists.txt +++ b/examples/cli/CMakeLists.txt @@ -7,6 +7,10 @@ add_executable(${TARGET} image_metadata.cpp main.cpp ) +target_include_directories(${TARGET} PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/.." + "${PROJECT_SOURCE_DIR}/src" +) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT}) if(SD_WEBP) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index a5b0037b..0a8063f4 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -19,6 +19,7 @@ #include "common/media_io.h" #include "common/resource_owners.hpp" #include "image_metadata.h" +#include "llm.hpp" namespace fs = std::filesystem; @@ -498,6 +499,15 @@ int main(int argc, const char* argv[]) { SDContextParams ctx_params; SDGenerationParams gen_params; + cli_params.verbose = true; + sd_set_log_callback(sd_log_cb, (void*)&cli_params); + LLM::GemmaTokenizer tokenizer; + auto tokens = tokenizer.tokenize(" 一只可爱的小猫"); + for (auto token : tokens) { + LOG_INFO("%d", token); + } + return 0; + parse_args(argc, argv, cli_params, ctx_params, gen_params); sd_set_log_callback(sd_log_cb, (void*)&cli_params); log_verbose = cli_params.verbose; diff --git a/src/llm.hpp b/src/llm.hpp index c6c29614..17743396 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -47,24 +47,12 @@ namespace LLM { std::vector special_tokens; - bool add_bos_token = false; + bool add_bos_token = false; + bool byte_level_bpe = true; + bool byte_fallback = false; protected: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); + virtual std::string preprocess(const std::string& text) const { return text; } @@ -92,6 +80,22 @@ namespace LLM { return false; } + static std::vector split_utf32(const std::u32string& s, char32_t delim) { + std::vector result; + size_t start = 0; + + while (true) { + size_t pos = s.find(delim, start); + if (pos == std::u32string::npos) { + result.emplace_back(s.substr(start)); + break; + } + result.emplace_back(s.substr(start, pos - start)); + start = pos + 1; + } + return result; + } + public: BPETokenizer() = default; @@ -208,7 +212,7 @@ namespace LLM { } } - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { + virtual std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { std::string original_text = text; std::vector bpe_tokens; std::vector token_strs; @@ -230,25 +234,43 @@ namespace LLM { } } - std::string token_str = token; + std::string token_str = preprocess(token); std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; + if (byte_level_bpe) { + for (int i = 0; i < token_str.length(); i++) { + unsigned char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + } else { + utf32_token = utf8_to_utf32(token_str); } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - start = pos + 1; + auto bpe_strs = bpe(utf32_token); + for (const auto& bpe_str : split_utf32(bpe_strs, U' ')) { + int token_id; + auto iter = encoder.find(bpe_str); + if (iter != encoder.end()) { + token_id = iter->second; + } else { + if (byte_fallback) { + auto utf8_token_str = utf32_to_utf8(bpe_str); + for (int i = 0; i < utf8_token_str.length(); i++) { + unsigned char b = utf8_token_str[i]; + char hex_buf[16]; + snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b); + iter = encoder.find(utf8_to_utf32(hex_buf)); + GGML_ASSERT(iter != encoder.end()); + bpe_tokens.push_back(token_id); + token_strs.push_back(hex_buf); + } + continue; + } else { + token_id = UNK_TOKEN_ID; + } + } + bpe_tokens.push_back(token_id); + token_strs.push_back(utf32_to_utf8(bpe_str)); } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); } } @@ -259,7 +281,6 @@ namespace LLM { } ss << "]"; LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); return bpe_tokens; } }; @@ -268,35 +289,18 @@ namespace LLM { protected: void load_from_merges(const std::string& merges_utf8_str) { auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); for (auto& pair : byte_unicode_pairs) { byte_decoder[pair.second] = pair.first; } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); std::vector> merge_pairs; - // int print_num = 10; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } } + LOG_DEBUG("merges size %zu", merge_pairs.size()); std::vector tokens; for (const auto& pair : byte_unicode_pairs) { @@ -396,27 +400,14 @@ namespace LLM { for (auto& pair : byte_unicode_pairs) { byte_decoder[pair.second] = pair.first; } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); std::vector> merge_pairs; - // int print_num = 10; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } } + LOG_DEBUG("merges size %zu", merge_pairs.size()); int rank = 0; for (const auto& merge : merge_pairs) { @@ -473,6 +464,198 @@ namespace LLM { } }; + class GemmaTokenizer : public BPETokenizer { + protected: + std::vector special_tokens_before_merge; + std::vector special_tokens_after_merge; + + std::string preprocess(const std::string& text) const override { + std::string normalized = text; + size_t pos = 0; + while ((pos = normalized.find(' ', pos)) != std::string::npos) { + normalized.replace(pos, 1, "\xE2\x96\x81"); + pos += 3; + } + return normalized; + } + + void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + nlohmann::json vocab; + try { + vocab = nlohmann::json::parse(vocab_utf8_str); + } catch (const nlohmann::json::parse_error&) { + GGML_ABORT("invalid vocab json str"); + } + for (const auto& [key, value] : vocab.items()) { + std::u32string token = utf8_to_utf32(key); + int i = value; + encoder[token] = i; + decoder[i] = token; + } + encoder_len = static_cast(vocab.size()); + LOG_DEBUG("vocab size: %d", encoder_len); + + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + std::vector merges = split_utf32(merges_utf32_str, U'\n'); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + LOG_DEBUG("merges size %zu", merge_pairs.size()); + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; + }; + + public: + explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_json_utf8_str = "") { + byte_level_bpe = false; + byte_fallback = true; + add_bos_token = true; + PAD_TOKEN = ""; + EOS_TOKEN = ""; + BOS_TOKEN = ""; + UNK_TOKEN = ""; + + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + BOS_TOKEN_ID = 2; + UNK_TOKEN_ID = 3; + + special_tokens_before_merge = { + PAD_TOKEN, + EOS_TOKEN, + BOS_TOKEN, + UNK_TOKEN, + "", + "[multimodal]", + }; + for (int i = 0; i <= 98; i++) { + special_tokens_before_merge.push_back(""); + } + special_tokens_before_merge.push_back(""); + special_tokens_before_merge.push_back(""); + for (int i = 1; i <= 31; i++) { + special_tokens_before_merge.push_back(std::string(i, '\n')); + } + for (int i = 2; i <= 31; i++) { + std::string whitespace_token; + for (int j = 0; j < i; j++) { + whitespace_token += "\xE2\x96\x81"; + } + special_tokens_before_merge.push_back(whitespace_token); + } + std::vector html_tokens = { + "", + "", + "", + "", + "", + "
", + "
", + "", + "
", + "", + "", + "", + "", + "", + "", + "", + "

", + "

", + "

", + "

", + "

", + "
", + "
", + "
", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "
    ", + "
  • ", + "
    ", + "", + "", + }; + special_tokens_before_merge.insert(special_tokens_before_merge.end(), + html_tokens.begin(), + html_tokens.end()); + for (int i = 0; i <= 0xFF; i++) { + char hex_buf[16]; + snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i); + special_tokens_before_merge.push_back(hex_buf); + } + + special_tokens_after_merge = { + "", + "", + }; + for (int i = 1; i <= 31; i++) { + special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1, + std::string(i, '\t')); + } + for (int i = 99; i <= 6241; i++) { + special_tokens_after_merge.push_back(""); + } + special_tokens_after_merge.push_back(""); + + special_tokens = special_tokens_before_merge; + special_tokens.insert(special_tokens.end(), + special_tokens_after_merge.begin(), + special_tokens_after_merge.end()); + + if (merges_utf8_str.size() > 0 && vocab_json_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str, vocab_json_utf8_str); + } else { + load_from_merges(load_gemma_merges(), load_gemma_vocab_json()); + } + } + }; + enum class LLMArch { QWEN2_5_VL, QWEN3, diff --git a/src/vocab/gemma_merges.hpp b/src/vocab/gemma_merges.hpp new file mode 100644 index 00000000..6573647d Binary files /dev/null and b/src/vocab/gemma_merges.hpp differ diff --git a/src/vocab/gemma_vocab.hpp b/src/vocab/gemma_vocab.hpp new file mode 100644 index 00000000..07e13b10 Binary files /dev/null and b/src/vocab/gemma_vocab.hpp differ diff --git a/src/vocab/vocab.cpp b/src/vocab/vocab.cpp index 63b28686..4884f418 100644 --- a/src/vocab/vocab.cpp +++ b/src/vocab/vocab.cpp @@ -1,5 +1,7 @@ #include "vocab.h" #include "clip_t5.hpp" +#include "gemma_merges.hpp" +#include "gemma_vocab.hpp" #include "mistral.hpp" #include "qwen.hpp" #include "umt5.hpp" @@ -32,4 +34,14 @@ std::string load_t5_tokenizer_json() { std::string load_umt5_tokenizer_json() { std::string json_str(reinterpret_cast(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str)); return json_str; +} + +std::string load_gemma_merges() { + std::string merges_utf8_str(reinterpret_cast(gemma_merges_utf8_c_str), sizeof(gemma_merges_utf8_c_str)); + return merges_utf8_str; +} + +std::string load_gemma_vocab_json() { + std::string json_str(reinterpret_cast(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str)); + return json_str; } \ No newline at end of file diff --git a/src/vocab/vocab.h b/src/vocab/vocab.h index cfa033a4..29340f49 100644 --- a/src/vocab/vocab.h +++ b/src/vocab/vocab.h @@ -9,5 +9,7 @@ std::string load_mistral_merges(); std::string load_mistral_vocab_json(); std::string load_t5_tokenizer_json(); std::string load_umt5_tokenizer_json(); +std::string load_gemma_merges(); +std::string load_gemma_vocab_json(); #endif // __VOCAB_H__ \ No newline at end of file