fix: tensor loading thread count (#854)

This commit is contained in:
Wagner Bruna 2025-09-24 13:26:38 -03:00 committed by GitHub
parent 98ba155fc6
commit f3140eadbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 21 additions and 19 deletions

View File

@ -141,7 +141,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
} }
return true; return true;
}; };
model_loader.load_tensors(on_load); model_loader.load_tensors(on_load, 1);
readed_embeddings.push_back(embd_name); readed_embeddings.push_back(embd_name);
if (embd) { if (embd) {
int64_t hidden_size = text_model->model.hidden_size; int64_t hidden_size = text_model->model.hidden_size;

View File

@ -445,7 +445,7 @@ struct ControlNet : public GGMLRunner {
guided_hint_cached = true; guided_hint_cached = true;
} }
bool load_from_file(const std::string& file_path) { bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading control net from '%s'", file_path.c_str()); LOG_INFO("loading control net from '%s'", file_path.c_str());
alloc_params_buffer(); alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;
@ -458,7 +458,7 @@ struct ControlNet : public GGMLRunner {
return false; return false;
} }
bool success = model_loader.load_tensors(tensors, ignore_tensors); bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads);
if (!success) { if (!success) {
LOG_ERROR("load control net tensors from model loader failed"); LOG_ERROR("load control net tensors from model loader failed");

View File

@ -164,7 +164,7 @@ struct ESRGAN : public GGMLRunner {
return "esrgan"; return "esrgan";
} }
bool load_from_file(const std::string& file_path) { bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading esrgan from '%s'", file_path.c_str()); LOG_INFO("loading esrgan from '%s'", file_path.c_str());
alloc_params_buffer(); alloc_params_buffer();
@ -177,7 +177,7 @@ struct ESRGAN : public GGMLRunner {
return false; return false;
} }
bool success = model_loader.load_tensors(esrgan_tensors); bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
if (!success) { if (!success) {
LOG_ERROR("load esrgan tensors from model loader failed"); LOG_ERROR("load esrgan tensors from model loader failed");

View File

@ -116,7 +116,7 @@ struct LoraModel : public GGMLRunner {
return "lora"; return "lora";
} }
bool load_from_file(bool filter_tensor = false, int n_threads = 0) { bool load_from_file(bool filter_tensor, int n_threads) {
LOG_INFO("loading LoRA from '%s'", file_path.c_str()); LOG_INFO("loading LoRA from '%s'", file_path.c_str());
if (load_failed) { if (load_failed) {

View File

@ -1957,7 +1957,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
std::atomic<int64_t> copy_to_backend_time_ms(0); std::atomic<int64_t> copy_to_backend_time_ms(0);
std::atomic<int64_t> convert_time_ms(0); std::atomic<int64_t> convert_time_ms(0);
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : (int)std::thread::hardware_concurrency(); int num_threads_to_use = n_threads_p > 0 ? n_threads_p : get_num_physical_cores();
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
int64_t start_time = ggml_time_ms(); int64_t start_time = ggml_time_ms();
std::vector<TensorStorage> processed_tensor_storages; std::vector<TensorStorage> processed_tensor_storages;

View File

@ -591,7 +591,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return "id_embeds"; return "id_embeds";
} }
bool load_from_file(bool filter_tensor = false) { bool load_from_file(bool filter_tensor, int n_threads) {
LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str()); LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str());
if (load_failed) { if (load_failed) {
@ -623,11 +623,11 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return true; return true;
}; };
model_loader->load_tensors(on_new_tensor_cb); model_loader->load_tensors(on_new_tensor_cb, n_threads);
alloc_params_buffer(); alloc_params_buffer();
dry_run = false; dry_run = false;
model_loader->load_tensors(on_new_tensor_cb); model_loader->load_tensors(on_new_tensor_cb, n_threads);
LOG_DEBUG("finished loading PhotoMaker ID Embeds "); LOG_DEBUG("finished loading PhotoMaker ID Embeds ");
return true; return true;

View File

@ -531,7 +531,7 @@ public:
} }
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, ""); pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
if (!pmid_lora->load_from_file(true)) { if (!pmid_lora->load_from_file(true, n_threads)) {
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
return false; return false;
} }
@ -599,14 +599,14 @@ public:
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
vae_params_mem_size = first_stage_model->get_params_buffer_size(); vae_params_mem_size = first_stage_model->get_params_buffer_size();
} else { } else {
if (!tae_first_stage->load_from_file(taesd_path)) { if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false; return false;
} }
vae_params_mem_size = tae_first_stage->get_params_buffer_size(); vae_params_mem_size = tae_first_stage->get_params_buffer_size();
} }
size_t control_net_params_mem_size = 0; size_t control_net_params_mem_size = 0;
if (control_net) { if (control_net) {
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path))) { if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
return false; return false;
} }
control_net_params_mem_size = control_net->get_params_buffer_size(); control_net_params_mem_size = control_net->get_params_buffer_size();
@ -836,7 +836,7 @@ public:
return; return;
} }
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : ""); LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
if (!lora.load_from_file()) { if (!lora.load_from_file(false, n_threads)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str()); LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return; return;
} }

View File

@ -222,7 +222,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return "taesd"; return "taesd";
} }
bool load_from_file(const std::string& file_path) { bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer(); alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors; std::map<std::string, ggml_tensor*> taesd_tensors;
@ -238,7 +238,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return false; return false;
} }
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors); bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads);
if (!success) { if (!success) {
LOG_ERROR("load tae tensors from model loader failed"); LOG_ERROR("load tae tensors from model loader failed");

View File

@ -18,7 +18,8 @@ struct UpscalerGGML {
} }
bool load_from_file(const std::string& esrgan_path, bool load_from_file(const std::string& esrgan_path,
bool offload_params_to_cpu) { bool offload_params_to_cpu,
int n_threads) {
ggml_log_set(ggml_log_callback_default, nullptr); ggml_log_set(ggml_log_callback_default, nullptr);
#ifdef SD_USE_CUDA #ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend"); LOG_DEBUG("Using CUDA backend");
@ -54,7 +55,7 @@ struct UpscalerGGML {
if (direct) { if (direct) {
esrgan_upscaler->enable_conv2d_direct(); esrgan_upscaler->enable_conv2d_direct();
} }
if (!esrgan_upscaler->load_from_file(esrgan_path)) { if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
return false; return false;
} }
return true; return true;
@ -124,7 +125,7 @@ upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
return NULL; return NULL;
} }
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, offload_params_to_cpu)) { if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, offload_params_to_cpu, n_threads)) {
delete upscaler_ctx->upscaler; delete upscaler_ctx->upscaler;
upscaler_ctx->upscaler = NULL; upscaler_ctx->upscaler = NULL;
free(upscaler_ctx); free(upscaler_ctx);