add i2v support

This commit is contained in:
leejet 2026-05-17 03:16:36 +08:00
parent f8a0330d37
commit 18fbb4cdfb
4 changed files with 341 additions and 63 deletions

View File

@ -1135,10 +1135,25 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_conv_3d(ggml_context* ctx,
int p2 = 0, int p2 = 0,
int d0 = 1, int d0 = 1,
int d1 = 1, int d1 = 1,
int d2 = 1) { int d2 = 1,
bool force_prec_f32 = false) {
if (force_prec_f32) {
ggml_tensor* im2col = ggml_im2col_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, w->type);
int64_t OC = w->ne[3] / IC; int64_t OC = w->ne[3] / IC;
int64_t N = x->ne[3] / IC; int64_t N = x->ne[3] / IC;
x = ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]),
ggml_reshape_2d(ctx, w, w->ne[0] * w->ne[1] * w->ne[2] * IC, OC));
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
int64_t OD = im2col->ne[3] / N;
x = ggml_reshape_4d(ctx, x, im2col->ne[1] * im2col->ne[2], OD, N, OC);
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2));
x = ggml_reshape_4d(ctx, x, im2col->ne[1], im2col->ne[2], OD, OC * N);
} else {
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
if (b != nullptr) { if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
@ -3133,6 +3148,7 @@ protected:
std::tuple<int, int, int> padding; std::tuple<int, int, int> padding;
std::tuple<int, int, int> dilation; std::tuple<int, int, int> dilation;
bool bias; bool bias;
bool force_prec_f32;
std::string prefix; std::string prefix;
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
@ -3156,14 +3172,16 @@ public:
std::tuple<int, int, int> stride = {1, 1, 1}, std::tuple<int, int, int> stride = {1, 1, 1},
std::tuple<int, int, int> padding = {0, 0, 0}, std::tuple<int, int, int> padding = {0, 0, 0},
std::tuple<int, int, int> dilation = {1, 1, 1}, std::tuple<int, int, int> dilation = {1, 1, 1},
bool bias = true) bool bias = true,
bool force_prec_f32 = false)
: in_channels(in_channels), : in_channels(in_channels),
out_channels(out_channels), out_channels(out_channels),
kernel_size(kernel_size), kernel_size(kernel_size),
stride(stride), stride(stride),
padding(padding), padding(padding),
dilation(dilation), dilation(dilation),
bias(bias) {} bias(bias),
force_prec_f32(force_prec_f32) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* w = params["weight"]; ggml_tensor* w = params["weight"];
@ -3183,7 +3201,8 @@ public:
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding), std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation),
force_prec_f32);
} }
}; };

View File

