Merge branch 'master' into vace

This commit is contained in:
leejet 2025-09-08 22:51:42 +08:00
commit 4b9bf2b513
21 changed files with 344 additions and 152 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@ test/
.cache/ .cache/
*.swp *.swp
.vscode/ .vscode/
.idea/
*.bat *.bat
*.bin *.bin
*.exe *.exe

View File

@ -137,7 +137,9 @@ This provides BLAS acceleration using the ROCm cores of your AMD GPU. Make sure
Windows User Refer to [docs/hipBLAS_on_Windows.md](docs%2FhipBLAS_on_Windows.md) for a comprehensive guide. Windows User Refer to [docs/hipBLAS_on_Windows.md](docs%2FhipBLAS_on_Windows.md) for a comprehensive guide.
``` ```
cmake .. -G "Ninja" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=gfx1100 -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON export GFX_NAME=$(rocminfo | grep -m 1 -E "gfx[^0]{1}" | sed -e 's/ *Name: *//' | awk '{$1=$1; print}' || echo "rocminfo missing")
echo $GFX_NAME
cmake .. -G "Ninja" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS=$GFX_NAME -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON
cmake --build . --config Release cmake --build . --config Release
``` ```
@ -319,6 +321,7 @@ arguments:
-i, --end-img [IMAGE] path to the end image, required by flf2v -i, --end-img [IMAGE] path to the end image, required by flf2v
--control-image [IMAGE] path to image condition, control net --control-image [IMAGE] path to image condition, control net
-r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times)
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
-o, --output OUTPUT path to write result image to (default: ./output.png) -o, --output OUTPUT path to write result image to (default: ./output.png)
-p, --prompt [PROMPT] the prompt to render -p, --prompt [PROMPT] the prompt to render
-n, --negative-prompt PROMPT the negative prompt (default: "") -n, --negative-prompt PROMPT the negative prompt (default: "")

View File

@ -488,14 +488,14 @@ public:
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size)); blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = true) { struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]); auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]); auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]); auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]); auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
x = ggml_add(ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask)); x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x))); x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
return x; return x;
} }
@ -517,7 +517,11 @@ public:
} }
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
int clip_skip = -1,
bool mask = true) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
int layer_idx = n_layer - 1; int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip); // LOG_DEBUG("clip_skip %d", clip_skip);
@ -532,7 +536,7 @@ public:
} }
std::string name = "layers." + std::to_string(i); std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]); auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, x, mask); // [N, n_token, d_model] x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
// LOG_DEBUG("layer %d", i); // LOG_DEBUG("layer %d", i);
} }
return x; return x;
@ -712,6 +716,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings, struct ggml_tensor* tkn_embeddings,
size_t max_token_idx = 0, size_t max_token_idx = 0,
@ -722,7 +727,7 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]); auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
if (return_pooled || with_final_ln) { if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x); x = final_layer_norm->forward(ctx, x);
} }
@ -775,6 +780,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values, struct ggml_tensor* pixel_values,
bool return_pooled = true, bool return_pooled = true,
int clip_skip = -1) { int clip_skip = -1) {
@ -786,7 +792,7 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x); x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, clip_skip, false); x = encoder->forward(ctx, backend, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: "); // print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x; auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
@ -855,6 +861,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values, struct ggml_tensor* pixel_values,
bool return_pooled = true, bool return_pooled = true,
int clip_skip = -1) { int clip_skip = -1) {
@ -863,7 +870,7 @@ public:
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]); auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]); auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size] auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
if (return_pooled) { if (return_pooled) {
x = visual_projection->forward(ctx, x); // [N, projection_dim] x = visual_projection->forward(ctx, x); // [N, projection_dim]
@ -900,6 +907,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings, struct ggml_tensor* embeddings,
size_t max_token_idx = 0, size_t max_token_idx = 0,
@ -911,7 +919,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
} }
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled); return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled);
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@ -937,7 +945,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
} }
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled); struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);

View File

@ -270,7 +270,10 @@ public:
// to_out_1 is nn.Dropout(), skip for inference // to_out_1 is nn.Dropout(), skip for inference
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim] // x: [N, n_token, query_dim]
// context: [N, n_context, context_dim] // context: [N, n_context, context_dim]
// return: [N, n_token, query_dim] // return: [N, n_token, query_dim]
@ -288,7 +291,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x; return x;
@ -327,7 +330,10 @@ public:
} }
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim] // x: [N, n_token, query_dim]
// context: [N, n_context, context_dim] // context: [N, n_context, context_dim]
// return: [N, n_token, query_dim] // return: [N, n_token, query_dim]
@ -352,11 +358,11 @@ public:
auto r = x; auto r = x;
x = norm1->forward(ctx, x); x = norm1->forward(ctx, x);
x = attn1->forward(ctx, x, x); // self-attention x = attn1->forward(ctx, backend, x, x); // self-attention
x = ggml_add(ctx, x, r); x = ggml_add(ctx, x, r);
r = x; r = x;
x = norm2->forward(ctx, x); x = norm2->forward(ctx, x);
x = attn2->forward(ctx, x, context); // cross-attention x = attn2->forward(ctx, backend, x, context); // cross-attention
x = ggml_add(ctx, x, r); x = ggml_add(ctx, x, r);
r = x; r = x;
x = norm3->forward(ctx, x); x = norm3->forward(ctx, x);
@ -401,7 +407,10 @@ public:
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1})); blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
} }
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)] // context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
@ -424,7 +433,7 @@ public:
std::string name = "transformer_blocks." + std::to_string(i); std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]); auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
x = transformer_block->forward(ctx, x, context); x = transformer_block->forward(ctx, backend, x, context);
} }
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]

