diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 8e2ed694..7258972b 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1127,18 +1127,33 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_conv_3d(ggml_context* ctx, ggml_tensor* w, ggml_tensor* b, int64_t IC, - int s0 = 1, - int s1 = 1, - int s2 = 1, - int p0 = 0, - int p1 = 0, - int p2 = 0, - int d0 = 1, - int d1 = 1, - int d2 = 1) { - int64_t OC = w->ne[3] / IC; - int64_t N = x->ne[3] / IC; - x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); + int s0 = 1, + int s1 = 1, + int s2 = 1, + int p0 = 0, + int p1 = 0, + int p2 = 0, + int d0 = 1, + int d1 = 1, + int d2 = 1, + bool force_prec_f32 = false) { + if (force_prec_f32) { + ggml_tensor* im2col = ggml_im2col_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, w->type); + + int64_t OC = w->ne[3] / IC; + int64_t N = x->ne[3] / IC; + x = ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), + ggml_reshape_2d(ctx, w, w->ne[0] * w->ne[1] * w->ne[2] * IC, OC)); + ggml_mul_mat_set_prec(x, GGML_PREC_F32); + + int64_t OD = im2col->ne[3] / N; + x = ggml_reshape_4d(ctx, x, im2col->ne[1] * im2col->ne[2], OD, N, OC); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); + x = ggml_reshape_4d(ctx, x, im2col->ne[1], im2col->ne[2], OD, OC * N); + } else { + x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); + } if (b != nullptr) { b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] @@ -3133,6 +3148,7 @@ protected: std::tuple padding; std::tuple dilation; bool bias; + bool force_prec_f32; std::string prefix; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { @@ -3156,14 +3172,16 @@ public: std::tuple stride = {1, 1, 1}, std::tuple padding = {0, 0, 0}, std::tuple dilation = {1, 1, 1}, - bool bias = true) + bool bias = true, + bool force_prec_f32 = false) : in_channels(in_channels), out_channels(out_channels), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), - bias(bias) {} + bias(bias), + force_prec_f32(force_prec_f32) {} ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* w = params["weight"]; @@ -3183,7 +3201,8 @@ public: return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(padding), std::get<1>(padding), std::get<0>(padding), - std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); + std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation), + force_prec_f32); } }; diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index a4d47b50..8c41a51c 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -89,7 +89,8 @@ namespace LTXVAE { int kernel_size = 3, std::tuple stride = {1, 1, 1}, int dilation = 1, - bool bias = true) { + bool bias = true, + bool force_prec_f32 = false) { time_kernel_size = kernel_size; blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, out_channels, @@ -97,7 +98,8 @@ namespace LTXVAE { stride, {0, kernel_size / 2, kernel_size / 2}, {dilation, 1, 1}, - bias)); + bias, + force_prec_f32)); } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -469,7 +471,8 @@ namespace LTXVAE { SpaceToDepthDownsample(int64_t in_channels, int64_t out_channels, int factor_t, - int factor_s) + int factor_s, + bool force_conv_prec_f32 = false) : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), @@ -477,7 +480,13 @@ namespace LTXVAE { const int64_t factor = static_cast(factor_t) * static_cast(factor_s) * static_cast(factor_s); GGML_ASSERT(out_channels % factor == 0); - blocks["conv"] = std::make_shared(in_channels, out_channels / factor, 3); + blocks["conv"] = std::make_shared(in_channels, + out_channels / factor, + 3, + std::tuple{1, 1, 1}, + 1, + true, + force_conv_prec_f32); blocks["skip_downsample"] = std::make_shared(in_channels, out_channels, factor_t, factor_s); blocks["conv_downsample"] = std::make_shared(out_channels / factor, out_channels, factor_t, factor_s); } @@ -492,7 +501,7 @@ namespace LTXVAE { if (factor_t > 1 && x->ne[2] > 0) { auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); auto first_frame_pad = first_frame; - for (int i = 1; i < factor_t; ++i) { + for (int i = 1; i < factor_t - 1; ++i) { first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2); } x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2); @@ -550,6 +559,8 @@ namespace LTXVAE { std::vector blocks; }; + static inline EncoderConfig get_default_encoder_config(int version); + static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, const std::string& name) { return tensor_storage_map.find(name) != tensor_storage_map.end(); @@ -633,6 +644,84 @@ namespace LTXVAE { return cfg; } + static inline EncoderConfig infer_encoder_config_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int version) { + EncoderConfig cfg; + const std::string encoder_prefix = prefix + ".encoder.down_blocks."; + + int64_t current_channels = get_tensor_ne0(tensor_storage_map, + prefix + ".encoder.conv_in.conv.bias", + 0); + for (int block_idx = 0;; ++block_idx) { + const std::string block_prefix = encoder_prefix + std::to_string(block_idx); + const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias"; + const std::string conv_bias = block_prefix + ".conv.conv.bias"; + + if (has_tensor(tensor_storage_map, res0_bias)) { + int num_layers = 0; + while (has_tensor(tensor_storage_map, + block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) { + num_layers++; + } + cfg.blocks.push_back({"res_x", num_layers, 1}); + current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels); + continue; + } + + if (!has_tensor(tensor_storage_map, conv_bias)) { + break; + } + + const int64_t conv_channels = get_tensor_ne0(tensor_storage_map, conv_bias); + int64_t next_channels = 0; + for (int next_idx = block_idx + 1;; ++next_idx) { + const std::string next_res0_bias = encoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias"; + const std::string next_conv_bias = encoder_prefix + std::to_string(next_idx) + ".conv.conv.bias"; + if (has_tensor(tensor_storage_map, next_res0_bias)) { + next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias); + break; + } + if (!has_tensor(tensor_storage_map, next_conv_bias)) { + break; + } + } + + const int64_t multiplier = current_channels > 0 && next_channels > 0 && next_channels % current_channels == 0 + ? std::max(1, next_channels / current_channels) + : 1; + const int64_t factor = conv_channels > 0 && next_channels > 0 && next_channels % conv_channels == 0 + ? next_channels / conv_channels + : 0; + + if (factor == 8) { + cfg.blocks.push_back({"compress_all_res", 0, static_cast(multiplier)}); + } else if (factor == 4) { + cfg.blocks.push_back({"compress_space_res", 0, static_cast(multiplier)}); + } else if (factor == 2) { + cfg.blocks.push_back({"compress_time_res", 0, static_cast(multiplier)}); + } else { + LOG_WARN("unexpected LTX VAE encoder downsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all_res x%d", + block_prefix.c_str(), + (long long)conv_channels, + (long long)current_channels, + (long long)next_channels, + (int)multiplier); + cfg.blocks.push_back({"compress_all_res", 0, static_cast(multiplier)}); + } + if (next_channels > 0) { + current_channels = next_channels; + } else if (current_channels > 0) { + current_channels *= multiplier; + } + } + + if (cfg.blocks.empty()) { + return get_default_encoder_config(version); + } + return cfg; + } + static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias"; @@ -647,7 +736,7 @@ namespace LTXVAE { return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end(); } - static inline EncoderConfig get_encoder_config(int version) { + static inline EncoderConfig get_default_encoder_config(int version) { EncoderConfig cfg; if (version < 2) { GGML_ABORT("LTX VAE encoder is only implemented for version >= 2"); @@ -674,6 +763,8 @@ namespace LTXVAE { int64_t latent_channels; Encoder(int version, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, int patch_size = 4, int64_t in_channels = 3, int64_t latent_channels = 128) @@ -681,9 +772,12 @@ namespace LTXVAE { patch_size(patch_size), in_channels(in_channels), latent_channels(latent_channels) { - auto cfg = get_encoder_config(version); - int64_t channels = 128; - int64_t in_dim = in_channels * patch_size * patch_size; + auto cfg = infer_encoder_config_from_weights(tensor_storage_map, prefix, version); + int64_t channels = get_tensor_ne0(tensor_storage_map, + prefix + ".encoder.conv_in.conv.bias", + 0); + GGML_ASSERT(channels > 0); + int64_t in_dim = in_channels * patch_size * patch_size; blocks["conv_in"] = std::make_shared(in_dim, channels, 3); @@ -708,11 +802,14 @@ namespace LTXVAE { 1); channels = next_channels; } else if (block.type == "compress_all_res") { - int64_t next_channels = channels * block.multiplier; + int64_t next_channels = channels * block.multiplier; + // LTX 2.3 encoder down_blocks.7.conv overflows with fp16 accumulation. + bool force_conv_prec_f32 = block_idx == 7; blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared(channels, next_channels, 2, - 2); + 2, + force_conv_prec_f32); channels = next_channels; } else { GGML_ABORT("Unsupported LTX VAE encoder block"); @@ -956,7 +1053,10 @@ namespace LTXVAE { patch_size(patch_size), decode_only(decode_only) { if (!decode_only) { - blocks["encoder"] = std::make_shared(version, patch_size); + blocks["encoder"] = std::make_shared(version, + tensor_storage_map, + prefix, + patch_size); } blocks["decoder"] = std::make_shared(version, tensor_storage_map, @@ -1096,7 +1196,7 @@ struct LTXVideoVAE : public VAE { const sd::Tensor& z, bool decode_graph) override { if (!decode_graph && decode_only) { - LOG_ERROR("LTX video VAE encoder is not implemented yet"); + LOG_ERROR("LTX video VAE encode requires encoder weights; create the context with vae_decode_only=false"); return {}; } sd::Tensor input = z; diff --git a/src/ltxv.hpp b/src/ltxv.hpp index 50142cb1..9b79abfc 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -23,12 +23,28 @@ namespace LTXV { return ggml_rms_norm(ctx, x, eps); } + __STATIC_INLINE__ ggml_tensor* align_token_modulation(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* mod) { + if (mod != nullptr && x != nullptr && mod->ne[1] == 1 && mod->ne[2] == x->ne[1] && x->ne[2] == 1) { + return ggml_permute(ctx, mod, 0, 2, 1, 3); + } + return mod; + } + + __STATIC_INLINE__ ggml_tensor* modulate(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* shift, + ggml_tensor* scale) { + shift = align_token_modulation(ctx, x, shift); + scale = align_token_modulation(ctx, x, scale); + return Flux::modulate(ctx, x, shift, scale, true); + } + __STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx, ggml_tensor* x, ggml_tensor* gate) { - if (gate->ne[1] != 1) { - gate = ggml_reshape_3d(ctx, gate, gate->ne[0], 1, gate->ne[2]); - } + gate = align_token_modulation(ctx, x, gate); return ggml_mul(ctx, x, gate); } @@ -538,7 +554,7 @@ namespace LTXV { auto gate_mlp = mods[5]; auto x_norm = rms_norm(ctx->ggml_ctx, x); - x_norm = Flux::modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa, true); + x_norm = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa); auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa)); @@ -548,12 +564,12 @@ namespace LTXV { auto gate_q = mods[8]; auto q = rms_norm(ctx->ggml_ctx, x); - q = Flux::modulate(ctx->ggml_ctx, q, shift_q, scale_q, true); + q = modulate(ctx->ggml_ctx, q, shift_q, scale_q); auto context_mod = context; if (prompt_timestep != nullptr) { auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep); - context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1], true); + context_mod = modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1]); } auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); @@ -564,7 +580,7 @@ namespace LTXV { } auto y = rms_norm(ctx->ggml_ctx, x); - y = Flux::modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp, true); + y = modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp); auto mlp_out = ff->forward(ctx, y); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; @@ -947,11 +963,11 @@ namespace LTXV { if (cross_attention_adaln) { auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3); auto q = rms_norm(ctx->ggml_ctx, x); - q = Flux::modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1], true); + q = modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1]); auto context_mod = context; if (prompt_timestep != nullptr && prompt_table != nullptr) { auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2); - context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1], true); + context_mod = modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1]); } auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); return apply_gate(ctx->ggml_ctx, out, q_mods[2]); @@ -998,7 +1014,7 @@ namespace LTXV { auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6); auto v_norm = rms_norm(ctx->ggml_ctx, vx); - v_norm = Flux::modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1], true); + v_norm = modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1]); auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2])); auto v_txt = apply_text_cross_attention(ctx, @@ -1016,7 +1032,7 @@ namespace LTXV { if (run_ax) { auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6); auto a_norm = rms_norm(ctx->ggml_ctx, ax); - a_norm = Flux::modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1], true); + a_norm = modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1]); auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2])); auto a_txt = apply_text_cross_attention(ctx, @@ -1039,8 +1055,8 @@ namespace LTXV { auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4); - auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0], true); - auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0], true); + auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0]); + auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0]); auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe); auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5); auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0]; @@ -1052,8 +1068,8 @@ namespace LTXV { auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4); - auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2], true); - auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2], true); + auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2]); + auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2]); auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe); auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5); auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; @@ -1061,14 +1077,14 @@ namespace LTXV { } auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); - ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true); + ax_scaled = modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1]); auto a_ff_out = audio_ff->forward(ctx, ax_scaled); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2])); } auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto vx_scaled = rms_norm(ctx->ggml_ctx, vx); - vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1], true); + vx_scaled = modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1]); auto v_ff_out = ff->forward(ctx, vx_scaled); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2])); @@ -1188,6 +1204,15 @@ namespace LTXV { return ax; } + ggml_tensor* repeat_scalar_timestep_like(GGMLRunnerContext* ctx, ggml_tensor* timestep, ggml_tensor* like) { + GGML_ASSERT(timestep != nullptr && like != nullptr); + if (timestep->ne[0] == like->ne[0]) { + return timestep; + } + GGML_ASSERT(timestep->ne[0] == 1); + return ggml_repeat(ctx->ggml_ctx, timestep, ggml_new_tensor_1d(ctx->ggml_ctx, timestep->type, like->ne[0])); + } + ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) { if (ax == nullptr) { return nullptr; @@ -1367,21 +1392,24 @@ namespace LTXV { if (cfg.cross_attention_adaln) { auto prompt_adaln_single = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); auto audio_prompt_adaln_single = std::dynamic_pointer_cast(blocks["audio_prompt_adaln_single"]); - v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, v_timestep_scaled).first; + v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first; a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first; } + auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep); + auto av_ca_audio_timestep = effective_audio_timestep; + auto av_ca_factor = cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier; auto av_ca_video_scale_shift_timestep = - std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, a_timestep_scaled).first; + std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first; auto av_ca_a2v_gate_noise_timestep = std::dynamic_pointer_cast(blocks["av_ca_a2v_gate_adaln_single"]) - ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, a_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) + ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_video_timestep, av_ca_factor)) .first; auto av_ca_audio_scale_shift_timestep = - std::dynamic_pointer_cast(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, v_timestep_scaled).first; + std::dynamic_pointer_cast(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, av_ca_audio_timestep).first; auto av_ca_v2a_gate_noise_timestep = std::dynamic_pointer_cast(blocks["av_ca_v2a_gate_adaln_single"]) - ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, v_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) + ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor)) .first; for (int i = 0; i < cfg.num_layers; i++) { @@ -1410,14 +1438,14 @@ namespace LTXV { auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size); vx = norm_out->forward(ctx, vx); - vx = Flux::modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1], true); + vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]); vx = proj_out->forward(ctx, vx); vx = unpatchify_video(ctx, vx, width, height, frames); if (ax != nullptr && audio_time > 0) { auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size); ax = audio_norm_out->forward(ctx, ax); - ax = Flux::modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1], true); + ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]); ax = audio_proj_out->forward(ctx, ax); ax = unpatchify_audio(ctx, ax, audio_time); } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index ac2ff224..3491cadb 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -458,10 +458,6 @@ public: // Might need vae encode for control cond vae_decode_only = false; } - if (sd_version_is_ltxav(version)) { - vae_decode_only = true; - } - bool tae_preview_only = sd_ctx_params->tae_preview_only; if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { tae_preview_only = false; @@ -705,7 +701,7 @@ public: params_backend_for(SDBackendModule::VAE), tensor_storage_map, "first_stage_model", - true, + vae_decode_only, version); } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || @@ -1570,6 +1566,38 @@ public: } } + std::vector process_ltxav_video_timesteps(const std::vector& timesteps, + const sd::Tensor& init_latent, + const sd::Tensor& denoise_mask) { + if (timesteps.empty() || denoise_mask.empty() || init_latent.dim() < 4 || denoise_mask.dim() < 4) { + return timesteps; + } + + int64_t width = init_latent.shape()[0]; + int64_t height = init_latent.shape()[1]; + int64_t frames = init_latent.shape()[2]; + if (denoise_mask.shape()[0] != width || + denoise_mask.shape()[1] != height || + denoise_mask.shape()[2] != frames || + denoise_mask.shape()[3] < 1) { + LOG_WARN("unexpected LTXAV denoise mask shape for timestep processing"); + return timesteps; + } + + std::vector video_timesteps(static_cast(width * height * frames)); + size_t idx = 0; + for (int64_t t = 0; t < frames; ++t) { + for (int64_t h = 0; h < height; ++h) { + for (int64_t w = 0; w < width; ++w) { + float mask = denoise_mask.dim() == 5 ? denoise_mask.index(w, h, t, 0, 0) + : denoise_mask.index(w, h, t, 0); + video_timesteps[idx++] = mask * timesteps[0]; + } + } + } + return video_timesteps; + } + void preview_image(int step, const sd::Tensor& latents, enum SDVersion version, @@ -1846,14 +1874,24 @@ public: float c_out = scaling[1]; float c_in = scaling[2]; - std::vector timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep); - timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask); - adjust_sample_step_scalings(shifted_timestep, timesteps_vec, c_in, &c_skip, &c_out); + std::vector base_timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep); + std::vector timesteps_vec = base_timesteps_vec; + sd::Tensor audio_timesteps_tensor; + if (sd_version_is_ltxav(version) && !denoise_mask.empty()) { + timesteps_vec = process_ltxav_video_timesteps(base_timesteps_vec, init_latent, denoise_mask); + audio_timesteps_tensor = sd::Tensor({static_cast(base_timesteps_vec.size())}, base_timesteps_vec); + } else { + timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask); + } + const std::vector& scaling_timesteps_vec = (sd_version_is_ltxav(version) && !denoise_mask.empty()) + ? base_timesteps_vec + : timesteps_vec; + adjust_sample_step_scalings(shifted_timestep, scaling_timesteps_vec, c_in, &c_skip, &c_out); sd::Tensor timesteps_tensor({static_cast(timesteps_vec.size())}, timesteps_vec); sd::Tensor guidance_tensor({1}, std::vector{guidance.distilled_guidance}); sd::Tensor noised_input = x * c_in; - if (!denoise_mask.empty() && version == VERSION_WAN2_2_TI2V) { + if (!denoise_mask.empty() && (version == VERSION_WAN2_2_TI2V || sd_version_is_ltxav(version))) { noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask); } @@ -1884,6 +1922,7 @@ public: DiffusionParams diffusion_params; diffusion_params.x = &noised_input; diffusion_params.timesteps = ×teps_tensor; + diffusion_params.audio_timesteps = audio_timesteps_tensor.empty() ? nullptr : &audio_timesteps_tensor; diffusion_params.guidance = &guidance_tensor; diffusion_params.ref_latents = &ref_latents; diffusion_params.increase_ref_index = increase_ref_index; @@ -2916,6 +2955,7 @@ struct GenerationRequest { vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); seed = sd_vid_gen_params->seed; + strength = sd_vid_gen_params->strength; cache_params = &sd_vid_gen_params->cache; vace_strength = sd_vid_gen_params->vace_strength; guidance = sd_vid_gen_params->sample_params.guidance; @@ -3204,6 +3244,55 @@ static sd::Tensor pack_ltxav_audio_and_video_latents(const sd::Tensor pack_ltxav_audio_and_video_denoise_mask(const sd::Tensor& video_mask, + const sd::Tensor& video_latent, + const sd::Tensor& audio_latent) { + if (video_mask.empty() || audio_latent.empty()) { + return video_mask; + } + + GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5); + GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4); + if (video_latent.dim() == 5) { + GGML_ASSERT(video_latent.shape()[4] == 1); + } + if (audio_latent.dim() == 4) { + GGML_ASSERT(audio_latent.shape()[3] == 1); + } + + int64_t width = video_latent.shape()[0]; + int64_t height = video_latent.shape()[1]; + int64_t frames = video_latent.shape()[2]; + int64_t video_ch = video_latent.shape()[3]; + int64_t spatial_size = width * height * frames; + int64_t audio_values = audio_latent.numel(); + int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size; + + GGML_ASSERT(video_mask.dim() == video_latent.dim()); + GGML_ASSERT(video_mask.shape()[0] == width); + GGML_ASSERT(video_mask.shape()[1] == height); + GGML_ASSERT(video_mask.shape()[2] == frames); + if (video_mask.dim() == 5) { + GGML_ASSERT(video_mask.shape()[4] == video_latent.shape()[4]); + } + + int64_t mask_ch = video_mask.shape()[3]; + if (mask_ch == video_ch + extra_ch) { + return video_mask; + } + GGML_ASSERT(mask_ch == 1 || mask_ch == video_ch); + + sd::Tensor video_mask_full = video_mask; + if (mask_ch == 1 && video_ch != 1) { + video_mask_full = video_mask * sd::Tensor::ones(video_latent.shape()); + } + + std::vector audio_mask_shape = video_latent.shape(); + audio_mask_shape[3] = extra_ch; + auto audio_mask = sd::Tensor::ones(audio_mask_shape); + return sd::ops::concat(video_mask_full, audio_mask, 3); +} + static sd::Tensor unpack_ltxav_audio_latent(const sd::Tensor& packed_latent, int audio_length, int video_channels) { @@ -4030,10 +4119,47 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version)) { - if (!start_image.empty() || !end_image.empty() || sd_vid_gen_params->control_frames_size > 0) { - LOG_ERROR("LTXAV currently supports txt2vid only; init_image, end_image, and control_frames are not implemented"); + if (!end_image.empty() || sd_vid_gen_params->control_frames_size > 0) { + LOG_ERROR("LTXAV currently supports txt2vid and init_image i2v only; end_image and control_frames are not implemented"); return std::nullopt; } + + if (!start_image.empty()) { + if (sd_ctx->sd->vae_decode_only) { + LOG_ERROR("LTXAV init_image i2v requires VAE encoder weights; create the context with vae_decode_only=false"); + return std::nullopt; + } + + LOG_INFO("IMG2VID"); + + int64_t t1 = ggml_time_ms(); + auto init_img = start_image.reshape({start_image.shape()[0], + start_image.shape()[1], + 1, + start_image.shape()[2], + start_image.shape()[3]}); + auto init_image_latent = sd_ctx->sd->encode_first_stage(init_img); + if (init_image_latent.empty()) { + LOG_ERROR("failed to encode LTXAV init image"); + return std::nullopt; + } + + latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true); + sd::ops::slice_assign(&latents.init_latent, 2, 0, init_image_latent.shape()[2], init_image_latent); + + float conditioning_strength = std::clamp(request->strength, 0.f, 1.f); + float conditioned_mask = 1.0f - conditioning_strength; + latents.denoise_mask = sd::full({latents.init_latent.shape()[0], + latents.init_latent.shape()[1], + latents.init_latent.shape()[2], + 1, + 1}, + 1.f); + sd::ops::fill_slice(&latents.denoise_mask, 2, 0, init_image_latent.shape()[2], conditioned_mask); + + int64_t t2 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); + } } if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || @@ -4207,6 +4333,11 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) { + if (!latents.denoise_mask.empty()) { + latents.denoise_mask = pack_ltxav_audio_and_video_denoise_mask(latents.denoise_mask, + latents.init_latent, + latents.audio_latent); + } latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent); }