fix: tensor loading thread count (#854)

This commit is contained in:
Wagner Bruna 2025-09-24 13:26:38 -03:00 committed by GitHub
parent 98ba155fc6
commit f3140eadbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 21 additions and 19 deletions

View File

@ -141,7 +141,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}
return true;
};
model_loader.load_tensors(on_load);
model_loader.load_tensors(on_load, 1);
readed_embeddings.push_back(embd_name);
if (embd) {
int64_t hidden_size = text_model->model.hidden_size;

View File

@ -445,7 +445,7 @@ struct ControlNet : public GGMLRunner {
guided_hint_cached = true;
}
bool load_from_file(const std::string& file_path) {
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading control net from '%s'", file_path.c_str());
alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
@ -458,7 +458,7 @@ struct ControlNet : public GGMLRunner {
return false;
}
bool success = model_loader.load_tensors(tensors, ignore_tensors);
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads);
if (!success) {
LOG_ERROR("load control net tensors from model loader failed");

View File

@ -164,7 +164,7 @@ struct ESRGAN : public GGMLRunner {
return "esrgan";
}
bool load_from_file(const std::string& file_path) {
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
alloc_params_buffer();
@ -177,7 +177,7 @@ struct ESRGAN : public GGMLRunner {
return false;
}
bool success = model_loader.load_tensors(esrgan_tensors);
bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
if (!success) {
LOG_ERROR("load esrgan tensors from model loader failed");

View File

@ -116,7 +116,7 @@ struct LoraModel : public GGMLRunner {
return "lora";
}
bool load_from_file(bool filter_tensor = false, int n_threads = 0) {
bool load_from_file(bool filter_tensor, int n_threads) {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
if (load_failed) {

View File

@ -1957,7 +1957,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
std::atomic<int64_t> copy_to_backend_time_ms(0);
std::atomic<int64_t> convert_time_ms(0);
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : (int)std::thread::hardware_concurrency();
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : get_num_physical_cores();
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
int64_t start_time = ggml_time_ms();
std::vector<TensorStorage> processed_tensor_storages;

View File

@ -591,7 +591,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return "id_embeds";
}
bool load_from_file(bool filter_tensor = false) {
bool load_from_file(bool filter_tensor, int n_threads) {
LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str());
if (load_failed) {
@ -623,11 +623,11 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return true;
};
model_loader->load_tensors(on_new_tensor_cb);
model_loader->load_tensors(on_new_tensor_cb, n_threads);
alloc_params_buffer();
dry_run = false;
model_loader->load_tensors(on_new_tensor_cb);
model_loader->load_tensors(on_new_tensor_cb, n_threads);
LOG_DEBUG("finished loading PhotoMaker ID Embeds ");
return true;

View File

@ -531,7 +531,7 @@ public:
}
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
if (!pmid_lora->load_from_file(true)) {
if (!pmid_lora->load_from_file(true, n_threads)) {
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
return false;
}
@ -599,14 +599,14 @@ public:
if (!use_tiny_autoencoder) {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
} else {
if (!tae_first_stage->load_from_file(taesd_path)) {
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
}
size_t control_net_params_mem_size = 0;
if (control_net) {
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path))) {
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
return false;
}
control_net_params_mem_size = control_net->get_params_buffer_size();
@ -836,7 +836,7 @@ public:
return;
}
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
if (!lora.load_from_file()) {
if (!lora.load_from_file(false, n_threads)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return;
}

View File

@ -222,7 +222,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return "taesd";
}
bool load_from_file(const std::string& file_path) {
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors;
@ -238,7 +238,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return false;
}
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors);
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads);
if (!success) {
LOG_ERROR("load tae tensors from model loader failed");

View File

@ -18,7 +18,8 @@ struct UpscalerGGML {
}
bool load_from_file(const std::string& esrgan_path,
bool offload_params_to_cpu) {
bool offload_params_to_cpu,
int n_threads) {
ggml_log_set(ggml_log_callback_default, nullptr);
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
@ -54,7 +55,7 @@ struct UpscalerGGML {
if (direct) {
esrgan_upscaler->enable_conv2d_direct();
}
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
return false;
}
return true;
@ -124,7 +125,7 @@ upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
return NULL;
}
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, offload_params_to_cpu)) {
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, offload_params_to_cpu, n_threads)) {
delete upscaler_ctx->upscaler;
upscaler_ctx->upscaler = NULL;
free(upscaler_ctx);