View File

@ -639,7 +639,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
pixel_values = to_backend(pixel_values); pixel_values = to_backend(pixel_values);
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip); struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);

View File

@ -174,10 +174,11 @@ public:
struct ggml_tensor* attention_layer_forward(std::string name, struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx, struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context) { struct ggml_tensor* context) {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]); auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, x, context); return block->forward(ctx, backend, x, context);
} }
struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx, struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
@ -199,6 +200,7 @@ public:
} }
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx, std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* hint, struct ggml_tensor* hint,
struct ggml_tensor* guided_hint, struct ggml_tensor* guided_hint,
@ -272,7 +274,7 @@ public:
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w] h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w] h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
} }
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]); auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
@ -297,7 +299,7 @@ public:
// middle_block // middle_block
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8] h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
// out // out
@ -403,6 +405,7 @@ struct ControlNet : public GGMLRunner {
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
auto outs = control_net.forward(compute_ctx, auto outs = control_net.forward(compute_ctx,
runtime_backend,
x, x,
hint, hint,
guided_hint_cached ? guided_hint : NULL, guided_hint_cached ? guided_hint : NULL,

View File

@ -14,6 +14,7 @@ struct DiffusionParams {
struct ggml_tensor* y = NULL; struct ggml_tensor* y = NULL;
struct ggml_tensor* guidance = NULL; struct ggml_tensor* guidance = NULL;
std::vector<ggml_tensor*> ref_latents = {}; std::vector<ggml_tensor*> ref_latents = {};
bool increase_ref_index = false;
int num_video_frames = -1; int num_video_frames = -1;
std::vector<struct ggml_tensor*> controls = {}; std::vector<struct ggml_tensor*> controls = {};
float control_strength = 0.f; float control_strength = 0.f;
@ -195,6 +196,7 @@ struct FluxModel : public DiffusionModel {
diffusion_params.y, diffusion_params.y,
diffusion_params.guidance, diffusion_params.guidance,
diffusion_params.ref_latents, diffusion_params.ref_latents,
diffusion_params.increase_ref_index,
output, output,
output_ctx, output_ctx,
diffusion_params.skip_layers); diffusion_params.skip_layers);

View File

@ -11,3 +11,29 @@ Here's a simple example:
``` ```
`../models/marblesh.safetensors` or `../models/marblesh.ckpt` will be applied to the model `../models/marblesh.safetensors` or `../models/marblesh.ckpt` will be applied to the model
# Support matrix
> CUDA `get_rows` support is defined here:
> [ggml-org/ggml/src/ggml-cuda/getrows.cu#L156](https://github.com/ggml-org/ggml/blob/7dee1d6a1e7611f238d09be96738388da97c88ed/src/ggml-cuda/getrows.cu#L156)
> Currently only the basic types + Q4/Q5/Q8 are implemented. K-quants are **not** supported.
NOTE: The other backends may have different support.
| Quant / Type | CUDA |
|--------------|------|
| F32 | ✔️ |
| F16 | ✔️ |
| BF16 | ✔️ |
| I32 | ✔️ |
| Q4_0 | ✔️ |
| Q4_1 | ✔️ |
| Q5_0 | ✔️ |
| Q5_1 | ✔️ |
| Q8_0 | ✔️ |
| Q2_K | ❌ |
| Q3_K | ❌ |
| Q4_K | ❌ |
| Q5_K | ❌ |
| Q6_K | ❌ |
| Q8_K | ❌ |

View File

@ -1,6 +1,7 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
#include <filesystem>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
@ -74,6 +75,7 @@ struct SDParams {
std::string mask_image_path; std::string mask_image_path;
std::string control_image_path; std::string control_image_path;
std::vector<std::string> ref_image_paths; std::vector<std::string> ref_image_paths;
bool increase_ref_index = false;
std::string prompt; std::string prompt;
std::string negative_prompt; std::string negative_prompt;
@ -156,6 +158,7 @@ void print_params(SDParams params) {
for (auto& path : params.ref_image_paths) { for (auto& path : params.ref_image_paths) {
printf(" %s\n", path.c_str()); printf(" %s\n", path.c_str());
}; };
printf(" increase_ref_index: %s\n", params.increase_ref_index ? "true" : "false");
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false"); printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
@ -222,6 +225,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n"); printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" --control-image [IMAGE] path to image condition, control net\n");
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
printf(" --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).\n");
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@ -536,6 +540,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--color", "", true, &params.color}, {"", "--color", "", true, &params.color},
{"", "--chroma-disable-dit-mask", "", false, &params.chroma_use_dit_mask}, {"", "--chroma-disable-dit-mask", "", false, &params.chroma_use_dit_mask},
{"", "--chroma-enable-t5-mask", "", true, &params.chroma_use_t5_mask}, {"", "--chroma-enable-t5-mask", "", true, &params.chroma_use_t5_mask},
{"", "--increase-ref-index", "", true, &params.increase_ref_index},
}; };
auto on_mode_arg = [&](int argc, const char** argv, int index) { auto on_mode_arg = [&](int argc, const char** argv, int index) {
@ -1207,6 +1212,7 @@ int main(int argc, const char* argv[]) {
init_image, init_image,
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
params.increase_ref_index,
mask_image, mask_image,
params.width, params.width,
params.height, params.height,
@ -1278,6 +1284,21 @@ int main(int argc, const char* argv[]) {
} }
} }
// create directory if not exists
{
namespace fs = std::filesystem;
const fs::path out_path = params.output_path;
if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
std::error_code ec;
fs::create_directories(out_dir, ec); // OK if already exists
if (ec) {
fprintf(stderr, "failed to create directory '%s': %s\n",
out_dir.string().c_str(), ec.message().c_str());
return 1;
}
}
}
std::string base_path; std::string base_path;
std::string file_ext; std::string file_ext;
std::string file_ext_lower; std::string file_ext_lower;

View File

@ -114,6 +114,7 @@ namespace Flux {
} }
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q, struct ggml_tensor* q,
struct ggml_tensor* k, struct ggml_tensor* k,
struct ggml_tensor* v, struct ggml_tensor* v,
@ -126,7 +127,7 @@ namespace Flux {
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
return x; return x;
} }
@ -169,12 +170,16 @@ namespace Flux {
return x; return x;
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask) {
// x: [N, n_token, dim] // x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2] // pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim] // return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
@ -299,6 +304,7 @@ namespace Flux {
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* img, struct ggml_tensor* img,
struct ggml_tensor* txt, struct ggml_tensor* txt,
struct ggml_tensor* vec, struct ggml_tensor* vec,
@ -362,7 +368,7 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,
@ -446,6 +452,7 @@ namespace Flux {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* vec, struct ggml_tensor* vec,
struct ggml_tensor* pe, struct ggml_tensor* pe,
@ -496,7 +503,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q); q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -699,6 +706,7 @@ namespace Flux {
} }
struct ggml_tensor* forward_orig(struct ggml_context* ctx, struct ggml_tensor* forward_orig(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* img, struct ggml_tensor* img,
struct ggml_tensor* txt, struct ggml_tensor* txt,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
@ -763,7 +771,7 @@ namespace Flux {
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); auto img_txt = block->forward(ctx, backend, img, txt, vec, pe, txt_img_mask);
img = img_txt.first; // [N, n_img_token, hidden_size] img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size]
} }
@ -775,7 +783,7 @@ namespace Flux {
} }
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); txt_img = block->forward(ctx, backend, txt_img, vec, pe, txt_img_mask);
} }
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
@ -808,6 +816,7 @@ namespace Flux {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
@ -857,7 +866,7 @@ namespace Flux {
} }
} }
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
@ -951,6 +960,7 @@ namespace Flux {
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
std::vector<int> skip_layers = {}) { std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
@ -990,6 +1000,7 @@ namespace Flux {
x->ne[3], x->ne[3],
context->ne[1], context->ne[1],
ref_latents, ref_latents,
increase_ref_index,
flux_params.theta, flux_params.theta,
flux_params.axes_dim); flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
@ -1001,6 +1012,7 @@ namespace Flux {
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* out = flux.forward(compute_ctx, struct ggml_tensor* out = flux.forward(compute_ctx,
runtime_backend,
x, x,
timesteps, timesteps,
context, context,
@ -1025,6 +1037,7 @@ namespace Flux {
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
struct ggml_tensor** output = NULL, struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL, struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) { std::vector<int> skip_layers = std::vector<int>()) {
@ -1034,7 +1047,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels] // y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ] // guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers); return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
}; };
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -1074,7 +1087,7 @@ namespace Flux {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx); compute(8, x, timesteps, context, NULL, y, guidance, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

View File

@ -1032,6 +1032,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
// mask: [N, L_q, L_k] // mask: [N, L_q, L_k]
// return: [N, L_q, C] // return: [N, L_q, C]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q, struct ggml_tensor* q,
struct ggml_tensor* k, struct ggml_tensor* k,
struct ggml_tensor* v, struct ggml_tensor* v,
@ -1071,6 +1072,47 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head)); float scale = (1.0f / sqrt((float)d_head));
int kv_pad = 0; int kv_pad = 0;
ggml_tensor* kqv = nullptr;
auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
if (kv_pad != 0) {
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
}
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_head * N);
if (kv_pad != 0) {
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
}
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
if (mask_in != nullptr) {
mask_in = ggml_transpose(ctx, mask_in);
} else {
if (kv_pad > 0) {
mask_in = ggml_zeros(ctx, L_k, L_q, 1, 1);
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1);
mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0);
}
}
if (mask_in != nullptr) {
int mask_pad = 0;
if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) {
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1];
}
if (mask_pad > 0) {
mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0);
}
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
}
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
return out;
};
if (flash_attn) { if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
bool can_use_flash_attn = true; bool can_use_flash_attn = true;
@ -1079,60 +1121,24 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
} }
if (mask != nullptr) { if (mask != nullptr) {
// TODO(Green-Sky): figure out if we can bend t5 to work too // TODO: figure out if we can bend t5 to work too
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
} }
if (!can_use_flash_attn) { if (can_use_flash_attn) {
flash_attn = false; kqv = build_kqv(q, k, v, mask);
} if (!ggml_backend_supports_op(backend, kqv)) {
} kqv = nullptr;
ggml_tensor* kqv = nullptr;
if (flash_attn) {
// LOG_DEBUG(" uses flash attention");
if (kv_pad != 0) {
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
}
k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
if (kv_pad != 0) {
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
}
v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (mask != nullptr) {
mask = ggml_transpose(ctx, mask);
} else { } else {
if (kv_pad > 0) {
mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k]
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad]
mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad]
}
}
// mask pad
if (mask != nullptr) {
int mask_pad = 0;
if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) {
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1];
}
if (mask_pad > 0) {
mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad]
}
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
// LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad);
}
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
} else { }
}
}
if (kqv == nullptr) {
// if (flash_attn) {
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// }
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
@ -2196,7 +2202,10 @@ public:
} }
// x: [N, n_token, embed_dim] // x: [N, n_token, embed_dim]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = false) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
bool mask = false) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]); auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]); auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]); auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]);
@ -2206,7 +2215,7 @@ public:
struct ggml_tensor* k = k_proj->forward(ctx, x); struct ggml_tensor* k = k_proj->forward(ctx, x);
struct ggml_tensor* v = v_proj->forward(ctx, x); struct ggml_tensor* v = v_proj->forward(ctx, x);
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, mask); // [N, n_token, embed_dim] x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, mask); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x; return x;

