mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 16:28:53 +00:00
refactor: introduce shared tokenizer abstraction and split implementations (#1423)
This commit is contained in:
parent
ee5bf956b0
commit
9ac7b672c2
@ -156,8 +156,10 @@ file(GLOB SD_LIB_SOURCES
|
|||||||
"src/*.h"
|
"src/*.h"
|
||||||
"src/*.cpp"
|
"src/*.cpp"
|
||||||
"src/*.hpp"
|
"src/*.hpp"
|
||||||
"src/vocab/*.h"
|
"src/tokenizers/*.h"
|
||||||
"src/vocab/*.cpp"
|
"src/tokenizers/*.cpp"
|
||||||
|
"src/tokenizers/vocab/*.h"
|
||||||
|
"src/tokenizers/vocab/*.cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||||
@ -250,7 +252,7 @@ endif()
|
|||||||
add_subdirectory(thirdparty)
|
add_subdirectory(thirdparty)
|
||||||
|
|
||||||
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
|
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
|
||||||
target_include_directories(${SD_LIB} PUBLIC . include)
|
target_include_directories(${SD_LIB} PUBLIC . src include)
|
||||||
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
|
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
|
||||||
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
|
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp \
|
for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \
|
||||||
examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
|
examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
|
||||||
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
|
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
|
||||||
[[ "$f" == vocab* ]] && continue
|
[[ "$f" == vocab* ]] && continue
|
||||||
|
|||||||
450
src/clip.hpp
450
src/clip.hpp
@ -3,455 +3,7 @@
|
|||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
#include "tokenize_util.h"
|
#include "tokenizers/clip_tokenizer.h"
|
||||||
#include "vocab/vocab.h"
|
|
||||||
|
|
||||||
/*================================================== CLIPTokenizer ===================================================*/
|
|
||||||
|
|
||||||
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
|
||||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
|
||||||
std::set<int> byte_set;
|
|
||||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
|
||||||
byte_set.insert(b);
|
|
||||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
||||||
}
|
|
||||||
for (int b = 161; b <= 172; ++b) {
|
|
||||||
byte_set.insert(b);
|
|
||||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
||||||
}
|
|
||||||
for (int b = 174; b <= 255; ++b) {
|
|
||||||
byte_set.insert(b);
|
|
||||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
||||||
}
|
|
||||||
int n = 0;
|
|
||||||
for (int b = 0; b < 256; ++b) {
|
|
||||||
if (byte_set.find(b) == byte_set.end()) {
|
|
||||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
|
||||||
++n;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
|
|
||||||
return byte_unicode_pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
|
||||||
|
|
||||||
typedef std::function<bool(std::string&, std::vector<int32_t>&)> on_new_token_cb_t;
|
|
||||||
|
|
||||||
class CLIPTokenizer {
|
|
||||||
private:
|
|
||||||
std::map<int, std::u32string> byte_encoder;
|
|
||||||
std::map<std::u32string, int> byte_decoder;
|
|
||||||
std::map<std::u32string, int> encoder;
|
|
||||||
std::map<int, std::u32string> decoder;
|
|
||||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
|
||||||
std::regex pat;
|
|
||||||
int encoder_len;
|
|
||||||
int bpe_len;
|
|
||||||
|
|
||||||
std::vector<std::string> special_tokens;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
|
||||||
const std::string BOS_TOKEN = "<|startoftext|>";
|
|
||||||
const std::string EOS_TOKEN = "<|endoftext|>";
|
|
||||||
const std::string PAD_TOKEN = "<|endoftext|>";
|
|
||||||
|
|
||||||
const int UNK_TOKEN_ID = 49407;
|
|
||||||
const int BOS_TOKEN_ID = 49406;
|
|
||||||
const int EOS_TOKEN_ID = 49407;
|
|
||||||
const int PAD_TOKEN_ID = 49407;
|
|
||||||
|
|
||||||
private:
|
|
||||||
static std::string strip(const std::string& str) {
|
|
||||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
|
||||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
|
||||||
|
|
||||||
if (start == std::string::npos) {
|
|
||||||
// String contains only whitespace characters
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
return str.substr(start, end - start + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string whitespace_clean(std::string text) {
|
|
||||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
|
||||||
text = strip(text);
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
|
||||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
|
||||||
if (subwords.size() == 0) {
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
std::u32string prev_subword = subwords[0];
|
|
||||||
for (int i = 1; i < subwords.size(); i++) {
|
|
||||||
std::u32string subword = subwords[i];
|
|
||||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
|
||||||
pairs.insert(pair);
|
|
||||||
prev_subword = subword;
|
|
||||||
}
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_special_token(const std::string& token) {
|
|
||||||
for (auto& special_token : special_tokens) {
|
|
||||||
if (special_token == token) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
|
|
||||||
: PAD_TOKEN_ID(pad_token_id) {
|
|
||||||
if (merges_utf8_str.size() > 0) {
|
|
||||||
load_from_merges(merges_utf8_str);
|
|
||||||
} else {
|
|
||||||
load_from_merges(load_clip_merges());
|
|
||||||
}
|
|
||||||
add_special_token("<|startoftext|>");
|
|
||||||
add_special_token("<|endoftext|>");
|
|
||||||
}
|
|
||||||
|
|
||||||
void load_from_merges(const std::string& merges_utf8_str) {
|
|
||||||
auto byte_unicode_pairs = bytes_to_unicode();
|
|
||||||
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
|
||||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
|
||||||
for (auto& pair : byte_unicode_pairs) {
|
|
||||||
byte_decoder[pair.second] = pair.first;
|
|
||||||
}
|
|
||||||
// for (auto & pair: byte_unicode_pairs) {
|
|
||||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
|
||||||
// }
|
|
||||||
std::vector<std::u32string> merges;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t pos;
|
|
||||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
|
||||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
|
||||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
// LOG_DEBUG("merges size %llu", merges.size());
|
|
||||||
GGML_ASSERT(merges.size() == 48895);
|
|
||||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
|
||||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
|
||||||
for (const auto& merge : merges) {
|
|
||||||
size_t space_pos = merge.find(' ');
|
|
||||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
|
||||||
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
|
||||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
|
||||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
|
||||||
}
|
|
||||||
std::vector<std::u32string> vocab;
|
|
||||||
for (const auto& pair : byte_unicode_pairs) {
|
|
||||||
vocab.push_back(pair.second);
|
|
||||||
}
|
|
||||||
for (const auto& pair : byte_unicode_pairs) {
|
|
||||||
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
|
||||||
}
|
|
||||||
for (const auto& merge : merge_pairs) {
|
|
||||||
vocab.push_back(merge.first + merge.second);
|
|
||||||
}
|
|
||||||
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
|
||||||
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
|
||||||
LOG_DEBUG("vocab size: %llu", vocab.size());
|
|
||||||
int i = 0;
|
|
||||||
for (const auto& token : vocab) {
|
|
||||||
encoder[token] = i;
|
|
||||||
decoder[i] = token;
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
encoder_len = i;
|
|
||||||
|
|
||||||
auto it = encoder.find(utf8_to_utf32("img</w>"));
|
|
||||||
if (it != encoder.end()) {
|
|
||||||
LOG_DEBUG("trigger word img already in vocab");
|
|
||||||
} else {
|
|
||||||
LOG_DEBUG("trigger word img not in vocab yet");
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank = 0;
|
|
||||||
for (const auto& merge : merge_pairs) {
|
|
||||||
bpe_ranks[merge] = rank++;
|
|
||||||
}
|
|
||||||
bpe_len = rank;
|
|
||||||
};
|
|
||||||
|
|
||||||
void add_token(const std::string& text) {
|
|
||||||
std::u32string token = utf8_to_utf32(text);
|
|
||||||
auto it = encoder.find(token);
|
|
||||||
if (it != encoder.end()) {
|
|
||||||
encoder[token] = encoder_len;
|
|
||||||
decoder[encoder_len] = token;
|
|
||||||
encoder_len++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void add_special_token(const std::string& token) {
|
|
||||||
special_tokens.push_back(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::u32string bpe(const std::u32string& token) {
|
|
||||||
std::vector<std::u32string> word;
|
|
||||||
|
|
||||||
for (int i = 0; i < token.size() - 1; i++) {
|
|
||||||
word.emplace_back(1, token[i]);
|
|
||||||
}
|
|
||||||
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
|
|
||||||
|
|
||||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
|
||||||
|
|
||||||
if (pairs.empty()) {
|
|
||||||
return token + utf8_to_utf32("</w>");
|
|
||||||
}
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
|
||||||
pairs.end(),
|
|
||||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
|
||||||
const std::pair<std::u32string, std::u32string>& b) {
|
|
||||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
|
||||||
return false;
|
|
||||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
|
||||||
});
|
|
||||||
|
|
||||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
|
||||||
|
|
||||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::u32string first = bigram.first;
|
|
||||||
std::u32string second = bigram.second;
|
|
||||||
std::vector<std::u32string> new_word;
|
|
||||||
int32_t i = 0;
|
|
||||||
|
|
||||||
while (i < word.size()) {
|
|
||||||
auto it = std::find(word.begin() + i, word.end(), first);
|
|
||||||
if (it == word.end()) {
|
|
||||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
|
||||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
|
||||||
|
|
||||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
|
||||||
new_word.push_back(first + second);
|
|
||||||
i += 2;
|
|
||||||
} else {
|
|
||||||
new_word.push_back(word[i]);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
word = new_word;
|
|
||||||
|
|
||||||
if (word.size() == 1) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
pairs = get_pairs(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::u32string result;
|
|
||||||
for (int i = 0; i < word.size(); i++) {
|
|
||||||
result += word[i];
|
|
||||||
if (i != word.size() - 1) {
|
|
||||||
result += utf8_to_utf32(" ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> tokenize(std::string text,
|
|
||||||
on_new_token_cb_t on_new_token_cb,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
|
|
||||||
|
|
||||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
|
||||||
if (max_length > 0) {
|
|
||||||
if (tokens.size() > max_length - 1) {
|
|
||||||
tokens.resize(max_length - 1);
|
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
} else {
|
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
if (padding) {
|
|
||||||
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
void pad_tokens(std::vector<int>& tokens,
|
|
||||||
std::vector<float>& weights,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
if (max_length > 0 && padding) {
|
|
||||||
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.0 / (max_length - 2)));
|
|
||||||
if (n == 0) {
|
|
||||||
n = 1;
|
|
||||||
}
|
|
||||||
size_t length = max_length * n;
|
|
||||||
LOG_DEBUG("token length: %llu", length);
|
|
||||||
std::vector<int> new_tokens;
|
|
||||||
std::vector<float> new_weights;
|
|
||||||
new_tokens.push_back(BOS_TOKEN_ID);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
int token_idx = 0;
|
|
||||||
for (int i = 1; i < length; i++) {
|
|
||||||
if (token_idx >= tokens.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (i % max_length == 0) {
|
|
||||||
new_tokens.push_back(BOS_TOKEN_ID);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
} else if (i % max_length == max_length - 1) {
|
|
||||||
new_tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
} else {
|
|
||||||
new_tokens.push_back(tokens[token_idx]);
|
|
||||||
new_weights.push_back(weights[token_idx]);
|
|
||||||
token_idx++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
new_tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
tokens = new_tokens;
|
|
||||||
weights = new_weights;
|
|
||||||
|
|
||||||
if (padding) {
|
|
||||||
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
|
|
||||||
weights.insert(weights.end(), length - weights.size(), 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string clean_up_tokenization(std::string& text) {
|
|
||||||
std::regex pattern(R"( ,)");
|
|
||||||
// Replace " ," with ","
|
|
||||||
std::string result = std::regex_replace(text, pattern, ",");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string decode(const std::vector<int>& tokens) {
|
|
||||||
std::string text = "";
|
|
||||||
for (int t : tokens) {
|
|
||||||
if (t == 49406 || t == 49407)
|
|
||||||
continue;
|
|
||||||
std::u32string ts = decoder[t];
|
|
||||||
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
|
|
||||||
std::string s = utf32_to_utf8(ts);
|
|
||||||
if (s.length() >= 4) {
|
|
||||||
if (ends_with(s, "</w>")) {
|
|
||||||
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
|
|
||||||
} else {
|
|
||||||
text += s;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
text += " " + s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// std::vector<unsigned char> bytes;
|
|
||||||
// for (auto c : text){
|
|
||||||
// bytes.push_back(byte_decoder[c]);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// std::string s((char *)bytes.data());
|
|
||||||
// std::string s = "";
|
|
||||||
text = clean_up_tokenization(text);
|
|
||||||
return trim(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> token_split(const std::string& text) {
|
|
||||||
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
|
||||||
std::regex::icase);
|
|
||||||
std::sregex_iterator iter(text.begin(), text.end(), pat);
|
|
||||||
std::sregex_iterator end;
|
|
||||||
|
|
||||||
std::vector<std::string> result;
|
|
||||||
for (; iter != end; ++iter) {
|
|
||||||
result.emplace_back(iter->str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
|
|
||||||
std::string original_text = text;
|
|
||||||
std::vector<int32_t> bpe_tokens;
|
|
||||||
text = whitespace_clean(text);
|
|
||||||
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
|
||||||
|
|
||||||
std::string str = text;
|
|
||||||
std::vector<std::string> token_strs;
|
|
||||||
|
|
||||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
|
||||||
|
|
||||||
for (auto& splited_text : splited_texts) {
|
|
||||||
LOG_DEBUG("token %s", splited_text.c_str());
|
|
||||||
if (is_special_token(splited_text)) {
|
|
||||||
LOG_DEBUG("special %s", splited_text.c_str());
|
|
||||||
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
|
||||||
if (skip) {
|
|
||||||
token_strs.push_back(splited_text);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tokens = token_split(splited_text);
|
|
||||||
for (auto& token : tokens) {
|
|
||||||
if (on_new_token_cb != nullptr) {
|
|
||||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
|
||||||
if (skip) {
|
|
||||||
token_strs.push_back(token);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string token_str = token;
|
|
||||||
std::u32string utf32_token;
|
|
||||||
for (int i = 0; i < token_str.length(); i++) {
|
|
||||||
unsigned char b = token_str[i];
|
|
||||||
utf32_token += byte_encoder[b];
|
|
||||||
}
|
|
||||||
auto bpe_strs = bpe(utf32_token);
|
|
||||||
size_t start = 0;
|
|
||||||
size_t pos;
|
|
||||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
|
||||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
|
||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
|
||||||
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
|
||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// std::stringstream ss;
|
|
||||||
// ss << "[";
|
|
||||||
// for (auto token : token_strs) {
|
|
||||||
// ss << "\"" << token << "\", ";
|
|
||||||
// }
|
|
||||||
// ss << "]";
|
|
||||||
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
|
||||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
|
||||||
return bpe_tokens;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*================================================ FrozenCLIPEmbedder ================================================*/
|
/*================================================ FrozenCLIPEmbedder ================================================*/
|
||||||
|
|
||||||
|
|||||||
@ -256,15 +256,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
|
|
||||||
tokenize_with_trigger_token(std::string text,
|
|
||||||
int num_input_imgs,
|
|
||||||
int32_t image_token,
|
|
||||||
bool padding = false) {
|
|
||||||
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
|
|
||||||
text_model->model.n_token, padding);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> convert_token_to_id(std::string text) {
|
std::vector<int> convert_token_to_id(std::string text) {
|
||||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||||
auto iter = embedding_map.find(str);
|
auto iter = embedding_map.find(str);
|
||||||
@ -288,9 +279,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
|
||||||
tokenize_with_trigger_token(std::string text,
|
tokenize_with_trigger_token(std::string text,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int32_t image_token,
|
int32_t image_token) {
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -377,7 +366,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
// tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
|
// tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
|
||||||
// weights.insert(weights.begin(), 1.0);
|
// weights.insert(weights.begin(), 1.0);
|
||||||
|
|
||||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
tokenizer.pad_tokens(tokens, &weights, nullptr, text_model->model.n_token, text_model->model.n_token, true);
|
||||||
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
for (int i = 0; i < tokens.size(); i++) {
|
||||||
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||||
@ -403,13 +392,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||||
bool padding = false) {
|
size_t min_length = 0,
|
||||||
return tokenize(text, text_model->model.n_token, padding);
|
size_t max_length = 0,
|
||||||
}
|
bool allow_overflow_expand = true) {
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -460,7 +445,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
tokenizer.pad_tokens(tokens, &weights, nullptr, min_length, max_length, allow_overflow_expand);
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
@ -603,8 +588,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
GGML_ASSERT(image_tokens.size() == 1);
|
GGML_ASSERT(image_tokens.size() == 1);
|
||||||
auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
|
auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
|
||||||
conditioner_params.num_input_imgs,
|
conditioner_params.num_input_imgs,
|
||||||
image_tokens[0],
|
image_tokens[0]);
|
||||||
true);
|
|
||||||
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
||||||
std::vector<float>& weights = std::get<1>(tokens_and_weights);
|
std::vector<float>& weights = std::get<1>(tokens_and_weights);
|
||||||
std::vector<bool>& clsm = std::get<2>(tokens_and_weights);
|
std::vector<bool>& clsm = std::get<2>(tokens_and_weights);
|
||||||
@ -630,7 +614,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
std::string remove_trigger_from_prompt(const std::string& prompt) override {
|
std::string remove_trigger_from_prompt(const std::string& prompt) override {
|
||||||
auto image_tokens = convert_token_to_id(trigger_word);
|
auto image_tokens = convert_token_to_id(trigger_word);
|
||||||
GGML_ASSERT(image_tokens.size() == 1);
|
GGML_ASSERT(image_tokens.size() == 1);
|
||||||
auto tokens_and_weights = tokenize(prompt, false);
|
auto tokens_and_weights = tokenize(prompt);
|
||||||
std::vector<int>& tokens = tokens_and_weights.first;
|
std::vector<int>& tokens = tokens_and_weights.first;
|
||||||
auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]);
|
auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]);
|
||||||
GGML_ASSERT(it != tokens.end()); // prompt must have trigger word
|
GGML_ASSERT(it != tokens.end()); // prompt must have trigger word
|
||||||
@ -640,7 +624,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(int n_threads,
|
SDCondition get_learned_condition(int n_threads,
|
||||||
const ConditionerParams& conditioner_params) override {
|
const ConditionerParams& conditioner_params) override {
|
||||||
auto tokens_and_weights = tokenize(conditioner_params.text, true);
|
auto tokens_and_weights = tokenize(conditioner_params.text, text_model->model.n_token, text_model->model.n_token, true);
|
||||||
std::vector<int>& tokens = tokens_and_weights.first;
|
std::vector<int>& tokens = tokens_and_weights.first;
|
||||||
std::vector<float>& weights = tokens_and_weights.second;
|
std::vector<float>& weights = tokens_and_weights.second;
|
||||||
return get_learned_condition_common(n_threads,
|
return get_learned_condition_common(n_threads,
|
||||||
@ -822,8 +806,9 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
|
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
|
||||||
size_t max_length = 0,
|
size_t min_length = 0,
|
||||||
bool padding = false) {
|
size_t max_length = 0,
|
||||||
|
bool allow_overflow_expand = true) {
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -860,20 +845,20 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
|
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
if (t5) {
|
if (t5) {
|
||||||
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
|
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
|
||||||
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (clip_l) {
|
if (clip_l) {
|
||||||
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
|
clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, min_length, max_length, allow_overflow_expand);
|
||||||
}
|
}
|
||||||
if (clip_g) {
|
if (clip_g) {
|
||||||
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
|
clip_g_tokenizer.pad_tokens(clip_g_tokens, &clip_g_weights, nullptr, min_length, max_length, allow_overflow_expand);
|
||||||
}
|
}
|
||||||
if (t5) {
|
if (t5) {
|
||||||
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding);
|
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (int i = 0; i < clip_l_tokens.size(); i++) {
|
// for (int i = 0; i < clip_l_tokens.size(); i++) {
|
||||||
@ -1056,7 +1041,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(int n_threads,
|
SDCondition get_learned_condition(int n_threads,
|
||||||
const ConditionerParams& conditioner_params) override {
|
const ConditionerParams& conditioner_params) override {
|
||||||
auto tokens_and_weights = tokenize(conditioner_params.text, 77, true);
|
auto tokens_and_weights = tokenize(conditioner_params.text, 77, 77, true);
|
||||||
return get_learned_condition_common(n_threads,
|
return get_learned_condition_common(n_threads,
|
||||||
tokens_and_weights,
|
tokens_and_weights,
|
||||||
conditioner_params.clip_skip,
|
conditioner_params.clip_skip,
|
||||||
@ -1158,8 +1143,8 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
|
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
|
||||||
size_t max_length = 0,
|
size_t min_length = 0,
|
||||||
bool padding = false) {
|
size_t max_length = 0) {
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -1189,17 +1174,17 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
|
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
if (t5) {
|
if (t5) {
|
||||||
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
|
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
|
||||||
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (clip_l) {
|
if (clip_l) {
|
||||||
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
|
clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, 77, 77, true);
|
||||||
}
|
}
|
||||||
if (t5) {
|
if (t5) {
|
||||||
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding);
|
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (int i = 0; i < clip_l_tokens.size(); i++) {
|
// for (int i = 0; i < clip_l_tokens.size(); i++) {
|
||||||
@ -1300,7 +1285,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(int n_threads,
|
SDCondition get_learned_condition(int n_threads,
|
||||||
const ConditionerParams& conditioner_params) override {
|
const ConditionerParams& conditioner_params) override {
|
||||||
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
|
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len);
|
||||||
return get_learned_condition_common(n_threads,
|
return get_learned_condition_common(n_threads,
|
||||||
tokens_and_weights,
|
tokens_and_weights,
|
||||||
conditioner_params.clip_skip,
|
conditioner_params.clip_skip,
|
||||||
@ -1377,8 +1362,8 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
|
||||||
size_t max_length = 0,
|
size_t min_length = 0,
|
||||||
bool padding = false) {
|
size_t max_length = 0) {
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -1403,12 +1388,15 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
|
|
||||||
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
|
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
|
||||||
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
|
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, &t5_mask, min_length, max_length, true);
|
||||||
|
for (auto& mask_value : t5_mask) {
|
||||||
|
mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return {t5_tokens, t5_weights, t5_mask};
|
return {t5_tokens, t5_weights, t5_mask};
|
||||||
}
|
}
|
||||||
@ -1496,7 +1484,7 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(int n_threads,
|
SDCondition get_learned_condition(int n_threads,
|
||||||
const ConditionerParams& conditioner_params) override {
|
const ConditionerParams& conditioner_params) override {
|
||||||
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
|
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len);
|
||||||
return get_learned_condition_common(n_threads,
|
return get_learned_condition_common(n_threads,
|
||||||
tokens_and_weights,
|
tokens_and_weights,
|
||||||
conditioner_params.clip_skip,
|
conditioner_params.clip_skip,
|
||||||
@ -1505,14 +1493,14 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct AnimaConditioner : public Conditioner {
|
struct AnimaConditioner : public Conditioner {
|
||||||
std::shared_ptr<LLM::BPETokenizer> qwen_tokenizer;
|
std::shared_ptr<BPETokenizer> qwen_tokenizer;
|
||||||
T5UniGramTokenizer t5_tokenizer;
|
T5UniGramTokenizer t5_tokenizer;
|
||||||
std::shared_ptr<LLM::LLMRunner> llm;
|
std::shared_ptr<LLM::LLMRunner> llm;
|
||||||
|
|
||||||
AnimaConditioner(ggml_backend_t backend,
|
AnimaConditioner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2TensorStorage& tensor_storage_map = {}) {
|
const String2TensorStorage& tensor_storage_map = {}) {
|
||||||
qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
|
qwen_tokenizer = std::make_shared<Qwen2Tokenizer>();
|
||||||
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
|
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
|
||||||
backend,
|
backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
@ -1578,7 +1566,7 @@ struct AnimaConditioner : public Conditioner {
|
|||||||
for (const auto& item : parsed_attention) {
|
for (const auto& item : parsed_attention) {
|
||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
|
std::vector<int> curr_tokens = t5_tokenizer.tokenize(curr_text, nullptr, true);
|
||||||
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
@ -1620,7 +1608,7 @@ struct AnimaConditioner : public Conditioner {
|
|||||||
|
|
||||||
struct LLMEmbedder : public Conditioner {
|
struct LLMEmbedder : public Conditioner {
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
std::shared_ptr<LLM::BPETokenizer> tokenizer;
|
std::shared_ptr<BPETokenizer> tokenizer;
|
||||||
std::shared_ptr<LLM::LLMRunner> llm;
|
std::shared_ptr<LLM::LLMRunner> llm;
|
||||||
|
|
||||||
LLMEmbedder(ggml_backend_t backend,
|
LLMEmbedder(ggml_backend_t backend,
|
||||||
@ -1637,9 +1625,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
}
|
}
|
||||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||||
tokenizer = std::make_shared<LLM::MistralTokenizer>();
|
tokenizer = std::make_shared<MistralTokenizer>();
|
||||||
} else {
|
} else {
|
||||||
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
|
tokenizer = std::make_shared<Qwen2Tokenizer>();
|
||||||
}
|
}
|
||||||
llm = std::make_shared<LLM::LLMRunner>(arch,
|
llm = std::make_shared<LLM::LLMRunner>(arch,
|
||||||
backend,
|
backend,
|
||||||
@ -1677,10 +1665,10 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
|
||||||
const std::pair<int, int>& attn_range,
|
const std::pair<int, int>& attn_range,
|
||||||
size_t max_length = 0,
|
size_t min_length = 0,
|
||||||
bool padding = false) {
|
size_t max_length = 100000000) {
|
||||||
std::vector<std::pair<std::string, float>> parsed_attention;
|
std::vector<std::pair<std::string, float>> parsed_attention;
|
||||||
if (attn_range.first >= 0 && attn_range.second > 0) {
|
if (attn_range.first >= 0 && attn_range.second > 0) {
|
||||||
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
|
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
|
||||||
@ -1710,39 +1698,34 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
for (const auto& item : parsed_attention) {
|
for (const auto& item : parsed_attention) {
|
||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
|
std::vector<int> curr_tokens = tokenizer->encode(curr_text, nullptr);
|
||||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer->pad_tokens(tokens, weights, max_length, padding);
|
std::vector<float> mask;
|
||||||
|
tokenizer->pad_tokens(tokens, &weights, &mask, min_length, max_length);
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
|
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
|
||||||
// }
|
// }
|
||||||
// std::cout << std::endl;
|
// std::cout << std::endl;
|
||||||
|
|
||||||
return {tokens, weights};
|
return {tokens, weights, mask};
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::Tensor<float> encode_prompt(int n_threads,
|
sd::Tensor<float> encode_prompt(int n_threads,
|
||||||
const std::string prompt,
|
const std::string prompt,
|
||||||
const std::pair<int, int>& prompt_attn_range,
|
const std::pair<int, int>& prompt_attn_range,
|
||||||
int max_length,
|
|
||||||
int min_length,
|
int min_length,
|
||||||
|
int hidden_states_min_length,
|
||||||
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
|
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
|
||||||
const std::set<int>& out_layers,
|
const std::set<int>& out_layers,
|
||||||
int prompt_template_encode_start_idx) {
|
int prompt_template_encode_start_idx) {
|
||||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range);
|
auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length);
|
||||||
auto& tokens = std::get<0>(tokens_and_weights);
|
auto& tokens = std::get<0>(tokens_weights_mask);
|
||||||
auto& weights = std::get<1>(tokens_and_weights);
|
auto& weights = std::get<1>(tokens_weights_mask);
|
||||||
std::vector<float> mask;
|
auto& mask = std::get<2>(tokens_weights_mask);
|
||||||
|
|
||||||
if (max_length > 0 && tokens.size() < max_length) {
|
|
||||||
mask.insert(mask.end(), tokens.size(), 1.f);
|
|
||||||
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
|
|
||||||
tokenizer->pad_tokens(tokens, weights, max_length, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
|
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
|
||||||
sd::Tensor<float> attention_mask;
|
sd::Tensor<float> attention_mask;
|
||||||
@ -1769,9 +1752,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx);
|
GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx);
|
||||||
|
|
||||||
int64_t zero_pad_len = 0;
|
int64_t zero_pad_len = 0;
|
||||||
if (min_length > 0) {
|
if (hidden_states_min_length > 0) {
|
||||||
if (hidden_states.shape()[1] - prompt_template_encode_start_idx < min_length) {
|
if (hidden_states.shape()[1] - prompt_template_encode_start_idx < hidden_states_min_length) {
|
||||||
zero_pad_len = min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx;
|
zero_pad_len = hidden_states_min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1798,8 +1781,8 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
std::vector<std::pair<int, int>> extra_prompts_attn_range;
|
std::vector<std::pair<int, int>> extra_prompts_attn_range;
|
||||||
std::vector<std::pair<int, sd::Tensor<float>>> image_embeds;
|
std::vector<std::pair<int, sd::Tensor<float>>> image_embeds;
|
||||||
int prompt_template_encode_start_idx = 34;
|
int prompt_template_encode_start_idx = 34;
|
||||||
int max_length = 0; // pad tokens
|
int min_length = 0; // pad tokens
|
||||||
int min_length = 0; // zero pad hidden_states
|
int hidden_states_min_length = 0; // zero pad hidden_states
|
||||||
std::set<int> out_layers;
|
std::set<int> out_layers;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
@ -1874,7 +1857,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
} else if (version == VERSION_FLUX2) {
|
} else if (version == VERSION_FLUX2) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
min_length = 512;
|
hidden_states_min_length = 512;
|
||||||
out_layers = {10, 20, 30};
|
out_layers = {10, 20, 30};
|
||||||
|
|
||||||
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
||||||
@ -1907,7 +1890,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
} else if (version == VERSION_FLUX2_KLEIN) {
|
} else if (version == VERSION_FLUX2_KLEIN) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
max_length = 512;
|
min_length = 512;
|
||||||
out_layers = {9, 18, 27};
|
out_layers = {9, 18, 27};
|
||||||
|
|
||||||
prompt = "<|im_start|>user\n";
|
prompt = "<|im_start|>user\n";
|
||||||
@ -1919,7 +1902,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
prompt_template_encode_start_idx = 28;
|
prompt_template_encode_start_idx = 28;
|
||||||
max_length = prompt_template_encode_start_idx + 256;
|
min_length = prompt_template_encode_start_idx + 256;
|
||||||
|
|
||||||
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
|
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
|
||||||
|
|
||||||
@ -1935,8 +1918,8 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
auto hidden_states = encode_prompt(n_threads,
|
auto hidden_states = encode_prompt(n_threads,
|
||||||
prompt,
|
prompt,
|
||||||
prompt_attn_range,
|
prompt_attn_range,
|
||||||
max_length,
|
|
||||||
min_length,
|
min_length,
|
||||||
|
hidden_states_min_length,
|
||||||
image_embeds,
|
image_embeds,
|
||||||
out_layers,
|
out_layers,
|
||||||
prompt_template_encode_start_idx);
|
prompt_template_encode_start_idx);
|
||||||
@ -1945,8 +1928,8 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
auto extra_hidden_states = encode_prompt(n_threads,
|
auto extra_hidden_states = encode_prompt(n_threads,
|
||||||
extra_prompts[i],
|
extra_prompts[i],
|
||||||
extra_prompts_attn_range[i],
|
extra_prompts_attn_range[i],
|
||||||
max_length,
|
|
||||||
min_length,
|
min_length,
|
||||||
|
hidden_states_min_length,
|
||||||
image_embeds,
|
image_embeds,
|
||||||
out_layers,
|
out_layers,
|
||||||
prompt_template_encode_start_idx);
|
prompt_template_encode_start_idx);
|
||||||
|
|||||||
457
src/llm.hpp
457
src/llm.hpp
@ -14,465 +14,16 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "clip.hpp"
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
#include "tokenize_util.h"
|
#include "tokenizers/bpe_tokenizer.h"
|
||||||
#include "vocab/vocab.h"
|
#include "tokenizers/mistral_tokenizer.h"
|
||||||
|
#include "tokenizers/qwen2_tokenizer.h"
|
||||||
|
|
||||||
namespace LLM {
|
namespace LLM {
|
||||||
constexpr int LLM_GRAPH_SIZE = 10240;
|
constexpr int LLM_GRAPH_SIZE = 10240;
|
||||||
|
|
||||||
class BPETokenizer {
|
|
||||||
protected:
|
|
||||||
std::map<int, std::u32string> byte_encoder;
|
|
||||||
std::map<std::u32string, int> byte_decoder;
|
|
||||||
std::map<std::u32string, int> encoder;
|
|
||||||
std::map<int, std::u32string> decoder;
|
|
||||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
|
||||||
std::regex pat;
|
|
||||||
int encoder_len;
|
|
||||||
int bpe_len;
|
|
||||||
|
|
||||||
std::string UNK_TOKEN;
|
|
||||||
std::string BOS_TOKEN;
|
|
||||||
std::string EOS_TOKEN;
|
|
||||||
std::string PAD_TOKEN;
|
|
||||||
|
|
||||||
int UNK_TOKEN_ID;
|
|
||||||
int BOS_TOKEN_ID;
|
|
||||||
int EOS_TOKEN_ID;
|
|
||||||
int PAD_TOKEN_ID;
|
|
||||||
|
|
||||||
std::vector<std::string> special_tokens;
|
|
||||||
|
|
||||||
bool add_bos_token = false;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
static std::string strip(const std::string& str) {
|
|
||||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
|
||||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
|
||||||
|
|
||||||
if (start == std::string::npos) {
|
|
||||||
// String contains only whitespace characters
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
return str.substr(start, end - start + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string whitespace_clean(std::string text) {
|
|
||||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
|
||||||
text = strip(text);
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
|
||||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
|
||||||
if (subwords.size() == 0) {
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
std::u32string prev_subword = subwords[0];
|
|
||||||
for (int i = 1; i < subwords.size(); i++) {
|
|
||||||
std::u32string subword = subwords[i];
|
|
||||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
|
||||||
pairs.insert(pair);
|
|
||||||
prev_subword = subword;
|
|
||||||
}
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_special_token(const std::string& token) {
|
|
||||||
for (auto& special_token : special_tokens) {
|
|
||||||
if (special_token == token) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
BPETokenizer() = default;
|
|
||||||
|
|
||||||
std::u32string bpe(const std::u32string& token) {
|
|
||||||
std::vector<std::u32string> word;
|
|
||||||
|
|
||||||
for (int i = 0; i < token.size(); i++) {
|
|
||||||
word.emplace_back(1, token[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
|
||||||
|
|
||||||
if (pairs.empty()) {
|
|
||||||
return token;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
|
||||||
pairs.end(),
|
|
||||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
|
||||||
const std::pair<std::u32string, std::u32string>& b) {
|
|
||||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
|
||||||
return false;
|
|
||||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
|
||||||
});
|
|
||||||
|
|
||||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
|
||||||
|
|
||||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::u32string first = bigram.first;
|
|
||||||
std::u32string second = bigram.second;
|
|
||||||
std::vector<std::u32string> new_word;
|
|
||||||
int32_t i = 0;
|
|
||||||
|
|
||||||
while (i < word.size()) {
|
|
||||||
auto it = std::find(word.begin() + i, word.end(), first);
|
|
||||||
if (it == word.end()) {
|
|
||||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
|
||||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
|
||||||
|
|
||||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
|
||||||
new_word.push_back(first + second);
|
|
||||||
i += 2;
|
|
||||||
} else {
|
|
||||||
new_word.push_back(word[i]);
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
word = new_word;
|
|
||||||
|
|
||||||
if (word.size() == 1) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
pairs = get_pairs(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::u32string result;
|
|
||||||
for (int i = 0; i < word.size(); i++) {
|
|
||||||
result += word[i];
|
|
||||||
if (i != word.size() - 1) {
|
|
||||||
result += utf8_to_utf32(" ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> tokenize(std::string text,
|
|
||||||
on_new_token_cb_t on_new_token_cb = nullptr,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
|
|
||||||
|
|
||||||
if (max_length > 0) {
|
|
||||||
if (tokens.size() < max_length) {
|
|
||||||
tokens.resize(max_length);
|
|
||||||
} else {
|
|
||||||
if (padding) {
|
|
||||||
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
void pad_tokens(std::vector<int>& tokens,
|
|
||||||
std::vector<float>& weights,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
if (add_bos_token) {
|
|
||||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
|
||||||
weights.insert(weights.begin(), 1.f);
|
|
||||||
}
|
|
||||||
if (max_length > 0 && padding) {
|
|
||||||
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.f / max_length));
|
|
||||||
if (n == 0) {
|
|
||||||
n = 1;
|
|
||||||
}
|
|
||||||
size_t length = max_length * n;
|
|
||||||
LOG_DEBUG("token length: %llu", length);
|
|
||||||
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
|
|
||||||
weights.insert(weights.end(), length - weights.size(), 1.f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
|
|
||||||
std::string original_text = text;
|
|
||||||
std::vector<int32_t> bpe_tokens;
|
|
||||||
std::vector<std::string> token_strs;
|
|
||||||
|
|
||||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
|
||||||
|
|
||||||
for (auto& splited_text : splited_texts) {
|
|
||||||
if (is_special_token(splited_text)) {
|
|
||||||
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
|
||||||
token_strs.push_back(splited_text);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto tokens = token_split(splited_text);
|
|
||||||
for (auto& token : tokens) {
|
|
||||||
if (on_new_token_cb != nullptr) {
|
|
||||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
|
||||||
if (skip) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string token_str = token;
|
|
||||||
std::u32string utf32_token;
|
|
||||||
for (int i = 0; i < token_str.length(); i++) {
|
|
||||||
unsigned char b = token_str[i];
|
|
||||||
utf32_token += byte_encoder[b];
|
|
||||||
}
|
|
||||||
auto bpe_strs = bpe(utf32_token);
|
|
||||||
size_t start = 0;
|
|
||||||
size_t pos;
|
|
||||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
|
||||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
|
||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
|
||||||
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
|
||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "[";
|
|
||||||
for (auto token : token_strs) {
|
|
||||||
ss << "\"" << token << "\", ";
|
|
||||||
}
|
|
||||||
ss << "]";
|
|
||||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
|
||||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
|
||||||
return bpe_tokens;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class Qwen2Tokenizer : public BPETokenizer {
|
|
||||||
protected:
|
|
||||||
void load_from_merges(const std::string& merges_utf8_str) {
|
|
||||||
auto byte_unicode_pairs = bytes_to_unicode();
|
|
||||||
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
|
||||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
|
||||||
for (auto& pair : byte_unicode_pairs) {
|
|
||||||
byte_decoder[pair.second] = pair.first;
|
|
||||||
}
|
|
||||||
// for (auto & pair: byte_unicode_pairs) {
|
|
||||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
|
||||||
// }
|
|
||||||
std::vector<std::u32string> merges;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t pos;
|
|
||||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
|
||||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
|
||||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
LOG_DEBUG("merges size %llu", merges.size());
|
|
||||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
|
||||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
|
||||||
// int print_num = 10;
|
|
||||||
for (const auto& merge : merges) {
|
|
||||||
size_t space_pos = merge.find(' ');
|
|
||||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
|
||||||
// if (print_num > 0) {
|
|
||||||
// print_num--;
|
|
||||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
|
||||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::u32string> tokens;
|
|
||||||
for (const auto& pair : byte_unicode_pairs) {
|
|
||||||
tokens.push_back(pair.second);
|
|
||||||
}
|
|
||||||
for (const auto& merge : merge_pairs) {
|
|
||||||
tokens.push_back(merge.first + merge.second);
|
|
||||||
}
|
|
||||||
for (auto& special_token : special_tokens) {
|
|
||||||
tokens.push_back(utf8_to_utf32(special_token));
|
|
||||||
}
|
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
for (const auto& token : tokens) {
|
|
||||||
encoder[token] = i;
|
|
||||||
decoder[i] = token;
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
encoder_len = i;
|
|
||||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
|
||||||
|
|
||||||
int rank = 0;
|
|
||||||
for (const auto& merge : merge_pairs) {
|
|
||||||
bpe_ranks[merge] = rank++;
|
|
||||||
}
|
|
||||||
bpe_len = rank;
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") {
|
|
||||||
UNK_TOKEN = "<|endoftext|>";
|
|
||||||
EOS_TOKEN = "<|endoftext|>";
|
|
||||||
PAD_TOKEN = "<|endoftext|>";
|
|
||||||
|
|
||||||
UNK_TOKEN_ID = 151643;
|
|
||||||
EOS_TOKEN_ID = 151643;
|
|
||||||
PAD_TOKEN_ID = 151643;
|
|
||||||
|
|
||||||
special_tokens = {
|
|
||||||
"<|endoftext|>",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
"<|object_ref_start|>",
|
|
||||||
"<|object_ref_end|>",
|
|
||||||
"<|box_start|>",
|
|
||||||
"<|box_end|>",
|
|
||||||
"<|quad_start|>",
|
|
||||||
"<|quad_end|>",
|
|
||||||
"<|vision_start|>",
|
|
||||||
"<|vision_end|>",
|
|
||||||
"<|vision_pad|>",
|
|
||||||
"<|image_pad|>",
|
|
||||||
"<|video_pad|>",
|
|
||||||
"<tool_call>",
|
|
||||||
"</tool_call>",
|
|
||||||
"<|fim_prefix|>",
|
|
||||||
"<|fim_middle|>",
|
|
||||||
"<|fim_suffix|>",
|
|
||||||
"<|fim_pad|>",
|
|
||||||
"<|repo_name|>",
|
|
||||||
"<|file_sep|>",
|
|
||||||
"<tool_response>",
|
|
||||||
"</tool_response>",
|
|
||||||
"<think>",
|
|
||||||
"</think>",
|
|
||||||
};
|
|
||||||
|
|
||||||
if (merges_utf8_str.size() > 0) {
|
|
||||||
load_from_merges(merges_utf8_str);
|
|
||||||
} else {
|
|
||||||
load_from_merges(load_qwen2_merges());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class MistralTokenizer : public BPETokenizer {
|
|
||||||
protected:
|
|
||||||
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
|
||||||
nlohmann::json vocab;
|
|
||||||
|
|
||||||
try {
|
|
||||||
vocab = nlohmann::json::parse(vocab_utf8_str);
|
|
||||||
} catch (const nlohmann::json::parse_error&) {
|
|
||||||
GGML_ABORT("invalid vocab json str");
|
|
||||||
}
|
|
||||||
for (const auto& [key, value] : vocab.items()) {
|
|
||||||
std::u32string token = utf8_to_utf32(key);
|
|
||||||
int i = value;
|
|
||||||
encoder[token] = i;
|
|
||||||
decoder[i] = token;
|
|
||||||
}
|
|
||||||
encoder_len = static_cast<int>(vocab.size());
|
|
||||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
|
||||||
|
|
||||||
auto byte_unicode_pairs = bytes_to_unicode();
|
|
||||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
|
||||||
for (auto& pair : byte_unicode_pairs) {
|
|
||||||
byte_decoder[pair.second] = pair.first;
|
|
||||||
}
|
|
||||||
std::vector<std::u32string> merges;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t pos;
|
|
||||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
|
||||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
|
||||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
LOG_DEBUG("merges size %llu", merges.size());
|
|
||||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
|
||||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
|
||||||
// int print_num = 10;
|
|
||||||
for (const auto& merge : merges) {
|
|
||||||
size_t space_pos = merge.find(' ');
|
|
||||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
|
||||||
// if (print_num > 0) {
|
|
||||||
// print_num--;
|
|
||||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
|
||||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank = 0;
|
|
||||||
for (const auto& merge : merge_pairs) {
|
|
||||||
bpe_ranks[merge] = rank++;
|
|
||||||
}
|
|
||||||
bpe_len = rank;
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "") {
|
|
||||||
add_bos_token = true;
|
|
||||||
|
|
||||||
UNK_TOKEN = "<unk>";
|
|
||||||
BOS_TOKEN = "<s>";
|
|
||||||
EOS_TOKEN = "</s>";
|
|
||||||
PAD_TOKEN = "<pad>";
|
|
||||||
|
|
||||||
UNK_TOKEN_ID = 0;
|
|
||||||
BOS_TOKEN_ID = 1;
|
|
||||||
EOS_TOKEN_ID = 2;
|
|
||||||
PAD_TOKEN_ID = 11;
|
|
||||||
|
|
||||||
special_tokens = {
|
|
||||||
"<unk>",
|
|
||||||
"<s>",
|
|
||||||
"</s>",
|
|
||||||
"[INST]",
|
|
||||||
"[/INST]",
|
|
||||||
"[AVAILABLE_TOOLS]",
|
|
||||||
"[/AVAILABLE_TOOLS]",
|
|
||||||
"[TOOL_RESULTS]",
|
|
||||||
"[/TOOL_RESULTS]",
|
|
||||||
"[TOOL_CALLS]",
|
|
||||||
"[IMG]",
|
|
||||||
"<pad>",
|
|
||||||
"[IMG_BREAK]",
|
|
||||||
"[IMG_END]",
|
|
||||||
"[PREFIX]",
|
|
||||||
"[MIDDLE]",
|
|
||||||
"[SUFFIX]",
|
|
||||||
"[SYSTEM_PROMPT]",
|
|
||||||
"[/SYSTEM_PROMPT]",
|
|
||||||
"[TOOL_CONTENT]",
|
|
||||||
};
|
|
||||||
for (int i = 20; i < 1000; i++) {
|
|
||||||
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
|
||||||
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
|
||||||
} else {
|
|
||||||
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class LLMArch {
|
enum class LLMArch {
|
||||||
QWEN2_5_VL,
|
QWEN2_5_VL,
|
||||||
QWEN3,
|
QWEN3,
|
||||||
@ -1479,7 +1030,7 @@ namespace LLM {
|
|||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer->pad_tokens(tokens, weights, max_length, padding);
|
tokenizer->pad_tokens(tokens, &weights, nullptr, padding ? max_length : 0, padding ? max_length : 100000000, padding);
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
|
|||||||
456
src/t5.hpp
456
src/t5.hpp
@ -10,452 +10,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "darts.h"
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "json.hpp"
|
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
#include "vocab/vocab.h"
|
#include "tokenizers/t5_unigram_tokenizer.h"
|
||||||
|
|
||||||
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
|
|
||||||
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
|
|
||||||
// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE
|
|
||||||
//
|
|
||||||
// Since tokenization is not the bottleneck in SD, performance was not a major consideration
|
|
||||||
// during the migration.
|
|
||||||
class MetaspacePreTokenizer {
|
|
||||||
private:
|
|
||||||
std::string replacement;
|
|
||||||
bool add_prefix_space;
|
|
||||||
|
|
||||||
public:
|
|
||||||
MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true)
|
|
||||||
: replacement(replacement), add_prefix_space(add_prefix_space) {}
|
|
||||||
|
|
||||||
std::string tokenize(const std::string& input) const {
|
|
||||||
std::string tokens;
|
|
||||||
std::stringstream ss(input);
|
|
||||||
|
|
||||||
if (add_prefix_space) {
|
|
||||||
tokens += replacement;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string token;
|
|
||||||
bool firstToken = true;
|
|
||||||
while (std::getline(ss, token, ' ')) {
|
|
||||||
if (!firstToken)
|
|
||||||
tokens += replacement + token;
|
|
||||||
else
|
|
||||||
tokens += token;
|
|
||||||
|
|
||||||
firstToken = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using EncodeResult = std::vector<std::pair<std::string, int>>;
|
|
||||||
class T5UniGramTokenizer {
|
|
||||||
public:
|
|
||||||
enum Status {
|
|
||||||
OK,
|
|
||||||
NO_PIECES_LOADED,
|
|
||||||
NO_ENTRY_FOUND,
|
|
||||||
BUILD_DOUBLE_ARRAY_FAILED,
|
|
||||||
PIECE_ALREADY_DEFINED,
|
|
||||||
INVLIAD_JSON
|
|
||||||
};
|
|
||||||
|
|
||||||
protected:
|
|
||||||
MetaspacePreTokenizer pre_tokenizer;
|
|
||||||
|
|
||||||
// all <piece, score> pairs
|
|
||||||
std::vector<std::pair<std::string, float>> piece_score_pairs;
|
|
||||||
|
|
||||||
float min_score_ = 0.0;
|
|
||||||
float max_score_ = 0.0;
|
|
||||||
std::unique_ptr<Darts::DoubleArray> trie_;
|
|
||||||
|
|
||||||
// Maximum size of the return value of Trie, which corresponds
|
|
||||||
// to the maximum size of shared common prefix in the sentence pieces.
|
|
||||||
int trie_results_size_;
|
|
||||||
// unknown id.
|
|
||||||
int unk_id_ = 2;
|
|
||||||
std::string eos_token_ = "</s>";
|
|
||||||
int eos_id_ = 1;
|
|
||||||
int pad_id_ = 0;
|
|
||||||
// status.
|
|
||||||
Status status_ = OK;
|
|
||||||
|
|
||||||
float kUnkPenalty = 10.0;
|
|
||||||
|
|
||||||
std::string replacement;
|
|
||||||
bool add_prefix_space = true;
|
|
||||||
|
|
||||||
void InitializePieces(const std::string& json_str) {
|
|
||||||
nlohmann::json data;
|
|
||||||
|
|
||||||
try {
|
|
||||||
data = nlohmann::json::parse(json_str);
|
|
||||||
} catch (const nlohmann::json::parse_error&) {
|
|
||||||
status_ = INVLIAD_JSON;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!data.contains("model")) {
|
|
||||||
status_ = INVLIAD_JSON;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
nlohmann::json model = data["model"];
|
|
||||||
if (!model.contains("vocab")) {
|
|
||||||
status_ = INVLIAD_JSON;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (model.contains("unk_id")) {
|
|
||||||
unk_id_ = model["unk_id"];
|
|
||||||
}
|
|
||||||
|
|
||||||
replacement = data["pre_tokenizer"]["replacement"];
|
|
||||||
add_prefix_space = data["pre_tokenizer"]["add_prefix_space"];
|
|
||||||
|
|
||||||
pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space);
|
|
||||||
|
|
||||||
for (const auto& item : model["vocab"]) {
|
|
||||||
if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) {
|
|
||||||
status_ = INVLIAD_JSON;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::string piece = item[0];
|
|
||||||
if (piece.empty()) {
|
|
||||||
piece = "<empty_token>";
|
|
||||||
}
|
|
||||||
float score = item[1];
|
|
||||||
piece_score_pairs.emplace_back(piece, score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builds a Trie index.
|
|
||||||
void BuildTrie(std::vector<std::pair<std::string, int>>* pieces) {
|
|
||||||
if (status_ != OK)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if (pieces->empty()) {
|
|
||||||
status_ = NO_PIECES_LOADED;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort by sentencepiece since DoubleArray::build()
|
|
||||||
// only accepts sorted strings.
|
|
||||||
sort(pieces->begin(), pieces->end());
|
|
||||||
|
|
||||||
// Makes key/value set for DoubleArrayTrie.
|
|
||||||
std::vector<const char*> key(pieces->size());
|
|
||||||
std::vector<int> value(pieces->size());
|
|
||||||
for (size_t i = 0; i < pieces->size(); ++i) {
|
|
||||||
// LOG_DEBUG("%s %d", (*pieces)[i].first.c_str(), (*pieces)[i].second);
|
|
||||||
key[i] = (*pieces)[i].first.data(); // sorted piece.
|
|
||||||
value[i] = (*pieces)[i].second; // vocab_id
|
|
||||||
}
|
|
||||||
|
|
||||||
trie_ = std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray());
|
|
||||||
if (trie_->build(key.size(), const_cast<char**>(&key[0]), nullptr,
|
|
||||||
&value[0]) != 0) {
|
|
||||||
status_ = BUILD_DOUBLE_ARRAY_FAILED;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes the maximum number of shared prefixes in the trie.
|
|
||||||
const int kMaxTrieResultsSize = 1024;
|
|
||||||
std::vector<Darts::DoubleArray::result_pair_type> results(
|
|
||||||
kMaxTrieResultsSize);
|
|
||||||
trie_results_size_ = 0;
|
|
||||||
for (const auto& p : *pieces) {
|
|
||||||
const size_t num_nodes = trie_->commonPrefixSearch(
|
|
||||||
p.first.data(), results.data(), results.size(), p.first.size());
|
|
||||||
trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (trie_results_size_ == 0)
|
|
||||||
status_ = NO_ENTRY_FOUND;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-virtual (inlined) implementation for faster execution.
|
|
||||||
inline float GetScoreInlined(int id) const {
|
|
||||||
return piece_score_pairs[id].second;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IsUnusedInlined(int id) const {
|
|
||||||
return false; // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IsUserDefinedInlined(int id) const {
|
|
||||||
return false; // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t OneCharLen(const char* src) const {
|
|
||||||
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// The optimized Viterbi encode.
|
|
||||||
// Main differences from the original function:
|
|
||||||
// 1. Memorizes the best path at each postion so far,
|
|
||||||
// 2. No need to store the Lattice nodes,
|
|
||||||
// 3. Works in utf-8 directly,
|
|
||||||
// 4. Defines a new struct with fewer fields than Lattice,
|
|
||||||
// 5. Does not depend on `class Lattice` nor call `SetSentence()`,
|
|
||||||
// `PopulateNodes()`, or `Viterbi()`. It does everything in one function.
|
|
||||||
// For detailed explanations please see the comments inside the function body.
|
|
||||||
EncodeResult EncodeOptimized(const std::string& normalized) const {
|
|
||||||
// An optimized Viterbi algorithm for unigram language models. Benchmarking
|
|
||||||
// results show that it generates almost identical outputs and achieves 2.1x
|
|
||||||
// speedup on average for 102 languages compared to the original
|
|
||||||
// implementation. It's based on the following three ideas:
|
|
||||||
//
|
|
||||||
// 1. Because it uses the *unigram* model:
|
|
||||||
// best_score(x1, x2, ... xt) = best_score(x1, x2, ... x{t-1}) + score(xt)
|
|
||||||
// Deciding the best path (and score) can be decoupled into two isolated
|
|
||||||
// terms: (a) the best path ended before the last token `best_score(x1, x2, ...)`
|
|
||||||
// x{t-1})`, and (b) the last token and its `score(xt)`. The two terms are
|
|
||||||
// not related to each other at all.
|
|
||||||
//
|
|
||||||
// Therefore, we can compute once and store the *best_path ending at
|
|
||||||
// each character position*. In this way, when we know best_path_ends_at[M],
|
|
||||||
// we can reuse it to compute all the best_path_ends_at_[...] where the last
|
|
||||||
// token starts at the same character position M.
|
|
||||||
//
|
|
||||||
// This improves the time complexity from O(n*k*k) to O(n*k) because it
|
|
||||||
// eliminates the extra loop of recomputing the best path ending at the same
|
|
||||||
// position, where n is the input length and k is the maximum number of tokens
|
|
||||||
// that can be recognized starting at each position.
|
|
||||||
//
|
|
||||||
// 2. Again, because it uses the *unigram* model, we don't need to actually
|
|
||||||
// store the lattice nodes. We still recognize all the tokens and lattice
|
|
||||||
// nodes from the input, but along identifying them, we use and discard them
|
|
||||||
// on the fly. There is no need to actually store them for best path Viterbi
|
|
||||||
// decoding. The only thing we need to store is the best_path ending at
|
|
||||||
// each character position.
|
|
||||||
//
|
|
||||||
// This improvement reduces the things needed to store in memory from O(n*k)
|
|
||||||
// to O(n), where n is the input length and k is the maximum number of tokens
|
|
||||||
// that can be recognized starting at each position.
|
|
||||||
//
|
|
||||||
// It also avoids the need of dynamic-size lattice node pool, because the
|
|
||||||
// number of things to store is fixed as n.
|
|
||||||
//
|
|
||||||
// 3. SentencePiece is designed to work with unicode, taking utf-8 encoding
|
|
||||||
// inputs. In the original implementation, the lattice positions are based on
|
|
||||||
// unicode positions. A mapping from unicode position to the utf-8 position is
|
|
||||||
// maintained to recover the utf-8 string piece.
|
|
||||||
//
|
|
||||||
// We found that it is sufficient and beneficial to directly work with utf-8
|
|
||||||
// positions:
|
|
||||||
//
|
|
||||||
// Firstly, it saves the conversion and mapping between unicode positions and
|
|
||||||
// utf-8 positions.
|
|
||||||
//
|
|
||||||
// Secondly, it reduces the number of fields we need to maintain in the
|
|
||||||
// node/path structure. Specifically, there are 8 fields defined in
|
|
||||||
// `Lattice::Node` used by the original encoder, but here in the optimized
|
|
||||||
// encoder we only need to define 3 fields in `BestPathNode`.
|
|
||||||
|
|
||||||
if (status() != OK || normalized.empty()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
// Represents the last node of the best path.
|
|
||||||
struct BestPathNode {
|
|
||||||
int id = -1; // The vocab id. (maybe -1 for UNK)
|
|
||||||
float best_path_score =
|
|
||||||
0; // The total score of the best path ending at this node.
|
|
||||||
int starts_at =
|
|
||||||
-1; // The starting position (in utf-8) of this node. The entire best
|
|
||||||
// path can be constructed by backtracking along this link.
|
|
||||||
};
|
|
||||||
const int size = static_cast<int>(normalized.size());
|
|
||||||
const float unk_score = min_score() - kUnkPenalty;
|
|
||||||
// The ends are exclusive.
|
|
||||||
std::vector<BestPathNode> best_path_ends_at(size + 1);
|
|
||||||
// Generate lattice on-the-fly (not stored) and update best_path_ends_at.
|
|
||||||
int starts_at = 0;
|
|
||||||
while (starts_at < size) {
|
|
||||||
std::size_t node_pos = 0;
|
|
||||||
std::size_t key_pos = starts_at;
|
|
||||||
const auto best_path_score_till_here =
|
|
||||||
best_path_ends_at[starts_at].best_path_score;
|
|
||||||
bool has_single_node = false;
|
|
||||||
const int mblen =
|
|
||||||
std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)),
|
|
||||||
size - starts_at);
|
|
||||||
while (key_pos < size) {
|
|
||||||
const int ret =
|
|
||||||
trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1);
|
|
||||||
if (ret == -2)
|
|
||||||
break;
|
|
||||||
if (ret >= 0) {
|
|
||||||
if (IsUnusedInlined(ret))
|
|
||||||
continue;
|
|
||||||
// Update the best path node.
|
|
||||||
auto& target_node = best_path_ends_at[key_pos];
|
|
||||||
const auto length = (key_pos - starts_at);
|
|
||||||
// User defined symbol receives extra bonus to always be selected.
|
|
||||||
const auto score = IsUserDefinedInlined(ret)
|
|
||||||
? (length * max_score_ - 0.1)
|
|
||||||
: GetScoreInlined(ret);
|
|
||||||
const auto candidate_best_path_score =
|
|
||||||
score + best_path_score_till_here;
|
|
||||||
if (target_node.starts_at == -1 ||
|
|
||||||
candidate_best_path_score > target_node.best_path_score) {
|
|
||||||
target_node.best_path_score = static_cast<float>(candidate_best_path_score);
|
|
||||||
target_node.starts_at = starts_at;
|
|
||||||
target_node.id = ret;
|
|
||||||
}
|
|
||||||
if (!has_single_node && length == mblen) {
|
|
||||||
has_single_node = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!has_single_node) {
|
|
||||||
auto& target_node = best_path_ends_at[starts_at + mblen];
|
|
||||||
const auto candidate_best_path_score =
|
|
||||||
unk_score + best_path_score_till_here;
|
|
||||||
if (target_node.starts_at == -1 ||
|
|
||||||
candidate_best_path_score > target_node.best_path_score) {
|
|
||||||
target_node.best_path_score = candidate_best_path_score;
|
|
||||||
target_node.starts_at = starts_at;
|
|
||||||
target_node.id = unk_id_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Move by one unicode character.
|
|
||||||
starts_at += mblen;
|
|
||||||
}
|
|
||||||
// Backtrack to identify the best path.
|
|
||||||
EncodeResult results;
|
|
||||||
int ends_at = size;
|
|
||||||
while (ends_at > 0) {
|
|
||||||
const auto& node = best_path_ends_at[ends_at];
|
|
||||||
results.emplace_back(
|
|
||||||
normalized.substr(node.starts_at, ends_at - node.starts_at), node.id);
|
|
||||||
ends_at = node.starts_at;
|
|
||||||
}
|
|
||||||
std::reverse(results.begin(), results.end());
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit T5UniGramTokenizer(bool is_umt5 = false) {
|
|
||||||
if (is_umt5) {
|
|
||||||
InitializePieces(load_umt5_tokenizer_json());
|
|
||||||
} else {
|
|
||||||
InitializePieces(load_t5_tokenizer_json());
|
|
||||||
}
|
|
||||||
|
|
||||||
min_score_ = FLT_MAX;
|
|
||||||
max_score_ = FLT_MIN;
|
|
||||||
|
|
||||||
std::vector<std::pair<std::string, int>> pieces;
|
|
||||||
for (int i = 0; i < piece_score_pairs.size(); i++) {
|
|
||||||
const auto& sp = piece_score_pairs[i];
|
|
||||||
|
|
||||||
min_score_ = std::min(min_score_, sp.second);
|
|
||||||
max_score_ = std::max(max_score_, sp.second);
|
|
||||||
|
|
||||||
pieces.emplace_back(sp.first, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
BuildTrie(&pieces);
|
|
||||||
}
|
|
||||||
~T5UniGramTokenizer(){};
|
|
||||||
|
|
||||||
std::string Normalize(const std::string& input) const {
|
|
||||||
// Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29
|
|
||||||
// TODO: nmt-nfkc
|
|
||||||
std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " ");
|
|
||||||
return normalized;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> Encode(const std::string& input, bool append_eos_if_not_present = true) const {
|
|
||||||
std::string normalized = Normalize(input);
|
|
||||||
normalized = pre_tokenizer.tokenize(normalized);
|
|
||||||
EncodeResult result = EncodeOptimized(normalized);
|
|
||||||
if (result.size() > 0 && append_eos_if_not_present) {
|
|
||||||
auto item = result[result.size() - 1];
|
|
||||||
if (item.first != eos_token_) {
|
|
||||||
result.emplace_back(eos_token_, eos_id_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int> tokens;
|
|
||||||
for (auto item : result) {
|
|
||||||
tokens.push_back(item.second);
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
void pad_tokens(std::vector<int>& tokens,
|
|
||||||
std::vector<float>& weights,
|
|
||||||
std::vector<float>* attention_mask,
|
|
||||||
size_t max_length = 0,
|
|
||||||
bool padding = false) {
|
|
||||||
if (max_length > 0 && padding) {
|
|
||||||
size_t orig_token_num = tokens.size() - 1;
|
|
||||||
size_t n = static_cast<size_t>(std::ceil(orig_token_num * 1.0 / (max_length - 1)));
|
|
||||||
if (n == 0) {
|
|
||||||
n = 1;
|
|
||||||
}
|
|
||||||
size_t length = max_length * n;
|
|
||||||
LOG_DEBUG("token length: %llu", length);
|
|
||||||
std::vector<int> new_tokens;
|
|
||||||
std::vector<float> new_weights;
|
|
||||||
std::vector<float> new_attention_mask;
|
|
||||||
int token_idx = 0;
|
|
||||||
for (int i = 0; i < length; i++) {
|
|
||||||
if (token_idx >= orig_token_num) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (attention_mask != nullptr) {
|
|
||||||
new_attention_mask.push_back(0.0);
|
|
||||||
}
|
|
||||||
if (i % max_length == max_length - 1) {
|
|
||||||
new_tokens.push_back(eos_id_);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
} else {
|
|
||||||
new_tokens.push_back(tokens[token_idx]);
|
|
||||||
new_weights.push_back(weights[token_idx]);
|
|
||||||
token_idx++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
new_tokens.push_back(eos_id_);
|
|
||||||
new_weights.push_back(1.0);
|
|
||||||
if (attention_mask != nullptr) {
|
|
||||||
new_attention_mask.push_back(0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens = new_tokens;
|
|
||||||
weights = new_weights;
|
|
||||||
if (attention_mask != nullptr) {
|
|
||||||
*attention_mask = new_attention_mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (padding) {
|
|
||||||
int pad_token_id = pad_id_;
|
|
||||||
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
|
|
||||||
weights.insert(weights.end(), length - weights.size(), 1.0);
|
|
||||||
if (attention_mask != nullptr) {
|
|
||||||
// maybe keep some padding tokens unmasked?
|
|
||||||
attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the minimum score in sentence pieces.
|
|
||||||
// min_score() - 10 is used for the cost of unknown sentence.
|
|
||||||
float min_score() const { return min_score_; }
|
|
||||||
|
|
||||||
// Returns the maximum score in sentence pieces.
|
|
||||||
// max_score() is used for the cost of user defined symbols.
|
|
||||||
float max_score() const { return max_score_; }
|
|
||||||
|
|
||||||
Status status() const { return status_; }
|
|
||||||
};
|
|
||||||
|
|
||||||
class T5LayerNorm : public UnaryBlock {
|
class T5LayerNorm : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
@ -937,18 +494,17 @@ struct T5Embedder {
|
|||||||
for (const auto& item : parsed_attention) {
|
for (const auto& item : parsed_attention) {
|
||||||
const std::string& curr_text = item.first;
|
const std::string& curr_text = item.first;
|
||||||
float curr_weight = item.second;
|
float curr_weight = item.second;
|
||||||
std::vector<int> curr_tokens = tokenizer.Encode(curr_text, false);
|
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
|
||||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
int EOS_TOKEN_ID = 1;
|
|
||||||
tokens.push_back(EOS_TOKEN_ID);
|
|
||||||
weights.push_back(1.0);
|
|
||||||
|
|
||||||
std::vector<float> attention_mask;
|
std::vector<float> attention_mask;
|
||||||
|
|
||||||
tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding);
|
tokenizer.pad_tokens(tokens, &weights, &attention_mask, padding ? max_length : 0, padding ? max_length : 100000000, padding);
|
||||||
|
for (auto& mask_value : attention_mask) {
|
||||||
|
mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF;
|
||||||
|
}
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
|
|||||||
189
src/tokenizers/bpe_tokenizer.cpp
Normal file
189
src/tokenizers/bpe_tokenizer.cpp
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
std::vector<std::pair<int, std::u32string>> BPETokenizer::bytes_to_unicode() {
|
||||||
|
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||||
|
std::set<int> byte_set;
|
||||||
|
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
for (int b = 161; b <= 172; ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
for (int b = 174; b <= 255; ++b) {
|
||||||
|
byte_set.insert(b);
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||||
|
}
|
||||||
|
int n = 0;
|
||||||
|
for (int b = 0; b < 256; ++b) {
|
||||||
|
if (byte_set.find(b) == byte_set.end()) {
|
||||||
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||||
|
++n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return byte_unicode_pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> BPETokenizer::token_split(const std::string& text) const {
|
||||||
|
return ::token_split(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) {
|
||||||
|
std::vector<std::u32string> result;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::u32string utf32_text = utf8_to_utf32(text);
|
||||||
|
while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) {
|
||||||
|
result.push_back(utf32_text.substr(start, pos - start));
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||||
|
if (subwords.empty()) {
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string prev_subword = subwords[0];
|
||||||
|
for (int i = 1; i < static_cast<int>(subwords.size()); i++) {
|
||||||
|
std::u32string subword = subwords[i];
|
||||||
|
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||||
|
pairs.insert(pair);
|
||||||
|
prev_subword = subword;
|
||||||
|
}
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> BPETokenizer::bpe(const std::u32string& token) const {
|
||||||
|
std::vector<std::u32string> word;
|
||||||
|
|
||||||
|
for (int i = 0; i < static_cast<int>(token.size()) - 1; i++) {
|
||||||
|
word.emplace_back(1, token[i]);
|
||||||
|
}
|
||||||
|
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix));
|
||||||
|
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||||
|
|
||||||
|
if (pairs.empty()) {
|
||||||
|
return {token + utf8_to_utf32(end_of_word_suffix)};
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||||
|
pairs.end(),
|
||||||
|
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||||
|
const std::pair<std::u32string, std::u32string>& b) {
|
||||||
|
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||||
|
return false;
|
||||||
|
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||||
|
});
|
||||||
|
|
||||||
|
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||||
|
|
||||||
|
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string first = bigram.first;
|
||||||
|
std::u32string second = bigram.second;
|
||||||
|
std::vector<std::u32string> new_word;
|
||||||
|
int32_t i = 0;
|
||||||
|
|
||||||
|
while (i < static_cast<int32_t>(word.size())) {
|
||||||
|
auto it = std::find(word.begin() + i, word.end(), first);
|
||||||
|
if (it == word.end()) {
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||||
|
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||||
|
|
||||||
|
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||||
|
new_word.push_back(first + second);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
new_word.push_back(word[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
word = new_word;
|
||||||
|
|
||||||
|
if (word.size() == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
pairs = get_pairs(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) {
|
||||||
|
std::string normalized_text = normalize(text);
|
||||||
|
std::vector<int32_t> bpe_tokens;
|
||||||
|
std::vector<std::string> token_strs;
|
||||||
|
|
||||||
|
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
|
||||||
|
|
||||||
|
for (auto& splited_text : splited_texts) {
|
||||||
|
if (is_special_token(splited_text)) {
|
||||||
|
if (on_new_token_cb != nullptr) {
|
||||||
|
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
||||||
|
if (skip) {
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto tokens = token_split(splited_text);
|
||||||
|
for (auto& token : tokens) {
|
||||||
|
if (on_new_token_cb != nullptr) {
|
||||||
|
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||||
|
if (skip) {
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token_str = token;
|
||||||
|
std::u32string utf32_token;
|
||||||
|
for (int i = 0; i < static_cast<int>(token_str.length()); i++) {
|
||||||
|
unsigned char b = token_str[i];
|
||||||
|
utf32_token += byte_encoder[b];
|
||||||
|
}
|
||||||
|
auto bpe_strs = bpe(utf32_token);
|
||||||
|
for (auto bpe_str : bpe_strs) {
|
||||||
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
for (auto token : token_strs) {
|
||||||
|
ss << "\"" << token << "\", ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str());
|
||||||
|
return bpe_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string BPETokenizer::decode_token(int token_id) const {
|
||||||
|
return utf32_to_utf8(decoder.at(token_id));
|
||||||
|
}
|
||||||
40
src/tokenizers/bpe_tokenizer.h
Normal file
40
src/tokenizers/bpe_tokenizer.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <regex>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tokenizer.h"
|
||||||
|
|
||||||
|
class BPETokenizer : public Tokenizer {
|
||||||
|
protected:
|
||||||
|
std::map<int, std::u32string> byte_encoder;
|
||||||
|
std::map<std::u32string, int> byte_decoder;
|
||||||
|
std::map<std::u32string, int> encoder;
|
||||||
|
std::map<int, std::u32string> decoder;
|
||||||
|
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||||
|
int encoder_len = 0;
|
||||||
|
int bpe_len = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static std::vector<std::pair<int, std::u32string>> bytes_to_unicode();
|
||||||
|
static std::vector<std::u32string> split_utf32(const std::string& text, char32_t delimiter = U'\n');
|
||||||
|
virtual std::vector<std::string> token_split(const std::string& text) const;
|
||||||
|
std::vector<std::u32string> bpe(const std::u32string& token) const;
|
||||||
|
std::string decode_token(int token_id) const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
BPETokenizer() = default;
|
||||||
|
virtual ~BPETokenizer() = default;
|
||||||
|
|
||||||
|
std::vector<int> encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||||
116
src/tokenizers/clip_tokenizer.cpp
Normal file
116
src/tokenizers/clip_tokenizer.cpp
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#include "clip_tokenizer.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
|
#include <cmath>
|
||||||
|
#include <regex>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_str) {
|
||||||
|
UNK_TOKEN = "<|endoftext|>";
|
||||||
|
BOS_TOKEN = "<|startoftext|>";
|
||||||
|
EOS_TOKEN = "<|endoftext|>";
|
||||||
|
PAD_TOKEN = "<|endoftext|>";
|
||||||
|
|
||||||
|
UNK_TOKEN_ID = 49407;
|
||||||
|
BOS_TOKEN_ID = 49406;
|
||||||
|
EOS_TOKEN_ID = 49407;
|
||||||
|
PAD_TOKEN_ID = pad_token_id;
|
||||||
|
|
||||||
|
end_of_word_suffix = "</w>";
|
||||||
|
add_bos_token = true;
|
||||||
|
add_eos_token = true;
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_clip_merges());
|
||||||
|
}
|
||||||
|
add_special_token("<|startoftext|>");
|
||||||
|
add_special_token("<|endoftext|>");
|
||||||
|
}
|
||||||
|
|
||||||
|
void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) {
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
for (auto& pair : byte_unicode_pairs) {
|
||||||
|
byte_decoder[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
GGML_ASSERT(merges.size() == 48895);
|
||||||
|
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
std::vector<std::u32string> vocab;
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
vocab.push_back(pair.second);
|
||||||
|
}
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||||
|
}
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
vocab.push_back(merge.first + merge.second);
|
||||||
|
}
|
||||||
|
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||||
|
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||||
|
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& token : vocab) {
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
encoder_len = i;
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string strip(const std::string& str) {
|
||||||
|
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||||
|
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||||
|
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
return str.substr(start, end - start + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string whitespace_clean(const std::string& text) {
|
||||||
|
auto result = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||||
|
result = strip(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CLIPTokenizer::normalize(const std::string& text) const {
|
||||||
|
auto normalized_text = whitespace_clean(text);
|
||||||
|
std::transform(normalized_text.begin(), normalized_text.end(), normalized_text.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||||
|
return normalized_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> CLIPTokenizer::token_split(const std::string& text) const {
|
||||||
|
std::regex clip_pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||||
|
std::regex::icase);
|
||||||
|
std::sregex_iterator iter(text.begin(), text.end(), clip_pat);
|
||||||
|
std::sregex_iterator end;
|
||||||
|
|
||||||
|
std::vector<std::string> result;
|
||||||
|
for (; iter != end; ++iter) {
|
||||||
|
result.emplace_back(iter->str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
20
src/tokenizers/clip_tokenizer.h
Normal file
20
src/tokenizers/clip_tokenizer.h
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
class CLIPTokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str);
|
||||||
|
std::string normalize(const std::string& text) const override;
|
||||||
|
std::vector<std::string> token_split(const std::string& text) const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||||
89
src/tokenizers/mistral_tokenizer.cpp
Normal file
89
src/tokenizers/mistral_tokenizer.cpp
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
#include "mistral_tokenizer.h"
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
nlohmann::json vocab;
|
||||||
|
|
||||||
|
try {
|
||||||
|
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||||
|
} catch (const nlohmann::json::parse_error&) {
|
||||||
|
GGML_ABORT("invalid vocab json str");
|
||||||
|
}
|
||||||
|
for (const auto& [key, value] : vocab.items()) {
|
||||||
|
std::u32string token = utf8_to_utf32(key);
|
||||||
|
int i = value;
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
}
|
||||||
|
encoder_len = static_cast<int>(vocab.size());
|
||||||
|
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||||
|
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
for (auto& pair : byte_unicode_pairs) {
|
||||||
|
byte_decoder[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
LOG_DEBUG("merges size %llu", merges.size());
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
MistralTokenizer::MistralTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
add_bos_token = true;
|
||||||
|
|
||||||
|
UNK_TOKEN = "<unk>";
|
||||||
|
BOS_TOKEN = "<s>";
|
||||||
|
EOS_TOKEN = "</s>";
|
||||||
|
PAD_TOKEN = "<pad>";
|
||||||
|
|
||||||
|
UNK_TOKEN_ID = 0;
|
||||||
|
BOS_TOKEN_ID = 1;
|
||||||
|
EOS_TOKEN_ID = 2;
|
||||||
|
PAD_TOKEN_ID = 11;
|
||||||
|
|
||||||
|
special_tokens = {
|
||||||
|
"<unk>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
"[INST]",
|
||||||
|
"[/INST]",
|
||||||
|
"[AVAILABLE_TOOLS]",
|
||||||
|
"[/AVAILABLE_TOOLS]",
|
||||||
|
"[TOOL_RESULTS]",
|
||||||
|
"[/TOOL_RESULTS]",
|
||||||
|
"[TOOL_CALLS]",
|
||||||
|
"[IMG]",
|
||||||
|
"<pad>",
|
||||||
|
"[IMG_BREAK]",
|
||||||
|
"[IMG_END]",
|
||||||
|
"[PREFIX]",
|
||||||
|
"[MIDDLE]",
|
||||||
|
"[SUFFIX]",
|
||||||
|
"[SYSTEM_PROMPT]",
|
||||||
|
"[/SYSTEM_PROMPT]",
|
||||||
|
"[TOOL_CONTENT]",
|
||||||
|
};
|
||||||
|
for (int i = 20; i < 1000; i++) {
|
||||||
|
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
|
||||||
|
}
|
||||||
|
}
|
||||||
16
src/tokenizers/mistral_tokenizer.h
Normal file
16
src/tokenizers/mistral_tokenizer.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
class MistralTokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||||
91
src/tokenizers/qwen2_tokenizer.cpp
Normal file
91
src/tokenizers/qwen2_tokenizer.cpp
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#include "qwen2_tokenizer.h"
|
||||||
|
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
void Qwen2Tokenizer::load_from_merges(const std::string& merges_utf8_str) {
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
for (auto& pair : byte_unicode_pairs) {
|
||||||
|
byte_decoder[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
LOG_DEBUG("merges size %llu", merges.size());
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> tokens;
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
tokens.push_back(pair.second);
|
||||||
|
}
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
tokens.push_back(merge.first + merge.second);
|
||||||
|
}
|
||||||
|
for (auto& special_token : special_tokens) {
|
||||||
|
tokens.push_back(utf8_to_utf32(special_token));
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& token : tokens) {
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
encoder_len = i;
|
||||||
|
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
Qwen2Tokenizer::Qwen2Tokenizer(const std::string& merges_utf8_str) {
|
||||||
|
UNK_TOKEN = "<|endoftext|>";
|
||||||
|
EOS_TOKEN = "<|endoftext|>";
|
||||||
|
PAD_TOKEN = "<|endoftext|>";
|
||||||
|
|
||||||
|
UNK_TOKEN_ID = 151643;
|
||||||
|
EOS_TOKEN_ID = 151643;
|
||||||
|
PAD_TOKEN_ID = 151643;
|
||||||
|
|
||||||
|
special_tokens = {
|
||||||
|
"<|endoftext|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>",
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<|fim_prefix|>",
|
||||||
|
"<|fim_middle|>",
|
||||||
|
"<|fim_suffix|>",
|
||||||
|
"<|fim_pad|>",
|
||||||
|
"<|repo_name|>",
|
||||||
|
"<|file_sep|>",
|
||||||
|
"<tool_response>",
|
||||||
|
"</tool_response>",
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_qwen2_merges());
|
||||||
|
}
|
||||||
|
}
|
||||||
16
src/tokenizers/qwen2_tokenizer.h
Normal file
16
src/tokenizers/qwen2_tokenizer.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_QWEN2_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_QWEN2_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
class Qwen2Tokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str);
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_QWEN2_TOKENIZER_H__
|
||||||
339
src/tokenizers/t5_unigram_tokenizer.cpp
Normal file
339
src/tokenizers/t5_unigram_tokenizer.cpp
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
#include "t5_unigram_tokenizer.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
#include <regex>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
|
||||||
|
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
|
||||||
|
// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE
|
||||||
|
//
|
||||||
|
// Since tokenization is not the bottleneck in SD, performance was not a major consideration
|
||||||
|
// during the migration.
|
||||||
|
|
||||||
|
MetaspacePreTokenizer::MetaspacePreTokenizer(const std::string replacement, bool add_prefix_space)
|
||||||
|
: replacement(replacement), add_prefix_space(add_prefix_space) {}
|
||||||
|
|
||||||
|
std::string MetaspacePreTokenizer::tokenize(const std::string& input) const {
|
||||||
|
std::string tokens;
|
||||||
|
std::stringstream ss(input);
|
||||||
|
|
||||||
|
if (add_prefix_space) {
|
||||||
|
tokens += replacement;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token;
|
||||||
|
bool first_token = true;
|
||||||
|
while (std::getline(ss, token, ' ')) {
|
||||||
|
if (!first_token) {
|
||||||
|
tokens += replacement + token;
|
||||||
|
} else {
|
||||||
|
tokens += token;
|
||||||
|
}
|
||||||
|
|
||||||
|
first_token = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
void T5UniGramTokenizer::InitializePieces(const std::string& json_str) {
|
||||||
|
nlohmann::json data;
|
||||||
|
|
||||||
|
try {
|
||||||
|
data = nlohmann::json::parse(json_str);
|
||||||
|
} catch (const nlohmann::json::parse_error&) {
|
||||||
|
status_ = INVLIAD_JSON;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!data.contains("model")) {
|
||||||
|
status_ = INVLIAD_JSON;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nlohmann::json model = data["model"];
|
||||||
|
if (!model.contains("vocab")) {
|
||||||
|
status_ = INVLIAD_JSON;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (model.contains("unk_id")) {
|
||||||
|
UNK_TOKEN_ID = model["unk_id"];
|
||||||
|
}
|
||||||
|
|
||||||
|
replacement = data["pre_tokenizer"]["replacement"];
|
||||||
|
add_prefix_space = data["pre_tokenizer"]["add_prefix_space"];
|
||||||
|
|
||||||
|
pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space);
|
||||||
|
|
||||||
|
for (const auto& item : model["vocab"]) {
|
||||||
|
if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) {
|
||||||
|
status_ = INVLIAD_JSON;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string piece = item[0];
|
||||||
|
if (piece.empty()) {
|
||||||
|
piece = "<empty_token>";
|
||||||
|
}
|
||||||
|
float score = item[1];
|
||||||
|
piece_score_pairs.emplace_back(piece, score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void T5UniGramTokenizer::BuildTrie(std::vector<std::pair<std::string, int>>* pieces) {
|
||||||
|
if (status_ != OK) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pieces->empty()) {
|
||||||
|
status_ = NO_PIECES_LOADED;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(pieces->begin(), pieces->end());
|
||||||
|
|
||||||
|
std::vector<const char*> key(pieces->size());
|
||||||
|
std::vector<int> value(pieces->size());
|
||||||
|
for (size_t i = 0; i < pieces->size(); ++i) {
|
||||||
|
key[i] = (*pieces)[i].first.data();
|
||||||
|
value[i] = (*pieces)[i].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
trie_ = std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray());
|
||||||
|
if (trie_->build(key.size(), const_cast<char**>(&key[0]), nullptr, &value[0]) != 0) {
|
||||||
|
status_ = BUILD_DOUBLE_ARRAY_FAILED;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int kMaxTrieResultsSize = 1024;
|
||||||
|
std::vector<Darts::DoubleArray::result_pair_type> results(kMaxTrieResultsSize);
|
||||||
|
trie_results_size_ = 0;
|
||||||
|
for (const auto& p : *pieces) {
|
||||||
|
const size_t num_nodes = trie_->commonPrefixSearch(
|
||||||
|
p.first.data(), results.data(), results.size(), p.first.size());
|
||||||
|
trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trie_results_size_ == 0) {
|
||||||
|
status_ = NO_ENTRY_FOUND;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float T5UniGramTokenizer::GetScoreInlined(int id) const {
|
||||||
|
return piece_score_pairs[id].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool T5UniGramTokenizer::IsUnusedInlined(int id) const {
|
||||||
|
(void)id;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool T5UniGramTokenizer::IsUserDefinedInlined(int id) const {
|
||||||
|
(void)id;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t T5UniGramTokenizer::OneCharLen(const char* src) const {
|
||||||
|
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
EncodeResult T5UniGramTokenizer::EncodeOptimized(const std::string& normalized) const {
|
||||||
|
if (status() != OK || normalized.empty()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BestPathNode {
|
||||||
|
int id = -1;
|
||||||
|
float best_path_score = 0;
|
||||||
|
int starts_at = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int size = static_cast<int>(normalized.size());
|
||||||
|
const float unk_score = min_score() - kUnkPenalty;
|
||||||
|
std::vector<BestPathNode> best_path_ends_at(size + 1);
|
||||||
|
|
||||||
|
int starts_at = 0;
|
||||||
|
while (starts_at < size) {
|
||||||
|
std::size_t node_pos = 0;
|
||||||
|
std::size_t key_pos = starts_at;
|
||||||
|
const auto best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
|
||||||
|
bool has_single_node = false;
|
||||||
|
const int mblen = std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)), size - starts_at);
|
||||||
|
while (key_pos < static_cast<size_t>(size)) {
|
||||||
|
const int ret = trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1);
|
||||||
|
if (ret == -2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ret >= 0) {
|
||||||
|
if (IsUnusedInlined(ret)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& target_node = best_path_ends_at[key_pos];
|
||||||
|
const auto length = static_cast<int>(key_pos - starts_at);
|
||||||
|
const auto score = IsUserDefinedInlined(ret) ? (length * max_score_ - 0.1f) : GetScoreInlined(ret);
|
||||||
|
const auto candidate_best_path_score = score + best_path_score_till_here;
|
||||||
|
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
|
||||||
|
target_node.best_path_score = static_cast<float>(candidate_best_path_score);
|
||||||
|
target_node.starts_at = starts_at;
|
||||||
|
target_node.id = ret;
|
||||||
|
}
|
||||||
|
if (!has_single_node && length == mblen) {
|
||||||
|
has_single_node = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_single_node) {
|
||||||
|
auto& target_node = best_path_ends_at[starts_at + mblen];
|
||||||
|
const auto candidate_best_path_score = unk_score + best_path_score_till_here;
|
||||||
|
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
|
||||||
|
target_node.best_path_score = candidate_best_path_score;
|
||||||
|
target_node.starts_at = starts_at;
|
||||||
|
target_node.id = UNK_TOKEN_ID;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
starts_at += mblen;
|
||||||
|
}
|
||||||
|
|
||||||
|
EncodeResult results;
|
||||||
|
int ends_at = size;
|
||||||
|
while (ends_at > 0) {
|
||||||
|
const auto& node = best_path_ends_at[ends_at];
|
||||||
|
results.emplace_back(normalized.substr(node.starts_at, ends_at - node.starts_at), node.id);
|
||||||
|
ends_at = node.starts_at;
|
||||||
|
}
|
||||||
|
std::reverse(results.begin(), results.end());
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
T5UniGramTokenizer::T5UniGramTokenizer(bool is_umt5) {
|
||||||
|
add_bos_token = false;
|
||||||
|
add_eos_token = true;
|
||||||
|
|
||||||
|
if (is_umt5) {
|
||||||
|
PAD_TOKEN_ID = 0;
|
||||||
|
EOS_TOKEN_ID = 1;
|
||||||
|
BOS_TOKEN_ID = 2;
|
||||||
|
UNK_TOKEN_ID = 3;
|
||||||
|
|
||||||
|
PAD_TOKEN = "<pad>";
|
||||||
|
EOS_TOKEN = "</s>";
|
||||||
|
BOS_TOKEN = "<s>";
|
||||||
|
UNK_TOKEN = "<unk>";
|
||||||
|
} else {
|
||||||
|
PAD_TOKEN_ID = 0;
|
||||||
|
EOS_TOKEN_ID = 1;
|
||||||
|
UNK_TOKEN_ID = 2;
|
||||||
|
|
||||||
|
PAD_TOKEN = "<pad>";
|
||||||
|
EOS_TOKEN = "</s>";
|
||||||
|
UNK_TOKEN = "<unk>";
|
||||||
|
}
|
||||||
|
|
||||||
|
special_tokens = {
|
||||||
|
"<pad>",
|
||||||
|
"</s>",
|
||||||
|
"<unk>",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (is_umt5) {
|
||||||
|
special_tokens.push_back("<s>");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_umt5) {
|
||||||
|
InitializePieces(load_umt5_tokenizer_json());
|
||||||
|
} else {
|
||||||
|
InitializePieces(load_t5_tokenizer_json());
|
||||||
|
}
|
||||||
|
|
||||||
|
min_score_ = FLT_MAX;
|
||||||
|
max_score_ = FLT_MIN;
|
||||||
|
|
||||||
|
std::vector<std::pair<std::string, int>> pieces;
|
||||||
|
for (int i = 0; i < static_cast<int>(piece_score_pairs.size()); i++) {
|
||||||
|
const auto& sp = piece_score_pairs[i];
|
||||||
|
|
||||||
|
min_score_ = std::min(min_score_, sp.second);
|
||||||
|
max_score_ = std::max(max_score_, sp.second);
|
||||||
|
|
||||||
|
pieces.emplace_back(sp.first, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
BuildTrie(&pieces);
|
||||||
|
}
|
||||||
|
|
||||||
|
T5UniGramTokenizer::~T5UniGramTokenizer() = default;
|
||||||
|
|
||||||
|
std::string T5UniGramTokenizer::decode_token(int token_id) const {
|
||||||
|
if (token_id < 0 || token_id >= static_cast<int>(piece_score_pairs.size())) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& piece = piece_score_pairs[token_id].first;
|
||||||
|
if (piece == "<empty_token>") {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return piece;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string T5UniGramTokenizer::normalize(const std::string& input) const {
|
||||||
|
// Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29
|
||||||
|
// TODO: nmt-nfkc
|
||||||
|
std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " ");
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> T5UniGramTokenizer::encode(const std::string& input, on_new_token_cb_t on_new_token_cb) {
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
std::vector<std::string> token_strs;
|
||||||
|
std::string normalized = normalize(input);
|
||||||
|
auto splited_texts = split_with_special_tokens(normalized, special_tokens);
|
||||||
|
if (splited_texts.empty()) {
|
||||||
|
splited_texts.push_back(normalized); // for empty string
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& splited_text : splited_texts) {
|
||||||
|
if (is_special_token(splited_text)) {
|
||||||
|
if (on_new_token_cb != nullptr) {
|
||||||
|
bool skip = on_new_token_cb(splited_text, tokens);
|
||||||
|
if (skip) {
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (splited_text == UNK_TOKEN) {
|
||||||
|
tokens.push_back(UNK_TOKEN_ID);
|
||||||
|
token_strs.push_back(UNK_TOKEN);
|
||||||
|
} else if (splited_text == EOS_TOKEN) {
|
||||||
|
tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
token_strs.push_back(EOS_TOKEN);
|
||||||
|
} else if (splited_text == PAD_TOKEN) {
|
||||||
|
tokens.push_back(PAD_TOKEN_ID);
|
||||||
|
token_strs.push_back(PAD_TOKEN);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string pretokenized = pre_tokenizer.tokenize(splited_text);
|
||||||
|
EncodeResult result = EncodeOptimized(pretokenized);
|
||||||
|
for (const auto& item : result) {
|
||||||
|
tokens.push_back(item.second);
|
||||||
|
token_strs.push_back(item.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
for (const auto& token_str : token_strs) {
|
||||||
|
ss << "\"" << token_str << "\", ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
LOG_DEBUG("split prompt \"%s\" to tokens %s", input.c_str(), ss.str().c_str());
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
70
src/tokenizers/t5_unigram_tokenizer.h
Normal file
70
src/tokenizers/t5_unigram_tokenizer.h
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "darts.h"
|
||||||
|
#include "tokenizer.h"
|
||||||
|
|
||||||
|
class MetaspacePreTokenizer {
|
||||||
|
private:
|
||||||
|
std::string replacement;
|
||||||
|
bool add_prefix_space;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true);
|
||||||
|
|
||||||
|
std::string tokenize(const std::string& input) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
using EncodeResult = std::vector<std::pair<std::string, int>>;
|
||||||
|
|
||||||
|
class T5UniGramTokenizer : public Tokenizer {
|
||||||
|
public:
|
||||||
|
enum Status {
|
||||||
|
OK,
|
||||||
|
NO_PIECES_LOADED,
|
||||||
|
NO_ENTRY_FOUND,
|
||||||
|
BUILD_DOUBLE_ARRAY_FAILED,
|
||||||
|
PIECE_ALREADY_DEFINED,
|
||||||
|
INVLIAD_JSON
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
MetaspacePreTokenizer pre_tokenizer;
|
||||||
|
std::vector<std::pair<std::string, float>> piece_score_pairs;
|
||||||
|
float min_score_ = 0.0f;
|
||||||
|
float max_score_ = 0.0f;
|
||||||
|
std::unique_ptr<Darts::DoubleArray> trie_;
|
||||||
|
int trie_results_size_ = 0;
|
||||||
|
Status status_ = OK;
|
||||||
|
float kUnkPenalty = 10.0f;
|
||||||
|
std::string replacement;
|
||||||
|
bool add_prefix_space = true;
|
||||||
|
|
||||||
|
void InitializePieces(const std::string& json_str);
|
||||||
|
void BuildTrie(std::vector<std::pair<std::string, int>>* pieces);
|
||||||
|
float GetScoreInlined(int id) const;
|
||||||
|
bool IsUnusedInlined(int id) const;
|
||||||
|
bool IsUserDefinedInlined(int id) const;
|
||||||
|
size_t OneCharLen(const char* src) const;
|
||||||
|
EncodeResult EncodeOptimized(const std::string& normalized) const;
|
||||||
|
|
||||||
|
float min_score() const { return min_score_; }
|
||||||
|
float max_score() const { return max_score_; }
|
||||||
|
Status status() const { return status_; }
|
||||||
|
std::string decode_token(int token_id) const override;
|
||||||
|
std::string normalize(const std::string& input) const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit T5UniGramTokenizer(bool is_umt5 = false);
|
||||||
|
~T5UniGramTokenizer();
|
||||||
|
|
||||||
|
std::vector<int> encode(const std::string& input, on_new_token_cb_t on_new_token_cb = nullptr) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#ifndef __TOKENIZE_UTIL__
|
#ifndef __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__
|
||||||
#define __TOKENIZE_UTIL__
|
#define __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -7,4 +7,4 @@
|
|||||||
std::vector<std::string> token_split(const std::string& text);
|
std::vector<std::string> token_split(const std::string& text);
|
||||||
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::vector<std::string>& special_tokens);
|
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::vector<std::string>& special_tokens);
|
||||||
|
|
||||||
#endif // __TOKENIZE_UTIL__
|
#endif // __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__
|
||||||
211
src/tokenizers/tokenizer.cpp
Normal file
211
src/tokenizers/tokenizer.cpp
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
#include "tokenizer.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
void Tokenizer::add_special_token(const std::string& token) {
|
||||||
|
special_tokens.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Tokenizer::is_special_token(const std::string& token) const {
|
||||||
|
for (const auto& special_token : special_tokens) {
|
||||||
|
if (special_token == token) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Tokenizer::normalize(const std::string& text) const {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> Tokenizer::tokenize(const std::string& text,
|
||||||
|
on_new_token_cb_t on_new_token_cb,
|
||||||
|
bool padding,
|
||||||
|
size_t min_length,
|
||||||
|
size_t max_length,
|
||||||
|
bool allow_overflow_expand) {
|
||||||
|
std::vector<int> tokens = encode(text, on_new_token_cb);
|
||||||
|
if (padding) {
|
||||||
|
pad_tokens(tokens, nullptr, nullptr, min_length, max_length, allow_overflow_expand);
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tokenizer::pad_tokens(std::vector<int>& tokens,
|
||||||
|
std::vector<float>* weights,
|
||||||
|
std::vector<float>* mask,
|
||||||
|
size_t min_length,
|
||||||
|
size_t max_length,
|
||||||
|
bool allow_overflow_expand) {
|
||||||
|
const bool use_weights = weights != nullptr;
|
||||||
|
const bool use_mask = mask != nullptr;
|
||||||
|
|
||||||
|
if (use_weights && tokens.size() != weights->size()) {
|
||||||
|
LOG_ERROR("tokens size != weights size");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t bos_count = add_bos_token ? 1 : 0;
|
||||||
|
const size_t eos_count = add_eos_token ? 1 : 0;
|
||||||
|
const size_t special_token_count = bos_count + eos_count;
|
||||||
|
|
||||||
|
auto build_sequence = [&](size_t begin,
|
||||||
|
size_t count,
|
||||||
|
size_t target_length,
|
||||||
|
std::vector<int>& out_tokens,
|
||||||
|
std::vector<float>& out_weights,
|
||||||
|
std::vector<float>& out_mask) {
|
||||||
|
const size_t base_length = count + special_token_count;
|
||||||
|
const size_t final_length = std::max(target_length, base_length);
|
||||||
|
|
||||||
|
out_tokens.clear();
|
||||||
|
out_weights.clear();
|
||||||
|
out_mask.clear();
|
||||||
|
|
||||||
|
out_tokens.reserve(final_length);
|
||||||
|
if (use_weights) {
|
||||||
|
out_weights.reserve(final_length);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
out_mask.reserve(final_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_bos_token) {
|
||||||
|
out_tokens.push_back(BOS_TOKEN_ID);
|
||||||
|
if (use_weights) {
|
||||||
|
out_weights.push_back(1.0f);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
out_mask.push_back(1.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
out_tokens.push_back(tokens[begin + i]);
|
||||||
|
if (use_weights) {
|
||||||
|
out_weights.push_back((*weights)[begin + i]);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
out_mask.push_back(1.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_eos_token) {
|
||||||
|
out_tokens.push_back(EOS_TOKEN_ID);
|
||||||
|
if (use_weights) {
|
||||||
|
out_weights.push_back(1.0f);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
out_mask.push_back(1.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (final_length > out_tokens.size()) {
|
||||||
|
const size_t pad_count = final_length - out_tokens.size();
|
||||||
|
out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID);
|
||||||
|
|
||||||
|
if (use_weights) {
|
||||||
|
out_weights.insert(out_weights.end(), pad_count, 1.0f);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
out_mask.insert(out_mask.end(), pad_count, 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t single_length = std::max(min_length, tokens.size() + special_token_count);
|
||||||
|
const bool exceeds_max_length = max_length > 0 && single_length > max_length;
|
||||||
|
|
||||||
|
std::vector<int> new_tokens;
|
||||||
|
std::vector<float> new_weights;
|
||||||
|
std::vector<float> new_mask;
|
||||||
|
|
||||||
|
if (!exceeds_max_length) {
|
||||||
|
build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask);
|
||||||
|
} else if (!allow_overflow_expand) {
|
||||||
|
build_sequence(0, tokens.size(), 0, new_tokens, new_weights, new_mask);
|
||||||
|
|
||||||
|
new_tokens.resize(max_length);
|
||||||
|
if (use_weights) {
|
||||||
|
new_weights.resize(max_length);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
new_mask.resize(max_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_eos_token && !new_tokens.empty()) {
|
||||||
|
new_tokens.back() = EOS_TOKEN_ID;
|
||||||
|
if (use_weights) {
|
||||||
|
new_weights.back() = 1.0f;
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
new_mask.back() = 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (min_length > special_token_count) {
|
||||||
|
const size_t tokens_per_chunk = min_length - special_token_count;
|
||||||
|
size_t offset = 0;
|
||||||
|
|
||||||
|
while (offset < tokens.size()) {
|
||||||
|
const size_t remaining = tokens.size() - offset;
|
||||||
|
const size_t take = std::min(tokens_per_chunk, remaining);
|
||||||
|
|
||||||
|
std::vector<int> chunk_tokens;
|
||||||
|
std::vector<float> chunk_weights;
|
||||||
|
std::vector<float> chunk_mask;
|
||||||
|
|
||||||
|
build_sequence(offset, take, min_length, chunk_tokens, chunk_weights, chunk_mask);
|
||||||
|
|
||||||
|
new_tokens.insert(new_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||||
|
if (use_weights) {
|
||||||
|
new_weights.insert(new_weights.end(), chunk_weights.begin(), chunk_weights.end());
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
new_mask.insert(new_mask.end(), chunk_mask.begin(), chunk_mask.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += take;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = std::move(new_tokens);
|
||||||
|
if (use_weights) {
|
||||||
|
*weights = std::move(new_weights);
|
||||||
|
}
|
||||||
|
if (use_mask) {
|
||||||
|
*mask = std::move(new_mask);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string clean_up_tokenization(std::string& text) {
|
||||||
|
std::regex pattern(R"( ,)");
|
||||||
|
return std::regex_replace(text, pattern, ",");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Tokenizer::decode(const std::vector<int>& tokens) const {
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
for (int token_id : tokens) {
|
||||||
|
if (token_id == BOS_TOKEN_ID || token_id == EOS_TOKEN_ID || token_id == PAD_TOKEN_ID) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string piece = decode_token(token_id);
|
||||||
|
if (!end_of_word_suffix.empty() && ends_with(piece, end_of_word_suffix)) {
|
||||||
|
piece.erase(piece.size() - end_of_word_suffix.size());
|
||||||
|
text += piece + " ";
|
||||||
|
} else {
|
||||||
|
text += piece;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
text = clean_up_tokenization(text);
|
||||||
|
return trim(text);
|
||||||
|
}
|
||||||
52
src/tokenizers/tokenizer.h
Normal file
52
src/tokenizers/tokenizer.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using on_new_token_cb_t = std::function<bool(std::string&, std::vector<int32_t>&)>;
|
||||||
|
|
||||||
|
class Tokenizer {
|
||||||
|
protected:
|
||||||
|
std::vector<std::string> special_tokens;
|
||||||
|
bool add_bos_token = false;
|
||||||
|
bool add_eos_token = false;
|
||||||
|
std::string end_of_word_suffix;
|
||||||
|
|
||||||
|
virtual std::string decode_token(int token_id) const = 0;
|
||||||
|
virtual std::string normalize(const std::string& text) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::string UNK_TOKEN;
|
||||||
|
std::string BOS_TOKEN;
|
||||||
|
std::string EOS_TOKEN;
|
||||||
|
std::string PAD_TOKEN;
|
||||||
|
int UNK_TOKEN_ID = 0;
|
||||||
|
int BOS_TOKEN_ID = 0;
|
||||||
|
int EOS_TOKEN_ID = 0;
|
||||||
|
int PAD_TOKEN_ID = 0;
|
||||||
|
|
||||||
|
virtual ~Tokenizer() = default;
|
||||||
|
|
||||||
|
void add_special_token(const std::string& token);
|
||||||
|
bool is_special_token(const std::string& token) const;
|
||||||
|
virtual std::vector<int> encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) = 0;
|
||||||
|
std::vector<int> tokenize(const std::string& text,
|
||||||
|
on_new_token_cb_t on_new_token_cb = nullptr,
|
||||||
|
bool padding = false,
|
||||||
|
size_t min_length = 0,
|
||||||
|
size_t max_length = 100000000,
|
||||||
|
bool allow_overflow_expand = false);
|
||||||
|
void pad_tokens(std::vector<int>& tokens,
|
||||||
|
std::vector<float>* weights,
|
||||||
|
std::vector<float>* mask,
|
||||||
|
size_t min_length = 0,
|
||||||
|
size_t max_length = 100000000,
|
||||||
|
bool allow_overflow_expand = false);
|
||||||
|
std::string decode(const std::vector<int>& tokens) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_TOKENIZER_H__
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#ifndef __VOCAB_H__
|
#ifndef __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||||
#define __VOCAB_H__
|
#define __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@ -10,4 +10,4 @@ std::string load_mistral_vocab_json();
|
|||||||
std::string load_t5_tokenizer_json();
|
std::string load_t5_tokenizer_json();
|
||||||
std::string load_umt5_tokenizer_json();
|
std::string load_umt5_tokenizer_json();
|
||||||
|
|
||||||
#endif // __VOCAB_H__
|
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||||
Loading…
x
Reference in New Issue
Block a user