refactor: introduce GGMLRunnerContext (#928)

* introduce GGMLRunnerContext

* add Flash Attention enable control through GGMLRunnerContext

* add conv2d_direct enable control through GGMLRunnerContext
This commit is contained in:
leejet 2025-11-02 02:11:04 +08:00 committed by GitHub
parent c42826b77c
commit 6103d86e2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1079 additions and 1199 deletions

View File

@ -451,16 +451,16 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, n_token, d_model]
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x);
if (use_gelu) {
x = ggml_gelu_inplace(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
} else {
x = ggml_gelu_quick_inplace(ctx, x);
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
}
x = fc2->forward(ctx, x);
return x;
@ -488,15 +488,15 @@ public:
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
// x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
x = ggml_add(ctx->ggml_ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask));
x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
return x;
}
};
@ -517,8 +517,7 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int clip_skip = -1,
bool mask = true) {
@ -536,7 +535,7 @@ public:
}
std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
// LOG_DEBUG("layer %d", i);
}
return x;
@ -578,7 +577,7 @@ public:
return params["token_embedding.weight"];
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* custom_embed_weight) {
// input_ids: [N, n_token]
@ -586,12 +585,12 @@ public:
auto position_embed_weight = params["position_embedding.weight"];
GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]);
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids);
token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
auto token_embedding = ggml_get_rows(ctx->ggml_ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids);
token_embedding = ggml_reshape_3d(ctx->ggml_ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
// token_embedding + position_embedding
auto x = ggml_add(ctx,
auto x = ggml_add(ctx->ggml_ctx,
token_embedding,
position_embed_weight); // [N, n_token, embed_dim]
return x;
@ -629,7 +628,7 @@ public:
num_positions = num_patches + 1;
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values) {
// pixel_values: [N, num_channels, image_size, image_size]
// return: [N, num_positions, embed_dim]
GGML_ASSERT(pixel_values->ne[0] == image_size && pixel_values->ne[1] == image_size && pixel_values->ne[2] == num_channels);
@ -641,18 +640,18 @@ public:
// concat(patch_embedding, class_embedding) + position_embedding
struct ggml_tensor* patch_embedding;
int64_t N = pixel_values->ne[3];
patch_embedding = ggml_ext_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx->ggml_ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx->ggml_ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N);
class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim]
class_embedding = ggml_reshape_4d(ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1]
struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, embed_dim, N);
class_embedding = ggml_repeat(ctx->ggml_ctx, class_embed_weight, class_embedding); // [N, embed_dim]
class_embedding = ggml_reshape_4d(ctx->ggml_ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1]
struct ggml_tensor* x = ggml_concat(ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1]
x = ggml_reshape_3d(ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim]
x = ggml_add(ctx, x, position_embed_weight);
struct ggml_tensor* x = ggml_concat(ctx->ggml_ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1]
x = ggml_reshape_3d(ctx->ggml_ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim]
x = ggml_add(ctx->ggml_ctx, x, position_embed_weight);
return x; // [N, num_positions, embed_dim]
}
};
@ -714,8 +713,7 @@ public:
return embeddings->get_token_embed_weight();
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings,
size_t max_token_idx = 0,
@ -727,16 +725,16 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
if (return_pooled) {
auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
ggml_tensor* pooled = ggml_view_1d(ctx->ggml_ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != nullptr) {
pooled = ggml_ext_linear(ctx, pooled, text_projection, nullptr);
pooled = ggml_ext_linear(ctx->ggml_ctx, pooled, text_projection, nullptr);
} else {
LOG_DEBUG("identity projection");
}
@ -779,8 +777,7 @@ public:
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
@ -792,14 +789,14 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, backend, x, clip_skip, false);
x = encoder->forward(ctx, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) {
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
ggml_tensor* pooled = ggml_cont(ctx->ggml_ctx, ggml_view_2d(ctx->ggml_ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
return pooled; // [N, hidden_size]
} else {
// return x; // [N, n_token, hidden_size]
@ -831,12 +828,12 @@ public:
out_features(out_features),
transpose_weight(transpose_weight) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* w = params["weight"];
if (transpose_weight) {
w = ggml_cont(ctx, ggml_transpose(ctx, w));
w = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, w));
}
return ggml_ext_linear(ctx, x, w, nullptr);
return ggml_ext_linear(ctx->ggml_ctx, x, w, nullptr);
}
};
@ -860,8 +857,7 @@ public:
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
@ -870,7 +866,7 @@ public:
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
if (return_pooled) {
x = visual_projection->forward(ctx, x); // [N, projection_dim]
@ -902,8 +898,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
model.get_param_tensors(tensors, prefix);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings,
size_t max_token_idx = 0,
@ -913,10 +908,10 @@ struct CLIPTextModelRunner : public GGMLRunner {
size_t n_token = input_ids->ne[0];
if (input_ids->ne[0] > model.n_token) {
GGML_ASSERT(input_ids->ne[0] % model.n_token == 0);
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
}
return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@ -943,7 +938,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states);

View File

@ -23,12 +23,12 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
@ -52,11 +52,11 @@ public:
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
return x;
}
@ -121,7 +121,7 @@ public:
}
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) {
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) {
// For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml
// [N, c, t, h, w] => [N, c, t, h * w]
// x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w]
@ -137,32 +137,32 @@ public:
// in_layers
auto h = in_layers_0->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
// emb_layers
if (!skip_t_emb) {
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
auto emb_out = ggml_silu(ctx, emb);
auto emb_out = ggml_silu(ctx->ggml_ctx, emb);
emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels]
if (dims == 2) {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
} else {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
if (exchange_temb_dims) {
// emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
emb_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
}
}
h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
h = ggml_add(ctx->ggml_ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
// out_layers
h = out_layers_0->forward(ctx, h);
h = ggml_silu_inplace(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h);
// dropout, skip for inference
h = out_layers_3->forward(ctx, h);
@ -172,7 +172,7 @@ public:
x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
h = ggml_add(ctx, h, x);
h = ggml_add(ctx->ggml_ctx, h, x);
return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
};
@ -193,24 +193,24 @@ public:
GEGLU(int64_t dim_in, int64_t dim_out)
: dim_in(dim_in), dim_out(dim_out) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
struct ggml_tensor* w = params["proj.weight"];
struct ggml_tensor* b = params["proj.bias"];
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_in = x;
x = ggml_ext_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_ext_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
gate = ggml_gelu_inplace(ctx, gate);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out]
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
return x;
}
@ -222,13 +222,13 @@ public:
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
return x;
}
};
@ -262,7 +262,7 @@ public:
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim]
// return: [ne3, ne2, ne1, dim_out]
@ -281,19 +281,16 @@ protected:
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;
public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head,
bool flash_attn = false)
int64_t d_head)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim),
flash_attn(flash_attn) {
context_dim(context_dim) {
int64_t inner_dim = d_head * n_head;
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
@ -304,8 +301,7 @@ public:
// to_out_1 is nn.Dropout(), skip for inference
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
@ -325,7 +321,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
@ -343,16 +339,15 @@ public:
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false,
bool flash_attn = false)
bool ff_in = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
@ -364,8 +359,7 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
@ -387,21 +381,21 @@ public:
x = norm_in->forward(ctx, x);
x = ff_in->forward(ctx, x);
// self.is_res is always True
x = ggml_add(ctx, x, x_skip);
x = ggml_add(ctx->ggml_ctx, x, x_skip);
}
auto r = x;
x = norm1->forward(ctx, x);
x = attn1->forward(ctx, backend, x, x); // self-attention
x = ggml_add(ctx, x, r);
x = attn1->forward(ctx, x, x); // self-attention
x = ggml_add(ctx->ggml_ctx, x, r);
r = x;
x = norm2->forward(ctx, x);
x = attn2->forward(ctx, backend, x, context); // cross-attention
x = ggml_add(ctx, x, r);
x = attn2->forward(ctx, x, context); // cross-attention
x = ggml_add(ctx->ggml_ctx, x, r);
r = x;
x = norm3->forward(ctx, x);
x = ff->forward(ctx, x);
x = ggml_add(ctx, x, r);
x = ggml_add(ctx->ggml_ctx, x, r);
return x;
}
@ -420,8 +414,7 @@ public:
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim,
bool flash_attn = false)
int64_t context_dim)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
@ -435,14 +428,13 @@ public:
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, in_channels, h, w]
@ -460,23 +452,23 @@ public:
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
x = transformer_block->forward(ctx, backend, x, context);
x = transformer_block->forward(ctx, x, context);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_add(ctx, x, x_in);
x = ggml_add(ctx->ggml_ctx, x, x_in);
return x;
}
};
@ -503,14 +495,14 @@ public:
// since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x_spatial,
struct ggml_tensor* x_temporal) {
// image_only_indicator is always tensor([0.])
float alpha = get_alpha();
auto x = ggml_add(ctx,
ggml_scale(ctx, x_spatial, alpha),
ggml_scale(ctx, x_temporal, 1.0f - alpha));
auto x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x;
}
};
@ -528,7 +520,7 @@ public:
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
@ -546,18 +538,18 @@ public:
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
emb = ggml_reshape_4d(ctx->ggml_ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
x = time_stack->forward(ctx, x, emb); // b t c (h w)
x = time_mixer->forward(ctx, x_mix, x); // b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}

View File

@ -641,7 +641,9 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
pixel_values = to_backend(pixel_values);
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);
auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = vision_model.forward(&runner_ctx, pixel_values, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states);

View File

