mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-17 03:37:20 +00:00
fix: normalize CLIP prompts before special-token splitting (#1670)
This commit is contained in:
parent
92a3b73cdb
commit
7f0e728b7d
@ -142,8 +142,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
|
||||
: version(version), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
|
||||
for (const auto& kv : orig_embedding_map) {
|
||||
std::string name = kv.first;
|
||||
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
std::string name = normalize_embedding_name(kv.first);
|
||||
embedding_map[name] = kv.second;
|
||||
tokenizer.add_special_token(name);
|
||||
}
|
||||
@ -278,17 +277,23 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> convert_token_to_id(std::string text) {
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
auto iter = embedding_map.find(str);
|
||||
static std::string normalize_embedding_name(std::string name) {
|
||||
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
return name;
|
||||
}
|
||||
|
||||
bool append_embedding_tokens(std::string str, std::vector<int32_t>& bpe_tokens) {
|
||||
std::string name = normalize_embedding_name(std::move(str));
|
||||
auto iter = embedding_map.find(name);
|
||||
if (iter == embedding_map.end()) {
|
||||
return false;
|
||||
}
|
||||
std::string embedding_path = iter->second;
|
||||
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||
return true;
|
||||
return load_embedding(name, iter->second, bpe_tokens);
|
||||
}
|
||||
return false;
|
||||
|
||||
std::vector<int> convert_token_to_id(std::string text) {
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
return append_embedding_tokens(str, bpe_tokens);
|
||||
};
|
||||
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
|
||||
return curr_tokens;
|
||||
@ -315,15 +320,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
}
|
||||
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
auto iter = embedding_map.find(str);
|
||||
if (iter == embedding_map.end()) {
|
||||
return false;
|
||||
}
|
||||
std::string embedding_path = iter->second;
|
||||
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return append_embedding_tokens(str, bpe_tokens);
|
||||
};
|
||||
|
||||
std::vector<int> tokens;
|
||||
|
||||
@ -134,7 +134,8 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
std::vector<std::string> token_strs;
|
||||
|
||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||
std::string normalized_text = normalize_before_split ? normalize(text) : text;
|
||||
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
|
||||
|
||||
for (auto& splited_text : splited_texts) {
|
||||
if (is_special_token(splited_text)) {
|
||||
@ -159,7 +160,7 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = normalize(token);
|
||||
std::string token_str = normalize_before_split ? token : normalize(token);
|
||||
std::u32string utf32_token;
|
||||
if (byte_level_bpe) {
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
|
||||
@ -25,6 +25,7 @@ CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_st
|
||||
end_of_word_suffix = "</w>";
|
||||
add_bos_token = true;
|
||||
add_eos_token = true;
|
||||
normalize_before_split = true;
|
||||
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str);
|
||||
|
||||
@ -15,6 +15,7 @@ protected:
|
||||
bool add_bos_token = false;
|
||||
bool add_eos_token = false;
|
||||
bool pad_left = false;
|
||||
bool normalize_before_split = false;
|
||||
std::string end_of_word_suffix;
|
||||
|
||||
virtual std::string decode_token(int token_id) const = 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user