mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
fix: strip trailing latent channels for preview decode (#1548)
This commit is contained in:
parent
8cf55a3b3b
commit
cbf92191c3
@ -1610,23 +1610,18 @@ public:
|
|||||||
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
|
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
|
||||||
void* step_callback_data,
|
void* step_callback_data,
|
||||||
bool is_noisy) {
|
bool is_noisy) {
|
||||||
|
bool is_video = preview_latent_tensor_is_video(latents);
|
||||||
|
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
|
||||||
|
int channels = get_latent_channel();
|
||||||
|
auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels)
|
||||||
|
: sd::ops::slice(latents, 2, 0, channels)
|
||||||
|
: latents;
|
||||||
if (preview_mode == PREVIEW_PROJ) {
|
if (preview_mode == PREVIEW_PROJ) {
|
||||||
sd::Tensor<float> _latents = latents;
|
|
||||||
int patch_sz = 1;
|
int patch_sz = 1;
|
||||||
const float(*latent_rgb_proj)[3] = nullptr;
|
const float(*latent_rgb_proj)[3] = nullptr;
|
||||||
float* latent_rgb_bias = nullptr;
|
float* latent_rgb_bias = nullptr;
|
||||||
bool is_video = preview_latent_tensor_is_video(latents);
|
|
||||||
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
|
|
||||||
if (version == VERSION_LTXAV) {
|
|
||||||
if (is_video) {
|
|
||||||
_latents = sd::ops::slice(_latents, 3, 0, 128);
|
|
||||||
} else {
|
|
||||||
_latents = sd::ops::slice(_latents, 2, 0, 128);
|
|
||||||
}
|
|
||||||
dim = 128;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dim == 128) {
|
if (channels == 128) {
|
||||||
if (sd_version_uses_flux2_vae(version)) {
|
if (sd_version_uses_flux2_vae(version)) {
|
||||||
latent_rgb_proj = flux2_latent_rgb_proj;
|
latent_rgb_proj = flux2_latent_rgb_proj;
|
||||||
latent_rgb_bias = flux2_latent_rgb_bias;
|
latent_rgb_bias = flux2_latent_rgb_bias;
|
||||||
@ -1638,7 +1633,7 @@ public:
|
|||||||
LOG_WARN("No latent to RGB projection known for this model");
|
LOG_WARN("No latent to RGB projection known for this model");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (dim == 48) {
|
} else if (channels == 48) {
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version)) {
|
||||||
latent_rgb_proj = wan_22_latent_rgb_proj;
|
latent_rgb_proj = wan_22_latent_rgb_proj;
|
||||||
latent_rgb_bias = wan_22_latent_rgb_bias;
|
latent_rgb_bias = wan_22_latent_rgb_bias;
|
||||||
@ -1646,7 +1641,7 @@ public:
|
|||||||
LOG_WARN("No latent to RGB projection known for this model");
|
LOG_WARN("No latent to RGB projection known for this model");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (dim == 16) {
|
} else if (channels == 16) {
|
||||||
if (sd_version_is_sd3(version)) {
|
if (sd_version_is_sd3(version)) {
|
||||||
latent_rgb_proj = sd3_latent_rgb_proj;
|
latent_rgb_proj = sd3_latent_rgb_proj;
|
||||||
latent_rgb_bias = sd3_latent_rgb_bias;
|
latent_rgb_bias = sd3_latent_rgb_bias;
|
||||||
@ -1660,7 +1655,7 @@ public:
|
|||||||
LOG_WARN("No latent to RGB projection known for this model");
|
LOG_WARN("No latent to RGB projection known for this model");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (dim == 4) {
|
} else if (channels == 4) {
|
||||||
if (sd_version_is_sdxl(version)) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
latent_rgb_proj = sdxl_latent_rgb_proj;
|
latent_rgb_proj = sdxl_latent_rgb_proj;
|
||||||
latent_rgb_bias = sdxl_latent_rgb_bias;
|
latent_rgb_bias = sdxl_latent_rgb_bias;
|
||||||
@ -1671,8 +1666,8 @@ public:
|
|||||||
LOG_WARN("No latent to RGB projection known for this model");
|
LOG_WARN("No latent to RGB projection known for this model");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (dim != 3) {
|
} else if (channels != 3) {
|
||||||
LOG_WARN("No latent to RGB projection known for this model");
|
LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1697,14 +1692,13 @@ public:
|
|||||||
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
|
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
|
||||||
sd::Tensor<float> vae_latents;
|
sd::Tensor<float> vae_latents;
|
||||||
sd::Tensor<float> decoded;
|
sd::Tensor<float> decoded;
|
||||||
bool is_video = preview_latent_tensor_is_video(latents);
|
|
||||||
if (preview_vae) {
|
if (preview_vae) {
|
||||||
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
|
vae_latents = preview_vae->diffusion_to_vae_latents(_latents);
|
||||||
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
||||||
} else {
|
} else {
|
||||||
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
|
vae_latents = first_stage_model->diffusion_to_vae_latents(_latents);
|
||||||
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
||||||
}
|
}
|
||||||
if (decoded.empty()) {
|
if (decoded.empty()) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user