mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add LTX temporal latent upscaler support (#1551)
This commit is contained in:
parent
645e6e9089
commit
0baf721215
@ -31,6 +31,7 @@ namespace LTXVUpsampler {
|
|||||||
float spatial_scale = 2.f;
|
float spatial_scale = 2.f;
|
||||||
int spatial_up_num = 2;
|
int spatial_up_num = 2;
|
||||||
int spatial_down_den = 1;
|
int spatial_down_den = 1;
|
||||||
|
int temporal_up_factor = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
|
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
|
||||||
@ -83,9 +84,13 @@ namespace LTXVUpsampler {
|
|||||||
if (detected_blocks > 0) {
|
if (detected_blocks > 0) {
|
||||||
config.num_blocks_per_stage = detected_blocks;
|
config.num_blocks_per_stage = detected_blocks;
|
||||||
}
|
}
|
||||||
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
||||||
config.spatial_upsample = config.rational_resampler || has_tensor(tensor_storage_map, "upsampler.0.weight");
|
int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0);
|
||||||
config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight");
|
config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels;
|
||||||
|
config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels;
|
||||||
|
if (config.temporal_upsample) {
|
||||||
|
config.temporal_up_factor = 2;
|
||||||
|
}
|
||||||
if (config.rational_resampler) {
|
if (config.rational_resampler) {
|
||||||
int64_t out_channels = get_tensor_ne(tensor_storage_map,
|
int64_t out_channels = get_tensor_ne(tensor_storage_map,
|
||||||
"upsampler.conv.weight",
|
"upsampler.conv.weight",
|
||||||
@ -207,6 +212,30 @@ namespace LTXVUpsampler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TemporalPixelShuffleND : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int upscale_factor;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit TemporalPixelShuffleND(int upscale_factor)
|
||||||
|
: upscale_factor(upscale_factor) {}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
GGML_ASSERT(upscale_factor > 0);
|
||||||
|
GGML_ASSERT(x->ne[3] % upscale_factor == 0);
|
||||||
|
const int64_t W = x->ne[0];
|
||||||
|
const int64_t H = x->ne[1];
|
||||||
|
const int64_t F = x->ne[2];
|
||||||
|
const int64_t C = x->ne[3] / upscale_factor;
|
||||||
|
|
||||||
|
// x: [b, c*p, f, h, w] -> [b, c, f*p, h, w]
|
||||||
|
x = ggml_ext_cont(ctx->ggml_ctx, x);
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, F, upscale_factor, C);
|
||||||
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3));
|
||||||
|
return ggml_reshape_4d(ctx->ggml_ctx, x, W, H, F * upscale_factor, C);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class BlurDownsample : public GGMLBlock {
|
class BlurDownsample : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t channels;
|
int64_t channels;
|
||||||
@ -308,8 +337,7 @@ namespace LTXVUpsampler {
|
|||||||
explicit LatentUpsampler(LatentUpsamplerConfig config)
|
explicit LatentUpsampler(LatentUpsamplerConfig config)
|
||||||
: config(std::move(config)) {
|
: config(std::move(config)) {
|
||||||
GGML_ASSERT(this->config.dims == 3);
|
GGML_ASSERT(this->config.dims == 3);
|
||||||
GGML_ASSERT(this->config.spatial_upsample);
|
GGML_ASSERT(this->config.spatial_upsample || this->config.temporal_upsample);
|
||||||
GGML_ASSERT(!this->config.temporal_upsample);
|
|
||||||
|
|
||||||
blocks["initial_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.in_channels,
|
blocks["initial_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.in_channels,
|
||||||
this->config.mid_channels,
|
this->config.mid_channels,
|
||||||
@ -324,6 +352,13 @@ namespace LTXVUpsampler {
|
|||||||
blocks["upsampler"] = std::shared_ptr<GGMLBlock>(new SpatialRationalResampler(this->config.mid_channels,
|
blocks["upsampler"] = std::shared_ptr<GGMLBlock>(new SpatialRationalResampler(this->config.mid_channels,
|
||||||
this->config.spatial_up_num,
|
this->config.spatial_up_num,
|
||||||
this->config.spatial_down_den));
|
this->config.spatial_down_den));
|
||||||
|
} else if (this->config.temporal_upsample) {
|
||||||
|
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.mid_channels,
|
||||||
|
this->config.temporal_up_factor * this->config.mid_channels,
|
||||||
|
{3, 3, 3},
|
||||||
|
{1, 1, 1},
|
||||||
|
{1, 1, 1}));
|
||||||
|
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new TemporalPixelShuffleND(this->config.temporal_up_factor));
|
||||||
} else {
|
} else {
|
||||||
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
|
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
|
||||||
4 * this->config.mid_channels,
|
4 * this->config.mid_channels,
|
||||||
@ -344,7 +379,7 @@ namespace LTXVUpsampler {
|
|||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
// x: [b, c, f, h, w]
|
// x: [b, c, f, h, w]
|
||||||
// return: [b, c, f, scaled_h, scaled_w]
|
// return: [b, c, scaled_f, scaled_h, scaled_w]
|
||||||
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
|
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
|
||||||
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
|
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
|
||||||
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
|
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
|
||||||
@ -363,6 +398,12 @@ namespace LTXVUpsampler {
|
|||||||
if (config.rational_resampler) {
|
if (config.rational_resampler) {
|
||||||
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
|
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
|
||||||
x = upsampler->forward(ctx, x);
|
x = upsampler->forward(ctx, x);
|
||||||
|
} else if (config.temporal_upsample) {
|
||||||
|
auto upsample_conv = std::dynamic_pointer_cast<Conv3d>(blocks["upsampler.0"]);
|
||||||
|
auto pixel_shuffle = std::dynamic_pointer_cast<TemporalPixelShuffleND>(blocks["upsampler.1"]);
|
||||||
|
x = upsample_conv->forward(ctx, x); // [b, c*2, f, h, w]
|
||||||
|
x = pixel_shuffle->forward(ctx, x); // [b, c, f*2, h, w]
|
||||||
|
x = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, x->ne[2]); // x[:, :, 1:, :, :]
|
||||||
} else {
|
} else {
|
||||||
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
|
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
|
||||||
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
|
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
|
||||||
@ -415,23 +456,24 @@ namespace LTXVUpsampler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||||
bool has_regular_spatial = has_tensor(tensor_storage_map, "upsampler.0.weight");
|
bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight");
|
||||||
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
||||||
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
|
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
|
||||||
(!has_regular_spatial && !has_rational_spatial)) {
|
(!has_regular_upsampler && !has_rational_spatial)) {
|
||||||
LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors");
|
LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
|
LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
|
||||||
if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample ||
|
if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) ||
|
||||||
config.spatial_up_num < 1 || config.spatial_down_den < 1) {
|
config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) {
|
||||||
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f",
|
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d",
|
||||||
config.dims,
|
config.dims,
|
||||||
config.spatial_upsample,
|
config.spatial_upsample,
|
||||||
config.temporal_upsample,
|
config.temporal_upsample,
|
||||||
config.rational_resampler,
|
config.rational_resampler,
|
||||||
config.spatial_scale);
|
config.spatial_scale,
|
||||||
|
config.temporal_up_factor);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,11 +496,12 @@ namespace LTXVUpsampler {
|
|||||||
}
|
}
|
||||||
model->load_fixed_tensors();
|
model->load_fixed_tensors();
|
||||||
|
|
||||||
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, rational=%d",
|
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d",
|
||||||
config.in_channels,
|
config.in_channels,
|
||||||
config.mid_channels,
|
config.mid_channels,
|
||||||
config.num_blocks_per_stage,
|
config.num_blocks_per_stage,
|
||||||
config.spatial_scale,
|
config.spatial_scale,
|
||||||
|
config.temporal_up_factor,
|
||||||
config.rational_resampler);
|
config.rational_resampler);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -153,7 +153,7 @@ namespace LTXVAE {
|
|||||||
|
|
||||||
GGML_ASSERT(x->ne[2] >= temporal_pad);
|
GGML_ASSERT(x->ne[2] >= temporal_pad);
|
||||||
|
|
||||||
int end_idx = x->ne[2] - temporal_pad;
|
int end_idx = (int)x->ne[2] - temporal_pad;
|
||||||
int start_idx = std::max(end_idx - pad, 0);
|
int start_idx = std::max(end_idx - pad, 0);
|
||||||
|
|
||||||
// Save a contiguous copy of the last `pad` frames so the large `x`
|
// Save a contiguous copy of the last `pad` frames so the large `x`
|
||||||
|
|||||||
@ -2153,19 +2153,41 @@ public:
|
|||||||
int vae_scale_factor = get_vae_scale_factor();
|
int vae_scale_factor = get_vae_scale_factor();
|
||||||
int W = width / vae_scale_factor;
|
int W = width / vae_scale_factor;
|
||||||
int H = height / vae_scale_factor;
|
int H = height / vae_scale_factor;
|
||||||
int T = frames;
|
int T = video_frames_to_latent_frames(frames);
|
||||||
if (sd_version_is_ltxav(version)) {
|
int C = get_latent_channel();
|
||||||
T = ((T - 1) / 8) + 1;
|
|
||||||
} else if (sd_version_is_wan(version)) {
|
|
||||||
T = ((T - 1) / 4) + 1;
|
|
||||||
}
|
|
||||||
int C = get_latent_channel();
|
|
||||||
if (video) {
|
if (video) {
|
||||||
return sd::zeros<float>({W, H, T, C, 1});
|
return sd::zeros<float>({W, H, T, C, 1});
|
||||||
}
|
}
|
||||||
return sd::zeros<float>({W, H, C, 1});
|
return sd::zeros<float>({W, H, C, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int video_frames_to_latent_frames(int frames) {
|
||||||
|
int latent_frames = frames;
|
||||||
|
if (sd_version_is_ltxav(version)) {
|
||||||
|
latent_frames = ((frames - 1) / 8) + 1;
|
||||||
|
} else if (sd_version_is_wan(version)) {
|
||||||
|
latent_frames = ((frames - 1) / 4) + 1;
|
||||||
|
}
|
||||||
|
return latent_frames;
|
||||||
|
}
|
||||||
|
|
||||||
|
int latent_frames_to_video_frames(int latent_frames) {
|
||||||
|
if (latent_frames <= 0) {
|
||||||
|
return latent_frames;
|
||||||
|
}
|
||||||
|
if (sd_version_is_ltxav(version)) {
|
||||||
|
return (latent_frames - 1) * 8 + 1;
|
||||||
|
}
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
return (latent_frames - 1) * 4 + 1;
|
||||||
|
}
|
||||||
|
return latent_frames;
|
||||||
|
}
|
||||||
|
|
||||||
|
int align_video_frames(int frames) {
|
||||||
|
return latent_frames_to_video_frames(video_frames_to_latent_frames(frames));
|
||||||
|
}
|
||||||
|
|
||||||
sd::Tensor<float> encode_to_vae_latents(const sd::Tensor<float>& x) {
|
sd::Tensor<float> encode_to_vae_latents(const sd::Tensor<float>& x) {
|
||||||
auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y);
|
auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y);
|
||||||
if (latents.empty()) {
|
if (latents.empty()) {
|
||||||
@ -3000,16 +3022,12 @@ struct GenerationRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
|
GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||||
prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
||||||
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
||||||
width = sd_vid_gen_params->width;
|
width = sd_vid_gen_params->width;
|
||||||
height = sd_vid_gen_params->height;
|
height = sd_vid_gen_params->height;
|
||||||
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
|
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
|
||||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
frames = sd_ctx->sd->align_video_frames(requested_frames);
|
||||||
frames = ((requested_frames - 1 + 7) / 8) * 8 + 1;
|
|
||||||
} else {
|
|
||||||
frames = (requested_frames - 1) / 4 * 4 + 1;
|
|
||||||
}
|
|
||||||
clip_skip = sd_vid_gen_params->clip_skip;
|
clip_skip = sd_vid_gen_params->clip_skip;
|
||||||
fps = std::max(1, sd_vid_gen_params->fps);
|
fps = std::max(1, sd_vid_gen_params->fps);
|
||||||
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||||
@ -3567,6 +3585,30 @@ static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& pack
|
|||||||
return audio_latent;
|
return audio_latent;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static sd::Tensor<float> make_ltxav_empty_audio_latent(int audio_length) {
|
||||||
|
if (audio_length <= 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
constexpr int kLtxavAudioFrequencyBins = 16;
|
||||||
|
constexpr int kLtxavAudioChannels = 8;
|
||||||
|
return sd::zeros<float>({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
static sd::Tensor<float> resize_ltxav_audio_latent(const sd::Tensor<float>& audio_latent,
|
||||||
|
int target_audio_length) {
|
||||||
|
auto resized = make_ltxav_empty_audio_latent(target_audio_length);
|
||||||
|
if (resized.empty() || audio_latent.empty()) {
|
||||||
|
return resized;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
|
||||||
|
int copy_length = std::min(static_cast<int>(audio_latent.shape()[1]), target_audio_length);
|
||||||
|
if (copy_length > 0) {
|
||||||
|
auto copied = sd::ops::slice(audio_latent, 1, 0, copy_length);
|
||||||
|
sd::ops::slice_assign(&resized, 1, 0, copy_length, copied);
|
||||||
|
}
|
||||||
|
return resized;
|
||||||
|
}
|
||||||
|
|
||||||
static int get_ltxav_num_audio_latents(int frames, int fps) {
|
static int get_ltxav_num_audio_latents(int frames, int fps) {
|
||||||
GGML_ASSERT(frames > 0);
|
GGML_ASSERT(frames > 0);
|
||||||
GGML_ASSERT(fps > 0);
|
GGML_ASSERT(fps > 0);
|
||||||
@ -4396,10 +4438,8 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
constexpr int kLtxavAudioFrequencyBins = 16;
|
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
|
||||||
constexpr int kLtxavAudioChannels = 8;
|
latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length);
|
||||||
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
|
|
||||||
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
@ -4749,9 +4789,9 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
|||||||
(int)vid.shape()[1],
|
(int)vid.shape()[1],
|
||||||
(int)vid.shape()[2],
|
(int)vid.shape()[2],
|
||||||
(int)vid.shape()[3]);
|
(int)vid.shape()[3]);
|
||||||
if (request.requested_frames > 0 &&
|
if (request.frames > 0 &&
|
||||||
vid.shape()[2] > request.requested_frames) {
|
vid.shape()[2] > request.frames) {
|
||||||
vid = sd::ops::slice(vid, 2, 0, request.requested_frames);
|
vid = sd::ops::slice(vid, 2, 0, request.frames);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
|
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
|
||||||
@ -5118,9 +5158,46 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
LOG_INFO("LTX latent spatial upscale completed, taking %.2fs",
|
LOG_INFO("LTX latent spatial upscale completed, taking %.2fs",
|
||||||
(upscale_end - upscale_start) * 1.0f / 1000);
|
(upscale_end - upscale_start) * 1.0f / 1000);
|
||||||
|
|
||||||
x_t = std::move(upscaled_latent);
|
x_t = std::move(upscaled_latent);
|
||||||
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
|
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
|
||||||
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
|
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
|
||||||
|
int upscaled_latent_frames = static_cast<int>(x_t.shape()[2]);
|
||||||
|
int upscaled_frames = sd_ctx->sd->latent_frames_to_video_frames(upscaled_latent_frames);
|
||||||
|
if (upscaled_frames != hires_request.frames) {
|
||||||
|
LOG_INFO("LTX latent upsampler output latent frames %d, frames %d -> %d",
|
||||||
|
upscaled_latent_frames,
|
||||||
|
hires_request.frames,
|
||||||
|
upscaled_frames);
|
||||||
|
hires_request.frames = upscaled_frames;
|
||||||
|
}
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0) {
|
||||||
|
int target_audio_length = get_ltxav_num_audio_latents(hires_request.frames, hires_request.fps);
|
||||||
|
if (target_audio_length != latents.audio_length) {
|
||||||
|
int latent_channels = sd_ctx->sd->get_latent_channel();
|
||||||
|
sd::Tensor<float> video_latent = x_t;
|
||||||
|
sd::Tensor<float> audio_latent = latents.audio_latent;
|
||||||
|
if (x_t.shape()[3] > latent_channels) {
|
||||||
|
video_latent = sd::ops::slice(x_t, 3, 0, latent_channels);
|
||||||
|
audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels);
|
||||||
|
}
|
||||||
|
audio_latent = resize_ltxav_audio_latent(audio_latent, target_audio_length);
|
||||||
|
if (audio_latent.empty()) {
|
||||||
|
LOG_ERROR("failed to resize LTX audio latent for latent upscale: %d -> %d",
|
||||||
|
latents.audio_length,
|
||||||
|
target_audio_length);
|
||||||
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
|
sd_ctx->sd->diffusion_model->free_params_buffer();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
x_t = pack_ltxav_audio_and_video_latents(video_latent, audio_latent);
|
||||||
|
latents.audio_latent = std::move(audio_latent);
|
||||||
|
LOG_INFO("LTX audio latent length adjusted for latent upscale: %d -> %d",
|
||||||
|
latents.audio_length,
|
||||||
|
target_audio_length);
|
||||||
|
latents.audio_length = target_audio_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
if ((request.hires.target_width > 0 || request.hires.target_height > 0) &&
|
if ((request.hires.target_width > 0 || request.hires.target_height > 0) &&
|
||||||
(request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) {
|
(request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) {
|
||||||
LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",
|
LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user