mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: optimize the handling of embedding (#1068)
* optimize the handling of embedding * support case-insensitive embedding names
This commit is contained in:
parent
0392273e10
commit
96c3e64057
78
clip.hpp
78
clip.hpp
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
|
||||||
/*================================================== CLIPTokenizer ===================================================*/
|
/*================================================== CLIPTokenizer ===================================================*/
|
||||||
|
|
||||||
@ -72,6 +73,8 @@ private:
|
|||||||
int encoder_len;
|
int encoder_len;
|
||||||
int bpe_len;
|
int bpe_len;
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||||
const std::string BOS_TOKEN = "<|startoftext|>";
|
const std::string BOS_TOKEN = "<|startoftext|>";
|
||||||
@ -117,6 +120,15 @@ private:
|
|||||||
return pairs;
|
return pairs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_special_token(const std::string& token) {
|
||||||
|
for (auto& special_token : special_tokens) {
|
||||||
|
if (special_token == token) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
|
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
|
||||||
: PAD_TOKEN_ID(pad_token_id) {
|
: PAD_TOKEN_ID(pad_token_id) {
|
||||||
@ -125,6 +137,8 @@ public:
|
|||||||
} else {
|
} else {
|
||||||
load_from_merges(ModelLoader::load_merges());
|
load_from_merges(ModelLoader::load_merges());
|
||||||
}
|
}
|
||||||
|
add_special_token("<|startoftext|>");
|
||||||
|
add_special_token("<|endoftext|>");
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_from_merges(const std::string& merges_utf8_str) {
|
void load_from_merges(const std::string& merges_utf8_str) {
|
||||||
@ -201,6 +215,10 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add_special_token(const std::string& token) {
|
||||||
|
special_tokens.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
std::u32string bpe(const std::u32string& token) {
|
std::u32string bpe(const std::u32string& token) {
|
||||||
std::vector<std::u32string> word;
|
std::vector<std::u32string> word;
|
||||||
|
|
||||||
@ -379,25 +397,54 @@ public:
|
|||||||
return trim(text);
|
return trim(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> token_split(const std::string& text) {
|
||||||
|
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||||
|
std::regex::icase);
|
||||||
|
std::sregex_iterator iter(text.begin(), text.end(), pat);
|
||||||
|
std::sregex_iterator end;
|
||||||
|
|
||||||
|
std::vector<std::string> result;
|
||||||
|
for (; iter != end; ++iter) {
|
||||||
|
result.emplace_back(iter->str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
|
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
|
||||||
std::string original_text = text;
|
std::string original_text = text;
|
||||||
std::vector<int32_t> bpe_tokens;
|
std::vector<int32_t> bpe_tokens;
|
||||||
text = whitespace_clean(text);
|
text = whitespace_clean(text);
|
||||||
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||||
|
|
||||||
std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
|
||||||
std::regex::icase);
|
|
||||||
|
|
||||||
std::smatch matches;
|
|
||||||
std::string str = text;
|
std::string str = text;
|
||||||
std::vector<std::string> token_strs;
|
std::vector<std::string> token_strs;
|
||||||
while (std::regex_search(str, matches, pat)) {
|
|
||||||
bool skip = on_new_token_cb(str, bpe_tokens);
|
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||||
if (skip) {
|
|
||||||
|
for (auto& splited_text : splited_texts) {
|
||||||
|
LOG_DEBUG("token %s", splited_text.c_str());
|
||||||
|
if (is_special_token(splited_text)) {
|
||||||
|
LOG_DEBUG("special %s", splited_text.c_str());
|
||||||
|
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
||||||
|
if (skip) {
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (auto& token : matches) {
|
|
||||||
std::string token_str = token.str();
|
auto tokens = token_split(splited_text);
|
||||||
|
for (auto& token : tokens) {
|
||||||
|
if (on_new_token_cb != nullptr) {
|
||||||
|
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||||
|
if (skip) {
|
||||||
|
token_strs.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token_str = token;
|
||||||
std::u32string utf32_token;
|
std::u32string utf32_token;
|
||||||
for (int i = 0; i < token_str.length(); i++) {
|
for (int i = 0; i < token_str.length(); i++) {
|
||||||
unsigned char b = token_str[i];
|
unsigned char b = token_str[i];
|
||||||
@ -417,14 +464,13 @@ public:
|
|||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
}
|
}
|
||||||
str = matches.suffix();
|
|
||||||
}
|
}
|
||||||
std::stringstream ss;
|
// std::stringstream ss;
|
||||||
ss << "[";
|
// ss << "[";
|
||||||
for (auto token : token_strs) {
|
// for (auto token : token_strs) {
|
||||||
ss << "\"" << token << "\", ";
|
// ss << "\"" << token << "\", ";
|
||||||
}
|
// }
|
||||||
ss << "]";
|
// ss << "]";
|
||||||
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||||
return bpe_tokens;
|
return bpe_tokens;
|
||||||
|
|||||||
@ -56,7 +56,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
std::shared_ptr<CLIPTextModelRunner> text_model2;
|
std::shared_ptr<CLIPTextModelRunner> text_model2;
|
||||||
|
|
||||||
std::string trigger_word = "img"; // should be user settable
|
std::string trigger_word = "img"; // should be user settable
|
||||||
std::string embd_dir;
|
std::map<std::string, std::string> embedding_map;
|
||||||
int32_t num_custom_embeddings = 0;
|
int32_t num_custom_embeddings = 0;
|
||||||
int32_t num_custom_embeddings_2 = 0;
|
int32_t num_custom_embeddings_2 = 0;
|
||||||
std::vector<uint8_t> token_embed_custom;
|
std::vector<uint8_t> token_embed_custom;
|
||||||
@ -65,11 +65,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
|
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2TensorStorage& tensor_storage_map,
|
const String2TensorStorage& tensor_storage_map,
|
||||||
const std::string& embd_dir,
|
const std::map<std::string, std::string>& orig_embedding_map,
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1,
|
||||||
PMVersion pv = PM_VERSION_1)
|
PMVersion pv = PM_VERSION_1)
|
||||||
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
|
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
|
||||||
bool force_clip_f32 = embd_dir.size() > 0;
|
for (const auto& kv : orig_embedding_map) {
|
||||||
|
std::string name = kv.first;
|
||||||
|
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||||
|
embedding_map[name] = kv.second;
|
||||||
|
tokenizer.add_special_token(name);
|
||||||
|
}
|
||||||
|
bool force_clip_f32 = !embedding_map.empty();
|
||||||
if (sd_version_is_sd1(version)) {
|
if (sd_version_is_sd1(version)) {
|
||||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
|
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
|
||||||
} else if (sd_version_is_sd2(version)) {
|
} else if (sd_version_is_sd2(version)) {
|
||||||
@ -196,25 +202,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
|
|
||||||
std::vector<int> convert_token_to_id(std::string text) {
|
std::vector<int> convert_token_to_id(std::string text) {
|
||||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||||
size_t word_end = str.find(",");
|
auto iter = embedding_map.find(str);
|
||||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
if (iter == embedding_map.end()) {
|
||||||
embd_name = trim(embd_name);
|
return false;
|
||||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
|
||||||
if (embd_path.size() == 0) {
|
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
|
||||||
}
|
}
|
||||||
if (embd_path.size() == 0) {
|
std::string embedding_path = iter->second;
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||||
}
|
return true;
|
||||||
if (embd_path.size() > 0) {
|
|
||||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
|
||||||
if (word_end != std::string::npos) {
|
|
||||||
str = str.substr(word_end);
|
|
||||||
} else {
|
|
||||||
str = "";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@ -245,25 +239,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||||
size_t word_end = str.find(",");
|
auto iter = embedding_map.find(str);
|
||||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
if (iter == embedding_map.end()) {
|
||||||
embd_name = trim(embd_name);
|
return false;
|
||||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
|
||||||
if (embd_path.size() == 0) {
|
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
|
||||||
}
|
}
|
||||||
if (embd_path.size() == 0) {
|
std::string embedding_path = iter->second;
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||||
}
|
return true;
|
||||||
if (embd_path.size() > 0) {
|
|
||||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
|
||||||
if (word_end != std::string::npos) {
|
|
||||||
str = str.substr(word_end);
|
|
||||||
} else {
|
|
||||||
str = "";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@ -376,25 +358,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
||||||
size_t word_end = str.find(",");
|
auto iter = embedding_map.find(str);
|
||||||
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
|
if (iter == embedding_map.end()) {
|
||||||
embd_name = trim(embd_name);
|
return false;
|
||||||
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
|
|
||||||
if (embd_path.size() == 0) {
|
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
|
|
||||||
}
|
}
|
||||||
if (embd_path.size() == 0) {
|
std::string embedding_path = iter->second;
|
||||||
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
|
if (load_embedding(str, embedding_path, bpe_tokens)) {
|
||||||
}
|
return true;
|
||||||
if (embd_path.size() > 0) {
|
|
||||||
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
|
|
||||||
if (word_end != std::string::npos) {
|
|
||||||
str = str.substr(word_end);
|
|
||||||
} else {
|
|
||||||
str = "";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@ -1728,7 +1698,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||||
std::pair<int, int> prompt_attn_range;
|
std::pair<int, int> prompt_attn_range;
|
||||||
int prompt_template_encode_start_idx = 34;
|
int prompt_template_encode_start_idx = 34;
|
||||||
int max_length = 0;
|
int max_length = 0;
|
||||||
std::set<int> out_layers;
|
std::set<int> out_layers;
|
||||||
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||||
LOG_INFO("QwenImageEditPlusPipeline");
|
LOG_INFO("QwenImageEditPlusPipeline");
|
||||||
@ -1828,7 +1798,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt += "[/INST]";
|
prompt += "[/INST]";
|
||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
prompt_template_encode_start_idx = 28;
|
prompt_template_encode_start_idx = 28;
|
||||||
max_length = prompt_template_encode_start_idx + 256;
|
max_length = prompt_template_encode_start_idx + 256;
|
||||||
|
|
||||||
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
|
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
|
||||||
|
|
||||||
|
|||||||
@ -501,6 +501,9 @@ struct SDContextParams {
|
|||||||
std::string tensor_type_rules;
|
std::string tensor_type_rules;
|
||||||
std::string lora_model_dir;
|
std::string lora_model_dir;
|
||||||
|
|
||||||
|
std::map<std::string, std::string> embedding_map;
|
||||||
|
std::vector<sd_embedding_t> embedding_array;
|
||||||
|
|
||||||
rng_type_t rng_type = CUDA_RNG;
|
rng_type_t rng_type = CUDA_RNG;
|
||||||
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
|
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
|
||||||
bool offload_params_to_cpu = false;
|
bool offload_params_to_cpu = false;
|
||||||
@ -828,6 +831,37 @@ struct SDContextParams {
|
|||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void build_embedding_map() {
|
||||||
|
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
|
||||||
|
|
||||||
|
if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& p : fs::directory_iterator(embedding_dir)) {
|
||||||
|
if (!p.is_regular_file())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto path = p.path();
|
||||||
|
std::string ext = path.extension().string();
|
||||||
|
|
||||||
|
bool valid = false;
|
||||||
|
for (auto& e : valid_ext) {
|
||||||
|
if (ext == e) {
|
||||||
|
valid = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!valid)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
std::string key = path.stem().string();
|
||||||
|
std::string value = path.string();
|
||||||
|
|
||||||
|
embedding_map[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool process_and_check(SDMode mode) {
|
bool process_and_check(SDMode mode) {
|
||||||
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
|
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
|
||||||
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
|
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
|
||||||
@ -845,10 +879,24 @@ struct SDContextParams {
|
|||||||
n_threads = sd_get_num_physical_cores();
|
n_threads = sd_get_num_physical_cores();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
build_embedding_map();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string to_string() const {
|
std::string to_string() const {
|
||||||
|
std::ostringstream emb_ss;
|
||||||
|
emb_ss << "{\n";
|
||||||
|
for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) {
|
||||||
|
emb_ss << " \"" << it->first << "\": \"" << it->second << "\"";
|
||||||
|
if (std::next(it) != embedding_map.end()) {
|
||||||
|
emb_ss << ",";
|
||||||
|
}
|
||||||
|
emb_ss << "\n";
|
||||||
|
}
|
||||||
|
emb_ss << " }";
|
||||||
|
|
||||||
|
std::string embeddings_str = emb_ss.str();
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "SDContextParams {\n"
|
oss << "SDContextParams {\n"
|
||||||
<< " n_threads: " << n_threads << ",\n"
|
<< " n_threads: " << n_threads << ",\n"
|
||||||
@ -866,6 +914,7 @@ struct SDContextParams {
|
|||||||
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
||||||
<< " control_net_path: \"" << control_net_path << "\",\n"
|
<< " control_net_path: \"" << control_net_path << "\",\n"
|
||||||
<< " embedding_dir: \"" << embedding_dir << "\",\n"
|
<< " embedding_dir: \"" << embedding_dir << "\",\n"
|
||||||
|
<< " embeddings: " << embeddings_str << "\n"
|
||||||
<< " wtype: " << sd_type_name(wtype) << ",\n"
|
<< " wtype: " << sd_type_name(wtype) << ",\n"
|
||||||
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
|
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
|
||||||
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
|
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
|
||||||
@ -898,6 +947,15 @@ struct SDContextParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
|
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
|
||||||
|
embedding_array.clear();
|
||||||
|
embedding_array.reserve(embedding_map.size());
|
||||||
|
for (const auto& kv : embedding_map) {
|
||||||
|
sd_embedding_t item;
|
||||||
|
item.name = kv.first.c_str();
|
||||||
|
item.path = kv.second.c_str();
|
||||||
|
embedding_array.emplace_back(item);
|
||||||
|
}
|
||||||
|
|
||||||
sd_ctx_params_t sd_ctx_params = {
|
sd_ctx_params_t sd_ctx_params = {
|
||||||
model_path.c_str(),
|
model_path.c_str(),
|
||||||
clip_l_path.c_str(),
|
clip_l_path.c_str(),
|
||||||
@ -912,7 +970,8 @@ struct SDContextParams {
|
|||||||
taesd_path.c_str(),
|
taesd_path.c_str(),
|
||||||
control_net_path.c_str(),
|
control_net_path.c_str(),
|
||||||
lora_model_dir.c_str(),
|
lora_model_dir.c_str(),
|
||||||
embedding_dir.c_str(),
|
embedding_array.data(),
|
||||||
|
static_cast<uint32_t>(embedding_array.size()),
|
||||||
photo_maker_path.c_str(),
|
photo_maker_path.c_str(),
|
||||||
tensor_type_rules.c_str(),
|
tensor_type_rules.c_str(),
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
|
|||||||
@ -508,18 +508,22 @@ public:
|
|||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
version);
|
version);
|
||||||
} else { // SD1.x SD2.x SDXL
|
} else { // SD1.x SD2.x SDXL
|
||||||
|
std::map<std::string, std::string> embbeding_map;
|
||||||
|
for (int i = 0; i < sd_ctx_params->embedding_count; i++) {
|
||||||
|
embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path));
|
||||||
|
}
|
||||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
embbeding_map,
|
||||||
version,
|
version,
|
||||||
PM_VERSION_2);
|
PM_VERSION_2);
|
||||||
} else {
|
} else {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
embbeding_map,
|
||||||
version);
|
version);
|
||||||
}
|
}
|
||||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||||
@ -2521,7 +2525,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"taesd_path: %s\n"
|
"taesd_path: %s\n"
|
||||||
"control_net_path: %s\n"
|
"control_net_path: %s\n"
|
||||||
"lora_model_dir: %s\n"
|
"lora_model_dir: %s\n"
|
||||||
"embedding_dir: %s\n"
|
|
||||||
"photo_maker_path: %s\n"
|
"photo_maker_path: %s\n"
|
||||||
"tensor_type_rules: %s\n"
|
"tensor_type_rules: %s\n"
|
||||||
"vae_decode_only: %s\n"
|
"vae_decode_only: %s\n"
|
||||||
@ -2552,7 +2555,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
SAFE_STR(sd_ctx_params->taesd_path),
|
SAFE_STR(sd_ctx_params->taesd_path),
|
||||||
SAFE_STR(sd_ctx_params->control_net_path),
|
SAFE_STR(sd_ctx_params->control_net_path),
|
||||||
SAFE_STR(sd_ctx_params->lora_model_dir),
|
SAFE_STR(sd_ctx_params->lora_model_dir),
|
||||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
|
||||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||||
SAFE_STR(sd_ctx_params->tensor_type_rules),
|
SAFE_STR(sd_ctx_params->tensor_type_rules),
|
||||||
BOOL_STR(sd_ctx_params->vae_decode_only),
|
BOOL_STR(sd_ctx_params->vae_decode_only),
|
||||||
|
|||||||
@ -150,6 +150,11 @@ typedef struct {
|
|||||||
float rel_size_y;
|
float rel_size_y;
|
||||||
} sd_tiling_params_t;
|
} sd_tiling_params_t;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
const char* name;
|
||||||
|
const char* path;
|
||||||
|
} sd_embedding_t;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
const char* model_path;
|
const char* model_path;
|
||||||
const char* clip_l_path;
|
const char* clip_l_path;
|
||||||
@ -164,7 +169,8 @@ typedef struct {
|
|||||||
const char* taesd_path;
|
const char* taesd_path;
|
||||||
const char* control_net_path;
|
const char* control_net_path;
|
||||||
const char* lora_model_dir;
|
const char* lora_model_dir;
|
||||||
const char* embedding_dir;
|
const sd_embedding_t* embeddings;
|
||||||
|
uint32_t embedding_count;
|
||||||
const char* photo_maker_path;
|
const char* photo_maker_path;
|
||||||
const char* tensor_type_rules;
|
const char* tensor_type_rules;
|
||||||
bool vae_decode_only;
|
bool vae_decode_only;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user