@ -165,7 +165,7 @@ public:
}
struct ggml_tensor* resblock_forward(std::string name,
struct ggml_context* ctx,
GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb) {
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
@ -173,15 +173,14 @@ public:
}
struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx,
ggml_backend_t backend,
GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, backend, x, context);
return block->forward(ctx, x, context);
}
struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
struct ggml_tensor* input_hint_block_forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hint,
struct ggml_tensor* emb,
struct ggml_tensor* context) {
@ -193,14 +192,13 @@ public:
h = block->forward(ctx, h);
} else {
h = ggml_silu_inplace(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h);
}
}
return h;
}
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::vector<struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* hint,
struct ggml_tensor* guided_hint,
@ -213,13 +211,13 @@ public:
// y: [N, adm_in_channels] or [1, adm_in_channels]
if (context != nullptr) {
if (context->ne[2] != x->ne[3]) {
context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
}
}
if (y != nullptr) {
if (y->ne[1] != x->ne[3]) {
y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
}
}
@ -230,10 +228,10 @@ public:
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);
auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
// SDXL/SVD
@ -242,10 +240,10 @@ public:
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
auto label_emb = label_embed_0->forward(ctx, y);
label_emb = ggml_silu_inplace(ctx, label_emb);
label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb);
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim]
emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim]
}
std::vector<struct ggml_tensor*> outs;
@ -259,7 +257,7 @@ public:
// input block 0
auto h = input_blocks_0_0->forward(ctx, x);
h = ggml_add(ctx, h, guided_hint);
h = ggml_add(ctx->ggml_ctx, h, guided_hint);
outs.push_back(zero_convs_0->forward(ctx, h));
// input block 1-11
@ -274,7 +272,7 @@ public:
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
}
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
@ -299,7 +297,7 @@ public:
// middle_block
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
// out
@ -326,17 +324,6 @@ struct ControlNet : public GGMLRunner {
control_net.init(params_ctx, tensor_types, "");
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
control_net.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
~ControlNet() override {
free_control_ctx();
}
@ -404,8 +391,9 @@ struct ControlNet : public GGMLRunner {
y = to_backend(y);
timesteps = to_backend(timesteps);
auto outs = control_net.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
auto outs = control_net.forward(&runner_ctx,
x,
hint,
guided_hint_cached ? guided_hint : nullptr,

View File

@ -36,6 +36,7 @@ struct DiffusionModel {
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
};
struct UNetModel : public DiffusionModel {
@ -44,9 +45,8 @@ struct UNetModel : public DiffusionModel {
UNetModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
SDVersion version = VERSION_SD1)
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
}
std::string get_desc() override {
@ -77,6 +77,10 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
void set_flash_attn_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -98,9 +102,8 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn = false,
const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
}
std::string get_desc() override {
@ -131,6 +134,10 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
void set_flash_attn_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -153,9 +160,8 @@ struct FluxModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
}
std::string get_desc() override {
@ -186,6 +192,10 @@ struct FluxModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -213,9 +223,8 @@ struct WanModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2,
bool flash_attn = false)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
SDVersion version = VERSION_WAN2)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
}
std::string get_desc() override {
@ -246,6 +255,10 @@ struct WanModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -272,9 +285,8 @@ struct QwenImageModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
SDVersion version = VERSION_QWEN_IMAGE)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
}
std::string get_desc() override {
@ -305,6 +317,10 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,

View File

@ -27,11 +27,11 @@ public:
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [n, num_feat, h, w]
// return: [n, num_feat, h, w]
@ -42,16 +42,16 @@ public:
auto conv5 = std::dynamic_pointer_cast<Conv2d>(blocks["conv5"]);
auto x1 = lrelu(ctx, conv1->forward(ctx, x));
auto x_cat = ggml_concat(ctx, x, x1, 2);
auto x_cat = ggml_concat(ctx->ggml_ctx, x, x1, 2);
auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x2, 2);
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x2, 2);
auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x3, 2);
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x3, 2);
auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x4, 2);
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat);
x5 = ggml_add(ctx, ggml_scale(ctx, x5, 0.2f), x);
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5;
}
};
@ -64,7 +64,7 @@ public:
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [n, num_feat, h, w]
// return: [n, num_feat, h, w]
@ -76,7 +76,7 @@ public:
out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out);
out = ggml_add(ctx, ggml_scale(ctx, out, 0.2f), x);
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
return out;
}
};
@ -112,11 +112,11 @@ public:
int get_scale() { return scale; }
int get_num_block() { return num_block; }
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [n, num_in_ch, h, w]
// return: [n, num_out_ch, h*scale, w*scale]
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
@ -133,14 +133,14 @@ public:
body_feat = block->forward(ctx, body_feat);
}
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx, feat, body_feat);
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
// upsample
if (scale >= 2) {
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
if (scale == 4) {
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
}
}
// for all scales
@ -161,19 +161,6 @@ struct ESRGAN : public GGMLRunner {
// rrdb_net will be created in load_from_file
}
void enable_conv2d_direct() {
if (!rrdb_net)
return;
std::vector<GGMLBlock*> blocks;
rrdb_net->get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() override {
return "esrgan";
}
@ -359,7 +346,9 @@ struct ESRGAN : public GGMLRunner {
constexpr int kGraphNodes = 1 << 16; // 65k
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
x = to_backend(x);
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
auto runner_ctx = get_context();
struct ggml_tensor* out = rrdb_net->forward(&runner_ctx, x);
ggml_build_forward_expand(gf, out);
return gf;
}

342
flux.hpp
View File

@ -19,14 +19,14 @@ namespace Flux {
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [..., in_dim]
// return: [..., hidden_dim]
auto in_layer = std::dynamic_pointer_cast<Linear>(blocks["in_layer"]);
auto out_layer = std::dynamic_pointer_cast<Linear>(blocks["out_layer"]);
x = in_layer->forward(ctx, x);
x = ggml_silu_inplace(ctx, x);
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = out_layer->forward(ctx, x);
return x;
}
@ -48,10 +48,10 @@ namespace Flux {
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* w = params["scale"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
x = ggml_mul(ctx->ggml_ctx, x, w);
return x;
}
};
@ -63,7 +63,7 @@ namespace Flux {
blocks["key_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim));
}
struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* query_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [..., dim]
// return: [..., dim]
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["query_norm"]);
@ -72,7 +72,7 @@ namespace Flux {
return x;
}
struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* key_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [..., dim]
// return: [..., dim]
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["key_norm"]);
@ -85,13 +85,11 @@ namespace Flux {
struct SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
bool flash_attn;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool flash_attn = false)
bool qkv_bias = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
@ -99,30 +97,29 @@ namespace Flux {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]);
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]);
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
return {q, k, v};
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); // [N, n_token, dim]
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask) {
@ -130,7 +127,7 @@ namespace Flux {
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
@ -144,11 +141,11 @@ namespace Flux {
ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr)
: shift(shift), scale(scale), gate(gate) {}
ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) {
ModulationOut(GGMLRunnerContext* ctx, ggml_tensor* vec, int64_t offset) {
int64_t stride = vec->nb[1] * vec->ne[1];
shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
gate = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
}
};
@ -164,16 +161,16 @@ namespace Flux {
blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier));
}
std::vector<ModulationOut> forward(struct ggml_context* ctx, struct ggml_tensor* vec) {
std::vector<ModulationOut> forward(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
// x: [N, dim]
// return: [ModulationOut, ModulationOut]
auto lin = std::dynamic_pointer_cast<Linear>(blocks["lin"]);
auto out = ggml_silu(ctx, vec);
auto out = ggml_silu(ctx->ggml_ctx, vec);
out = lin->forward(ctx, out); // [N, multiplier*dim]
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
auto m = ggml_reshape_3d(ctx->ggml_ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
ModulationOut m_0 = ModulationOut(ctx, m, 0);
if (is_double) {
@ -199,7 +196,6 @@ namespace Flux {
}
struct DoubleStreamBlock : public GGMLBlock {
bool flash_attn;
bool prune_mod;
int idx = 0;
@ -209,15 +205,14 @@ namespace Flux {
float mlp_ratio,
int idx = 0,
bool qkv_bias = false,
bool flash_attn = false,
bool prune_mod = false)
: idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
if (!prune_mod) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
@ -228,7 +223,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
@ -236,7 +231,7 @@ namespace Flux {
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
}
std::vector<ModulationOut> get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
// TODO: not hardcoded?
const int single_blocks_count = 38;
const int double_blocks_count = 19;
@ -245,7 +240,7 @@ namespace Flux {
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
}
std::vector<ModulationOut> get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
std::vector<ModulationOut> get_distil_txt_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
// TODO: not hardcoded?
const int single_blocks_count = 38;
const int double_blocks_count = 19;
@ -254,8 +249,7 @@ namespace Flux {
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* vec,
@ -300,7 +294,7 @@ namespace Flux {
// prepare image for attention
auto img_modulated = img_norm1->forward(ctx, img);
img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale);
img_modulated = Flux::modulate(ctx->ggml_ctx, img_modulated, img_mod1.shift, img_mod1.scale);
auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head]
auto img_q = img_qkv[0];
auto img_k = img_qkv[1];
@ -308,20 +302,20 @@ namespace Flux {
// prepare txt for attention
auto txt_modulated = txt_norm1->forward(ctx, txt);
txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale);
txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale);
auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head]
auto txt_q = txt_qkv[0];
auto txt_k = txt_qkv[1];
auto txt_v = txt_qkv[2];
// run actual attention
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -329,8 +323,8 @@ namespace Flux {
attn->nb[1],
attn->nb[2],
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx,
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -338,25 +332,25 @@ namespace Flux {
attn->nb[1],
attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out);
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
// calculate the txt bloks
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out);
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
return {img, txt};
}
@ -367,7 +361,6 @@ namespace Flux {
int64_t num_heads;
int64_t hidden_size;
int64_t mlp_hidden_dim;
bool flash_attn;
bool prune_mod;
int idx = 0;
@ -377,9 +370,8 @@ namespace Flux {
float mlp_ratio = 4.0f,
int idx = 0,
float qk_scale = 0.f,
bool flash_attn = false,
bool prune_mod = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
@ -397,13 +389,12 @@ namespace Flux {
}
}
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
int64_t offset = 3 * idx;
return ModulationOut(ctx, vec, offset);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* vec,
struct ggml_tensor* pe,
@ -424,11 +415,11 @@ namespace Flux {
mod = modulation->forward(ctx, vec)[0];
}
auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto qkv = ggml_view_3d(ctx,
auto qkv = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
@ -436,8 +427,8 @@ namespace Flux {
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx,
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
@ -445,21 +436,21 @@ namespace Flux {
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token]
mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim]
auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, ggml_gelu_inplace(ctx->ggml_ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate));
output = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, output, mod.gate));
return output;
}
};
@ -480,16 +471,16 @@ namespace Flux {
}
}
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
int64_t offset = vec->ne[2] - 2;
int64_t stride = vec->nb[1] * vec->ne[1];
auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
auto shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
auto scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
// No gate
return {shift, scale, nullptr};
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
@ -505,16 +496,16 @@ namespace Flux {
} else {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
}
x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale);
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);
return x;
@ -533,7 +524,7 @@ namespace Flux {
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(inner_size, hidden_size, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks["in_proj"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
@ -541,7 +532,7 @@ namespace Flux {
for (int i = 0; i < n_layers; i++) {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norms." + std::to_string(i)]);
auto embed = std::dynamic_pointer_cast<MLPEmbedder>(blocks["layers." + std::to_string(i)]);
x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
x = ggml_add_inplace(ctx->ggml_ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
}
x = out_proj->forward(ctx, x);
@ -556,7 +547,7 @@ namespace Flux {
blocks["embedder.0"] = std::make_shared<Linear>(in_channels + max_freqs * max_freqs, hidden_size_input);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* dct) {
// x: (B, P^2, C)
@ -564,8 +555,8 @@ namespace Flux {
// return: (B, P^2, hidden_size_input)
auto embedder = std::dynamic_pointer_cast<Linear>(blocks["embedder.0"]);
dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
x = ggml_concat(ctx, x, dct, 0);
dct = ggml_repeat_4d(ctx->ggml_ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
x = ggml_concat(ctx->ggml_ctx, x, dct, 0);
x = embedder->forward(ctx, x);
return x;
@ -583,7 +574,7 @@ namespace Flux {
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size_x);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* s) {
// x: (batch_size, n_token, hidden_size_x)
@ -596,31 +587,31 @@ namespace Flux {
int64_t hidden_size_x = x->ne[0];
auto mlp_params = param_generator->forward(ctx, s);
auto fc_params = ggml_ext_chunk(ctx, mlp_params, 3, 0);
auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
auto fc_params = ggml_ext_chunk(ctx->ggml_ctx, mlp_params, 3, 0);
auto fc1_gate = ggml_reshape_3d(ctx->ggml_ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc1_value = ggml_reshape_3d(ctx->ggml_ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc2 = ggml_reshape_3d(ctx->ggml_ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
fc1_gate = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f);
fc1_value = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f);
fc2 = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
fc2 = ggml_l2_norm(ctx, fc2, 1e-12f);
fc1_gate = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_gate = ggml_l2_norm(ctx->ggml_ctx, fc1_gate, 1e-12f);
fc1_value = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_value = ggml_l2_norm(ctx->ggml_ctx, fc1_value, 1e-12f);
fc2 = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
fc2 = ggml_l2_norm(ctx->ggml_ctx, fc2, 1e-12f);
auto res_x = x;
x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x]
auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
x1 = ggml_silu_inplace(ctx, x1);
auto x1 = ggml_mul_mat(ctx->ggml_ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
x1 = ggml_silu_inplace(ctx->ggml_ctx, x1);
auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
auto x2 = ggml_mul_mat(ctx->ggml_ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
x = ggml_mul_inplace(ctx->ggml_ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
x = ggml_mul_mat(ctx->ggml_ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
x = ggml_add_inplace(ctx, x, res_x);
x = ggml_add_inplace(ctx->ggml_ctx, x, res_x);
return x;
}
@ -633,7 +624,7 @@ namespace Flux {
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
@ -652,15 +643,15 @@ namespace Flux {
blocks["conv"] = std::make_shared<Conv2d>(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1});
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
x = norm->forward(ctx, x);
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
x = conv->forward(ctx, x);
return x;
@ -692,7 +683,6 @@ namespace Flux {
int theta = 10000;
bool qkv_bias = true;
bool guidance_embed = true;
bool flash_attn = true;
int64_t in_dim = 64;
ChromaRadianceParams chroma_radiance_params;
};
@ -731,7 +721,6 @@ namespace Flux {
params.mlp_ratio,
i,
params.qkv_bias,
params.flash_attn,
params.is_chroma);
}
@ -741,7 +730,6 @@ namespace Flux {
params.mlp_ratio,
i,
0.f,
params.flash_attn,
params.is_chroma);
}
@ -828,8 +816,7 @@ namespace Flux {
return x;
}
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* timesteps,
@ -851,41 +838,41 @@ namespace Flux {
if (params.is_chroma) {
int64_t mod_index_length = 344;
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
auto distill_timestep = ggml_ext_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
auto distill_guidance = ggml_ext_timestep_embedding(ctx, guidance, 16, 10000, 1000.f);
auto distill_timestep = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 16, 10000, 1000.f);
auto distill_guidance = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 16, 10000, 1000.f);
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
GGML_ASSERT(mod_index_arange != nullptr);
auto modulation_index = ggml_ext_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
auto modulation_index = ggml_ext_timestep_embedding(ctx->ggml_ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
// Batch broadcast (will it ever be useful)
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
modulation_index = ggml_repeat(ctx->ggml_ctx, modulation_index, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
auto timestep_guidance = ggml_concat(ctx->ggml_ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
timestep_guidance = ggml_repeat(ctx->ggml_ctx, timestep_guidance, modulation_index); // [N, 344, 32]
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
vec = ggml_concat(ctx->ggml_ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
// Permute for consistency with non-distilled modulation implementation
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
vec = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
if (y != nullptr) {
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
txt_img_mask = ggml_pad(ctx->ggml_ctx, y, img->ne[1], 0, 0, 0);
}
} else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) {
GGML_ASSERT(guidance != nullptr);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
auto g_in = ggml_ext_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
auto g_in = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in));
}
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y));
}
txt = txt_in->forward(ctx, txt);
@ -897,23 +884,23 @@ namespace Flux {
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
auto img_txt = block->forward(ctx, backend, img, txt, vec, pe, txt_img_mask);
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask);
img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size]
}
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
for (int i = 0; i < params.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
continue;
}
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, backend, txt_img, vec, pe, txt_img_mask);
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask);
}
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx,
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx,
txt_img,
txt_img->ne[0],
txt_img->ne[1],
@ -921,7 +908,7 @@ namespace Flux {
txt_img->nb[1],
txt_img->nb[2],
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -930,8 +917,7 @@ namespace Flux {
return img;
}
struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -952,31 +938,31 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = pad_to_patch_size(ctx, x);
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
auto orig_img = img;
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
img = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
// nerf decode
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
int64_t num_patches = nerf_pixels->ne[1];
nerf_pixels = ggml_reshape_3d(ctx,
nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx,
nerf_pixels,
nerf_pixels->ne[0] / C,
C,
nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
nerf_pixels = ggml_cont(ctx, ggml_ext_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
nerf_pixels = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
auto nerf_hidden = ggml_reshape_2d(ctx->ggml_ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
@ -985,17 +971,16 @@ namespace Flux {
img_dct = block->forward(ctx, img_dct, nerf_hidden);
}
img_dct = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
return out;
}
struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_flux_chroma(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -1016,58 +1001,57 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = process_img(ctx, x);
auto img = process_img(ctx->ggml_ctx, x);
uint64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);
masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);
control = process_img(ctx, control);
masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
control = process_img(ctx->ggml_ctx, control);
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr);
auto control = process_img(ctx, c_concat);
img = ggml_concat(ctx, img, control, 0);
auto control = process_img(ctx->ggml_ctx, c_concat);
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
}
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
ref = process_img(ctx->ggml_ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
return out;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -1091,7 +1075,6 @@ namespace Flux {
if (params.version == VERSION_CHROMA_RADIANCE) {
return forward_chroma_radiance(ctx,
backend,
x,
timestep,
context,
@ -1105,7 +1088,6 @@ namespace Flux {
skip_layers);
} else {
return forward_flux_chroma(ctx,
backend,
x,
timestep,
context,
@ -1136,11 +1118,9 @@ namespace Flux {
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
flux_params.version = version;
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
@ -1323,8 +1303,9 @@ namespace Flux {
set_backend_tensor_data(dct, dct_vec.data());
}
struct ggml_tensor* out = flux.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
struct ggml_tensor* out = flux.forward(&runner_ctx,
x,
timesteps,
context,
@ -1435,8 +1416,7 @@ namespace Flux {
tensor_types,
"model.diffusion_model",
VERSION_CHROMA_RADIANCE,
false,
true);
false);
flux->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;

View File

@ -1157,8 +1157,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false,
bool skip_reshape = false,
bool flash_attn = false, // avoid overflow
float kv_scale = 1.0f) {
bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow
int64_t L_q;
int64_t L_k;
int64_t C;
@ -1462,6 +1462,13 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
typedef std::map<std::string, enum ggml_type> String2GGMLType;
struct GGMLRunnerContext {
ggml_backend_t backend = nullptr;
ggml_context* ggml_ctx = nullptr;
bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false;
};
struct GGMLRunner {
protected:
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
@ -1488,6 +1495,9 @@ protected:
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
const std::string final_result_name = "ggml_runner_final_result_tensor";
bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false;
void alloc_params_ctx() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
@ -1744,6 +1754,15 @@ public:
free_cache_ctx_and_buffer();
}
virtual GGMLRunnerContext get_context() {
GGMLRunnerContext runner_ctx;
runner_ctx.ggml_ctx = compute_ctx;
runner_ctx.backend = runtime_backend;
runner_ctx.flash_attn_enabled = flash_attn_enabled;
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
return runner_ctx;
}
void reset_compute_ctx() {
free_compute_ctx();
alloc_compute_ctx();
@ -1864,6 +1883,14 @@ public:
free_compute_buffer();
}
}
void set_flash_attention_enabled(bool enabled) {
flash_attn_enabled = enabled;
}
void set_conv2d_direct_enabled(bool enabled) {
conv2d_direct_enabled = enabled;
}
};
class GGMLBlock {
@ -1955,12 +1982,12 @@ public:
class UnaryBlock : public GGMLBlock {
public:
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) = 0;
};
class Identity : public UnaryBlock {
public:
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
return x;
}
};
@ -1974,7 +2001,7 @@ protected:
bool force_prec_f32;
float scale;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
@ -2000,13 +2027,13 @@ public:
force_prec_f32(force_prec_f32),
scale(scale) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
return ggml_ext_linear(ctx, x, w, b, force_prec_f32, scale);
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
}
};
@ -2022,7 +2049,7 @@ class Embedding : public UnaryBlock {
protected:
int64_t embedding_dim;
int64_t num_embeddings;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
if (!support_get_rows(wtype)) {
wtype = GGML_TYPE_F32;
@ -2036,7 +2063,7 @@ public:
num_embeddings(num_embeddings) {
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids) {
// input_ids: [N, n_token]
auto weight = params["weight"];
@ -2044,11 +2071,11 @@ public:
// There are issues with ggml batch inference, so we are expanding it here first.
// TODO: fix ggml batch inference
int64_t n = input_ids->ne[1];
input_ids = ggml_reshape_1d(ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]);
input_ids = ggml_reshape_1d(ctx->ggml_ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]);
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
auto embedding = ggml_get_rows(ctx, weight, input_ids);
embedding = ggml_reshape_3d(ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n);
input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
auto embedding = ggml_get_rows(ctx->ggml_ctx, weight, input_ids);
embedding = ggml_reshape_3d(ctx->ggml_ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n);
// [N, n_token, embedding_dim]
return embedding;
@ -2064,10 +2091,9 @@ protected:
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
bool direct = false;
float scale = 1.f;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
if (bias) {
@ -2092,10 +2118,6 @@ public:
dilation(dilation),
bias(bias) {}
void enable_direct() {
direct = true;
}
void set_scale(float scale_value) {
scale = scale_value;
}
@ -2104,13 +2126,13 @@ public:
return "Conv2d";
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
return ggml_ext_conv_2d(ctx,
return ggml_ext_conv_2d(ctx->ggml_ctx,
x,
w,
b,
@ -2120,7 +2142,7 @@ public:
padding.first,
dilation.second,
dilation.first,
direct,
ctx->conv2d_direct_enabled,
scale);
}
};
@ -2135,7 +2157,7 @@ protected:
int64_t dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
if (bias) {
@ -2162,13 +2184,13 @@ public:
// x: [N, IC, ID, IH*IW]
// result: [N, OC, OD, OH*OW]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
return ggml_ext_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation);
return ggml_ext_conv_3d_nx1x1(ctx->ggml_ctx, x, w, b, stride, padding, dilation);
}
};
@ -2182,7 +2204,7 @@ protected:
std::tuple<int, int, int> dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx,
wtype,
@ -2211,13 +2233,13 @@ public:
dilation(dilation),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
@ -2231,7 +2253,7 @@ protected:
bool elementwise_affine;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
if (elementwise_affine) {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
@ -2252,7 +2274,7 @@ public:
elementwise_affine(elementwise_affine),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = nullptr;
struct ggml_tensor* b = nullptr;
@ -2262,7 +2284,7 @@ public:
b = params["bias"];
}
}
return ggml_ext_layer_norm(ctx, x, w, b, eps);
return ggml_ext_layer_norm(ctx->ggml_ctx, x, w, b, eps);
}
};
@ -2273,7 +2295,7 @@ protected:
float eps;
bool affine;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
if (affine) {
enum ggml_type wtype = GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32;
@ -2292,14 +2314,14 @@ public:
eps(eps),
affine(affine) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = nullptr;
struct ggml_tensor* b = nullptr;
if (affine) {
w = params["weight"];
b = params["bias"];
}
return ggml_ext_group_norm(ctx, x, w, b, num_groups);
return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups);
}
};
@ -2314,7 +2336,7 @@ protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -2325,10 +2347,10 @@ public:
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul_inplace(ctx, x, w);
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
x = ggml_mul_inplace(ctx->ggml_ctx, x, w);
return x;
}
};
@ -2364,8 +2386,7 @@ public:
}
// x: [N, n_token, embed_dim]
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
bool mask = false) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
@ -2377,7 +2398,7 @@ public:
struct ggml_tensor* k = k_proj->forward(ctx, x);
struct ggml_tensor* v = v_proj->forward(ctx, x);
x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;

