mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
340 lines
11 KiB
C++
340 lines
11 KiB
C++
#include "t5_unigram_tokenizer.h"
|
|
|
|
#include <algorithm>
|
|
#include <cfloat>
|
|
#include <cmath>
|
|
#include <regex>
|
|
#include <sstream>
|
|
|
|
#include "core/util.h"
|
|
#include "json.hpp"
|
|
#include "tokenize_util.h"
|
|
#include "vocab/vocab.h"
|
|
|
|
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
|
|
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
|
|
// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE
|
|
//
|
|
// Since tokenization is not the bottleneck in SD, performance was not a major consideration
|
|
// during the migration.
|
|
|
|
MetaspacePreTokenizer::MetaspacePreTokenizer(const std::string replacement, bool add_prefix_space)
|
|
: replacement(replacement), add_prefix_space(add_prefix_space) {}
|
|
|
|
std::string MetaspacePreTokenizer::tokenize(const std::string& input) const {
|
|
std::string tokens;
|
|
std::stringstream ss(input);
|
|
|
|
if (add_prefix_space) {
|
|
tokens += replacement;
|
|
}
|
|
|
|
std::string token;
|
|
bool first_token = true;
|
|
while (std::getline(ss, token, ' ')) {
|
|
if (!first_token) {
|
|
tokens += replacement + token;
|
|
} else {
|
|
tokens += token;
|
|
}
|
|
|
|
first_token = false;
|
|
}
|
|
|
|
return tokens;
|
|
}
|
|
|
|
void T5UniGramTokenizer::InitializePieces(const std::string& json_str) {
|
|
nlohmann::json data;
|
|
|
|
try {
|
|
data = nlohmann::json::parse(json_str);
|
|
} catch (const nlohmann::json::parse_error&) {
|
|
status_ = INVLIAD_JSON;
|
|
return;
|
|
}
|
|
if (!data.contains("model")) {
|
|
status_ = INVLIAD_JSON;
|
|
return;
|
|
}
|
|
nlohmann::json model = data["model"];
|
|
if (!model.contains("vocab")) {
|
|
status_ = INVLIAD_JSON;
|
|
return;
|
|
}
|
|
if (model.contains("unk_id")) {
|
|
UNK_TOKEN_ID = model["unk_id"];
|
|
}
|
|
|
|
replacement = data["pre_tokenizer"]["replacement"];
|
|
add_prefix_space = data["pre_tokenizer"]["add_prefix_space"];
|
|
|
|
pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space);
|
|
|
|
for (const auto& item : model["vocab"]) {
|
|
if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) {
|
|
status_ = INVLIAD_JSON;
|
|
return;
|
|
}
|
|
std::string piece = item[0];
|
|
if (piece.empty()) {
|
|
piece = "<empty_token>";
|
|
}
|
|
float score = item[1];
|
|
piece_score_pairs.emplace_back(piece, score);
|
|
}
|
|
}
|
|
|
|
void T5UniGramTokenizer::BuildTrie(std::vector<std::pair<std::string, int>>* pieces) {
|
|
if (status_ != OK) {
|
|
return;
|
|
}
|
|
|
|
if (pieces->empty()) {
|
|
status_ = NO_PIECES_LOADED;
|
|
return;
|
|
}
|
|
|
|
std::sort(pieces->begin(), pieces->end());
|
|
|
|
std::vector<const char*> key(pieces->size());
|
|
std::vector<int> value(pieces->size());
|
|
for (size_t i = 0; i < pieces->size(); ++i) {
|
|
key[i] = (*pieces)[i].first.data();
|
|
value[i] = (*pieces)[i].second;
|
|
}
|
|
|
|
trie_ = std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray());
|
|
if (trie_->build(key.size(), const_cast<char**>(&key[0]), nullptr, &value[0]) != 0) {
|
|
status_ = BUILD_DOUBLE_ARRAY_FAILED;
|
|
return;
|
|
}
|
|
|
|
const int kMaxTrieResultsSize = 1024;
|
|
std::vector<Darts::DoubleArray::result_pair_type> results(kMaxTrieResultsSize);
|
|
trie_results_size_ = 0;
|
|
for (const auto& p : *pieces) {
|
|
const size_t num_nodes = trie_->commonPrefixSearch(
|
|
p.first.data(), results.data(), results.size(), p.first.size());
|
|
trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
|
|
}
|
|
|
|
if (trie_results_size_ == 0) {
|
|
status_ = NO_ENTRY_FOUND;
|
|
}
|
|
}
|
|
|
|
float T5UniGramTokenizer::GetScoreInlined(int id) const {
|
|
return piece_score_pairs[id].second;
|
|
}
|
|
|
|
bool T5UniGramTokenizer::IsUnusedInlined(int id) const {
|
|
(void)id;
|
|
return false;
|
|
}
|
|
|
|
bool T5UniGramTokenizer::IsUserDefinedInlined(int id) const {
|
|
(void)id;
|
|
return false;
|
|
}
|
|
|
|
size_t T5UniGramTokenizer::OneCharLen(const char* src) const {
|
|
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
|
|
}
|
|
|
|
EncodeResult T5UniGramTokenizer::EncodeOptimized(const std::string& normalized) const {
|
|
if (status() != OK || normalized.empty()) {
|
|
return {};
|
|
}
|
|
|
|
struct BestPathNode {
|
|
int id = -1;
|
|
float best_path_score = 0;
|
|
int starts_at = -1;
|
|
};
|
|
|
|
const int size = static_cast<int>(normalized.size());
|
|
const float unk_score = min_score() - kUnkPenalty;
|
|
std::vector<BestPathNode> best_path_ends_at(size + 1);
|
|
|
|
int starts_at = 0;
|
|
while (starts_at < size) {
|
|
std::size_t node_pos = 0;
|
|
std::size_t key_pos = starts_at;
|
|
const auto best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
|
|
bool has_single_node = false;
|
|
const int mblen = std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)), size - starts_at);
|
|
while (key_pos < static_cast<size_t>(size)) {
|
|
const int ret = trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1);
|
|
if (ret == -2) {
|
|
break;
|
|
}
|
|
if (ret >= 0) {
|
|
if (IsUnusedInlined(ret)) {
|
|
continue;
|
|
}
|
|
auto& target_node = best_path_ends_at[key_pos];
|
|
const auto length = static_cast<int>(key_pos - starts_at);
|
|
const auto score = IsUserDefinedInlined(ret) ? (length * max_score_ - 0.1f) : GetScoreInlined(ret);
|
|
const auto candidate_best_path_score = score + best_path_score_till_here;
|
|
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
|
|
target_node.best_path_score = static_cast<float>(candidate_best_path_score);
|
|
target_node.starts_at = starts_at;
|
|
target_node.id = ret;
|
|
}
|
|
if (!has_single_node && length == mblen) {
|
|
has_single_node = true;
|
|
}
|
|
}
|
|
}
|
|
if (!has_single_node) {
|
|
auto& target_node = best_path_ends_at[starts_at + mblen];
|
|
const auto candidate_best_path_score = unk_score + best_path_score_till_here;
|
|
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
|
|
target_node.best_path_score = candidate_best_path_score;
|
|
target_node.starts_at = starts_at;
|
|
target_node.id = UNK_TOKEN_ID;
|
|
}
|
|
}
|
|
starts_at += mblen;
|
|
}
|
|
|
|
EncodeResult results;
|
|
int ends_at = size;
|
|
while (ends_at > 0) {
|
|
const auto& node = best_path_ends_at[ends_at];
|
|
results.emplace_back(normalized.substr(node.starts_at, ends_at - node.starts_at), node.id);
|
|
ends_at = node.starts_at;
|
|
}
|
|
std::reverse(results.begin(), results.end());
|
|
return results;
|
|
}
|
|
|
|
T5UniGramTokenizer::T5UniGramTokenizer(bool is_umt5) {
|
|
add_bos_token = false;
|
|
add_eos_token = true;
|
|
|
|
if (is_umt5) {
|
|
PAD_TOKEN_ID = 0;
|
|
EOS_TOKEN_ID = 1;
|
|
BOS_TOKEN_ID = 2;
|
|
UNK_TOKEN_ID = 3;
|
|
|
|
PAD_TOKEN = "<pad>";
|
|
EOS_TOKEN = "</s>";
|
|
BOS_TOKEN = "<s>";
|
|
UNK_TOKEN = "<unk>";
|
|
} else {
|
|
PAD_TOKEN_ID = 0;
|
|
EOS_TOKEN_ID = 1;
|
|
UNK_TOKEN_ID = 2;
|
|
|
|
PAD_TOKEN = "<pad>";
|
|
EOS_TOKEN = "</s>";
|
|
UNK_TOKEN = "<unk>";
|
|
}
|
|
|
|
special_tokens = {
|
|
"<pad>",
|
|
"</s>",
|
|
"<unk>",
|
|
};
|
|
|
|
if (is_umt5) {
|
|
special_tokens.push_back("<s>");
|
|
}
|
|
|
|
if (is_umt5) {
|
|
InitializePieces(load_umt5_tokenizer_json());
|
|
} else {
|
|
InitializePieces(load_t5_tokenizer_json());
|
|
}
|
|
|
|
min_score_ = FLT_MAX;
|
|
max_score_ = FLT_MIN;
|
|
|
|
std::vector<std::pair<std::string, int>> pieces;
|
|
for (int i = 0; i < static_cast<int>(piece_score_pairs.size()); i++) {
|
|
const auto& sp = piece_score_pairs[i];
|
|
|
|
min_score_ = std::min(min_score_, sp.second);
|
|
max_score_ = std::max(max_score_, sp.second);
|
|
|
|
pieces.emplace_back(sp.first, i);
|
|
}
|
|
|
|
BuildTrie(&pieces);
|
|
}
|
|
|
|
T5UniGramTokenizer::~T5UniGramTokenizer() = default;
|
|
|
|
std::string T5UniGramTokenizer::decode_token(int token_id) const {
|
|
if (token_id < 0 || token_id >= static_cast<int>(piece_score_pairs.size())) {
|
|
return "";
|
|
}
|
|
|
|
const std::string& piece = piece_score_pairs[token_id].first;
|
|
if (piece == "<empty_token>") {
|
|
return "";
|
|
}
|
|
return piece;
|
|
}
|
|
|
|
std::string T5UniGramTokenizer::normalize(const std::string& input) const {
|
|
// Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29
|
|
// TODO: nmt-nfkc
|
|
std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " ");
|
|
return normalized;
|
|
}
|
|
|
|
std::vector<int> T5UniGramTokenizer::encode(const std::string& input, on_new_token_cb_t on_new_token_cb) {
|
|
std::vector<int32_t> tokens;
|
|
std::vector<std::string> token_strs;
|
|
std::string normalized = normalize(input);
|
|
auto splited_texts = split_with_special_tokens(normalized, special_tokens);
|
|
if (splited_texts.empty()) {
|
|
splited_texts.push_back(normalized); // for empty string
|
|
}
|
|
|
|
for (auto& splited_text : splited_texts) {
|
|
if (is_special_token(splited_text)) {
|
|
if (on_new_token_cb != nullptr) {
|
|
bool skip = on_new_token_cb(splited_text, tokens);
|
|
if (skip) {
|
|
token_strs.push_back(splited_text);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
if (splited_text == UNK_TOKEN) {
|
|
tokens.push_back(UNK_TOKEN_ID);
|
|
token_strs.push_back(UNK_TOKEN);
|
|
} else if (splited_text == EOS_TOKEN) {
|
|
tokens.push_back(EOS_TOKEN_ID);
|
|
token_strs.push_back(EOS_TOKEN);
|
|
} else if (splited_text == PAD_TOKEN) {
|
|
tokens.push_back(PAD_TOKEN_ID);
|
|
token_strs.push_back(PAD_TOKEN);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
std::string pretokenized = pre_tokenizer.tokenize(splited_text);
|
|
EncodeResult result = EncodeOptimized(pretokenized);
|
|
for (const auto& item : result) {
|
|
tokens.push_back(item.second);
|
|
token_strs.push_back(item.first);
|
|
}
|
|
}
|
|
|
|
std::stringstream ss;
|
|
ss << "[";
|
|
for (const auto& token_str : token_strs) {
|
|
ss << "\"" << token_str << "\", ";
|
|
}
|
|
ss << "]";
|
|
LOG_DEBUG("split prompt \"%s\" to tokens %s", input.c_str(), ss.str().c_str());
|
|
|
|
return tokens;
|
|
}
|