mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-27 00:26:38 +00:00
refactor: consolidate WAN VAE version checks (#1712)
This commit is contained in:
parent
ec4cb8104b
commit
9956436c92
@ -208,6 +208,13 @@ static inline bool sd_version_uses_flux2_vae(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_uses_wan_vae(SDVersion version) {
|
||||||
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_krea2(version) || sd_version_is_anima(version)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||||
if (version == VERSION_SD1_INPAINT ||
|
if (version == VERSION_SD1_INPAINT ||
|
||||||
version == VERSION_SD2_INPAINT ||
|
version == VERSION_SD2_INPAINT ||
|
||||||
|
|||||||
@ -616,7 +616,6 @@ struct LogitNormalScheduler : SigmaScheduler {
|
|||||||
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
||||||
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
||||||
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
||||||
|
|||||||
@ -892,11 +892,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
|
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
|
||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_uses_wan_vae(version) || sd_version_is_ltxav(version)) {
|
||||||
sd_version_is_qwen_image(version) ||
|
|
||||||
sd_version_is_krea2(version) ||
|
|
||||||
sd_version_is_anima(version) ||
|
|
||||||
sd_version_is_ltxav(version)) {
|
|
||||||
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"decoder",
|
"decoder",
|
||||||
@ -933,10 +929,7 @@ public:
|
|||||||
false,
|
false,
|
||||||
version,
|
version,
|
||||||
model_manager);
|
model_manager);
|
||||||
} else if (sd_version_is_wan(version) ||
|
} else if (sd_version_uses_wan_vae(version)) {
|
||||||
sd_version_is_qwen_image(version) ||
|
|
||||||
sd_version_is_krea2(version) ||
|
|
||||||
sd_version_is_anima(version)) {
|
|
||||||
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"first_stage_model",
|
"first_stage_model",
|
||||||
@ -1742,7 +1735,7 @@ public:
|
|||||||
} else if (sd_version_uses_flux_vae(version)) {
|
} else if (sd_version_uses_flux_vae(version)) {
|
||||||
latent_rgb_proj = flux_latent_rgb_proj;
|
latent_rgb_proj = flux_latent_rgb_proj;
|
||||||
latent_rgb_bias = flux_latent_rgb_bias;
|
latent_rgb_bias = flux_latent_rgb_bias;
|
||||||
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) {
|
} else if (sd_version_uses_wan_vae(version)) {
|
||||||
latent_rgb_proj = wan_21_latent_rgb_proj;
|
latent_rgb_proj = wan_21_latent_rgb_proj;
|
||||||
latent_rgb_bias = wan_21_latent_rgb_bias;
|
latent_rgb_bias = wan_21_latent_rgb_bias;
|
||||||
} else {
|
} else {
|
||||||
@ -3156,7 +3149,7 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
|
|||||||
return SIMPLE_SCHEDULER;
|
return SIMPLE_SCHEDULER;
|
||||||
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
return LTX2_SCHEDULER;
|
return LTX2_SCHEDULER;
|
||||||
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
|
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
|
||||||
return LOGIT_NORMAL_SCHEDULER;
|
return LOGIT_NORMAL_SCHEDULER;
|
||||||
}
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user