add GemmaTokenizer

This commit is contained in:
leejet 2026-04-12 23:44:43 +08:00
parent 6b675a5ede
commit 51d681e159
7 changed files with 281 additions and 70 deletions

View File

@ -7,6 +7,10 @@ add_executable(${TARGET}
image_metadata.cpp
main.cpp
)
target_include_directories(${TARGET} PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/.."
"${PROJECT_SOURCE_DIR}/src"
)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
if(SD_WEBP)

View File

@ -19,6 +19,7 @@
#include "common/media_io.h"
#include "common/resource_owners.hpp"
#include "image_metadata.h"
#include "llm.hpp"
namespace fs = std::filesystem;
@ -498,6 +499,15 @@ int main(int argc, const char* argv[]) {
SDContextParams ctx_params;
SDGenerationParams gen_params;
cli_params.verbose = true;
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
LLM::GemmaTokenizer tokenizer;
auto tokens = tokenizer.tokenize("<html> 一只可爱的小猫");
for (auto token : tokens) {
LOG_INFO("%d", token);
}
return 0;
parse_args(argc, argv, cli_params, ctx_params, gen_params);
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
log_verbose = cli_params.verbose;

View File

@ -47,24 +47,12 @@ namespace LLM {
std::vector<std::string> special_tokens;
bool add_bos_token = false;
bool add_bos_token = false;
bool byte_level_bpe = true;
bool byte_fallback = false;
protected:
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
if (start == std::string::npos) {
// String contains only whitespace characters
return "";
}
return str.substr(start, end - start + 1);
}
static std::string whitespace_clean(std::string text) {
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
text = strip(text);
virtual std::string preprocess(const std::string& text) const {
return text;
}
@ -92,6 +80,22 @@ namespace LLM {
return false;
}
static std::vector<std::u32string> split_utf32(const std::u32string& s, char32_t delim) {
std::vector<std::u32string> result;
size_t start = 0;
while (true) {
size_t pos = s.find(delim, start);
if (pos == std::u32string::npos) {
result.emplace_back(s.substr(start));
break;
}
result.emplace_back(s.substr(start, pos - start));
start = pos + 1;
}
return result;
}
public:
BPETokenizer() = default;
@ -208,7 +212,7 @@ namespace LLM {
}
}
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
virtual std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
std::vector<std::string> token_strs;
@ -230,25 +234,43 @@ namespace LLM {
}
}
std::string token_str = token;
std::string token_str = preprocess(token);
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];
utf32_token += byte_encoder[b];
if (byte_level_bpe) {
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];
utf32_token += byte_encoder[b];
}
} else {
utf32_token = utf8_to_utf32(token_str);
}
auto bpe_strs = bpe(utf32_token);
size_t start = 0;
size_t pos;
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
auto bpe_str = bpe_strs.substr(start, pos - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
start = pos + 1;
auto bpe_strs = bpe(utf32_token);
for (const auto& bpe_str : split_utf32(bpe_strs, U' ')) {
int token_id;
auto iter = encoder.find(bpe_str);
if (iter != encoder.end()) {
token_id = iter->second;
} else {
if (byte_fallback) {
auto utf8_token_str = utf32_to_utf8(bpe_str);
for (int i = 0; i < utf8_token_str.length(); i++) {
unsigned char b = utf8_token_str[i];
char hex_buf[16];
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
iter = encoder.find(utf8_to_utf32(hex_buf));
GGML_ASSERT(iter != encoder.end());
bpe_tokens.push_back(token_id);
token_strs.push_back(hex_buf);
}
continue;
} else {
token_id = UNK_TOKEN_ID;
}
}
bpe_tokens.push_back(token_id);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
}
@ -259,7 +281,6 @@ namespace LLM {
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}
};
@ -268,35 +289,18 @@ namespace LLM {
protected:
void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
LOG_DEBUG("merges size %llu", merges.size());
merges = std::vector<std::u32string>(merges.begin(), merges.end());
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
// int print_num = 10;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// if (print_num > 0) {
// print_num--;
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// }
}
LOG_DEBUG("merges size %zu", merge_pairs.size());
std::vector<std::u32string> tokens;
for (const auto& pair : byte_unicode_pairs) {
@ -396,27 +400,14 @@ namespace LLM {
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
LOG_DEBUG("merges size %llu", merges.size());
merges = std::vector<std::u32string>(merges.begin(), merges.end());
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
// int print_num = 10;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// if (print_num > 0) {
// print_num--;
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// }
}
LOG_DEBUG("merges size %zu", merge_pairs.size());
int rank = 0;
for (const auto& merge : merge_pairs) {
@ -473,6 +464,198 @@ namespace LLM {
}
};
class GemmaTokenizer : public BPETokenizer {
protected:
std::vector<std::string> special_tokens_before_merge;
std::vector<std::string> special_tokens_after_merge;
std::string preprocess(const std::string& text) const override {
std::string normalized = text;
size_t pos = 0;
while ((pos = normalized.find(' ', pos)) != std::string::npos) {
normalized.replace(pos, 1, "\xE2\x96\x81");
pos += 3;
}
return normalized;
}
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
nlohmann::json vocab;
try {
vocab = nlohmann::json::parse(vocab_utf8_str);
} catch (const nlohmann::json::parse_error&) {
GGML_ABORT("invalid vocab json str");
}
for (const auto& [key, value] : vocab.items()) {
std::u32string token = utf8_to_utf32(key);
int i = value;
encoder[token] = i;
decoder[i] = token;
}
encoder_len = static_cast<int>(vocab.size());
LOG_DEBUG("vocab size: %d", encoder_len);
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
}
LOG_DEBUG("merges size %zu", merge_pairs.size());
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
};
public:
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_json_utf8_str = "") {
byte_level_bpe = false;
byte_fallback = true;
add_bos_token = true;
PAD_TOKEN = "<pad>";
EOS_TOKEN = "<eos>";
BOS_TOKEN = "<bos>";
UNK_TOKEN = "<unk>";
PAD_TOKEN_ID = 0;
EOS_TOKEN_ID = 1;
BOS_TOKEN_ID = 2;
UNK_TOKEN_ID = 3;
special_tokens_before_merge = {
PAD_TOKEN,
EOS_TOKEN,
BOS_TOKEN,
UNK_TOKEN,
"<mask>",
"[multimodal]",
};
for (int i = 0; i <= 98; i++) {
special_tokens_before_merge.push_back("<unused" + std::to_string(i) + ">");
}
special_tokens_before_merge.push_back("<start_of_turn>");
special_tokens_before_merge.push_back("<end_of_turn>");
for (int i = 1; i <= 31; i++) {
special_tokens_before_merge.push_back(std::string(i, '\n'));
}
for (int i = 2; i <= 31; i++) {
std::string whitespace_token;
for (int j = 0; j < i; j++) {
whitespace_token += "\xE2\x96\x81";
}
special_tokens_before_merge.push_back(whitespace_token);
}
std::vector<std::string> html_tokens = {
"<table>",
"<caption>",
"<thead>",
"<tbody>",
"<tfoot>",
"<tr>",
"<th>",
"<td>",
"</table>",
"</caption>",
"</thead>",
"</tbody>",
"</tfoot>",
"</tr>",
"</th>",
"</td>",
"<h1>",
"<h2>",
"<h3>",
"<h4>",
"<h5>",
"<h6>",
"<blockquote>",
"</h1>",
"</h2>",
"</h3>",
"</h4>",
"</h5>",
"</h6>",
"</blockquote>",
"<strong>",
"<em>",
"<b>",
"<i>",
"<u>",
"<s>",
"<sub>",
"<sup>",
"<code>",
"</strong>",
"</em>",
"</b>",
"</i>",
"</u>",
"</s>",
"</sub>",
"</sup>",
"</code>",
"<a>",
"<html>",
"<body>",
"<img>",
"<span>",
"<bbox>",
"<ul>",
"<li>",
"<div>",
"<iframe>",
"<footer>",
"</a>",
"</html>",
"</body>",
"</img>",
"</span>",
"</bbox>",
"</ul>",
"</li>",
"</div>",
"</iframe>",
"</footer>",
};
special_tokens_before_merge.insert(special_tokens_before_merge.end(),
html_tokens.begin(),
html_tokens.end());
for (int i = 0; i <= 0xFF; i++) {
char hex_buf[16];
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i);
special_tokens_before_merge.push_back(hex_buf);
}
special_tokens_after_merge = {
"<start_of_image>",
"<end_of_image>",
};
for (int i = 1; i <= 31; i++) {
special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1,
std::string(i, '\t'));
}
for (int i = 99; i <= 6241; i++) {
special_tokens_after_merge.push_back("<unused" + std::to_string(i) + ">");
}
special_tokens_after_merge.push_back("<image_soft_token>");
special_tokens = special_tokens_before_merge;
special_tokens.insert(special_tokens.end(),
special_tokens_after_merge.begin(),
special_tokens_after_merge.end());
if (merges_utf8_str.size() > 0 && vocab_json_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_json_utf8_str);
} else {
load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
}
}
};
enum class LLMArch {
QWEN2_5_VL,
QWEN3,

BIN
src/vocab/gemma_merges.hpp Normal file

Binary file not shown.

BIN
src/vocab/gemma_vocab.hpp Normal file

Binary file not shown.

View File

@ -1,5 +1,7 @@
#include "vocab.h"
#include "clip_t5.hpp"
#include "gemma_merges.hpp"
#include "gemma_vocab.hpp"
#include "mistral.hpp"
#include "qwen.hpp"
#include "umt5.hpp"
@ -33,3 +35,13 @@ std::string load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}
std::string load_gemma_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(gemma_merges_utf8_c_str), sizeof(gemma_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_gemma_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
return json_str;
}

View File

@ -9,5 +9,7 @@ std::string load_mistral_merges();
std::string load_mistral_vocab_json();
std::string load_t5_tokenizer_json();
std::string load_umt5_tokenizer_json();
std::string load_gemma_merges();
std::string load_gemma_vocab_json();
#endif // __VOCAB_H__