@ -89,7 +89,8 @@ namespace LTXVAE {
int kernel_size = 3, int kernel_size = 3,
std::tuple<int, int, int> stride = {1, 1, 1}, std::tuple<int, int, int> stride = {1, 1, 1},
int dilation = 1, int dilation = 1,
bool bias = true) { bool bias = true,
bool force_prec_f32 = false) {
time_kernel_size = kernel_size; time_kernel_size = kernel_size;
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels, blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
out_channels, out_channels,
@ -97,7 +98,8 @@ namespace LTXVAE {
stride, stride,
{0, kernel_size / 2, kernel_size / 2}, {0, kernel_size / 2, kernel_size / 2},
{dilation, 1, 1}, {dilation, 1, 1},
bias)); bias,
force_prec_f32));
} }
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -469,7 +471,8 @@ namespace LTXVAE {
SpaceToDepthDownsample(int64_t in_channels, SpaceToDepthDownsample(int64_t in_channels,
int64_t out_channels, int64_t out_channels,
int factor_t, int factor_t,
int factor_s) int factor_s,
bool force_conv_prec_f32 = false)
: in_channels(in_channels), : in_channels(in_channels),
out_channels(out_channels), out_channels(out_channels),
factor_t(factor_t), factor_t(factor_t),
@ -477,7 +480,13 @@ namespace LTXVAE {
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s); const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
GGML_ASSERT(out_channels % factor == 0); GGML_ASSERT(out_channels % factor == 0);
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_channels / factor, 3); blocks["conv"] = std::make_shared<CausalConv3d>(in_channels,
out_channels / factor,
3,
std::tuple<int, int, int>{1, 1, 1},
1,
true,
force_conv_prec_f32);
blocks["skip_downsample"] = std::make_shared<WAN::AvgDown3D>(in_channels, out_channels, factor_t, factor_s); blocks["skip_downsample"] = std::make_shared<WAN::AvgDown3D>(in_channels, out_channels, factor_t, factor_s);
blocks["conv_downsample"] = std::make_shared<WAN::AvgDown3D>(out_channels / factor, out_channels, factor_t, factor_s); blocks["conv_downsample"] = std::make_shared<WAN::AvgDown3D>(out_channels / factor, out_channels, factor_t, factor_s);
} }
@ -492,7 +501,7 @@ namespace LTXVAE {
if (factor_t > 1 && x->ne[2] > 0) { if (factor_t > 1 && x->ne[2] > 0) {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame; auto first_frame_pad = first_frame;
for (int i = 1; i < factor_t; ++i) { for (int i = 1; i < factor_t - 1; ++i) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2); first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
} }
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2); x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
@ -550,6 +559,8 @@ namespace LTXVAE {
std::vector<Block> blocks; std::vector<Block> blocks;
}; };
static inline EncoderConfig get_default_encoder_config(int version);
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
const std::string& name) { const std::string& name) {
return tensor_storage_map.find(name) != tensor_storage_map.end(); return tensor_storage_map.find(name) != tensor_storage_map.end();
@ -633,6 +644,84 @@ namespace LTXVAE {
return cfg; return cfg;
} }
static inline EncoderConfig infer_encoder_config_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int version) {
EncoderConfig cfg;
const std::string encoder_prefix = prefix + ".encoder.down_blocks.";
int64_t current_channels = get_tensor_ne0(tensor_storage_map,
prefix + ".encoder.conv_in.conv.bias",
0);
for (int block_idx = 0;; ++block_idx) {
const std::string block_prefix = encoder_prefix + std::to_string(block_idx);
const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias";
const std::string conv_bias = block_prefix + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, res0_bias)) {
int num_layers = 0;
while (has_tensor(tensor_storage_map,
block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) {
num_layers++;
}
cfg.blocks.push_back({"res_x", num_layers, 1});
current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels);
continue;
}
if (!has_tensor(tensor_storage_map, conv_bias)) {
break;
}
const int64_t conv_channels = get_tensor_ne0(tensor_storage_map, conv_bias);
int64_t next_channels = 0;
for (int next_idx = block_idx + 1;; ++next_idx) {
const std::string next_res0_bias = encoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias";
const std::string next_conv_bias = encoder_prefix + std::to_string(next_idx) + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, next_res0_bias)) {
next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias);
break;
}
if (!has_tensor(tensor_storage_map, next_conv_bias)) {
break;
}
}
const int64_t multiplier = current_channels > 0 && next_channels > 0 && next_channels % current_channels == 0
? std::max<int64_t>(1, next_channels / current_channels)
: 1;
const int64_t factor = conv_channels > 0 && next_channels > 0 && next_channels % conv_channels == 0
? next_channels / conv_channels
: 0;
if (factor == 8) {
cfg.blocks.push_back({"compress_all_res", 0, static_cast<int>(multiplier)});
} else if (factor == 4) {
cfg.blocks.push_back({"compress_space_res", 0, static_cast<int>(multiplier)});
} else if (factor == 2) {
cfg.blocks.push_back({"compress_time_res", 0, static_cast<int>(multiplier)});
} else {
LOG_WARN("unexpected LTX VAE encoder downsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all_res x%d",
block_prefix.c_str(),
(long long)conv_channels,
(long long)current_channels,
(long long)next_channels,
(int)multiplier);
cfg.blocks.push_back({"compress_all_res", 0, static_cast<int>(multiplier)});
}
if (next_channels > 0) {
current_channels = next_channels;
} else if (current_channels > 0) {
current_channels *= multiplier;
}
}
if (cfg.blocks.empty()) {
return get_default_encoder_config(version);
}
return cfg;
}
static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map, static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) { const std::string& prefix) {
const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias"; const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias";
@ -647,7 +736,7 @@ namespace LTXVAE {
return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end(); return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end();
} }
static inline EncoderConfig get_encoder_config(int version) { static inline EncoderConfig get_default_encoder_config(int version) {
EncoderConfig cfg; EncoderConfig cfg;
if (version < 2) { if (version < 2) {
GGML_ABORT("LTX VAE encoder is only implemented for version >= 2"); GGML_ABORT("LTX VAE encoder is only implemented for version >= 2");
@ -674,6 +763,8 @@ namespace LTXVAE {
int64_t latent_channels; int64_t latent_channels;
Encoder(int version, Encoder(int version,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int patch_size = 4, int patch_size = 4,
int64_t in_channels = 3, int64_t in_channels = 3,
int64_t latent_channels = 128) int64_t latent_channels = 128)
@ -681,8 +772,11 @@ namespace LTXVAE {
patch_size(patch_size), patch_size(patch_size),
in_channels(in_channels), in_channels(in_channels),
latent_channels(latent_channels) { latent_channels(latent_channels) {
auto cfg = get_encoder_config(version); auto cfg = infer_encoder_config_from_weights(tensor_storage_map, prefix, version);
int64_t channels = 128; int64_t channels = get_tensor_ne0(tensor_storage_map,
prefix + ".encoder.conv_in.conv.bias",
0);
GGML_ASSERT(channels > 0);
int64_t in_dim = in_channels * patch_size * patch_size; int64_t in_dim = in_channels * patch_size * patch_size;
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_dim, channels, 3); blocks["conv_in"] = std::make_shared<CausalConv3d>(in_dim, channels, 3);
@ -709,10 +803,13 @@ namespace LTXVAE {
channels = next_channels; channels = next_channels;
} else if (block.type == "compress_all_res") { } else if (block.type == "compress_all_res") {
int64_t next_channels = channels * block.multiplier; int64_t next_channels = channels * block.multiplier;
// LTX 2.3 encoder down_blocks.7.conv overflows with fp16 accumulation.
bool force_conv_prec_f32 = block_idx == 7;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels, blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels, next_channels,
2, 2,
2); 2,
force_conv_prec_f32);
channels = next_channels; channels = next_channels;
} else { } else {
GGML_ABORT("Unsupported LTX VAE encoder block"); GGML_ABORT("Unsupported LTX VAE encoder block");
@ -956,7 +1053,10 @@ namespace LTXVAE {
patch_size(patch_size), patch_size(patch_size),
decode_only(decode_only) { decode_only(decode_only) {
if (!decode_only) { if (!decode_only) {
blocks["encoder"] = std::make_shared<Encoder>(version, patch_size); blocks["encoder"] = std::make_shared<Encoder>(version,
tensor_storage_map,
prefix,
patch_size);
} }
blocks["decoder"] = std::make_shared<Decoder>(version, blocks["decoder"] = std::make_shared<Decoder>(version,
tensor_storage_map, tensor_storage_map,
@ -1096,7 +1196,7 @@ struct LTXVideoVAE : public VAE {
const sd::Tensor<float>& z, const sd::Tensor<float>& z,
bool decode_graph) override { bool decode_graph) override {
if (!decode_graph && decode_only) { if (!decode_graph && decode_only) {
LOG_ERROR("LTX video VAE encoder is not implemented yet"); LOG_ERROR("LTX video VAE encode requires encoder weights; create the context with vae_decode_only=false");
return {}; return {};
} }
sd::Tensor<float> input = z; sd::Tensor<float> input = z;

View File

@ -23,12 +23,28 @@ namespace LTXV {
return ggml_rms_norm(ctx, x, eps); return ggml_rms_norm(ctx, x, eps);
} }
__STATIC_INLINE__ ggml_tensor* align_token_modulation(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* mod) {
if (mod != nullptr && x != nullptr && mod->ne[1] == 1 && mod->ne[2] == x->ne[1] && x->ne[2] == 1) {
return ggml_permute(ctx, mod, 0, 2, 1, 3);
}
return mod;
}
__STATIC_INLINE__ ggml_tensor* modulate(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* shift,
ggml_tensor* scale) {
shift = align_token_modulation(ctx, x, shift);
scale = align_token_modulation(ctx, x, scale);
return Flux::modulate(ctx, x, shift, scale, true);
}
__STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx, __STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx,
ggml_tensor* x, ggml_tensor* x,
ggml_tensor* gate) { ggml_tensor* gate) {
if (gate->ne[1] != 1) { gate = align_token_modulation(ctx, x, gate);
gate = ggml_reshape_3d(ctx, gate, gate->ne[0], 1, gate->ne[2]);
}
return ggml_mul(ctx, x, gate); return ggml_mul(ctx, x, gate);
} }
@ -538,7 +554,7 @@ namespace LTXV {
auto gate_mlp = mods[5]; auto gate_mlp = mods[5];
auto x_norm = rms_norm(ctx->ggml_ctx, x); auto x_norm = rms_norm(ctx->ggml_ctx, x);
x_norm = Flux::modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa, true); x_norm = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa);
auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe); auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe);
x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa)); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa));
@ -548,12 +564,12 @@ namespace LTXV {
auto gate_q = mods[8]; auto gate_q = mods[8];
auto q = rms_norm(ctx->ggml_ctx, x); auto q = rms_norm(ctx->ggml_ctx, x);
q = Flux::modulate(ctx->ggml_ctx, q, shift_q, scale_q, true); q = modulate(ctx->ggml_ctx, q, shift_q, scale_q);
auto context_mod = context; auto context_mod = context;
if (prompt_timestep != nullptr) { if (prompt_timestep != nullptr) {
auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep); auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep);
context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1], true); context_mod = modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1]);
} }
auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr);
@ -564,7 +580,7 @@ namespace LTXV {
} }
auto y = rms_norm(ctx->ggml_ctx, x); auto y = rms_norm(ctx->ggml_ctx, x);
y = Flux::modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp, true); y = modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp);
auto mlp_out = ff->forward(ctx, y); auto mlp_out = ff->forward(ctx, y);
x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp)); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp));
return x; return x;
@ -947,11 +963,11 @@ namespace LTXV {
if (cross_attention_adaln) { if (cross_attention_adaln) {
auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3); auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3);
auto q = rms_norm(ctx->ggml_ctx, x); auto q = rms_norm(ctx->ggml_ctx, x);
q = Flux::modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1], true); q = modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1]);
auto context_mod = context; auto context_mod = context;
if (prompt_timestep != nullptr && prompt_table != nullptr) { if (prompt_timestep != nullptr && prompt_table != nullptr) {
auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2); auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2);
context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1], true); context_mod = modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1]);
} }
auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr);
return apply_gate(ctx->ggml_ctx, out, q_mods[2]); return apply_gate(ctx->ggml_ctx, out, q_mods[2]);
@ -998,7 +1014,7 @@ namespace LTXV {
auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6); auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6);
auto v_norm = rms_norm(ctx->ggml_ctx, vx); auto v_norm = rms_norm(ctx->ggml_ctx, vx);
v_norm = Flux::modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1], true); v_norm = modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1]);
auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe); auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe);
vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2])); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2]));
auto v_txt = apply_text_cross_attention(ctx, auto v_txt = apply_text_cross_attention(ctx,
@ -1016,7 +1032,7 @@ namespace LTXV {
if (run_ax) { if (run_ax) {
auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6); auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6);
auto a_norm = rms_norm(ctx->ggml_ctx, ax); auto a_norm = rms_norm(ctx->ggml_ctx, ax);
a_norm = Flux::modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1], true); a_norm = modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1]);
auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe); auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe);
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2])); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2]));
auto a_txt = apply_text_cross_attention(ctx, auto a_txt = apply_text_cross_attention(ctx,
@ -1039,8 +1055,8 @@ namespace LTXV {
auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4);
auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4);
auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4); auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4);
auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0], true); auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0]);
auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0], true); auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0]);
auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe); auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe);
auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5); auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5);
auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0]; auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0];
@ -1052,8 +1068,8 @@ namespace LTXV {
auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4);
auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4);
auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4); auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4);
auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2], true); auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2]);
auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2], true); auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2]);
auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe); auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe);
auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5); auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5);
auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0];
@ -1061,14 +1077,14 @@ namespace LTXV {
} }
auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3);
auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); auto ax_scaled = rms_norm(ctx->ggml_ctx, ax);
ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true); ax_scaled = modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1]);
auto a_ff_out = audio_ff->forward(ctx, ax_scaled); auto a_ff_out = audio_ff->forward(ctx, ax_scaled);
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2])); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2]));
} }
auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3);
auto vx_scaled = rms_norm(ctx->ggml_ctx, vx); auto vx_scaled = rms_norm(ctx->ggml_ctx, vx);
vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1], true); vx_scaled = modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1]);
auto v_ff_out = ff->forward(ctx, vx_scaled); auto v_ff_out = ff->forward(ctx, vx_scaled);
vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2])); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2]));
@ -1188,6 +1204,15 @@ namespace LTXV {
return ax; return ax;
} }
ggml_tensor* repeat_scalar_timestep_like(GGMLRunnerContext* ctx, ggml_tensor* timestep, ggml_tensor* like) {
GGML_ASSERT(timestep != nullptr && like != nullptr);
if (timestep->ne[0] == like->ne[0]) {
return timestep;
}
GGML_ASSERT(timestep->ne[0] == 1);
return ggml_repeat(ctx->ggml_ctx, timestep, ggml_new_tensor_1d(ctx->ggml_ctx, timestep->type, like->ne[0]));
}
ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) { ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) {
if (ax == nullptr) { if (ax == nullptr) {
return nullptr; return nullptr;
@ -1367,21 +1392,24 @@ namespace LTXV {
if (cfg.cross_attention_adaln) { if (cfg.cross_attention_adaln) {
auto prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["prompt_adaln_single"]); auto prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["prompt_adaln_single"]);
auto audio_prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["audio_prompt_adaln_single"]); auto audio_prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["audio_prompt_adaln_single"]);
v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, v_timestep_scaled).first; v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first;
a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first; a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first;
} }
auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep);
auto av_ca_audio_timestep = effective_audio_timestep;
auto av_ca_factor = cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier;
auto av_ca_video_scale_shift_timestep = auto av_ca_video_scale_shift_timestep =
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, a_timestep_scaled).first; std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first;
auto av_ca_a2v_gate_noise_timestep = auto av_ca_a2v_gate_noise_timestep =
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_a2v_gate_adaln_single"]) std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_a2v_gate_adaln_single"])
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, a_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_video_timestep, av_ca_factor))
.first; .first;
auto av_ca_audio_scale_shift_timestep = auto av_ca_audio_scale_shift_timestep =
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, v_timestep_scaled).first; std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, av_ca_audio_timestep).first;
auto av_ca_v2a_gate_noise_timestep = auto av_ca_v2a_gate_noise_timestep =
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_v2a_gate_adaln_single"]) std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_v2a_gate_adaln_single"])
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, v_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor))
.first; .first;
for (int i = 0; i < cfg.num_layers; i++) { for (int i = 0; i < cfg.num_layers; i++) {
@ -1410,14 +1438,14 @@ namespace LTXV {
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size); auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);
vx = norm_out->forward(ctx, vx); vx = norm_out->forward(ctx, vx);
vx = Flux::modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1], true); vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]);
vx = proj_out->forward(ctx, vx); vx = proj_out->forward(ctx, vx);
vx = unpatchify_video(ctx, vx, width, height, frames); vx = unpatchify_video(ctx, vx, width, height, frames);
if (ax != nullptr && audio_time > 0) { if (ax != nullptr && audio_time > 0) {
auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size); auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size);
ax = audio_norm_out->forward(ctx, ax); ax = audio_norm_out->forward(ctx, ax);
ax = Flux::modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1], true); ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]);
ax = audio_proj_out->forward(ctx, ax); ax = audio_proj_out->forward(ctx, ax);
ax = unpatchify_audio(ctx, ax, audio_time); ax = unpatchify_audio(ctx, ax, audio_time);
} }

