refactor: optimize the handling of pred type (#1048)

This commit is contained in:
leejet 2025-12-04 23:31:55 +08:00 committed by GitHub
parent 3f3610b5cd
commit 985aedda32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 99 deletions

View File

@ -409,7 +409,7 @@ struct SDCliParams {
return -1; return -1;
} }
const char* preview = argv[index]; const char* preview = argv[index];
int preview_found = -1; int preview_found = -1;
for (int m = 0; m < PREVIEW_COUNT; m++) { for (int m = 0; m < PREVIEW_COUNT; m++) {
if (!strcmp(preview, previews_str[m])) { if (!strcmp(preview, previews_str[m])) {
preview_found = m; preview_found = m;
@ -515,7 +515,7 @@ struct SDContextParams {
bool chroma_use_t5_mask = false; bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1; int chroma_t5_mask_pad = 1;
prediction_t prediction = DEFAULT_PRED; prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};

View File

@ -707,7 +707,7 @@ public:
return false; return false;
} }
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0); LOG_DEBUG("finished loaded file");
{ {
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size(); size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
@ -782,8 +782,59 @@ public:
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
} }
if (sd_ctx_params->prediction != DEFAULT_PRED) { // init denoiser
switch (sd_ctx_params->prediction) { {
prediction_t pred_type = sd_ctx_params->prediction;
float flow_shift = sd_ctx_params->flow_shift;
if (pred_type == PREDICTION_COUNT) {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
pred_type = EDM_V_PRED;
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (flow_shift == INFINITY) {
if (sd_version_is_wan(version)) {
flow_shift = 5.f;
} else {
flow_shift = 3.f;
}
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
flow_shift = 1.15f;
}
}
}
} else if (sd_version_is_flux2(version)) {
pred_type = FLUX2_FLOW_PRED;
} else {
pred_type = EPS_PRED;
}
}
switch (pred_type) {
case EPS_PRED: case EPS_PRED:
LOG_INFO("running in eps-prediction mode"); LOG_INFO("running in eps-prediction mode");
break; break;
@ -795,22 +846,14 @@ public:
LOG_INFO("running in v-prediction EDM mode"); LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>(); denoiser = std::make_shared<EDMVDenoiser>();
break; break;
case SD3_FLOW_PRED: { case FLOW_PRED: {
LOG_INFO("running in FLOW mode"); LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift; denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
break; break;
} }
case FLUX_FLOW_PRED: { case FLUX_FLOW_PRED: {
LOG_INFO("running in Flux FLOW mode"); LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift; denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
break; break;
} }
case FLUX2_FLOW_PRED: { case FLUX2_FLOW_PRED: {
@ -819,93 +862,21 @@ public:
break; break;
} }
default: { default: {
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction); LOG_ERROR("Unknown predition type %i", pred_type);
ggml_free(ctx);
return false; return false;
} }
} }
} else {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}
if (sd_version_is_sd3(version)) { auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
LOG_INFO("running in FLOW mode"); if (comp_vis_denoiser) {
float shift = sd_ctx_params->flow_shift; for (int i = 0; i < TIMESTEPS; i++) {
if (shift == INFINITY) { comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
shift = 3.0; comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
} }
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
shift = 1.15f;
}
}
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_flux2(version)) {
LOG_INFO("running in Flux2 FLOW mode");
denoiser = std::make_shared<Flux2FlowDenoiser>();
} else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 5.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_qwen_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_z_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0f;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else if (is_using_edm_v_parameterization) {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
} }
} }
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
}
}
LOG_DEBUG("finished loaded file");
ggml_free(ctx); ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only; use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
return true; return true;
@ -2426,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
} }
const char* prediction_to_str[] = { const char* prediction_to_str[] = {
"default",
"eps", "eps",
"v", "v",
"edm_v", "edm_v",
@ -2512,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->wtype = SD_TYPE_COUNT; sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG; sd_ctx_params->rng_type = CUDA_RNG;
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT; sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
sd_ctx_params->prediction = DEFAULT_PRED; sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_clip_on_cpu = false;

View File

@ -65,11 +65,10 @@ enum scheduler_t {
}; };
enum prediction_t { enum prediction_t {
DEFAULT_PRED,
EPS_PRED, EPS_PRED,
V_PRED, V_PRED,
EDM_V_PRED, EDM_V_PRED,
SD3_FLOW_PRED, FLOW_PRED,
FLUX_FLOW_PRED, FLUX_FLOW_PRED,
FLUX2_FLOW_PRED, FLUX2_FLOW_PRED,
PREDICTION_COUNT PREDICTION_COUNT