View File

@ -58,6 +58,7 @@ struct LoraModel : public GGMLRunner {
{"x_block.attn.proj", "attn.to_out.0"}, {"x_block.attn.proj", "attn.to_out.0"},
{"x_block.attn2.proj", "attn2.to_out.0"}, {"x_block.attn2.proj", "attn2.to_out.0"},
// flux // flux
{"img_in", "x_embedder"},
// singlestream // singlestream
{"linear2", "proj_out"}, {"linear2", "proj_out"},
{"modulation.lin", "norm.linear"}, {"modulation.lin", "norm.linear"},

View File

@ -202,9 +202,11 @@ public:
} }
// x: [N, n_token, dim] // x: [N, n_token, dim]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x); auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
@ -415,7 +417,10 @@ public:
return x; return x;
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size] // x: [N, n_token, hidden_size]
// c: [N, hidden_size] // c: [N, hidden_size]
// return: [N, n_token, hidden_size] // return: [N, n_token, hidden_size]
@ -430,8 +435,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
x = post_attention_x(ctx, x = post_attention_x(ctx,
attn_out, attn_out,
attn2_out, attn2_out,
@ -447,7 +452,7 @@ public:
auto qkv = qkv_intermediates.first; auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second; auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
@ -462,6 +467,7 @@ public:
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx, block_mixing(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c, struct ggml_tensor* c,
@ -491,7 +497,7 @@ block_mixing(struct ggml_context* ctx,
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx, auto context_attn = ggml_view_3d(ctx,
attn, attn,
@ -525,7 +531,7 @@ block_mixing(struct ggml_context* ctx,
} }
if (x_block->self_attn) { if (x_block->self_attn) {
auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] auto attn2 = ggml_nn_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx, x = x_block->post_attention_x(ctx,
x_attn, x_attn,
@ -563,13 +569,14 @@ public:
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c) { struct ggml_tensor* c) {
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]); auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]); auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
return block_mixing(ctx, context, x, c, context_block, x_block); return block_mixing(ctx, backend, context, x, c, context_block, x_block);
} }
}; };
@ -771,6 +778,7 @@ public:
} }
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c_mod, struct ggml_tensor* c_mod,
struct ggml_tensor* context, struct ggml_tensor* context,
@ -789,7 +797,7 @@ public:
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
auto context_x = block->forward(ctx, context, x, c_mod); auto context_x = block->forward(ctx, backend, context, x, c_mod);
context = context_x.first; context = context_x.first;
x = context_x.second; x = context_x.second;
} }
@ -800,6 +808,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* t, struct ggml_tensor* t,
struct ggml_tensor* y = NULL, struct ggml_tensor* y = NULL,
@ -835,7 +844,7 @@ public:
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
} }
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
x = unpatchify(ctx, x, h, w); // [N, C, H, W] x = unpatchify(ctx, x, h, w); // [N, C, H, W]
@ -874,6 +883,7 @@ struct MMDiTRunner : public GGMLRunner {
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
struct ggml_tensor* out = mmdit.forward(compute_ctx, struct ggml_tensor* out = mmdit.forward(compute_ctx,
runtime_backend,
x, x,
timesteps, timesteps,
y, y,

View File

@ -1966,6 +1966,16 @@ std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& v
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
int64_t process_time_ms = 0;
int64_t read_time_ms = 0;
int64_t memcpy_time_ms = 0;
int64_t copy_to_backend_time_ms = 0;
int64_t convert_time_ms = 0;
int64_t prev_time_ms = 0;
int64_t curr_time_ms = 0;
int64_t start_time = ggml_time_ms();
prev_time_ms = start_time;
std::vector<TensorStorage> processed_tensor_storages; std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
// LOG_DEBUG("%s", name.c_str()); // LOG_DEBUG("%s", name.c_str());
@ -1978,6 +1988,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
} }
std::vector<TensorStorage> dedup = remove_duplicates(processed_tensor_storages); std::vector<TensorStorage> dedup = remove_duplicates(processed_tensor_storages);
processed_tensor_storages = dedup; processed_tensor_storages = dedup;
curr_time_ms = ggml_time_ms();
process_time_ms = curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
bool success = true; bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
@ -2019,15 +2032,27 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
size_t entry_size = zip_entry_size(zip); size_t entry_size = zip_entry_size(zip);
if (entry_size != n) { if (entry_size != n) {
read_buffer.resize(entry_size); read_buffer.resize(entry_size);
prev_time_ms = ggml_time_ms();
zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size); zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n); memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n);
curr_time_ms = ggml_time_ms();
memcpy_time_ms += curr_time_ms - prev_time_ms;
} else { } else {
prev_time_ms = ggml_time_ms();
zip_entry_noallocread(zip, (void*)buf, n); zip_entry_noallocread(zip, (void*)buf, n);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
} }
zip_entry_close(zip); zip_entry_close(zip);
} else { } else {
prev_time_ms = ggml_time_ms();
file.seekg(tensor_storage.offset); file.seekg(tensor_storage.offset);
file.read(buf, n); file.read(buf, n);
curr_time_ms = ggml_time_ms();
read_time_ms += curr_time_ms - prev_time_ms;
if (!file) { if (!file) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str()); LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
return false; return false;
@ -2072,6 +2097,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read); read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read);
} }
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) { if (tensor_storage.is_bf16) {
// inplace op // inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
@ -2086,10 +2112,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
} else if (tensor_storage.is_i64) { } else if (tensor_storage.is_i64) {
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements()); i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements());
} }
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
} else { } else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read); read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) { if (tensor_storage.is_bf16) {
// inplace op // inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
@ -2109,11 +2138,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
} }
} else { } else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read); read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read);
prev_time_ms = ggml_time_ms();
if (tensor_storage.is_bf16) { if (tensor_storage.is_bf16) {
// inplace op // inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
@ -2133,14 +2165,24 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
if (tensor_storage.type == dst_tensor->type) { if (tensor_storage.type == dst_tensor->type) {
// copy to device memory // copy to device memory
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor)); ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
curr_time_ms = ggml_time_ms();
copy_to_backend_time_ms += curr_time_ms - prev_time_ms;
} else { } else {
// convert first, then copy to device memory // convert first, then copy to device memory
convert_buffer.resize(ggml_nbytes(dst_tensor)); convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type, convert_tensor((void*)read_buffer.data(), tensor_storage.type,
(void*)convert_buffer.data(), dst_tensor->type, (void*)convert_buffer.data(), dst_tensor->type,
(int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
curr_time_ms = ggml_time_ms();
convert_time_ms += curr_time_ms - prev_time_ms;
prev_time_ms = curr_time_ms;
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
curr_time_ms = ggml_time_ms();
copy_to_backend_time_ms += curr_time_ms - prev_time_ms;
} }
} }
++tensor_count; ++tensor_count;
@ -2170,6 +2212,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
break; break;
} }
} }
int64_t end_time = ggml_time_ms();
LOG_INFO("loading tensors completed, taking %.2fs (process: %.2fs, read: %.2fs, memcpy: %.2fs, convert: %.2fs, copy_to_backend: %.2fs)",
(end_time - start_time) / 1000.f,
process_time_ms / 1000.f,
read_time_ms / 1000.f,
memcpy_time_ms / 1000.f,
convert_time_ms / 1000.f,
copy_to_backend_time_ms / 1000.f);
return success; return success;
} }

