diff --git a/clip.hpp b/clip.hpp index 0070833..24c94f1 100644 --- a/clip.hpp +++ b/clip.hpp @@ -7,31 +7,6 @@ /*================================================== CLIPTokenizer ===================================================*/ -__STATIC_INLINE__ std::pair, std::string> extract_and_remove_lora(std::string text) { - std::regex re("]+)>"); - std::smatch matches; - std::unordered_map filename2multiplier; - - while (std::regex_search(text, matches, re)) { - std::string filename = matches[1].str(); - float multiplier = std::stof(matches[2].str()); - - text = std::regex_replace(text, re, "", std::regex_constants::format_first_only); - - if (multiplier == 0.f) { - continue; - } - - if (filename2multiplier.find(filename) == filename2multiplier.end()) { - filename2multiplier[filename] = multiplier; - } else { - filename2multiplier[filename] += multiplier; - } - } - - return std::make_pair(filename2multiplier, text); -} - __STATIC_INLINE__ std::vector> bytes_to_unicode() { std::vector> byte_unicode_pairs; std::set byte_set; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 2829f4d..c55da35 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -507,7 +507,7 @@ struct SDContextParams { std::string lora_model_dir; std::map embedding_map; - std::vector embedding_array; + std::vector embedding_vec; rng_type_t rng_type = CUDA_RNG; rng_type_t sampler_rng_type = RNG_TYPE_COUNT; @@ -952,13 +952,13 @@ struct SDContextParams { } sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { - embedding_array.clear(); - embedding_array.reserve(embedding_map.size()); + embedding_vec.clear(); + embedding_vec.reserve(embedding_map.size()); for (const auto& kv : embedding_map) { sd_embedding_t item; item.name = kv.first.c_str(); item.path = kv.second.c_str(); - embedding_array.emplace_back(item); + embedding_vec.emplace_back(item); } sd_ctx_params_t sd_ctx_params = { @@ -975,8 +975,8 @@ struct SDContextParams { taesd_path.c_str(), control_net_path.c_str(), lora_model_dir.c_str(), - embedding_array.data(), - static_cast(embedding_array.size()), + embedding_vec.data(), + static_cast(embedding_vec.size()), photo_maker_path.c_str(), tensor_type_rules.c_str(), vae_decode_only, @@ -1030,6 +1030,15 @@ static std::string vec_str_to_string(const std::vector& v) { return oss.str(); } +static bool is_absolute_path(const std::string& p) { +#ifdef _WIN32 + // Windows: C:/path or C:\path + return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; +#else + return !p.empty() && p[0] == '/'; +#endif +} + struct SDGenerationParams { std::string prompt; std::string negative_prompt; @@ -1072,6 +1081,10 @@ struct SDGenerationParams { int upscale_repeats = 1; + std::map lora_map; + std::map high_noise_lora_map; + std::vector lora_vec; + SDGenerationParams() { sd_sample_params_init(&sample_params); sd_sample_params_init(&high_noise_sample_params); @@ -1442,7 +1455,88 @@ struct SDGenerationParams { return options; } - bool process_and_check(SDMode mode) { + void extract_and_remove_lora(const std::string& lora_model_dir) { + static const std::regex re(R"(]+):([^>]+)>)"); + static const std::vector valid_ext = {".pt", ".safetensors", ".gguf"}; + std::smatch m; + + std::string tmp = prompt; + + while (std::regex_search(tmp, m, re)) { + std::string raw_path = m[1].str(); + const std::string raw_mul = m[2].str(); + + float mul = 0.f; + try { + mul = std::stof(raw_mul); + } catch (...) { + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + + bool is_high_noise = false; + static const std::string prefix = "|high_noise|"; + if (raw_path.rfind(prefix, 0) == 0) { + raw_path.erase(0, prefix.size()); + is_high_noise = true; + } + + fs::path final_path; + if (is_absolute_path(raw_path)) { + final_path = raw_path; + } else { + final_path = fs::path(lora_model_dir) / raw_path; + } + if (!fs::exists(final_path)) { + bool found = false; + for (const auto& ext : valid_ext) { + fs::path try_path = final_path; + try_path += ext; + if (fs::exists(try_path)) { + final_path = try_path; + found = true; + break; + } + } + if (!found) { + printf("can not found lora %s\n", final_path.lexically_normal().string().c_str()); + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + } + + const std::string key = final_path.lexically_normal().string(); + + if (is_high_noise) + high_noise_lora_map[key] += mul; + else + lora_map[key] += mul; + + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + + tmp = m.suffix().str(); + } + + for (const auto& kv : lora_map) { + sd_lora_t item; + item.is_high_noise = false; + item.path = kv.first.c_str(); + item.multiplier = kv.second; + lora_vec.emplace_back(item); + } + + for (const auto& kv : high_noise_lora_map) { + sd_lora_t item; + item.is_high_noise = true; + item.path = kv.first.c_str(); + item.multiplier = kv.second; + lora_vec.emplace_back(item); + } + } + + bool process_and_check(SDMode mode, const std::string& lora_model_dir) { if (width <= 0) { fprintf(stderr, "error: the width must be greater than 0\n"); return false; @@ -1553,14 +1647,44 @@ struct SDGenerationParams { seed = rand(); } + extract_and_remove_lora(lora_model_dir); + return true; } std::string to_string() const { char* sample_params_str = sd_sample_params_to_str(&sample_params); char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params); + + std::ostringstream lora_ss; + lora_ss << "{\n"; + for (auto it = lora_map.begin(); it != lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string loras_str = lora_ss.str(); + + lora_ss = std::ostringstream(); + ; + lora_ss << "{\n"; + for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != high_noise_lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string high_noise_loras_str = lora_ss.str(); + std::ostringstream oss; oss << "SDGenerationParams {\n" + << " loras: \"" << loras_str << "\",\n" + << " high_noise_loras: \"" << high_noise_loras_str << "\",\n" << " prompt: \"" << prompt << "\",\n" << " negative_prompt: \"" << negative_prompt << "\",\n" << " clip_skip: " << clip_skip << ",\n" @@ -1626,7 +1750,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP exit(cli_params.normal_exit ? 0 : 1); } - if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode)) { + if (!cli_params.process_and_check() || + !ctx_params.process_and_check(cli_params.mode) || + !gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) { print_usage(argc, argv, options_vec); exit(1); } @@ -2139,6 +2265,8 @@ int main(int argc, const char* argv[]) { if (cli_params.mode == IMG_GEN) { sd_img_gen_params_t img_gen_params = { + gen_params.lora_vec.data(), + static_cast(gen_params.lora_vec.size()), gen_params.prompt.c_str(), gen_params.negative_prompt.c_str(), gen_params.clip_skip, @@ -2170,6 +2298,8 @@ int main(int argc, const char* argv[]) { num_results = gen_params.batch_count; } else if (cli_params.mode == VID_GEN) { sd_vid_gen_params_t vid_gen_params = { + gen_params.lora_vec.data(), + static_cast(gen_params.lora_vec.size()), gen_params.prompt.c_str(), gen_params.negative_prompt.c_str(), gen_params.clip_skip, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 6ee0fca..d381bf6 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -937,28 +937,17 @@ public: float multiplier, ggml_backend_t backend, LoraModel::filter_t lora_tensor_filter = nullptr) { - std::string lora_name = lora_id; - std::string high_noise_tag = "|high_noise|"; - bool is_high_noise = false; - if (starts_with(lora_name, high_noise_tag)) { - lora_name = lora_name.substr(high_noise_tag.size()); + std::string lora_path = lora_id; + static std::string high_noise_tag = "|high_noise|"; + bool is_high_noise = false; + if (starts_with(lora_path, high_noise_tag)) { + lora_path = lora_path.substr(high_noise_tag.size()); is_high_noise = true; - LOG_DEBUG("high noise lora: %s", lora_name.c_str()); + LOG_DEBUG("high noise lora: %s", lora_path.c_str()); } - std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); - std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt"); - std::string file_path; - if (file_exists(st_file_path)) { - file_path = st_file_path; - } else if (file_exists(ckpt_file_path)) { - file_path = ckpt_file_path; - } else { - LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); - return nullptr; - } - auto lora = std::make_shared(lora_id, backend, file_path, is_high_noise ? "model.high_noise_" : "", version); + auto lora = std::make_shared(lora_id, backend, lora_path, is_high_noise ? "model.high_noise_" : "", version); if (!lora->load_from_file(n_threads, lora_tensor_filter)) { - LOG_WARN("load lora tensors from %s failed", file_path.c_str()); + LOG_WARN("load lora tensors from %s failed", lora_path.c_str()); return nullptr; } @@ -1143,12 +1132,15 @@ public: } } - std::string apply_loras_from_prompt(const std::string& prompt) { - auto result_pair = extract_and_remove_lora(prompt); - std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier - - for (auto& kv : lora_f2m) { - LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second); + void apply_loras(const sd_lora_t* loras, uint32_t lora_count) { + std::unordered_map lora_f2m; + for (int i = 0; i < lora_count; i++) { + std::string lora_id = SAFE_STR(loras[i].path); + if (loras[i].is_high_noise) { + lora_id = "|high_noise|" + lora_id; + } + lora_f2m[lora_id] = loras[i].multiplier; + LOG_DEBUG("lora %s:%.2f", lora_id.c_str(), loras[i].multiplier); } int64_t t0 = ggml_time_ms(); if (apply_lora_immediately) { @@ -1159,9 +1151,7 @@ public: int64_t t1 = ggml_time_ms(); if (!lora_f2m.empty()) { LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str()); } - return result_pair.second; } ggml_tensor* id_encoder(ggml_context* work_ctx, @@ -2815,8 +2805,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int sample_steps = sigmas.size() - 1; int64_t t0 = ggml_time_ms(); - // Apply lora - prompt = sd_ctx->sd->apply_loras_from_prompt(prompt); // Photo Maker std::string prompt_text_only; @@ -3188,6 +3176,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g size_t t0 = ggml_time_ms(); + // Apply lora + sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count); + enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method; if (sample_method == SAMPLE_METHOD_COUNT) { sample_method = sd_get_default_sample_method(sd_ctx); @@ -3487,7 +3478,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t t0 = ggml_time_ms(); // Apply lora - prompt = sd_ctx->sd->apply_loras_from_prompt(prompt); + sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count); ggml_tensor* init_latent = nullptr; ggml_tensor* clip_vision_output = nullptr; diff --git a/stable-diffusion.h b/stable-diffusion.h index cc5f4fa..601b79f 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -242,6 +242,14 @@ typedef struct { } sd_easycache_params_t; typedef struct { + bool is_high_noise; + float multiplier; + const char* path; +} sd_lora_t; + +typedef struct { + const sd_lora_t* loras; + uint32_t lora_count; const char* prompt; const char* negative_prompt; int clip_skip; @@ -265,6 +273,8 @@ typedef struct { } sd_img_gen_params_t; typedef struct { + const sd_lora_t* loras; + uint32_t lora_count; const char* prompt; const char* negative_prompt; int clip_skip; diff --git a/util.cpp b/util.cpp index 4a59852..680ff80 100644 --- a/util.cpp +++ b/util.cpp @@ -95,20 +95,6 @@ bool is_directory(const std::string& path) { return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY)); } -std::string get_full_path(const std::string& dir, const std::string& filename) { - std::string full_path = dir + "\\" + filename; - - WIN32_FIND_DATA find_file_data; - HANDLE hFind = FindFirstFile(full_path.c_str(), &find_file_data); - - if (hFind != INVALID_HANDLE_VALUE) { - FindClose(hFind); - return full_path; - } else { - return ""; - } -} - #else // Unix #include #include @@ -123,26 +109,6 @@ bool is_directory(const std::string& path) { return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); } -// TODO: add windows version -std::string get_full_path(const std::string& dir, const std::string& filename) { - DIR* dp = opendir(dir.c_str()); - - if (dp != nullptr) { - struct dirent* entry; - - while ((entry = readdir(dp)) != nullptr) { - if (strcasecmp(entry->d_name, filename.c_str()) == 0) { - closedir(dp); - return dir + "/" + entry->d_name; - } - } - - closedir(dp); - } - - return ""; -} - #endif // get_num_physical_cores is copy from diff --git a/util.h b/util.h index 61ca933..dd4a0c3 100644 --- a/util.h +++ b/util.h @@ -22,7 +22,6 @@ int round_up_to(int value, int base); bool file_exists(const std::string& filename); bool is_directory(const std::string& path); -std::string get_full_path(const std::string& dir, const std::string& filename); std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str);