refactor: consolidate WAN VAE version checks (#1712)

This commit is contained in:
leejet 2026-06-27 01:23:37 +08:00 committed by GitHub
parent ec4cb8104b
commit 9956436c92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 12 deletions

View File

@ -208,6 +208,13 @@ static inline bool sd_version_uses_flux2_vae(SDVersion version) {
return false;
}
static inline bool sd_version_uses_wan_vae(SDVersion version) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_krea2(version) || sd_version_is_anima(version)) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||

View File

@ -616,7 +616,6 @@ struct LogitNormalScheduler : SigmaScheduler {
one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min);
}
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {

View File

@ -892,11 +892,7 @@ public:
}
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_krea2(version) ||
sd_version_is_anima(version) ||
sd_version_is_ltxav(version)) {
if (sd_version_uses_wan_vae(version) || sd_version_is_ltxav(version)) {
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"decoder",
@ -933,10 +929,7 @@ public:
false,
version,
model_manager);
} else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_krea2(version) ||
sd_version_is_anima(version)) {
} else if (sd_version_uses_wan_vae(version)) {
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"first_stage_model",
@ -1742,7 +1735,7 @@ public:
} else if (sd_version_uses_flux_vae(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) {
} else if (sd_version_uses_wan_vae(version)) {
latent_rgb_proj = wan_21_latent_rgb_proj;
latent_rgb_bias = wan_21_latent_rgb_bias;
} else {