View File

@ -458,10 +458,6 @@ public:
// Might need vae encode for control cond // Might need vae encode for control cond
vae_decode_only = false; vae_decode_only = false;
} }
if (sd_version_is_ltxav(version)) {
vae_decode_only = true;
}
bool tae_preview_only = sd_ctx_params->tae_preview_only; bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
tae_preview_only = false; tae_preview_only = false;
@ -705,7 +701,7 @@ public:
params_backend_for(SDBackendModule::VAE), params_backend_for(SDBackendModule::VAE),
tensor_storage_map, tensor_storage_map,
"first_stage_model", "first_stage_model",
true, vae_decode_only,
version); version);
} else if (sd_version_is_wan(version) || } else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
@ -1570,6 +1566,38 @@ public:
} }
} }
std::vector<float> process_ltxav_video_timesteps(const std::vector<float>& timesteps,
const sd::Tensor<float>& init_latent,
const sd::Tensor<float>& denoise_mask) {
if (timesteps.empty() || denoise_mask.empty() || init_latent.dim() < 4 || denoise_mask.dim() < 4) {
return timesteps;
}
int64_t width = init_latent.shape()[0];
int64_t height = init_latent.shape()[1];
int64_t frames = init_latent.shape()[2];
if (denoise_mask.shape()[0] != width ||
denoise_mask.shape()[1] != height ||
denoise_mask.shape()[2] != frames ||
denoise_mask.shape()[3] < 1) {
LOG_WARN("unexpected LTXAV denoise mask shape for timestep processing");
return timesteps;
}
std::vector<float> video_timesteps(static_cast<size_t>(width * height * frames));
size_t idx = 0;
for (int64_t t = 0; t < frames; ++t) {
for (int64_t h = 0; h < height; ++h) {
for (int64_t w = 0; w < width; ++w) {
float mask = denoise_mask.dim() == 5 ? denoise_mask.index(w, h, t, 0, 0)
: denoise_mask.index(w, h, t, 0);
video_timesteps[idx++] = mask * timesteps[0];
}
}
}
return video_timesteps;
}
void preview_image(int step, void preview_image(int step,
const sd::Tensor<float>& latents, const sd::Tensor<float>& latents,
enum SDVersion version, enum SDVersion version,
@ -1846,14 +1874,24 @@ public:
float c_out = scaling[1]; float c_out = scaling[1];
float c_in = scaling[2]; float c_in = scaling[2];
std::vector<float> timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep); std::vector<float> base_timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep);
std::vector<float> timesteps_vec = base_timesteps_vec;
sd::Tensor<float> audio_timesteps_tensor;
if (sd_version_is_ltxav(version) && !denoise_mask.empty()) {
timesteps_vec = process_ltxav_video_timesteps(base_timesteps_vec, init_latent, denoise_mask);
audio_timesteps_tensor = sd::Tensor<float>({static_cast<int64_t>(base_timesteps_vec.size())}, base_timesteps_vec);
} else {
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask); timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
adjust_sample_step_scalings(shifted_timestep, timesteps_vec, c_in, &c_skip, &c_out); }
const std::vector<float>& scaling_timesteps_vec = (sd_version_is_ltxav(version) && !denoise_mask.empty())
? base_timesteps_vec
: timesteps_vec;
adjust_sample_step_scalings(shifted_timestep, scaling_timesteps_vec, c_in, &c_skip, &c_out);
sd::Tensor<float> timesteps_tensor({static_cast<int64_t>(timesteps_vec.size())}, timesteps_vec); sd::Tensor<float> timesteps_tensor({static_cast<int64_t>(timesteps_vec.size())}, timesteps_vec);
sd::Tensor<float> guidance_tensor({1}, std::vector<float>{guidance.distilled_guidance}); sd::Tensor<float> guidance_tensor({1}, std::vector<float>{guidance.distilled_guidance});
sd::Tensor<float> noised_input = x * c_in; sd::Tensor<float> noised_input = x * c_in;
if (!denoise_mask.empty() && version == VERSION_WAN2_2_TI2V) { if (!denoise_mask.empty() && (version == VERSION_WAN2_2_TI2V || sd_version_is_ltxav(version))) {
noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask); noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask);
} }
@ -1884,6 +1922,7 @@ public:
DiffusionParams diffusion_params; DiffusionParams diffusion_params;
diffusion_params.x = &noised_input; diffusion_params.x = &noised_input;
diffusion_params.timesteps = &timesteps_tensor; diffusion_params.timesteps = &timesteps_tensor;
diffusion_params.audio_timesteps = audio_timesteps_tensor.empty() ? nullptr : &audio_timesteps_tensor;
diffusion_params.guidance = &guidance_tensor; diffusion_params.guidance = &guidance_tensor;
diffusion_params.ref_latents = &ref_latents; diffusion_params.ref_latents = &ref_latents;
diffusion_params.increase_ref_index = increase_ref_index; diffusion_params.increase_ref_index = increase_ref_index;
@ -2916,6 +2955,7 @@ struct GenerationRequest {
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
seed = sd_vid_gen_params->seed; seed = sd_vid_gen_params->seed;
strength = sd_vid_gen_params->strength;
cache_params = &sd_vid_gen_params->cache; cache_params = &sd_vid_gen_params->cache;
vace_strength = sd_vid_gen_params->vace_strength; vace_strength = sd_vid_gen_params->vace_strength;
guidance = sd_vid_gen_params->sample_params.guidance; guidance = sd_vid_gen_params->sample_params.guidance;
@ -3204,6 +3244,55 @@ static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<flo
return packed; return packed;
} }
static sd::Tensor<float> pack_ltxav_audio_and_video_denoise_mask(const sd::Tensor<float>& video_mask,
const sd::Tensor<float>& video_latent,
const sd::Tensor<float>& audio_latent) {
if (video_mask.empty() || audio_latent.empty()) {
return video_mask;
}
GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5);
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
if (video_latent.dim() == 5) {
GGML_ASSERT(video_latent.shape()[4] == 1);
}
if (audio_latent.dim() == 4) {
GGML_ASSERT(audio_latent.shape()[3] == 1);
}
int64_t width = video_latent.shape()[0];
int64_t height = video_latent.shape()[1];
int64_t frames = video_latent.shape()[2];
int64_t video_ch = video_latent.shape()[3];
int64_t spatial_size = width * height * frames;
int64_t audio_values = audio_latent.numel();
int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size;
GGML_ASSERT(video_mask.dim() == video_latent.dim());
GGML_ASSERT(video_mask.shape()[0] == width);
GGML_ASSERT(video_mask.shape()[1] == height);
GGML_ASSERT(video_mask.shape()[2] == frames);
if (video_mask.dim() == 5) {
GGML_ASSERT(video_mask.shape()[4] == video_latent.shape()[4]);
}
int64_t mask_ch = video_mask.shape()[3];
if (mask_ch == video_ch + extra_ch) {
return video_mask;
}
GGML_ASSERT(mask_ch == 1 || mask_ch == video_ch);
sd::Tensor<float> video_mask_full = video_mask;
if (mask_ch == 1 && video_ch != 1) {
video_mask_full = video_mask * sd::Tensor<float>::ones(video_latent.shape());
}
std::vector<int64_t> audio_mask_shape = video_latent.shape();
audio_mask_shape[3] = extra_ch;
auto audio_mask = sd::Tensor<float>::ones(audio_mask_shape);
return sd::ops::concat(video_mask_full, audio_mask, 3);
}
static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent, static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent,
int audio_length, int audio_length,
int video_channels) { int video_channels) {
@ -4030,10 +4119,47 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
} }
if (sd_version_is_ltxav(sd_ctx->sd->version)) { if (sd_version_is_ltxav(sd_ctx->sd->version)) {
if (!start_image.empty() || !end_image.empty() || sd_vid_gen_params->control_frames_size > 0) { if (!end_image.empty() || sd_vid_gen_params->control_frames_size > 0) {
LOG_ERROR("LTXAV currently supports txt2vid only; init_image, end_image, and control_frames are not implemented"); LOG_ERROR("LTXAV currently supports txt2vid and init_image i2v only; end_image and control_frames are not implemented");
return std::nullopt; return std::nullopt;
} }
if (!start_image.empty()) {
if (sd_ctx->sd->vae_decode_only) {
LOG_ERROR("LTXAV init_image i2v requires VAE encoder weights; create the context with vae_decode_only=false");
return std::nullopt;
}
LOG_INFO("IMG2VID");
int64_t t1 = ggml_time_ms();
auto init_img = start_image.reshape({start_image.shape()[0],
start_image.shape()[1],
1,
start_image.shape()[2],
start_image.shape()[3]});
auto init_image_latent = sd_ctx->sd->encode_first_stage(init_img);
if (init_image_latent.empty()) {
LOG_ERROR("failed to encode LTXAV init image");
return std::nullopt;
}
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
sd::ops::slice_assign(&latents.init_latent, 2, 0, init_image_latent.shape()[2], init_image_latent);
float conditioning_strength = std::clamp(request->strength, 0.f, 1.f);
float conditioned_mask = 1.0f - conditioning_strength;
latents.denoise_mask = sd::full<float>({latents.init_latent.shape()[0],
latents.init_latent.shape()[1],
latents.init_latent.shape()[2],
1,
1},
1.f);
sd::ops::fill_slice(&latents.denoise_mask, 2, 0, init_image_latent.shape()[2], conditioned_mask);
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
}
} }
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
@ -4207,6 +4333,11 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
} }
if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) { if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) {
if (!latents.denoise_mask.empty()) {
latents.denoise_mask = pack_ltxav_audio_and_video_denoise_mask(latents.denoise_mask,
latents.init_latent,
latents.audio_latent);
}
latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent); latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
} }