View File

@ -508,6 +508,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* id_pixel_values, struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds, struct ggml_tensor* prompt_embeds,
struct ggml_tensor* class_tokens_mask, struct ggml_tensor* class_tokens_mask,
@ -520,7 +521,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]); auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]);
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]); auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, backend, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)]
struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280]
@ -579,6 +580,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
*/ */
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* id_pixel_values, struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds, struct ggml_tensor* prompt_embeds,
struct ggml_tensor* class_tokens_mask, struct ggml_tensor* class_tokens_mask,
@ -592,7 +594,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]); auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, backend, id_pixel_values, false); // [N, hidden_size]
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
@ -742,6 +744,7 @@ public:
struct ggml_tensor* updated_prompt_embeds = NULL; struct ggml_tensor* updated_prompt_embeds = NULL;
if (pm_version == PM_VERSION_1) if (pm_version == PM_VERSION_1)
updated_prompt_embeds = id_encoder.forward(ctx0, updated_prompt_embeds = id_encoder.forward(ctx0,
runtime_backend,
id_pixel_values_d, id_pixel_values_d,
prompt_embeds_d, prompt_embeds_d,
class_tokens_mask_d, class_tokens_mask_d,
@ -749,6 +752,7 @@ public:
left, right); left, right);
else if (pm_version == PM_VERSION_2) else if (pm_version == PM_VERSION_2)
updated_prompt_embeds = id_encoder2.forward(ctx0, updated_prompt_embeds = id_encoder2.forward(ctx0,
runtime_backend,
id_pixel_values_d, id_pixel_values_d,
prompt_embeds_d, prompt_embeds_d,
class_tokens_mask_d, class_tokens_mask_d,

View File

@ -156,25 +156,33 @@ struct Rope {
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
std::vector<ggml_tensor*> ref_latents) { std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len); auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs); auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
int index = 1;
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0; uint64_t h_offset = 0;
uint64_t w_offset = 0; uint64_t w_offset = 0;
if (!increase_ref_index) {
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset; w_offset = curr_w_offset;
} else { } else {
h_offset = curr_h_offset; h_offset = curr_h_offset;
} }
}
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, index, h_offset, w_offset);
ids = concat_ids(ids, ref_ids, bs); ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) {
index++;
}
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
} }
@ -188,9 +196,10 @@ struct Rope {
int bs, int bs,
int context_len, int context_len,
std::vector<ggml_tensor*> ref_latents, std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents); std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }

