diff --git a/src/model.h b/src/model.h index cce30913..44d66960 100644 --- a/src/model.h +++ b/src/model.h @@ -208,6 +208,13 @@ static inline bool sd_version_uses_flux2_vae(SDVersion version) { return false; } +static inline bool sd_version_uses_wan_vae(SDVersion version) { + if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_krea2(version) || sd_version_is_anima(version)) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || diff --git a/src/runtime/denoiser.hpp b/src/runtime/denoiser.hpp index 28b29ef2..ee77548f 100644 --- a/src/runtime/denoiser.hpp +++ b/src/runtime/denoiser.hpp @@ -616,7 +616,6 @@ struct LogitNormalScheduler : SigmaScheduler { one_minus_t_min = sigmoid(0.5f * logsnr_max); // t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min)); one_minus_t_max = sigmoid(0.5f * logsnr_min); - } LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) { diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 311c7511..f2b01e39 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -892,11 +892,7 @@ public: } auto create_tae = [&](bool decode_only) -> std::shared_ptr { - if (sd_version_is_wan(version) || - sd_version_is_qwen_image(version) || - sd_version_is_krea2(version) || - sd_version_is_anima(version) || - sd_version_is_ltxav(version)) { + if (sd_version_uses_wan_vae(version) || sd_version_is_ltxav(version)) { return std::make_shared(backend_for(SDBackendModule::VAE), tensor_storage_map, "decoder", @@ -933,10 +929,7 @@ public: false, version, model_manager); - } else if (sd_version_is_wan(version) || - sd_version_is_qwen_image(version) || - sd_version_is_krea2(version) || - sd_version_is_anima(version)) { + } else if (sd_version_uses_wan_vae(version)) { return std::make_shared(backend_for(SDBackendModule::VAE), tensor_storage_map, "first_stage_model", @@ -1742,7 +1735,7 @@ public: } else if (sd_version_uses_flux_vae(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; - } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) { + } else if (sd_version_uses_wan_vae(version)) { latent_rgb_proj = wan_21_latent_rgb_proj; latent_rgb_bias = wan_21_latent_rgb_bias; } else { @@ -3156,7 +3149,7 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return SIMPLE_SCHEDULER; } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) { return LTX2_SCHEDULER; - } else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) { + } else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) { return LOGIT_NORMAL_SCHEDULER; } return DISCRETE_SCHEDULER;