fix: strip trailing latent channels for preview decode (#1548)

This commit is contained in:
stduhpf 2026-05-22 18:26:40 +02:00 committed by GitHub
parent 8cf55a3b3b
commit cbf92191c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1610,23 +1610,18 @@ public:
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback, std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
void* step_callback_data, void* step_callback_data,
bool is_noisy) { bool is_noisy) {
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
int channels = get_latent_channel();
auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels)
: sd::ops::slice(latents, 2, 0, channels)
: latents;
if (preview_mode == PREVIEW_PROJ) { if (preview_mode == PREVIEW_PROJ) {
sd::Tensor<float> _latents = latents;
int patch_sz = 1; int patch_sz = 1;
const float(*latent_rgb_proj)[3] = nullptr; const float(*latent_rgb_proj)[3] = nullptr;
float* latent_rgb_bias = nullptr; float* latent_rgb_bias = nullptr;
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
if (version == VERSION_LTXAV) {
if (is_video) {
_latents = sd::ops::slice(_latents, 3, 0, 128);
} else {
_latents = sd::ops::slice(_latents, 2, 0, 128);
}
dim = 128;
}
if (dim == 128) { if (channels == 128) {
if (sd_version_uses_flux2_vae(version)) { if (sd_version_uses_flux2_vae(version)) {
latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias; latent_rgb_bias = flux2_latent_rgb_bias;
@ -1638,7 +1633,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model"); LOG_WARN("No latent to RGB projection known for this model");
return; return;
} }
} else if (dim == 48) { } else if (channels == 48) {
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
latent_rgb_proj = wan_22_latent_rgb_proj; latent_rgb_proj = wan_22_latent_rgb_proj;
latent_rgb_bias = wan_22_latent_rgb_bias; latent_rgb_bias = wan_22_latent_rgb_bias;
@ -1646,7 +1641,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model"); LOG_WARN("No latent to RGB projection known for this model");
return; return;
} }
} else if (dim == 16) { } else if (channels == 16) {
if (sd_version_is_sd3(version)) { if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias; latent_rgb_bias = sd3_latent_rgb_bias;
@ -1660,7 +1655,7 @@ public:
LOG_WARN("No latent to RGB projection known for this model"); LOG_WARN("No latent to RGB projection known for this model");
return; return;
} }
} else if (dim == 4) { } else if (channels == 4) {
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
latent_rgb_proj = sdxl_latent_rgb_proj; latent_rgb_proj = sdxl_latent_rgb_proj;
latent_rgb_bias = sdxl_latent_rgb_bias; latent_rgb_bias = sdxl_latent_rgb_bias;
@ -1671,8 +1666,8 @@ public:
LOG_WARN("No latent to RGB projection known for this model"); LOG_WARN("No latent to RGB projection known for this model");
return; return;
} }
} else if (dim != 3) { } else if (channels != 3) {
LOG_WARN("No latent to RGB projection known for this model"); LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim);
return; return;
} }
@ -1697,14 +1692,13 @@ public:
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) { if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
sd::Tensor<float> vae_latents; sd::Tensor<float> vae_latents;
sd::Tensor<float> decoded; sd::Tensor<float> decoded;
bool is_video = preview_latent_tensor_is_video(latents);
if (preview_vae) { if (preview_vae) {
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling); preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = preview_vae->diffusion_to_vae_latents(latents); vae_latents = preview_vae->diffusion_to_vae_latents(_latents);
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} else { } else {
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling); first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = first_stage_model->diffusion_to_vae_latents(latents); vae_latents = first_stage_model->diffusion_to_vae_latents(_latents);
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} }
if (decoded.empty()) { if (decoded.empty()) {