fix: normalize CLIP prompts before special-token splitting (#1670)

This commit is contained in:
leejet 2026-06-17 00:33:00 +08:00 committed by GitHub
parent 92a3b73cdb
commit 7f0e728b7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 28 deletions

View File

@ -142,8 +142,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr) std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
: version(version), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) { : version(version), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
for (const auto& kv : orig_embedding_map) { for (const auto& kv : orig_embedding_map) {
std::string name = kv.first; std::string name = normalize_embedding_name(kv.first);
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
embedding_map[name] = kv.second; embedding_map[name] = kv.second;
tokenizer.add_special_token(name); tokenizer.add_special_token(name);
} }
@ -278,17 +277,23 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return true; return true;
} }
std::vector<int> convert_token_to_id(std::string text) { static std::string normalize_embedding_name(std::string name) {
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool { std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
auto iter = embedding_map.find(str); return name;
}
bool append_embedding_tokens(std::string str, std::vector<int32_t>& bpe_tokens) {
std::string name = normalize_embedding_name(std::move(str));
auto iter = embedding_map.find(name);
if (iter == embedding_map.end()) { if (iter == embedding_map.end()) {
return false; return false;
} }
std::string embedding_path = iter->second; return load_embedding(name, iter->second, bpe_tokens);
if (load_embedding(str, embedding_path, bpe_tokens)) {
return true;
} }
return false;
std::vector<int> convert_token_to_id(std::string text) {
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
return append_embedding_tokens(str, bpe_tokens);
}; };
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb); std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
return curr_tokens; return curr_tokens;
@ -315,15 +320,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
} }
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool { auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
auto iter = embedding_map.find(str); return append_embedding_tokens(str, bpe_tokens);
if (iter == embedding_map.end()) {
return false;
}
std::string embedding_path = iter->second;
if (load_embedding(str, embedding_path, bpe_tokens)) {
return true;
}
return false;
}; };
std::vector<int> tokens; std::vector<int> tokens;

View File

@ -134,7 +134,8 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
std::vector<int32_t> bpe_tokens; std::vector<int32_t> bpe_tokens;
std::vector<std::string> token_strs; std::vector<std::string> token_strs;
auto splited_texts = split_with_special_tokens(text, special_tokens); std::string normalized_text = normalize_before_split ? normalize(text) : text;
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
for (auto& splited_text : splited_texts) { for (auto& splited_text : splited_texts) {
if (is_special_token(splited_text)) { if (is_special_token(splited_text)) {
@ -159,7 +160,7 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
} }
} }
std::string token_str = normalize(token); std::string token_str = normalize_before_split ? token : normalize(token);
std::u32string utf32_token; std::u32string utf32_token;
if (byte_level_bpe) { if (byte_level_bpe) {
for (int i = 0; i < token_str.length(); i++) { for (int i = 0; i < token_str.length(); i++) {

View File

@ -25,6 +25,7 @@ CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_st
end_of_word_suffix = "</w>"; end_of_word_suffix = "</w>";
add_bos_token = true; add_bos_token = true;
add_eos_token = true; add_eos_token = true;
normalize_before_split = true;
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);

View File

@ -15,6 +15,7 @@ protected:
bool add_bos_token = false; bool add_bos_token = false;
bool add_eos_token = false; bool add_eos_token = false;
bool pad_left = false; bool pad_left = false;
bool normalize_before_split = false;
std::string end_of_word_suffix; std::string end_of_word_suffix;
virtual std::string decode_token(int token_id) const = 0; virtual std::string decode_token(int token_id) const = 0;