mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: optimize the handling of embedding (#1068)
* optimize the handling of embedding * support case-insensitive embedding names
This commit is contained in:
parent
0392273e10
commit
96c3e64057
76
clip.hpp
76
clip.hpp
@ -3,6 +3,7 @@
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "tokenize_util.h"
|
||||
|
||||
/*================================================== CLIPTokenizer ===================================================*/
|
||||
|
||||
@ -72,6 +73,8 @@ private:
|
||||
int encoder_len;
|
||||
int bpe_len;
|
||||
|
||||
std::vector<std::string> special_tokens;
|
||||
|
||||
public:
|
||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||
const std::string BOS_TOKEN = "<|startoftext|>";
|
||||
@ -117,6 +120,15 @@ private:
|
||||
return pairs;
|
||||
}
|
||||
|
||||
bool is_special_token(const std::string& token) {
|
||||
for (auto& special_token : special_tokens) {
|
||||
if (special_token == token) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
|
||||
: PAD_TOKEN_ID(pad_token_id) {
|
||||
@ -125,6 +137,8 @@ public:
|
||||
} else {
|
||||
load_from_merges(ModelLoader::load_merges());
|
||||
}
|
||||
add_special_token("<|startoftext|>");
|
||||
add_special_token("<|endoftext|>");
|
||||
}
|
||||
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
@ -201,6 +215,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void add_special_token(const std::string& token) {
|
||||
special_tokens.push_back(token);
|
||||
}
|
||||
|
||||
std::u32string bpe(const std::u32string& token) {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
@ -379,25 +397,54 @@ public:
|
||||
return trim(text);
|
||||
}
|
||||
|
||||
std::vector<std::string> token_split(const std::string& text) {
|
||||
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||
std::regex::icase);
|
||||
std::sregex_iterator iter(text.begin(), text.end(), pat);
|
||||
std::sregex_iterator end;
|
||||
|
||||
std::vector<std::string> result;
|
||||
for (; iter != end; ++iter) {
|
||||
result.emplace_back(iter->str());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
|
||||
std::string original_text = text;
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
text = whitespace_clean(text);
|
||||
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
|
||||
std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||
std::regex::icase);
|
||||
|
||||
std::smatch matches;
|
||||
std::string str = text;
|
||||
std::vector<std::string> token_strs;
|
||||
while (std::regex_search(str, matches, pat)) {
|
||||
bool skip = on_new_token_cb(str, bpe_tokens);
|
||||
|
||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||
|
||||
for (auto& splited_text : splited_texts) {
|
||||
LOG_DEBUG("token %s", splited_text.c_str());
|
||||
if (is_special_token(splited_text)) {
|
||||
LOG_DEBUG("special %s", splited_text.c_str());
|
||||
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
for (auto& token : matches) {
|
||||
std::string token_str = token.str();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tokens = token_split(splited_text);
|
||||
for (auto& token : tokens) {
|
||||
if (on_new_token_cb != nullptr) {
|
||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(token);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = token;
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
@ -417,14 +464,13 @@ public:
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
str = matches.suffix();
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (auto token : token_strs) {
|
||||
ss << "\"" << token << "\", ";
|
||||
}
|
||||
ss << "]";
|
||||
// std::stringstream ss;
|
||||
// ss << "[";
|
||||
// for (auto token : token_strs) {
|
||||
// ss << "\"" << token << "\", ";
|
||||
// }
|
||||
// ss << "]";
|
||||
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
|
||||
@ -56,7 +56,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
std::shared_ptr<CLIPTextModelRunner> text_model2;
|
||||
|
||||
std::string trigger_word = "img"; // should be user settable
|
||||
std::string embd_dir;
|
||||
std::map<std::string, std::string> embedding_map;
|
||||
int32_t num_custom_embeddings = 0;
|
||||
int32_t num_custom_embeddings_2 = 0;
|
||||
std::vector<uint8_t> token_embed_custom;
|
||||
@ -65,11 +65,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& embd_dir,
|
||||
const std::map<std::string, std::string>& orig_embedding_map,
|
||||
SDVersion version = VERSION_SD1,
|
||||
PMVersion pv = PM_VERSION_1)
|
||||
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
|
||||
bool force_clip_f32 = embd_dir.size() > 0;
|
||||
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
|
||||
for (const auto& kv : orig_embedding_map) {
|
||||
std::string name = kv.first;
|
||||
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
embedding_map[name] = kv.second;
|
||||
tokenizer.add_special_token(name);
|
||||
}
|
||||
bool force_clip_f32 = !embedding_map.empty();
|
||||
if (sd_version_is_sd1(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
|
||||
} else if (sd_version_is_sd2(version)) {
|
||||
@ -196,26 +202,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
|
||||
std::vector<int> convert_token_to_id(std::string text) {
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
size_t word_end = str.find(",");
|
||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
||||
embd_name = trim(embd_name);
|
||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
||||
}
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
||||
}
|
||||
if (embd_path.size() > 0) {
|
||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
||||
if (word_end != std::string::npos) {
|
||||
str = str.substr(word_end);
|
||||
} else {
|
||||
str = "";
|
||||
auto iter = embedding_map.find(str);
|
||||
if (iter == embedding_map.end()) {
|
||||
return false;
|
||||
}
|
||||
std::string embedding_path = iter->second;
|
||||
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
|
||||
@ -245,26 +239,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
}
|
||||
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
size_t word_end = str.find(",");
|
||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
||||
embd_name = trim(embd_name);
|
||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
||||
}
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
||||
}
|
||||
if (embd_path.size() > 0) {
|
||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
||||
if (word_end != std::string::npos) {
|
||||
str = str.substr(word_end);
|
||||
} else {
|
||||
str = "";
|
||||
auto iter = embedding_map.find(str);
|
||||
if (iter == embedding_map.end()) {
|
||||
return false;
|
||||
}
|
||||
std::string embedding_path = iter->second;
|
||||
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
@ -376,26 +358,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
}
|
||||
|
||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||
size_t word_end = str.find(",");
|
||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
||||
embd_name = trim(embd_name);
|
||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
||||
}
|
||||
if (embd_path.size() == 0) {
|
||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
||||
}
|
||||
if (embd_path.size() > 0) {
|
||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
||||
if (word_end != std::string::npos) {
|
||||
str = str.substr(word_end);
|
||||
} else {
|
||||
str = "";
|
||||
auto iter = embedding_map.find(str);
|
||||
if (iter == embedding_map.end()) {
|
||||
return false;
|
||||
}
|
||||
std::string embedding_path = iter->second;
|
||||
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
|
||||
@ -501,6 +501,9 @@ struct SDContextParams {
|
||||
std::string tensor_type_rules;
|
||||
std::string lora_model_dir;
|
||||
|
||||
std::map<std::string, std::string> embedding_map;
|
||||
std::vector<sd_embedding_t> embedding_array;
|
||||
|
||||
rng_type_t rng_type = CUDA_RNG;
|
||||
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
|
||||
bool offload_params_to_cpu = false;
|
||||
@ -828,6 +831,37 @@ struct SDContextParams {
|
||||
return options;
|
||||
}
|
||||
|
||||
void build_embedding_map() {
|
||||
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
|
||||
|
||||
if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& p : fs::directory_iterator(embedding_dir)) {
|
||||
if (!p.is_regular_file())
|
||||
continue;
|
||||
|
||||
auto path = p.path();
|
||||
std::string ext = path.extension().string();
|
||||
|
||||
bool valid = false;
|
||||
for (auto& e : valid_ext) {
|
||||
if (ext == e) {
|
||||
valid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!valid)
|
||||
continue;
|
||||
|
||||
std::string key = path.stem().string();
|
||||
std::string value = path.string();
|
||||
|
||||
embedding_map[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
bool process_and_check(SDMode mode) {
|
||||
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
|
||||
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
|
||||
@ -845,10 +879,24 @@ struct SDContextParams {
|
||||
n_threads = sd_get_num_physical_cores();
|
||||
}
|
||||
|
||||
build_embedding_map();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::ostringstream emb_ss;
|
||||
emb_ss << "{\n";
|
||||
for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) {
|
||||
emb_ss << " \"" << it->first << "\": \"" << it->second << "\"";
|
||||
if (std::next(it) != embedding_map.end()) {
|
||||
emb_ss << ",";
|
||||
}
|
||||
emb_ss << "\n";
|
||||
}
|
||||
emb_ss << " }";
|
||||
|
||||
std::string embeddings_str = emb_ss.str();
|
||||
std::ostringstream oss;
|
||||
oss << "SDContextParams {\n"
|
||||
<< " n_threads: " << n_threads << ",\n"
|
||||
@ -866,6 +914,7 @@ struct SDContextParams {
|
||||
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
||||
<< " control_net_path: \"" << control_net_path << "\",\n"
|
||||
<< " embedding_dir: \"" << embedding_dir << "\",\n"
|
||||
<< " embeddings: " << embeddings_str << "\n"
|
||||
<< " wtype: " << sd_type_name(wtype) << ",\n"
|
||||
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
|
||||
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
|
||||
@ -898,6 +947,15 @@ struct SDContextParams {
|
||||
}
|
||||
|
||||
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
|
||||
embedding_array.clear();
|
||||
embedding_array.reserve(embedding_map.size());
|
||||
for (const auto& kv : embedding_map) {
|
||||
sd_embedding_t item;
|
||||
item.name = kv.first.c_str();
|
||||
item.path = kv.second.c_str();
|
||||
embedding_array.emplace_back(item);
|
||||
}
|
||||
|
||||
sd_ctx_params_t sd_ctx_params = {
|
||||
model_path.c_str(),
|
||||
clip_l_path.c_str(),
|
||||
@ -912,7 +970,8 @@ struct SDContextParams {
|
||||
taesd_path.c_str(),
|
||||
control_net_path.c_str(),
|
||||
lora_model_dir.c_str(),
|
||||
embedding_dir.c_str(),
|
||||
embedding_array.data(),
|
||||
static_cast<uint32_t>(embedding_array.size()),
|
||||
photo_maker_path.c_str(),
|
||||
tensor_type_rules.c_str(),
|
||||
vae_decode_only,
|
||||
|
||||
@ -508,18 +508,22 @@ public:
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
std::map<std::string, std::string> embbeding_map;
|
||||
for (int i = 0; i < sd_ctx_params->embedding_count; i++) {
|
||||
embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path));
|
||||
}
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
embbeding_map,
|
||||
version,
|
||||
PM_VERSION_2);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
embbeding_map,
|
||||
version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||
@ -2521,7 +2525,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"taesd_path: %s\n"
|
||||
"control_net_path: %s\n"
|
||||
"lora_model_dir: %s\n"
|
||||
"embedding_dir: %s\n"
|
||||
"photo_maker_path: %s\n"
|
||||
"tensor_type_rules: %s\n"
|
||||
"vae_decode_only: %s\n"
|
||||
@ -2552,7 +2555,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->taesd_path),
|
||||
SAFE_STR(sd_ctx_params->control_net_path),
|
||||
SAFE_STR(sd_ctx_params->lora_model_dir),
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||
SAFE_STR(sd_ctx_params->tensor_type_rules),
|
||||
BOOL_STR(sd_ctx_params->vae_decode_only),
|
||||
|
||||
@ -150,6 +150,11 @@ typedef struct {
|
||||
float rel_size_y;
|
||||
} sd_tiling_params_t;
|
||||
|
||||
typedef struct {
|
||||
const char* name;
|
||||
const char* path;
|
||||
} sd_embedding_t;
|
||||
|
||||
typedef struct {
|
||||
const char* model_path;
|
||||
const char* clip_l_path;
|
||||
@ -164,7 +169,8 @@ typedef struct {
|
||||
const char* taesd_path;
|
||||
const char* control_net_path;
|
||||
const char* lora_model_dir;
|
||||
const char* embedding_dir;
|
||||
const sd_embedding_t* embeddings;
|
||||
uint32_t embedding_count;
|
||||
const char* photo_maker_path;
|
||||
const char* tensor_type_rules;
|
||||
bool vae_decode_only;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user