mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 08:18:51 +00:00
add GemmaTokenizer
This commit is contained in:
parent
6b675a5ede
commit
51d681e159
@ -7,6 +7,10 @@ add_executable(${TARGET}
|
||||
image_metadata.cpp
|
||||
main.cpp
|
||||
)
|
||||
target_include_directories(${TARGET} PRIVATE
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/.."
|
||||
"${PROJECT_SOURCE_DIR}/src"
|
||||
)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
|
||||
if(SD_WEBP)
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "common/media_io.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
#include "image_metadata.h"
|
||||
#include "llm.hpp"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@ -498,6 +499,15 @@ int main(int argc, const char* argv[]) {
|
||||
SDContextParams ctx_params;
|
||||
SDGenerationParams gen_params;
|
||||
|
||||
cli_params.verbose = true;
|
||||
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
|
||||
LLM::GemmaTokenizer tokenizer;
|
||||
auto tokens = tokenizer.tokenize("<html> 一只可爱的小猫");
|
||||
for (auto token : tokens) {
|
||||
LOG_INFO("%d", token);
|
||||
}
|
||||
return 0;
|
||||
|
||||
parse_args(argc, argv, cli_params, ctx_params, gen_params);
|
||||
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
|
||||
log_verbose = cli_params.verbose;
|
||||
|
||||
323
src/llm.hpp
323
src/llm.hpp
@ -47,24 +47,12 @@ namespace LLM {
|
||||
|
||||
std::vector<std::string> special_tokens;
|
||||
|
||||
bool add_bos_token = false;
|
||||
bool add_bos_token = false;
|
||||
bool byte_level_bpe = true;
|
||||
bool byte_fallback = false;
|
||||
|
||||
protected:
|
||||
static std::string strip(const std::string& str) {
|
||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||
|
||||
if (start == std::string::npos) {
|
||||
// String contains only whitespace characters
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string whitespace_clean(std::string text) {
|
||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||
text = strip(text);
|
||||
virtual std::string preprocess(const std::string& text) const {
|
||||
return text;
|
||||
}
|
||||
|
||||
@ -92,6 +80,22 @@ namespace LLM {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<std::u32string> split_utf32(const std::u32string& s, char32_t delim) {
|
||||
std::vector<std::u32string> result;
|
||||
size_t start = 0;
|
||||
|
||||
while (true) {
|
||||
size_t pos = s.find(delim, start);
|
||||
if (pos == std::u32string::npos) {
|
||||
result.emplace_back(s.substr(start));
|
||||
break;
|
||||
}
|
||||
result.emplace_back(s.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public:
|
||||
BPETokenizer() = default;
|
||||
|
||||
@ -208,7 +212,7 @@ namespace LLM {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
|
||||
virtual std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
|
||||
std::string original_text = text;
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
std::vector<std::string> token_strs;
|
||||
@ -230,25 +234,43 @@ namespace LLM {
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = token;
|
||||
std::string token_str = preprocess(token);
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
if (byte_level_bpe) {
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
} else {
|
||||
utf32_token = utf8_to_utf32(token_str);
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
|
||||
start = pos + 1;
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
for (const auto& bpe_str : split_utf32(bpe_strs, U' ')) {
|
||||
int token_id;
|
||||
auto iter = encoder.find(bpe_str);
|
||||
if (iter != encoder.end()) {
|
||||
token_id = iter->second;
|
||||
} else {
|
||||
if (byte_fallback) {
|
||||
auto utf8_token_str = utf32_to_utf8(bpe_str);
|
||||
for (int i = 0; i < utf8_token_str.length(); i++) {
|
||||
unsigned char b = utf8_token_str[i];
|
||||
char hex_buf[16];
|
||||
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
|
||||
iter = encoder.find(utf8_to_utf32(hex_buf));
|
||||
GGML_ASSERT(iter != encoder.end());
|
||||
bpe_tokens.push_back(token_id);
|
||||
token_strs.push_back(hex_buf);
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
token_id = UNK_TOKEN_ID;
|
||||
}
|
||||
}
|
||||
bpe_tokens.push_back(token_id);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,7 +281,6 @@ namespace LLM {
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
};
|
||||
@ -268,35 +289,18 @@ namespace LLM {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
// for (auto & pair: byte_unicode_pairs) {
|
||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||
// }
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
LOG_DEBUG("merges size %llu", merges.size());
|
||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
// int print_num = 10;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// if (print_num > 0) {
|
||||
// print_num--;
|
||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
// }
|
||||
}
|
||||
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||
|
||||
std::vector<std::u32string> tokens;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
@ -396,27 +400,14 @@ namespace LLM {
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
LOG_DEBUG("merges size %llu", merges.size());
|
||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
// int print_num = 10;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// if (print_num > 0) {
|
||||
// print_num--;
|
||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
// }
|
||||
}
|
||||
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
@ -473,6 +464,198 @@ namespace LLM {
|
||||
}
|
||||
};
|
||||
|
||||
class GemmaTokenizer : public BPETokenizer {
|
||||
protected:
|
||||
std::vector<std::string> special_tokens_before_merge;
|
||||
std::vector<std::string> special_tokens_after_merge;
|
||||
|
||||
std::string preprocess(const std::string& text) const override {
|
||||
std::string normalized = text;
|
||||
size_t pos = 0;
|
||||
while ((pos = normalized.find(' ', pos)) != std::string::npos) {
|
||||
normalized.replace(pos, 1, "\xE2\x96\x81");
|
||||
pos += 3;
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
nlohmann::json vocab;
|
||||
try {
|
||||
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||
} catch (const nlohmann::json::parse_error&) {
|
||||
GGML_ABORT("invalid vocab json str");
|
||||
}
|
||||
for (const auto& [key, value] : vocab.items()) {
|
||||
std::u32string token = utf8_to_utf32(key);
|
||||
int i = value;
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
}
|
||||
encoder_len = static_cast<int>(vocab.size());
|
||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf32_str, U'\n');
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
}
|
||||
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
};
|
||||
|
||||
public:
|
||||
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_json_utf8_str = "") {
|
||||
byte_level_bpe = false;
|
||||
byte_fallback = true;
|
||||
add_bos_token = true;
|
||||
PAD_TOKEN = "<pad>";
|
||||
EOS_TOKEN = "<eos>";
|
||||
BOS_TOKEN = "<bos>";
|
||||
UNK_TOKEN = "<unk>";
|
||||
|
||||
PAD_TOKEN_ID = 0;
|
||||
EOS_TOKEN_ID = 1;
|
||||
BOS_TOKEN_ID = 2;
|
||||
UNK_TOKEN_ID = 3;
|
||||
|
||||
special_tokens_before_merge = {
|
||||
PAD_TOKEN,
|
||||
EOS_TOKEN,
|
||||
BOS_TOKEN,
|
||||
UNK_TOKEN,
|
||||
"<mask>",
|
||||
"[multimodal]",
|
||||
};
|
||||
for (int i = 0; i <= 98; i++) {
|
||||
special_tokens_before_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||
}
|
||||
special_tokens_before_merge.push_back("<start_of_turn>");
|
||||
special_tokens_before_merge.push_back("<end_of_turn>");
|
||||
for (int i = 1; i <= 31; i++) {
|
||||
special_tokens_before_merge.push_back(std::string(i, '\n'));
|
||||
}
|
||||
for (int i = 2; i <= 31; i++) {
|
||||
std::string whitespace_token;
|
||||
for (int j = 0; j < i; j++) {
|
||||
whitespace_token += "\xE2\x96\x81";
|
||||
}
|
||||
special_tokens_before_merge.push_back(whitespace_token);
|
||||
}
|
||||
std::vector<std::string> html_tokens = {
|
||||
"<table>",
|
||||
"<caption>",
|
||||
"<thead>",
|
||||
"<tbody>",
|
||||
"<tfoot>",
|
||||
"<tr>",
|
||||
"<th>",
|
||||
"<td>",
|
||||
"</table>",
|
||||
"</caption>",
|
||||
"</thead>",
|
||||
"</tbody>",
|
||||
"</tfoot>",
|
||||
"</tr>",
|
||||
"</th>",
|
||||
"</td>",
|
||||
"<h1>",
|
||||
"<h2>",
|
||||
"<h3>",
|
||||
"<h4>",
|
||||
"<h5>",
|
||||
"<h6>",
|
||||
"<blockquote>",
|
||||
"</h1>",
|
||||
"</h2>",
|
||||
"</h3>",
|
||||
"</h4>",
|
||||
"</h5>",
|
||||
"</h6>",
|
||||
"</blockquote>",
|
||||
"<strong>",
|
||||
"<em>",
|
||||
"<b>",
|
||||
"<i>",
|
||||
"<u>",
|
||||
"<s>",
|
||||
"<sub>",
|
||||
"<sup>",
|
||||
"<code>",
|
||||
"</strong>",
|
||||
"</em>",
|
||||
"</b>",
|
||||
"</i>",
|
||||
"</u>",
|
||||
"</s>",
|
||||
"</sub>",
|
||||
"</sup>",
|
||||
"</code>",
|
||||
"<a>",
|
||||
"<html>",
|
||||
"<body>",
|
||||
"<img>",
|
||||
"<span>",
|
||||
"<bbox>",
|
||||
"<ul>",
|
||||
"<li>",
|
||||
"<div>",
|
||||
"<iframe>",
|
||||
"<footer>",
|
||||
"</a>",
|
||||
"</html>",
|
||||
"</body>",
|
||||
"</img>",
|
||||
"</span>",
|
||||
"</bbox>",
|
||||
"</ul>",
|
||||
"</li>",
|
||||
"</div>",
|
||||
"</iframe>",
|
||||
"</footer>",
|
||||
};
|
||||
special_tokens_before_merge.insert(special_tokens_before_merge.end(),
|
||||
html_tokens.begin(),
|
||||
html_tokens.end());
|
||||
for (int i = 0; i <= 0xFF; i++) {
|
||||
char hex_buf[16];
|
||||
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i);
|
||||
special_tokens_before_merge.push_back(hex_buf);
|
||||
}
|
||||
|
||||
special_tokens_after_merge = {
|
||||
"<start_of_image>",
|
||||
"<end_of_image>",
|
||||
};
|
||||
for (int i = 1; i <= 31; i++) {
|
||||
special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1,
|
||||
std::string(i, '\t'));
|
||||
}
|
||||
for (int i = 99; i <= 6241; i++) {
|
||||
special_tokens_after_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||
}
|
||||
special_tokens_after_merge.push_back("<image_soft_token>");
|
||||
|
||||
special_tokens = special_tokens_before_merge;
|
||||
special_tokens.insert(special_tokens.end(),
|
||||
special_tokens_after_merge.begin(),
|
||||
special_tokens_after_merge.end());
|
||||
|
||||
if (merges_utf8_str.size() > 0 && vocab_json_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str, vocab_json_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enum class LLMArch {
|
||||
QWEN2_5_VL,
|
||||
QWEN3,
|
||||
|
||||
BIN
src/vocab/gemma_merges.hpp
Normal file
BIN
src/vocab/gemma_merges.hpp
Normal file
Binary file not shown.
BIN
src/vocab/gemma_vocab.hpp
Normal file
BIN
src/vocab/gemma_vocab.hpp
Normal file
Binary file not shown.
@ -1,5 +1,7 @@
|
||||
#include "vocab.h"
|
||||
#include "clip_t5.hpp"
|
||||
#include "gemma_merges.hpp"
|
||||
#include "gemma_vocab.hpp"
|
||||
#include "mistral.hpp"
|
||||
#include "qwen.hpp"
|
||||
#include "umt5.hpp"
|
||||
@ -32,4 +34,14 @@ std::string load_t5_tokenizer_json() {
|
||||
std::string load_umt5_tokenizer_json() {
|
||||
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
|
||||
return json_str;
|
||||
}
|
||||
|
||||
std::string load_gemma_merges() {
|
||||
std::string merges_utf8_str(reinterpret_cast<const char*>(gemma_merges_utf8_c_str), sizeof(gemma_merges_utf8_c_str));
|
||||
return merges_utf8_str;
|
||||
}
|
||||
|
||||
std::string load_gemma_vocab_json() {
|
||||
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
|
||||
return json_str;
|
||||
}
|
||||
@ -9,5 +9,7 @@ std::string load_mistral_merges();
|
||||
std::string load_mistral_vocab_json();
|
||||
std::string load_t5_tokenizer_json();
|
||||
std::string load_umt5_tokenizer_json();
|
||||
std::string load_gemma_merges();
|
||||
std::string load_gemma_vocab_json();
|
||||
|
||||
#endif // __VOCAB_H__
|
||||
Loading…
x
Reference in New Issue
Block a user