From 985aedda32bfd3c3e39d0f6d702483d2ad22a870 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 4 Dec 2025 23:31:55 +0800 Subject: [PATCH] refactor: optimize the handling of pred type (#1048) --- examples/cli/main.cpp | 4 +- stable-diffusion.cpp | 160 +++++++++++++++++------------------------- stable-diffusion.h | 3 +- 3 files changed, 68 insertions(+), 99 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 070dfa0..bf42f5a 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -409,7 +409,7 @@ struct SDCliParams { return -1; } const char* preview = argv[index]; - int preview_found = -1; + int preview_found = -1; for (int m = 0; m < PREVIEW_COUNT; m++) { if (!strcmp(preview, previews_str[m])) { preview_found = m; @@ -515,7 +515,7 @@ struct SDContextParams { bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; - prediction_t prediction = DEFAULT_PRED; + prediction_t prediction = PREDICTION_COUNT; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 831beb9..e554f09 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -707,7 +707,7 @@ public: return false; } - // LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0); + LOG_DEBUG("finished loaded file"); { size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size(); @@ -782,8 +782,59 @@ public: ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); } - if (sd_ctx_params->prediction != DEFAULT_PRED) { - switch (sd_ctx_params->prediction) { + // init denoiser + { + prediction_t pred_type = sd_ctx_params->prediction; + float flow_shift = sd_ctx_params->flow_shift; + + if (pred_type == PREDICTION_COUNT) { + if (sd_version_is_sd2(version)) { + // check is_using_v_parameterization_for_sd2 + if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { + pred_type = V_PRED; + } else { + pred_type = EPS_PRED; + } + } else if (sd_version_is_sdxl(version)) { + if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) { + // CosXL models + // TODO: get sigma_min and sigma_max values from file + pred_type = EDM_V_PRED; + } else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) { + pred_type = V_PRED; + } else { + pred_type = EPS_PRED; + } + } else if (sd_version_is_sd3(version) || + sd_version_is_wan(version) || + sd_version_is_qwen_image(version) || + sd_version_is_z_image(version)) { + pred_type = FLOW_PRED; + if (flow_shift == INFINITY) { + if (sd_version_is_wan(version)) { + flow_shift = 5.f; + } else { + flow_shift = 3.f; + } + } + } else if (sd_version_is_flux(version)) { + pred_type = FLUX_FLOW_PRED; + if (flow_shift == INFINITY) { + flow_shift = 1.0f; // TODO: validate + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) { + flow_shift = 1.15f; + } + } + } + } else if (sd_version_is_flux2(version)) { + pred_type = FLUX2_FLOW_PRED; + } else { + pred_type = EPS_PRED; + } + } + + switch (pred_type) { case EPS_PRED: LOG_INFO("running in eps-prediction mode"); break; @@ -795,22 +846,14 @@ public: LOG_INFO("running in v-prediction EDM mode"); denoiser = std::make_shared(); break; - case SD3_FLOW_PRED: { + case FLOW_PRED: { LOG_INFO("running in FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 3.0; - } - denoiser = std::make_shared(shift); + denoiser = std::make_shared(flow_shift); break; } case FLUX_FLOW_PRED: { LOG_INFO("running in Flux FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 3.0; - } - denoiser = std::make_shared(shift); + denoiser = std::make_shared(flow_shift); break; } case FLUX2_FLOW_PRED: { @@ -819,93 +862,21 @@ public: break; } default: { - LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction); + LOG_ERROR("Unknown predition type %i", pred_type); + ggml_free(ctx); return false; } } - } else { - if (sd_version_is_sd2(version)) { - // check is_using_v_parameterization_for_sd2 - if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { - is_using_v_parameterization = true; - } - } else if (sd_version_is_sdxl(version)) { - if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) { - // CosXL models - // TODO: get sigma_min and sigma_max values from file - is_using_edm_v_parameterization = true; - } - if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) { - is_using_v_parameterization = true; - } - } else if (version == VERSION_SVD) { - // TODO: V_PREDICTION_EDM - is_using_v_parameterization = true; - } - if (sd_version_is_sd3(version)) { - LOG_INFO("running in FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 3.0; + auto comp_vis_denoiser = std::dynamic_pointer_cast(denoiser); + if (comp_vis_denoiser) { + for (int i = 0; i < TIMESTEPS; i++) { + comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]); + comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]); } - denoiser = std::make_shared(shift); - } else if (sd_version_is_flux(version)) { - LOG_INFO("running in Flux FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 1.0f; // TODO: validate - for (const auto& [name, tensor_storage] : tensor_storage_map) { - if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) { - shift = 1.15f; - } - } - } - denoiser = std::make_shared(shift); - } else if (sd_version_is_flux2(version)) { - LOG_INFO("running in Flux2 FLOW mode"); - denoiser = std::make_shared(); - } else if (sd_version_is_wan(version)) { - LOG_INFO("running in FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 5.0; - } - denoiser = std::make_shared(shift); - } else if (sd_version_is_qwen_image(version)) { - LOG_INFO("running in FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 3.0; - } - denoiser = std::make_shared(shift); - } else if (sd_version_is_z_image(version)) { - LOG_INFO("running in FLOW mode"); - float shift = sd_ctx_params->flow_shift; - if (shift == INFINITY) { - shift = 3.0f; - } - denoiser = std::make_shared(shift); - } else if (is_using_v_parameterization) { - LOG_INFO("running in v-prediction mode"); - denoiser = std::make_shared(); - } else if (is_using_edm_v_parameterization) { - LOG_INFO("running in v-prediction EDM mode"); - denoiser = std::make_shared(); - } else { - LOG_INFO("running in eps-prediction mode"); } } - auto comp_vis_denoiser = std::dynamic_pointer_cast(denoiser); - if (comp_vis_denoiser) { - for (int i = 0; i < TIMESTEPS; i++) { - comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]); - comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]); - } - } - - LOG_DEBUG("finished loaded file"); ggml_free(ctx); use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only; return true; @@ -2426,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) { } const char* prediction_to_str[] = { - "default", "eps", "v", "edm_v", @@ -2512,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->wtype = SD_TYPE_COUNT; sd_ctx_params->rng_type = CUDA_RNG; sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT; - sd_ctx_params->prediction = DEFAULT_PRED; + sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->keep_clip_on_cpu = false; diff --git a/stable-diffusion.h b/stable-diffusion.h index b0d3ee6..e34cdec 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -65,11 +65,10 @@ enum scheduler_t { }; enum prediction_t { - DEFAULT_PRED, EPS_PRED, V_PRED, EDM_V_PRED, - SD3_FLOW_PRED, + FLOW_PRED, FLUX_FLOW_PRED, FLUX2_FLOW_PRED, PREDICTION_COUNT