refactor: unify model config detection (#1613)

This commit is contained in:
leejet 2026-06-07 01:05:12 +08:00 committed by GitHub
parent b9254dda0d
commit cfbc19d186
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2124 additions and 1637 deletions

View File

@ -44,6 +44,8 @@ Naming conventions:
Some older code in the project may not fully follow the current conventions. Please do not submit PRs that only rewrite existing code to match style rules.
When adding or modifying model implementations, follow the model config and weight detection conventions in [docs/model_config.md](docs/model_config.md).
## AI-Assisted Contributions
AI tools may be used to assist development, but contributors are responsible for the quality and correctness of the submitted code.

118
docs/model_config.md Normal file
View File

@ -0,0 +1,118 @@
# Model Configuration Conventions
This document describes the conventions for model configuration structs and
weight-based configuration detection.
## Config Types
Model configuration should live in a model-specific `*Config` struct.
Examples:
- `ZImageConfig`
- `UNetConfig`
- `MMDiTConfig`
- `LLMConfig`
Preserve established acronym casing in type names, such as `UNet`, `MMDiT`,
`LLM`, `VAE`, and `T5`.
Place the config struct near the top of the model header, before the main model
blocks and runner types that consume it.
## Config Variables
Variables and members that hold a config should be named `config`.
Examples:
```cpp
UNetConfig config;
UnetModelBlock unet;
MMDiTRunner(...)
: DiffusionModelRunner(backend, params_backend, prefix),
config(MMDiTConfig::detect_from_weights(tensor_storage_map, prefix)),
mmdit(config) {
}
```
Avoid alternate names such as `params`, `params_cfg`, `model_params`, or
model-specific aliases unless an existing public API requires them.
## Weight Detection
If a model can derive configuration from loaded weight metadata, expose that
logic as a static method on the config type:
```cpp
static XxxConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix);
```
Additional selector arguments are allowed when required by an existing model
family, for example `SDVersion version` or an architecture enum:
```cpp
static UNetConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
SDVersion version = VERSION_SD1);
```
Use `TensorStorage` metadata, especially `n_dims` and `ne`, to infer shapes.
Do not load or parse tensor data for config detection.
Detection should respect `prefix`. For nested weights, construct full names from
`prefix + "." + suffix` or filter entries with `starts_with(name, prefix)`.
Do not add persistent config fields such as `inferred_from_weights` only to
record whether detection happened. If the function needs to decide whether to
print a debug line, keep that as local control flow inside `detect_from_weights`.
## Logging
When config values are inferred from weights, print one `LOG_DEBUG` line at the
end of `detect_from_weights`.
Example:
```cpp
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
config.num_layers,
config.vocab_size,
config.hidden_size,
config.intermediate_size);
```
Only print the config detection log when the function actually inferred values
from weights. Do not duplicate the same config summary in runner constructors or
model loading code.
Use the correct format specifiers for field types, such as `%" PRId64 "` for
`int64_t` and `%d` for `int`.
## Runner And Model Responsibilities
Runners should detect the config once and pass it into the model block:
```cpp
struct XxxRunner : public DiffusionModelRunner {
XxxConfig config;
XxxModel model;
XxxRunner(..., const String2TensorStorage& tensor_storage_map, const std::string prefix)
: DiffusionModelRunner(backend, params_backend, prefix),
config(XxxConfig::detect_from_weights(tensor_storage_map, prefix)),
model(config) {
model.init(params_ctx, tensor_storage_map, prefix);
}
};
```
Model blocks should consume `config` directly instead of re-scanning weights in
their constructors. Keep config-derived behavior centralized in the config
struct.
If a model has no weight-derived config today, it may still provide
`detect_from_weights` for API consistency, but it should not print a config
detection log unless it actually derives values from weights.

View File

@ -1,6 +1,7 @@
#ifndef __ANIMA_HPP__
#define __ANIMA_HPP__
#include <algorithm>
#include <cmath>
#include <memory>
#include <utility>
@ -14,6 +15,47 @@
namespace Anima {
constexpr int ANIMA_GRAPH_SIZE = 65536;
struct AnimaConfig {
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t hidden_size = 2048;
int64_t text_embed_dim = 1024;
int64_t num_heads = 16;
int64_t head_dim = 128;
int patch_size = 2;
int64_t num_layers = 28;
std::vector<int> axes_dim = {44, 42, 42};
int theta = 10000;
static AnimaConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
AnimaConfig config;
int64_t detected_layers = 0;
std::string layer_tag = prefix.empty() ? "blocks." : prefix + ".blocks.";
for (const auto& [name, _] : tensor_storage_map) {
size_t pos = name.find(layer_tag);
if (pos == std::string::npos) {
continue;
}
size_t start = pos + layer_tag.size();
size_t end = name.find('.', start);
if (end == std::string::npos) {
continue;
}
int64_t layer_id = atoll(name.substr(start, end - start).c_str());
detected_layers = std::max(detected_layers, layer_id + 1);
}
if (detected_layers > 0) {
config.num_layers = detected_layers;
LOG_DEBUG("anima: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", head_dim = %" PRId64,
config.num_layers,
config.hidden_size,
config.num_heads,
config.head_dim);
}
return config;
}
};
__STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* gate) {
@ -418,31 +460,22 @@ namespace Anima {
struct AnimaNet : public GGMLBlock {
public:
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t hidden_size = 2048;
int64_t text_embed_dim = 1024;
int64_t num_heads = 16;
int64_t head_dim = 128;
int patch_size = 2;
int64_t num_layers = 28;
std::vector<int> axes_dim = {44, 42, 42};
int theta = 10000;
AnimaConfig config;
public:
AnimaNet() = default;
explicit AnimaNet(int64_t num_layers)
: num_layers(num_layers) {
blocks["x_embedder"] = std::make_shared<XEmbedder>((in_channels + 1) * patch_size * patch_size, hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(hidden_size, hidden_size * 3);
blocks["t_embedding_norm"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::make_shared<TransformerBlock>(hidden_size,
text_embed_dim,
num_heads,
head_dim);
explicit AnimaNet(AnimaConfig config)
: config(config) {
blocks["x_embedder"] = std::make_shared<XEmbedder>((config.in_channels + 1) * config.patch_size * config.patch_size, config.hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(config.hidden_size, config.hidden_size * 3);
blocks["t_embedding_norm"] = std::make_shared<RMSNorm>(config.hidden_size, 1e-6f);
for (int i = 0; i < config.num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::make_shared<TransformerBlock>(config.hidden_size,
config.text_embed_dim,
config.num_heads,
config.head_dim);
}
blocks["final_layer"] = std::make_shared<FinalLayer>(hidden_size, patch_size, out_channels);
blocks["final_layer"] = std::make_shared<FinalLayer>(config.hidden_size, config.patch_size, config.out_channels);
blocks["llm_adapter"] = std::make_shared<LLMAdapter>(1024, 1024, 1024, 6, 16);
}
@ -469,11 +502,11 @@ namespace Anima {
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]);
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W]
x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw]
x = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size); // [N, h*w, (C+1)*ph*pw]
x = x_embedder->forward(ctx, x);
auto timestep_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast<int>(hidden_size));
auto timestep_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast<int>(config.hidden_size));
auto temb = t_embedder->forward(ctx, timestep_proj);
auto embedded_timestep = t_embedding_norm->forward(ctx, timestep_proj);
@ -505,7 +538,7 @@ namespace Anima {
sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb");
sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context");
for (int i = 0; i < num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x");
@ -513,7 +546,7 @@ namespace Anima {
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W]
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, config.patch_size, config.patch_size, false); // [N, C, H, W]
return x;
}
@ -524,35 +557,16 @@ namespace Anima {
std::vector<float> image_pe_vec;
std::vector<float> adapter_q_pe_vec;
std::vector<float> adapter_k_pe_vec;
AnimaConfig config;
AnimaNet net;
AnimaRunner(ggml_backend_t backend,
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: DiffusionModelRunner(backend, params_backend, prefix) {
int64_t num_layers = 0;
std::string layer_tag = prefix + ".net.blocks.";
for (const auto& kv : tensor_storage_map) {
const std::string& tensor_name = kv.first;
size_t pos = tensor_name.find(layer_tag);
if (pos == std::string::npos) {
continue;
}
size_t start = pos + layer_tag.size();
size_t end = tensor_name.find('.', start);
if (end == std::string::npos) {
continue;
}
int64_t layer_id = atoll(tensor_name.substr(start, end - start).c_str());
num_layers = std::max(num_layers, layer_id + 1);
}
if (num_layers <= 0) {
num_layers = 28;
}
LOG_INFO("anima net layers: %" PRId64, num_layers);
net = AnimaNet(num_layers);
: DiffusionModelRunner(backend, params_backend, prefix),
config(AnimaConfig::detect_from_weights(tensor_storage_map, prefix + ".net")) {
net = AnimaNet(config);
net.init(params_ctx, tensor_storage_map, prefix + ".net");
}
@ -623,22 +637,22 @@ namespace Anima {
GGML_ASSERT(x->ne[3] == 1);
ggml_cgraph* gf = new_graph_custom(ANIMA_GRAPH_SIZE);
int64_t pad_h = (net.patch_size - x->ne[1] % net.patch_size) % net.patch_size;
int64_t pad_w = (net.patch_size - x->ne[0] % net.patch_size) % net.patch_size;
int64_t pad_h = (config.patch_size - x->ne[1] % config.patch_size) % config.patch_size;
int64_t pad_w = (config.patch_size - x->ne[0] % config.patch_size) % config.patch_size;
int64_t h_pad = x->ne[1] + pad_h;
int64_t w_pad = x->ne[0] + pad_w;
image_pe_vec = gen_anima_image_pe_vec(1,
static_cast<int>(h_pad),
static_cast<int>(w_pad),
static_cast<int>(net.patch_size),
net.theta,
net.axes_dim,
static_cast<int>(config.patch_size),
config.theta,
config.axes_dim,
4.0f,
4.0f,
1.0f);
int64_t image_pos_len = static_cast<int64_t>(image_pe_vec.size()) / (2 * 2 * (net.head_dim / 2));
auto image_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, net.head_dim / 2, image_pos_len);
int64_t image_pos_len = static_cast<int64_t>(image_pe_vec.size()) / (2 * 2 * (config.head_dim / 2));
auto image_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.head_dim / 2, image_pos_len);
set_backend_tensor_data(image_pe, image_pe_vec.data());
ggml_tensor* adapter_q_pe = nullptr;

View File

@ -1971,7 +1971,7 @@ struct LLMEmbedder : public Conditioner {
for (int i = 0; i < conditioner_params.ref_images->size(); i++) {
const auto& image = (*conditioner_params.ref_images)[i];
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
double factor = llm->config.vision.patch_size * llm->config.vision.spatial_merge_size;
int height = static_cast<int>(image.shape()[1]);
int width = static_cast<int>(image.shape()[0]);
int h_bar = static_cast<int>(std::round(height / factor) * factor);
@ -2042,7 +2042,7 @@ struct LLMEmbedder : public Conditioner {
for (int i = 0; i < conditioner_params.ref_images->size(); i++) {
const auto& image = (*conditioner_params.ref_images)[i];
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
double factor = llm->config.vision.patch_size * llm->config.vision.spatial_merge_size;
int height = static_cast<int>(image.shape()[1]);
int width = static_cast<int>(image.shape()[0]);
int h_bar = static_cast<int>(std::round(height / factor) * factor);

View File

@ -13,6 +13,76 @@
namespace ErnieImage {
constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960;
struct ErnieImageConfig {
int64_t hidden_size = 4096;
int64_t num_heads = 32;
int64_t num_layers = 36;
int64_t ffn_hidden_size = 12288;
int64_t in_channels = 128;
int64_t out_channels = 128;
int patch_size = 1;
int64_t text_in_dim = 3072;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int axes_dim_sum = 128;
float eps = 1e-6f;
static ErnieImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
ErnieImageConfig config;
config.num_layers = 0;
int64_t detected_head_dim = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) {
config.patch_size = static_cast<int>(tensor_storage.ne[0]);
config.in_channels = tensor_storage.ne[2];
config.hidden_size = tensor_storage.ne[3];
} else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) {
config.text_in_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) {
detected_head_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) {
config.ffn_hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) {
int64_t out_dim = tensor_storage.ne[1];
int64_t patch_area = config.patch_size * config.patch_size;
config.out_channels = out_dim / patch_area;
}
size_t pos = name.find("layers.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.num_layers) {
config.num_layers = block_index + 1;
}
}
}
}
if (config.num_layers == 0) {
config.num_layers = 36;
}
if (detected_head_dim > 0) {
config.num_heads = config.hidden_size / detected_head_dim;
}
config.axes_dim_sum = 0;
for (int axis_dim : config.axes_dim) {
config.axes_dim_sum += axis_dim;
}
LOG_DEBUG("ernie_image: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
config.num_layers,
config.hidden_size,
config.num_heads,
config.ffn_hidden_size,
config.in_channels,
config.out_channels);
return config;
}
};
__STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx,
ggml_tensor* timesteps,
int dim,
@ -208,51 +278,36 @@ namespace ErnieImage {
}
};
struct ErnieImageParams {
int64_t hidden_size = 4096;
int64_t num_heads = 32;
int64_t num_layers = 36;
int64_t ffn_hidden_size = 12288;
int64_t in_channels = 128;
int64_t out_channels = 128;
int patch_size = 1;
int64_t text_in_dim = 3072;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int axes_dim_sum = 128;
float eps = 1e-6f;
};
class ErnieImageModel : public GGMLBlock {
public:
ErnieImageParams params;
ErnieImageConfig config;
ErnieImageModel() = default;
ErnieImageModel(ErnieImageParams params)
: params(params) {
blocks["x_embedder.proj"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size,
std::pair<int, int>{params.patch_size, params.patch_size},
std::pair<int, int>{params.patch_size, params.patch_size},
ErnieImageModel(ErnieImageConfig config)
: config(config) {
blocks["x_embedder.proj"] = std::make_shared<Conv2d>(config.in_channels,
config.hidden_size,
std::pair<int, int>{config.patch_size, config.patch_size},
std::pair<int, int>{config.patch_size, config.patch_size},
std::pair<int, int>{0, 0},
std::pair<int, int>{1, 1},
true);
if (params.text_in_dim != params.hidden_size) {
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
if (config.text_in_dim != config.hidden_size) {
blocks["text_proj"] = std::make_shared<Linear>(config.text_in_dim, config.hidden_size, false);
}
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(config.hidden_size, config.hidden_size);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(config.hidden_size, 6 * config.hidden_size, true);
for (int i = 0; i < params.num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::make_shared<ErnieImageSharedAdaLNBlock>(params.hidden_size,
params.num_heads,
params.ffn_hidden_size,
params.eps);
for (int i = 0; i < config.num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::make_shared<ErnieImageSharedAdaLNBlock>(config.hidden_size,
config.num_heads,
config.ffn_hidden_size,
config.eps);
}
blocks["final_norm"] = std::make_shared<ErnieImageAdaLNContinuous>(params.hidden_size, params.eps);
blocks["final_linear"] = std::make_shared<Linear>(params.hidden_size,
params.patch_size * params.patch_size * params.out_channels,
blocks["final_norm"] = std::make_shared<ErnieImageAdaLNContinuous>(config.hidden_size, config.eps);
blocks["final_linear"] = std::make_shared<Linear>(config.hidden_size,
config.patch_size * config.patch_size * config.out_channels,
true);
}
@ -265,12 +320,12 @@ namespace ErnieImage {
// context: [N, text_tokens, 3072]
// pe: [image_tokens + text_tokens, head_dim/2, 2, 2]
GGML_ASSERT(context != nullptr);
GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0);
GGML_ASSERT(x->ne[1] % config.patch_size == 0 && x->ne[0] % config.patch_size == 0);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t Hp = H / params.patch_size;
int64_t Wp = W / params.patch_size;
int64_t Hp = H / config.patch_size;
int64_t Wp = W / config.patch_size;
int64_t n_img = Hp * Wp;
int64_t N = x->ne[3];
@ -292,7 +347,7 @@ namespace ErnieImage {
auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size]
auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast<int>(params.hidden_size));
auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast<int>(config.hidden_size));
auto c = time_embedding->forward(ctx, sample); // [N, hidden_size]
auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size]
@ -305,7 +360,7 @@ namespace ErnieImage {
temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size]
}
for (int i = 0; i < params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto layer = std::dynamic_pointer_cast<ErnieImageSharedAdaLNBlock>(blocks["layers." + std::to_string(i)]);
hidden_states = layer->forward(ctx, hidden_states, pe, temb);
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.layers." + std::to_string(i), "hidden_states");
@ -319,15 +374,15 @@ namespace ErnieImage {
patches,
Hp,
Wp,
params.patch_size,
params.patch_size,
config.patch_size,
config.patch_size,
false); // [N, out_channels, H, W]
return out;
}
};
struct ErnieImageRunner : public DiffusionModelRunner {
ErnieImageParams ernie_params;
ErnieImageConfig config;
ErnieImageModel ernie_image;
std::vector<float> pe_vec;
@ -335,58 +390,9 @@ namespace ErnieImage {
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: DiffusionModelRunner(backend, params_backend, prefix) {
ernie_params.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) {
ernie_params.patch_size = static_cast<int>(tensor_storage.ne[0]);
ernie_params.in_channels = tensor_storage.ne[2];
ernie_params.hidden_size = tensor_storage.ne[3];
} else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.text_in_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) {
int64_t head_dim = tensor_storage.ne[0];
ernie_params.num_heads = ernie_params.hidden_size / head_dim;
} else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.ffn_hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) {
int64_t out_dim = tensor_storage.ne[1];
ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size;
}
size_t pos = name.find("layers.");
if (pos != std::string::npos) {
std::string layer_name = name.substr(pos);
auto items = split_string(layer_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > ernie_params.num_layers) {
ernie_params.num_layers = block_index + 1;
}
}
}
}
if (ernie_params.num_layers == 0) {
ernie_params.num_layers = 36;
}
ernie_params.axes_dim_sum = 0;
for (int axis_dim : ernie_params.axes_dim) {
ernie_params.axes_dim_sum += axis_dim;
}
LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64
", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
ernie_params.num_layers,
ernie_params.hidden_size,
ernie_params.num_heads,
ernie_params.ffn_hidden_size,
ernie_params.in_channels,
ernie_params.out_channels);
ernie_image = ErnieImageModel(ernie_params);
: DiffusionModelRunner(backend, params_backend, prefix),
config(ErnieImageConfig::detect_from_weights(tensor_storage_map, prefix)) {
ernie_image = ErnieImageModel(config);
ernie_image.init(params_ctx, tensor_storage_map, prefix);
}
@ -410,15 +416,15 @@ namespace ErnieImage {
pe_vec = Rope::gen_ernie_image_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
ernie_params.patch_size,
config.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
ernie_params.theta,
config.theta,
circular_y_enabled,
circular_x_enabled,
ernie_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / ernie_params.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2);
config.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, config.axes_dim_sum, 1, pos_len, 2);
set_backend_tensor_data(pe, pe_vec.data());
auto runner_ctx = get_context();

View File