View File

@ -27,7 +27,7 @@ namespace LTXV {
bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
bool causal = true) {
// x: [N*IC, ID, IH, IW]

229
mmdit.hpp
View File

@ -27,13 +27,13 @@ public:
blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_features, out_features, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, n_token, in_features]
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = fc2->forward(ctx, x);
return x;
}
@ -72,7 +72,7 @@ public:
bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, H*W, embed_dim]
auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]);
@ -82,13 +82,13 @@ public:
int64_t H = x->ne[1];
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
x = ggml_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
}
x = proj->forward(ctx, x);
if (flatten) {
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3));
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3));
}
return x;
}
@ -107,16 +107,16 @@ public:
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) {
// t: [N, ]
// return: [N, hidden_size]
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
auto t_freq = ggml_ext_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
auto t_freq = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
auto t_emb = mlp_0->forward(ctx, t_freq);
t_emb = ggml_silu_inplace(ctx, t_emb);
t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb);
t_emb = mlp_2->forward(ctx, t_emb);
return t_emb;
}
@ -131,14 +131,14 @@ public:
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, input_dim]
// return: [N, hidden_size]
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = mlp_0->forward(ctx, x);
x = ggml_silu_inplace(ctx, x);
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
@ -149,16 +149,14 @@ public:
int64_t num_heads;
bool pre_only;
std::string qk_norm;
bool flash_attn;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) {
@ -173,14 +171,14 @@ public:
}
}
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = qkv_vec[2]; // [N, n_token, n_head*d_head]
if (qk_norm == "rms" || qk_norm == "ln") {
@ -190,13 +188,13 @@ public:
k = ln_k->forward(ctx, k);
}
q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
return {q, k, v};
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
GGML_ASSERT(!pre_only);
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
@ -206,11 +204,10 @@ public:
}
// x: [N, n_token, dim]
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
@ -236,7 +233,6 @@ public:
int64_t num_heads;
bool pre_only;
bool self_attn;
bool flash_attn;
public:
DismantledBlock(int64_t hidden_size,
@ -245,17 +241,16 @@ public:
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn = false,
bool flash_attn = false)
bool self_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase
// scale_mod_only is always Flase
// swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
}
if (!pre_only) {
@ -274,7 +269,7 @@ public:
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
}
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
std::tuple<std::vector<ggml_tensor*>, std::vector<ggml_tensor*>, std::vector<ggml_tensor*>> pre_attention_x(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
GGML_ASSERT(self_attn);
@ -286,35 +281,35 @@ public:
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x);
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
auto attn_in = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
auto attn2_in = modulate(ctx->ggml_ctx, x_norm, shift_msa2, scale_msa2);
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
}
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
@ -327,33 +322,33 @@ public:
if (pre_only) {
n_mods = 2;
}
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
if (!pre_only) {
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}};
} else {
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
return {qkv, {nullptr, nullptr, nullptr, nullptr, nullptr}};
}
}
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
struct ggml_tensor* post_attention_x(GGMLRunnerContext* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* attn2_out,
struct ggml_tensor* x,
@ -376,22 +371,22 @@ public:
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
gate_msa2 = ggml_reshape_3d(ctx->ggml_ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
attn_out = attn->post_attention(ctx, attn_out);
attn2_out = attn2->post_attention(ctx, attn2_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn2_out, gate_msa2));
auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* post_attention(struct ggml_context* ctx,
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* x,
struct ggml_tensor* gate_msa,
@ -411,20 +406,19 @@ public:
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
attn_out = attn->post_attention(ctx, attn_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
@ -441,8 +435,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
@ -458,7 +452,7 @@ public:
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
@ -472,9 +466,7 @@ public:
};
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx,
ggml_backend_t backend,
bool flash_attn,
block_mixing(GGMLRunnerContext* ctx,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
@ -501,12 +493,12 @@ block_mixing(struct ggml_context* ctx,
}
std::vector<struct ggml_tensor*> qkv;
for (int i = 0; i < 3; i++) {
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
}
auto attn = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx,
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -514,8 +506,8 @@ block_mixing(struct ggml_context* ctx,
attn->nb[1],
attn->nb[2],
0); // [n_context, N, hidden_size]
context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx,
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -523,7 +515,7 @@ block_mixing(struct ggml_context* ctx,
attn->nb[1],
attn->nb[2],
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
if (!context_block->pre_only) {
context = context_block->post_attention(ctx,
@ -538,7 +530,7 @@ block_mixing(struct ggml_context* ctx,
}
if (x_block->self_attn) {
auto attn2 = ggml_ext_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx,
x_attn,
@ -563,8 +555,6 @@ block_mixing(struct ggml_context* ctx,
}
struct JointBlock : public GGMLBlock {
bool flash_attn;
public:
JointBlock(int64_t hidden_size,
int64_t num_heads,
@ -572,22 +562,19 @@ public:
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn_x = false,
bool flash_attn = false)
: flash_attn(flash_attn) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c) {
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
return block_mixing(ctx, context, x, c, context_block, x_block);
}
};
@ -603,7 +590,7 @@ public:
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
@ -613,15 +600,15 @@ public:
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = modulate(ctx, norm_final->forward(ctx, x), shift, scale);
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);
return x;
@ -645,7 +632,6 @@ protected:
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
bool flash_attn = false;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
@ -653,8 +639,7 @@ protected:
}
public:
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
: flash_attn(flash_attn) {
MMDiT(const String2GGMLType& tensor_types = {}) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
@ -722,8 +707,7 @@ public:
qk_norm,
true,
i == depth - 1,
i <= d_self,
flash_attn));
i <= d_self));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
@ -791,8 +775,7 @@ public:
return x;
}
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c_mod,
struct ggml_tensor* context,
@ -811,7 +794,7 @@ public:
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
auto context_x = block->forward(ctx, backend, context, x, c_mod);
auto context_x = block->forward(ctx, context, x, c_mod);
context = context_x.first;
x = context_x.second;
}
@ -821,8 +804,7 @@ public:
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* t,
struct ggml_tensor* y = nullptr,
@ -841,15 +823,15 @@ public:
int64_t h = x->ne[1];
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
auto pos_embed = cropped_pos_embed(ctx, h, w); // [1, H*W, hidden_size]
x = ggml_add(ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size]
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
if (y != nullptr && adm_in_channels != -1) {
auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]);
y = y_embedder->forward(ctx, y); // [N, hidden_size]
c = ggml_add(ctx, c, y);
c = ggml_add(ctx->ggml_ctx, c, y);
}
if (context != nullptr) {
@ -858,9 +840,9 @@ public:
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
}
x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
x = unpatchify(ctx, x, h, w); // [N, C, H, W]
x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W]
return x;
}
@ -870,10 +852,9 @@ struct MMDiTRunner : public GGMLRunner {
MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
}
@ -897,8 +878,8 @@ struct MMDiTRunner : public GGMLRunner {
y = to_backend(y);
timesteps = to_backend(timesteps);
struct ggml_tensor* out = mmdit.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
struct ggml_tensor* out = mmdit.forward(&runner_ctx,
x,
timesteps,
y,
@ -972,7 +953,7 @@ struct MMDiTRunner : public GGMLRunner {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<MMDiTRunner> mmdit = std::make_shared<MMDiTRunner>(backend, false, false);
std::shared_ptr<MMDiTRunner> mmdit = std::make_shared<MMDiTRunner>(backend, false);
{
LOG_INFO("loading from '%s'", file_path.c_str());

120
pmid.hpp
View File

@ -21,7 +21,7 @@ public:
blocks["layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
@ -33,11 +33,11 @@ public:
x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue)
x = ggml_add(ctx, x, r);
x = ggml_add(ctx->ggml_ctx, x, r);
return x;
}
};
@ -54,7 +54,7 @@ public:
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
@ -100,7 +100,7 @@ public:
ggml_cont(ctx, tli)};
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* latents) {
// x (torch.Tensor): image features
@ -118,33 +118,33 @@ public:
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto q = to_q->forward(ctx, latents);
auto kv_input = ggml_concat(ctx, x, latents, 1);
auto kv_input = ggml_concat(ctx->ggml_ctx, x, latents, 1);
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
auto kv = to_kv->forward(ctx, kv_input);
auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
k = ggml_cont(ctx, k);
v = ggml_cont(ctx, v);
q = reshape_tensor(ctx, q, heads);
k = reshape_tensor(ctx, k, heads);
v = reshape_tensor(ctx, v, heads);
auto k = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
auto v = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
k = ggml_cont(ctx->ggml_ctx, k);
v = ggml_cont(ctx->ggml_ctx, v);
q = reshape_tensor(ctx->ggml_ctx, q, heads);
k = reshape_tensor(ctx->ggml_ctx, k, heads);
v = reshape_tensor(ctx->ggml_ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx, k, scale);
q = ggml_scale_inplace(ctx, q, scale);
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale);
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale);
// auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
// in this case, dimension along which Softmax will be computed is the last dim
// in torch and the first dim in GGML, consistent with the convention that pytorch's
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
// weight = ggml_soft_max(ctx, weight);
weight = ggml_soft_max_inplace(ctx, weight);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
weight = ggml_soft_max_inplace(ctx->ggml_ctx, weight);
v = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, v));
// auto out = ggml_mul_mat(ctx, weight, v);
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
auto out = ggml_mul_mat(ctx->ggml_ctx, v, weight); // NOTE order of mul is opposite to pytorch
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3));
out = ggml_reshape_3d(ctx->ggml_ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
out = to_out->forward(ctx, out);
return out;
@ -176,7 +176,7 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* latents,
struct ggml_tensor* x) {
// x: [N, channels, h, w]
@ -191,9 +191,9 @@ public:
name = "layers." + std::to_string(i) + ".1";
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
auto t = attn->forward(ctx, x, latents);
latents = ggml_add(ctx, t, latents);
latents = ggml_add(ctx->ggml_ctx, t, latents);
t = ff->forward(ctx, latents);
latents = ggml_add(ctx, t, latents);
latents = ggml_add(ctx->ggml_ctx, t, latents);
}
latents = proj_out->forward(ctx, latents);
latents = norm_out->forward(ctx, latents);
@ -225,7 +225,7 @@ public:
4));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* last_hidden_state) {
// x: [N, channels, h, w]
@ -235,11 +235,11 @@ public:
x = token_proj->forward(ctx, x);
int64_t nel = ggml_nelements(x);
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
x = ggml_reshape_3d(ctx->ggml_ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
x = token_norm->forward(ctx, x);
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
if (use_residul)
out = ggml_add(ctx, x, out);
out = ggml_add(ctx->ggml_ctx, x, out);
return out;
}
};
@ -256,24 +256,24 @@ public:
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(embed_dim));
}
struct ggml_tensor* fuse_fn(struct ggml_context* ctx,
struct ggml_tensor* fuse_fn(GGMLRunnerContext* ctx,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* id_embeds) {
auto mlp1 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp1"]);
auto mlp2 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp2"]);
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm"]);
auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0);
auto stacked_id_embeds = ggml_concat(ctx->ggml_ctx, prompt_embeds, id_embeds, 0);
stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds);
stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds);
stacked_id_embeds = ggml_add(ctx->ggml_ctx, stacked_id_embeds, prompt_embeds);
stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds);
stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds);
return stacked_id_embeds;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* id_embeds,
struct ggml_tensor* class_tokens_mask,
@ -286,25 +286,25 @@ public:
// # slice out the image token embeddings
ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos");
ggml_set_name(prompt_embeds, "prompt_embeds");
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx->ggml_ctx, prompt_embeds, class_tokens_mask_pos);
ggml_set_name(image_token_embeds, "image_token_embeds");
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
valid_id_embeds = ggml_reshape_2d(ctx->ggml_ctx, valid_id_embeds, valid_id_embeds->ne[0],
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
if (left && right) {
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1);
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1);
} else if (left) {
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1);
} else if (right) {
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1);
}
class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask));
class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds);
prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask);
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds);
class_tokens_mask = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, class_tokens_mask));
class_tokens_mask = ggml_repeat(ctx->ggml_ctx, class_tokens_mask, prompt_embeds);
prompt_embeds = ggml_mul(ctx->ggml_ctx, prompt_embeds, class_tokens_mask);
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx->ggml_ctx, prompt_embeds, stacked_id_embeds);
ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds");
return updated_prompt_embeds;
}
@ -317,8 +317,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* class_tokens_mask,
@ -331,15 +330,15 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]);
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, backend, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)]
struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280]
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3));
id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 2, 0, 1, 3));
id_embeds_2 = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds_2, 2, 0, 1, 3));
id_embeds = ggml_concat(ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 1, 2, 0, 3));
id_embeds = ggml_concat(ctx->ggml_ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 1, 2, 0, 3));
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
prompt_embeds,
@ -366,8 +365,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
num_tokens));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* class_tokens_mask,
@ -381,7 +379,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, backend, id_pixel_values, false); // [N, hidden_size]
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
@ -458,7 +456,7 @@ public:
zeros_right.clear();
zeros_right_16.clear();
ggml_context* ctx0 = compute_ctx;
auto runner_ctx = get_context();
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
@ -466,7 +464,7 @@ public:
int64_t seq_length = prompt_embeds->ne[1];
ggml_type type = GGML_TYPE_F32;
struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(ctx0, type, class_tokens_mask.size());
struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(runner_ctx.ggml_ctx, type, class_tokens_mask.size());
struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
@ -488,16 +486,16 @@ public:
}
// printf("\n");
if (ctmpos[0] > 0) {
// left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]);
left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1);
// left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, 1, ctmpos[0]);
left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, ctmpos[0], 1);
}
if (ctmpos[ctmpos.size() - 1] < seq_length - 1) {
// right = ggml_new_tensor_3d(ctx0, type,
// right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type,
// hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
right = ggml_new_tensor_3d(ctx0, type,
right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type,
hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1);
}
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size());
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(runner_ctx.ggml_ctx, GGML_TYPE_I32, ctmpos.size());
{
if (type == GGML_TYPE_F16)
@ -530,16 +528,14 @@ public:
}
struct ggml_tensor* updated_prompt_embeds = nullptr;
if (pm_version == PM_VERSION_1)
updated_prompt_embeds = id_encoder.forward(ctx0,
runtime_backend,
updated_prompt_embeds = id_encoder.forward(&runner_ctx,
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,
class_tokens_mask_pos,
left, right);
else if (pm_version == PM_VERSION_2)
updated_prompt_embeds = id_encoder2.forward(ctx0,
runtime_backend,
updated_prompt_embeds = id_encoder2.forward(&runner_ctx,
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,

View File

@ -27,18 +27,18 @@ namespace Qwen {
blocks["linear_2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, out_dim, sample_proj_bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* sample,
struct ggml_tensor* condition = nullptr) {
if (condition != nullptr) {
auto cond_proj = std::dynamic_pointer_cast<Linear>(blocks["cond_proj"]);
sample = ggml_add(ctx, sample, cond_proj->forward(ctx, condition));
sample = ggml_add(ctx->ggml_ctx, sample, cond_proj->forward(ctx, condition));
}
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
sample = linear_1->forward(ctx, sample);
sample = ggml_silu_inplace(ctx, sample);
sample = ggml_silu_inplace(ctx->ggml_ctx, sample);
sample = linear_2->forward(ctx, sample);
return sample;
}
@ -50,13 +50,13 @@ namespace Qwen {
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* timesteps) {
// timesteps: [N,]
// return: [N, embedding_dim]
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
auto timesteps_proj = ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
return timesteps_emb;
}
@ -65,7 +65,6 @@ namespace Qwen {
struct QwenImageAttention : public GGMLBlock {
protected:
int64_t dim_head;
bool flash_attn;
public:
QwenImageAttention(int64_t query_dim,
@ -75,9 +74,8 @@ namespace Qwen {
int64_t out_context_dim = 0,
bool bias = true,
bool out_bias = true,
float eps = 1e-6,
bool flash_attn = false)
: dim_head(dim_head), flash_attn(flash_attn) {
float eps = 1e-6)
: dim_head(dim_head) {
int64_t inner_dim = out_dim > 0 ? out_dim : dim_head * num_heads;
out_dim = out_dim > 0 ? out_dim : query_dim;
out_context_dim = out_context_dim > 0 ? out_context_dim : query_dim;
@ -105,8 +103,7 @@ namespace Qwen {
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
}
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* pe,
@ -138,32 +135,32 @@ namespace Qwen {
auto img_q = to_q->forward(ctx, img);
int64_t num_heads = img_q->ne[0] / dim_head;
img_q = ggml_reshape_4d(ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
img_q = ggml_reshape_4d(ctx->ggml_ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
auto img_k = to_k->forward(ctx, img);
img_k = ggml_reshape_4d(ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
img_k = ggml_reshape_4d(ctx->ggml_ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
auto img_v = to_v->forward(ctx, img);
img_v = ggml_reshape_4d(ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
img_v = ggml_reshape_4d(ctx->ggml_ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
img_q = norm_q->forward(ctx, img_q);
img_k = norm_k->forward(ctx, img_k);
auto txt_q = add_q_proj->forward(ctx, txt);
txt_q = ggml_reshape_4d(ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
txt_q = ggml_reshape_4d(ctx->ggml_ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
auto txt_k = add_k_proj->forward(ctx, txt);
txt_k = ggml_reshape_4d(ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
txt_k = ggml_reshape_4d(ctx->ggml_ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
auto txt_v = add_v_proj->forward(ctx, txt);
txt_v = ggml_reshape_4d(ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
txt_v = ggml_reshape_4d(ctx->ggml_ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
txt_q = norm_added_q->forward(ctx, txt_q);
txt_k = norm_added_k->forward(ctx, txt_k);
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -171,8 +168,8 @@ namespace Qwen {
attn->nb[1],
attn->nb[2],
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx,
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
@ -180,7 +177,7 @@ namespace Qwen {
attn->nb[1],
attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img_attn_out = to_out_0->forward(ctx, img_attn_out);
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
@ -194,8 +191,7 @@ namespace Qwen {
QwenImageTransformerBlock(int64_t dim,
int64_t num_attention_heads,
int64_t attention_head_dim,
float eps = 1e-6,
bool flash_attn = false) {
float eps = 1e-6) {
// img_mod.0 is nn.SiLU()
blocks["img_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
@ -217,12 +213,10 @@ namespace Qwen {
0, // out_context-dim
true, // bias
true, // out_bias
eps,
flash_attn));
eps));
}
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* t_emb,
@ -244,40 +238,40 @@ namespace Qwen {
auto attn = std::dynamic_pointer_cast<QwenImageAttention>(blocks["attn"]);
auto img_mod_params = ggml_silu(ctx, t_emb);
auto img_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
img_mod_params = img_mod_1->forward(ctx, img_mod_params);
auto img_mod_param_vec = ggml_ext_chunk(ctx, img_mod_params, 6, 0);
auto img_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, img_mod_params, 6, 0);
auto txt_mod_params = ggml_silu(ctx, t_emb);
auto txt_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
auto txt_mod_param_vec = ggml_ext_chunk(ctx, txt_mod_params, 6, 0);
auto txt_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, txt_mod_params, 6, 0);
auto img_normed = img_norm1->forward(ctx, img);
auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
auto img_gate1 = img_mod_param_vec[2];
auto txt_normed = txt_norm1->forward(ctx, txt);
auto txt_modulated = Flux::modulate(ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]);
auto txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]);
auto txt_gate1 = txt_mod_param_vec[2];
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe);
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, img_modulated, txt_modulated, pe);
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1));
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn_output, img_gate1));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn_output, txt_gate1));
auto img_normed2 = img_norm2->forward(ctx, img);
auto img_modulated2 = Flux::modulate(ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
auto img_gate2 = img_mod_param_vec[5];
auto txt_normed2 = txt_norm2->forward(ctx, txt);
auto txt_modulated2 = Flux::modulate(ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]);
auto txt_modulated2 = Flux::modulate(ctx->ggml_ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]);
auto txt_gate2 = txt_mod_param_vec[5];
auto img_mlp_out = img_mlp->forward(ctx, img_modulated2);
auto txt_mlp_out = txt_mlp->forward(ctx, txt_modulated2);
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_gate2));
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_gate2));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_gate2));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_gate2));
return {img, txt};
}
@ -294,7 +288,7 @@ namespace Qwen {
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(conditioning_embedding_dim, embedding_dim * 2, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
@ -304,13 +298,13 @@ namespace Qwen {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto emb = linear->forward(ctx, ggml_silu(ctx, c));
auto mods = ggml_ext_chunk(ctx, emb, 2, 0);
auto emb = linear->forward(ctx, ggml_silu(ctx->ggml_ctx, c));
auto mods = ggml_ext_chunk(ctx->ggml_ctx, emb, 2, 0);
auto scale = mods[0];
auto shift = mods[1];
x = norm->forward(ctx, x);
x = Flux::modulate(ctx, x, shift, scale);
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
return x;
}
@ -327,7 +321,6 @@ namespace Qwen {
float theta = 10000;
std::vector<int> axes_dim = {16, 56, 56};
int64_t axes_dim_sum = 128;
bool flash_attn = false;
};
class QwenImageModel : public GGMLBlock {
@ -349,8 +342,7 @@ namespace Qwen {
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
params.num_attention_heads,
params.attention_head_dim,
1e-6f,
params.flash_attn));
1e-6f));
blocks["transformer_blocks." + std::to_string(i)] = block;
}
@ -421,8 +413,7 @@ namespace Qwen {
return x;
}
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -442,7 +433,7 @@ namespace Qwen {
for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
auto result = block->forward(ctx, backend, img, txt, t_emb, pe);
auto result = block->forward(ctx, img, txt, t_emb, pe);
img = result.first;
txt = result.second;
}
@ -453,8 +444,7 @@ namespace Qwen {
return img;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -472,32 +462,32 @@ namespace Qwen {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
auto img = process_img(ctx, x);
auto img = process_img(ctx->ggml_ctx, x);
uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
ref = process_img(ctx->ggml_ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice
out = ggml_ext_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
return out;
}
@ -514,10 +504,8 @@ namespace Qwen {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
SDVersion version = VERSION_QWEN_IMAGE)
: GGMLRunner(backend, offload_params_to_cpu) {
qwen_image_params.flash_attn = flash_attn;
qwen_image_params.num_layers = 0;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
@ -582,8 +570,9 @@ namespace Qwen {
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* out = qwen_image.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
struct ggml_tensor* out = qwen_image.forward(&runner_ctx,
x,
timesteps,
context,
@ -672,8 +661,7 @@ namespace Qwen {
false,
tensor_types,
"model.diffusion_model",
VERSION_QWEN_IMAGE,
true);
VERSION_QWEN_IMAGE);
qwen_image->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;

View File

@ -349,15 +349,15 @@ namespace Qwen {
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, n_token, hidden_size]
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
auto h = gate_proj->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
h = ggml_mul_inplace(ctx, h, up_proj->forward(ctx, x));
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x));
h = down_proj->forward(ctx, h);
return h;
}
@ -409,10 +409,10 @@ namespace Qwen {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
// return: [N*grid_t*grid_h*grid_w, embed_dim]
x = ggml_reshape_4d(ctx,
x = ggml_reshape_4d(ctx->ggml_ctx,
x,
patch_size,
patch_size,
@ -423,22 +423,22 @@ namespace Qwen {
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
auto x0 = ggml_ext_slice(ctx, x, 2, 0, 1);
x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
auto x0 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
x0 = ggml_reshape_4d(ctx->ggml_ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
x0 = proj_0->forward(ctx, x0);
auto x1 = ggml_ext_slice(ctx, x, 2, 1, 2);
x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
auto x1 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, 2);
x1 = ggml_reshape_4d(ctx->ggml_ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
x1 = proj_1->forward(ctx, x1);
x = ggml_add(ctx, x0, x1);
x = ggml_add(ctx->ggml_ctx, x0, x1);
} else {
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
x = proj->forward(ctx, x);
}
x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
x = ggml_reshape_2d(ctx->ggml_ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
return x;
}
};
@ -458,15 +458,15 @@ namespace Qwen {
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx, x);
x = ggml_gelu(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
@ -495,8 +495,7 @@ namespace Qwen {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
@ -519,14 +518,14 @@ namespace Qwen {
} else {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
qkv_vec = split_qkv(ctx, qkv);
qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
}
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size]
x = Rope::attention(ctx, q, k, v, pe, mask, 1.f, false); // [N, n_token, hidden_size]
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
@ -546,8 +545,7 @@ namespace Qwen {
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
@ -559,13 +557,13 @@ namespace Qwen {
auto residual = x;
x = norm1->forward(ctx, x);
x = attn->forward(ctx, backend, x, pe, mask);
x = ggml_add_inplace(ctx, x, residual);
x = attn->forward(ctx, x, pe, mask);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
residual = x;
x = norm2->forward(ctx, x);
x = mlp->forward(ctx, x);
x = ggml_add_inplace(ctx, x, residual);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
return x;
}
@ -607,8 +605,7 @@ namespace Qwen {
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
@ -623,9 +620,9 @@ namespace Qwen {
auto x = patch_embed->forward(ctx, pixel_values);
x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx, x, window_index);
x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx->ggml_ctx, x, window_index);
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
@ -634,12 +631,12 @@ namespace Qwen {
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
mask = nullptr;
}
x = block->forward(ctx, backend, x, pe, mask);
x = block->forward(ctx, x, pe, mask);
}
x = merger->forward(ctx, x);
x = ggml_get_rows(ctx, x, window_inverse_index);
x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index);
return x;
}
@ -664,8 +661,7 @@ namespace Qwen {
blocks["o_proj"] = std::shared_ptr<GGMLBlock>(new Linear(num_heads * head_dim, hidden_size, false));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* input_pos) {
// x: [N, n_token, hidden_size]
@ -680,21 +676,21 @@ namespace Qwen {
auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim]
k = ggml_reshape_4d(ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
v = ggml_reshape_4d(ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim]
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
int sections[4] = {16, 24, 24, 0};
q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_multi(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_cont(ctx, ggml_ext_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
k = ggml_cont(ctx, ggml_ext_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
@ -714,8 +710,7 @@ namespace Qwen {
blocks["post_attention_layernorm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* input_pos) {
// x: [N, n_token, hidden_size]
@ -726,13 +721,13 @@ namespace Qwen {
auto residual = x;
x = input_layernorm->forward(ctx, x);
x = self_attn->forward(ctx, backend, x, input_pos);
x = ggml_add_inplace(ctx, x, residual);
x = self_attn->forward(ctx, x, input_pos);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
residual = x;
x = post_attention_layernorm->forward(ctx, x);
x = mlp->forward(ctx, x);
x = ggml_add_inplace(ctx, x, residual);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
return x;
}
@ -761,8 +756,7 @@ namespace Qwen {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
@ -777,7 +771,7 @@ namespace Qwen {
if (image_embeds.size() > 0) {
GGML_ASSERT(x->ne[2] == 1); // N == 1
auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type);
auto raw_x = ggml_cast(ctx->ggml_ctx, x, image_embeds[0].second->type);
int64_t txt_token_start = 0;
int64_t txt_token_end = 0;
@ -791,23 +785,23 @@ namespace Qwen {
}
txt_token_end = image_embeds[i].first;
auto txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
auto txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end);
if (input_embed == nullptr) {
input_embed = txt_embed;
} else {
input_embed = ggml_concat(ctx, input_embed, txt_embed, 1);
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, txt_embed, 1);
}
auto image_embed = image_embeds[i].second;
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, image_embed, 1);
}
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
txt_token_end = raw_x->ne[1];
auto final_txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
auto final_txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end);
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, final_txt_embed, 1);
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
x = input_embed;
@ -816,7 +810,7 @@ namespace Qwen {
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
x = block->forward(ctx, backend, x, input_pos);
x = block->forward(ctx, x, input_pos);
}
x = norm->forward(ctx, x);
@ -880,20 +874,18 @@ namespace Qwen {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
// input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds);
auto x = model->forward(ctx, input_ids, input_pos, image_embeds);
return x;
}
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
@ -901,7 +893,7 @@ namespace Qwen {
struct ggml_tensor* window_mask) {
GGML_ASSERT(enable_vision);
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask);
return vision_model->forward(ctx, pixel_values, pe, window_index, window_inverse_index, window_mask);
}
};
@ -959,23 +951,21 @@ namespace Qwen {
model.get_param_tensors(tensors, prefix);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
return hidden_states;
}
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx,
struct ggml_tensor* pixel_values,
struct ggml_tensor* input_pos,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
auto hidden_states = model.vision_forward(ctx, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
return hidden_states;
}
@ -1002,7 +992,9 @@ namespace Qwen {
n_tokens * 4);
set_backend_tensor_data(input_pos, input_pos_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds);
auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds);
ggml_build_forward_expand(gf, hidden_states);
@ -1167,8 +1159,8 @@ namespace Qwen {
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* hidden_states = vision_forward(compute_ctx,
runtime_backend,
auto runnter_ctx = get_context();
struct ggml_tensor* hidden_states = vision_forward(&runnter_ctx,
pixel_values,
pe,
window_index,

View File

@ -386,23 +386,21 @@ namespace Rope {
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
__STATIC_INLINE__ struct ggml_tensor* attention(GGMLRunnerContext* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f,
bool rope_interleaved = true) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_ext_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope

View File

@ -341,16 +341,12 @@ public:
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
}
if (sd_version_is_sd3(version)) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu,
sd_ctx_params->diffusion_flash_attn,
model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;
@ -384,7 +380,6 @@ public:
offload_params_to_cpu,
model_loader.tensor_storages_types,
version,
sd_ctx_params->diffusion_flash_attn,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
@ -397,15 +392,13 @@ public:
offload_params_to_cpu,
model_loader.tensor_storages_types,
"model.diffusion_model",
version,
sd_ctx_params->diffusion_flash_attn);
version);
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
"model.high_noise_diffusion_model",
version,
sd_ctx_params->diffusion_flash_attn);
version);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
@ -428,8 +421,7 @@ public:
offload_params_to_cpu,
model_loader.tensor_storages_types,
"model.diffusion_model",
version,
sd_ctx_params->diffusion_flash_attn);
version);
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
@ -448,14 +440,18 @@ public:
diffusion_model = std::make_shared<UNetModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
version,
sd_ctx_params->diffusion_flash_attn);
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the diffusion model");
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.set_conv2d_direct_enabled(true);
}
}
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
}
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@ -500,7 +496,7 @@ public:
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->enable_conv2d_direct();
first_stage_model->set_conv2d_direct_enabled(true);
}
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
@ -522,7 +518,7 @@ public:
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->enable_conv2d_direct();
tae_first_stage->set_conv2d_direct_enabled(true);
}
}
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
@ -541,7 +537,7 @@ public:
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the control net");
control_net->enable_conv2d_direct();
control_net->set_conv2d_direct_enabled(true);
}
}

