#ifndef __SD_MODEL_DIFFUSION_LTXV_HPP__ #define __SD_MODEL_DIFFUSION_LTXV_HPP__ #include #include #include #include #include #include #include #include "model/common/block.hpp" #include "model/common/rope.hpp" #include "model/diffusion/flux.hpp" #include "model/diffusion/model.hpp" namespace LTXV { constexpr int LTXAV_GRAPH_SIZE = 102400; __STATIC_INLINE__ ggml_tensor* rms_norm(ggml_context* ctx, ggml_tensor* x, float eps = 1e-6f) { return ggml_rms_norm(ctx, x, eps); } __STATIC_INLINE__ ggml_tensor* align_token_modulation(ggml_context* ctx, ggml_tensor* x, ggml_tensor* mod) { if (mod != nullptr && x != nullptr && mod->ne[1] == 1 && mod->ne[2] == x->ne[1] && x->ne[2] == 1) { return ggml_permute(ctx, mod, 0, 2, 1, 3); } return mod; } __STATIC_INLINE__ ggml_tensor* modulate(ggml_context* ctx, ggml_tensor* x, ggml_tensor* shift, ggml_tensor* scale) { shift = align_token_modulation(ctx, x, shift); scale = align_token_modulation(ctx, x, scale); return Flux::modulate(ctx, x, shift, scale, true); } __STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx, ggml_tensor* x, ggml_tensor* gate) { gate = align_token_modulation(ctx, x, gate); return ggml_mul(ctx, x, gate); } __STATIC_INLINE__ int count_prefix_blocks(const String2TensorStorage& tensor_storage_map, const std::string& prefix, const std::string& marker) { int max_block = -1; for (const auto& [name, _] : tensor_storage_map) { if (!starts_with(name, prefix)) { continue; } size_t pos = name.find(marker); if (pos == std::string::npos) { continue; } pos += marker.size(); size_t end = name.find(".", pos); if (end == std::string::npos) { continue; } int block = atoi(name.substr(pos, end - pos).c_str()); max_block = std::max(max_block, block); } return max_block + 1; } struct LTXAVConfig { int64_t in_channels = 128; int64_t out_channels = 128; int64_t hidden_size = 3840; int64_t cross_attention_dim = 4096; int64_t caption_channels = 3840; int64_t num_attention_heads = 30; int64_t attention_head_dim = 128; int64_t num_layers = 28; float positional_embedding_theta = 10000.f; std::vector positional_embedding_max_pos = {20, 2048, 2048}; std::tuple vae_scale_factors = {8, 32, 32}; bool causal_temporal_positioning = true; float timestep_scale_multiplier = 1000.f; int64_t audio_in_channels = 128; int64_t audio_out_channels = 128; int64_t audio_hidden_size = 2048; int64_t audio_cross_attention_dim = 2048; int64_t audio_num_attention_heads = 32; int64_t audio_attention_head_dim = 64; std::vector audio_positional_embedding_max_pos = {20}; float av_ca_timestep_scale_multiplier = 1000.f; int64_t num_audio_channels = 8; int64_t audio_frequency_bins = 16; bool use_connector = false; int64_t connector_hidden_size = 3840; int64_t connector_num_heads = 30; int64_t connector_head_dim = 128; int64_t connector_num_layers = 2; int64_t connector_num_registers = 128; bool connector_rope_interleaved = false; bool connector_apply_gated_attention = false; bool use_audio_connector = false; int64_t audio_connector_hidden_size = 2048; int64_t audio_connector_num_heads = 32; int64_t audio_connector_head_dim = 64; int64_t audio_connector_num_layers = 2; int64_t audio_connector_num_registers = 128; bool audio_connector_rope_interleaved = false; bool audio_connector_apply_gated_attention = false; bool video_rope_interleaved = false; bool use_middle_indices_grid = true; bool cross_attention_adaln = false; bool use_caption_projection = true; bool use_audio_caption_projection = true; bool caption_proj_before_connector = true; bool caption_projection_first_linear = false; bool self_attention_gated = false; bool cross_attention_gated = false; static std::pair infer_attention_layout(int64_t hidden_size, int64_t preferred_heads = -1) { if (preferred_heads > 0 && hidden_size % preferred_heads == 0) { return {preferred_heads, hidden_size / preferred_heads}; } const int candidates[] = {128, 96, 80, 64, 48, 40, 32}; for (int head_dim : candidates) { if (hidden_size % head_dim == 0) { int64_t heads = hidden_size / head_dim; if (heads >= 8 && heads <= 64) { return {heads, head_dim}; } } } return {32, hidden_size / 32}; } static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map, const std::string& bias_name, int64_t fallback_heads) { auto it = tensor_storage_map.find(bias_name); if (it != tensor_storage_map.end()) { return it->second.ne[0]; } return fallback_heads; } static LTXAVConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { LTXAVConfig config; auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight"); if (patchify_proj_iter != tensor_storage_map.end()) { config.in_channels = patchify_proj_iter->second.ne[0]; config.hidden_size = patchify_proj_iter->second.ne[1]; int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32); auto attn_layout = infer_attention_layout(config.hidden_size, video_heads); config.num_attention_heads = attn_layout.first; config.attention_head_dim = attn_layout.second; } auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight"); if (audio_patchify_proj_iter != tensor_storage_map.end()) { config.audio_in_channels = audio_patchify_proj_iter->second.ne[0]; config.audio_hidden_size = audio_patchify_proj_iter->second.ne[1]; config.audio_out_channels = config.audio_in_channels; int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32); auto audio_attn_layout = infer_attention_layout(config.audio_hidden_size, audio_heads); config.audio_num_attention_heads = audio_attn_layout.first; config.audio_attention_head_dim = audio_attn_layout.second; } auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight"); if (proj_out_iter != tensor_storage_map.end()) { config.out_channels = proj_out_iter->second.ne[1]; } auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight"); if (audio_proj_out_iter != tensor_storage_map.end()) { config.audio_out_channels = audio_proj_out_iter->second.ne[1]; } auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight"); if (attn2_iter != tensor_storage_map.end()) { config.cross_attention_dim = attn2_iter->second.ne[0]; } auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight"); if (audio_attn2_iter != tensor_storage_map.end()) { config.audio_cross_attention_dim = audio_attn2_iter->second.ne[0]; } if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) { config.cross_attention_adaln = true; } if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() || tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) { config.self_attention_gated = true; } if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() || tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) { config.cross_attention_gated = true; } if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() && tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) { config.use_caption_projection = false; } if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() && tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) { config.use_audio_caption_projection = false; } config.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks."); auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); if (connector_iter != tensor_storage_map.end()) { config.use_connector = true; config.connector_hidden_size = connector_iter->second.ne[1]; int64_t connector_heads = infer_gate_heads(tensor_storage_map, prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", 32); auto connector_layout = infer_attention_layout(config.connector_hidden_size, connector_heads); config.connector_num_heads = connector_layout.first; config.connector_head_dim = connector_layout.second; config.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks."); auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers"); if (register_iter != tensor_storage_map.end()) { config.connector_num_registers = register_iter->second.ne[1]; } if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { config.connector_apply_gated_attention = true; } } auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); if (audio_connector_iter != tensor_storage_map.end()) { config.use_audio_connector = true; config.audio_connector_hidden_size = audio_connector_iter->second.ne[1]; int64_t connector_heads = infer_gate_heads(tensor_storage_map, prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", 32); auto connector_layout = infer_attention_layout(config.audio_connector_hidden_size, connector_heads); config.audio_connector_num_heads = connector_layout.first; config.audio_connector_head_dim = connector_layout.second; config.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks."); auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers"); if (register_iter != tensor_storage_map.end()) { config.audio_connector_num_registers = register_iter->second.ne[1]; } if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { config.audio_connector_apply_gated_attention = true; } } LOG_DEBUG("ltxav: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_attention_heads = %" PRId64 ", audio_hidden_size = %" PRId64 ", audio_num_attention_heads = %" PRId64, config.num_layers, config.hidden_size, config.num_attention_heads, config.audio_hidden_size, config.audio_num_attention_heads); return config; } }; __STATIC_INLINE__ std::vector generate_freq_grid(float theta, int positional_dims, int dim) { const int n_elem = 2 * positional_dims; const int freq_count = dim / n_elem; std::vector out(freq_count); if (freq_count <= 0) { return out; } if (freq_count == 1) { out[0] = 1.5707963267948966f; return out; } const float half_pi = 1.5707963267948966f; const float log_theta = std::log(theta); for (int i = 0; i < freq_count; i++) { float ratio = static_cast(i) / static_cast(freq_count - 1); out[i] = std::exp(log_theta * ratio) * half_pi; } return out; } __STATIC_INLINE__ std::vector generate_freq_grid_double(double theta, int positional_dims, int dim) { const int n_elem = 2 * positional_dims; const int freq_count = dim / n_elem; std::vector out(freq_count); if (freq_count <= 0) { return out; } if (freq_count == 1) { out[0] = 1.5707963267948966; return out; } const double half_pi = 1.5707963267948966; const double log_theta = std::log(theta); for (int i = 0; i < freq_count; i++) { double ratio = static_cast(i) / static_cast(freq_count - 1); out[i] = std::exp(log_theta * ratio) * half_pi; } return out; } __STATIC_INLINE__ std::vector build_rope_matrix_from_frequencies( const std::vector>& frequencies, int dim) { const int half_dim = dim / 2; std::vector out(static_cast(frequencies.size()) * static_cast(half_dim) * 4, 0.f); for (size_t token = 0; token < frequencies.size(); token++) { for (int i = 0; i < half_dim; i++) { float angle = i < static_cast(frequencies[token].size()) ? frequencies[token][i] : 0.f; float c = std::cos(angle); float s = std::sin(angle); size_t base = (token * static_cast(half_dim) + static_cast(i)) * 4; out[base + 0] = c; out[base + 1] = -s; out[base + 2] = s; out[base + 3] = c; } } return out; } __STATIC_INLINE__ std::vector> split_frequencies_by_heads( const std::vector>& frequencies, int inner_dim, int num_heads) { GGML_ASSERT(num_heads > 0); GGML_ASSERT(inner_dim % num_heads == 0); const int inner_half_dim = inner_dim / 2; const int per_head_half_dim = inner_half_dim / num_heads; GGML_ASSERT(inner_half_dim % num_heads == 0); std::vector> out( frequencies.size() * static_cast(num_heads), std::vector(per_head_half_dim, 0.f)); for (size_t token = 0; token < frequencies.size(); token++) { GGML_ASSERT(static_cast(frequencies[token].size()) == inner_half_dim); for (int head = 0; head < num_heads; head++) { auto& dst = out[token * static_cast(num_heads) + static_cast(head)]; std::copy_n(frequencies[token].begin() + head * per_head_half_dim, per_head_half_dim, dst.begin()); } } return out; } __STATIC_INLINE__ std::vector build_video_rope_matrix(int64_t width, int64_t height, int64_t frames, int dim, int num_heads = 1, float frame_rate = 24.f, float theta = 10000.f, const std::vector& max_pos = {20, 2048, 2048}, const std::tuple& vae_scale_factors = {8, 32, 32}, bool causal_temporal_positioning = false, bool use_middle_indices_grid = false) { GGML_ASSERT(max_pos.size() == 3); GGML_ASSERT(dim % num_heads == 0); const std::vector indices = generate_freq_grid(theta, 3, dim); const int half_dim = dim / 2; const int pad_size = half_dim - static_cast(indices.size()) * 3; std::vector> freqs(static_cast(width * height * frames), std::vector(half_dim, 0.f)); const int scale_t = std::get<0>(vae_scale_factors); const int scale_h = std::get<1>(vae_scale_factors); const int scale_w = std::get<2>(vae_scale_factors); size_t token = 0; for (int64_t t = 0; t < frames; t++) { float pixel_t = static_cast(t * scale_t); if (causal_temporal_positioning) { pixel_t = std::max(0.f, pixel_t + 1.f - scale_t); } pixel_t /= frame_rate; if (use_middle_indices_grid) { float end = static_cast((t + 1) * scale_t); if (causal_temporal_positioning) { end = std::max(0.f, end + 1.f - scale_t); } end /= frame_rate; pixel_t = 0.5f * (pixel_t + end); } for (int64_t h = 0; h < height; h++) { float pixel_h = static_cast(h * scale_h); if (use_middle_indices_grid) { pixel_h += 0.5f * static_cast(scale_h); } for (int64_t w = 0; w < width; w++) { float pixel_w = static_cast(w * scale_w); if (use_middle_indices_grid) { pixel_w += 0.5f * static_cast(scale_w); } int out_idx = 0; for (int i = 0; i < pad_size; i++) { freqs[token][out_idx++] = 0.f; } const float coords[3] = { pixel_t / max_pos[0], pixel_h / max_pos[1], pixel_w / max_pos[2], }; for (float index : indices) { for (int axis = 0; axis < 3; axis++) { freqs[token][out_idx++] = index * (coords[axis] * 2.f - 1.f); } } token++; } } } if (num_heads > 1) { return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); } return build_rope_matrix_from_frequencies(freqs, dim); } __STATIC_INLINE__ std::vector build_video_rope_matrix_from_positions(const sd::Tensor& positions, int dim, int num_heads, float theta, const std::vector& max_pos, bool use_middle_indices_grid) { GGML_ASSERT(max_pos.size() == 3); GGML_ASSERT(dim % num_heads == 0); GGML_ASSERT(positions.dim() == 3 || positions.dim() == 4); GGML_ASSERT(positions.shape()[0] == 2); GGML_ASSERT(positions.shape()[1] == 3); if (positions.dim() == 4) { GGML_ASSERT(positions.shape()[3] == 1); } const int64_t tokens = positions.shape()[2]; const std::vector indices = generate_freq_grid(theta, 3, dim); const int half_dim = dim / 2; const int pad_size = half_dim - static_cast(indices.size()) * 3; std::vector> freqs(static_cast(tokens), std::vector(half_dim, 0.f)); for (int64_t token = 0; token < tokens; token++) { int out_idx = 0; for (int i = 0; i < pad_size; i++) { freqs[token][out_idx++] = 0.f; } float coords[3]; for (int axis = 0; axis < 3; axis++) { float start = positions.dim() == 4 ? positions.index(0, axis, token, 0) : positions.index(0, axis, token); float end = positions.dim() == 4 ? positions.index(1, axis, token, 0) : positions.index(1, axis, token); float coord = use_middle_indices_grid ? 0.5f * (start + end) : start; coords[axis] = coord / static_cast(max_pos[axis]); } for (float index : indices) { for (int axis = 0; axis < 3; axis++) { freqs[token][out_idx++] = index * (coords[axis] * 2.f - 1.f); } } } if (num_heads > 1) { return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); } return build_rope_matrix_from_frequencies(freqs, dim); } __STATIC_INLINE__ std::vector build_1d_rope_matrix(int64_t seq_len, int dim, int num_heads = 1, float theta = 10000.f, float positional_scale = 4096.f, bool double_precision = false) { GGML_ASSERT(dim % num_heads == 0); const std::vector indices = double_precision ? std::vector() : generate_freq_grid(theta, 1, dim); const std::vector indices_d = double_precision ? generate_freq_grid_double(static_cast(theta), 1, dim) : std::vector(); const int half_dim = dim / 2; const int pad_size = half_dim - static_cast(double_precision ? indices_d.size() : indices.size()); std::vector> freqs(static_cast(seq_len), std::vector(half_dim, 0.f)); for (int64_t pos = 0; pos < seq_len; pos++) { int out_idx = 0; for (int i = 0; i < pad_size; i++) { freqs[static_cast(pos)][out_idx++] = 0.f; } if (double_precision) { double coord = static_cast(pos) / static_cast(positional_scale); for (double index : indices_d) { freqs[static_cast(pos)][out_idx++] = static_cast(index * (coord * 2.0 - 1.0)); } } else { float coord = static_cast(pos) / positional_scale; for (float index : indices) { freqs[static_cast(pos)][out_idx++] = index * (coord * 2.f - 1.f); } } } if (num_heads > 1) { return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); } return build_rope_matrix_from_frequencies(freqs, dim); } __STATIC_INLINE__ ggml_tensor* apply_hidden_rope(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe, int64_t heads, int64_t dim_head, bool rope_interleaved) { GGML_ASSERT(x->ne[0] == heads * dim_head); auto x4 = ggml_reshape_4d(ctx, x, dim_head, heads, x->ne[1], x->ne[2]); if (pe != nullptr && pe->ne[3] == x->ne[1] * heads) { auto x_flat = ggml_reshape_4d(ctx, x4, dim_head, 1, x->ne[1] * heads, x->ne[2]); auto out_flat = Rope::apply_rope(ctx, x_flat, pe, rope_interleaved); auto out4 = ggml_reshape_4d(ctx, out_flat, dim_head, heads, x->ne[1], x->ne[2]); return ggml_reshape_3d(ctx, out4, heads * dim_head, x->ne[1], x->ne[2]); } return Rope::apply_rope(ctx, x4, pe, rope_interleaved); } struct TimestepEmbedder : public GGMLBlock { int frequency_embedding_size; TimestepEmbedder(int64_t hidden_size, int frequency_embedding_size = 256) : frequency_embedding_size(frequency_embedding_size) { blocks["linear_1"] = std::make_shared(frequency_embedding_size, hidden_size, true, true); blocks["linear_2"] = std::make_shared(hidden_size, hidden_size, true, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, frequency_embedding_size); t_emb = linear_1->forward(ctx, t_emb); t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb); t_emb = linear_2->forward(ctx, t_emb); return t_emb; } }; struct AdaLayerNormSingle : public GGMLBlock { int64_t embedding_dim; int64_t embedding_coefficient; AdaLayerNormSingle(int64_t embedding_dim, int64_t embedding_coefficient = 6) : embedding_dim(embedding_dim), embedding_coefficient(embedding_coefficient) { blocks["emb.timestep_embedder"] = std::make_shared(embedding_dim); blocks["linear"] = std::make_shared(embedding_dim, embedding_coefficient * embedding_dim, true, true); } std::pair forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { auto timestep_embedder = std::dynamic_pointer_cast(blocks["emb.timestep_embedder"]); auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto embedded_timestep = timestep_embedder->forward(ctx, timestep); auto hidden = ggml_silu(ctx->ggml_ctx, embedded_timestep); auto out = linear->forward(ctx, hidden); return {out, embedded_timestep}; } }; struct PixArtAlphaTextProjection : public GGMLBlock { PixArtAlphaTextProjection(int64_t in_features, int64_t hidden_size, int64_t out_features = -1) { if (out_features < 0) { out_features = hidden_size; } blocks["linear_1"] = std::make_shared(in_features, hidden_size, true, true); blocks["linear_2"] = std::make_shared(hidden_size, out_features, true, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* caption) { auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); caption = linear_1->forward(ctx, caption); caption = ggml_ext_gelu(ctx->ggml_ctx, caption, true); caption = linear_2->forward(ctx, caption); return caption; } }; struct NormSingleLinearTextProjection : public GGMLBlock { int64_t in_features; int64_t hidden_size; NormSingleLinearTextProjection(int64_t in_features, int64_t hidden_size) : in_features(in_features), hidden_size(hidden_size) { blocks["linear_1"] = std::make_shared(in_features, hidden_size, true, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* caption) { auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); caption = ggml_rms_norm(ctx->ggml_ctx, caption, 1e-6f); caption = ggml_ext_scale(ctx->ggml_ctx, caption, std::sqrt(static_cast(hidden_size) / static_cast(in_features))); return linear_1->forward(ctx, caption); } }; struct CrossAttention : public GGMLBlock { int64_t heads; int64_t dim_head; bool rope_interleaved; CrossAttention(int64_t query_dim, int64_t context_dim, int64_t heads, int64_t dim_head, bool apply_gated_attention = false, bool rope_interleaved = true) : heads(heads), dim_head(dim_head), rope_interleaved(rope_interleaved) { int64_t inner_dim = heads * dim_head; blocks["q_norm"] = std::make_shared(inner_dim, 1e-5f); blocks["k_norm"] = std::make_shared(inner_dim, 1e-5f); blocks["to_q"] = std::make_shared(query_dim, inner_dim, true); blocks["to_k"] = std::make_shared(context_dim, inner_dim, true); blocks["to_v"] = std::make_shared(context_dim, inner_dim, true); if (apply_gated_attention) { blocks["to_gate_logits"] = std::make_shared(query_dim, heads, true); } blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* context = nullptr, ggml_tensor* mask = nullptr, ggml_tensor* pe = nullptr, ggml_tensor* k_pe = nullptr) { if (context == nullptr) { context = x; } auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); auto q = to_q->forward(ctx, x); auto k = to_k->forward(ctx, context); auto v = to_v->forward(ctx, context); q = q_norm->forward(ctx, q); k = k_norm->forward(ctx, k); if (pe != nullptr) { if (k_pe == nullptr) { k_pe = pe; } q = apply_hidden_rope(ctx->ggml_ctx, q, pe, heads, dim_head, rope_interleaved); k = apply_hidden_rope(ctx->ggml_ctx, k, k_pe, heads, dim_head, rope_interleaved); } auto out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, heads, mask, false, ctx->flash_attn_enabled); if (blocks.count("to_gate_logits") > 0) { auto to_gate_logits = std::dynamic_pointer_cast(blocks["to_gate_logits"]); auto gate_logits = to_gate_logits->forward(ctx, x); auto gates = ggml_sigmoid(ctx->ggml_ctx, gate_logits); gates = ggml_ext_scale(ctx->ggml_ctx, gates, 2.0f, true); gates = ggml_reshape_4d(ctx->ggml_ctx, gates, 1, heads, gate_logits->ne[1], gate_logits->ne[2]); auto out4 = ggml_reshape_4d(ctx->ggml_ctx, out, dim_head, heads, out->ne[1], out->ne[2]); gates = ggml_repeat(ctx->ggml_ctx, gates, out4); out4 = ggml_mul(ctx->ggml_ctx, out4, gates); out = ggml_reshape_3d(ctx->ggml_ctx, out4, heads * dim_head, out4->ne[2], out4->ne[3]); } return to_out_0->forward(ctx, out); } }; struct BasicTransformerBlock : public GGMLBlock { int64_t dim; bool cross_attention_adaln; bool self_attention_gated; bool cross_attention_gated; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { ggml_type wtype = get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32); params["scale_shift_table"] = ggml_new_tensor_2d(ctx, wtype, dim, cross_attention_adaln ? 9 : 6); if (cross_attention_adaln) { ggml_type prompt_wtype = get_type(prefix + "prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, prompt_wtype, dim, 2); } } BasicTransformerBlock(int64_t dim, int64_t n_heads, int64_t d_head, int64_t context_dim, bool rope_interleaved = true, bool cross_attention_adaln = false, bool self_attention_gated = false, bool cross_attention_gated = false) : dim(dim), cross_attention_adaln(cross_attention_adaln), self_attention_gated(self_attention_gated), cross_attention_gated(cross_attention_gated) { blocks["attn1"] = std::make_shared(dim, dim, n_heads, d_head, self_attention_gated, rope_interleaved); blocks["attn2"] = std::make_shared(dim, context_dim, n_heads, d_head, cross_attention_gated, false); blocks["ff"] = std::make_shared(dim, dim, 4, FeedForward::Activation::GELU); } std::vector get_scale_shift_values(GGMLRunnerContext* ctx, ggml_tensor* timestep) { auto table = params["scale_shift_table"]; int64_t batch = timestep->ne[1]; int64_t coeff = cross_attention_adaln ? 9 : 6; auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, batch); auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1); auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, batch); s = ggml_repeat(ctx->ggml_ctx, s, e); t = ggml_repeat(ctx->ggml_ctx, t, e); auto out = ggml_add(ctx->ggml_ctx, s, t); return ggml_ext_chunk(ctx->ggml_ctx, out, static_cast(coeff), 1); } std::vector get_prompt_scale_shift_values(GGMLRunnerContext* ctx, ggml_tensor* prompt_timestep) { auto table = params["prompt_scale_shift_table"]; int64_t batch = prompt_timestep->ne[1]; auto t = ggml_reshape_3d(ctx->ggml_ctx, prompt_timestep, dim, 2, batch); auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, 2, 1); auto e = ggml_new_tensor_3d(ctx->ggml_ctx, prompt_timestep->type, dim, 2, batch); s = ggml_repeat(ctx->ggml_ctx, s, e); t = ggml_repeat(ctx->ggml_ctx, t, e); auto out = ggml_add(ctx->ggml_ctx, s, t); return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* context, ggml_tensor* timestep, ggml_tensor* prompt_timestep, ggml_tensor* pe, ggml_tensor* attention_mask = nullptr, ggml_tensor* self_attention_mask = nullptr) { auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto ff = std::dynamic_pointer_cast(blocks["ff"]); auto mods = get_scale_shift_values(ctx, timestep); auto shift_msa = mods[0]; auto scale_msa = mods[1]; auto gate_msa = mods[2]; auto shift_mlp = mods[3]; auto scale_mlp = mods[4]; auto gate_mlp = mods[5]; auto x_norm = rms_norm(ctx->ggml_ctx, x); x_norm = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa); auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa)); if (cross_attention_adaln) { auto shift_q = mods[6]; auto scale_q = mods[7]; auto gate_q = mods[8]; auto q = rms_norm(ctx->ggml_ctx, x); q = modulate(ctx->ggml_ctx, q, shift_q, scale_q); auto context_mod = context; if (prompt_timestep != nullptr) { auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep); context_mod = modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1]); } auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mca, gate_q)); } else { auto mca = attn2->forward(ctx, x, context, attention_mask, nullptr, nullptr); x = ggml_add(ctx->ggml_ctx, x, mca); } auto y = rms_norm(ctx->ggml_ctx, x); y = modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp); auto mlp_out = ff->forward(ctx, y); x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } }; struct BasicTransformerBlock1D : public GGMLBlock { BasicTransformerBlock1D(int64_t dim, int64_t n_heads, int64_t d_head, bool rope_interleaved, bool apply_gated_attention = false) { blocks["attn1"] = std::make_shared(dim, dim, n_heads, d_head, apply_gated_attention, rope_interleaved); blocks["ff"] = std::make_shared(dim, dim, 4, FeedForward::Activation::GELU); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* pe, ggml_tensor* attention_mask = nullptr) { auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); auto ff = std::dynamic_pointer_cast(blocks["ff"]); auto h = rms_norm(ctx->ggml_ctx, x); h = attn1->forward(ctx, h, nullptr, attention_mask, pe); x = ggml_add(ctx->ggml_ctx, x, h); h = rms_norm(ctx->ggml_ctx, x); h = ff->forward(ctx, h); x = ggml_add(ctx->ggml_ctx, x, h); return x; } }; struct Embeddings1DConnector : public GGMLBlock { int64_t hidden_size; int64_t num_attention_heads; int64_t attention_head_dim; int64_t num_layers; int64_t num_learnable_registers; bool rope_interleaved; bool apply_gated_attention; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { if (num_learnable_registers > 0) { ggml_type wtype = get_type(prefix + "learnable_registers", tensor_storage_map, GGML_TYPE_F32); params["learnable_registers"] = ggml_new_tensor_2d(ctx, wtype, hidden_size, num_learnable_registers); } } Embeddings1DConnector(int64_t hidden_size, int64_t num_attention_heads = 30, int64_t attention_head_dim = 128, int64_t num_layers = 2, int64_t num_learnable_registers = 128, bool rope_interleaved = false, bool apply_gated_attention = false) : hidden_size(hidden_size), num_attention_heads(num_attention_heads), attention_head_dim(attention_head_dim), num_layers(num_layers), num_learnable_registers(num_learnable_registers), rope_interleaved(rope_interleaved), apply_gated_attention(apply_gated_attention) { for (int i = 0; i < num_layers; i++) { blocks["transformer_1d_blocks." + std::to_string(i)] = std::make_shared(hidden_size, num_attention_heads, attention_head_dim, rope_interleaved, apply_gated_attention); } } ggml_tensor* append_registers(GGMLRunnerContext* ctx, ggml_tensor* hidden_states) { if (num_learnable_registers <= 0 || params.count("learnable_registers") == 0) { return hidden_states; } int64_t seq_len = hidden_states->ne[1]; int64_t target_len = std::max(1024, seq_len); int64_t duplications = (target_len + num_learnable_registers - 1) / num_learnable_registers; int64_t total_to_keep = duplications * num_learnable_registers - seq_len; if (total_to_keep <= 0) { return hidden_states; } auto regs = ggml_reshape_3d(ctx->ggml_ctx, params["learnable_registers"], hidden_size, num_learnable_registers, 1); auto temp = ggml_new_tensor_3d(ctx->ggml_ctx, regs->type, regs->ne[0], regs->ne[1], hidden_states->ne[2]); regs = ggml_repeat(ctx->ggml_ctx, regs, temp); auto regs_full = regs; for (int64_t i = 1; i < duplications; i++) { regs_full = ggml_concat(ctx->ggml_ctx, regs_full, regs, 1); } regs_full = ggml_ext_slice(ctx->ggml_ctx, regs_full, 1, seq_len, seq_len + total_to_keep); return ggml_concat(ctx->ggml_ctx, hidden_states, regs_full, 1); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* hidden_states, ggml_tensor* pe, ggml_tensor* attention_mask = nullptr) { hidden_states = append_registers(ctx, hidden_states); for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_1d_blocks." + std::to_string(i)]); hidden_states = block->forward(ctx, hidden_states, pe, attention_mask); } return ggml_rms_norm(ctx->ggml_ctx, hidden_states, 1e-6f); } }; __STATIC_INLINE__ std::pair infer_attention_layout(int64_t hidden_size, int64_t preferred_heads = -1) { if (preferred_heads > 0 && hidden_size % preferred_heads == 0) { return {preferred_heads, hidden_size / preferred_heads}; } const int candidates[] = {128, 96, 80, 64, 48, 40, 32}; for (int head_dim : candidates) { if (hidden_size % head_dim == 0) { int64_t heads = hidden_size / head_dim; if (heads >= 8 && heads <= 64) { return {heads, head_dim}; } } } return {32, hidden_size / 32}; } __STATIC_INLINE__ std::vector build_1d_rope_matrix_from_coords(const std::vector& coords, int dim, int num_heads = 1, float theta = 10000.f, float max_pos = 20.f, bool double_precision = false) { GGML_ASSERT(dim % num_heads == 0); const std::vector indices = double_precision ? std::vector() : generate_freq_grid(theta, 1, dim); const std::vector indices_d = double_precision ? generate_freq_grid_double(static_cast(theta), 1, dim) : std::vector(); const int half_dim = dim / 2; const int pad_size = half_dim - static_cast(double_precision ? indices_d.size() : indices.size()); std::vector> freqs(coords.size(), std::vector(half_dim, 0.f)); for (size_t pos = 0; pos < coords.size(); pos++) { int out_idx = 0; for (int i = 0; i < pad_size; i++) { freqs[pos][out_idx++] = 0.f; } if (double_precision) { double coord = static_cast(coords[pos]) / static_cast(max_pos); for (double index : indices_d) { freqs[pos][out_idx++] = static_cast(index * (coord * 2.0 - 1.0)); } } else { float coord = coords[pos] / max_pos; for (float index : indices) { freqs[pos][out_idx++] = index * (coord * 2.f - 1.f); } } } if (num_heads > 1) { return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); } return build_rope_matrix_from_frequencies(freqs, dim); } __STATIC_INLINE__ float video_latent_corner_to_time_sec(int64_t corner_index, int scale_t, float frame_rate, bool causal_temporal_positioning) { float pixel_t = static_cast(corner_index * scale_t); if (causal_temporal_positioning) { pixel_t = std::max(0.f, pixel_t + 1.f - scale_t); } return pixel_t / frame_rate; } __STATIC_INLINE__ std::vector build_video_temporal_rope_matrix(int64_t width, int64_t height, int64_t frames, int dim, int num_heads, float frame_rate, float theta, int max_pos_t, int scale_t, bool causal_temporal_positioning, bool use_middle_indices_grid) { std::vector coords; coords.reserve(static_cast(width * height * frames)); for (int64_t t = 0; t < frames; t++) { float coord = video_latent_corner_to_time_sec(t, scale_t, frame_rate, causal_temporal_positioning); if (use_middle_indices_grid) { float end = video_latent_corner_to_time_sec(t + 1, scale_t, frame_rate, causal_temporal_positioning); coord = 0.5f * (coord + end); } for (int64_t h = 0; h < height; h++) { for (int64_t w = 0; w < width; w++) { coords.push_back(coord); } } } return build_1d_rope_matrix_from_coords(coords, dim, num_heads, theta, static_cast(max_pos_t)); } __STATIC_INLINE__ std::vector build_video_temporal_rope_matrix_from_positions(const sd::Tensor& positions, int dim, int num_heads, float theta, int max_pos_t, bool use_middle_indices_grid) { GGML_ASSERT(positions.dim() == 3 || positions.dim() == 4); GGML_ASSERT(positions.shape()[0] == 2); GGML_ASSERT(positions.shape()[1] >= 1); if (positions.dim() == 4) { GGML_ASSERT(positions.shape()[3] == 1); } std::vector coords; coords.reserve(static_cast(positions.shape()[2])); for (int64_t token = 0; token < positions.shape()[2]; token++) { float start = positions.dim() == 4 ? positions.index(0, 0, token, 0) : positions.index(0, 0, token); float end = positions.dim() == 4 ? positions.index(1, 0, token, 0) : positions.index(1, 0, token); coords.push_back(use_middle_indices_grid ? 0.5f * (start + end) : start); } return build_1d_rope_matrix_from_coords(coords, dim, num_heads, theta, static_cast(max_pos_t)); } __STATIC_INLINE__ float audio_latent_start_time_sec(int64_t latent_index, int audio_latent_downsample_factor = 4, int hop_length = 160, int sample_rate = 16000, bool causal = true) { float mel_frame = static_cast(latent_index * audio_latent_downsample_factor); if (causal) { mel_frame = std::max(0.f, mel_frame + 1.f - static_cast(audio_latent_downsample_factor)); } return mel_frame * static_cast(hop_length) / static_cast(sample_rate); } __STATIC_INLINE__ std::vector build_audio_rope_matrix(int64_t seq_len, int dim, int num_heads, float theta = 10000.f, int max_pos_t = 20, bool use_middle_indices_grid = false) { std::vector coords(static_cast(seq_len), 0.f); for (int64_t t = 0; t < seq_len; t++) { float start = audio_latent_start_time_sec(t); if (use_middle_indices_grid) { float end = audio_latent_start_time_sec(t + 1); coords[static_cast(t)] = 0.5f * (start + end); } else { coords[static_cast(t)] = start; } } return build_1d_rope_matrix_from_coords(coords, dim, num_heads, theta, static_cast(max_pos_t)); } struct BasicAVTransformerBlock : public GGMLBlock { int64_t v_dim; int64_t a_dim; bool cross_attention_adaln; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { int64_t coeff = cross_attention_adaln ? 9 : 6; ggml_type vw = get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32); ggml_type aw = get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); params["scale_shift_table"] = ggml_new_tensor_2d(ctx, vw, v_dim, coeff); params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, aw, a_dim, coeff); if (cross_attention_adaln) { ggml_type vpw = get_type(prefix + "prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); ggml_type apw = get_type(prefix + "audio_prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, vpw, v_dim, 2); params["audio_prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, apw, a_dim, 2); } ggml_type avw = get_type(prefix + "scale_shift_table_a2v_ca_audio", tensor_storage_map, GGML_TYPE_F32); ggml_type vaw = get_type(prefix + "scale_shift_table_a2v_ca_video", tensor_storage_map, GGML_TYPE_F32); params["scale_shift_table_a2v_ca_audio"] = ggml_new_tensor_2d(ctx, avw, a_dim, 5); params["scale_shift_table_a2v_ca_video"] = ggml_new_tensor_2d(ctx, vaw, v_dim, 5); } BasicAVTransformerBlock(int64_t v_dim, int64_t a_dim, int64_t v_heads, int64_t a_heads, int64_t vd_head, int64_t ad_head, int64_t v_context_dim, int64_t a_context_dim, bool apply_gated_attention, bool cross_attention_adaln, bool video_rope_interleaved) : v_dim(v_dim), a_dim(a_dim), cross_attention_adaln(cross_attention_adaln) { blocks["attn1"] = std::make_shared(v_dim, v_dim, v_heads, vd_head, apply_gated_attention, video_rope_interleaved); blocks["audio_attn1"] = std::make_shared(a_dim, a_dim, a_heads, ad_head, apply_gated_attention, false); blocks["attn2"] = std::make_shared(v_dim, v_context_dim, v_heads, vd_head, apply_gated_attention, false); blocks["audio_attn2"] = std::make_shared(a_dim, a_context_dim, a_heads, ad_head, apply_gated_attention, false); blocks["audio_to_video_attn"] = std::make_shared(v_dim, a_dim, a_heads, ad_head, apply_gated_attention, false); blocks["video_to_audio_attn"] = std::make_shared(a_dim, v_dim, a_heads, ad_head, apply_gated_attention, false); blocks["ff"] = std::make_shared(v_dim, v_dim, 4, FeedForward::Activation::GELU); blocks["audio_ff"] = std::make_shared(a_dim, a_dim, 4, FeedForward::Activation::GELU); } std::vector get_ada_values(GGMLRunnerContext* ctx, ggml_tensor* table, ggml_tensor* timestep, int64_t dim, int64_t coeff, int64_t start = 0, int64_t count = -1) { if (count < 0) { count = coeff - start; } auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, timestep->ne[1]); auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1); auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, timestep->ne[1]); t = ggml_repeat(ctx->ggml_ctx, t, e); s = ggml_repeat(ctx->ggml_ctx, s, e); auto out = ggml_add(ctx->ggml_ctx, s, t); auto chunks = ggml_ext_chunk(ctx->ggml_ctx, out, static_cast(coeff), 1); return std::vector(chunks.begin() + start, chunks.begin() + start + count); } ggml_tensor* apply_text_cross_attention(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* context, CrossAttention* attn, ggml_tensor* table, ggml_tensor* prompt_table, ggml_tensor* timestep, ggml_tensor* prompt_timestep, int64_t dim, ggml_tensor* attention_mask) { if (cross_attention_adaln) { auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3); auto q = rms_norm(ctx->ggml_ctx, x); q = modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1]); auto context_mod = context; if (prompt_timestep != nullptr && prompt_table != nullptr) { auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2); context_mod = modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1]); } auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); return apply_gate(ctx->ggml_ctx, out, q_mods[2]); } auto q = rms_norm(ctx->ggml_ctx, x); return attn->forward(ctx, q, context, attention_mask, nullptr, nullptr); } std::pair forward(GGMLRunnerContext* ctx, ggml_tensor* vx, ggml_tensor* ax, ggml_tensor* v_context, ggml_tensor* a_context, ggml_tensor* attention_mask, ggml_tensor* v_timestep, ggml_tensor* a_timestep, ggml_tensor* v_pe, ggml_tensor* a_pe, ggml_tensor* v_cross_pe, ggml_tensor* a_cross_pe, ggml_tensor* v_cross_scale_shift_timestep, ggml_tensor* a_cross_scale_shift_timestep, ggml_tensor* v_cross_gate_timestep, ggml_tensor* a_cross_gate_timestep, ggml_tensor* v_prompt_timestep, ggml_tensor* a_prompt_timestep, ggml_tensor* self_attention_mask = nullptr) { auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); auto audio_attn1 = std::dynamic_pointer_cast(blocks["audio_attn1"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto audio_attn2 = std::dynamic_pointer_cast(blocks["audio_attn2"]); auto audio_to_video_attn = std::dynamic_pointer_cast(blocks["audio_to_video_attn"]); auto video_to_audio_attn = std::dynamic_pointer_cast(blocks["video_to_audio_attn"]); auto ff = std::dynamic_pointer_cast(blocks["ff"]); auto audio_ff = std::dynamic_pointer_cast(blocks["audio_ff"]); auto v_table = params["scale_shift_table"]; auto a_table = params["audio_scale_shift_table"]; bool run_ax = ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0; bool run_a2v = run_ax; bool run_v2a = run_ax; auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6); auto v_norm = rms_norm(ctx->ggml_ctx, vx); v_norm = modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1]); auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2])); auto v_txt = apply_text_cross_attention(ctx, vx, v_context, attn2.get(), v_table, cross_attention_adaln ? params["prompt_scale_shift_table"] : nullptr, v_timestep, v_prompt_timestep, v_dim, attention_mask); vx = ggml_add(ctx->ggml_ctx, vx, v_txt); if (run_ax) { auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6); auto a_norm = rms_norm(ctx->ggml_ctx, ax); a_norm = modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1]); auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2])); auto a_txt = apply_text_cross_attention(ctx, ax, a_context, audio_attn2.get(), a_table, cross_attention_adaln ? params["audio_prompt_scale_shift_table"] : nullptr, a_timestep, a_prompt_timestep, a_dim, attention_mask); ax = ggml_add(ctx->ggml_ctx, ax, a_txt); auto vx_norm3 = rms_norm(ctx->ggml_ctx, vx); auto ax_norm3 = rms_norm(ctx->ggml_ctx, ax); if (run_a2v) { auto a2v_audio_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 0, 4); auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4); auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0]); auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0]); auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe); auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5); auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0]; vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, a2v_out, a2v_gate)); } if (run_v2a) { auto v2a_audio_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 0, 4); auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4); auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4); auto ax_scaled = modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2]); auto vx_scaled = modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2]); auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe); auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5); auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate)); } auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); ax_scaled = modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1]); auto a_ff_out = audio_ff->forward(ctx, ax_scaled); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2])); } auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto vx_scaled = rms_norm(ctx->ggml_ctx, vx); vx_scaled = modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1]); auto v_ff_out = ff->forward(ctx, vx_scaled); vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2])); return {vx, ax}; } }; struct LTXAVModelBlock : public GGMLBlock { LTXAVConfig config; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { params["scale_shift_table"] = ggml_new_tensor_2d(ctx, get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32), config.hidden_size, 2); params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32), config.audio_hidden_size, 2); } LTXAVModelBlock(const LTXAVConfig& config) : config(config) { blocks["patchify_proj"] = std::make_shared(config.in_channels, config.hidden_size, true, true); blocks["audio_patchify_proj"] = std::make_shared(config.audio_in_channels, config.audio_hidden_size, true, true); blocks["adaln_single"] = std::make_shared(config.hidden_size, config.cross_attention_adaln ? 9 : 6); blocks["audio_adaln_single"] = std::make_shared(config.audio_hidden_size, config.cross_attention_adaln ? 9 : 6); if (config.cross_attention_adaln) { blocks["prompt_adaln_single"] = std::make_shared(config.hidden_size, 2); blocks["audio_prompt_adaln_single"] = std::make_shared(config.audio_hidden_size, 2); } blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared(config.hidden_size, 4); blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared(config.hidden_size, 1); blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared(config.audio_hidden_size, 4); blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared(config.audio_hidden_size, 1); if (config.use_caption_projection) { if (config.caption_proj_before_connector) { if (config.caption_projection_first_linear) { blocks["caption_projection"] = std::make_shared(config.caption_channels, config.hidden_size); } } else { blocks["caption_projection"] = std::make_shared(config.caption_channels, config.hidden_size, config.hidden_size); } } if (config.use_audio_caption_projection) { if (config.caption_proj_before_connector) { if (config.caption_projection_first_linear) { blocks["audio_caption_projection"] = std::make_shared(config.caption_channels, config.audio_hidden_size); } } else { blocks["audio_caption_projection"] = std::make_shared(config.caption_channels, config.audio_hidden_size, config.audio_hidden_size); } } if (config.use_connector) { blocks["video_embeddings_connector"] = std::make_shared(config.connector_hidden_size, config.connector_num_heads, config.connector_head_dim, config.connector_num_layers, config.connector_num_registers, config.connector_rope_interleaved, config.connector_apply_gated_attention); } if (config.use_audio_connector) { blocks["audio_embeddings_connector"] = std::make_shared(config.audio_connector_hidden_size, config.audio_connector_num_heads, config.audio_connector_head_dim, config.audio_connector_num_layers, config.audio_connector_num_registers, config.audio_connector_rope_interleaved, config.audio_connector_apply_gated_attention); } for (int i = 0; i < config.num_layers; i++) { blocks["transformer_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, config.audio_hidden_size, config.num_attention_heads, config.audio_num_attention_heads, config.attention_head_dim, config.audio_attention_head_dim, config.cross_attention_dim, config.audio_cross_attention_dim, config.self_attention_gated || config.cross_attention_gated, config.cross_attention_adaln, config.video_rope_interleaved); } blocks["norm_out"] = std::make_shared(config.hidden_size, 1e-6f, false); blocks["proj_out"] = std::make_shared(config.hidden_size, config.out_channels, true, true); blocks["audio_norm_out"] = std::make_shared(config.audio_hidden_size, 1e-6f, false); blocks["audio_proj_out"] = std::make_shared(config.audio_hidden_size, config.audio_out_channels, true, true); } ggml_tensor* patchify_video(GGMLRunnerContext* ctx, ggml_tensor* x, int64_t n) { x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / n, n); x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); return x; } ggml_tensor* unpatchify_video(GGMLRunnerContext* ctx, ggml_tensor* x, int64_t width, int64_t height, int64_t frames) { x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); x = ggml_reshape_4d(ctx->ggml_ctx, x, width, height, frames, x->ne[1] * x->ne[2]); return x; } ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) { // ax: [b, c, t, f] ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, t, c, f] ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[1], ax->ne[2], ax->ne[3]); // [b, t, c*f] return ax; } ggml_tensor* repeat_scalar_timestep_like(GGMLRunnerContext* ctx, ggml_tensor* timestep, ggml_tensor* like) { GGML_ASSERT(timestep != nullptr && like != nullptr); if (timestep->ne[0] == like->ne[0]) { return timestep; } GGML_ASSERT(timestep->ne[0] == 1); return ggml_repeat(ctx->ggml_ctx, timestep, ggml_new_tensor_1d(ctx->ggml_ctx, timestep->type, like->ne[0])); } ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) { if (ax == nullptr) { return nullptr; } ax = ggml_reshape_4d(ctx->ggml_ctx, ax, config.audio_frequency_bins, config.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f] ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f] return ax; } std::pair preprocess_contexts(GGMLRunnerContext* ctx, ggml_tensor* context, ggml_tensor* video_connector_pe, ggml_tensor* audio_connector_pe, bool process_audio_context) { if (context == nullptr) { return {nullptr, nullptr}; } bool is_fully_processed_context = context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim && context->ne[1] >= 1024; bool is_unprocessed_dual_context = context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim && context->ne[1] < 1024; if (is_fully_processed_context) { auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim); ggml_tensor* a_context = nullptr; if (process_audio_context) { a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim); } return {v_context, a_context}; } ggml_tensor* v_context = context; ggml_tensor* a_context = process_audio_context ? context : nullptr; if (is_unprocessed_dual_context) { v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim); if (process_audio_context) { a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim); } } else if (context->ne[0] == config.caption_channels * 2) { v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.caption_channels); if (process_audio_context) { a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.caption_channels, config.caption_channels * 2); } } if (config.caption_proj_before_connector) { if (config.use_caption_projection && blocks.count("caption_projection") > 0 && v_context != nullptr && v_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); if (caption_projection != nullptr) { v_context = caption_projection->forward(ctx, v_context); } } if (process_audio_context && config.use_audio_caption_projection && blocks.count("audio_caption_projection") > 0 && a_context != nullptr && a_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); if (caption_projection != nullptr) { a_context = caption_projection->forward(ctx, a_context); } } } if (config.use_connector && v_context != nullptr && v_context->ne[0] == config.connector_hidden_size) { auto connector = std::dynamic_pointer_cast(blocks["video_embeddings_connector"]); v_context = connector->forward(ctx, v_context, video_connector_pe); } if (process_audio_context && config.use_audio_connector && a_context != nullptr && a_context->ne[0] == config.audio_connector_hidden_size) { auto connector = std::dynamic_pointer_cast(blocks["audio_embeddings_connector"]); a_context = connector->forward(ctx, a_context, audio_connector_pe); } if (!config.caption_proj_before_connector && config.use_caption_projection && blocks.count("caption_projection") > 0 && v_context != nullptr && v_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); if (caption_projection != nullptr) { v_context = caption_projection->forward(ctx, v_context); } } if (process_audio_context && !config.caption_proj_before_connector && config.use_audio_caption_projection && blocks.count("audio_caption_projection") > 0 && a_context != nullptr && a_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); if (caption_projection != nullptr) { a_context = caption_projection->forward(ctx, a_context); } } return {v_context, a_context}; } std::vector get_output_scale_shift(GGMLRunnerContext* ctx, ggml_tensor* table, ggml_tensor* embedded_timestep, int64_t dim) { auto temp = ggml_new_tensor_3d(ctx->ggml_ctx, embedded_timestep->type, dim, 2, embedded_timestep->ne[1]); auto t = ggml_repeat(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, embedded_timestep, dim, 1, embedded_timestep->ne[1]), temp); auto s = ggml_repeat(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, table, dim, 2, 1), temp); auto out = ggml_add(ctx->ggml_ctx, s, t); return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1); } std::pair forward(GGMLRunnerContext* ctx, ggml_tensor* vx, ggml_tensor* ax, ggml_tensor* timestep, ggml_tensor* audio_timestep, ggml_tensor* context, ggml_tensor* v_pe, ggml_tensor* a_pe, ggml_tensor* v_cross_pe, ggml_tensor* a_cross_pe, ggml_tensor* video_connector_pe, ggml_tensor* audio_connector_pe) { auto patchify_proj = std::dynamic_pointer_cast(blocks["patchify_proj"]); auto audio_patchify_proj = std::dynamic_pointer_cast(blocks["audio_patchify_proj"]); auto adaln_single = std::dynamic_pointer_cast(blocks["adaln_single"]); auto audio_adaln_single = std::dynamic_pointer_cast(blocks["audio_adaln_single"]); auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); auto audio_norm_out = std::dynamic_pointer_cast(blocks["audio_norm_out"]); auto audio_proj_out = std::dynamic_pointer_cast(blocks["audio_proj_out"]); GGML_ASSERT(vx->ne[3] % config.in_channels == 0); int64_t n = vx->ne[3] / config.in_channels; int64_t width = vx->ne[0]; int64_t height = vx->ne[1]; int64_t frames = vx->ne[2]; int64_t audio_time = ax != nullptr ? ax->ne[1] : 0; vx = patchify_video(ctx, vx, n); vx = patchify_proj->forward(ctx, vx); if (ax != nullptr && ggml_nelements(ax) > 0 && audio_time > 0) { ax = patchify_audio(ctx, ax); ax = audio_patchify_proj->forward(ctx, ax); } else { ax = nullptr; } bool run_ax = ax != nullptr && ggml_nelements(ax) > 0 && audio_time > 0; auto contexts = preprocess_contexts(ctx, context, video_connector_pe, audio_connector_pe, run_ax); auto v_context = contexts.first; auto a_context = contexts.second != nullptr ? contexts.second : contexts.first; if (contexts.second != nullptr) { a_context = ggml_cont(ctx->ggml_ctx, a_context); } auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, config.timestep_scale_multiplier); auto v_pair = adaln_single->forward(ctx, v_timestep_scaled); auto v_timestep_mod = v_pair.first; auto v_embedded_time = v_pair.second; ggml_tensor* effective_audio_timestep = audio_timestep != nullptr ? audio_timestep : timestep; auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, config.timestep_scale_multiplier); auto a_pair = audio_adaln_single->forward(ctx, a_timestep_scaled); auto a_timestep_mod = a_pair.first; auto a_embedded_time = a_pair.second; ggml_tensor* v_prompt_timestep_mod = nullptr; ggml_tensor* a_prompt_timestep_mod = nullptr; if (config.cross_attention_adaln) { auto prompt_adaln_single = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); auto audio_prompt_adaln_single = std::dynamic_pointer_cast(blocks["audio_prompt_adaln_single"]); v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first; a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first; } auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep); auto av_ca_audio_timestep = effective_audio_timestep; auto av_ca_factor = config.av_ca_timestep_scale_multiplier / config.timestep_scale_multiplier; auto av_ca_video_scale_shift_timestep = std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first; auto av_ca_a2v_gate_noise_timestep = std::dynamic_pointer_cast(blocks["av_ca_a2v_gate_adaln_single"]) ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_video_timestep, av_ca_factor)) .first; auto av_ca_audio_scale_shift_timestep = std::dynamic_pointer_cast(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, av_ca_audio_timestep).first; auto av_ca_v2a_gate_noise_timestep = std::dynamic_pointer_cast(blocks["av_ca_v2a_gate_adaln_single"]) ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor)) .first; sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.prelude", "vx"); sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.prelude", "ax"); for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); auto out = block->forward(ctx, vx, ax, v_context, a_context, nullptr, v_timestep_mod, a_timestep_mod, v_pe, a_pe, v_cross_pe, a_cross_pe, av_ca_video_scale_shift_timestep, av_ca_audio_scale_shift_timestep, av_ca_a2v_gate_noise_timestep, av_ca_v2a_gate_noise_timestep, v_prompt_timestep_mod, a_prompt_timestep_mod); vx = out.first; ax = out.second; sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.transformer_blocks." + std::to_string(i), "vx"); sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.transformer_blocks." + std::to_string(i), "ax"); } auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, config.hidden_size); vx = norm_out->forward(ctx, vx); vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]); vx = proj_out->forward(ctx, vx); vx = unpatchify_video(ctx, vx, width, height, frames); if (ax != nullptr && audio_time > 0) { auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, config.audio_hidden_size); ax = audio_norm_out->forward(ctx, ax); ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]); ax = audio_proj_out->forward(ctx, ax); ax = unpatchify_audio(ctx, ax, audio_time); } return {vx, ax}; } }; struct LTXAVRunner : public DiffusionModelRunner { LTXAVConfig config; LTXAVModelBlock model; std::vector video_pe_vec; std::vector audio_pe_vec; std::vector video_cross_pe_vec; std::vector audio_cross_pe_vec; std::vector connector_pe_vec; std::vector audio_connector_pe_vec; sd::Tensor vx_input_cache; sd::Tensor ax_input_cache; LTXAVRunner(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string& prefix = "model.diffusion_model") : DiffusionModelRunner(backend, params_backend, prefix), config(LTXAVConfig::detect_from_weights(tensor_storage_map, prefix)), model(config) { model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { return "ltxav"; } void get_param_tensors(std::map& tensors, const std::string& prefix) override { model.get_param_tensors(tensors, prefix); } std::pair, sd::Tensor> split_av_latents(const sd::Tensor& x_tensor, int audio_length) const { if (x_tensor.empty()) { return {{}, {}}; } GGML_ASSERT(x_tensor.dim() == 4 || x_tensor.dim() == 5); if (x_tensor.dim() == 5) { GGML_ASSERT(x_tensor.shape()[4] == 1); } int64_t width = x_tensor.shape()[0]; int64_t height = x_tensor.shape()[1]; int64_t frames = x_tensor.shape()[2]; int64_t total_channels = x_tensor.shape()[3]; int64_t spatial_size = width * height * frames; GGML_ASSERT(total_channels >= config.in_channels); sd::Tensor vx({width, height, frames, config.in_channels}); size_t video_values = static_cast(config.in_channels * spatial_size); std::copy_n(x_tensor.data(), video_values, vx.data()); if (audio_length <= 0 || total_channels == config.in_channels) { return {vx, {}}; } int64_t needed_audio_values = static_cast(audio_length) * config.num_audio_channels * config.audio_frequency_bins; int64_t packed_audio_values = (total_channels - config.in_channels) * spatial_size; GGML_ASSERT(packed_audio_values >= needed_audio_values); sd::Tensor ax({config.audio_frequency_bins, audio_length, config.num_audio_channels, 1}); const float* audio_src = x_tensor.data() + video_values; std::copy_n(audio_src, static_cast(needed_audio_values), ax.data()); return {vx, ax}; } ggml_tensor* merge_av_latents(ggml_context* ctx, ggml_tensor* vx, ggml_tensor* ax) const { if (ax == nullptr || ggml_nelements(ax) == 0 || ax->ne[1] == 0) { return vx; } int64_t width = vx->ne[0]; int64_t height = vx->ne[1]; int64_t frames = vx->ne[2]; int64_t divisor = width * height * frames; int64_t audio_values = ax->ne[0] * ax->ne[1] * ax->ne[2] * ax->ne[3]; int64_t pad_values = (divisor - (audio_values % divisor)) % divisor; int64_t padded_len = audio_values + pad_values; ax = ggml_cont(ctx, ax); ax = ggml_reshape_4d(ctx, ax, audio_values, 1, 1, 1); if (pad_values > 0) { ax = ggml_ext_pad(ctx, ax, static_cast(pad_values), 0, 0, 0); } int64_t extra_channels = padded_len / divisor; ax = ggml_reshape_4d(ctx, ax, width, height, frames, extra_channels); return ggml_concat(ctx, vx, ax, 3); } ggml_cgraph* build_graph(const sd::Tensor& x_tensor, const sd::Tensor& timesteps_tensor, const sd::Tensor& context_tensor = {}, const sd::Tensor& audio_x_tensor = {}, const sd::Tensor& audio_timesteps_tensor = {}, int audio_length = 0, float frame_rate = 24.f, const sd::Tensor& video_positions_tensor = {}) { auto split_inputs = split_av_latents(x_tensor, audio_length); vx_input_cache = split_inputs.first; if (!audio_x_tensor.empty()) { ax_input_cache = audio_x_tensor; } else { ax_input_cache = split_inputs.second; } ggml_tensor* vx = make_input(vx_input_cache); ggml_tensor* ax = make_optional_input(ax_input_cache); ggml_tensor* timesteps = make_input(timesteps_tensor); ggml_tensor* a_timestep = make_optional_input(audio_timesteps_tensor); ggml_tensor* context = make_optional_input(context_tensor); ggml_cgraph* gf = new_graph_custom(LTXAV_GRAPH_SIZE); float video_frame_rate = frame_rate > 0.f ? frame_rate : 24.f; int64_t video_token_count = vx->ne[0] * vx->ne[1] * vx->ne[2]; bool has_video_positions = !video_positions_tensor.empty(); if (has_video_positions) { GGML_ASSERT(video_positions_tensor.shape()[2] == video_token_count); video_pe_vec = build_video_rope_matrix_from_positions(video_positions_tensor, static_cast(config.hidden_size), static_cast(config.num_attention_heads), config.positional_embedding_theta, config.positional_embedding_max_pos, config.use_middle_indices_grid); } else { video_pe_vec = build_video_rope_matrix(vx->ne[0], vx->ne[1], vx->ne[2], static_cast(config.hidden_size), static_cast(config.num_attention_heads), video_frame_rate, config.positional_embedding_theta, config.positional_embedding_max_pos, config.vae_scale_factors, config.causal_temporal_positioning, config.use_middle_indices_grid); } auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.attention_head_dim / 2, video_token_count * config.num_attention_heads); ggml_set_name(video_pe, "ltxav_video_pe"); set_backend_tensor_data(video_pe, video_pe_vec.data()); ggml_tensor* audio_pe = nullptr; ggml_tensor* video_cross_pe = nullptr; ggml_tensor* audio_cross_pe = nullptr; if (ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0) { audio_pe_vec = build_audio_rope_matrix(ax->ne[1], static_cast(config.audio_hidden_size), static_cast(config.audio_num_attention_heads), config.positional_embedding_theta, config.audio_positional_embedding_max_pos[0], config.use_middle_indices_grid); audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads); ggml_set_name(audio_pe, "ltxav_audio_pe"); set_backend_tensor_data(audio_pe, audio_pe_vec.data()); int temporal_max_pos = std::max(config.positional_embedding_max_pos[0], config.audio_positional_embedding_max_pos[0]); if (has_video_positions) { video_cross_pe_vec = build_video_temporal_rope_matrix_from_positions(video_positions_tensor, static_cast(config.audio_cross_attention_dim), static_cast(config.audio_num_attention_heads), config.positional_embedding_theta, temporal_max_pos, true); } else { video_cross_pe_vec = build_video_temporal_rope_matrix(vx->ne[0], vx->ne[1], vx->ne[2], static_cast(config.audio_cross_attention_dim), static_cast(config.audio_num_attention_heads), video_frame_rate, config.positional_embedding_theta, temporal_max_pos, std::get<0>(config.vae_scale_factors), config.causal_temporal_positioning, true); } video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, video_token_count * config.audio_num_attention_heads); ggml_set_name(video_cross_pe, "ltxav_video_cross_pe"); set_backend_tensor_data(video_cross_pe, video_cross_pe_vec.data()); audio_cross_pe_vec = build_audio_rope_matrix(ax->ne[1], static_cast(config.audio_cross_attention_dim), static_cast(config.audio_num_attention_heads), config.positional_embedding_theta, temporal_max_pos, true); audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads); ggml_set_name(audio_cross_pe, "ltxav_audio_cross_pe"); set_backend_tensor_data(audio_cross_pe, audio_cross_pe_vec.data()); } bool needs_video_connector_pe = config.use_connector && context != nullptr && (context->ne[0] == config.connector_hidden_size || ((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim || context->ne[0] == config.caption_channels * 2) && context->ne[1] < 1024)); ggml_tensor* video_connector_pe = nullptr; if (needs_video_connector_pe) { int64_t seq_len = context->ne[1]; int64_t target_len = std::max(1024, seq_len); int64_t duplications = (target_len + config.connector_num_registers - 1) / config.connector_num_registers; int64_t full_len = seq_len + duplications * config.connector_num_registers - seq_len; connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(config.connector_hidden_size), static_cast(config.connector_num_heads), 10000.f, 4096.f, true); video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.connector_head_dim / 2, full_len * config.connector_num_heads); ggml_set_name(video_connector_pe, "ltxav_video_connector_pe"); set_backend_tensor_data(video_connector_pe, connector_pe_vec.data()); } bool run_audio_context = ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0; bool needs_audio_connector_pe = run_audio_context && config.use_audio_connector && context != nullptr && (context->ne[0] == config.audio_connector_hidden_size || ((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim || context->ne[0] == config.caption_channels * 2) && context->ne[1] < 1024)); ggml_tensor* audio_connector_pe = nullptr; if (needs_audio_connector_pe) { int64_t seq_len = context->ne[1]; int64_t target_len = std::max(1024, seq_len); int64_t duplications = (target_len + config.audio_connector_num_registers - 1) / config.audio_connector_num_registers; int64_t full_len = seq_len + duplications * config.audio_connector_num_registers - seq_len; audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(config.audio_connector_hidden_size), static_cast(config.audio_connector_num_heads), 10000.f, 4096.f, true); audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_connector_head_dim / 2, full_len * config.audio_connector_num_heads); ggml_set_name(audio_connector_pe, "ltxav_audio_connector_pe"); set_backend_tensor_data(audio_connector_pe, audio_connector_pe_vec.data()); } auto runner_ctx = get_context(); auto out_pair = model.forward(&runner_ctx, vx, ax, timesteps, a_timestep, context, video_pe, audio_pe, video_cross_pe, audio_cross_pe, video_connector_pe, audio_connector_pe); auto out = merge_av_latents(compute_ctx, out_pair.first, out_pair.second); ggml_build_forward_expand(gf, out); return gf; } sd::Tensor compute(int n_threads, const sd::Tensor& x, const sd::Tensor& timesteps, const sd::Tensor& context = {}, const sd::Tensor& audio_x = {}, const sd::Tensor& audio_timesteps = {}, int audio_length = 0, float frame_rate = 24.f, const sd::Tensor& video_positions = {}) { auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x, timesteps, context, audio_x, audio_timesteps, audio_length, frame_rate, video_positions); }; auto out = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), x.dim()); return out; } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); const auto* extra = diffusion_extra_as(diffusion_params); return compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(extra->audio_x), tensor_or_empty(extra->audio_timesteps), extra->audio_length, extra->frame_rate, tensor_or_empty(extra->video_positions)); } void test(const std::string& x_path, const std::string& timesteps_path = "", const std::string& context_path = "", const std::string& audio_x_path = "", const std::string& audio_timesteps_path = "") { auto x = sd::load_tensor_from_file_as_tensor(x_path); GGML_ASSERT(!x.empty()); print_sd_tensor(x, false, "ltxav_x"); sd::Tensor timesteps; if (!timesteps_path.empty()) { timesteps = sd::load_tensor_from_file_as_tensor(timesteps_path); } else { timesteps = sd::Tensor::from_vector(std::vector{1.f}); } GGML_ASSERT(!timesteps.empty()); print_sd_tensor(timesteps, false, "ltxav_timesteps"); sd::Tensor context; if (!context_path.empty()) { context = sd::load_tensor_from_file_as_tensor(context_path); GGML_ASSERT(!context.empty()); print_sd_tensor(context, false, "ltxav_context"); } sd::Tensor audio_x; int audio_length = 0; if (!audio_x_path.empty()) { audio_x = sd::load_tensor_from_file_as_tensor(audio_x_path); GGML_ASSERT(!audio_x.empty()); GGML_ASSERT(audio_x.dim() >= 2); audio_length = static_cast(audio_x.shape()[1]); print_sd_tensor(audio_x, false, "ltxav_audio_x"); } sd::Tensor audio_timesteps; if (!audio_timesteps_path.empty()) { audio_timesteps = sd::load_tensor_from_file_as_tensor(audio_timesteps_path); GGML_ASSERT(!audio_timesteps.empty()); } else if (!audio_x.empty()) { audio_timesteps = timesteps; } if (!audio_timesteps.empty()) { print_sd_tensor(audio_timesteps, false, "ltxav_audio_timesteps"); } int64_t t0 = ggml_time_ms(); auto out_opt = compute(8, x, timesteps, context, audio_x, audio_timesteps, audio_length); int64_t t1 = ggml_time_ms(); GGML_ASSERT(!out_opt.empty()); print_sd_tensor(out_opt, false, "ltxav_out"); LOG_DEBUG("ltxav test done in %lldms", t1 - t0); } static void load_from_file_and_test(const std::string& model_path, const std::string& x_path, const std::string& timesteps_path = "", const std::string& context_path = "", const std::string& embeddings_path = "", const std::string& audio_x_path = "", const std::string& audio_timesteps_path = "") { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = sd_backend_cpu_init(); LOG_INFO("loading ltxav from '%s'", model_path.c_str()); ModelLoader model_loader; if (!model_loader.init_from_file_and_convert_name(model_path, "model.diffusion_model.")) { LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); return; } if (!embeddings_path.empty()) { LOG_INFO("loading ltxav embeddings from '%s'", embeddings_path.c_str()); if (!model_loader.init_from_file(embeddings_path)) { LOG_ERROR("init embeddings model loader from file failed: '%s'", embeddings_path.c_str()); return; } } auto& tensor_storage_map = model_loader.get_tensor_storage_map(); std::shared_ptr ltxav = std::make_shared(backend, backend, tensor_storage_map, "model.diffusion_model"); if (!ltxav->alloc_params_buffer()) { LOG_ERROR("ltxav buffer allocation failed"); return; } std::map tensors; ltxav->get_param_tensors(tensors, "model.diffusion_model"); if (!model_loader.load_tensors(tensors)) { LOG_ERROR("load tensors from model loader failed"); return; } LOG_INFO("ltxav model loaded"); ltxav->test(x_path, timesteps_path, context_path, audio_x_path, audio_timesteps_path); } }; }; // namespace LTXV #endif // __SD_MODEL_DIFFUSION_LTXV_HPP__