refactor: consolidate WAN VAE version checks (#1712)

This commit is contained in:
leejet 2026-06-27 01:23:37 +08:00 committed by GitHub
parent ec4cb8104b
commit 9956436c92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 12 deletions

View File

@ -208,6 +208,13 @@ static inline bool sd_version_uses_flux2_vae(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_uses_wan_vae(SDVersion version) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_krea2(version) || sd_version_is_anima(version)) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT || version == VERSION_SD2_INPAINT ||

View File

@ -616,7 +616,6 @@ struct LogitNormalScheduler : SigmaScheduler {
one_minus_t_min = sigmoid(0.5f * logsnr_max); one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min); one_minus_t_max = sigmoid(0.5f * logsnr_min);
} }
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) { LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {

View File

@ -892,11 +892,7 @@ public:
} }
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> { auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) || if (sd_version_uses_wan_vae(version) || sd_version_is_ltxav(version)) {
sd_version_is_qwen_image(version) ||
sd_version_is_krea2(version) ||
sd_version_is_anima(version) ||
sd_version_is_ltxav(version)) {
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE), return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
tensor_storage_map, tensor_storage_map,
"decoder", "decoder",
@ -933,10 +929,7 @@ public:
false, false,
version, version,
model_manager); model_manager);
} else if (sd_version_is_wan(version) || } else if (sd_version_uses_wan_vae(version)) {
sd_version_is_qwen_image(version) ||
sd_version_is_krea2(version) ||
sd_version_is_anima(version)) {
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE), return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
tensor_storage_map, tensor_storage_map,
"first_stage_model", "first_stage_model",
@ -1742,7 +1735,7 @@ public:
} else if (sd_version_uses_flux_vae(version)) { } else if (sd_version_uses_flux_vae(version)) {
latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias; latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) { } else if (sd_version_uses_wan_vae(version)) {
latent_rgb_proj = wan_21_latent_rgb_proj; latent_rgb_proj = wan_21_latent_rgb_proj;
latent_rgb_bias = wan_21_latent_rgb_bias; latent_rgb_bias = wan_21_latent_rgb_bias;
} else { } else {
@ -3156,7 +3149,7 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return SIMPLE_SCHEDULER; return SIMPLE_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) { } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
return LTX2_SCHEDULER; return LTX2_SCHEDULER;
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) { } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
return LOGIT_NORMAL_SCHEDULER; return LOGIT_NORMAL_SCHEDULER;
} }
return DISCRETE_SCHEDULER; return DISCRETE_SCHEDULER;