mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: optimize the handling of pred type (#1048)
This commit is contained in:
parent
3f3610b5cd
commit
985aedda32
@ -515,7 +515,7 @@ struct SDContextParams {
|
||||
bool chroma_use_t5_mask = false;
|
||||
int chroma_t5_mask_pad = 1;
|
||||
|
||||
prediction_t prediction = DEFAULT_PRED;
|
||||
prediction_t prediction = PREDICTION_COUNT;
|
||||
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
|
||||
|
||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
|
||||
@ -707,7 +707,7 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
|
||||
LOG_DEBUG("finished loaded file");
|
||||
|
||||
{
|
||||
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
|
||||
@ -782,8 +782,59 @@ public:
|
||||
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
|
||||
}
|
||||
|
||||
if (sd_ctx_params->prediction != DEFAULT_PRED) {
|
||||
switch (sd_ctx_params->prediction) {
|
||||
// init denoiser
|
||||
{
|
||||
prediction_t pred_type = sd_ctx_params->prediction;
|
||||
float flow_shift = sd_ctx_params->flow_shift;
|
||||
|
||||
if (pred_type == PREDICTION_COUNT) {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
pred_type = V_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
pred_type = EDM_V_PRED;
|
||||
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
pred_type = V_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
} else if (sd_version_is_sd3(version) ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_z_image(version)) {
|
||||
pred_type = FLOW_PRED;
|
||||
if (flow_shift == INFINITY) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
flow_shift = 5.f;
|
||||
} else {
|
||||
flow_shift = 3.f;
|
||||
}
|
||||
}
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
pred_type = FLUX_FLOW_PRED;
|
||||
if (flow_shift == INFINITY) {
|
||||
flow_shift = 1.0f; // TODO: validate
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
flow_shift = 1.15f;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
pred_type = FLUX2_FLOW_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
}
|
||||
|
||||
switch (pred_type) {
|
||||
case EPS_PRED:
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
break;
|
||||
@ -795,22 +846,14 @@ public:
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
break;
|
||||
case SD3_FLOW_PRED: {
|
||||
case FLOW_PRED: {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
|
||||
break;
|
||||
}
|
||||
case FLUX_FLOW_PRED: {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
|
||||
break;
|
||||
}
|
||||
case FLUX2_FLOW_PRED: {
|
||||
@ -819,83 +862,11 @@ public:
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
|
||||
LOG_ERROR("Unknown predition type %i", pred_type);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
// TODO: V_PREDICTION_EDM
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (sd_version_is_sd3(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 1.0f; // TODO: validate
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
shift = 1.15f;
|
||||
}
|
||||
}
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
LOG_INFO("running in Flux2 FLOW mode");
|
||||
denoiser = std::make_shared<Flux2FlowDenoiser>();
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 5.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_z_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0f;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
} else if (is_using_edm_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
} else {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
}
|
||||
}
|
||||
|
||||
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
|
||||
if (comp_vis_denoiser) {
|
||||
@ -904,8 +875,8 @@ public:
|
||||
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DEBUG("finished loaded file");
|
||||
ggml_free(ctx);
|
||||
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
|
||||
return true;
|
||||
@ -2426,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
|
||||
}
|
||||
|
||||
const char* prediction_to_str[] = {
|
||||
"default",
|
||||
"eps",
|
||||
"v",
|
||||
"edm_v",
|
||||
@ -2512,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->wtype = SD_TYPE_COUNT;
|
||||
sd_ctx_params->rng_type = CUDA_RNG;
|
||||
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
|
||||
sd_ctx_params->prediction = DEFAULT_PRED;
|
||||
sd_ctx_params->prediction = PREDICTION_COUNT;
|
||||
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
|
||||
sd_ctx_params->offload_params_to_cpu = false;
|
||||
sd_ctx_params->keep_clip_on_cpu = false;
|
||||
|
||||
@ -65,11 +65,10 @@ enum scheduler_t {
|
||||
};
|
||||
|
||||
enum prediction_t {
|
||||
DEFAULT_PRED,
|
||||
EPS_PRED,
|
||||
V_PRED,
|
||||
EDM_V_PRED,
|
||||
SD3_FLOW_PRED,
|
||||
FLOW_PRED,
|
||||
FLUX_FLOW_PRED,
|
||||
FLUX2_FLOW_PRED,
|
||||
PREDICTION_COUNT
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user