View File

@ -330,7 +330,7 @@ public:
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
use_t5xxl = true; use_t5xxl = true;
} }
if (!ggml_backend_is_cpu(backend) && use_t5xxl) { if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
LOG_WARN( LOG_WARN(
"!!!It appears that you are using the T5 model. Some backends may encounter issues with it." "!!!It appears that you are using the T5 model. Some backends may encounter issues with it."
"If you notice that the generated images are completely black," "If you notice that the generated images are completely black,"
@ -557,8 +557,6 @@ public:
// load weights // load weights
LOG_DEBUG("loading weights"); LOG_DEBUG("loading weights");
int64_t t0 = ggml_time_ms();
std::set<std::string> ignore_tensors; std::set<std::string> ignore_tensors;
tensors["alphas_cumprod"] = alphas_cumprod_tensor; tensors["alphas_cumprod"] = alphas_cumprod_tensor;
if (use_tiny_autoencoder) { if (use_tiny_autoencoder) {
@ -656,11 +654,7 @@ public:
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
} }
int64_t t1 = ggml_time_ms();
LOG_INFO("loading model from '%s' completed, taking %.2fs", SAFE_STR(sd_ctx_params->model_path), (t1 - t0) * 1.0f / 1000);
// check is_using_v_parameterization_for_sd2 // check is_using_v_parameterization_for_sd2
if (sd_version_is_sd2(version)) { if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true; is_using_v_parameterization = true;
@ -1037,6 +1031,7 @@ public:
int start_merge_step, int start_merge_step,
SDCondition id_cond, SDCondition id_cond,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
ggml_tensor* denoise_mask = NULL, ggml_tensor* denoise_mask = NULL,
ggml_tensor* vace_context = NULL, ggml_tensor* vace_context = NULL,
float vace_strength = 1.f) { float vace_strength = 1.f) {
@ -1128,6 +1123,7 @@ public:
diffusion_params.timesteps = timesteps; diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor; diffusion_params.guidance = guidance_tensor;
diffusion_params.ref_latents = ref_latents; diffusion_params.ref_latents = ref_latents;
diffusion_params.increase_ref_index = increase_ref_index;
diffusion_params.controls = controls; diffusion_params.controls = controls;
diffusion_params.control_strength = control_strength; diffusion_params.control_strength = control_strength;
diffusion_params.vace_context = vace_context; diffusion_params.vace_context = vace_context;
@ -1697,6 +1693,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"\n" "\n"
"batch_count: %d\n" "batch_count: %d\n"
"ref_images_count: %d\n" "ref_images_count: %d\n"
"increase_ref_index: %s\n"
"control_strength: %.2f\n" "control_strength: %.2f\n"
"style_strength: %.2f\n" "style_strength: %.2f\n"
"normalize_input: %s\n" "normalize_input: %s\n"
@ -1711,6 +1708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->seed, sd_img_gen_params->seed,
sd_img_gen_params->batch_count, sd_img_gen_params->batch_count,
sd_img_gen_params->ref_images_count, sd_img_gen_params->ref_images_count,
BOOL_STR(sd_img_gen_params->increase_ref_index),
sd_img_gen_params->control_strength, sd_img_gen_params->control_strength,
sd_img_gen_params->style_strength, sd_img_gen_params->style_strength,
BOOL_STR(sd_img_gen_params->normalize_input), BOOL_STR(sd_img_gen_params->normalize_input),
@ -1784,6 +1782,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
bool normalize_input, bool normalize_input,
std::string input_id_images_path, std::string input_id_images_path,
std::vector<ggml_tensor*> ref_latents, std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
ggml_tensor* concat_latent = NULL, ggml_tensor* concat_latent = NULL,
ggml_tensor* denoise_mask = NULL) { ggml_tensor* denoise_mask = NULL) {
if (seed < 0) { if (seed < 0) {
@ -2041,6 +2040,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
start_merge_step, start_merge_step,
id_cond, id_cond,
ref_latents, ref_latents,
increase_ref_index,
denoise_mask); denoise_mask);
// print_ggml_tensor(x_0); // print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
@ -2291,7 +2291,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("EDIT mode"); LOG_INFO("EDIT mode");
} }
std::vector<struct ggml_tensor*> ref_latents; std::vector<ggml_tensor*> ref_latents;
for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) { for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) {
ggml_tensor* img = ggml_new_tensor_4d(work_ctx, ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
@ -2346,6 +2346,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->normalize_input, sd_img_gen_params->normalize_input,
sd_img_gen_params->input_id_images_path, sd_img_gen_params->input_id_images_path,
ref_latents, ref_latents,
sd_img_gen_params->increase_ref_index,
concat_latent, concat_latent,
denoise_mask); denoise_mask);
@ -2641,6 +2642,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
-1, -1,
{}, {},
{}, {},
false,
denoise_mask, denoise_mask,
vace_context); vace_context);
@ -2674,6 +2676,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
-1, -1,
{}, {},
{}, {},
false,
denoise_mask, denoise_mask,
vace_context); vace_context);