65
t5.hpp
View File

@ -472,10 +472,10 @@ public:
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
x = ggml_mul(ctx->ggml_ctx, x, w);
return x;
}
};
@ -487,13 +487,13 @@ public:
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, n_token, model_dim]
auto wi = std::dynamic_pointer_cast<Linear>(blocks["wi"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
x = wi->forward(ctx, x);
x = ggml_relu_inplace(ctx, x);
x = ggml_relu_inplace(ctx->ggml_ctx, x);
x = wo->forward(ctx, x);
return x;
}
@ -509,15 +509,15 @@ public:
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false, false, false, scale));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, n_token, model_dim]
auto wi_0 = std::dynamic_pointer_cast<Linear>(blocks["wi_0"]);
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
auto hidden_gelu = ggml_gelu_inplace(ctx, wi_0->forward(ctx, x));
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x));
auto hidden_linear = wi_1->forward(ctx, x);
x = ggml_mul_inplace(ctx, hidden_gelu, hidden_linear);
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
x = wo->forward(ctx, x);
return x;
}
@ -530,14 +530,14 @@ public:
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, n_token, model_dim]
auto DenseReluDense = std::dynamic_pointer_cast<T5DenseGatedActDense>(blocks["DenseReluDense"]);
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
auto forwarded_states = layer_norm->forward(ctx, x);
forwarded_states = DenseReluDense->forward(ctx, forwarded_states);
x = ggml_add_inplace(ctx, forwarded_states, x);
x = ggml_add_inplace(ctx->ggml_ctx, forwarded_states, x);
return x;
}
};
@ -569,18 +569,17 @@ public:
}
}
struct ggml_tensor* compute_bias(struct ggml_context* ctx,
struct ggml_tensor* compute_bias(GGMLRunnerContext* ctx,
struct ggml_tensor* relative_position_bucket) {
auto relative_attention_bias = std::dynamic_pointer_cast<Embedding>(blocks["relative_attention_bias"]);
auto values = relative_attention_bias->forward(ctx, relative_position_bucket); // shape (query_length, key_length, num_heads)
values = ggml_cont(ctx, ggml_permute(ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length)
values = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length)
return values;
}
// x: [N, n_token, model_dim]
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* past_bias = nullptr,
struct ggml_tensor* mask = nullptr,
@ -602,16 +601,16 @@ public:
}
if (past_bias != nullptr) {
if (mask != nullptr) {
mask = ggml_repeat(ctx, mask, past_bias);
mask = ggml_add(ctx, mask, past_bias);
mask = ggml_repeat(ctx->ggml_ctx, mask, past_bias);
mask = ggml_add(ctx->ggml_ctx, mask, past_bias);
} else {
mask = past_bias;
}
}
k = ggml_scale_inplace(ctx, k, sqrt(d_head));
k = ggml_scale_inplace(ctx->ggml_ctx, k, sqrt(d_head));
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
x = out_proj->forward(ctx, x); // [N, n_token, model_dim]
return {x, past_bias};
@ -629,8 +628,7 @@ public:
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* past_bias = nullptr,
struct ggml_tensor* mask = nullptr,
@ -640,11 +638,11 @@ public:
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
auto normed_hidden_state = layer_norm->forward(ctx, x);
auto ret = SelfAttention->forward(ctx, backend, normed_hidden_state, past_bias, mask, relative_position_bucket);
auto ret = SelfAttention->forward(ctx, normed_hidden_state, past_bias, mask, relative_position_bucket);
auto output = ret.first;
past_bias = ret.second;
x = ggml_add_inplace(ctx, output, x);
x = ggml_add_inplace(ctx->ggml_ctx, output, x);
return {x, past_bias};
}
};
@ -656,8 +654,7 @@ public:
blocks["layer.1"] = std::shared_ptr<GGMLBlock>(new T5LayerFF(model_dim, ff_dim));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* past_bias = nullptr,
struct ggml_tensor* mask = nullptr,
@ -666,7 +663,7 @@ public:
auto layer_0 = std::dynamic_pointer_cast<T5LayerSelfAttention>(blocks["layer.0"]);
auto layer_1 = std::dynamic_pointer_cast<T5LayerFF>(blocks["layer.1"]);
auto ret = layer_0->forward(ctx, backend, x, past_bias, mask, relative_position_bucket);
auto ret = layer_0->forward(ctx, x, past_bias, mask, relative_position_bucket);
x = ret.first;
past_bias = ret.second;
x = layer_1->forward(ctx, x);
@ -692,8 +689,7 @@ public:
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* past_bias = nullptr,
struct ggml_tensor* attention_mask = nullptr,
@ -702,7 +698,7 @@ public:
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]);
auto ret = block->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
x = ret.first;
past_bias = ret.second;
}
@ -740,8 +736,7 @@ public:
params.model_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* past_bias = nullptr,
struct ggml_tensor* attention_mask = nullptr,
@ -752,7 +747,7 @@ public:
auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]);
auto x = shared->forward(ctx, input_ids);
x = encoder->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
return x;
}
};
@ -784,15 +779,14 @@ struct T5Runner : public GGMLRunner {
model.get_param_tensors(tensors, prefix);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* relative_position_bucket,
struct ggml_tensor* attention_mask = nullptr) {
size_t N = input_ids->ne[1];
size_t n_token = input_ids->ne[0];
auto hidden_states = model.forward(ctx, backend, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
auto hidden_states = model.forward(ctx, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
return hidden_states;
}
@ -818,7 +812,8 @@ struct T5Runner : public GGMLRunner {
input_ids->ne[0]);
set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, relative_position_bucket, attention_mask);
auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, relative_position_bucket, attention_mask);
ggml_build_forward_expand(gf, hidden_states);