@ -13,6 +13,155 @@
namespace Flux {
struct ChromaRadianceConfig {
int64_t nerf_hidden_size = 64;
int nerf_mlp_ratio = 4;
int nerf_depth = 4;
int nerf_max_freqs = 8;
bool use_x0 = false;
bool fake_patch_size_x2 = false;
};
struct FluxConfig {
SDVersion version = VERSION_FLUX;
bool is_chroma = false;
int patch_size = 2;
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
float mlp_ratio = 4.0f;
int num_heads = 24;
int depth = 19;
int depth_single_blocks = 38;
std::vector<int> axes_dim = {16, 56, 56};
int axes_dim_sum = 128;
int theta = 10000;
bool qkv_bias = true;
bool guidance_embed = true;
int64_t in_dim = 64;
bool disable_bias = false;
bool share_modulation = false;
bool semantic_txt_norm = false;
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceConfig chroma_radiance_params;
static FluxConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
SDVersion version = VERSION_FLUX) {
FluxConfig config;
config.version = version;
config.guidance_embed = false;
config.depth = 0;
config.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
config.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
config.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
config.in_channels = 196;
} else if (version == VERSION_CHROMA_RADIANCE) {
config.in_channels = 3;
config.patch_size = 16;
} else if (version == VERSION_OVIS_IMAGE) {
config.semantic_txt_norm = true;
config.use_yak_mlp = true;
config.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
config.in_channels = 128;
config.patch_size = 1;
config.out_channels = 128;
config.mlp_ratio = 3.f;
config.theta = 2000;
config.axes_dim = {32, 32, 32, 32};
config.vec_in_dim = 0;
config.qkv_bias = false;
config.disable_bias = true;
config.share_modulation = true;
config.ref_index_scale = 10.f;
config.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
config.context_in_dim = 3584;
config.vec_in_dim = 0;
}
int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (name.find("guidance_in.in_layer.weight") != std::string::npos) {
config.guidance_embed = true;
}
if (name.find("__x0__") != std::string::npos) {
LOG_DEBUG("using x0 prediction");
config.chroma_radiance_params.use_x0 = true;
}
if (name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32");
config.patch_size = 32;
}
if (name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = tensor_storage.ne[0];
LOG_DEBUG("actual radiance patch size: %" PRId64, actual_radiance_patch_size);
}
if (name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
config.is_chroma = true;
}
size_t db = name.find("double_blocks.");
if (db != std::string::npos) {
std::string block_name = name.substr(db);
int block_depth = atoi(block_name.substr(14, block_name.find(".", 14)).c_str());
if (block_depth + 1 > config.depth) {
config.depth = block_depth + 1;
}
}
size_t sb = name.find("single_blocks.");
if (sb != std::string::npos) {
std::string block_name = name.substr(sb);
int block_depth = atoi(block_name.substr(14, block_name.find(".", 14)).c_str());
if (block_depth + 1 > config.depth_single_blocks) {
config.depth_single_blocks = block_depth + 1;
}
}
if (ends_with(name, "txt_in.weight")) {
config.context_in_dim = tensor_storage.ne[0];
config.hidden_size = tensor_storage.ne[1];
}
if (ends_with(name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = tensor_storage.ne[0];
}
if (ends_with(name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = tensor_storage.ne[0];
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != config.patch_size) {
GGML_ASSERT(config.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
config.chroma_radiance_params.fake_patch_size_x2 = true;
}
if (head_dim > 0) {
config.num_heads = static_cast<int>(config.hidden_size / head_dim);
}
config.axes_dim_sum = 0;
for (int axis_dim : config.axes_dim) {
config.axes_dim_sum += axis_dim;
}
LOG_DEBUG("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %d",
config.depth,
config.depth_single_blocks,
config.guidance_embed ? "true" : "false",
config.context_in_dim,
config.hidden_size,
config.num_heads);
return config;
}
};
struct MLPEmbedder : public UnaryBlock {
public:
MLPEmbedder(int64_t in_dim, int64_t hidden_dim, bool bias = true) {
@ -723,127 +872,90 @@ namespace Flux {
}
};
struct ChromaRadianceParams {
int64_t nerf_hidden_size = 64;
int nerf_mlp_ratio = 4;
int nerf_depth = 4;
int nerf_max_freqs = 8;
bool use_x0 = false;
bool fake_patch_size_x2 = false;
};
struct FluxParams {
SDVersion version = VERSION_FLUX;
bool is_chroma = false;
int patch_size = 2;
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
float mlp_ratio = 4.0f;
int num_heads = 24;
int depth = 19;
int depth_single_blocks = 38;
std::vector<int> axes_dim = {16, 56, 56};
int axes_dim_sum = 128;
int theta = 10000;
bool qkv_bias = true;
bool guidance_embed = true;
int64_t in_dim = 64;
bool disable_bias = false;
bool share_modulation = false;
bool semantic_txt_norm = false;
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params;
};
struct Flux : public GGMLBlock {
public:
FluxParams params;
FluxConfig config;
Flux() {}
Flux(FluxParams params)
: params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
if (params.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {params.patch_size / 2, params.patch_size / 2};
Flux(FluxConfig config)
: config(config) {
if (config.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {config.patch_size, config.patch_size};
if (config.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {config.patch_size / 2, config.patch_size / 2};
}
std::pair<int, int> stride = kernel_size;
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size,
blocks["img_in_patch"] = std::make_shared<Conv2d>(config.in_channels,
config.hidden_size,
kernel_size,
stride);
} else {
blocks["img_in"] = std::make_shared<Linear>(params.in_channels, params.hidden_size, !params.disable_bias);
blocks["img_in"] = std::make_shared<Linear>(config.in_channels, config.hidden_size, !config.disable_bias);
}
if (params.is_chroma) {
blocks["distilled_guidance_layer"] = std::make_shared<ChromaApproximator>(params.in_dim, params.hidden_size);
if (config.is_chroma) {
blocks["distilled_guidance_layer"] = std::make_shared<ChromaApproximator>(config.in_dim, config.hidden_size);
} else {
blocks["time_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
if (params.vec_in_dim > 0) {
blocks["vector_in"] = std::make_shared<MLPEmbedder>(params.vec_in_dim, params.hidden_size, !params.disable_bias);
blocks["time_in"] = std::make_shared<MLPEmbedder>(256, config.hidden_size, !config.disable_bias);
if (config.vec_in_dim > 0) {
blocks["vector_in"] = std::make_shared<MLPEmbedder>(config.vec_in_dim, config.hidden_size, !config.disable_bias);
}
if (params.guidance_embed) {
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
if (config.guidance_embed) {
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, config.hidden_size, !config.disable_bias);
}
}
if (params.semantic_txt_norm) {
blocks["txt_norm"] = std::make_shared<RMSNorm>(params.context_in_dim);
if (config.semantic_txt_norm) {
blocks["txt_norm"] = std::make_shared<RMSNorm>(config.context_in_dim);
}
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, !params.disable_bias);
blocks["txt_in"] = std::make_shared<Linear>(config.context_in_dim, config.hidden_size, !config.disable_bias);
for (int i = 0; i < params.depth; i++) {
blocks["double_blocks." + std::to_string(i)] = std::make_shared<DoubleStreamBlock>(params.hidden_size,
params.num_heads,
params.mlp_ratio,
for (int i = 0; i < config.depth; i++) {
blocks["double_blocks." + std::to_string(i)] = std::make_shared<DoubleStreamBlock>(config.hidden_size,
config.num_heads,
config.mlp_ratio,
i,
params.qkv_bias,
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
config.qkv_bias,
config.is_chroma,
config.share_modulation,
!config.disable_bias,
config.use_yak_mlp,
config.use_mlp_silu_act);
}
for (int i = 0; i < params.depth_single_blocks; i++) {
blocks["single_blocks." + std::to_string(i)] = std::make_shared<SingleStreamBlock>(params.hidden_size,
params.num_heads,
params.mlp_ratio,
for (int i = 0; i < config.depth_single_blocks; i++) {
blocks["single_blocks." + std::to_string(i)] = std::make_shared<SingleStreamBlock>(config.hidden_size,
config.num_heads,
config.mlp_ratio,
i,
0.f,
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
config.is_chroma,
config.share_modulation,
!config.disable_bias,
config.use_yak_mlp,
config.use_mlp_silu_act);
}
if (params.version == VERSION_CHROMA_RADIANCE) {
blocks["nerf_image_embedder"] = std::make_shared<NerfEmbedder>(params.in_channels,
params.chroma_radiance_params.nerf_hidden_size,
params.chroma_radiance_params.nerf_max_freqs);
if (config.version == VERSION_CHROMA_RADIANCE) {
blocks["nerf_image_embedder"] = std::make_shared<NerfEmbedder>(config.in_channels,
config.chroma_radiance_params.nerf_hidden_size,
config.chroma_radiance_params.nerf_max_freqs);
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
blocks["nerf_blocks." + std::to_string(i)] = std::make_shared<NerfGLUBlock>(params.hidden_size,
params.chroma_radiance_params.nerf_hidden_size,
params.chroma_radiance_params.nerf_mlp_ratio);
for (int i = 0; i < config.chroma_radiance_params.nerf_depth; i++) {
blocks["nerf_blocks." + std::to_string(i)] = std::make_shared<NerfGLUBlock>(config.hidden_size,
config.chroma_radiance_params.nerf_hidden_size,
config.chroma_radiance_params.nerf_mlp_ratio);
}
blocks["nerf_final_layer_conv"] = std::make_shared<NerfFinalLayerConv>(params.chroma_radiance_params.nerf_hidden_size,
params.in_channels);
blocks["nerf_final_layer_conv"] = std::make_shared<NerfFinalLayerConv>(config.chroma_radiance_params.nerf_hidden_size,
config.in_channels);
} else {
blocks["final_layer"] = std::make_shared<LastLayer>(params.hidden_size, 1, params.out_channels, params.is_chroma, !params.disable_bias);
blocks["final_layer"] = std::make_shared<LastLayer>(config.hidden_size, 1, config.out_channels, config.is_chroma, !config.disable_bias);
}
if (params.share_modulation) {
blocks["double_stream_modulation_img"] = std::make_shared<Modulation>(params.hidden_size, true, !params.disable_bias);
blocks["double_stream_modulation_txt"] = std::make_shared<Modulation>(params.hidden_size, true, !params.disable_bias);
blocks["single_stream_modulation"] = std::make_shared<Modulation>(params.hidden_size, false, !params.disable_bias);
if (config.share_modulation) {
blocks["double_stream_modulation_img"] = std::make_shared<Modulation>(config.hidden_size, true, !config.disable_bias);
blocks["double_stream_modulation_txt"] = std::make_shared<Modulation>(config.hidden_size, true, !config.disable_bias);
blocks["single_stream_modulation"] = std::make_shared<Modulation>(config.hidden_size, false, !config.disable_bias);
}
}
@ -866,7 +978,7 @@ namespace Flux {
ggml_tensor* vec;
ggml_tensor* txt_img_mask = nullptr;
if (params.is_chroma) {
if (config.is_chroma) {
int64_t mod_index_length = 344;
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
auto distill_timestep = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 16, 10000, 1000.f);
@ -894,7 +1006,7 @@ namespace Flux {
} else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) {
if (config.guidance_embed) {
GGML_ASSERT(guidance != nullptr);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
@ -902,7 +1014,7 @@ namespace Flux {
vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in));
}
if (params.vec_in_dim > 0) {
if (config.vec_in_dim > 0) {
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y));
}
@ -911,7 +1023,7 @@ namespace Flux {
std::vector<ModulationOut> ds_img_mods;
std::vector<ModulationOut> ds_txt_mods;
std::vector<ModulationOut> ss_mods;
if (params.share_modulation) {
if (config.share_modulation) {
auto double_stream_modulation_img = std::dynamic_pointer_cast<Modulation>(blocks["double_stream_modulation_img"]);
auto double_stream_modulation_txt = std::dynamic_pointer_cast<Modulation>(blocks["double_stream_modulation_txt"]);
auto single_stream_modulation = std::dynamic_pointer_cast<Modulation>(blocks["single_stream_modulation"]);
@ -921,7 +1033,7 @@ namespace Flux {
ss_mods = single_stream_modulation->forward(ctx, vec);
}
if (params.semantic_txt_norm) {
if (config.semantic_txt_norm) {
auto semantic_txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
txt = semantic_txt_norm->forward(ctx, txt);
@ -932,7 +1044,7 @@ namespace Flux {
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt");
sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec");
for (int i = 0; i < params.depth; i++) {
for (int i = 0; i < config.depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
}
@ -947,8 +1059,8 @@ namespace Flux {
}
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
for (int i = 0; i < params.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
for (int i = 0; i < config.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + config.depth) != skip_layers.end()) {
continue;
}
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
@ -999,14 +1111,14 @@ namespace Flux {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int patch_size = params.patch_size;
int patch_size = config.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = DiT::pad_to_patch_size(ctx, x, params.patch_size, params.patch_size);
auto img = DiT::pad_to_patch_size(ctx, x, config.patch_size, config.patch_size);
auto orig_img = img;
if (params.chroma_radiance_params.fake_patch_size_x2) {
if (config.chroma_radiance_params.fake_patch_size_x2) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@ -1037,7 +1149,7 @@ namespace Flux {
auto nerf_hidden = ggml_reshape_2d(ctx->ggml_ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
for (int i = 0; i < config.chroma_radiance_params.nerf_depth; i++) {
auto block = std::dynamic_pointer_cast<NerfGLUBlock>(blocks["nerf_blocks." + std::to_string(i)]);
img_dct = block->forward(ctx, img_dct, nerf_hidden);
@ -1049,7 +1161,7 @@ namespace Flux {
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
if (params.chroma_radiance_params.use_x0) {
if (config.chroma_radiance_params.use_x0) {
out = _apply_x0_residual(ctx, out, orig_img, timestep);
}
@ -1073,14 +1185,14 @@ namespace Flux {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int patch_size = params.patch_size;
int patch_size = config.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size);
int64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) {
if (config.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
@ -1089,7 +1201,7 @@ namespace Flux {
mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
} else if (config.version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
@ -1100,7 +1212,7 @@ namespace Flux {
control = DiT::pad_and_patchify(ctx, control, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
} else if (config.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr);
auto control = DiT::pad_and_patchify(ctx, c_concat, patch_size, patch_size);
@ -1147,7 +1259,7 @@ namespace Flux {
// pe: (L, d_head/2, 2, 2)
// return: (N, C, H, W)
if (params.version == VERSION_CHROMA_RADIANCE) {
if (config.version == VERSION_CHROMA_RADIANCE) {
return forward_chroma_radiance(ctx,
x,
timestep,
@ -1179,7 +1291,7 @@ namespace Flux {
struct FluxRunner : public DiffusionModelRunner {
public:
FluxParams flux_params;
FluxConfig config;
Flux flux;
std::vector<float> pe_vec;
std::vector<float> mod_index_arange_vec;
@ -1194,114 +1306,15 @@ namespace Flux {
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool use_mask = false)
: DiffusionModelRunner(backend, params_backend, prefix), version(version), use_mask(use_mask) {
flux_params.version = version;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
} else if (version == VERSION_CHROMA_RADIANCE) {
flux_params.in_channels = 3;
flux_params.patch_size = 16;
} else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
flux_params.in_channels = 128;
flux_params.patch_size = 1;
flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f;
flux_params.theta = 2000;
flux_params.axes_dim = {32, 32, 32, 32};
flux_params.vec_in_dim = 0;
flux_params.qkv_bias = false;
flux_params.disable_bias = true;
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
}
int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
continue;
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
flux_params.guidance_embed = true;
}
if (tensor_name.find("__x0__") != std::string::npos) {
LOG_DEBUG("using x0 prediction");
flux_params.chroma_radiance_params.use_x0 = true;
}
if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32");
flux_params.patch_size = 32;
}
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = pair.second.ne[0];
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
}
size_t db = tensor_name.find("double_blocks.");
if (db != std::string::npos) {
tensor_name = tensor_name.substr(db); // remove prefix
int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str());
if (block_depth + 1 > flux_params.depth) {
flux_params.depth = block_depth + 1;
}
}
size_t sb = tensor_name.find("single_blocks.");
if (sb != std::string::npos) {
tensor_name = tensor_name.substr(sb); // remove prefix
int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str());
if (block_depth + 1 > flux_params.depth_single_blocks) {
flux_params.depth_single_blocks = block_depth + 1;
}
}
if (ends_with(tensor_name, "txt_in.weight")) {
flux_params.context_in_dim = pair.second.ne[0];
flux_params.hidden_size = pair.second.ne[1];
}
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
}
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
", hidden_size = %" PRId64 ", num_heads = %d",
flux_params.depth,
flux_params.depth_single_blocks,
flux_params.guidance_embed ? "true" : "false",
flux_params.context_in_dim,
flux_params.hidden_size,
flux_params.num_heads);
if (flux_params.is_chroma) {
: DiffusionModelRunner(backend, params_backend, prefix),
config(FluxConfig::detect_from_weights(tensor_storage_map, prefix, version)),
version(version),
use_mask(use_mask) {
if (config.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)");
}
flux = Flux(flux_params);
flux = Flux(config);
flux.init(params_ctx, tensor_storage_map, prefix);
}
@ -1377,10 +1390,10 @@ namespace Flux {
ggml_tensor* context = make_optional_input(context_tensor);
ggml_tensor* c_concat = make_optional_input(c_concat_tensor);
ggml_tensor* y = make_optional_input(y_tensor);
if (flux_params.guidance_embed || flux_params.is_chroma) {
if (config.guidance_embed || config.is_chroma) {
if (!guidance_tensor.empty()) {
this->guidance_tensor = guidance_tensor;
if (flux_params.is_chroma) {
if (config.is_chroma) {
this->guidance_tensor.fill_(0.f);
}
}
@ -1398,7 +1411,7 @@ namespace Flux {
ggml_tensor* mod_index_arange = nullptr;
ggml_tensor* dct = nullptr; // for chroma radiance
if (flux_params.is_chroma) {
if (config.is_chroma) {
if (!use_mask) {
y = nullptr;
}
@ -1417,29 +1430,29 @@ namespace Flux {
}
pe_vec = Rope::gen_flux_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
flux_params.patch_size,
config.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
txt_arange_dims,
ref_latents,
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
config.ref_index_scale,
config.theta,
circular_y_enabled,
circular_x_enabled,
flux_params.axes_dim,
config.axes_dim,
sd_version_is_longcat(version));
int pos_len = static_cast<int>(pe_vec.size() / flux_params.axes_dim_sum / 2);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
if (version == VERSION_CHROMA_RADIANCE) {
int patch_size = flux_params.patch_size;
int nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs;
int patch_size = config.patch_size;
int nerf_max_freqs = config.chroma_radiance_params.nerf_max_freqs;
dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs);
dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size);
// dct->data = dct_vec.data();

View File

@ -1707,7 +1707,7 @@ protected:
uint64_t resident_state_token = 0;
size_t max_graph_vram_bytes = 0;
bool stream_layers_enabled = false;
bool stream_layers_enabled = false;
size_t observed_max_effective_budget_ = 0;
sd::layer_registry::LayerRegistry layer_registry_;

View File

@ -23,6 +23,39 @@ namespace HiDreamO1 {
constexpr int IMAGE_TOKEN_ID = 151655;
constexpr int VISION_START_TOKEN_ID = 151652;
struct HiDreamO1Config {
LLM::LLMConfig llm;
int patch_size = PATCH_SIZE;
static HiDreamO1Config detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
(void)tensor_storage_map;
(void)prefix;
HiDreamO1Config config;
config.llm.arch = LLM::LLMArch::QWEN3_VL;
config.llm.hidden_size = 4096;
config.llm.intermediate_size = 12288;
config.llm.num_layers = 36;
config.llm.num_heads = 32;
config.llm.num_kv_heads = 8;
config.llm.head_dim = 128;
config.llm.qkv_bias = false;
config.llm.qk_norm = true;
config.llm.vocab_size = 151936;
config.llm.rms_norm_eps = 1e-6f;
config.llm.vision.arch = LLM::LLMVisionArch::QWEN3_VL;
config.llm.vision.num_layers = 27;
config.llm.vision.hidden_size = 1152;
config.llm.vision.intermediate_size = 4304;
config.llm.vision.num_heads = 16;
config.llm.vision.out_hidden_size = 4096;
config.llm.vision.patch_size = 16;
config.llm.vision.spatial_merge_size = 2;
config.llm.vision.temporal_patch_size = 2;
config.llm.vision.num_position_embeddings = 2304;
return config;
}
};
static inline std::string repeat_special_token(const std::string& token, int64_t count) {
std::string out;
out.reserve(static_cast<size_t>(count) * token.size());
@ -205,50 +238,19 @@ namespace HiDreamO1 {
}
};
struct HiDreamO1Params {
LLM::LLMParams llm;
int patch_size = PATCH_SIZE;
};
static inline HiDreamO1Params make_hidream_o1_params() {
HiDreamO1Params params;
params.llm.arch = LLM::LLMArch::QWEN3_VL;
params.llm.hidden_size = 4096;
params.llm.intermediate_size = 12288;
params.llm.num_layers = 36;
params.llm.num_heads = 32;
params.llm.num_kv_heads = 8;
params.llm.head_dim = 128;
params.llm.qkv_bias = false;
params.llm.qk_norm = true;
params.llm.vocab_size = 151936;
params.llm.rms_norm_eps = 1e-6f;
params.llm.vision.arch = LLM::LLMVisionArch::QWEN3_VL;
params.llm.vision.num_layers = 27;
params.llm.vision.hidden_size = 1152;
params.llm.vision.intermediate_size = 4304;
params.llm.vision.num_heads = 16;
params.llm.vision.out_hidden_size = 4096;
params.llm.vision.patch_size = 16;
params.llm.vision.spatial_merge_size = 2;
params.llm.vision.temporal_patch_size = 2;
params.llm.vision.num_position_embeddings = 2304;
return params;
}
struct HiDreamO1Model : public GGMLBlock {
HiDreamO1Params params;
HiDreamO1Config config;
HiDreamO1Model() = default;
explicit HiDreamO1Model(HiDreamO1Params params)
: params(std::move(params)) {
blocks["language_model"] = std::make_shared<LLM::TextModel>(this->params.llm);
blocks["t_embedder1"] = std::make_shared<TimestepEmbedder>(this->params.llm.hidden_size);
blocks["x_embedder"] = std::make_shared<BottleneckPatchEmbed>(this->params.patch_size * this->params.patch_size * 3,
this->params.llm.hidden_size / 4,
this->params.llm.hidden_size);
blocks["final_layer2"] = std::make_shared<FinalLayer>(this->params.llm.hidden_size,
this->params.patch_size * this->params.patch_size * 3);
explicit HiDreamO1Model(HiDreamO1Config config)
: config(std::move(config)) {
blocks["language_model"] = std::make_shared<LLM::TextModel>(this->config.llm);
blocks["t_embedder1"] = std::make_shared<TimestepEmbedder>(this->config.llm.hidden_size);
blocks["x_embedder"] = std::make_shared<BottleneckPatchEmbed>(this->config.patch_size * this->config.patch_size * 3,
this->config.llm.hidden_size / 4,
this->config.llm.hidden_size);
blocks["final_layer2"] = std::make_shared<FinalLayer>(this->config.llm.hidden_size,
this->config.patch_size * this->config.patch_size * 3);
}
std::shared_ptr<LLM::TextModel> text_model() {
@ -269,7 +271,7 @@ namespace HiDreamO1 {
};
struct HiDreamO1VisionRunner : public GGMLRunner {
HiDreamO1Params params;
HiDreamO1Config config;
std::shared_ptr<LLM::VisionModel> model;
std::vector<int> window_index_vec;
@ -284,8 +286,8 @@ namespace HiDreamO1 {
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "model.visual")
: GGMLRunner(backend, params_backend),
params(make_hidream_o1_params()),
model(std::make_shared<LLM::VisionModel>(false, params.llm.vision)) {
config(HiDreamO1Config::detect_from_weights(tensor_storage_map, prefix)),
model(std::make_shared<LLM::VisionModel>(false, config.llm.vision)) {
model->init(params_ctx, tensor_storage_map, prefix);
}
@ -302,7 +304,7 @@ namespace HiDreamO1 {
compute_ctx,
runner_ctx,
image,
params.llm.vision,
config.llm.vision,
model,
window_index_vec,
window_inverse_index_vec,
@ -331,7 +333,7 @@ namespace HiDreamO1 {
};
struct HiDreamO1Runner : public DiffusionModelRunner {
HiDreamO1Params params;
HiDreamO1Config config;
HiDreamO1Model model;
std::vector<float> attention_mask_vec;
@ -341,8 +343,8 @@ namespace HiDreamO1 {
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "model")
: DiffusionModelRunner(backend, params_backend, prefix),
params(make_hidream_o1_params()) {
model = HiDreamO1Model(params);
config(HiDreamO1Config::detect_from_weights(tensor_storage_map, prefix)) {
model = HiDreamO1Model(config);
model.init(params_ctx, tensor_storage_map, prefix);
}

View File

@ -38,6 +38,34 @@ namespace Ideogram4 {
std::vector<int> mrope_section = {DEFAULT_MROPE_SECTION_T,
DEFAULT_MROPE_SECTION_H,
DEFAULT_MROPE_SECTION_W};
static Ideogram4Config detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
Ideogram4Config config;
int64_t detected_layers = 0;
std::string layer_prefix = prefix.empty() ? "layers." : prefix + ".layers.";
for (const auto& [name, _] : tensor_storage_map) {
if (name.find(layer_prefix) != 0) {
continue;
}
std::string tail = name.substr(layer_prefix.size());
size_t dot = tail.find('.');
if (dot == std::string::npos) {
continue;
}
int layer_idx = std::atoi(tail.substr(0, dot).c_str());
detected_layers = std::max<int64_t>(detected_layers, layer_idx + 1);
}
if (detected_layers > 0) {
config.num_layers = detected_layers;
LOG_DEBUG("ideogram4: num_layers = %" PRId64 ", emb_dim = %" PRId64 ", num_heads = %" PRId64 ", intermediate_size = %" PRId64,
config.num_layers,
config.emb_dim,
config.num_heads,
config.intermediate_size);
}
return config;
}
};
__STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx,
@ -380,26 +408,6 @@ namespace Ideogram4 {
class Ideogram4Runner : public DiffusionModelRunner {
protected:
static int64_t detect_num_layers(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
int64_t detected_layers = 0;
std::string layer_prefix = prefix.empty() ? "layers." : prefix + ".layers.";
for (const auto& pair : tensor_storage_map) {
const std::string& name = pair.first;
if (name.find(layer_prefix) != 0) {
continue;
}
std::string tail = name.substr(layer_prefix.size());
size_t dot = tail.find('.');
if (dot == std::string::npos) {
continue;
}
int layer_idx = std::atoi(tail.substr(0, dot).c_str());
detected_layers = std::max<int64_t>(detected_layers, layer_idx + 1);
}
return detected_layers;
}
bool should_use_uncond_model(const DiffusionParams& diffusion_params) const {
return has_uncond_model &&
diffusion_params.context == nullptr &&
@ -421,12 +429,8 @@ namespace Ideogram4 {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: DiffusionModelRunner(backend, params_backend, prefix),
config(Ideogram4Config::detect_from_weights(tensor_storage_map, prefix)),
uncond_prefix(prefix + ".uncond") {
int64_t detected_layers = detect_num_layers(tensor_storage_map, prefix);
if (detected_layers > 0) {
config.num_layers = detected_layers;
}
model = Ideogram4Transformer(config);
model.init(params_ctx, tensor_storage_map, prefix);
for (const auto& pair : tensor_storage_map) {

View File

@ -13,6 +13,71 @@
namespace Lens {
constexpr int LENS_GRAPH_SIZE = 40960;
struct LensConfig {
int patch_size = 2;
int64_t in_channels = 128;
int64_t out_channels = 32;
int num_layers = 48;
int64_t attention_head_dim = 64;
int64_t num_attention_heads = 24;
int64_t joint_attention_dim = 2880;
int selected_layer_count = 4;
int theta = 10000;
std::vector<int> axes_dim = {8, 28, 28};
int axes_dim_sum = 64;
static LensConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
LensConfig config;
config.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "img_in.weight") && tensor_storage.n_dims == 2) {
config.in_channels = tensor_storage.ne[0];
int64_t inner_dim = tensor_storage.ne[1];
if (config.attention_head_dim > 0) {
config.num_attention_heads = inner_dim / config.attention_head_dim;
}
} else if (ends_with(name, "txt_in.weight") && tensor_storage.n_dims == 2) {
config.selected_layer_count = static_cast<int>(tensor_storage.ne[0] / config.joint_attention_dim);
} else if (ends_with(name, "proj_out.weight") && tensor_storage.n_dims == 2) {
int64_t patch_area = config.patch_size * config.patch_size;
config.out_channels = tensor_storage.ne[1] / patch_area;
} else if (ends_with(name, "transformer_blocks.0.attn.norm_q.weight") && tensor_storage.n_dims == 1) {
config.attention_head_dim = tensor_storage.ne[0];
}
size_t pos = name.find("transformer_blocks.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.num_layers) {
config.num_layers = block_index + 1;
}
}
}
}
if (config.num_layers == 0) {
config.num_layers = 48;
}
config.axes_dim_sum = 0;
for (int axis_dim : config.axes_dim) {
config.axes_dim_sum += axis_dim;
}
LOG_DEBUG("lens: num_layers = %d, selected_layer_count = %d, hidden_size = %" PRId64 ", num_attention_heads = %" PRId64 ", attention_head_dim = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
config.num_layers,
config.selected_layer_count,
config.num_attention_heads * config.attention_head_dim,
config.num_attention_heads,
config.attention_head_dim,
config.in_channels,
config.out_channels);
return config;
}
};
struct LensTimestepProjEmbeddings : public GGMLBlock {
LensTimestepProjEmbeddings(int64_t embedding_dim) {
blocks["timestep_embedder"] = std::make_shared<Qwen::TimestepEmbedding>(256, embedding_dim);
@ -209,41 +274,27 @@ namespace Lens {
}
};
struct LensParams {
int patch_size = 2;
int64_t in_channels = 128;
int64_t out_channels = 32;
int num_layers = 48;
int64_t attention_head_dim = 64;
int64_t num_attention_heads = 24;
int64_t joint_attention_dim = 2880;
int selected_layer_count = 4;
int theta = 10000;
std::vector<int> axes_dim = {8, 28, 28};
int axes_dim_sum = 64;
};
class LensModel : public GGMLBlock {
public:
LensParams params;
LensConfig config;
LensModel() = default;
LensModel(LensParams params)
: params(params) {
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
LensModel(LensConfig config)
: config(config) {
int64_t inner_dim = config.num_attention_heads * config.attention_head_dim;
blocks["time_text_embed"] = std::make_shared<LensTimestepProjEmbeddings>(inner_dim);
blocks["img_in"] = std::make_shared<Linear>(params.in_channels, inner_dim, true);
blocks["txt_in"] = std::make_shared<Linear>(params.joint_attention_dim * params.selected_layer_count, inner_dim, true);
for (int i = 0; i < params.selected_layer_count; ++i) {
blocks["txt_norm." + std::to_string(i)] = std::make_shared<RMSNorm>(params.joint_attention_dim, 1e-5f);
blocks["img_in"] = std::make_shared<Linear>(config.in_channels, inner_dim, true);
blocks["txt_in"] = std::make_shared<Linear>(config.joint_attention_dim * config.selected_layer_count, inner_dim, true);
for (int i = 0; i < config.selected_layer_count; ++i) {
blocks["txt_norm." + std::to_string(i)] = std::make_shared<RMSNorm>(config.joint_attention_dim, 1e-5f);
}
for (int i = 0; i < params.num_layers; ++i) {
for (int i = 0; i < config.num_layers; ++i) {
blocks["transformer_blocks." + std::to_string(i)] = std::make_shared<LensTransformerBlock>(inner_dim,
params.num_attention_heads,
params.attention_head_dim);
config.num_attention_heads,
config.attention_head_dim);
}
blocks["norm_out"] = std::make_shared<LensAdaLayerNormContinuous>(inner_dim, 1e-6f);
blocks["proj_out"] = std::make_shared<Linear>(inner_dim, params.patch_size * params.patch_size * params.out_channels, true);
blocks["proj_out"] = std::make_shared<Linear>(inner_dim, config.patch_size * config.patch_size * config.out_channels, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -269,9 +320,9 @@ namespace Lens {
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3));
img = img_in->forward(ctx, img);
std::vector<ggml_tensor*> txt_chunks = ggml_ext_chunk(ctx->ggml_ctx, context, params.selected_layer_count, 0);
std::vector<ggml_tensor*> txt_chunks = ggml_ext_chunk(ctx->ggml_ctx, context, config.selected_layer_count, 0);
ggml_tensor* txt = nullptr;
for (int i = 0; i < params.selected_layer_count; ++i) {
for (int i = 0; i < config.selected_layer_count; ++i) {
auto txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm." + std::to_string(i)]);
auto chunk = txt_norm->forward(ctx, txt_chunks[i]);
txt = txt == nullptr ? chunk : ggml_concat(ctx->ggml_ctx, txt, chunk, 0);
@ -281,7 +332,7 @@ namespace Lens {
sd::ggml_graph_cut::mark_graph_cut(img, "lens.prelude", "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "lens.prelude", "txt");
for (int i = 0; i < params.num_layers; ++i) {
for (int i = 0; i < config.num_layers; ++i) {
auto block = std::dynamic_pointer_cast<LensTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
auto out = block->forward(ctx, img, txt, t_emb, pe);
img = out.first;
@ -294,13 +345,13 @@ namespace Lens {
img = proj_out->forward(ctx, img);
auto out = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3));
out = ggml_reshape_4d(ctx->ggml_ctx, out, W, H, params.patch_size * params.patch_size * params.out_channels, N);
out = ggml_reshape_4d(ctx->ggml_ctx, out, W, H, config.patch_size * config.patch_size * config.out_channels, N);
return out;
}
};
struct LensRunner : public DiffusionModelRunner {
LensParams lens_params;
LensConfig config;
LensModel lens;
std::vector<float> pe_vec;
@ -308,53 +359,9 @@ namespace Lens {
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: DiffusionModelRunner(backend, params_backend, prefix) {
lens_params.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "img_in.weight") && tensor_storage.n_dims == 2) {
lens_params.in_channels = tensor_storage.ne[0];
int64_t inner_dim = tensor_storage.ne[1];
lens_params.num_attention_heads = inner_dim / lens_params.attention_head_dim;
} else if (ends_with(name, "txt_in.weight") && tensor_storage.n_dims == 2) {
lens_params.selected_layer_count = static_cast<int>(tensor_storage.ne[0] / lens_params.joint_attention_dim);
} else if (ends_with(name, "proj_out.weight") && tensor_storage.n_dims == 2) {
lens_params.out_channels = tensor_storage.ne[1] / lens_params.patch_size / lens_params.patch_size;
} else if (ends_with(name, "transformer_blocks.0.attn.norm_q.weight") && tensor_storage.n_dims == 1) {
lens_params.attention_head_dim = tensor_storage.ne[0];
}
size_t pos = name.find("transformer_blocks.");
if (pos != std::string::npos) {
std::string layer_name = name.substr(pos);
auto items = split_string(layer_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > lens_params.num_layers) {
lens_params.num_layers = block_index + 1;
}
}
}
}
if (lens_params.num_layers == 0) {
lens_params.num_layers = 48;
}
lens_params.axes_dim_sum = 0;
for (int axis_dim : lens_params.axes_dim) {
lens_params.axes_dim_sum += axis_dim;
}
LOG_INFO("lens: layers = %d, in_channels = %" PRId64 ", out_channels = %" PRId64
", heads = %" PRId64 ", head_dim = %" PRId64,
lens_params.num_layers,
lens_params.in_channels,
lens_params.out_channels,
lens_params.num_attention_heads,
lens_params.attention_head_dim);
lens = LensModel(lens_params);
: DiffusionModelRunner(backend, params_backend, prefix),
config(LensConfig::detect_from_weights(tensor_storage_map, prefix)) {
lens = LensModel(config);
lens.init(params_ctx, tensor_storage_map, prefix);
}
@ -380,12 +387,12 @@ namespace Lens {
static_cast<int>(x->ne[0]),
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
lens_params.theta,
config.theta,
circular_y_enabled,
circular_x_enabled,
lens_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / lens_params.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, lens_params.axes_dim_sum / 2, pos_len);
config.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
set_backend_tensor_data(pe, pe_vec.data());
auto runner_ctx = get_context();

View File

@ -63,7 +63,7 @@ namespace LLM {
QWEN3_VL,
};
struct LLMVisionParams {
struct LLMVisionConfig {
LLMVisionArch arch = LLMVisionArch::QWEN2_5_VL;
int num_layers = 32;
int64_t hidden_size = 1280;
@ -79,7 +79,7 @@ namespace LLM {
std::set<int> fullatt_block_indexes = {7, 15, 23, 31};
};
struct LLMParams {
struct LLMConfig {
LLMArch arch = LLMArch::QWEN2_5_VL;
int64_t num_layers = 28;
int64_t hidden_size = 3584;
@ -101,7 +101,129 @@ namespace LLM {
std::vector<int> sliding_attention;
int64_t num_experts = 0;
int64_t num_experts_per_tok = 0;
LLMVisionParams vision;
LLMVisionConfig vision;
bool have_vision_weight = false;
bool llama_cpp_style = false;
static LLMConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
LLMArch arch) {
LLMConfig config;
config.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
config.head_dim = 128;
config.num_heads = 32;
config.num_kv_heads = 8;
config.qkv_bias = false;
config.rms_norm_eps = 1e-5f;
} else if (arch == LLMArch::QWEN3 || arch == LLMArch::QWEN3_VL) {
config.head_dim = 128;
config.num_heads = 32;
config.num_kv_heads = 8;
config.qkv_bias = false;
config.qk_norm = true;
config.rms_norm_eps = 1e-6f;
if (arch == LLMArch::QWEN3_VL) {
config.max_position_embeddings = 262144;
config.rope_thetas = {5000000.f};
config.vision.arch = LLMVisionArch::QWEN3_VL;
}
} else if (arch == LLMArch::GEMMA3_12B) {
config.head_dim = 256;
config.num_heads = 16;
config.num_kv_heads = 8;
config.qkv_bias = false;
config.qk_norm = true;
config.rms_norm_eps = 1e-6f;
config.rms_norm_add = false;
config.normalize_input = true;
config.max_position_embeddings = 131072;
config.mlp_activation = MLPActivation::GELU_TANH;
config.rope_thetas = {1000000.f, 10000.f};
config.rope_scales = {8.f, 1.f};
config.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
} else if (arch == LLMArch::GEMMA2_2B) {
config.head_dim = 256;
config.num_heads = 8;
config.num_kv_heads = 4;
config.qkv_bias = false;
config.qk_norm = false;
config.rms_norm_eps = 1e-6f;
config.rms_norm_add = true;
config.normalize_input = true;
config.max_position_embeddings = 8192;
config.mlp_activation = MLPActivation::GELU_TANH;
config.hidden_size = 2304;
config.intermediate_size = 9216;
config.num_layers = 26;
config.vocab_size = 256000;
} else if (arch == LLMArch::GPT_OSS_20B) {
config.head_dim = 64;
config.num_heads = 64;
config.num_kv_heads = 8;
config.qkv_bias = true;
config.attention_out_bias = true;
config.qk_norm = false;
config.rms_norm_eps = 1e-5f;
config.hidden_size = 2880;
config.intermediate_size = 2880;
config.num_layers = 24;
config.vocab_size = 201088;
config.max_position_embeddings = 131072;
config.rope_thetas = {150000.f};
config.rope_scales = {32.f};
config.sliding_attention = {128, 0};
config.num_experts = 32;
config.num_experts_per_tok = 4;
}
config.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
size_t pos = name.find("visual.");
if (pos != std::string::npos) {
config.have_vision_weight = true;
if (contains(name, "attn.q_proj")) {
config.llama_cpp_style = true;
}
continue;
}
pos = name.find("layers.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.num_layers) {
config.num_layers = block_index + 1;
}
}
}
if (contains(name, "embed_tokens.weight")) {
config.hidden_size = tensor_storage.ne[0];
config.vocab_size = tensor_storage.ne[1];
}
if (contains(name, "layers.0.mlp.gate_proj.weight")) {
config.intermediate_size = tensor_storage.ne[1];
}
if (contains(name, "layers.0.mlp.experts.gate_up_proj.weight")) {
config.intermediate_size = tensor_storage.ne[1] / 2;
}
if (contains(name, "layers.0.mlp.experts.gate_proj.weight")) {
config.intermediate_size = tensor_storage.ne[1];
}
}
if (arch == LLMArch::QWEN3 && config.num_layers == 28) {
config.num_heads = 16;
}
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
config.num_layers,
config.vocab_size,
config.hidden_size,
config.intermediate_size);
return config;
}
};
struct LLMRMSNorm : public UnaryBlock {
@ -232,11 +354,11 @@ namespace LLM {
}
public:
GPTOSSMLP(const LLMParams& params)
: hidden_size(params.hidden_size),
intermediate_size(params.intermediate_size),
num_experts(params.num_experts),
num_experts_per_tok(params.num_experts_per_tok) {}
GPTOSSMLP(const LLMConfig& config)
: hidden_size(config.hidden_size),
intermediate_size(config.intermediate_size),
num_experts(config.num_experts),
num_experts_per_tok(config.num_experts_per_tok) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
// x: [N, n_token, hidden_size]
@ -667,7 +789,7 @@ namespace LLM {
public:
VisionModel(bool llama_cpp_style,
const LLMVisionParams& vision_params,
const LLMVisionConfig& vision_params,
float eps = 1e-6f)
: arch_(vision_params.arch),
num_layers(vision_params.num_layers),
@ -784,23 +906,23 @@ namespace LLM {
}
public:
Attention(const LLMParams& params)
: arch(params.arch),
num_heads(params.num_heads),
num_kv_heads(params.num_kv_heads),
head_dim(params.head_dim),
qk_norm(params.qk_norm),
max_position_embeddings(params.max_position_embeddings),
rope_thetas(params.rope_thetas),
rope_scales(params.rope_scales),
has_attention_sinks(params.arch == LLMArch::GPT_OSS_20B) {
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, params.attention_out_bias);
if (params.qk_norm) {
blocks["q_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
blocks["k_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
Attention(const LLMConfig& config)
: arch(config.arch),
num_heads(config.num_heads),
num_kv_heads(config.num_kv_heads),
head_dim(config.head_dim),
qk_norm(config.qk_norm),
max_position_embeddings(config.max_position_embeddings),
rope_thetas(config.rope_thetas),
rope_scales(config.rope_scales),
has_attention_sinks(config.arch == LLMArch::GPT_OSS_20B) {
blocks["q_proj"] = std::make_shared<Linear>(config.hidden_size, num_heads * head_dim, config.qkv_bias);
blocks["k_proj"] = std::make_shared<Linear>(config.hidden_size, num_kv_heads * head_dim, config.qkv_bias);
blocks["v_proj"] = std::make_shared<Linear>(config.hidden_size, num_kv_heads * head_dim, config.qkv_bias);
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, config.hidden_size, config.attention_out_bias);
if (config.qk_norm) {
blocks["q_norm"] = std::make_shared<LLMRMSNorm>(head_dim, config.rms_norm_eps, config.rms_norm_add);
blocks["k_norm"] = std::make_shared<LLMRMSNorm>(head_dim, config.rms_norm_eps, config.rms_norm_add);
}
}
@ -982,42 +1104,42 @@ namespace LLM {
std::string post_ffw_norm_name;
public:
TransformerBlock(const LLMParams& params, int layer_index)
: arch(params.arch),
TransformerBlock(const LLMConfig& config, int layer_index)
: arch(config.arch),
sliding_attention(0) {
if (params.arch == LLMArch::GEMMA3_12B) {
if (config.arch == LLMArch::GEMMA3_12B) {
post_attention_norm_name = "post_attention_norm"; // attn_post_norm
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
post_ffw_norm_name = "post_ffw_norm"; // ffn_post_norm
} else if (params.arch == LLMArch::GEMMA2_2B) {
} else if (config.arch == LLMArch::GEMMA2_2B) {
post_attention_norm_name = "post_attention_layernorm"; // ffn_norm
pre_ffw_norm_name = "pre_feedforward_layernorm";
post_ffw_norm_name = "post_feedforward_layernorm";
} else if (params.arch == LLMArch::GPT_OSS_20B) {
} else if (config.arch == LLMArch::GPT_OSS_20B) {
pre_ffw_norm_name = "post_attention_norm"; // attn_post_norm
} else {
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
}
blocks["self_attn"] = std::make_shared<Attention>(params);
if (params.arch == LLMArch::GPT_OSS_20B) {
blocks["mlp"] = std::make_shared<GPTOSSMLP>(params);
blocks["self_attn"] = std::make_shared<Attention>(config);
if (config.arch == LLMArch::GPT_OSS_20B) {
blocks["mlp"] = std::make_shared<GPTOSSMLP>(config);
} else {
blocks["mlp"] = std::make_shared<MLP>(params.hidden_size,
params.intermediate_size,
blocks["mlp"] = std::make_shared<MLP>(config.hidden_size,
config.intermediate_size,
false,
params.mlp_activation);
config.mlp_activation);
}
blocks["input_layernorm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
blocks[pre_ffw_norm_name] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
blocks["input_layernorm"] = std::make_shared<LLMRMSNorm>(config.hidden_size, config.rms_norm_eps, config.rms_norm_add);
blocks[pre_ffw_norm_name] = std::make_shared<LLMRMSNorm>(config.hidden_size, config.rms_norm_eps, config.rms_norm_add);
if (!post_attention_norm_name.empty()) {
blocks[post_attention_norm_name] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
blocks[post_attention_norm_name] = std::make_shared<LLMRMSNorm>(config.hidden_size, config.rms_norm_eps, config.rms_norm_add);
}
if (!post_ffw_norm_name.empty()) {
blocks[post_ffw_norm_name] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
blocks[post_ffw_norm_name] = std::make_shared<LLMRMSNorm>(config.hidden_size, config.rms_norm_eps, config.rms_norm_add);
}
if (!params.sliding_attention.empty()) {
sliding_attention = params.sliding_attention[layer_index % params.sliding_attention.size()];
if (!config.sliding_attention.empty()) {
sliding_attention = config.sliding_attention[layer_index % config.sliding_attention.size()];
}
}
@ -1074,16 +1196,16 @@ namespace LLM {
struct TextModel : public GGMLBlock {
protected:
int64_t num_layers;
LLMParams params;
LLMConfig config;
public:
TextModel(const LLMParams& params)
: num_layers(params.num_layers), params(params) {
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size, params.hidden_size));
TextModel(const LLMConfig& config)
: num_layers(config.num_layers), config(config) {
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(config.vocab_size, config.hidden_size));
for (int i = 0; i < num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(params, i));
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(config, i));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LLMRMSNorm(params.hidden_size, params.rms_norm_eps, params.rms_norm_add));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LLMRMSNorm(config.hidden_size, config.rms_norm_eps, config.rms_norm_add));
}
ggml_tensor* embed(GGMLRunnerContext* ctx,
@ -1103,8 +1225,8 @@ namespace LLM {
auto norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["norm"]);
std::vector<ggml_tensor*> intermediate_outputs;
if (params.normalize_input) {
x = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(static_cast<float>(params.hidden_size)), true);
if (config.normalize_input) {
x = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(static_cast<float>(config.hidden_size)), true);
}
if (return_all_hidden_states) {
intermediate_outputs.push_back(x);
@ -1174,15 +1296,15 @@ namespace LLM {
struct LLM : public GGMLBlock {
bool enable_vision;
LLMParams params;
LLMConfig config;
public:
LLM() = default;
LLM(LLMParams params, bool enable_vision = false, bool llama_cpp_style = false)
: enable_vision(enable_vision), params(params) {
blocks["model"] = std::shared_ptr<GGMLBlock>(new TextModel(params));
LLM(LLMConfig config, bool enable_vision = false, bool llama_cpp_style = false)
: enable_vision(enable_vision), config(config) {
blocks["model"] = std::shared_ptr<GGMLBlock>(new TextModel(config));
if (enable_vision) {
blocks["visual"] = std::shared_ptr<GGMLBlock>(new VisionModel(llama_cpp_style, params.vision));
blocks["visual"] = std::shared_ptr<GGMLBlock>(new VisionModel(llama_cpp_style, config.vision));
}
}
@ -1226,7 +1348,7 @@ namespace LLM {
};
struct LLMRunner : public GGMLRunner {
LLMParams params;
LLMConfig config;
bool enable_vision;
LLM model;
@ -1242,7 +1364,7 @@ namespace LLM {
static ggml_tensor* process_image_common(ggml_context* ctx,
ggml_tensor* image,
const LLMVisionParams& vision_params) {
const LLMVisionConfig& vision_params) {
// image: [C, H, W]
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
int64_t C = image->ne[2];
@ -1337,7 +1459,7 @@ namespace LLM {
ggml_context* compute_ctx,
GGMLRunnerContext* runner_ctx,
ggml_tensor* image,
const LLMVisionParams& vision_params,
const LLMVisionConfig& vision_params,
std::shared_ptr<VisionModel> vision_model,
std::vector<int>& window_index_vec,
std::vector<int>& window_inverse_index_vec,
@ -1452,141 +1574,25 @@ namespace LLM {
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool enable_vision_ = false)
: GGMLRunner(backend, params_backend), enable_vision(enable_vision_) {
params.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.rms_norm_eps = 1e-5f;
} else if (arch == LLMArch::QWEN3 || arch == LLMArch::QWEN3_VL) {
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.rms_norm_eps = 1e-6f;
if (arch == LLMArch::QWEN3_VL) {
params.max_position_embeddings = 262144;
params.rope_thetas = {5000000.f};
params.vision.arch = LLMVisionArch::QWEN3_VL;
}
} else if (arch == LLMArch::GEMMA3_12B) {
params.head_dim = 256;
params.num_heads = 16;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.rms_norm_eps = 1e-6f;
// llama.cpp adds +1 to Gemma3 norm.weight when exporting GGUF, so GGUF loading
// must keep rms_norm_add disabled here or the offset gets applied twice.
// Convenient for the converter, less convenient for whoever gets to debug it later.
params.rms_norm_add = false;
params.normalize_input = true;
params.max_position_embeddings = 131072;
params.mlp_activation = MLPActivation::GELU_TANH;
params.rope_thetas = {1000000.f, 10000.f};
params.rope_scales = {8.f, 1.f};
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
} else if (arch == LLMArch::GEMMA2_2B) {
params.head_dim = 256;
params.num_heads = 8;
params.num_kv_heads = 4;
params.qkv_bias = false;
params.qk_norm = false;
params.rms_norm_eps = 1e-6f;
params.rms_norm_add = true;
params.normalize_input = true;
params.max_position_embeddings = 8192;
params.mlp_activation = MLPActivation::GELU_TANH;
params.hidden_size = 2304;
params.intermediate_size = 9216;
params.num_layers = 26;
params.vocab_size = 256000;
} else if (arch == LLMArch::GPT_OSS_20B) {
params.head_dim = 64;
params.num_heads = 64;
params.num_kv_heads = 8;
params.qkv_bias = true;
params.attention_out_bias = true;
params.qk_norm = false;
params.rms_norm_eps = 1e-5f;
params.hidden_size = 2880;
params.intermediate_size = 2880;
params.num_layers = 24;
params.vocab_size = 201088;
params.max_position_embeddings = 131072;
params.rope_thetas = {150000.f};
params.rope_scales = {32.f};
params.sliding_attention = {128, 0};
params.num_experts = 32;
params.num_experts_per_tok = 4;
}
bool have_vision_weight = false;
bool llama_cpp_style = false;
params.num_layers = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("visual.");
if (pos != std::string::npos) {
have_vision_weight = true;
if (contains(tensor_name, "attn.q_proj")) {
llama_cpp_style = true;
}
continue;
}
pos = tensor_name.find("layers.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > params.num_layers) {
params.num_layers = block_index + 1;
}
}
}
if (contains(tensor_name, "embed_tokens.weight")) {
params.hidden_size = pair.second.ne[0];
params.vocab_size = pair.second.ne[1];
}
if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) {
params.intermediate_size = pair.second.ne[1];
}
if (contains(tensor_name, "layers.0.mlp.experts.gate_up_proj.weight")) {
params.intermediate_size = pair.second.ne[1] / 2;
}
if (contains(tensor_name, "layers.0.mlp.experts.gate_proj.weight")) {
params.intermediate_size = pair.second.ne[1];
}
}
if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B
params.num_heads = 16;
}
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
params.num_layers,
params.vocab_size,
params.hidden_size,
params.intermediate_size);
if (enable_vision && !have_vision_weight) {
: GGMLRunner(backend, params_backend),
config(LLMConfig::detect_from_weights(tensor_storage_map, prefix, arch)),
enable_vision(enable_vision_) {
if (enable_vision && !config.have_vision_weight) {
LOG_WARN("no vision weights detected, vision disabled");
enable_vision = false;
}
if (enable_vision) {
LOG_DEBUG("enable llm vision");
if (llama_cpp_style) {
if (config.llama_cpp_style) {
LOG_DEBUG("llama.cpp style vision weight");
}
}
model = LLM(params, enable_vision, llama_cpp_style);
model = LLM(config, enable_vision, config.llama_cpp_style);
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return llm_arch_to_str[static_cast<int>(params.arch)];
return llm_arch_to_str[static_cast<int>(config.arch)];
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
@ -1638,12 +1644,12 @@ namespace LLM {
}
int64_t n_tokens = input_ids->ne[0];
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 ||
params.arch == LLMArch::MINISTRAL_3_3B ||
params.arch == LLMArch::QWEN3 ||
params.arch == LLMArch::GEMMA3_12B ||
params.arch == LLMArch::GEMMA2_2B ||
params.arch == LLMArch::GPT_OSS_20B) {
if (config.arch == LLMArch::MISTRAL_SMALL_3_2 ||
config.arch == LLMArch::MINISTRAL_3_3B ||
config.arch == LLMArch::QWEN3 ||
config.arch == LLMArch::GEMMA3_12B ||
config.arch == LLMArch::GEMMA2_2B ||
config.arch == LLMArch::GPT_OSS_20B) {
input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) {
input_pos_vec[i] = i;
@ -1682,9 +1688,9 @@ namespace LLM {
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
}
if (params.arch == LLMArch::GEMMA3_12B || params.arch == LLMArch::GPT_OSS_20B) {
if (config.arch == LLMArch::GEMMA3_12B || config.arch == LLMArch::GPT_OSS_20B) {
int sliding_window = 0;
for (int window : params.sliding_attention) {
for (int window : config.sliding_attention) {
sliding_window = std::max(sliding_window, window);
}
sliding_attention_mask_vec.resize(n_tokens * n_tokens);
@ -1740,15 +1746,15 @@ namespace LLM {
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
int64_t grid_t = 1;
int64_t grid_h = h / params.vision.patch_size;
int64_t grid_w = w / params.vision.patch_size;
int64_t llm_grid_h = grid_h / params.vision.spatial_merge_size;
int64_t llm_grid_w = grid_w / params.vision.spatial_merge_size;
int64_t grid_h = h / config.vision.patch_size;
int64_t grid_w = w / config.vision.patch_size;
int64_t llm_grid_h = grid_h / config.vision.spatial_merge_size;
int64_t llm_grid_w = grid_w / config.vision.spatial_merge_size;
return grid_t * grid_h * grid_w;
}
ggml_tensor* process_image(ggml_context* ctx, ggml_tensor* image) {
return process_image_common(ctx, image, params.vision);
return process_image_common(ctx, image, config.vision);
}
ggml_tensor* build_patch_pos_embeds(GGMLRunnerContext* runner_ctx,
@ -1770,7 +1776,7 @@ namespace LLM {
compute_ctx,
runner_ctx,
image,
params.vision,
config.vision,
model.vision_model(),
window_index_vec,
window_inverse_index_vec,
@ -1784,8 +1790,8 @@ namespace LLM {
ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE);
ggml_tensor* image = make_input(image_tensor);
GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[1] % (config.vision.patch_size * config.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[0] % (config.vision.patch_size * config.vision.spatial_merge_size) == 0);
auto runnter_ctx = get_context();
ggml_tensor* hidden_states = encode_image(&runnter_ctx, image);

View File

@ -58,11 +58,12 @@ namespace LTXV {
return base_output_sample_rate();
}
static LTXAudioVAEConfig detect_from_weights(const String2TensorStorage& tensor_storage_map) {
static LTXAudioVAEConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix = "") {
LTXAudioVAEConfig config;
auto require = [&](const std::string& name) -> const TensorStorage* {
auto iter = tensor_storage_map.find(name);
std::string tensor_name = prefix.empty() ? name : prefix + "." + name;
auto iter = tensor_storage_map.find(tensor_name);
if (iter == tensor_storage_map.end()) {
return nullptr;
}
@ -168,6 +169,12 @@ namespace LTXV {
if (config.audio_channels != 2 || config.latent_channels != 8 || config.mel_bins != 64) {
return config;
}
LOG_DEBUG("ltx_audio_vae: sample_rate = %d, mel_bins = %d, latent_channels = %d, latent_frequency_bins = %d, has_bwe = %s",
config.sample_rate,
config.mel_bins,
config.latent_channels,
config.latent_frequency_bins,
config.has_bwe ? "true" : "false");
return config;
}
};

View File

@ -72,6 +72,200 @@ namespace LTXV {
return max_block + 1;
}
struct LTXAVConfig {
int64_t in_channels = 128;
int64_t out_channels = 128;
int64_t hidden_size = 3840;
int64_t cross_attention_dim = 4096;
int64_t caption_channels = 3840;
int64_t num_attention_heads = 30;
int64_t attention_head_dim = 128;
int64_t num_layers = 28;
float positional_embedding_theta = 10000.f;
std::vector<int> positional_embedding_max_pos = {20, 2048, 2048};
std::tuple<int, int, int> vae_scale_factors = {8, 32, 32};
bool causal_temporal_positioning = true;
float timestep_scale_multiplier = 1000.f;
int64_t audio_in_channels = 128;
int64_t audio_out_channels = 128;
int64_t audio_hidden_size = 2048;
int64_t audio_cross_attention_dim = 2048;
int64_t audio_num_attention_heads = 32;
int64_t audio_attention_head_dim = 64;
std::vector<int> audio_positional_embedding_max_pos = {20};
float av_ca_timestep_scale_multiplier = 1000.f;
int64_t num_audio_channels = 8;
int64_t audio_frequency_bins = 16;
bool use_connector = false;
int64_t connector_hidden_size = 3840;
int64_t connector_num_heads = 30;
int64_t connector_head_dim = 128;
int64_t connector_num_layers = 2;
int64_t connector_num_registers = 128;
bool connector_rope_interleaved = false;
bool connector_apply_gated_attention = false;
bool use_audio_connector = false;
int64_t audio_connector_hidden_size = 2048;
int64_t audio_connector_num_heads = 32;
int64_t audio_connector_head_dim = 64;
int64_t audio_connector_num_layers = 2;
int64_t audio_connector_num_registers = 128;
bool audio_connector_rope_interleaved = false;
bool audio_connector_apply_gated_attention = false;
bool video_rope_interleaved = false;
bool use_middle_indices_grid = true;
bool cross_attention_adaln = false;
bool use_caption_projection = true;
bool use_audio_caption_projection = true;
bool caption_proj_before_connector = true;
bool caption_projection_first_linear = false;
bool self_attention_gated = false;
bool cross_attention_gated = false;
static std::pair<int64_t, int64_t> infer_attention_layout(int64_t hidden_size,
int64_t preferred_heads = -1) {
if (preferred_heads > 0 && hidden_size % preferred_heads == 0) {
return {preferred_heads, hidden_size / preferred_heads};
}
const int candidates[] = {128, 96, 80, 64, 48, 40, 32};
for (int head_dim : candidates) {
if (hidden_size % head_dim == 0) {
int64_t heads = hidden_size / head_dim;
if (heads >= 8 && heads <= 64) {
return {heads, head_dim};
}
}
}
return {32, hidden_size / 32};
}
static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map,
const std::string& bias_name,
int64_t fallback_heads) {
auto it = tensor_storage_map.find(bias_name);
if (it != tensor_storage_map.end()) {
return it->second.ne[0];
}
return fallback_heads;
}
static LTXAVConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
LTXAVConfig config;
auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight");
if (patchify_proj_iter != tensor_storage_map.end()) {
config.in_channels = patchify_proj_iter->second.ne[0];
config.hidden_size = patchify_proj_iter->second.ne[1];
int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32);
auto attn_layout = infer_attention_layout(config.hidden_size, video_heads);
config.num_attention_heads = attn_layout.first;
config.attention_head_dim = attn_layout.second;
}
auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight");
if (audio_patchify_proj_iter != tensor_storage_map.end()) {
config.audio_in_channels = audio_patchify_proj_iter->second.ne[0];
config.audio_hidden_size = audio_patchify_proj_iter->second.ne[1];
config.audio_out_channels = config.audio_in_channels;
int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32);
auto audio_attn_layout = infer_attention_layout(config.audio_hidden_size, audio_heads);
config.audio_num_attention_heads = audio_attn_layout.first;
config.audio_attention_head_dim = audio_attn_layout.second;
}
auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight");
if (proj_out_iter != tensor_storage_map.end()) {
config.out_channels = proj_out_iter->second.ne[1];
}
auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight");
if (audio_proj_out_iter != tensor_storage_map.end()) {
config.audio_out_channels = audio_proj_out_iter->second.ne[1];
}
auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight");
if (attn2_iter != tensor_storage_map.end()) {
config.cross_attention_dim = attn2_iter->second.ne[0];
}
auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight");
if (audio_attn2_iter != tensor_storage_map.end()) {
config.audio_cross_attention_dim = audio_attn2_iter->second.ne[0];
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) {
config.cross_attention_adaln = true;
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() ||
tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
config.self_attention_gated = true;
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() ||
tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) {
config.cross_attention_gated = true;
}
if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() &&
tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) {
config.use_caption_projection = false;
}
if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() &&
tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) {
config.use_audio_caption_projection = false;
}
config.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks.");
auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight");
if (connector_iter != tensor_storage_map.end()) {
config.use_connector = true;
config.connector_hidden_size = connector_iter->second.ne[1];
int64_t connector_heads = infer_gate_heads(tensor_storage_map,
prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias",
32);
auto connector_layout = infer_attention_layout(config.connector_hidden_size, connector_heads);
config.connector_num_heads = connector_layout.first;
config.connector_head_dim = connector_layout.second;
config.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks.");
auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers");
if (register_iter != tensor_storage_map.end()) {
config.connector_num_registers = register_iter->second.ne[1];
}
if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
config.connector_apply_gated_attention = true;
}
}
auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight");
if (audio_connector_iter != tensor_storage_map.end()) {
config.use_audio_connector = true;
config.audio_connector_hidden_size = audio_connector_iter->second.ne[1];
int64_t connector_heads = infer_gate_heads(tensor_storage_map,
prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias",
32);
auto connector_layout = infer_attention_layout(config.audio_connector_hidden_size, connector_heads);
config.audio_connector_num_heads = connector_layout.first;
config.audio_connector_head_dim = connector_layout.second;
config.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks.");
auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers");
if (register_iter != tensor_storage_map.end()) {
config.audio_connector_num_registers = register_iter->second.ne[1];
}
if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
config.audio_connector_apply_gated_attention = true;
}
}
LOG_DEBUG("ltxav: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_attention_heads = %" PRId64 ", audio_hidden_size = %" PRId64 ", audio_num_attention_heads = %" PRId64,
config.num_layers,
config.hidden_size,
config.num_attention_heads,
config.audio_hidden_size,
config.audio_num_attention_heads);
return config;
}
};
__STATIC_INLINE__ std::vector<float> generate_freq_grid(float theta,
int positional_dims,
int dim) {
@ -749,63 +943,6 @@ namespace LTXV {
}
};
struct LTXAVParams {
int64_t in_channels = 128;
int64_t out_channels = 128;
int64_t hidden_size = 3840;
int64_t cross_attention_dim = 4096;
int64_t caption_channels = 3840;
int64_t num_attention_heads = 30;
int64_t attention_head_dim = 128;
int64_t num_layers = 28;
float positional_embedding_theta = 10000.f;
std::vector<int> positional_embedding_max_pos = {20, 2048, 2048};
std::tuple<int, int, int> vae_scale_factors = {8, 32, 32};
bool causal_temporal_positioning = true;
float timestep_scale_multiplier = 1000.f;
int64_t audio_in_channels = 128;
int64_t audio_out_channels = 128;
int64_t audio_hidden_size = 2048;
int64_t audio_cross_attention_dim = 2048;
int64_t audio_num_attention_heads = 32;
int64_t audio_attention_head_dim = 64;
std::vector<int> audio_positional_embedding_max_pos = {20};
float av_ca_timestep_scale_multiplier = 1000.f;
int64_t num_audio_channels = 8;
int64_t audio_frequency_bins = 16;
bool use_connector = false;
int64_t connector_hidden_size = 3840;
int64_t connector_num_heads = 30;
int64_t connector_head_dim = 128;
int64_t connector_num_layers = 2;
int64_t connector_num_registers = 128;
bool connector_rope_interleaved = false;
bool connector_apply_gated_attention = false;
bool use_audio_connector = false;
int64_t audio_connector_hidden_size = 2048;
int64_t audio_connector_num_heads = 32;
int64_t audio_connector_head_dim = 64;
int64_t audio_connector_num_layers = 2;
int64_t audio_connector_num_registers = 128;
bool audio_connector_rope_interleaved = false;
bool audio_connector_apply_gated_attention = false;
bool video_rope_interleaved = false;
bool use_middle_indices_grid = true;
bool cross_attention_adaln = false;
bool use_caption_projection = true;
bool use_audio_caption_projection = true;
bool caption_proj_before_connector = true;
bool caption_projection_first_linear = false;
bool self_attention_gated = false;
bool cross_attention_gated = false;
};
__STATIC_INLINE__ std::pair<int64_t, int64_t> infer_attention_layout(int64_t hidden_size,
int64_t preferred_heads = -1) {
if (preferred_heads > 0 && hidden_size % preferred_heads == 0) {
@ -1169,92 +1306,92 @@ namespace LTXV {
};
struct LTXAVModelBlock : public GGMLBlock {
LTXAVParams cfg;
LTXAVConfig config;
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
params["scale_shift_table"] = ggml_new_tensor_2d(ctx,
get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32),
cfg.hidden_size,
config.hidden_size,
2);
params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx,
get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32),
cfg.audio_hidden_size,
config.audio_hidden_size,
2);
}
LTXAVModelBlock(const LTXAVParams& params)
: cfg(params) {
blocks["patchify_proj"] = std::make_shared<Linear>(cfg.in_channels, cfg.hidden_size, true, true);
blocks["audio_patchify_proj"] = std::make_shared<Linear>(cfg.audio_in_channels, cfg.audio_hidden_size, true, true);
blocks["adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.hidden_size, cfg.cross_attention_adaln ? 9 : 6);
blocks["audio_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.audio_hidden_size, cfg.cross_attention_adaln ? 9 : 6);
if (cfg.cross_attention_adaln) {
blocks["prompt_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.hidden_size, 2);
blocks["audio_prompt_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.audio_hidden_size, 2);
LTXAVModelBlock(const LTXAVConfig& config)
: config(config) {
blocks["patchify_proj"] = std::make_shared<Linear>(config.in_channels, config.hidden_size, true, true);
blocks["audio_patchify_proj"] = std::make_shared<Linear>(config.audio_in_channels, config.audio_hidden_size, true, true);
blocks["adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.hidden_size, config.cross_attention_adaln ? 9 : 6);
blocks["audio_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.audio_hidden_size, config.cross_attention_adaln ? 9 : 6);
if (config.cross_attention_adaln) {
blocks["prompt_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.hidden_size, 2);
blocks["audio_prompt_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.audio_hidden_size, 2);
}
blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.hidden_size, 4);
blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.hidden_size, 1);
blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.audio_hidden_size, 4);
blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared<AdaLayerNormSingle>(cfg.audio_hidden_size, 1);
blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.hidden_size, 4);
blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.hidden_size, 1);
blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.audio_hidden_size, 4);
blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared<AdaLayerNormSingle>(config.audio_hidden_size, 1);
if (cfg.use_caption_projection) {
if (cfg.caption_proj_before_connector) {
if (cfg.caption_projection_first_linear) {
blocks["caption_projection"] = std::make_shared<NormSingleLinearTextProjection>(cfg.caption_channels, cfg.hidden_size);
if (config.use_caption_projection) {
if (config.caption_proj_before_connector) {
if (config.caption_projection_first_linear) {
blocks["caption_projection"] = std::make_shared<NormSingleLinearTextProjection>(config.caption_channels, config.hidden_size);
}
} else {
blocks["caption_projection"] = std::make_shared<PixArtAlphaTextProjection>(cfg.caption_channels, cfg.hidden_size, cfg.hidden_size);
blocks["caption_projection"] = std::make_shared<PixArtAlphaTextProjection>(config.caption_channels, config.hidden_size, config.hidden_size);
}
}
if (cfg.use_audio_caption_projection) {
if (cfg.caption_proj_before_connector) {
if (cfg.caption_projection_first_linear) {
blocks["audio_caption_projection"] = std::make_shared<NormSingleLinearTextProjection>(cfg.caption_channels, cfg.audio_hidden_size);
if (config.use_audio_caption_projection) {
if (config.caption_proj_before_connector) {
if (config.caption_projection_first_linear) {
blocks["audio_caption_projection"] = std::make_shared<NormSingleLinearTextProjection>(config.caption_channels, config.audio_hidden_size);
}
} else {
blocks["audio_caption_projection"] = std::make_shared<PixArtAlphaTextProjection>(cfg.caption_channels, cfg.audio_hidden_size, cfg.audio_hidden_size);
blocks["audio_caption_projection"] = std::make_shared<PixArtAlphaTextProjection>(config.caption_channels, config.audio_hidden_size, config.audio_hidden_size);
}
}
if (cfg.use_connector) {
blocks["video_embeddings_connector"] = std::make_shared<Embeddings1DConnector>(cfg.connector_hidden_size,
cfg.connector_num_heads,
cfg.connector_head_dim,
cfg.connector_num_layers,
cfg.connector_num_registers,
cfg.connector_rope_interleaved,
cfg.connector_apply_gated_attention);
if (config.use_connector) {
blocks["video_embeddings_connector"] = std::make_shared<Embeddings1DConnector>(config.connector_hidden_size,
config.connector_num_heads,
config.connector_head_dim,
config.connector_num_layers,
config.connector_num_registers,
config.connector_rope_interleaved,
config.connector_apply_gated_attention);
}
if (cfg.use_audio_connector) {
blocks["audio_embeddings_connector"] = std::make_shared<Embeddings1DConnector>(cfg.audio_connector_hidden_size,
cfg.audio_connector_num_heads,
cfg.audio_connector_head_dim,
cfg.audio_connector_num_layers,
cfg.audio_connector_num_registers,
cfg.audio_connector_rope_interleaved,
cfg.audio_connector_apply_gated_attention);
if (config.use_audio_connector) {
blocks["audio_embeddings_connector"] = std::make_shared<Embeddings1DConnector>(config.audio_connector_hidden_size,
config.audio_connector_num_heads,
config.audio_connector_head_dim,
config.audio_connector_num_layers,
config.audio_connector_num_registers,
config.audio_connector_rope_interleaved,
config.audio_connector_apply_gated_attention);
}
for (int i = 0; i < cfg.num_layers; i++) {
blocks["transformer_blocks." + std::to_string(i)] = std::make_shared<BasicAVTransformerBlock>(cfg.hidden_size,
cfg.audio_hidden_size,
cfg.num_attention_heads,
cfg.audio_num_attention_heads,
cfg.attention_head_dim,
cfg.audio_attention_head_dim,
cfg.cross_attention_dim,
cfg.audio_cross_attention_dim,
cfg.self_attention_gated || cfg.cross_attention_gated,
cfg.cross_attention_adaln,
cfg.video_rope_interleaved);
for (int i = 0; i < config.num_layers; i++) {
blocks["transformer_blocks." + std::to_string(i)] = std::make_shared<BasicAVTransformerBlock>(config.hidden_size,
config.audio_hidden_size,
config.num_attention_heads,
config.audio_num_attention_heads,
config.attention_head_dim,
config.audio_attention_head_dim,
config.cross_attention_dim,
config.audio_cross_attention_dim,
config.self_attention_gated || config.cross_attention_gated,
config.cross_attention_adaln,
config.video_rope_interleaved);
}
blocks["norm_out"] = std::make_shared<LayerNorm>(cfg.hidden_size, 1e-6f, false);
blocks["proj_out"] = std::make_shared<Linear>(cfg.hidden_size, cfg.out_channels, true, true);
blocks["audio_norm_out"] = std::make_shared<LayerNorm>(cfg.audio_hidden_size, 1e-6f, false);
blocks["audio_proj_out"] = std::make_shared<Linear>(cfg.audio_hidden_size, cfg.audio_out_channels, true, true);
blocks["norm_out"] = std::make_shared<LayerNorm>(config.hidden_size, 1e-6f, false);
blocks["proj_out"] = std::make_shared<Linear>(config.hidden_size, config.out_channels, true, true);
blocks["audio_norm_out"] = std::make_shared<LayerNorm>(config.audio_hidden_size, 1e-6f, false);
blocks["audio_proj_out"] = std::make_shared<Linear>(config.audio_hidden_size, config.audio_out_channels, true, true);
}
ggml_tensor* patchify_video(GGMLRunnerContext* ctx, ggml_tensor* x, int64_t n) {
@ -1293,8 +1430,8 @@ namespace LTXV {
if (ax == nullptr) {
return nullptr;
}
ax = ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, cfg.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f]
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f]
ax = ggml_reshape_4d(ctx->ggml_ctx, ax, config.audio_frequency_bins, config.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f]
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f]
return ax;
}
@ -1308,17 +1445,17 @@ namespace LTXV {
}
bool is_fully_processed_context =
context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim &&
context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim &&
context->ne[1] >= 1024;
bool is_unprocessed_dual_context =
context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim &&
context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim &&
context->ne[1] < 1024;
if (is_fully_processed_context) {
auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim);
auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim);
ggml_tensor* a_context = nullptr;
if (process_audio_context) {
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim);
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim);
}
return {v_context, a_context};
}
@ -1326,32 +1463,32 @@ namespace LTXV {
ggml_tensor* v_context = context;
ggml_tensor* a_context = process_audio_context ? context : nullptr;
if (is_unprocessed_dual_context) {
v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim);
v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim);
if (process_audio_context) {
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim);
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim);
}
} else if (context->ne[0] == cfg.caption_channels * 2) {
v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.caption_channels);
} else if (context->ne[0] == config.caption_channels * 2) {
v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.caption_channels);
if (process_audio_context) {
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.caption_channels, cfg.caption_channels * 2);
a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.caption_channels, config.caption_channels * 2);
}
}
if (cfg.caption_proj_before_connector) {
if (cfg.use_caption_projection &&
if (config.caption_proj_before_connector) {
if (config.use_caption_projection &&
blocks.count("caption_projection") > 0 &&
v_context != nullptr &&
v_context->ne[0] == cfg.caption_channels) {
v_context->ne[0] == config.caption_channels) {
auto caption_projection = std::dynamic_pointer_cast<NormSingleLinearTextProjection>(blocks["caption_projection"]);
if (caption_projection != nullptr) {
v_context = caption_projection->forward(ctx, v_context);
}
}
if (process_audio_context &&
cfg.use_audio_caption_projection &&
config.use_audio_caption_projection &&
blocks.count("audio_caption_projection") > 0 &&
a_context != nullptr &&
a_context->ne[0] == cfg.caption_channels) {
a_context->ne[0] == config.caption_channels) {
auto caption_projection = std::dynamic_pointer_cast<NormSingleLinearTextProjection>(blocks["audio_caption_projection"]);
if (caption_projection != nullptr) {
a_context = caption_projection->forward(ctx, a_context);
@ -1359,34 +1496,34 @@ namespace LTXV {
}
}
if (cfg.use_connector && v_context != nullptr && v_context->ne[0] == cfg.connector_hidden_size) {
if (config.use_connector && v_context != nullptr && v_context->ne[0] == config.connector_hidden_size) {
auto connector = std::dynamic_pointer_cast<Embeddings1DConnector>(blocks["video_embeddings_connector"]);
v_context = connector->forward(ctx, v_context, video_connector_pe);
}
if (process_audio_context &&
cfg.use_audio_connector &&
config.use_audio_connector &&
a_context != nullptr &&
a_context->ne[0] == cfg.audio_connector_hidden_size) {
a_context->ne[0] == config.audio_connector_hidden_size) {
auto connector = std::dynamic_pointer_cast<Embeddings1DConnector>(blocks["audio_embeddings_connector"]);
a_context = connector->forward(ctx, a_context, audio_connector_pe);
}
if (!cfg.caption_proj_before_connector &&
cfg.use_caption_projection &&
if (!config.caption_proj_before_connector &&
config.use_caption_projection &&
blocks.count("caption_projection") > 0 &&
v_context != nullptr &&
v_context->ne[0] == cfg.caption_channels) {
v_context->ne[0] == config.caption_channels) {
auto caption_projection = std::dynamic_pointer_cast<PixArtAlphaTextProjection>(blocks["caption_projection"]);
if (caption_projection != nullptr) {
v_context = caption_projection->forward(ctx, v_context);
}
}
if (process_audio_context &&
!cfg.caption_proj_before_connector &&
cfg.use_audio_caption_projection &&
!config.caption_proj_before_connector &&
config.use_audio_caption_projection &&
blocks.count("audio_caption_projection") > 0 &&
a_context != nullptr &&
a_context->ne[0] == cfg.caption_channels) {
a_context->ne[0] == config.caption_channels) {
auto caption_projection = std::dynamic_pointer_cast<PixArtAlphaTextProjection>(blocks["audio_caption_projection"]);
if (caption_projection != nullptr) {
a_context = caption_projection->forward(ctx, a_context);
@ -1428,8 +1565,8 @@ namespace LTXV {
auto audio_norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["audio_norm_out"]);
auto audio_proj_out = std::dynamic_pointer_cast<Linear>(blocks["audio_proj_out"]);
GGML_ASSERT(vx->ne[3] % cfg.in_channels == 0);
int64_t n = vx->ne[3] / cfg.in_channels;
GGML_ASSERT(vx->ne[3] % config.in_channels == 0);
int64_t n = vx->ne[3] / config.in_channels;
int64_t width = vx->ne[0];
int64_t height = vx->ne[1];
int64_t frames = vx->ne[2];
@ -1452,20 +1589,20 @@ namespace LTXV {
a_context = ggml_cont(ctx->ggml_ctx, a_context);
}
auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, cfg.timestep_scale_multiplier);
auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, config.timestep_scale_multiplier);
auto v_pair = adaln_single->forward(ctx, v_timestep_scaled);
auto v_timestep_mod = v_pair.first;
auto v_embedded_time = v_pair.second;
ggml_tensor* effective_audio_timestep = audio_timestep != nullptr ? audio_timestep : timestep;
auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, cfg.timestep_scale_multiplier);
auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, config.timestep_scale_multiplier);
auto a_pair = audio_adaln_single->forward(ctx, a_timestep_scaled);
auto a_timestep_mod = a_pair.first;
auto a_embedded_time = a_pair.second;
ggml_tensor* v_prompt_timestep_mod = nullptr;
ggml_tensor* a_prompt_timestep_mod = nullptr;
if (cfg.cross_attention_adaln) {
if (config.cross_attention_adaln) {
auto prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["prompt_adaln_single"]);
auto audio_prompt_adaln_single = std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["audio_prompt_adaln_single"]);
v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first;
@ -1474,7 +1611,7 @@ namespace LTXV {
auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep);
auto av_ca_audio_timestep = effective_audio_timestep;
auto av_ca_factor = cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier;
auto av_ca_factor = config.av_ca_timestep_scale_multiplier / config.timestep_scale_multiplier;
auto av_ca_video_scale_shift_timestep =
std::dynamic_pointer_cast<AdaLayerNormSingle>(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first;
auto av_ca_a2v_gate_noise_timestep =
@ -1491,7 +1628,7 @@ namespace LTXV {
sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.prelude", "vx");
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.prelude", "ax");
for (int i = 0; i < cfg.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::dynamic_pointer_cast<BasicAVTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
auto out = block->forward(ctx,
vx,
@ -1517,14 +1654,14 @@ namespace LTXV {
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.transformer_blocks." + std::to_string(i), "ax");
}
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, config.hidden_size);
vx = norm_out->forward(ctx, vx);
vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]);
vx = proj_out->forward(ctx, vx);
vx = unpatchify_video(ctx, vx, width, height, frames);
if (ax != nullptr && audio_time > 0) {
auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size);
auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, config.audio_hidden_size);
ax = audio_norm_out->forward(ctx, ax);
ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]);
ax = audio_proj_out->forward(ctx, ax);
@ -1536,7 +1673,7 @@ namespace LTXV {
};
struct LTXAVRunner : public DiffusionModelRunner {
LTXAVParams params;
LTXAVConfig config;
LTXAVModelBlock model;
std::vector<float> video_pe_vec;
std::vector<float> audio_pe_vec;
@ -1547,124 +1684,13 @@ namespace LTXV {
sd::Tensor<float> vx_input_cache;
sd::Tensor<float> ax_input_cache;
static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map,
const std::string& bias_name,
int64_t fallback_heads) {
auto it = tensor_storage_map.find(bias_name);
if (it != tensor_storage_map.end()) {
return it->second.ne[0];
}
return fallback_heads;
}
LTXAVRunner(ggml_backend_t backend,
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "model.diffusion_model")
: DiffusionModelRunner(backend, params_backend, prefix),
params(),
model(params) {
auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight");
if (patchify_proj_iter != tensor_storage_map.end()) {
params.in_channels = patchify_proj_iter->second.ne[0];
params.hidden_size = patchify_proj_iter->second.ne[1];
int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32);
auto attn_layout = infer_attention_layout(params.hidden_size, video_heads);
params.num_attention_heads = attn_layout.first;
params.attention_head_dim = attn_layout.second;
}
auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight");
if (audio_patchify_proj_iter != tensor_storage_map.end()) {
params.audio_in_channels = audio_patchify_proj_iter->second.ne[0];
params.audio_hidden_size = audio_patchify_proj_iter->second.ne[1];
params.audio_out_channels = params.audio_in_channels;
int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32);
auto audio_attn_layout = infer_attention_layout(params.audio_hidden_size, audio_heads);
params.audio_num_attention_heads = audio_attn_layout.first;
params.audio_attention_head_dim = audio_attn_layout.second;
}
auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight");
if (proj_out_iter != tensor_storage_map.end()) {
params.out_channels = proj_out_iter->second.ne[1];
}
auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight");
if (audio_proj_out_iter != tensor_storage_map.end()) {
params.audio_out_channels = audio_proj_out_iter->second.ne[1];
}
auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight");
if (attn2_iter != tensor_storage_map.end()) {
params.cross_attention_dim = attn2_iter->second.ne[0];
}
auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight");
if (audio_attn2_iter != tensor_storage_map.end()) {
params.audio_cross_attention_dim = audio_attn2_iter->second.ne[0];
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) {
params.cross_attention_adaln = true;
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() ||
tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
params.self_attention_gated = true;
}
if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() ||
tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) {
params.cross_attention_gated = true;
}
if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() &&
tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) {
params.use_caption_projection = false;
}
if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() &&
tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) {
params.use_audio_caption_projection = false;
}
params.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks.");
auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight");
if (connector_iter != tensor_storage_map.end()) {
params.use_connector = true;
params.connector_hidden_size = connector_iter->second.ne[1];
int64_t connector_heads = infer_gate_heads(tensor_storage_map,
prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias",
32);
auto connector_layout = infer_attention_layout(params.connector_hidden_size, connector_heads);
params.connector_num_heads = connector_layout.first;
params.connector_head_dim = connector_layout.second;
params.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks.");
auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers");
if (register_iter != tensor_storage_map.end()) {
params.connector_num_registers = register_iter->second.ne[1];
}
if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
params.connector_apply_gated_attention = true;
}
}
auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight");
if (audio_connector_iter != tensor_storage_map.end()) {
params.use_audio_connector = true;
params.audio_connector_hidden_size = audio_connector_iter->second.ne[1];
int64_t connector_heads = infer_gate_heads(tensor_storage_map,
prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias",
32);
auto connector_layout = infer_attention_layout(params.audio_connector_hidden_size, connector_heads);
params.audio_connector_num_heads = connector_layout.first;
params.audio_connector_head_dim = connector_layout.second;
params.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks.");
auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers");
if (register_iter != tensor_storage_map.end()) {
params.audio_connector_num_registers = register_iter->second.ne[1];
}
if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) {
params.audio_connector_apply_gated_attention = true;
}
}
model = LTXAVModelBlock(params);
config(LTXAVConfig::detect_from_weights(tensor_storage_map, prefix)),
model(config) {
model.init(params_ctx, tensor_storage_map, prefix);
}
@ -1692,21 +1718,21 @@ namespace LTXV {
int64_t total_channels = x_tensor.shape()[3];
int64_t spatial_size = width * height * frames;
GGML_ASSERT(total_channels >= params.in_channels);
GGML_ASSERT(total_channels >= config.in_channels);
sd::Tensor<float> vx({width, height, frames, params.in_channels});
size_t video_values = static_cast<size_t>(params.in_channels * spatial_size);
sd::Tensor<float> vx({width, height, frames, config.in_channels});
size_t video_values = static_cast<size_t>(config.in_channels * spatial_size);
std::copy_n(x_tensor.data(), video_values, vx.data());
if (audio_length <= 0 || total_channels == params.in_channels) {
if (audio_length <= 0 || total_channels == config.in_channels) {
return {vx, {}};
}
int64_t needed_audio_values = static_cast<int64_t>(audio_length) * params.num_audio_channels * params.audio_frequency_bins;
int64_t packed_audio_values = (total_channels - params.in_channels) * spatial_size;
int64_t needed_audio_values = static_cast<int64_t>(audio_length) * config.num_audio_channels * config.audio_frequency_bins;
int64_t packed_audio_values = (total_channels - config.in_channels) * spatial_size;
GGML_ASSERT(packed_audio_values >= needed_audio_values);
sd::Tensor<float> ax({params.audio_frequency_bins, audio_length, params.num_audio_channels, 1});
sd::Tensor<float> ax({config.audio_frequency_bins, audio_length, config.num_audio_channels, 1});
const float* audio_src = x_tensor.data() + video_values;
std::copy_n(audio_src, static_cast<size_t>(needed_audio_values), ax.data());
return {vx, ax};
@ -1767,25 +1793,25 @@ namespace LTXV {
if (has_video_positions) {
GGML_ASSERT(video_positions_tensor.shape()[2] == video_token_count);
video_pe_vec = build_video_rope_matrix_from_positions(video_positions_tensor,
static_cast<int>(params.hidden_size),
static_cast<int>(params.num_attention_heads),
params.positional_embedding_theta,
params.positional_embedding_max_pos,
params.use_middle_indices_grid);
static_cast<int>(config.hidden_size),
static_cast<int>(config.num_attention_heads),
config.positional_embedding_theta,
config.positional_embedding_max_pos,
config.use_middle_indices_grid);
} else {
video_pe_vec = build_video_rope_matrix(vx->ne[0],
vx->ne[1],
vx->ne[2],
static_cast<int>(params.hidden_size),
static_cast<int>(params.num_attention_heads),
static_cast<int>(config.hidden_size),
static_cast<int>(config.num_attention_heads),
video_frame_rate,
params.positional_embedding_theta,
params.positional_embedding_max_pos,
params.vae_scale_factors,
params.causal_temporal_positioning,
params.use_middle_indices_grid);
config.positional_embedding_theta,
config.positional_embedding_max_pos,
config.vae_scale_factors,
config.causal_temporal_positioning,
config.use_middle_indices_grid);
}
auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.attention_head_dim / 2, video_token_count * params.num_attention_heads);
auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.attention_head_dim / 2, video_token_count * config.num_attention_heads);
ggml_set_name(video_pe, "ltxav_video_pe");
set_backend_tensor_data(video_pe, video_pe_vec.data());
@ -1794,66 +1820,66 @@ namespace LTXV {
ggml_tensor* audio_cross_pe = nullptr;
if (ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0) {
audio_pe_vec = build_audio_rope_matrix(ax->ne[1],
static_cast<int>(params.audio_hidden_size),
static_cast<int>(params.audio_num_attention_heads),
params.positional_embedding_theta,
params.audio_positional_embedding_max_pos[0],
params.use_middle_indices_grid);
audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads);
static_cast<int>(config.audio_hidden_size),
static_cast<int>(config.audio_num_attention_heads),
config.positional_embedding_theta,
config.audio_positional_embedding_max_pos[0],
config.use_middle_indices_grid);
audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads);
ggml_set_name(audio_pe, "ltxav_audio_pe");
set_backend_tensor_data(audio_pe, audio_pe_vec.data());
int temporal_max_pos = std::max(params.positional_embedding_max_pos[0], params.audio_positional_embedding_max_pos[0]);
int temporal_max_pos = std::max(config.positional_embedding_max_pos[0], config.audio_positional_embedding_max_pos[0]);
if (has_video_positions) {
video_cross_pe_vec = build_video_temporal_rope_matrix_from_positions(video_positions_tensor,
static_cast<int>(params.audio_cross_attention_dim),
static_cast<int>(params.audio_num_attention_heads),
params.positional_embedding_theta,
static_cast<int>(config.audio_cross_attention_dim),
static_cast<int>(config.audio_num_attention_heads),
config.positional_embedding_theta,
temporal_max_pos,
true);
} else {
video_cross_pe_vec = build_video_temporal_rope_matrix(vx->ne[0],
vx->ne[1],
vx->ne[2],
static_cast<int>(params.audio_cross_attention_dim),
static_cast<int>(params.audio_num_attention_heads),
static_cast<int>(config.audio_cross_attention_dim),
static_cast<int>(config.audio_num_attention_heads),
video_frame_rate,
params.positional_embedding_theta,
config.positional_embedding_theta,
temporal_max_pos,
std::get<0>(params.vae_scale_factors),
params.causal_temporal_positioning,
std::get<0>(config.vae_scale_factors),
config.causal_temporal_positioning,
true);
}
video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, video_token_count * params.audio_num_attention_heads);
video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, video_token_count * config.audio_num_attention_heads);
ggml_set_name(video_cross_pe, "ltxav_video_cross_pe");
set_backend_tensor_data(video_cross_pe, video_cross_pe_vec.data());
audio_cross_pe_vec = build_audio_rope_matrix(ax->ne[1],
static_cast<int>(params.audio_cross_attention_dim),
static_cast<int>(params.audio_num_attention_heads),
params.positional_embedding_theta,
static_cast<int>(config.audio_cross_attention_dim),
static_cast<int>(config.audio_num_attention_heads),
config.positional_embedding_theta,
temporal_max_pos,
true);
audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads);
audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads);
ggml_set_name(audio_cross_pe, "ltxav_audio_cross_pe");
set_backend_tensor_data(audio_cross_pe, audio_cross_pe_vec.data());
}
bool needs_video_connector_pe =
params.use_connector &&
config.use_connector &&
context != nullptr &&
(context->ne[0] == params.connector_hidden_size ||
((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim ||
context->ne[0] == params.caption_channels * 2) &&
(context->ne[0] == config.connector_hidden_size ||
((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim ||
context->ne[0] == config.caption_channels * 2) &&
context->ne[1] < 1024));
ggml_tensor* video_connector_pe = nullptr;
if (needs_video_connector_pe) {
int64_t seq_len = context->ne[1];
int64_t target_len = std::max<int64_t>(1024, seq_len);
int64_t duplications = (target_len + params.connector_num_registers - 1) / params.connector_num_registers;
int64_t full_len = seq_len + duplications * params.connector_num_registers - seq_len;
connector_pe_vec = build_1d_rope_matrix(full_len, static_cast<int>(params.connector_hidden_size), static_cast<int>(params.connector_num_heads), 10000.f, 4096.f, true);
video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.connector_head_dim / 2, full_len * params.connector_num_heads);
int64_t duplications = (target_len + config.connector_num_registers - 1) / config.connector_num_registers;
int64_t full_len = seq_len + duplications * config.connector_num_registers - seq_len;
connector_pe_vec = build_1d_rope_matrix(full_len, static_cast<int>(config.connector_hidden_size), static_cast<int>(config.connector_num_heads), 10000.f, 4096.f, true);
video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.connector_head_dim / 2, full_len * config.connector_num_heads);
ggml_set_name(video_connector_pe, "ltxav_video_connector_pe");
set_backend_tensor_data(video_connector_pe, connector_pe_vec.data());
}
@ -1864,20 +1890,20 @@ namespace LTXV {
ax->ne[1] > 0;
bool needs_audio_connector_pe =
run_audio_context &&
params.use_audio_connector &&
config.use_audio_connector &&
context != nullptr &&
(context->ne[0] == params.audio_connector_hidden_size ||
((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim ||
context->ne[0] == params.caption_channels * 2) &&
(context->ne[0] == config.audio_connector_hidden_size ||
((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim ||
context->ne[0] == config.caption_channels * 2) &&
context->ne[1] < 1024));
ggml_tensor* audio_connector_pe = nullptr;
if (needs_audio_connector_pe) {
int64_t seq_len = context->ne[1];
int64_t target_len = std::max<int64_t>(1024, seq_len);
int64_t duplications = (target_len + params.audio_connector_num_registers - 1) / params.audio_connector_num_registers;
int64_t full_len = seq_len + duplications * params.audio_connector_num_registers - seq_len;
audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast<int>(params.audio_connector_hidden_size), static_cast<int>(params.audio_connector_num_heads), 10000.f, 4096.f, true);
audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_connector_head_dim / 2, full_len * params.audio_connector_num_heads);
int64_t duplications = (target_len + config.audio_connector_num_registers - 1) / config.audio_connector_num_registers;
int64_t full_len = seq_len + duplications * config.audio_connector_num_registers - seq_len;
audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast<int>(config.audio_connector_hidden_size), static_cast<int>(config.audio_connector_num_heads), 10000.f, 4096.f, true);
audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_connector_head_dim / 2, full_len * config.audio_connector_num_heads);
ggml_set_name(audio_connector_pe, "ltxav_audio_connector_pe");
set_backend_tensor_data(audio_connector_pe, audio_connector_pe_vec.data());
}

View File

@ -1,7 +1,10 @@
#ifndef __MMDIT_HPP__
#define __MMDIT_HPP__
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "diffusion_model.hpp"
#include "ggml_extend.hpp"
@ -9,6 +12,128 @@
#define MMDIT_GRAPH_SIZE 10240
struct MMDiTConfig {
int64_t input_size = -1;
int patch_size = 2;
int64_t in_channels = 16;
int64_t d_self = -1; // >=0 for MMdiT-X
int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
int64_t out_channels = 16;
int64_t pos_embed_max_size = 192;
int64_t num_patches = 36864; // 192 * 192
int64_t context_size = 4096;
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size = 1536;
std::string qk_norm;
static MMDiTConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
MMDiTConfig config;
bool has_weight_config = false;
bool has_pos_embed = false;
bool has_hidden_size = false;
bool has_context_embed = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (name.find("x_embedder.proj.weight") != std::string::npos && tensor_storage.n_dims == 4) {
has_weight_config = true;
has_hidden_size = true;
config.patch_size = static_cast<int>(tensor_storage.ne[0]);
config.in_channels = tensor_storage.ne[2];
config.hidden_size = tensor_storage.ne[3];
} else if (name.find("t_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) {
has_weight_config = true;
has_hidden_size = true;
config.hidden_size = tensor_storage.ne[1];
} else if (name.find("y_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) {
has_weight_config = true;
has_hidden_size = true;
config.adm_in_channels = tensor_storage.ne[0];
config.hidden_size = tensor_storage.ne[1];
} else if (name.find("context_embedder.weight") != std::string::npos && tensor_storage.n_dims == 2) {
has_weight_config = true;
has_context_embed = true;
config.context_size = tensor_storage.ne[0];
config.context_embedder_out_dim = tensor_storage.ne[1];
} else if (name.find("final_layer.linear.weight") != std::string::npos && tensor_storage.n_dims == 2) {
has_weight_config = true;
has_hidden_size = true;
config.hidden_size = tensor_storage.ne[0];
int64_t patch_area = static_cast<int64_t>(config.patch_size) * config.patch_size;
if (patch_area > 0) {
config.out_channels = tensor_storage.ne[1] / patch_area;
}
} else if (name.find("pos_embed") != std::string::npos && tensor_storage.n_dims == 3) {
has_weight_config = true;
has_pos_embed = true;
has_hidden_size = true;
config.hidden_size = tensor_storage.ne[0];
config.num_patches = tensor_storage.ne[1];
for (int64_t size = 1; size * size <= config.num_patches; size++) {
if (size * size == config.num_patches) {
config.pos_embed_max_size = size;
break;
}
}
}
size_t jb = name.find("joint_blocks.");
if (jb == std::string::npos) {
continue;
}
has_weight_config = true;
std::string block_name = name.substr(jb);
int64_t block_depth = atoi(block_name.substr(13, block_name.find(".", 13)).c_str());
if (block_depth + 1 > config.depth) {
config.depth = block_depth + 1;
}
if (block_name.find("attn.ln") != std::string::npos) {
if (block_name.find(".bias") != std::string::npos) {
config.qk_norm = "ln";
} else {
config.qk_norm = "rms";
}
}
if (block_name.find("attn2") != std::string::npos) {
if (block_depth > config.d_self) {
config.d_self = block_depth;
}
}
}
if (!has_pos_embed && config.d_self >= 0) {
config.pos_embed_max_size *= 2;
config.num_patches *= 4;
}
if (!has_hidden_size || config.hidden_size <= 0) {
config.hidden_size = 64 * config.depth;
}
if (!has_context_embed || config.context_embedder_out_dim <= 0) {
config.context_embedder_out_dim = config.hidden_size;
}
if (has_weight_config) {
LOG_DEBUG("mmdit: num_layers = %" PRId64 ", num_mmdit_x_layers = %" PRId64 ", hidden_size = %" PRId64 ", patch_size = %d, in_channels = %" PRId64 ", out_channels = %" PRId64 ", context_size = %" PRId64 ", adm_in_channels = %" PRId64 ", qk_norm = %s",
config.depth,
config.d_self + 1,
config.hidden_size,
config.patch_size,
config.in_channels,
config.out_channels,
config.context_size,
config.adm_in_channels,
config.qk_norm.empty() ? "none" : config.qk_norm.c_str());
}
return config;
}
};
struct Mlp : public GGMLBlock {
public:
Mlp(int64_t in_features,
@ -612,28 +737,16 @@ public:
struct MMDiT : public GGMLBlock {
// Diffusion model with a Transformer backbone.
protected:
int64_t input_size = -1;
int patch_size = 2;
int64_t in_channels = 16;
int64_t d_self = -1; // >=0 for MMdiT-X
int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
int64_t out_channels = 16;
int64_t pos_embed_max_size = 192;
int64_t num_patchs = 36864; // 192 * 192
int64_t context_size = 4096;
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, config.hidden_size, config.num_patches, 1);
}
public:
MMDiT(const String2TensorStorage& tensor_storage_map = {}) {
MMDiTConfig config;
explicit MMDiT(MMDiTConfig config = {})
: config(config) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
@ -646,64 +759,30 @@ public:
// pos_embed_offset is not used
// context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
continue;
size_t jb = tensor_name.find("joint_blocks.");
if (jb != std::string::npos) {
tensor_name = tensor_name.substr(jb); // remove prefix
int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str());
if (block_depth + 1 > depth) {
depth = block_depth + 1;
}
if (tensor_name.find("attn.ln") != std::string::npos) {
if (tensor_name.find(".bias") != std::string::npos) {
qk_norm = "ln";
} else {
qk_norm = "rms";
}
}
if (tensor_name.find("attn2") != std::string::npos) {
if (block_depth > d_self) {
d_self = block_depth;
}
}
}
blocks["x_embedder"] = std::shared_ptr<GGMLBlock>(new PatchEmbed(config.input_size,
config.patch_size,
config.in_channels,
config.hidden_size,
true));
blocks["t_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedder(config.hidden_size));
if (config.adm_in_channels != -1) {
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(config.adm_in_channels, config.hidden_size));
}
if (d_self >= 0) {
pos_embed_max_size *= 2;
num_patchs *= 4;
}
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(config.context_size, config.context_embedder_out_dim, true, true));
LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1);
int64_t default_out_channels = in_channels;
hidden_size = 64 * depth;
context_embedder_out_dim = 64 * depth;
int64_t num_heads = depth;
blocks["x_embedder"] = std::shared_ptr<GGMLBlock>(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true));
blocks["t_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedder(hidden_size));
if (adm_in_channels != -1) {
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size));
}
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, context_embedder_out_dim, true, true));
for (int i = 0; i < depth; i++) {
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size,
num_heads,
mlp_ratio,
qk_norm,
for (int i = 0; i < config.depth; i++) {
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(config.hidden_size,
config.depth,
config.mlp_ratio,
config.qk_norm,
true,
i == depth - 1,
i <= d_self));
i == config.depth - 1,
i <= config.d_self));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(config.hidden_size, config.patch_size, config.out_channels));
}
ggml_tensor*
@ -712,22 +791,22 @@ public:
int64_t w) {
auto pos_embed = params["pos_embed"];
h = (h + 1) / patch_size;
w = (w + 1) / patch_size;
h = (h + 1) / config.patch_size;
w = (w + 1) / config.patch_size;
GGML_ASSERT(h <= pos_embed_max_size && h > 0);
GGML_ASSERT(w <= pos_embed_max_size && w > 0);
GGML_ASSERT(h <= config.pos_embed_max_size && h > 0);
GGML_ASSERT(w <= config.pos_embed_max_size && w > 0);
int64_t top = (pos_embed_max_size - h) / 2;
int64_t left = (pos_embed_max_size - w) / 2;
int64_t top = (config.pos_embed_max_size - h) / 2;
int64_t left = (config.pos_embed_max_size - w) / 2;
auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size);
auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, config.hidden_size, config.pos_embed_max_size, config.pos_embed_max_size);
// spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = ggml_view_3d(ctx,
spatial_pos_embed,
hidden_size,
pos_embed_max_size,
config.hidden_size,
config.pos_embed_max_size,
h,
spatial_pos_embed->nb[1],
spatial_pos_embed->nb[2],
@ -735,14 +814,14 @@ public:
spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [pos_embed_max_size, h, hidden_size]
spatial_pos_embed = ggml_view_3d(ctx,
spatial_pos_embed,
hidden_size,
config.hidden_size,
h,
w,
spatial_pos_embed->nb[1],
spatial_pos_embed->nb[2],
spatial_pos_embed->nb[2] * left); // [w, h, hidden_size]
spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size]
spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); // [1, h*w, hidden_size]
spatial_pos_embed->nb[2] * left); // [w, h, hidden_size]
spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size]
spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, config.hidden_size, h * w, 1); // [1, h*w, hidden_size]
return spatial_pos_embed;
}
@ -757,7 +836,7 @@ public:
// return: [N, N*W, patch_size * patch_size * out_channels]
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
for (int i = 0; i < depth; i++) {
for (int i = 0; i < config.depth; i++) {
// skip iteration if i is in skip_layers
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
@ -800,7 +879,7 @@ public:
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
if (y != nullptr && adm_in_channels != -1) {
if (y != nullptr && config.adm_in_channels != -1) {
auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]);
y = y_embedder->forward(ctx, y); // [N, hidden_size]
@ -820,19 +899,22 @@ public:
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, /*patch_last*/ false); // [N, C, H, W]
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, config.patch_size, config.patch_size, /*patch_last*/ false); // [N, C, H, W]
return x;
}
};
struct MMDiTRunner : public DiffusionModelRunner {
MMDiTConfig config;
MMDiT mmdit;
MMDiTRunner(ggml_backend_t backend,
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: DiffusionModelRunner(backend, params_backend, prefix), mmdit(tensor_storage_map) {
: DiffusionModelRunner(backend, params_backend, prefix),
config(MMDiTConfig::detect_from_weights(tensor_storage_map, prefix)),
mmdit(config) {
mmdit.init(params_ctx, tensor_storage_map, prefix);
}

View File

@ -16,7 +16,7 @@ namespace Pid {
constexpr int PID_GRAPH_SIZE = 196608;
constexpr float PID_PI = 3.14159265358979323846f;
struct PixelDiTParams {
struct PixelDiTConfig {
int64_t in_channels = 3;
int64_t hidden_size = 1536;
int64_t num_groups = 24;
@ -38,6 +38,45 @@ namespace Pid {
int64_t lq_latent_down_factor = 8;
int64_t rope_ref_grid_h = 64;
int64_t rope_ref_grid_w = 64;
static PixelDiTConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
PixelDiTConfig config;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
size_t pos = name.find("patch_blocks.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
config.patch_depth = std::max<int64_t>(config.patch_depth, block_index + 1);
}
}
pos = name.find("pixel_blocks.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
config.pixel_depth = std::max<int64_t>(config.pixel_depth, block_index + 1);
}
}
if (name.find("lq_proj.latent_proj.0.weight") != std::string::npos) {
config.lq_latent_channels = tensor_storage.ne[2];
config.lq_latent_down_factor = config.lq_latent_channels >= 64 ? 16 : 8;
}
if (name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) {
config.patch_mlp_hidden_dim = tensor_storage.ne[1];
}
}
LOG_DEBUG("pid: patch_depth = %" PRId64 ", pixel_depth = %" PRId64 ", patch_mlp_hidden_dim = %" PRId64 ", lq_latent_channels = %" PRId64 ", lq_latent_down_factor = %" PRId64,
config.patch_depth,
config.pixel_depth,
config.patch_mlp_hidden_dim,
config.lq_latent_channels,
config.lq_latent_down_factor);
return config;
}
};
inline std::vector<float> make_rope_1d(int length,
@ -466,29 +505,29 @@ namespace Pid {
};
struct LQProjection2D : public GGMLBlock {
PixelDiTParams params_cfg;
PixelDiTConfig config;
LQProjection2D(const PixelDiTParams& params_cfg)
: params_cfg(params_cfg) {
blocks["latent_proj.0"] = std::make_shared<Conv2d>(params_cfg.lq_latent_channels, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
blocks["latent_proj.2"] = std::make_shared<Conv2d>(params_cfg.lq_hidden_dim, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared<PiDResBlock>(params_cfg.lq_hidden_dim);
LQProjection2D(const PixelDiTConfig& config)
: config(config) {
blocks["latent_proj.0"] = std::make_shared<Conv2d>(config.lq_latent_channels, config.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
blocks["latent_proj.2"] = std::make_shared<Conv2d>(config.lq_hidden_dim, config.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
for (int i = 0; i < config.lq_num_res_blocks; ++i) {
blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared<PiDResBlock>(config.lq_hidden_dim);
}
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
int num_outputs = static_cast<int>((config.patch_depth + config.lq_interval - 1) / config.lq_interval);
for (int i = 0; i < num_outputs; ++i) {
blocks["output_heads." + std::to_string(i)] = std::make_shared<Linear>(params_cfg.lq_hidden_dim, params_cfg.hidden_size, true);
blocks["gate_modules." + std::to_string(i)] = std::make_shared<SigmaAwareGate>(params_cfg.hidden_size);
blocks["output_heads." + std::to_string(i)] = std::make_shared<Linear>(config.lq_hidden_dim, config.hidden_size, true);
blocks["gate_modules." + std::to_string(i)] = std::make_shared<SigmaAwareGate>(config.hidden_size);
}
}
bool is_gate_active(int block_idx) const {
return block_idx % params_cfg.lq_interval == 0;
return block_idx % config.lq_interval == 0;
}
int get_output_index(int block_idx) const {
return block_idx / static_cast<int>(params_cfg.lq_interval);
return block_idx / static_cast<int>(config.lq_interval);
}
ggml_tensor* gate(GGMLRunnerContext* ctx,
@ -506,8 +545,8 @@ namespace Pid {
int64_t target_pW) {
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.0"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.2"]);
float z_to_patch_ratio = static_cast<float>(params_cfg.lq_sr_scale * params_cfg.lq_latent_down_factor) /
static_cast<float>(params_cfg.patch_size);
float z_to_patch_ratio = static_cast<float>(config.lq_sr_scale * config.lq_latent_down_factor) /
static_cast<float>(config.patch_size);
GGML_ASSERT(z_to_patch_ratio >= 1.0f);
if (lq_latent->ne[0] != target_pW || lq_latent->ne[1] != target_pH) {
lq_latent = ggml_interpolate(ctx->ggml_ctx,
@ -522,7 +561,7 @@ namespace Pid {
auto feat = conv0->forward(ctx, lq_latent);
feat = ggml_silu_inplace(ctx->ggml_ctx, feat);
feat = conv2->forward(ctx, feat);
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
for (int i = 0; i < config.lq_num_res_blocks; ++i) {
auto block = std::dynamic_pointer_cast<PiDResBlock>(blocks["latent_proj." + std::to_string(3 + i)]);
feat = block->forward(ctx, feat);
}
@ -533,7 +572,7 @@ namespace Pid {
auto tokens = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, feat, 2, 0, 1, 3));
tokens = ggml_reshape_3d(ctx->ggml_ctx, tokens, C, L, B);
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
int num_outputs = static_cast<int>((config.patch_depth + config.lq_interval - 1) / config.lq_interval);
std::vector<ggml_tensor*> outputs;
outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
@ -545,34 +584,34 @@ namespace Pid {
};
struct PixelDiT : public GGMLBlock {
PixelDiTParams params_cfg;
PixelDiTConfig config;
PixelDiT() = default;
PixelDiT(const PixelDiTParams& params_cfg)
: params_cfg(params_cfg) {
blocks["pixel_embedder"] = std::make_shared<PixelTokenEmbedder>(params_cfg.in_channels, params_cfg.pixel_hidden_size);
blocks["s_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.in_channels * params_cfg.patch_size * params_cfg.patch_size, params_cfg.hidden_size, false, true);
blocks["t_embedder"] = std::make_shared<PixelDiTTimestepEmbedder>(params_cfg.hidden_size);
blocks["y_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.txt_embed_dim, params_cfg.hidden_size, true, true);
for (int i = 0; i < params_cfg.patch_depth; ++i) {
blocks["patch_blocks." + std::to_string(i)] = std::make_shared<MMDiTBlockT2I>(params_cfg.hidden_size, params_cfg.num_groups, params_cfg.patch_mlp_hidden_dim);
PixelDiT(const PixelDiTConfig& config)
: config(config) {
blocks["pixel_embedder"] = std::make_shared<PixelTokenEmbedder>(config.in_channels, config.pixel_hidden_size);
blocks["s_embedder"] = std::make_shared<PatchTokenEmbedder>(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, false, true);
blocks["t_embedder"] = std::make_shared<PixelDiTTimestepEmbedder>(config.hidden_size);
blocks["y_embedder"] = std::make_shared<PatchTokenEmbedder>(config.txt_embed_dim, config.hidden_size, true, true);
for (int i = 0; i < config.patch_depth; ++i) {
blocks["patch_blocks." + std::to_string(i)] = std::make_shared<MMDiTBlockT2I>(config.hidden_size, config.num_groups, config.patch_mlp_hidden_dim);
}
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
blocks["pixel_blocks." + std::to_string(i)] = std::make_shared<PiTBlock>(params_cfg.pixel_hidden_size,
params_cfg.hidden_size,
params_cfg.patch_size,
params_cfg.pixel_attn_hidden_size,
params_cfg.pixel_num_groups);
for (int i = 0; i < config.pixel_depth; ++i) {
blocks["pixel_blocks." + std::to_string(i)] = std::make_shared<PiTBlock>(config.pixel_hidden_size,
config.hidden_size,
config.patch_size,
config.pixel_attn_hidden_size,
config.pixel_num_groups);
}
blocks["final_layer"] = std::make_shared<FinalLayer>(params_cfg.pixel_hidden_size, params_cfg.in_channels);
blocks["lq_proj"] = std::make_shared<LQProjection2D>(params_cfg);
blocks["final_layer"] = std::make_shared<FinalLayer>(config.pixel_hidden_size, config.in_channels);
blocks["lq_proj"] = std::make_shared<LQProjection2D>(config);
}
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
std::string prefix = "") override {
params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, params_cfg.hidden_size, params_cfg.txt_max_length, 1);
params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, config.hidden_size, config.txt_max_length, 1);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -594,21 +633,21 @@ namespace Pid {
int64_t W_orig = x->ne[0];
int64_t H_orig = x->ne[1];
x = DiT::pad_to_patch_size(ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size));
x = DiT::pad_to_patch_size(ctx, x, static_cast<int>(config.patch_size), static_cast<int>(config.patch_size));
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t B = x->ne[3];
int64_t Hs = H / params_cfg.patch_size;
int64_t Ws = W / params_cfg.patch_size;
int64_t Hs = H / config.patch_size;
int64_t Ws = W / config.patch_size;
int64_t L = Hs * Ws;
int64_t P2 = params_cfg.patch_size * params_cfg.patch_size;
int64_t P2 = config.patch_size * config.patch_size;
auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size), true);
auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(config.patch_size), static_cast<int>(config.patch_size), true);
auto t_emb = t_embedder->forward(ctx, timesteps);
auto condition = ggml_silu(ctx->ggml_ctx, t_emb);
GGML_ASSERT(context != nullptr);
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
int64_t Ltxt = std::min<int64_t>(context->ne[1], config.txt_max_length);
auto y = ggml_ext_slice(ctx->ggml_ctx, context, 1, 0, Ltxt);
auto y_emb = y_embedder->forward(ctx, y);
auto y_pos = ggml_ext_slice(ctx->ggml_ctx, params["y_pos_embedding"], 1, 0, Ltxt);
@ -618,7 +657,7 @@ namespace Pid {
auto s = s_embedder->forward(ctx, x_patches);
for (int i = 0; i < params_cfg.patch_depth; ++i) {
for (int i = 0; i < config.patch_depth; ++i) {
if (lq_proj->is_gate_active(i)) {
int out_idx = lq_proj->get_output_index(i);
if (out_idx < static_cast<int>(lq_features.size())) {
@ -639,22 +678,22 @@ namespace Pid {
}
s = ggml_silu(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, s, t_emb));
auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, params_cfg.hidden_size, L * B);
auto pixels = pixel_embedder->forward(ctx, x, params_cfg.patch_size, pixel_pos_full);
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, config.hidden_size, L * B);
auto pixels = pixel_embedder->forward(ctx, x, config.patch_size, pixel_pos_full);
for (int i = 0; i < config.pixel_depth; ++i) {
auto block = std::dynamic_pointer_cast<PiTBlock>(blocks["pixel_blocks." + std::to_string(i)]);
pixels = block->forward(ctx, pixels, s_cond, H, W, pixel_pos_comp);
sd::ggml_graph_cut::mark_graph_cut(pixels, "pid.pixel_blocks." + std::to_string(i), "pixels");
}
pixels = final_layer->forward(ctx, pixels);
pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, params_cfg.in_channels * P2, L, B);
pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, config.in_channels * P2, L, B);
auto out = DiT::unpatchify(ctx->ggml_ctx,
pixels,
Hs,
Ws,
static_cast<int>(params_cfg.patch_size),
static_cast<int>(params_cfg.patch_size),
static_cast<int>(config.patch_size),
static_cast<int>(config.patch_size),
false);
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H_orig);
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W_orig);
@ -663,7 +702,7 @@ namespace Pid {
};
struct PiDRunner : public DiffusionModelRunner {
PixelDiTParams params_cfg;
PixelDiTConfig config;
PixelDiT model;
std::vector<float> pos_img_vec;
std::vector<float> pos_txt_vec;
@ -674,43 +713,9 @@ namespace Pid {
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map,
const std::string prefix = "model.diffusion_model")
: DiffusionModelRunner(backend, params_backend, prefix) {
for (const auto& pair : tensor_storage_map) {
const std::string& tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos) {
continue;
}
size_t pos = tensor_name.find("patch_blocks.");
if (pos != std::string::npos) {
auto items = split_string(tensor_name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
params_cfg.patch_depth = std::max<int64_t>(params_cfg.patch_depth, block_index + 1);
}
}
pos = tensor_name.find("pixel_blocks.");
if (pos != std::string::npos) {
auto items = split_string(tensor_name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
params_cfg.pixel_depth = std::max<int64_t>(params_cfg.pixel_depth, block_index + 1);
}
}
if (tensor_name.find("lq_proj.latent_proj.0.weight") != std::string::npos) {
params_cfg.lq_latent_channels = pair.second.ne[2];
params_cfg.lq_latent_down_factor = params_cfg.lq_latent_channels >= 64 ? 16 : 8;
}
if (tensor_name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) {
params_cfg.patch_mlp_hidden_dim = pair.second.ne[1];
}
}
LOG_INFO("PiD params: patch_depth=%" PRId64 ", pixel_depth=%" PRId64 ", patch_mlp_hidden_dim=%" PRId64 ", lq_latent_channels=%" PRId64 ", lq_latent_down_factor=%" PRId64,
params_cfg.patch_depth,
params_cfg.pixel_depth,
params_cfg.patch_mlp_hidden_dim,
params_cfg.lq_latent_channels,
params_cfg.lq_latent_down_factor);
model = PixelDiT(params_cfg);
: DiffusionModelRunner(backend, params_backend, prefix),
config(PixelDiTConfig::detect_from_weights(tensor_storage_map, prefix)) {
model = PixelDiT(config);
model.init(params_ctx, tensor_storage_map, prefix);
}
@ -737,60 +742,60 @@ namespace Pid {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t B = x->ne[3];
int64_t Wp = align_up(static_cast<int>(W), static_cast<int>(params_cfg.patch_size));
int64_t Hp = align_up(static_cast<int>(H), static_cast<int>(params_cfg.patch_size));
int64_t Hs = Hp / params_cfg.patch_size;
int64_t Ws = Wp / params_cfg.patch_size;
int64_t Wp = align_up(static_cast<int>(W), static_cast<int>(config.patch_size));
int64_t Hp = align_up(static_cast<int>(H), static_cast<int>(config.patch_size));
int64_t Hs = Hp / config.patch_size;
int64_t Ws = Wp / config.patch_size;
pos_img_vec = make_rope_2d(static_cast<int>(Hs),
static_cast<int>(Ws),
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
static_cast<int>(config.hidden_size / config.num_groups),
10000.f,
16.f,
static_cast<int>(params_cfg.rope_ref_grid_h),
static_cast<int>(params_cfg.rope_ref_grid_w));
static_cast<int>(config.rope_ref_grid_h),
static_cast<int>(config.rope_ref_grid_w));
auto pos_img = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.hidden_size / params_cfg.num_groups / 2,
config.hidden_size / config.num_groups / 2,
Hs * Ws);
set_backend_tensor_data(pos_img, pos_img_vec.data());
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
int64_t Ltxt = std::min<int64_t>(context->ne[1], config.txt_max_length);
pos_txt_vec = make_rope_1d(static_cast<int>(Ltxt),
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
params_cfg.text_rope_theta);
static_cast<int>(config.hidden_size / config.num_groups),
config.text_rope_theta);
auto pos_txt = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.hidden_size / params_cfg.num_groups / 2,
config.hidden_size / config.num_groups / 2,
Ltxt);
set_backend_tensor_data(pos_txt, pos_txt_vec.data());
pixel_pos_vec = make_pixel_abs_pos(static_cast<int>(Hp),
static_cast<int>(Wp),
static_cast<int>(params_cfg.pixel_hidden_size));
static_cast<int>(config.pixel_hidden_size));
auto pixel_pos = ggml_new_tensor_3d(compute_ctx,
GGML_TYPE_F32,
params_cfg.pixel_hidden_size,
config.pixel_hidden_size,
Wp * Hp,
1);
set_backend_tensor_data(pixel_pos, pixel_pos_vec.data());
pixel_pos_comp_vec = make_rope_2d(static_cast<int>(Hs),
static_cast<int>(Ws),
static_cast<int>(params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups),
static_cast<int>(config.pixel_attn_hidden_size / config.pixel_num_groups),
10000.f,
16.f,
static_cast<int>(params_cfg.rope_ref_grid_h),
static_cast<int>(params_cfg.rope_ref_grid_w));
static_cast<int>(config.rope_ref_grid_h),
static_cast<int>(config.rope_ref_grid_w));
auto pixel_pos_comp = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups / 2,
config.pixel_attn_hidden_size / config.pixel_num_groups / 2,
Hs * Ws);
set_backend_tensor_data(pixel_pos_comp, pixel_pos_comp_vec.data());

View File

@ -10,6 +10,48 @@
namespace Qwen {
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
struct QwenImageConfig {
int patch_size = 2;
int64_t in_channels = 64;
int64_t out_channels = 16;
int num_layers = 60;
int64_t attention_head_dim = 128;
int64_t num_attention_heads = 24;
int64_t joint_attention_dim = 3584;
int theta = 10000;
std::vector<int> axes_dim = {16, 56, 56};
int axes_dim_sum = 128;
bool zero_cond_t = false;
static QwenImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
QwenImageConfig config;
config.num_layers = 0;
for (const auto& [name, _] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (name.find("__index_timestep_zero__") != std::string::npos) {
config.zero_cond_t = true;
}
size_t pos = name.find("transformer_blocks.");
if (pos == std::string::npos) {
continue;
}
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.num_layers) {
config.num_layers = block_index + 1;
}
}
}
LOG_DEBUG("qwen_image: num_layers = %d, zero_cond_t = %s",
config.num_layers,
config.zero_cond_t ? "true" : "false");
return config;
}
};
struct TimestepEmbedding : public GGMLBlock {
public:
TimestepEmbedding(int64_t in_channels,
@ -350,46 +392,32 @@ namespace Qwen {
}
};
struct QwenImageParams {
int patch_size = 2;
int64_t in_channels = 64;
int64_t out_channels = 16;
int num_layers = 60;
int64_t attention_head_dim = 128;
int64_t num_attention_heads = 24;
int64_t joint_attention_dim = 3584;
int theta = 10000;
std::vector<int> axes_dim = {16, 56, 56};
int axes_dim_sum = 128;
bool zero_cond_t = false;
};
class QwenImageModel : public GGMLBlock {
protected:
QwenImageParams params;
QwenImageConfig config;
public:
QwenImageModel() {}
QwenImageModel(QwenImageParams params)
: params(params) {
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
QwenImageModel(QwenImageConfig config)
: config(config) {
int64_t inner_dim = config.num_attention_heads * config.attention_head_dim;
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim));
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(config.joint_attention_dim, 1e-6f));
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(config.in_channels, inner_dim));
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(config.joint_attention_dim, inner_dim));
// blocks
for (int i = 0; i < params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
params.num_attention_heads,
params.attention_head_dim,
config.num_attention_heads,
config.attention_head_dim,
1e-6f,
params.zero_cond_t));
config.zero_cond_t));
blocks["transformer_blocks." + std::to_string(i)] = block;
}
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new AdaLayerNormContinuous(inner_dim, inner_dim, false, 1e-6f));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, config.patch_size * config.patch_size * config.out_channels));
}
ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
@ -406,7 +434,7 @@ namespace Qwen {
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto t_emb = time_text_embed->forward(ctx, timestep);
if (params.zero_cond_t) {
if (config.zero_cond_t) {
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep));
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
}
@ -417,7 +445,7 @@ namespace Qwen {
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.prelude", "txt");
// sd::ggml_graph_cut::mark_graph_cut(t_emb, "qwen_image.prelude", "t_emb");
for (int i = 0; i < params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index);
@ -427,7 +455,7 @@ namespace Qwen {
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.transformer_blocks." + std::to_string(i), "txt");
}
if (params.zero_cond_t) {
if (config.zero_cond_t) {
t_emb = ggml_ext_chunk(ctx->ggml_ctx, t_emb, 2, 1)[0];
}
@ -456,12 +484,12 @@ namespace Qwen {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
auto img = DiT::pad_and_patchify(ctx, x, params.patch_size, params.patch_size);
auto img = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size);
int64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = DiT::pad_and_patchify(ctx, ref, params.patch_size, params.patch_size);
ref = DiT::pad_and_patchify(ctx, ref, config.patch_size, config.patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
@ -474,7 +502,7 @@ namespace Qwen {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, params.patch_size, params.patch_size); // [N, C, H, W]
out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, config.patch_size, config.patch_size); // [N, C, H, W]
return out;
}
@ -482,7 +510,7 @@ namespace Qwen {
struct QwenImageRunner : public DiffusionModelRunner {
public:
QwenImageParams qwen_image_params;
QwenImageConfig config;
QwenImageModel qwen_image;
std::vector<float> pe_vec;
std::vector<float> modulate_index_vec;
@ -494,34 +522,10 @@ namespace Qwen {
const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE,
bool zero_cond_t = false)
: DiffusionModelRunner(backend, params_backend, prefix) {
qwen_image_params.num_layers = 0;
qwen_image_params.zero_cond_t = zero_cond_t;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
if (tensor_name.find("__index_timestep_zero__") != std::string::npos) {
qwen_image_params.zero_cond_t = true;
}
size_t pos = tensor_name.find("transformer_blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > qwen_image_params.num_layers) {
qwen_image_params.num_layers = block_index + 1;
}
}
continue;
}
}
LOG_INFO("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
if (qwen_image_params.zero_cond_t) {
LOG_INFO("use zero_cond_t");
}
qwen_image = QwenImageModel(qwen_image_params);
: DiffusionModelRunner(backend, params_backend, prefix),
config(QwenImageConfig::detect_from_weights(tensor_storage_map, prefix)) {
config.zero_cond_t = config.zero_cond_t || zero_cond_t;
qwen_image = QwenImageModel(config);
qwen_image.init(params_ctx, tensor_storage_map, prefix);
}
@ -552,36 +556,36 @@ namespace Qwen {
pe_vec = Rope::gen_qwen_image_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
qwen_image_params.patch_size,
config.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
ref_latents,
increase_ref_index,
qwen_image_params.theta,
config.theta,
circular_y_enabled,
circular_x_enabled,
qwen_image_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / qwen_image_params.axes_dim_sum / 2);
config.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe, true, "pe");
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
ggml_tensor* modulate_index = nullptr;
if (qwen_image_params.zero_cond_t) {
if (config.zero_cond_t) {
modulate_index_vec.clear();
int64_t h_len = ((x->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
int64_t w_len = ((x->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
int64_t h_len = ((x->ne[1] + (config.patch_size / 2)) / config.patch_size);
int64_t w_len = ((x->ne[0] + (config.patch_size / 2)) / config.patch_size);
int64_t num_img_tokens = h_len * w_len;
modulate_index_vec.insert(modulate_index_vec.end(), num_img_tokens, 0.f);
int64_t num_ref_img_tokens = 0;
for (ggml_tensor* ref : ref_latents) {
int64_t h_len = ((ref->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
int64_t w_len = ((ref->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
int64_t h_len = ((ref->ne[1] + (config.patch_size / 2)) / config.patch_size);
int64_t w_len = ((ref->ne[0] + (config.patch_size / 2)) / config.patch_size);
num_ref_img_tokens += h_len * w_len;
}

View File

@ -14,6 +14,28 @@
#include "model.h"
#include "tokenizers/t5_unigram_tokenizer.h"
struct T5Config {
int64_t num_layers = 24;
int64_t model_dim = 4096;
int64_t ff_dim = 10240;
int64_t num_heads = 64;
int64_t vocab_size = 32128;
bool relative_attention = true;
static T5Config detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
bool is_umt5 = false) {
(void)tensor_storage_map;
(void)prefix;
T5Config config;
if (is_umt5) {
config.vocab_size = 256384;
config.relative_attention = false;
}
return config;
}
};
class T5LayerNorm : public UnaryBlock {
protected:
int64_t hidden_size;
@ -272,30 +294,21 @@ public:
}
};
struct T5Params {
int64_t num_layers = 24;
int64_t model_dim = 4096;
int64_t ff_dim = 10240;
int64_t num_heads = 64;
int64_t vocab_size = 32128;
bool relative_attention = true;
};
struct T5 : public GGMLBlock {
T5Params params;
T5Config config;
public:
T5() {}
T5(T5Params params)
: params(params) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(params.num_layers,
params.model_dim,
params.model_dim,
params.ff_dim,
params.num_heads,
params.relative_attention));
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size,
params.model_dim));
T5(T5Config config)
: config(config) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(config.num_layers,
config.model_dim,
config.model_dim,
config.ff_dim,
config.num_heads,
config.relative_attention));
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(config.vocab_size,
config.model_dim));
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -316,7 +329,7 @@ public:
};
struct T5Runner : public GGMLRunner {
T5Params params;
T5Config config;
T5 model;
std::vector<int> relative_position_bucket_vec;
@ -325,12 +338,9 @@ struct T5Runner : public GGMLRunner {
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool is_umt5 = false)
: GGMLRunner(backend, params_backend) {
if (is_umt5) {
params.vocab_size = 256384;
params.relative_attention = false;
}
model = T5(params);
: GGMLRunner(backend, params_backend),
config(T5Config::detect_from_weights(tensor_storage_map, prefix, is_umt5)) {
model = T5(config);
model.init(params_ctx, tensor_storage_map, prefix);
}

View File

@ -1,6 +1,9 @@
#ifndef __UNET_HPP__
#define __UNET_HPP__
#include <algorithm>
#include <vector>
#include "common_block.hpp"
#include "diffusion_model.hpp"
#include "model.h"
@ -9,6 +12,125 @@
#define UNET_GRAPH_SIZE 102400
struct UNetConfig {
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
int num_res_blocks = 2;
std::vector<int> attention_resolutions = {4, 2, 1};
std::vector<int> channel_mult = {1, 2, 4, 4};
std::vector<int> transformer_depth = {1, 1, 1, 1};
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
bool use_linear_projection = false;
bool tiny_unet = false;
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
static UNetConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
SDVersion version = VERSION_SD1) {
UNetConfig config;
config.version = version;
if (sd_version_is_sd2(version)) {
config.context_dim = 1024;
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
} else if (sd_version_is_sdxl(version)) {
config.context_dim = 2048;
config.attention_resolutions = {4, 2};
config.channel_mult = {1, 2, 4};
config.transformer_depth = {1, 2, 10};
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
if (version == VERSION_SDXL_VEGA) {
config.transformer_depth = {1, 1, 2};
}
} else if (version == VERSION_SVD) {
config.in_channels = 8;
config.out_channels = 4;
config.context_dim = 1024;
config.adm_in_channels = 768;
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
}
if (sd_version_is_inpaint(version)) {
config.in_channels = 9;
} else if (sd_version_is_unet_edit(version)) {
config.in_channels = 8;
}
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
config.num_res_blocks = 1;
config.channel_mult = {1, 2, 4};
config.tiny_unet = true;
if (version == VERSION_SDXS_512_DS) {
config.attention_resolutions = {4, 2}; // here just like SDXL
}
}
auto find_weight = [&](const std::string& suffix) -> const TensorStorage* {
std::string name = prefix.empty() ? suffix : prefix + "." + suffix;
auto it = tensor_storage_map.find(name);
if (it == tensor_storage_map.end()) {
return nullptr;
}
return &it->second;
};
if (const TensorStorage* input = find_weight("input_blocks.0.0.weight")) {
if (input->n_dims == 4) {
config.in_channels = static_cast<int>(input->ne[2]);
config.model_channels = static_cast<int>(input->ne[3]);
config.time_embed_dim = config.model_channels * 4;
}
}
if (const TensorStorage* time_embed = find_weight("time_embed.0.weight")) {
if (time_embed->n_dims == 2) {
config.model_channels = static_cast<int>(time_embed->ne[0]);
config.time_embed_dim = static_cast<int>(time_embed->ne[1]);
}
}
if (const TensorStorage* label_emb = find_weight("label_emb.0.0.weight")) {
if (label_emb->n_dims == 2) {
config.adm_in_channels = static_cast<int>(label_emb->ne[0]);
config.time_embed_dim = static_cast<int>(label_emb->ne[1]);
}
}
if (const TensorStorage* out = find_weight("out.2.weight")) {
if (out->n_dims == 4) {
config.out_channels = static_cast<int>(out->ne[3]);
}
}
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (name.find("attn2.to_k.weight") != std::string::npos && tensor_storage.n_dims == 2) {
config.context_dim = static_cast<int>(tensor_storage.ne[0]);
break;
}
}
LOG_DEBUG("unet: in_channels = %d, out_channels = %d, model_channels = %d, time_embed_dim = %d, context_dim = %d, adm_in_channels = %d, num_res_blocks = %d, tiny_unet = %s",
config.in_channels,
config.out_channels,
config.model_channels,
config.time_embed_dim,
config.context_dim,
config.adm_in_channels,
config.num_res_blocks,
config.tiny_unet ? "true" : "false");
return config;
}
};
class SpatialVideoTransformer : public SpatialTransformer {
protected:
int64_t time_depth;
@ -166,66 +288,26 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
protected:
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
int num_res_blocks = 2;
std::vector<int> attention_resolutions = {4, 2, 1};
std::vector<int> channel_mult = {1, 2, 4, 4};
std::vector<int> transformer_depth = {1, 1, 1, 1};
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
bool use_linear_projection = false;
bool tiny_unet = false;
public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UNetConfig config;
UnetModelBlock(SDVersion version = VERSION_SD1, const String2TensorStorage& tensor_storage_map = {})
: version(version) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
transformer_depth = {1, 2, 10};
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
if (version == VERSION_SDXL_VEGA) {
transformer_depth = {1, 1, 2};
}
} else if (version == VERSION_SVD) {
in_channels = 8;
out_channels = 4;
context_dim = 1024;
adm_in_channels = 768;
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
}
if (sd_version_is_inpaint(version)) {
in_channels = 9;
} else if (sd_version_is_unet_edit(version)) {
in_channels = 8;
}
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
num_res_blocks = 1;
channel_mult = {1, 2, 4};
tiny_unet = true;
if (version == VERSION_SDXS_512_DS) {
attention_resolutions = {4, 2}; // here just like SDXL
}
}
explicit UnetModelBlock(UNetConfig config = {})
: config(config) {
const SDVersion version = this->config.version;
const int in_channels = this->config.in_channels;
const int out_channels = this->config.out_channels;
const int num_res_blocks = this->config.num_res_blocks;
const auto& attention_resolutions = this->config.attention_resolutions;
const auto& channel_mult = this->config.channel_mult;
const auto& transformer_depth = this->config.transformer_depth;
const int time_embed_dim = this->config.time_embed_dim;
const int num_heads = this->config.num_heads;
const int num_head_channels = this->config.num_head_channels;
const int context_dim = this->config.context_dim;
const bool use_linear_projection = this->config.use_linear_projection;
const bool tiny_unet = this->config.tiny_unet;
const int model_channels = this->config.model_channels;
const int adm_in_channels = this->config.adm_in_channels;
// dims is always 2
// use_temporal_attention is always True for SVD
@ -398,7 +480,7 @@ public:
ggml_tensor* x,
ggml_tensor* emb,
int num_video_frames) {
if (version == VERSION_SVD) {
if (config.version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<VideoResBlock>(blocks[name]);
return block->forward(ctx, x, emb, num_video_frames);
@ -414,7 +496,7 @@ public:
ggml_tensor* x,
ggml_tensor* context,
int timesteps) {
if (version == VERSION_SVD) {
if (config.version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
return block->forward(ctx, x, context, timesteps);
@ -440,6 +522,13 @@ public:
// c_concat: [N, in_channels, h, w] or [1, in_channels, h, w]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// return: [N, out_channels, h, w]
const SDVersion version = config.version;
const int model_channels = config.model_channels;
const int num_res_blocks = config.num_res_blocks;
const auto& attention_resolutions = config.attention_resolutions;
const auto& channel_mult = config.channel_mult;
const bool tiny_unet = config.tiny_unet;
if (context != nullptr) {
if (context->ne[2] != x->ne[3]) {
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
@ -601,6 +690,7 @@ public:
};
struct UNetModelRunner : public DiffusionModelRunner {
UNetConfig config;
UnetModelBlock unet;
UNetModelRunner(ggml_backend_t backend,
@ -608,7 +698,9 @@ struct UNetModelRunner : public DiffusionModelRunner {
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
SDVersion version = VERSION_SD1)
: DiffusionModelRunner(backend, params_backend, prefix), unet(version, tensor_storage_map) {
: DiffusionModelRunner(backend, params_backend, prefix),
config(UNetConfig::detect_from_weights(tensor_storage_map, prefix, version)),
unet(config) {
unet.init(params_ctx, tensor_storage_map, prefix);
}

View File

@ -16,6 +16,77 @@ namespace WAN {
constexpr int CACHE_T = 2;
constexpr int WAN_GRAPH_SIZE = 10240;
struct WanConfig {
std::string model_type = "t2v";
std::tuple<int, int, int> patch_size = {1, 2, 2};
int64_t text_len = 512;
int64_t in_dim = 16;
int64_t dim = 2048;
int64_t ffn_dim = 8192;
int freq_dim = 256;
int64_t text_dim = 4096;
int64_t out_dim = 16;
int64_t num_heads = 16;
int num_layers = 32;
int vace_layers = 0;
int64_t vace_in_dim = 96;
std::map<int, int> vace_layers_mapping = {};
bool qk_norm = true;
bool cross_attn_norm = true;
float eps = 1e-6f;
int64_t flf_pos_embed_token_number = 0;
int theta = 10000;
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
std::vector<int> axes_dim = {44, 42, 42};
int64_t axes_dim_sum = 128;
static WanConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
WanConfig config;
config.num_layers = 0;
for (const auto& [name, _] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
size_t pos = name.find("vace_blocks.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.vace_layers) {
config.vace_layers = block_index + 1;
}
}
continue;
}
pos = name.find("blocks.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > config.num_layers) {
config.num_layers = block_index + 1;
}
}
continue;
}
if (name.find("img_emb") != std::string::npos) {
config.model_type = "i2v";
}
if (name.find("img_emb.emb_pos") != std::string::npos) {
config.flf_pos_embed_token_number = 514;
}
}
LOG_DEBUG("wan: model_type = %s, num_layers = %d, vace_layers = %d, dim = %" PRId64 ", ffn_dim = %" PRId64 ", num_heads = %" PRId64,
config.model_type.c_str(),
config.num_layers,
config.vace_layers,
config.dim,
config.ffn_dim,
config.num_heads);
return config;
}
};
class CausalConv3d : public GGMLBlock {
protected:
int64_t in_channels;
@ -1799,97 +1870,72 @@ namespace WAN {
}
};
struct WanParams {
std::string model_type = "t2v";
std::tuple<int, int, int> patch_size = {1, 2, 2};
int64_t text_len = 512;
int64_t in_dim = 16;
int64_t dim = 2048;
int64_t ffn_dim = 8192;
int freq_dim = 256;
int64_t text_dim = 4096;
int64_t out_dim = 16;
int64_t num_heads = 16;
int num_layers = 32;
int vace_layers = 0;
int64_t vace_in_dim = 96;
std::map<int, int> vace_layers_mapping = {};
bool qk_norm = true;
bool cross_attn_norm = true;
float eps = 1e-6f;
int64_t flf_pos_embed_token_number = 0;
int theta = 10000;
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
std::vector<int> axes_dim = {44, 42, 42};
int64_t axes_dim_sum = 128;
};
class Wan : public GGMLBlock {
protected:
WanParams params;
WanConfig config;
public:
Wan() {}
Wan(WanParams params)
: params(params) {
Wan(WanConfig config)
: config(config) {
// patch_embedding
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(config.in_dim, config.dim, config.patch_size, config.patch_size));
// text_embedding
blocks["text_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.text_dim, params.dim));
blocks["text_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(config.text_dim, config.dim));
// text_embedding.1 is nn.GELU()
blocks["text_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
blocks["text_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(config.dim, config.dim));
// time_embedding
blocks["time_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.freq_dim, params.dim));
blocks["time_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(config.freq_dim, config.dim));
// time_embedding.1 is nn.SiLU()
blocks["time_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
blocks["time_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(config.dim, config.dim));
// time_projection.0 is nn.SiLU()
blocks["time_projection.1"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim * 6));
blocks["time_projection.1"] = std::shared_ptr<GGMLBlock>(new Linear(config.dim, config.dim * 6));
// blocks
for (int i = 0; i < params.num_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new WanAttentionBlock(params.model_type == "t2v",
params.dim,
params.ffn_dim,
params.num_heads,
params.qk_norm,
params.cross_attn_norm,
params.eps));
for (int i = 0; i < config.num_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new WanAttentionBlock(config.model_type == "t2v",
config.dim,
config.ffn_dim,
config.num_heads,
config.qk_norm,
config.cross_attn_norm,
config.eps));
blocks["blocks." + std::to_string(i)] = block;
}
// head
blocks["head"] = std::shared_ptr<GGMLBlock>(new Head(params.dim, params.out_dim, params.patch_size, params.eps));
blocks["head"] = std::shared_ptr<GGMLBlock>(new Head(config.dim, config.out_dim, config.patch_size, config.eps));
// img_emb
if (params.model_type == "i2v") {
blocks["img_emb"] = std::shared_ptr<GGMLBlock>(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number));
if (config.model_type == "i2v") {
blocks["img_emb"] = std::shared_ptr<GGMLBlock>(new MLPProj(1280, config.dim, config.flf_pos_embed_token_number));
}
// vace
if (params.vace_layers > 0) {
for (int i = 0; i < params.vace_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new VaceWanAttentionBlock(params.model_type == "t2v",
params.dim,
params.ffn_dim,
params.num_heads,
params.qk_norm,
params.cross_attn_norm,
params.eps,
if (config.vace_layers > 0) {
for (int i = 0; i < config.vace_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new VaceWanAttentionBlock(config.model_type == "t2v",
config.dim,
config.ffn_dim,
config.num_heads,
config.qk_norm,
config.cross_attn_norm,
config.eps,
i));
blocks["vace_blocks." + std::to_string(i)] = block;
}
int step = params.num_layers / params.vace_layers;
int step = config.num_layers / config.vace_layers;
int n = 0;
for (int i = 0; i < params.num_layers; i += step) {
this->params.vace_layers_mapping[i] = n;
for (int i = 0; i < config.num_layers; i += step) {
this->config.vace_layers_mapping[i] = n;
n++;
}
blocks["vace_patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.vace_in_dim, params.dim, params.patch_size, params.patch_size));
blocks["vace_patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(config.vace_in_dim, config.dim, config.patch_size, config.patch_size));
}
}
@ -1899,9 +1945,9 @@ namespace WAN {
int64_t H = x->ne[1];
int64_t T = x->ne[2];
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
int pad_t = (std::get<0>(config.patch_size) - T % std::get<0>(config.patch_size)) % std::get<0>(config.patch_size);
int pad_h = (std::get<1>(config.patch_size) - H % std::get<1>(config.patch_size)) % std::get<1>(config.patch_size);
int pad_w = (std::get<2>(config.patch_size) - W % std::get<2>(config.patch_size)) % std::get<2>(config.patch_size);
ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
@ -1914,9 +1960,9 @@ namespace WAN {
// x: [N, t_len*h_len*w_len, pt*ph*pw*C]
// return: [N*C, t_len*pt, h_len*ph, w_len*pw]
int64_t N = x->ne[3];
int64_t pt = std::get<0>(params.patch_size);
int64_t ph = std::get<1>(params.patch_size);
int64_t pw = std::get<2>(params.patch_size);
int64_t pt = std::get<0>(config.patch_size);
int64_t ph = std::get<1>(config.patch_size);
int64_t pw = std::get<2>(config.patch_size);
int64_t C = x->ne[0] / pt / ph / pw;
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
@ -1967,7 +2013,7 @@ namespace WAN {
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
// time_embedding
auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, params.freq_dim);
auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, config.freq_dim);
e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx->ggml_ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
@ -1983,7 +2029,7 @@ namespace WAN {
int64_t context_img_len = 0;
if (clip_fea != nullptr) {
if (params.model_type == "i2v") {
if (config.model_type == "i2v") {
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
context = ggml_concat(ctx->ggml_ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
@ -1993,7 +2039,7 @@ namespace WAN {
// vace_patch_embedding
ggml_tensor* c = nullptr;
if (params.vace_layers > 0) {
if (config.vace_layers > 0) {
auto vace_patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["vace_patch_embedding"]);
c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len]
@ -2010,13 +2056,13 @@ namespace WAN {
auto x_orig = x;
for (int i = 0; i < params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, e0, pe, context, context_img_len);
auto iter = params.vace_layers_mapping.find(i);
if (iter != params.vace_layers_mapping.end()) {
auto iter = config.vace_layers_mapping.find(i);
if (iter != config.vace_layers_mapping.end()) {
int n = iter->second;
auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]);
@ -2065,14 +2111,14 @@ namespace WAN {
x = pad_to_patch_size(ctx, x);
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
int64_t t_len = ((T + (std::get<0>(config.patch_size) / 2)) / std::get<0>(config.patch_size));
int64_t h_len = ((H + (std::get<1>(config.patch_size) / 2)) / std::get<1>(config.patch_size));
int64_t w_len = ((W + (std::get<2>(config.patch_size) / 2)) / std::get<2>(config.patch_size));
if (time_dim_concat != nullptr) {
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
t_len = ((x->ne[2] + (std::get<0>(config.patch_size) / 2)) / std::get<0>(config.patch_size));
}
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
@ -2092,7 +2138,7 @@ namespace WAN {
struct WanRunner : public DiffusionModelRunner {
public:
std::string desc = "wan";
WanParams wan_params;
WanConfig config;
Wan wan;
std::vector<float> pe_vec;
SDVersion version;
@ -2102,109 +2148,73 @@ namespace WAN {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_WAN2)
: DiffusionModelRunner(backend, params_backend, prefix) {
wan_params.num_layers = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("vace_blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > wan_params.vace_layers) {
wan_params.vace_layers = block_index + 1;
}
}
continue;
}
pos = tensor_name.find("blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > wan_params.num_layers) {
wan_params.num_layers = block_index + 1;
}
}
continue;
}
if (tensor_name.find("img_emb") != std::string::npos) {
wan_params.model_type = "i2v";
}
if (tensor_name.find("img_emb.emb_pos") != std::string::npos) {
wan_params.flf_pos_embed_token_number = 514;
}
}
if (wan_params.num_layers == 30) {
: DiffusionModelRunner(backend, params_backend, prefix),
config(WanConfig::detect_from_weights(tensor_storage_map, prefix)) {
if (config.num_layers == 30) {
if (version == VERSION_WAN2_2_TI2V) {
desc = "Wan2.2-TI2V-5B";
wan_params.dim = 3072;
wan_params.eps = 1e-06f;
wan_params.ffn_dim = 14336;
wan_params.freq_dim = 256;
wan_params.in_dim = 48;
wan_params.num_heads = 24;
wan_params.out_dim = 48;
wan_params.text_len = 512;
desc = "Wan2.2-TI2V-5B";
config.dim = 3072;
config.eps = 1e-06f;
config.ffn_dim = 14336;
config.freq_dim = 256;
config.in_dim = 48;
config.num_heads = 24;
config.out_dim = 48;
config.text_len = 512;
} else {
if (wan_params.vace_layers > 0) {
desc = "Wan2.1-VACE-1.3B";
wan_params.in_dim = 16;
} else if (wan_params.model_type == "i2v") {
desc = "Wan2.1-I2V-1.3B";
wan_params.in_dim = 36;
if (config.vace_layers > 0) {
desc = "Wan2.1-VACE-1.3B";
config.in_dim = 16;
} else if (config.model_type == "i2v") {
desc = "Wan2.1-I2V-1.3B";
config.in_dim = 36;
} else {
desc = "Wan2.1-T2V-1.3B";
wan_params.in_dim = 16;
desc = "Wan2.1-T2V-1.3B";
config.in_dim = 16;
}
wan_params.dim = 1536;
wan_params.eps = 1e-06f;
wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256;
wan_params.num_heads = 12;
wan_params.out_dim = 16;
wan_params.text_len = 512;
config.dim = 1536;
config.eps = 1e-06f;
config.ffn_dim = 8960;
config.freq_dim = 256;
config.num_heads = 12;
config.out_dim = 16;
config.text_len = 512;
}
} else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") {
} else if (config.num_layers == 40) {
if (config.model_type == "t2v") {
if (version == VERSION_WAN2_2_I2V) {
desc = "Wan2.2-I2V-14B";
wan_params.in_dim = 36;
desc = "Wan2.2-I2V-14B";
config.in_dim = 36;
} else {
if (wan_params.vace_layers > 0) {
if (config.vace_layers > 0) {
desc = "Wan2.x-VACE-14B";
} else {
desc = "Wan2.x-T2V-14B";
}
wan_params.in_dim = 16;
config.in_dim = 16;
}
} else {
wan_params.in_dim = 36;
if (wan_params.flf_pos_embed_token_number > 0) {
config.in_dim = 36;
if (config.flf_pos_embed_token_number > 0) {
desc = "Wan2.1-FLF2V-14B";
} else {
desc = "Wan2.1-I2V-14B";
}
}
wan_params.dim = 5120;
wan_params.eps = 1e-06f;
wan_params.ffn_dim = 13824;
wan_params.freq_dim = 256;
wan_params.num_heads = 40;
wan_params.out_dim = 16;
wan_params.text_len = 512;
config.dim = 5120;
config.eps = 1e-06f;
config.ffn_dim = 13824;
config.freq_dim = 256;
config.num_heads = 40;
config.out_dim = 16;
config.text_len = 512;
} else {
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
GGML_ABORT("invalid num_layers(%d) of wan", config.num_layers);
}
LOG_INFO("%s", desc.c_str());
wan = Wan(wan_params);
wan = Wan(config);
wan.init(params_ctx, tensor_storage_map, prefix);
}
@ -2237,15 +2247,15 @@ namespace WAN {
pe_vec = Rope::gen_wan_pe(static_cast<int>(x->ne[2]),
static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
std::get<0>(wan_params.patch_size),
std::get<1>(wan_params.patch_size),
std::get<2>(wan_params.patch_size),
std::get<0>(config.patch_size),
std::get<1>(config.patch_size),
std::get<2>(config.patch_size),
1,
wan_params.theta,
wan_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / wan_params.axes_dim_sum / 2);
config.theta,
config.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;

View File

@ -20,6 +20,104 @@ namespace ZImage {
constexpr int ADALN_EMBED_DIM = 256;
constexpr int SEQ_MULTI_OF = 32;
struct ZImageConfig {
int patch_size = 2;
int64_t hidden_size = 3840;
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t num_layers = 30;
int64_t num_refiner_layers = 2;
int64_t head_dim = 128;
int64_t num_heads = 30;
int64_t num_kv_heads = 30;
int64_t multiple_of = 256;
float ffn_dim_multiplier = 8.0f / 3.0f;
float norm_eps = 1e-5f;
bool qk_norm = true;
int64_t cap_feat_dim = 2560;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int64_t axes_dim_sum = 128;
static ZImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) {
ZImageConfig config;
int64_t detected_layers = 0;
int64_t detected_refiner_layers = 0;
int64_t detected_context_refiner = 0;
int64_t detected_head_dim = 0;
int64_t detected_qkv_dim = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "x_embedder.weight") && tensor_storage.n_dims == 2) {
int64_t patch_area = config.patch_size * config.patch_size;
config.in_channels = tensor_storage.ne[0] / patch_area;
config.hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "cap_embedder.1.weight") && tensor_storage.n_dims == 2) {
config.cap_feat_dim = tensor_storage.ne[0];
config.hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "layers.0.attention.q_norm.weight") && tensor_storage.n_dims == 1) {
detected_head_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.attention.qkv.weight") && tensor_storage.n_dims == 2) {
detected_qkv_dim = tensor_storage.ne[1];
} else if (ends_with(name, "final_layer.linear.weight") && tensor_storage.n_dims == 2) {
int64_t patch_area = config.patch_size * config.patch_size;
config.out_channels = tensor_storage.ne[1] / patch_area;
}
size_t pos = name.find("layers.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
detected_layers = std::max<int64_t>(detected_layers, block_index + 1);
}
}
pos = name.find("noise_refiner.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
detected_refiner_layers = std::max<int64_t>(detected_refiner_layers, block_index + 1);
}
}
pos = name.find("context_refiner.");
if (pos != std::string::npos) {
auto items = split_string(name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
detected_context_refiner = std::max<int64_t>(detected_context_refiner, block_index + 1);
}
}
}
if (detected_layers > 0) {
config.num_layers = detected_layers;
}
if (detected_refiner_layers > 0 || detected_context_refiner > 0) {
config.num_refiner_layers = std::max(detected_refiner_layers, detected_context_refiner);
}
if (detected_head_dim > 0) {
config.head_dim = detected_head_dim;
config.num_heads = config.hidden_size / config.head_dim;
if (detected_qkv_dim > 0) {
int64_t qkv_heads = detected_qkv_dim / config.head_dim;
config.num_kv_heads = std::max<int64_t>(1, (qkv_heads - config.num_heads) / 2);
}
}
LOG_DEBUG("z_image: num_layers = %" PRId64 ", num_refiner_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", num_kv_heads = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
config.num_layers,
config.num_refiner_layers,
config.hidden_size,
config.num_heads,
config.num_kv_heads,
config.in_channels,
config.out_channels);
return config;
}
};
struct JointAttention : public GGMLBlock {
protected:
int64_t head_dim;
@ -263,90 +361,70 @@ namespace ZImage {
}
};
struct ZImageParams {
int patch_size = 2;
int64_t hidden_size = 3840;
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t num_layers = 30;
int64_t num_refiner_layers = 2;
int64_t head_dim = 128;
int64_t num_heads = 30;
int64_t num_kv_heads = 30;
int64_t multiple_of = 256;
float ffn_dim_multiplier = 8.0f / 3.0f;
float norm_eps = 1e-5f;
bool qk_norm = true;
int64_t cap_feat_dim = 2560;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int64_t axes_dim_sum = 128;
};
class ZImageModel : public GGMLBlock {
protected:
ZImageParams z_image_params;
ZImageConfig config;
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.hidden_size);
params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.hidden_size);
}
public:
ZImageModel() = default;
ZImageModel(ZImageParams z_image_params)
: z_image_params(z_image_params) {
blocks["x_embedder"] = std::make_shared<Linear>(z_image_params.patch_size * z_image_params.patch_size * z_image_params.in_channels, z_image_params.hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(MIN(z_image_params.hidden_size, 1024), 256, 256);
blocks["cap_embedder.0"] = std::make_shared<RMSNorm>(z_image_params.cap_feat_dim, z_image_params.norm_eps);
blocks["cap_embedder.1"] = std::make_shared<Linear>(z_image_params.cap_feat_dim, z_image_params.hidden_size);
ZImageModel(ZImageConfig config)
: config(config) {
blocks["x_embedder"] = std::make_shared<Linear>(config.patch_size * config.patch_size * config.in_channels, config.hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(MIN(config.hidden_size, 1024), 256, 256);
blocks["cap_embedder.0"] = std::make_shared<RMSNorm>(config.cap_feat_dim, config.norm_eps);
blocks["cap_embedder.1"] = std::make_shared<Linear>(config.cap_feat_dim, config.hidden_size);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
for (int i = 0; i < config.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
config.hidden_size,
config.head_dim,
config.num_heads,
config.num_kv_heads,
config.multiple_of,
config.ffn_dim_multiplier,
config.norm_eps,
config.qk_norm,
true);
blocks["noise_refiner." + std::to_string(i)] = block;
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
for (int i = 0; i < config.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
config.hidden_size,
config.head_dim,
config.num_heads,
config.num_kv_heads,
config.multiple_of,
config.ffn_dim_multiplier,
config.norm_eps,
config.qk_norm,
false);
blocks["context_refiner." + std::to_string(i)] = block;
}
for (int i = 0; i < z_image_params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
config.hidden_size,
config.head_dim,
config.num_heads,
config.num_kv_heads,
config.multiple_of,
config.ffn_dim_multiplier,
config.norm_eps,
config.qk_norm,
true);
blocks["layers." + std::to_string(i)] = block;
}
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
blocks["final_layer"] = std::make_shared<FinalLayer>(config.hidden_size, config.patch_size, config.out_channels);
}
ggml_tensor* forward_core(GGMLRunnerContext* ctx,
@ -393,14 +471,14 @@ namespace ZImage {
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
for (int i = 0; i < config.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr);
sd::ggml_graph_cut::mark_graph_cut(txt, "z_image.context_refiner." + std::to_string(i), "txt");
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
for (int i = 0; i < config.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
img = block->forward(ctx, img, img_pe, nullptr, t_emb);
@ -410,7 +488,7 @@ namespace ZImage {
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size]
sd::ggml_graph_cut::mark_graph_cut(txt_img, "z_image.prelude", "txt_img");
for (int i = 0; i < z_image_params.num_layers; i++) {
for (int i = 0; i < config.num_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb);
@ -442,7 +520,7 @@ namespace ZImage {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
int patch_size = z_image_params.patch_size;
int patch_size = config.patch_size;
auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size, false);
uint64_t n_img_token = img->ne[1];
@ -467,7 +545,7 @@ namespace ZImage {
struct ZImageRunner : public DiffusionModelRunner {
public:
ZImageParams z_image_params;
ZImageConfig config;
ZImageModel z_image;
std::vector<float> pe_vec;
std::vector<float> timestep_vec;
@ -478,8 +556,9 @@ namespace ZImage {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_Z_IMAGE)
: DiffusionModelRunner(backend, params_backend, prefix) {
z_image = ZImageModel(z_image_params);
: DiffusionModelRunner(backend, params_backend, prefix),
config(ZImageConfig::detect_from_weights(tensor_storage_map, prefix)) {
z_image = ZImageModel(config);
z_image.init(params_ctx, tensor_storage_map, prefix);
}
@ -510,19 +589,19 @@ namespace ZImage {
pe_vec = Rope::gen_z_image_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
z_image_params.patch_size,
config.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
SEQ_MULTI_OF,
ref_latents,
increase_ref_index,
z_image_params.theta,
config.theta,
circular_y_enabled,
circular_x_enabled,
z_image_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / z_image_params.axes_dim_sum / 2);
config.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe, true, "pe");
// pe->data = nullptr;