View File

@ -182,6 +182,7 @@ typedef struct {
sd_image_t init_image; sd_image_t init_image;
sd_image_t* ref_images; sd_image_t* ref_images;
int ref_images_count; int ref_images_count;
bool increase_ref_index;
sd_image_t mask_image; sd_image_t mask_image;
int width; int width;
int height; int height;

20
t5.hpp
View File

@ -578,6 +578,7 @@ public:
// x: [N, n_token, model_dim] // x: [N, n_token, model_dim]
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* past_bias = NULL, struct ggml_tensor* past_bias = NULL,
struct ggml_tensor* mask = NULL, struct ggml_tensor* mask = NULL,
@ -608,7 +609,7 @@ public:
k = ggml_scale_inplace(ctx, k, sqrt(d_head)); k = ggml_scale_inplace(ctx, k, sqrt(d_head));
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
x = out_proj->forward(ctx, x); // [N, n_token, model_dim] x = out_proj->forward(ctx, x); // [N, n_token, model_dim]
return {x, past_bias}; return {x, past_bias};
@ -627,6 +628,7 @@ public:
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* past_bias = NULL, struct ggml_tensor* past_bias = NULL,
struct ggml_tensor* mask = NULL, struct ggml_tensor* mask = NULL,
@ -636,7 +638,7 @@ public:
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]); auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
auto normed_hidden_state = layer_norm->forward(ctx, x); auto normed_hidden_state = layer_norm->forward(ctx, x);
auto ret = SelfAttention->forward(ctx, normed_hidden_state, past_bias, mask, relative_position_bucket); auto ret = SelfAttention->forward(ctx, backend, normed_hidden_state, past_bias, mask, relative_position_bucket);
auto output = ret.first; auto output = ret.first;
past_bias = ret.second; past_bias = ret.second;
@ -653,6 +655,7 @@ public:
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* past_bias = NULL, struct ggml_tensor* past_bias = NULL,
struct ggml_tensor* mask = NULL, struct ggml_tensor* mask = NULL,
@ -661,7 +664,7 @@ public:
auto layer_0 = std::dynamic_pointer_cast<T5LayerSelfAttention>(blocks["layer.0"]); auto layer_0 = std::dynamic_pointer_cast<T5LayerSelfAttention>(blocks["layer.0"]);
auto layer_1 = std::dynamic_pointer_cast<T5LayerFF>(blocks["layer.1"]); auto layer_1 = std::dynamic_pointer_cast<T5LayerFF>(blocks["layer.1"]);
auto ret = layer_0->forward(ctx, x, past_bias, mask, relative_position_bucket); auto ret = layer_0->forward(ctx, backend, x, past_bias, mask, relative_position_bucket);
x = ret.first; x = ret.first;
past_bias = ret.second; past_bias = ret.second;
x = layer_1->forward(ctx, x); x = layer_1->forward(ctx, x);
@ -688,6 +691,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* past_bias = NULL, struct ggml_tensor* past_bias = NULL,
struct ggml_tensor* attention_mask = NULL, struct ggml_tensor* attention_mask = NULL,
@ -696,7 +700,7 @@ public:
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]);
auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); auto ret = block->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
x = ret.first; x = ret.first;
past_bias = ret.second; past_bias = ret.second;
} }
@ -735,6 +739,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* past_bias = NULL, struct ggml_tensor* past_bias = NULL,
struct ggml_tensor* attention_mask = NULL, struct ggml_tensor* attention_mask = NULL,
@ -745,7 +750,7 @@ public:
auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]); auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]);
auto x = shared->forward(ctx, input_ids); auto x = shared->forward(ctx, input_ids);
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); x = encoder->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
return x; return x;
} }
}; };
@ -778,13 +783,14 @@ struct T5Runner : public GGMLRunner {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* relative_position_bucket, struct ggml_tensor* relative_position_bucket,
struct ggml_tensor* attention_mask = NULL) { struct ggml_tensor* attention_mask = NULL) {
size_t N = input_ids->ne[1]; size_t N = input_ids->ne[1];
size_t n_token = input_ids->ne[0]; size_t n_token = input_ids->ne[0];
auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] auto hidden_states = model.forward(ctx, backend, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
return hidden_states; return hidden_states;
} }
@ -810,7 +816,7 @@ struct T5Runner : public GGMLRunner {
input_ids->ne[0]); input_ids->ne[0]);
set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, relative_position_bucket, attention_mask);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);