42
tae.hpp
View File

@ -29,7 +29,7 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, n_in, h, w]
// return: [n, n_out, h, w]
@ -38,9 +38,9 @@ public:
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
auto h = conv_0->forward(ctx, x);
h = ggml_relu_inplace(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv_2->forward(ctx, h);
h = ggml_relu_inplace(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv_4->forward(ctx, h);
if (n_in != n_out) {
@ -49,8 +49,8 @@ public:
x = skip->forward(ctx, x);
}
h = ggml_add(ctx, h, x);
h = ggml_relu_inplace(ctx, h);
h = ggml_add(ctx->ggml_ctx, h, x);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
return h;
}
};
@ -86,7 +86,7 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, in_channels, h, w]
// return: [n, z_channels, h/8, w/8]
@ -136,20 +136,20 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
// z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8]
auto h = ggml_scale(ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx, h);
h = ggml_scale(ctx, h, 3.0f);
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
h = ggml_scale(ctx->ggml_ctx, h, 3.0f);
for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) {
if (i == 1) {
h = ggml_relu_inplace(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
} else {
h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST);
h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST);
}
continue;
}
@ -180,12 +180,12 @@ public:
}
}
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
return decoder->forward(ctx, z);
}
struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
return encoder->forward(ctx, x);
}
@ -207,17 +207,6 @@ struct TinyAutoEncoder : public GGMLRunner {
taesd.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
taesd.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() override {
return "taesd";
}
@ -252,7 +241,8 @@ struct TinyAutoEncoder : public GGMLRunner {
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? taesd.decode(compute_ctx, z) : taesd.encode(compute_ctx, z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? taesd.decode(&runner_ctx, z) : taesd.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
return gf;
}

117
unet.hpp
View File

@ -60,8 +60,7 @@ public:
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int timesteps) {
@ -92,7 +91,7 @@ public:
auto time_context = context; // [b*t, n_context, context_dim]
auto spatial_context = context;
// time_context_first_timestep = time_context[::timesteps]
auto time_context_first_timestep = ggml_view_3d(ctx,
auto time_context_first_timestep = ggml_view_3d(ctx->ggml_ctx,
time_context,
time_context->ne[0],
time_context->ne[1],
@ -100,26 +99,26 @@ public:
time_context->nb[1],
time_context->nb[2],
0); // [b, n_context, context_dim]
time_context = ggml_new_tensor_3d(ctx, GGML_TYPE_F32,
time_context = ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32,
time_context_first_timestep->ne[0],
time_context_first_timestep->ne[1],
time_context_first_timestep->ne[2] * h * w);
time_context = ggml_repeat(ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim]
time_context = ggml_repeat(ctx->ggml_ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim]
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
auto num_frames = ggml_arange(ctx, 0, timesteps, 1);
auto num_frames = ggml_arange(ctx->ggml_ctx, 0, timesteps, 1);
// since b is 1, no need to do repeat
auto t_emb = ggml_ext_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
auto emb = time_pos_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels]
emb = ggml_reshape_3d(ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels]
emb = ggml_reshape_3d(ctx->ggml_ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels]
for (int i = 0; i < depth; i++) {
std::string transformer_name = "transformer_blocks." + std::to_string(i);
@ -128,11 +127,11 @@ public:
auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]);
auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]);
x = block->forward(ctx, backend, x, spatial_context); // [N, h * w, inner_dim]
x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim]
// in_channels == inner_dim
auto x_mix = x;
x_mix = ggml_add(ctx, x_mix, emb); // [N, h * w, inner_dim]
x_mix = ggml_add(ctx->ggml_ctx, x_mix, emb); // [N, h * w, inner_dim]
int64_t N = x_mix->ne[2];
int64_t T = timesteps;
@ -140,26 +139,26 @@ public:
int64_t S = x_mix->ne[1];
int64_t C = x_mix->ne[0];
x_mix = ggml_reshape_4d(ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
x_mix = mix_block->forward(ctx, backend, x_mix, time_context); // [B * h * w, T, inner_dim]
x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim]
x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
x_mix = ggml_reshape_3d(ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c
x = time_mixer->forward(ctx, x, x_mix); // [N, h * w, inner_dim]
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_add(ctx, x, x_in);
x = ggml_add(ctx->ggml_ctx, x, x_in);
return x;
}
};
@ -184,7 +183,7 @@ public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}, bool flash_attn = false)
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {})
: version(version) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
@ -252,7 +251,7 @@ public:
if (version == VERSION_SVD) {
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
} else {
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn);
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
}
};
@ -377,7 +376,7 @@ public:
}
struct ggml_tensor* resblock_forward(std::string name,
struct ggml_context* ctx,
GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
@ -393,24 +392,22 @@ public:
}
struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx,
ggml_backend_t backend,
GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int timesteps) {
if (version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
return block->forward(ctx, backend, x, context, timesteps);
return block->forward(ctx, x, context, timesteps);
} else {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, backend, x, context);
return block->forward(ctx, x, context);
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -427,20 +424,20 @@ public:
// return: [N, out_channels, h, w]
if (context != nullptr) {
if (context->ne[2] != x->ne[3]) {
context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
}
}
if (c_concat != nullptr) {
if (c_concat->ne[3] != x->ne[3]) {
c_concat = ggml_repeat(ctx, c_concat, x);
c_concat = ggml_repeat(ctx->ggml_ctx, c_concat, x);
}
x = ggml_concat(ctx, x, c_concat, 2);
x = ggml_concat(ctx->ggml_ctx, x, c_concat, 2);
}
if (y != nullptr) {
if (y->ne[1] != x->ne[3]) {
y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
}
}
@ -451,10 +448,10 @@ public:
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
// SDXL/SVD
@ -463,10 +460,10 @@ public:
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
auto label_emb = label_embed_0->forward(ctx, y);
label_emb = ggml_silu_inplace(ctx, label_emb);
label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb);
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim]
emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim]
}
// input_blocks
@ -489,7 +486,7 @@ public:
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); // [N, mult*model_channels, h, w]
h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w]
}
hs.push_back(h);
}
@ -513,13 +510,13 @@ public:
if (version != VERSION_SD1_TINY_UNET) {
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (version != VERSION_SDXL_SSD1B) {
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
}
}
if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
h = ggml_add(ctx, h, cs); // middle control
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
}
int control_offset = controls.size() - 2;
@ -531,12 +528,12 @@ public:
hs.pop_back();
if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx, controls[control_offset], control_strength);
h_skip = ggml_add(ctx, h_skip, cs); // control net condition
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--;
}
h = ggml_concat(ctx, h, h_skip, 2);
h = ggml_concat(ctx->ggml_ctx, h, h_skip, 2);
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
@ -546,7 +543,7 @@ public:
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames);
h = attention_layer_forward(name, ctx, h, context, num_video_frames);
up_sample_idx++;
}
@ -572,7 +569,7 @@ public:
// out
h = out_0->forward(ctx, h);
h = ggml_silu_inplace(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = out_2->forward(ctx, h);
ggml_set_name(h, "bench-end");
return h; // [N, out_channels, h, w]
@ -586,24 +583,11 @@ struct UNetModelRunner : public GGMLRunner {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const std::string prefix,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types, flash_attn) {
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) {
unet.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
unet.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
LOG_DEBUG("block %s", block->get_desc().c_str());
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() override {
return "unet";
}
@ -636,8 +620,9 @@ struct UNetModelRunner : public GGMLRunner {
controls[i] = to_backend(controls[i]);
}
struct ggml_tensor* out = unet.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
struct ggml_tensor* out = unet.forward(&runner_ctx,
x,
timesteps,
context,

View File

@ -53,7 +53,7 @@ struct UpscalerGGML {
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
if (direct) {
esrgan_upscaler->enable_conv2d_direct();
esrgan_upscaler->set_conv2d_direct_enabled(true);
}
if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
return false;

82
vae.hpp
View File

@ -30,7 +30,7 @@ public:
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
// t_emb is always None
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
@ -40,12 +40,12 @@ public:
auto h = x;
h = norm1->forward(ctx, h);
h = ggml_silu_inplace(ctx, h); // swish
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
h = conv1->forward(ctx, h);
// return h;
h = norm2->forward(ctx, h);
h = ggml_silu_inplace(ctx, h); // swish
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
// dropout, skip for inference
h = conv2->forward(ctx, h);
@ -56,7 +56,7 @@ public:
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
}
h = ggml_add(ctx, h, x);
h = ggml_add(ctx->ggml_ctx, h, x);
return h; // [N, out_channels, h, w]
}
};
@ -76,7 +76,7 @@ public:
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
@ -92,24 +92,24 @@ public:
const int64_t w = h_->ne[0];
auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
h_ = ggml_ext_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
h_ = ggml_add(ctx, h_, x);
h_ = ggml_add(ctx->ggml_ctx, h_, x);
return h_;
}
};
@ -133,7 +133,7 @@ public:
kernel_padding));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) override {
// timesteps always None
// skip_video always False
@ -152,11 +152,11 @@ public:
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x; // [B*T, OC, OH, OW]
}
};
@ -182,7 +182,7 @@ public:
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
// t_emb is always None
@ -199,19 +199,19 @@ public:
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
x = time_stack->forward(ctx, x); // b t c (h w)
float alpha = get_alpha();
x = ggml_add(ctx,
ggml_scale(ctx, x, alpha),
ggml_scale(ctx, x_mix, 1.0f - alpha));
x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x, alpha),
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}
@ -271,7 +271,7 @@ public:
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
@ -307,7 +307,7 @@ public:
// end
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx, h); // nonlinearity/swish
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
return h;
}
@ -388,7 +388,7 @@ public:
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// alpha is always 0
// merge_strategy is always learned
@ -429,7 +429,7 @@ public:
}
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx, h); // nonlinearity/swish
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
return h;
}
@ -493,7 +493,7 @@ public:
}
}
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
@ -507,7 +507,7 @@ public:
return h;
}
struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
@ -529,7 +529,6 @@ struct VAE : public GGMLRunner {
struct ggml_tensor** output,
struct ggml_context* output_ctx) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual void enable_conv2d_direct(){};
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
};
@ -572,17 +571,6 @@ struct AutoEncoderKL : public VAE {
ae.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() override {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
void set_conv2d_scale(float scale) override {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
@ -607,7 +595,9 @@ struct AutoEncoderKL : public VAE {
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);

410
wan.hpp
View File

@ -54,7 +54,7 @@ namespace WAN {
dilation(std::move(dilation)),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) {
// x: [N*IC, ID, IH, IW]
// result: x: [N*OC, ID, IH, IW]
struct ggml_tensor* w = params["weight"];
@ -71,12 +71,12 @@ namespace WAN {
int rp2 = 0;
if (cache_x != nullptr && lp2 > 0) {
x = ggml_concat(ctx, cache_x, x, 2);
x = ggml_concat(ctx->ggml_ctx, cache_x, x, 2);
lp2 -= (int)cache_x->ne[2];
}
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
0, 0, 0,
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
@ -96,15 +96,15 @@ namespace WAN {
RMS_norm(int64_t dim)
: dim(dim) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N*IC, ID, IH, IW], IC == dim
// assert N == 1
struct ggml_tensor* w = params["gamma"];
auto h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx, h, 1e-12);
h = ggml_mul(ctx, h, w);
h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, h, 1, 2, 3, 0));
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12);
h = ggml_mul(ctx->ggml_ctx, h, w);
h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0));
return h;
}
@ -143,7 +143,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -165,16 +165,16 @@ namespace WAN {
} else {
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
cache_x = ggml_pad_ext(ctx->ggml_ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
}
if (chunk_idx == 1) {
@ -183,9 +183,9 @@ namespace WAN {
x = time_conv->forward(ctx, x, feat_cache[idx]);
}
feat_cache[idx] = cache_x;
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
}
}
}
@ -194,18 +194,18 @@ namespace WAN {
if (mode != "none") {
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
if (mode == "upsample2d") {
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "upsample3d") {
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "downsample2d") {
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
} else if (mode == "downsample3d") {
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
}
x = resample_1->forward(ctx, x);
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
}
if (mode == "downsample3d") {
@ -217,9 +217,9 @@ namespace WAN {
} else {
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
auto cache_x = ggml_ext_slice(ctx, x, 2, -1, x->ne[2]);
x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -1, x->ne[2]);
x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
x,
2);
x = time_conv->forward(ctx, x);
@ -249,7 +249,7 @@ namespace WAN {
GGML_ASSERT(in_channels * factor % out_channels == 0);
group_size = in_channels * factor / out_channels;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t B = 1) {
// x: [B*IC, T, H, W]
@ -262,20 +262,20 @@ namespace WAN {
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
x = ggml_pad_ext(ctx->ggml_ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
T = x->ne[2];
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
x = ggml_reshape_3d(ctx->ggml_ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
x = ggml_mean(ctx->ggml_ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
x = ggml_reshape_4d(ctx->ggml_ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
return x;
}
};
@ -296,7 +296,7 @@ namespace WAN {
GGML_ASSERT(out_channels * factor % in_channels == 0);
repeats = out_channels * factor / in_channels;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
bool first_chunk = false,
int64_t B = 1) {
@ -310,21 +310,21 @@ namespace WAN {
auto x_ = x;
for (int64_t i = 1; i < repeats; i++) {
x = ggml_concat(ctx, x, x_, 2);
x = ggml_concat(ctx->ggml_ctx, x, x_, 2);
}
C = out_channels;
x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
if (first_chunk) {
x = ggml_ext_slice(ctx, x, 2, factor_t - 1, x->ne[2]);
x = ggml_ext_slice(ctx->ggml_ctx, x, 2, factor_t - 1, x->ne[2]);
}
return x;
@ -351,7 +351,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -374,11 +374,11 @@ namespace WAN {
if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
@ -388,13 +388,13 @@ namespace WAN {
feat_idx += 1;
}
} else if (i == 1 || i == 4) {
x = ggml_silu(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x);
} else { // i == 5
// nn.Dropout(), ignore
}
}
x = ggml_add(ctx, x, h);
x = ggml_add(ctx->ggml_ctx, x, h);
return x;
}
};
@ -425,7 +425,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -453,7 +453,7 @@ namespace WAN {
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
x = ggml_add(ctx, x, shortcut);
x = ggml_add(ctx->ggml_ctx, x, shortcut);
return x;
}
@ -487,7 +487,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -513,7 +513,7 @@ namespace WAN {
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
x = ggml_add(ctx, x, shortcut);
x = ggml_add(ctx->ggml_ctx, x, shortcut);
}
return x;
@ -532,7 +532,7 @@ namespace WAN {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b) {
// x: [b*c, t, h, w]
@ -545,7 +545,7 @@ namespace WAN {
x = norm->forward(ctx, x);
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
const int64_t n = x->ne[3];
const int64_t c = x->ne[2];
@ -553,31 +553,31 @@ namespace WAN {
const int64_t w = x->ne[0];
auto qkv = to_qkv->forward(ctx, x);
auto qkv_vec = split_image_qkv(ctx, qkv);
auto qkv_vec = split_image_qkv(ctx->ggml_ctx, qkv);
auto q = qkv_vec[0];
q = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
q = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [t, h * w, c]
auto k = qkv_vec[1];
k = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
k = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [t, h * w, c]
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_ext_attention(ctx, q, k, v, false); // [t, h * w, c]
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c]
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
x = proj->forward(ctx, x);
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_add(ctx, x, identity);
x = ggml_add(ctx->ggml_ctx, x, identity);
return x;
}
};
@ -655,7 +655,7 @@ namespace WAN {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -673,11 +673,11 @@ namespace WAN {
// conv1
if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
@ -722,14 +722,14 @@ namespace WAN {
// head
x = head_0->forward(ctx, x);
x = ggml_silu(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x);
if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
@ -826,7 +826,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
@ -844,11 +844,11 @@ namespace WAN {
// conv1
if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
@ -893,14 +893,14 @@ namespace WAN {
// head
x = head_0->forward(ctx, x);
x = ggml_silu(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x);
if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x = ggml_concat(ctx->ggml_ctx,
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
@ -1015,7 +1015,7 @@ namespace WAN {
return x;
}
struct ggml_tensor* encode(struct ggml_context* ctx,
struct ggml_tensor* encode(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
int64_t b = 1) {
// x: [b*c, t, h, w]
@ -1025,7 +1025,7 @@ namespace WAN {
clear_cache();
if (wan2_2) {
x = patchify(ctx, x, 2, b);
x = patchify(ctx->ggml_ctx, x, 2, b);
}
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
@ -1037,21 +1037,21 @@ namespace WAN {
for (int i = 0; i < iter_; i++) {
_enc_conv_idx = 0;
if (i == 0) {
auto in = ggml_ext_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); // [b*c, 1, h, w]
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
} else {
auto in = ggml_ext_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
out = ggml_concat(ctx, out, out_, 2);
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
}
}
out = conv1->forward(ctx, out);
auto mu = ggml_ext_chunk(ctx, out, 2, 3)[0];
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
clear_cache();
return mu;
}
struct ggml_tensor* decode(struct ggml_context* ctx,
struct ggml_tensor* decode(GGMLRunnerContext* ctx,
struct ggml_tensor* z,
int64_t b = 1) {
// z: [b*c, t, h, w]
@ -1068,22 +1068,22 @@ namespace WAN {
for (int64_t i = 0; i < iter_; i++) {
_conv_idx = 0;
if (i == 0) {
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
} else {
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
out = ggml_concat(ctx, out, out_, 2);
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
}
}
if (wan2_2) {
out = unpatchify(ctx, out, 2, b);
out = unpatchify(ctx->ggml_ctx, out, 2, b);
}
clear_cache();
return out;
}
struct ggml_tensor* decode_partial(struct ggml_context* ctx,
struct ggml_tensor* decode_partial(GGMLRunnerContext* ctx,
struct ggml_tensor* z,
int64_t i,
int64_t b = 1) {
@ -1094,11 +1094,11 @@ namespace WAN {
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
auto x = conv2->forward(ctx, z);
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
if (wan2_2) {
out = unpatchify(ctx, out, 2, b);
out = unpatchify(ctx->ggml_ctx, out, 2, b);
}
return out;
}
@ -1131,7 +1131,9 @@ namespace WAN {
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
@ -1150,7 +1152,9 @@ namespace WAN {
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z);
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
@ -1283,15 +1287,13 @@ namespace WAN {
public:
int64_t num_heads;
int64_t head_dim;
bool flash_attn;
public:
WanSelfAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: num_heads(num_heads), flash_attn(flash_attn) {
float eps = 1e-6)
: num_heads(num_heads) {
head_dim = dim / num_heads;
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
@ -1307,8 +1309,7 @@ namespace WAN {
}
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
@ -1331,11 +1332,11 @@ namespace WAN {
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head]
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
@ -1347,11 +1348,9 @@ namespace WAN {
WanCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
float eps = 1e-6)
: WanSelfAttention(dim, num_heads, qk_norm, eps) {}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) = 0;
@ -1362,11 +1361,9 @@ namespace WAN {
WanT2VCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
float eps = 1e-6)
: WanCrossAttention(dim, num_heads, qk_norm, eps) {}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) override {
@ -1390,7 +1387,7 @@ namespace WAN {
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
@ -1402,9 +1399,8 @@ namespace WAN {
WanI2VCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {
float eps = 1e-6)
: WanCrossAttention(dim, num_heads, qk_norm, eps) {
blocks["k_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["v_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
@ -1415,8 +1411,7 @@ namespace WAN {
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) override {
@ -1441,11 +1436,11 @@ namespace WAN {
int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
@ -1457,10 +1452,10 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_ext_attention_ext(ctx, backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_add(ctx, x, img_x);
x = ggml_add(ctx->ggml_ctx, x, img_x);
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
@ -1511,20 +1506,19 @@ namespace WAN {
int64_t num_heads,
bool qk_norm = true,
bool cross_attn_norm = false,
float eps = 1e-6,
bool flash_attn = false)
float eps = 1e-6)
: dim(dim) {
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
if (cross_attn_norm) {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
} else {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
}
if (t2v_cross_attn) {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
} else {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
}
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
@ -1534,8 +1528,7 @@ namespace WAN {
blocks["ffn.2"] = std::shared_ptr<GGMLBlock>(new Linear(ffn_dim, dim));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* e,
struct ggml_tensor* pe,
@ -1547,8 +1540,8 @@ namespace WAN {
// return [N, n_token, dim]
auto modulation = params["modulation"];
e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
auto es = ggml_ext_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
@ -1560,27 +1553,27 @@ namespace WAN {
// self-attention
auto y = norm1->forward(ctx, x);
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1]));
y = modulate_add(ctx, y, es[0]);
y = self_attn->forward(ctx, backend, y, pe);
y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[1]));
y = modulate_add(ctx->ggml_ctx, y, es[0]);
y = self_attn->forward(ctx, y, pe);
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2]));
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[2]));
// cross-attention
x = ggml_add(ctx,
x = ggml_add(ctx->ggml_ctx,
x,
cross_attn->forward(ctx, backend, norm3->forward(ctx, x), context, context_img_len));
cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len));
// ffn
y = norm2->forward(ctx, x);
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4]));
y = modulate_add(ctx, y, es[3]);
y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[4]));
y = modulate_add(ctx->ggml_ctx, y, es[3]);
y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx, y);
y = ggml_gelu_inplace(ctx->ggml_ctx, y);
y = ffn_2->forward(ctx, y);
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5]));
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
return x;
}
@ -1602,17 +1595,15 @@ namespace WAN {
bool qk_norm = true,
bool cross_attn_norm = false,
float eps = 1e-6,
int block_id = 0,
bool flash_attn = false)
: WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps, flash_attn), block_id(block_id) {
int block_id = 0)
: WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps), block_id(block_id) {
if (block_id == 0) {
blocks["before_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
blocks["after_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* c,
struct ggml_tensor* x,
struct ggml_tensor* e,
@ -1627,12 +1618,12 @@ namespace WAN {
auto before_proj = std::dynamic_pointer_cast<Linear>(blocks["before_proj"]);
c = before_proj->forward(ctx, c);
c = ggml_add(ctx, c, x);
c = ggml_add(ctx->ggml_ctx, c, x);
}
auto after_proj = std::dynamic_pointer_cast<Linear>(blocks["after_proj"]);
c = WanAttentionBlock::forward(ctx, backend, c, e, pe, context, context_img_len);
c = WanAttentionBlock::forward(ctx, c, e, pe, context, context_img_len);
auto c_skip = after_proj->forward(ctx, c);
return {c_skip, c};
@ -1660,7 +1651,7 @@ namespace WAN {
blocks["head"] = std::shared_ptr<GGMLBlock>(new Linear(dim, out_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* e) {
// x: [N, n_token, dim]
@ -1668,18 +1659,18 @@ namespace WAN {
// return [N, n_token, out_dim]
auto modulation = params["modulation"];
e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
e = ggml_reshape_4d(ctx->ggml_ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
e = ggml_repeat_4d(ctx->ggml_ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
auto es = ggml_ext_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
x = norm->forward(ctx, x);
x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1]));
x = modulate_add(ctx, x, es[0]);
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, x, es[1]));
x = modulate_add(ctx->ggml_ctx, x, es[0]);
x = head->forward(ctx, x);
return x;
}
@ -1708,15 +1699,15 @@ namespace WAN {
blocks["proj.4"] = std::shared_ptr<GGMLBlock>(new LayerNorm(out_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* image_embeds) {
if (flf_pos_embed_token_number > 0) {
auto emb_pos = params["emb_pos"];
auto a = ggml_ext_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]);
auto b = ggml_ext_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]);
auto a = ggml_ext_slice(ctx->ggml_ctx, image_embeds, 1, 0, emb_pos->ne[1]);
auto b = ggml_ext_slice(ctx->ggml_ctx, emb_pos, 1, 0, image_embeds->ne[1]);
image_embeds = ggml_add(ctx, a, b);
image_embeds = ggml_add(ctx->ggml_ctx, a, b);
}
auto proj_0 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.0"]);
@ -1726,7 +1717,7 @@ namespace WAN {
auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x);
@ -1757,7 +1748,6 @@ namespace WAN {
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
std::vector<int> axes_dim = {44, 42, 42};
int64_t axes_dim_sum = 128;
bool flash_attn = false;
};
class Wan : public GGMLBlock {
@ -1792,8 +1782,7 @@ namespace WAN {
params.num_heads,
params.qk_norm,
params.cross_attn_norm,
params.eps,
params.flash_attn));
params.eps));
blocks["blocks." + std::to_string(i)] = block;
}
@ -1815,8 +1804,7 @@ namespace WAN {
params.qk_norm,
params.cross_attn_norm,
params.eps,
i,
params.flash_attn));
i));
blocks["vace_blocks." + std::to_string(i)] = block;
}
@ -1872,8 +1860,7 @@ namespace WAN {
return x;
}
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -1903,22 +1890,22 @@ namespace WAN {
// patch_embedding
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
// time_embedding
auto e = ggml_ext_timestep_embedding(ctx, timestep, params.freq_dim);
auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, params.freq_dim);
e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx, e);
e = ggml_silu_inplace(ctx->ggml_ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
// time_projection
auto e0 = ggml_silu(ctx, e);
auto e0 = ggml_silu(ctx->ggml_ctx, e);
e0 = time_projection_1->forward(ctx, e0);
e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx, context);
context = ggml_gelu(ctx->ggml_ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0;
@ -1926,7 +1913,7 @@ namespace WAN {
if (params.model_type == "i2v") {
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
context = ggml_concat(ctx->ggml_ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
}
context_img_len = clip_fea->ne[1]; // 257
}
@ -1937,8 +1924,8 @@ namespace WAN {
auto vace_patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["vace_patch_embedding"]);
c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len]
c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
c = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
c = ggml_reshape_3d(ctx->ggml_ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
c = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
}
auto x_orig = x;
@ -1946,7 +1933,7 @@ namespace WAN {
for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, backend, x, e0, pe, context, context_img_len);
x = block->forward(ctx, x, e0, pe, context, context_img_len);
auto iter = params.vace_layers_mapping.find(i);
if (iter != params.vace_layers_mapping.end()) {
@ -1954,11 +1941,11 @@ namespace WAN {
auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]);
auto result = vace_block->forward(ctx, backend, c, x_orig, e0, pe, context, context_img_len);
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first;
c = result.second;
c_skip = ggml_scale(ctx, c_skip, vace_strength);
x = ggml_add(ctx, x, c_skip);
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength);
x = ggml_add(ctx->ggml_ctx, x, c_skip);
}
}
@ -1967,8 +1954,7 @@ namespace WAN {
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
@ -1993,27 +1979,27 @@ namespace WAN {
int64_t T = x->ne[2];
int64_t C = x->ne[3];
x = pad_to_patch_size(ctx, x);
x = pad_to_patch_size(ctx->ggml_ctx, x);
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
if (time_dim_concat != nullptr) {
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat);
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
}
auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
out = unpatchify(ctx->ggml_ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
// slice
out = ggml_ext_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
out = ggml_ext_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
out = ggml_ext_slice(ctx, out, 0, 0, W); // [N*C, T, H, W]
out = ggml_ext_slice(ctx->ggml_ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N*C, T, H, W]
return out;
}
@ -2031,10 +2017,8 @@ namespace WAN {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_WAN2,
bool flash_attn = false)
SDVersion version = VERSION_WAN2)
: GGMLRunner(backend, offload_params_to_cpu) {
wan_params.flash_attn = flash_attn;
wan_params.num_layers = 0;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
@ -2183,8 +2167,9 @@ namespace WAN {
x = ggml_concat(compute_ctx, x, c_concat, 3);
}
struct ggml_tensor* out = wan.forward(compute_ctx,
runtime_backend,
auto runner_ctx = get_context();
struct ggml_tensor* out = wan.forward(&runner_ctx,
x,
timesteps,
context,
@ -2281,8 +2266,7 @@ namespace WAN {
false,
tensor_types,
"model.diffusion_model",
VERSION_WAN2_2_TI2V,
true);
VERSION_WAN2_2_TI2V);
wan->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;