mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
add i2v support
This commit is contained in:
parent
f8a0330d37
commit
18fbb4cdfb
@ -1127,18 +1127,33 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_conv_3d(ggml_context* ctx,
|
||||
ggml_tensor* w,
|
||||
ggml_tensor* b,
|
||||
int64_t IC,
|
||||
int s0 = 1,
|
||||
int s1 = 1,
|
||||
int s2 = 1,
|
||||
int p0 = 0,
|
||||
int p1 = 0,
|
||||
int p2 = 0,
|
||||
int d0 = 1,
|
||||
int d1 = 1,
|
||||
int d2 = 1) {
|
||||
int64_t OC = w->ne[3] / IC;
|
||||
int64_t N = x->ne[3] / IC;
|
||||
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
int s0 = 1,
|
||||
int s1 = 1,
|
||||
int s2 = 1,
|
||||
int p0 = 0,
|
||||
int p1 = 0,
|
||||
int p2 = 0,
|
||||
int d0 = 1,
|
||||
int d1 = 1,
|
||||
int d2 = 1,
|
||||
bool force_prec_f32 = false) {
|
||||
if (force_prec_f32) {
|
||||
ggml_tensor* im2col = ggml_im2col_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, w->type);
|
||||
|
||||
int64_t OC = w->ne[3] / IC;
|
||||
int64_t N = x->ne[3] / IC;
|
||||
x = ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]),
|
||||
ggml_reshape_2d(ctx, w, w->ne[0] * w->ne[1] * w->ne[2] * IC, OC));
|
||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||
|
||||
int64_t OD = im2col->ne[3] / N;
|
||||
x = ggml_reshape_4d(ctx, x, im2col->ne[1] * im2col->ne[2], OD, N, OC);
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2));
|
||||
x = ggml_reshape_4d(ctx, x, im2col->ne[1], im2col->ne[2], OD, OC * N);
|
||||
} else {
|
||||
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
}
|
||||
|
||||
if (b != nullptr) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||
@ -3133,6 +3148,7 @@ protected:
|
||||
std::tuple<int, int, int> padding;
|
||||
std::tuple<int, int, int> dilation;
|
||||
bool bias;
|
||||
bool force_prec_f32;
|
||||
std::string prefix;
|
||||
|
||||
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
@ -3156,14 +3172,16 @@ public:
|
||||
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||
std::tuple<int, int, int> padding = {0, 0, 0},
|
||||
std::tuple<int, int, int> dilation = {1, 1, 1},
|
||||
bool bias = true)
|
||||
bool bias = true,
|
||||
bool force_prec_f32 = false)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels),
|
||||
kernel_size(kernel_size),
|
||||
stride(stride),
|
||||
padding(padding),
|
||||
dilation(dilation),
|
||||
bias(bias) {}
|
||||
bias(bias),
|
||||
force_prec_f32(force_prec_f32) {}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||
ggml_tensor* w = params["weight"];
|
||||
@ -3183,7 +3201,8 @@ public:
|
||||
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation),
|
||||
force_prec_f32);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
126
src/ltx_vae.hpp
126
src/ltx_vae.hpp
@ -89,7 +89,8 @@ namespace LTXVAE {
|
||||
int kernel_size = 3,
|
||||
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||
int dilation = 1,
|
||||
bool bias = true) {
|
||||
bool bias = true,
|
||||
bool force_prec_f32 = false) {
|
||||
time_kernel_size = kernel_size;
|
||||
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
|
||||
out_channels,
|
||||
@ -97,7 +98,8 @@ namespace LTXVAE {
|
||||
stride,
|
||||
{0, kernel_size / 2, kernel_size / 2},
|
||||
{dilation, 1, 1},
|
||||
bias));
|
||||
bias,
|
||||
force_prec_f32));
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
@ -469,7 +471,8 @@ namespace LTXVAE {
|
||||
SpaceToDepthDownsample(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
int factor_t,
|
||||
int factor_s)
|
||||
int factor_s,
|
||||
bool force_conv_prec_f32 = false)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels),
|
||||
factor_t(factor_t),
|
||||
@ -477,7 +480,13 @@ namespace LTXVAE {
|
||||
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
|
||||
GGML_ASSERT(out_channels % factor == 0);
|
||||
|
||||
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_channels / factor, 3);
|
||||
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels,
|
||||
out_channels / factor,
|
||||
3,
|
||||
std::tuple<int, int, int>{1, 1, 1},
|
||||
1,
|
||||
true,
|
||||
force_conv_prec_f32);
|
||||
blocks["skip_downsample"] = std::make_shared<WAN::AvgDown3D>(in_channels, out_channels, factor_t, factor_s);
|
||||
blocks["conv_downsample"] = std::make_shared<WAN::AvgDown3D>(out_channels / factor, out_channels, factor_t, factor_s);
|
||||
}
|
||||
@ -492,7 +501,7 @@ namespace LTXVAE {
|
||||
if (factor_t > 1 && x->ne[2] > 0) {
|
||||
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
|
||||
auto first_frame_pad = first_frame;
|
||||
for (int i = 1; i < factor_t; ++i) {
|
||||
for (int i = 1; i < factor_t - 1; ++i) {
|
||||
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
|
||||
}
|
||||
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
|
||||
@ -550,6 +559,8 @@ namespace LTXVAE {
|
||||
std::vector<Block> blocks;
|
||||
};
|
||||
|
||||
static inline EncoderConfig get_default_encoder_config(int version);
|
||||
|
||||
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& name) {
|
||||
return tensor_storage_map.find(name) != tensor_storage_map.end();
|
||||
@ -633,6 +644,84 @@ namespace LTXVAE {
|
||||
return cfg;
|
||||
}
|
||||
|
||||
static inline EncoderConfig infer_encoder_config_from_weights(const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& prefix,
|
||||
int version) {
|
||||
EncoderConfig cfg;
|
||||
const std::string encoder_prefix = prefix + ".encoder.down_blocks.";
|
||||
|
||||
int64_t current_channels = get_tensor_ne0(tensor_storage_map,
|
||||
prefix + ".encoder.conv_in.conv.bias",
|
||||
0);
|
||||
for (int block_idx = 0;; ++block_idx) {
|
||||
const std::string block_prefix = encoder_prefix + std::to_string(block_idx);
|
||||
const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias";
|
||||
const std::string conv_bias = block_prefix + ".conv.conv.bias";
|
||||
|
||||
if (has_tensor(tensor_storage_map, res0_bias)) {
|
||||
int num_layers = 0;
|
||||
while (has_tensor(tensor_storage_map,
|
||||
block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) {
|
||||
num_layers++;
|
||||
}
|
||||
cfg.blocks.push_back({"res_x", num_layers, 1});
|
||||
current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!has_tensor(tensor_storage_map, conv_bias)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int64_t conv_channels = get_tensor_ne0(tensor_storage_map, conv_bias);
|
||||
int64_t next_channels = 0;
|
||||
for (int next_idx = block_idx + 1;; ++next_idx) {
|
||||
const std::string next_res0_bias = encoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias";
|
||||
const std::string next_conv_bias = encoder_prefix + std::to_string(next_idx) + ".conv.conv.bias";
|
||||
if (has_tensor(tensor_storage_map, next_res0_bias)) {
|
||||
next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias);
|
||||
break;
|
||||
}
|
||||
if (!has_tensor(tensor_storage_map, next_conv_bias)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t multiplier = current_channels > 0 && next_channels > 0 && next_channels % current_channels == 0
|
||||
? std::max<int64_t>(1, next_channels / current_channels)
|
||||
: 1;
|
||||
const int64_t factor = conv_channels > 0 && next_channels > 0 && next_channels % conv_channels == 0
|
||||
? next_channels / conv_channels
|
||||
: 0;
|
||||
|
||||
if (factor == 8) {
|
||||
cfg.blocks.push_back({"compress_all_res", 0, static_cast<int>(multiplier)});
|
||||
} else if (factor == 4) {
|
||||
cfg.blocks.push_back({"compress_space_res", 0, static_cast<int>(multiplier)});
|
||||
} else if (factor == 2) {
|
||||
cfg.blocks.push_back({"compress_time_res", 0, static_cast<int>(multiplier)});
|
||||
} else {
|
||||
LOG_WARN("unexpected LTX VAE encoder downsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all_res x%d",
|
||||
block_prefix.c_str(),
|
||||
(long long)conv_channels,
|
||||
(long long)current_channels,
|
||||
(long long)next_channels,
|
||||
(int)multiplier);
|
||||
cfg.blocks.push_back({"compress_all_res", 0, static_cast<int>(multiplier)});
|
||||
}
|
||||
if (next_channels > 0) {
|
||||
current_channels = next_channels;
|
||||
} else if (current_channels > 0) {
|
||||
current_channels *= multiplier;
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg.blocks.empty()) {
|
||||
return get_default_encoder_config(version);
|
||||
}
|
||||
return cfg;
|
||||
}
|
||||
|
||||
static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& prefix) {
|
||||
const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias";
|
||||
@ -647,7 +736,7 @@ namespace LTXVAE {
|
||||
return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end();
|
||||
}
|
||||
|
||||
static inline EncoderConfig get_encoder_config(int version) {
|
||||
static inline EncoderConfig get_default_encoder_config(int version) {
|
||||
EncoderConfig cfg;
|
||||
if (version < 2) {
|
||||
GGML_ABORT("LTX VAE encoder is only implemented for version >= 2");
|
||||
@ -674,6 +763,8 @@ namespace LTXVAE {
|
||||
int64_t latent_channels;
|
||||
|
||||
Encoder(int version,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& prefix,
|
||||
int patch_size = 4,
|
||||
int64_t in_channels = 3,
|
||||
int64_t latent_channels = 128)
|
||||
@ -681,9 +772,12 @@ namespace LTXVAE {
|
||||
patch_size(patch_size),
|
||||
in_channels(in_channels),
|
||||
latent_channels(latent_channels) {
|
||||
auto cfg = get_encoder_config(version);
|
||||
int64_t channels = 128;
|
||||
int64_t in_dim = in_channels * patch_size * patch_size;
|
||||
auto cfg = infer_encoder_config_from_weights(tensor_storage_map, prefix, version);
|
||||
int64_t channels = get_tensor_ne0(tensor_storage_map,
|
||||
prefix + ".encoder.conv_in.conv.bias",
|
||||
0);
|
||||
GGML_ASSERT(channels > 0);
|
||||
int64_t in_dim = in_channels * patch_size * patch_size;
|
||||
|
||||
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_dim, channels, 3);
|
||||
|
||||
@ -708,11 +802,14 @@ namespace LTXVAE {
|
||||
1);
|
||||
channels = next_channels;
|
||||
} else if (block.type == "compress_all_res") {
|
||||
int64_t next_channels = channels * block.multiplier;
|
||||
int64_t next_channels = channels * block.multiplier;
|
||||
// LTX 2.3 encoder down_blocks.7.conv overflows with fp16 accumulation.
|
||||
bool force_conv_prec_f32 = block_idx == 7;
|
||||
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
|
||||
next_channels,
|
||||
2,
|
||||
2);
|
||||
2,
|
||||
force_conv_prec_f32);
|
||||
channels = next_channels;
|
||||
} else {
|
||||
GGML_ABORT("Unsupported LTX VAE encoder block");
|
||||
@ -956,7 +1053,10 @@ namespace LTXVAE {
|
||||
patch_size(patch_size),
|
||||
decode_only(decode_only) {
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::make_shared<Encoder>(version, patch_size);
|
||||
blocks["encoder"] = std::make_shared<Encoder>(version,
|
||||
tensor_storage_map,
|
||||
prefix,
|
||||
patch_size);
|
||||
}
|
||||
blocks["decoder"] = std::make_shared<Decoder>(version,
|
||||
tensor_storage_map,
|
||||
@ -1096,7 +1196,7 @@ struct LTXVideoVAE : public VAE {
|
||||
const sd::Tensor<float>& z,
|
||||
bool decode_graph) override {
|
||||
if (!decode_graph && decode_only) {
|
||||
LOG_ERROR("LTX video VAE encoder is not implemented yet");
|
||||
LOG_ERROR("LTX video VAE encode requires encoder weights; create the context with vae_decode_only=false");
|
||||
return {};
|
||||
}
|
||||
sd::Tensor<float> input = z;
|
||||
|
||||
76
src/ltxv.hpp
76
src/ltxv.hpp
@ -23,12 +23,28 @@ namespace LTXV {
|
||||
return ggml_rms_norm(ctx, x, eps);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* align_token_modulation(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* mod) {
|
||||
if (mod != nullptr && x != nullptr && mod->ne[1] == 1 && mod->ne[2] == x->ne[1] && x->ne[2] == 1) {
|
||||
return ggml_permute(ctx, mod, 0, 2, 1, 3);
|
||||
}
|
||||
return mod;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* modulate(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* shift,
|
||||
ggml_tensor* scale) {
|
||||
shift = align_token_modulation(ctx, x, shift);
|
||||
scale = align_token_modulation(ctx, x, scale);
|
||||
return Flux::modulate(ctx, x, shift, scale, true);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* gate) {
|
||||
if (gate->ne[1] != 1) {
|
||||
gate = ggml_reshape_3d(ctx, gate, gate->ne[0], 1, gate->ne[2]);
|
||||
}
|
||||
gate = align_token_modulation(ctx, x, gate);
|
||||
return ggml_mul(ctx, x, gate);
|
||||
}
|
||||
|
||||
@ -538,7 +554,7 @@ namespace LTXV {
|
||||
auto gate_mlp = mods[5];
|
||||
|
||||
auto x_norm = rms_norm(ctx->ggml_ctx, x);
|
||||
x_norm = Flux::modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa, true);
|
||||
x_norm = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa);
|
||||
auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe);
|
||||
x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa));
|
||||
|
||||
@ -548,12 +564,12 @@ namespace LTXV {
|
||||
auto gate_q = mods[8];
|
||||
|
||||
auto q = rms_norm(ctx->ggml_ctx, x);
|
||||
q = Flux::modulate(ctx->ggml_ctx, q, shift_q, scale_q, true);
|
||||
q = modulate(ctx->ggml_ctx, q, shift_q, scale_q);
|
||||
|
||||
auto context_mod = context;
|
||||
if (prompt_timestep != nullptr) {
|
||||
auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep);
|
||||
context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1], true);
|
||||
context_mod = modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1]);
|
||||
}
|
||||
|
||||
auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr);
|
||||
@ -564,7 +580,7 @@ namespace LTXV {
|
||||
}
|
||||
|
||||
auto y = rms_norm(ctx->ggml_ctx, x);
|
||||
y = Flux::modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp, true);
|
||||
y = modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp);
|
||||
auto mlp_out = ff->forward(ctx, y);
|
||||
x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp));
|
||||
return x;
|
||||
@ -947,11 +963,11 @@ namespace LTXV {
|
||||
if (cross_attention_adaln) {
|
||||
auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3);
|
||||
auto q = rms_norm(ctx->ggml_ctx, x);
|
||||
q = Flux::modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1], true);
|
||||
q = modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1]);
|
||||
auto context_mod = context;
|
||||
if (prompt_timestep != nullptr && prompt_table != nullptr) {
|
||||
auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2);
|
||||
context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1], true);
|
||||
context_mod = modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1]);
|
||||
}
|
||||
auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr);
|
||||
return apply_gate(ctx->ggml_ctx, out, q_mods[2]);
|
||||
@ -998,7 +1014,7 @@ namespace LTXV {
|
||||
|
||||
auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6);
|
||||
auto v_norm = rms_norm(ctx->ggml_ctx, vx);
|
||||
v_norm = Flux::modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1], true);
|
||||
v_norm = modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1]);
|
||||
auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe);
|
||||
vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2]));
|
||||
auto v_txt = apply_text_cross_attention(ctx,
|
||||
@ -1016,7 +1032,7 @@ namespace LTXV {
|
||||
if (run_ax) {
|
||||
auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6);
|
||||
auto a_norm = rms_norm(ctx->ggml_ctx, ax);
|
||||
a_norm = Flux::modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1], true);
|
||||
a_norm = modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1]);
|
||||
auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe);
|
||||
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2]));
|
||||
auto a_txt = apply_text_cross_attention(ctx,
|
||||
@ -1039,8 +1055,8 @@ namespace LTXV {
|
||||
auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4);
|
||||
auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4);
|
||||
auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4);
|
||||
auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0], true);
|
||||
auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0], true);
|
||||
auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0]);
|
||||
auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0]);
|
||||
auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe);
|
||||
auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5);
|
||||
auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0];
|
||||
@ -1052,8 +1068,8 @@ namespace LTXV {
|
||||
auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4);
|
||||
auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4);
|
||||
auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4);
|
||||
auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2], true);
|
||||
auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2], true);
|
||||
auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2]);
|
||||
auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2]);
|
||||
auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe);
|
||||
auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5);
|
||||
auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0];
|
||||
@ -1061,14 +1077,14 @@ namespace LTXV {
|
||||
}
|
||||
auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3);
|
||||
auto ax_scaled = rms_norm(ctx->ggml_ctx, ax);
|
||||
ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true);
|
||||
ax_scaled = modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1]);
|
||||
auto a_ff_out = audio_ff->forward(ctx, ax_scaled);
|
||||
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2]));
|
||||
}
|
||||
|
||||
auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3);
|
||||
auto vx_scaled = rms_norm(ctx->ggml_ctx, vx);
|
||||
vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1], true);
|
||||
vx_scaled = modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1]);
|
||||
auto v_ff_out = ff->forward(ctx, vx_scaled);
|
||||
vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2]));
|
||||
|
||||
@ -1188,6 +1204,15 @@ namespace LTXV {
|
||||
return ax;
|
||||
}
|
||||
|
||||
ggml_tensor* repeat_scalar_timestep_like(GGMLRunnerContext* ctx, ggml_tensor* timestep, ggml_tensor* like) {
|
||||
GGML_ASSERT(timestep != nullptr && like != nullptr);
|
||||
if (timestep->ne[0] == like->ne[0]) {
|
||||
return timestep;
|
||||
}
|
||||
GGML_ASSERT(timestep->ne[0] == 1);
|
||||
return ggml_repeat(ctx->ggml_ctx, timestep, ggml_new_tensor_1d(ctx->ggml_ctx, timestep->type, like->ne[0]));
|
||||
}
|
||||
|
||||
ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) {
|
||||
if (ax == nullptr) {
|
||||
return nullptr;
|
||||
@ -1367,21 +1392,24 @@ namespace LTXV {
|
||||
if (cfg.cross_attention_adaln) {
|
||||
auto prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["prompt_adaln_single"]);
|
||||
auto audio_prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["audio_prompt_adaln_single"]);
|
||||
v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, v_timestep_scaled).first;
|
||||
v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first;
|
||||
a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first;
|
||||
}
|
||||
|
||||
auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep);
|
||||
auto av_ca_audio_timestep = effective_audio_timestep;
|
||||
auto av_ca_factor = cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier;
|
||||
auto av_ca_video_scale_shift_timestep =
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, a_timestep_scaled).first;
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first;
|
||||
auto av_ca_a2v_gate_noise_timestep =
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_a2v_gate_adaln_single"])
|
||||
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, a_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier))
|
||||
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_video_timestep, av_ca_factor))
|
||||
.first;
|
||||
auto av_ca_audio_scale_shift_timestep =
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, v_timestep_scaled).first;
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, av_ca_audio_timestep).first;
|
||||
auto av_ca_v2a_gate_noise_timestep =
|
||||
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_v2a_gate_adaln_single"])
|
||||
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, v_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier))
|
||||
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor))
|
||||
.first;
|
||||
|
||||
for (int i = 0; i < cfg.num_layers; i++) {
|
||||
@ -1410,14 +1438,14 @@ namespace LTXV {
|
||||
|
||||
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);
|
||||
vx = norm_out->forward(ctx, vx);
|
||||
vx = Flux::modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1], true);
|
||||
vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]);
|
||||
vx = proj_out->forward(ctx, vx);
|
||||
vx = unpatchify_video(ctx, vx, width, height, frames);
|
||||
|
||||
if (ax != nullptr && audio_time > 0) {
|
||||
auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size);
|
||||
ax = audio_norm_out->forward(ctx, ax);
|
||||
ax = Flux::modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1], true);
|
||||
ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]);
|
||||
ax = audio_proj_out->forward(ctx, ax);
|
||||
ax = unpatchify_audio(ctx, ax, audio_time);
|
||||
}
|
||||
|
||||
@ -458,10 +458,6 @@ public:
|
||||
// Might need vae encode for control cond
|
||||
vae_decode_only = false;
|
||||
}
|
||||
if (sd_version_is_ltxav(version)) {
|
||||
vae_decode_only = true;
|
||||
}
|
||||
|
||||
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
||||
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
|
||||
tae_preview_only = false;
|
||||
@ -705,7 +701,7 @@ public:
|
||||
params_backend_for(SDBackendModule::VAE),
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
true,
|
||||
vae_decode_only,
|
||||
version);
|
||||
} else if (sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
@ -1570,6 +1566,38 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> process_ltxav_video_timesteps(const std::vector<float>& timesteps,
|
||||
const sd::Tensor<float>& init_latent,
|
||||
const sd::Tensor<float>& denoise_mask) {
|
||||
if (timesteps.empty() || denoise_mask.empty() || init_latent.dim() < 4 || denoise_mask.dim() < 4) {
|
||||
return timesteps;
|
||||
}
|
||||
|
||||
int64_t width = init_latent.shape()[0];
|
||||
int64_t height = init_latent.shape()[1];
|
||||
int64_t frames = init_latent.shape()[2];
|
||||
if (denoise_mask.shape()[0] != width ||
|
||||
denoise_mask.shape()[1] != height ||
|
||||
denoise_mask.shape()[2] != frames ||
|
||||
denoise_mask.shape()[3] < 1) {
|
||||
LOG_WARN("unexpected LTXAV denoise mask shape for timestep processing");
|
||||
return timesteps;
|
||||
}
|
||||
|
||||
std::vector<float> video_timesteps(static_cast<size_t>(width * height * frames));
|
||||
size_t idx = 0;
|
||||
for (int64_t t = 0; t < frames; ++t) {
|
||||
for (int64_t h = 0; h < height; ++h) {
|
||||
for (int64_t w = 0; w < width; ++w) {
|
||||
float mask = denoise_mask.dim() == 5 ? denoise_mask.index(w, h, t, 0, 0)
|
||||
: denoise_mask.index(w, h, t, 0);
|
||||
video_timesteps[idx++] = mask * timesteps[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
return video_timesteps;
|
||||
}
|
||||
|
||||
void preview_image(int step,
|
||||
const sd::Tensor<float>& latents,
|
||||
enum SDVersion version,
|
||||
@ -1846,14 +1874,24 @@ public:
|
||||
float c_out = scaling[1];
|
||||
float c_in = scaling[2];
|
||||
|
||||
std::vector<float> timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep);
|
||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||
adjust_sample_step_scalings(shifted_timestep, timesteps_vec, c_in, &c_skip, &c_out);
|
||||
std::vector<float> base_timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep);
|
||||
std::vector<float> timesteps_vec = base_timesteps_vec;
|
||||
sd::Tensor<float> audio_timesteps_tensor;
|
||||
if (sd_version_is_ltxav(version) && !denoise_mask.empty()) {
|
||||
timesteps_vec = process_ltxav_video_timesteps(base_timesteps_vec, init_latent, denoise_mask);
|
||||
audio_timesteps_tensor = sd::Tensor<float>({static_cast<int64_t>(base_timesteps_vec.size())}, base_timesteps_vec);
|
||||
} else {
|
||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||
}
|
||||
const std::vector<float>& scaling_timesteps_vec = (sd_version_is_ltxav(version) && !denoise_mask.empty())
|
||||
? base_timesteps_vec
|
||||
: timesteps_vec;
|
||||
adjust_sample_step_scalings(shifted_timestep, scaling_timesteps_vec, c_in, &c_skip, &c_out);
|
||||
|
||||
sd::Tensor<float> timesteps_tensor({static_cast<int64_t>(timesteps_vec.size())}, timesteps_vec);
|
||||
sd::Tensor<float> guidance_tensor({1}, std::vector<float>{guidance.distilled_guidance});
|
||||
sd::Tensor<float> noised_input = x * c_in;
|
||||
if (!denoise_mask.empty() && version == VERSION_WAN2_2_TI2V) {
|
||||
if (!denoise_mask.empty() && (version == VERSION_WAN2_2_TI2V || sd_version_is_ltxav(version))) {
|
||||
noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask);
|
||||
}
|
||||
|
||||
@ -1884,6 +1922,7 @@ public:
|
||||
DiffusionParams diffusion_params;
|
||||
diffusion_params.x = &noised_input;
|
||||
diffusion_params.timesteps = ×teps_tensor;
|
||||
diffusion_params.audio_timesteps = audio_timesteps_tensor.empty() ? nullptr : &audio_timesteps_tensor;
|
||||
diffusion_params.guidance = &guidance_tensor;
|
||||
diffusion_params.ref_latents = &ref_latents;
|
||||
diffusion_params.increase_ref_index = increase_ref_index;
|
||||
@ -2916,6 +2955,7 @@ struct GenerationRequest {
|
||||
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||
diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
|
||||
seed = sd_vid_gen_params->seed;
|
||||
strength = sd_vid_gen_params->strength;
|
||||
cache_params = &sd_vid_gen_params->cache;
|
||||
vace_strength = sd_vid_gen_params->vace_strength;
|
||||
guidance = sd_vid_gen_params->sample_params.guidance;
|
||||
@ -3204,6 +3244,55 @@ static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<flo
|
||||
return packed;
|
||||
}
|
||||
|
||||
static sd::Tensor<float> pack_ltxav_audio_and_video_denoise_mask(const sd::Tensor<float>& video_mask,
|
||||
const sd::Tensor<float>& video_latent,
|
||||
const sd::Tensor<float>& audio_latent) {
|
||||
if (video_mask.empty() || audio_latent.empty()) {
|
||||
return video_mask;
|
||||
}
|
||||
|
||||
GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5);
|
||||
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
|
||||
if (video_latent.dim() == 5) {
|
||||
GGML_ASSERT(video_latent.shape()[4] == 1);
|
||||
}
|
||||
if (audio_latent.dim() == 4) {
|
||||
GGML_ASSERT(audio_latent.shape()[3] == 1);
|
||||
}
|
||||
|
||||
int64_t width = video_latent.shape()[0];
|
||||
int64_t height = video_latent.shape()[1];
|
||||
int64_t frames = video_latent.shape()[2];
|
||||
int64_t video_ch = video_latent.shape()[3];
|
||||
int64_t spatial_size = width * height * frames;
|
||||
int64_t audio_values = audio_latent.numel();
|
||||
int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size;
|
||||
|
||||
GGML_ASSERT(video_mask.dim() == video_latent.dim());
|
||||
GGML_ASSERT(video_mask.shape()[0] == width);
|
||||
GGML_ASSERT(video_mask.shape()[1] == height);
|
||||
GGML_ASSERT(video_mask.shape()[2] == frames);
|
||||
if (video_mask.dim() == 5) {
|
||||
GGML_ASSERT(video_mask.shape()[4] == video_latent.shape()[4]);
|
||||
}
|
||||
|
||||
int64_t mask_ch = video_mask.shape()[3];
|
||||
if (mask_ch == video_ch + extra_ch) {
|
||||
return video_mask;
|
||||
}
|
||||
GGML_ASSERT(mask_ch == 1 || mask_ch == video_ch);
|
||||
|
||||
sd::Tensor<float> video_mask_full = video_mask;
|
||||
if (mask_ch == 1 && video_ch != 1) {
|
||||
video_mask_full = video_mask * sd::Tensor<float>::ones(video_latent.shape());
|
||||
}
|
||||
|
||||
std::vector<int64_t> audio_mask_shape = video_latent.shape();
|
||||
audio_mask_shape[3] = extra_ch;
|
||||
auto audio_mask = sd::Tensor<float>::ones(audio_mask_shape);
|
||||
return sd::ops::concat(video_mask_full, audio_mask, 3);
|
||||
}
|
||||
|
||||
static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent,
|
||||
int audio_length,
|
||||
int video_channels) {
|
||||
@ -4030,10 +4119,47 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
||||
}
|
||||
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||
if (!start_image.empty() || !end_image.empty() || sd_vid_gen_params->control_frames_size > 0) {
|
||||
LOG_ERROR("LTXAV currently supports txt2vid only; init_image, end_image, and control_frames are not implemented");
|
||||
if (!end_image.empty() || sd_vid_gen_params->control_frames_size > 0) {
|
||||
LOG_ERROR("LTXAV currently supports txt2vid and init_image i2v only; end_image and control_frames are not implemented");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!start_image.empty()) {
|
||||
if (sd_ctx->sd->vae_decode_only) {
|
||||
LOG_ERROR("LTXAV init_image i2v requires VAE encoder weights; create the context with vae_decode_only=false");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
LOG_INFO("IMG2VID");
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
auto init_img = start_image.reshape({start_image.shape()[0],
|
||||
start_image.shape()[1],
|
||||
1,
|
||||
start_image.shape()[2],
|
||||
start_image.shape()[3]});
|
||||
auto init_image_latent = sd_ctx->sd->encode_first_stage(init_img);
|
||||
if (init_image_latent.empty()) {
|
||||
LOG_ERROR("failed to encode LTXAV init image");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
|
||||
sd::ops::slice_assign(&latents.init_latent, 2, 0, init_image_latent.shape()[2], init_image_latent);
|
||||
|
||||
float conditioning_strength = std::clamp(request->strength, 0.f, 1.f);
|
||||
float conditioned_mask = 1.0f - conditioning_strength;
|
||||
latents.denoise_mask = sd::full<float>({latents.init_latent.shape()[0],
|
||||
latents.init_latent.shape()[1],
|
||||
latents.init_latent.shape()[2],
|
||||
1,
|
||||
1},
|
||||
1.f);
|
||||
sd::ops::fill_slice(&latents.denoise_mask, 2, 0, init_image_latent.shape()[2], conditioned_mask);
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||
@ -4207,6 +4333,11 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
||||
}
|
||||
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) {
|
||||
if (!latents.denoise_mask.empty()) {
|
||||
latents.denoise_mask = pack_ltxav_audio_and_video_denoise_mask(latents.denoise_mask,
|
||||
latents.init_latent,
|
||||
latents.audio_latent);
|
||||
}
|
||||
latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user