View File

@ -61,6 +61,7 @@ public:
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context, struct ggml_tensor* context,
int timesteps) { int timesteps) {
@ -127,7 +128,7 @@ public:
auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]); auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]);
auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]); auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]);
x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim] x = block->forward(ctx, backend, x, spatial_context); // [N, h * w, inner_dim]
// in_channels == inner_dim // in_channels == inner_dim
auto x_mix = x; auto x_mix = x;
@ -143,7 +144,7 @@ public:
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim] x_mix = mix_block->forward(ctx, backend, x_mix, time_context); // [B * h * w, T, inner_dim]
x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
@ -363,21 +364,23 @@ public:
struct ggml_tensor* attention_layer_forward(std::string name, struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx, struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context, struct ggml_tensor* context,
int timesteps) { int timesteps) {
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]); auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
return block->forward(ctx, x, context, timesteps); return block->forward(ctx, backend, x, context, timesteps);
} else { } else {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]); auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, x, context); return block->forward(ctx, backend, x, context);
} }
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
@ -456,7 +459,7 @@ public:
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w] h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w] h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); // [N, mult*model_channels, h, w]
} }
hs.push_back(h); hs.push_back(h);
} }
@ -475,7 +478,7 @@ public:
// middle_block // middle_block
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (controls.size() > 0) { if (controls.size() > 0) {
@ -507,7 +510,7 @@ public:
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context, num_video_frames); h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames);
up_sample_idx++; up_sample_idx++;
} }
@ -592,6 +595,7 @@ struct UNetModelRunner : public GGMLRunner {
} }
struct ggml_tensor* out = unet.forward(compute_ctx, struct ggml_tensor* out = unet.forward(compute_ctx,
runtime_backend,
x, x,
timesteps, timesteps,
context, context,

29
wan.hpp
View File

@ -1306,6 +1306,7 @@ namespace WAN {
} }
virtual struct ggml_tensor* forward(struct ggml_context* ctx, virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* pe, struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) { struct ggml_tensor* mask = NULL) {
@ -1332,7 +1333,7 @@ namespace WAN {
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1348,6 +1349,7 @@ namespace WAN {
bool flash_attn = false) bool flash_attn = false)
: WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context, struct ggml_tensor* context,
int64_t context_img_len) = 0; int64_t context_img_len) = 0;
@ -1362,6 +1364,7 @@ namespace WAN {
bool flash_attn = false) bool flash_attn = false)
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context, struct ggml_tensor* context,
int64_t context_img_len) { int64_t context_img_len) {
@ -1385,7 +1388,7 @@ namespace WAN {
k = norm_k->forward(ctx, k); k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim] auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1411,6 +1414,7 @@ namespace WAN {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* context, struct ggml_tensor* context,
int64_t context_img_len) { int64_t context_img_len) {
@ -1451,8 +1455,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img); k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] auto img_x = ggml_nn_attention_ext(ctx, backend, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = ggml_add(ctx, x, img_x); x = ggml_add(ctx, x, img_x);
@ -1529,6 +1533,7 @@ namespace WAN {
} }
virtual struct ggml_tensor* forward(struct ggml_context* ctx, virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* e, struct ggml_tensor* e,
struct ggml_tensor* pe, struct ggml_tensor* pe,
@ -1555,14 +1560,14 @@ namespace WAN {
auto y = norm1->forward(ctx, x); auto y = norm1->forward(ctx, x);
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1])); y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1]));
y = modulate_add(ctx, y, es[0]); y = modulate_add(ctx, y, es[0]);
y = self_attn->forward(ctx, y, pe); y = self_attn->forward(ctx, backend, y, pe);
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2])); x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2]));
// cross-attention // cross-attention
x = ggml_add(ctx, x = ggml_add(ctx,
x, x,
cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len)); cross_attn->forward(ctx, backend, norm3->forward(ctx, x), context, context_img_len));
// ffn // ffn
y = norm2->forward(ctx, x); y = norm2->forward(ctx, x);
@ -1605,6 +1610,7 @@ namespace WAN {
} }
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx, std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* c, struct ggml_tensor* c,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* e, struct ggml_tensor* e,
@ -1624,7 +1630,7 @@ namespace WAN {
auto after_proj = std::dynamic_pointer_cast<Linear>(blocks["after_proj"]); auto after_proj = std::dynamic_pointer_cast<Linear>(blocks["after_proj"]);
c = WanAttentionBlock::forward(ctx, c, e, pe, context, context_img_len); c = WanAttentionBlock::forward(ctx, backend, c, e, pe, context, context_img_len);
auto c_skip = after_proj->forward(ctx, c); auto c_skip = after_proj->forward(ctx, c);
return {c_skip, c}; return {c_skip, c};
@ -1865,6 +1871,7 @@ namespace WAN {
} }
struct ggml_tensor* forward_orig(struct ggml_context* ctx, struct ggml_tensor* forward_orig(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
@ -1937,7 +1944,7 @@ namespace WAN {
for (int i = 0; i < params.num_layers; i++) { for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, e0, pe, context, context_img_len); x = block->forward(ctx, backend, x, e0, pe, context, context_img_len);
auto iter = params.vace_layers_mapping.find(i); auto iter = params.vace_layers_mapping.find(i);
if (iter != params.vace_layers_mapping.end()) { if (iter != params.vace_layers_mapping.end()) {
@ -1945,7 +1952,7 @@ namespace WAN {
auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]); auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]);
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto result = vace_block->forward(ctx, backend, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first; auto c_skip = result.first;
c = result.second; c = result.second;
c_skip = ggml_scale(ctx, c_skip, vace_strength); c_skip = ggml_scale(ctx, c_skip, vace_strength);
@ -1959,6 +1966,7 @@ namespace WAN {
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
@ -1995,7 +2003,7 @@ namespace WAN {
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
} }
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
@ -2174,6 +2182,7 @@ namespace WAN {
} }
struct ggml_tensor* out = wan.forward(compute_ctx, struct ggml_tensor* out = wan.forward(compute_ctx,
runtime_backend,
x, x,
timesteps, timesteps,
context, context,