mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 10:18:51 +00:00
Compare commits
No commits in common. "43e829f21966abb96b08c712bccee872dc820914" and "fa61ea744d1a87fa26a63f8a86e45587bc9534d6" have entirely different histories.
43e829f219
...
fa61ea744d
47
clip.hpp
47
clip.hpp
@ -479,9 +479,9 @@ public:
|
|||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
if (use_gelu) {
|
if (use_gelu) {
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
|
||||||
}
|
}
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
@ -510,7 +510,7 @@ public:
|
|||||||
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
||||||
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
||||||
@ -542,8 +542,8 @@ public:
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* mask = nullptr,
|
int clip_skip = -1,
|
||||||
int clip_skip = -1) {
|
bool mask = true) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
int layer_idx = n_layer - 1;
|
int layer_idx = n_layer - 1;
|
||||||
// LOG_DEBUG("clip_skip %d", clip_skip);
|
// LOG_DEBUG("clip_skip %d", clip_skip);
|
||||||
@ -741,17 +741,16 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* tkn_embeddings,
|
struct ggml_tensor* tkn_embeddings,
|
||||||
struct ggml_tensor* mask = nullptr,
|
size_t max_token_idx = 0,
|
||||||
size_t max_token_idx = 0,
|
bool return_pooled = false,
|
||||||
bool return_pooled = false,
|
int clip_skip = -1) {
|
||||||
int clip_skip = -1) {
|
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
|
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
|
||||||
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
|
||||||
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
||||||
|
|
||||||
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
||||||
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
|
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
|
||||||
if (return_pooled || with_final_ln) {
|
if (return_pooled || with_final_ln) {
|
||||||
x = final_layer_norm->forward(ctx, x);
|
x = final_layer_norm->forward(ctx, x);
|
||||||
}
|
}
|
||||||
@ -815,11 +814,10 @@ public:
|
|||||||
|
|
||||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||||
x = pre_layernorm->forward(ctx, x);
|
x = pre_layernorm->forward(ctx, x);
|
||||||
x = encoder->forward(ctx, x, nullptr, clip_skip);
|
x = encoder->forward(ctx, x, clip_skip, false);
|
||||||
|
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||||
auto last_hidden_state = x;
|
auto last_hidden_state = x;
|
||||||
|
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
|
||||||
|
|
||||||
GGML_ASSERT(x->ne[3] == 1);
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
@ -907,8 +905,6 @@ public:
|
|||||||
struct CLIPTextModelRunner : public GGMLRunner {
|
struct CLIPTextModelRunner : public GGMLRunner {
|
||||||
CLIPTextModel model;
|
CLIPTextModel model;
|
||||||
|
|
||||||
std::vector<float> attention_mask_vec;
|
|
||||||
|
|
||||||
CLIPTextModelRunner(ggml_backend_t backend,
|
CLIPTextModelRunner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2TensorStorage& tensor_storage_map,
|
const String2TensorStorage& tensor_storage_map,
|
||||||
@ -942,7 +938,6 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* embeddings,
|
struct ggml_tensor* embeddings,
|
||||||
struct ggml_tensor* mask,
|
|
||||||
size_t max_token_idx = 0,
|
size_t max_token_idx = 0,
|
||||||
bool return_pooled = false,
|
bool return_pooled = false,
|
||||||
int clip_skip = -1) {
|
int clip_skip = -1) {
|
||||||
@ -953,7 +948,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
|
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||||
@ -980,23 +975,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_tokens = static_cast<int>(input_ids->ne[0]);
|
|
||||||
attention_mask_vec.resize(n_tokens * n_tokens);
|
|
||||||
for (int i0 = 0; i0 < n_tokens; i0++) {
|
|
||||||
for (int i1 = 0; i1 < n_tokens; i1++) {
|
|
||||||
float value = 0.f;
|
|
||||||
if (i0 > i1) {
|
|
||||||
value = -INFINITY;
|
|
||||||
}
|
|
||||||
attention_mask_vec[i1 * n_tokens + i0] = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
|
||||||
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
|||||||
10
common.hpp
10
common.hpp
@ -200,7 +200,7 @@ public:
|
|||||||
|
|
||||||
gate = ggml_cont(ctx->ggml_ctx, gate);
|
gate = ggml_cont(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
|
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ public:
|
|||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -317,7 +317,7 @@ public:
|
|||||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||||
|
|
||||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||||
return x;
|
return x;
|
||||||
@ -536,8 +536,8 @@ public:
|
|||||||
// image_only_indicator is always tensor([0.])
|
// image_only_indicator is always tensor([0.])
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
auto x = ggml_add(ctx->ggml_ctx,
|
auto x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
|
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
|
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -51,7 +51,7 @@ public:
|
|||||||
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
|
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
|
||||||
auto x5 = conv5->forward(ctx, x_cat);
|
auto x5 = conv5->forward(ctx, x_cat);
|
||||||
|
|
||||||
x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
|
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
|
||||||
return x5;
|
return x5;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -76,7 +76,7 @@ public:
|
|||||||
out = rdb2->forward(ctx, out);
|
out = rdb2->forward(ctx, out);
|
||||||
out = rdb3->forward(ctx, out);
|
out = rdb3->forward(ctx, out);
|
||||||
|
|
||||||
out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
|
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -603,11 +603,11 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.mask_image_path.size() > 0) {
|
if (gen_params.mask_image_path.size() > 0) {
|
||||||
if (!load_sd_image_from_file(&mask_image,
|
if (load_sd_image_from_file(&mask_image,
|
||||||
gen_params.mask_image_path.c_str(),
|
gen_params.mask_image_path.c_str(),
|
||||||
gen_params.get_resolved_width(),
|
gen_params.get_resolved_width(),
|
||||||
gen_params.get_resolved_height(),
|
gen_params.get_resolved_height(),
|
||||||
1)) {
|
1)) {
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
@ -625,10 +625,10 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.control_image_path.size() > 0) {
|
if (gen_params.control_image_path.size() > 0) {
|
||||||
if (!load_sd_image_from_file(&control_image,
|
if (load_sd_image_from_file(&control_image,
|
||||||
gen_params.control_image_path.c_str(),
|
gen_params.control_image_path.c_str(),
|
||||||
gen_params.get_resolved_width(),
|
gen_params.get_resolved_width(),
|
||||||
gen_params.get_resolved_height())) {
|
gen_params.get_resolved_height())) {
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
|
|||||||
93
flux.hpp
93
flux.hpp
@ -103,7 +103,7 @@ namespace Flux {
|
|||||||
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x);
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||||
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
||||||
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
||||||
@ -153,7 +153,7 @@ namespace Flux {
|
|||||||
if (use_mlp_silu_act) {
|
if (use_mlp_silu_act) {
|
||||||
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
|
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
}
|
}
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
@ -376,23 +376,26 @@ namespace Flux {
|
|||||||
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [N, n_txt_token, hidden_size]
|
0); // [n_txt_token, N, hidden_size]
|
||||||
|
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
||||||
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
|
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
|
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
// calculate the img bloks
|
// calculate the img bloks
|
||||||
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
||||||
@ -489,29 +492,43 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
||||||
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
|
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
|
||||||
|
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
|
||||||
|
|
||||||
auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
|
auto qkv = ggml_view_3d(ctx->ggml_ctx,
|
||||||
auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
|
qkv_mlp,
|
||||||
auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
|
qkv_mlp->ne[0],
|
||||||
|
qkv_mlp->ne[1],
|
||||||
|
hidden_size * 3,
|
||||||
|
qkv_mlp->nb[1],
|
||||||
|
qkv_mlp->nb[2],
|
||||||
|
0); // [hidden_size * 3 , N, n_token]
|
||||||
|
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
|
||||||
|
auto mlp = ggml_view_3d(ctx->ggml_ctx,
|
||||||
|
qkv_mlp,
|
||||||
|
qkv_mlp->ne[0],
|
||||||
|
qkv_mlp->ne[1],
|
||||||
|
mlp_hidden_dim * mlp_mult_factor,
|
||||||
|
qkv_mlp->nb[1],
|
||||||
|
qkv_mlp->nb[2],
|
||||||
|
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
|
||||||
|
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
|
||||||
|
|
||||||
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
|
||||||
int64_t head_dim = hidden_size / num_heads;
|
int64_t head_dim = hidden_size / num_heads;
|
||||||
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
q = norm->query_norm(ctx, q);
|
||||||
|
k = norm->key_norm(ctx, k);
|
||||||
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
|
|
||||||
q = norm->query_norm(ctx, q);
|
|
||||||
k = norm->key_norm(ctx, k);
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
|
||||||
|
|
||||||
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
|
|
||||||
if (use_yak_mlp) {
|
if (use_yak_mlp) {
|
||||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
|
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
|
||||||
} else if (use_mlp_silu_act) {
|
} else if (use_mlp_silu_act) {
|
||||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
|
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
|
||||||
} else {
|
} else {
|
||||||
mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
|
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
|
||||||
}
|
}
|
||||||
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||||
@ -563,10 +580,13 @@ namespace Flux {
|
|||||||
} else {
|
} else {
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||||
shift = m_vec[0]; // [N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||||
scale = m_vec[1]; // [N, hidden_size]
|
|
||||||
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
|
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
|
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
@ -1014,14 +1034,16 @@ namespace Flux {
|
|||||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
||||||
}
|
}
|
||||||
|
|
||||||
img = ggml_view_3d(ctx->ggml_ctx,
|
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
txt_img,
|
img = ggml_view_3d(ctx->ggml_ctx,
|
||||||
txt_img->ne[0],
|
txt_img,
|
||||||
img->ne[1],
|
txt_img->ne[0],
|
||||||
txt_img->ne[2],
|
txt_img->ne[1],
|
||||||
txt_img->nb[1],
|
img->ne[1],
|
||||||
txt_img->nb[2],
|
txt_img->nb[1],
|
||||||
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
txt_img->nb[2],
|
||||||
|
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
|
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
if (final_layer) {
|
if (final_layer) {
|
||||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||||
@ -1174,8 +1196,9 @@ namespace Flux {
|
|||||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
||||||
|
|
||||||
if (out->ne[1] > img_tokens) {
|
if (out->ne[1] > img_tokens) {
|
||||||
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||||
out = ggml_cont(ctx->ggml_ctx, out);
|
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
||||||
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||||
|
|||||||
@ -687,8 +687,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int dim,
|
int dim,
|
||||||
int64_t start,
|
int64_t start,
|
||||||
int64_t end,
|
int64_t end) {
|
||||||
bool cont = true) {
|
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
if (x->ne[dim] == 1) {
|
if (x->ne[dim] == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -703,15 +702,27 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
||||||
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
||||||
|
|
||||||
int64_t slice_size = end - start;
|
int perm[4] = {0, 1, 2, 3};
|
||||||
int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
|
for (int i = dim; i < 3; ++i)
|
||||||
slice_ne[dim] = slice_size;
|
perm[i] = perm[i + 1];
|
||||||
|
perm[3] = dim;
|
||||||
|
|
||||||
x = ggml_view_4d(ctx, x,
|
int inv_perm[4];
|
||||||
slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
|
for (int i = 0; i < 4; ++i)
|
||||||
x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
|
inv_perm[perm[i]] = i;
|
||||||
|
|
||||||
if (cont) {
|
if (dim != 3) {
|
||||||
|
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_view_4d(
|
||||||
|
ctx, x,
|
||||||
|
x->ne[0], x->ne[1], x->ne[2], end - start,
|
||||||
|
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
||||||
|
|
||||||
|
if (dim != 3) {
|
||||||
|
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||||
x = ggml_cont(ctx, x);
|
x = ggml_cont(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -949,49 +960,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
|
|||||||
return ggml_group_norm(ctx, a, 32, eps);
|
return ggml_group_norm(ctx, a, 32, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
float factor,
|
|
||||||
bool inplace = false) {
|
|
||||||
if (!ggml_is_contiguous(x)) {
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
if (inplace) {
|
|
||||||
x = ggml_scale_inplace(ctx, x, factor);
|
|
||||||
} else {
|
|
||||||
x = ggml_scale(ctx, x, factor);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
bool inplace = false) {
|
|
||||||
if (!ggml_is_contiguous(x)) {
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
if (inplace) {
|
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
|
||||||
} else {
|
|
||||||
x = ggml_gelu(ctx, x);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
bool inplace = false) {
|
|
||||||
if (!ggml_is_contiguous(x)) {
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
if (inplace) {
|
|
||||||
x = ggml_gelu_quick_inplace(ctx, x);
|
|
||||||
} else {
|
|
||||||
x = ggml_gelu_quick(ctx, x);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* w,
|
struct ggml_tensor* w,
|
||||||
@ -999,7 +967,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
bool force_prec_f32 = false,
|
bool force_prec_f32 = false,
|
||||||
float scale = 1.f) {
|
float scale = 1.f) {
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_ext_scale(ctx, x, scale);
|
x = ggml_scale(ctx, x, scale);
|
||||||
}
|
}
|
||||||
if (x->ne[2] * x->ne[3] > 1024) {
|
if (x->ne[2] * x->ne[3] > 1024) {
|
||||||
// workaround: avoid ggml cuda error
|
// workaround: avoid ggml cuda error
|
||||||
@ -1018,7 +986,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_ext_scale(ctx, x, 1.f / scale);
|
x = ggml_scale(ctx, x, 1.f / scale);
|
||||||
}
|
}
|
||||||
if (b != nullptr) {
|
if (b != nullptr) {
|
||||||
x = ggml_add_inplace(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
@ -1087,7 +1055,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
bool circular_y = false,
|
bool circular_y = false,
|
||||||
float scale = 1.f) {
|
float scale = 1.f) {
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_ext_scale(ctx, x, scale);
|
x = ggml_scale(ctx, x, scale);
|
||||||
}
|
}
|
||||||
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
||||||
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
||||||
@ -1105,7 +1073,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
||||||
}
|
}
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_ext_scale(ctx, x, 1.f / scale);
|
x = ggml_scale(ctx, x, 1.f / scale);
|
||||||
}
|
}
|
||||||
if (b != nullptr) {
|
if (b != nullptr) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||||
@ -1203,7 +1171,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
|
|||||||
int64_t ne2,
|
int64_t ne2,
|
||||||
int64_t ne3) {
|
int64_t ne3) {
|
||||||
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
||||||
auto t = ggml_ext_scale(ctx, one, value); // [1,]
|
auto t = ggml_scale(ctx, one, value); // [1,]
|
||||||
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
@ -1257,6 +1225,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
struct ggml_tensor* v,
|
struct ggml_tensor* v,
|
||||||
int64_t n_head,
|
int64_t n_head,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
|
bool diag_mask_inf = false,
|
||||||
bool skip_reshape = false,
|
bool skip_reshape = false,
|
||||||
bool flash_attn = false,
|
bool flash_attn = false,
|
||||||
float kv_scale = 1.0f) { // avoid overflow
|
float kv_scale = 1.0f) { // avoid overflow
|
||||||
@ -1302,7 +1271,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
|
k_in = ggml_scale(ctx, k_in, kv_scale);
|
||||||
}
|
}
|
||||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
@ -1312,7 +1281,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
|
v_in = ggml_scale(ctx, v_in, kv_scale);
|
||||||
}
|
}
|
||||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
@ -1344,7 +1313,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
||||||
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
|
out = ggml_scale(ctx, out, 1.0f / kv_scale);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
};
|
};
|
||||||
@ -1384,6 +1353,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
if (mask) {
|
if (mask) {
|
||||||
kq = ggml_add_inplace(ctx, kq, mask);
|
kq = ggml_add_inplace(ctx, kq, mask);
|
||||||
}
|
}
|
||||||
|
if (diag_mask_inf) {
|
||||||
|
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
||||||
|
}
|
||||||
kq = ggml_soft_max_inplace(ctx, kq);
|
kq = ggml_soft_max_inplace(ctx, kq);
|
||||||
|
|
||||||
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
||||||
@ -1551,7 +1523,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
|
|||||||
int dim,
|
int dim,
|
||||||
int max_period = 10000,
|
int max_period = 10000,
|
||||||
float time_factor = 1.0f) {
|
float time_factor = 1.0f) {
|
||||||
timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
|
timesteps = ggml_scale(ctx, timesteps, time_factor);
|
||||||
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
|
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2600,7 +2572,7 @@ public:
|
|||||||
// x: [N, n_token, embed_dim]
|
// x: [N, n_token, embed_dim]
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* mask = nullptr) {
|
bool mask = false) {
|
||||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
||||||
|
|
||||||
ggml_tensor* q;
|
ggml_tensor* q;
|
||||||
@ -2623,7 +2595,7 @@ public:
|
|||||||
v = v_proj->forward(ctx, x);
|
v = v_proj->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
||||||
return x;
|
return x;
|
||||||
|
|||||||
4
llm.hpp
4
llm.hpp
@ -638,7 +638,7 @@ namespace LLM {
|
|||||||
x = ln_q->forward(ctx, x);
|
x = ln_q->forward(ctx, x);
|
||||||
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
||||||
x = mlp_0->forward(ctx, x);
|
x = mlp_0->forward(ctx, x);
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x);
|
x = ggml_gelu(ctx->ggml_ctx, x);
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -881,7 +881,7 @@ namespace LLM {
|
|||||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
return x;
|
return x;
|
||||||
|
|||||||
10
lora.hpp
10
lora.hpp
@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
|
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
|
||||||
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
float scale_value = 1.0f;
|
float scale_value = 1.0f;
|
||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
|
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
|
||||||
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
|
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
|
||||||
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
|
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
|
||||||
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
} else {
|
} else {
|
||||||
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
|
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
|
||||||
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
|
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value);
|
||||||
|
|
||||||
if (out_diff == nullptr) {
|
if (out_diff == nullptr) {
|
||||||
out_diff = curr_out_diff;
|
out_diff = curr_out_diff;
|
||||||
|
|||||||
87
mmdit.hpp
87
mmdit.hpp
@ -33,7 +33,7 @@ public:
|
|||||||
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -211,8 +211,8 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto qkv = pre_attention(ctx, x);
|
auto qkv = pre_attention(ctx, x);
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -284,19 +284,23 @@ public:
|
|||||||
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
int n_mods = 9;
|
int64_t n_mods = 9;
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
||||||
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
||||||
|
|
||||||
auto shift_msa = m_vec[0]; // [N, hidden_size]
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
auto scale_msa = m_vec[1]; // [N, hidden_size]
|
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
auto gate_msa = m_vec[2]; // [N, hidden_size]
|
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
auto shift_mlp = m_vec[3]; // [N, hidden_size]
|
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
||||||
auto scale_mlp = m_vec[4]; // [N, hidden_size]
|
|
||||||
auto gate_mlp = m_vec[5]; // [N, hidden_size]
|
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
||||||
auto shift_msa2 = m_vec[6]; // [N, hidden_size]
|
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
||||||
auto scale_msa2 = m_vec[7]; // [N, hidden_size]
|
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
||||||
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
|
|
||||||
|
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
|
||||||
|
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
|
||||||
|
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
|
||||||
|
|
||||||
auto x_norm = norm1->forward(ctx, x);
|
auto x_norm = norm1->forward(ctx, x);
|
||||||
|
|
||||||
@ -318,20 +322,22 @@ public:
|
|||||||
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
int n_mods = 6;
|
int64_t n_mods = 6;
|
||||||
if (pre_only) {
|
if (pre_only) {
|
||||||
n_mods = 2;
|
n_mods = 2;
|
||||||
}
|
}
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
||||||
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
||||||
|
|
||||||
auto shift_msa = m_vec[0]; // [N, hidden_size]
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
auto scale_msa = m_vec[1]; // [N, hidden_size]
|
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
|
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
auto gate_msa = m_vec[2]; // [N, hidden_size]
|
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
||||||
auto shift_mlp = m_vec[3]; // [N, hidden_size]
|
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
||||||
auto scale_mlp = m_vec[4]; // [N, hidden_size]
|
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
||||||
auto gate_mlp = m_vec[5]; // [N, hidden_size]
|
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
||||||
|
|
||||||
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
||||||
|
|
||||||
@ -433,8 +439,8 @@ public:
|
|||||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||||
auto intermediates = std::get<2>(qkv_intermediates);
|
auto intermediates = std::get<2>(qkv_intermediates);
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention_x(ctx,
|
x = post_attention_x(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
attn2_out,
|
attn2_out,
|
||||||
@ -450,7 +456,7 @@ public:
|
|||||||
auto qkv = qkv_intermediates.first;
|
auto qkv = qkv_intermediates.first;
|
||||||
auto intermediates = qkv_intermediates.second;
|
auto intermediates = qkv_intermediates.second;
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx,
|
x = post_attention(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
intermediates[0],
|
intermediates[0],
|
||||||
@ -494,24 +500,26 @@ block_mixing(GGMLRunnerContext* ctx,
|
|||||||
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
||||||
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
||||||
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
context->ne[1],
|
context->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [N, n_context, hidden_size]
|
0); // [n_context, N, hidden_size]
|
||||||
|
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
|
||||||
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
|
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
x->ne[1],
|
x->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
|
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
|
||||||
|
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
if (!context_block->pre_only) {
|
if (!context_block->pre_only) {
|
||||||
context = context_block->post_attention(ctx,
|
context = context_block->post_attention(ctx,
|
||||||
@ -526,7 +534,7 @@ block_mixing(GGMLRunnerContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (x_block->self_attn) {
|
if (x_block->self_attn) {
|
||||||
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = x_block->post_attention_x(ctx,
|
x = x_block->post_attention_x(ctx,
|
||||||
x_attn,
|
x_attn,
|
||||||
@ -596,10 +604,13 @@ public:
|
|||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||||
auto shift = m_vec[0]; // [N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||||
auto scale = m_vec[1]; // [N, hidden_size]
|
|
||||||
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
|
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
|
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
|
|
||||||
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
x = linear->forward(ctx, x);
|
x = linear->forward(ctx, x);
|
||||||
|
|||||||
6
pmid.hpp
6
pmid.hpp
@ -33,7 +33,7 @@ public:
|
|||||||
x = layer_norm->forward(ctx, x);
|
x = layer_norm->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
||||||
if (use_residue)
|
if (use_residue)
|
||||||
@ -129,8 +129,8 @@ public:
|
|||||||
k = reshape_tensor(ctx->ggml_ctx, k, heads);
|
k = reshape_tensor(ctx->ggml_ctx, k, heads);
|
||||||
v = reshape_tensor(ctx->ggml_ctx, v, heads);
|
v = reshape_tensor(ctx->ggml_ctx, v, heads);
|
||||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||||
k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
|
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale);
|
||||||
q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
|
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale);
|
||||||
// auto weight = ggml_mul_mat(ctx, q, k);
|
// auto weight = ggml_mul_mat(ctx, q, k);
|
||||||
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
|
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||||
|
|
||||||
|
|||||||
@ -162,25 +162,26 @@ namespace Qwen {
|
|||||||
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [N, n_txt_token, n_head*d_head]
|
0); // [n_txt_token, N, hidden_size]
|
||||||
|
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
||||||
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
attn->ne[2],
|
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
|
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
|
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
|
|
||||||
|
|
||||||
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
||||||
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
||||||
|
|||||||
2
rope.hpp
2
rope.hpp
@ -642,7 +642,7 @@ namespace Rope {
|
|||||||
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
|
|
||||||
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
}; // namespace Rope
|
}; // namespace Rope
|
||||||
|
|||||||
4
t5.hpp
4
t5.hpp
@ -515,7 +515,7 @@ public:
|
|||||||
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
||||||
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
||||||
|
|
||||||
auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
|
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x));
|
||||||
auto hidden_linear = wi_1->forward(ctx, x);
|
auto hidden_linear = wi_1->forward(ctx, x);
|
||||||
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
|
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
|
||||||
x = wo->forward(ctx, x);
|
x = wo->forward(ctx, x);
|
||||||
@ -608,7 +608,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
|
k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)));
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
||||||
|
|
||||||
|
|||||||
13
tae.hpp
13
tae.hpp
@ -161,9 +161,9 @@ public:
|
|||||||
// z: [n, z_channels, h, w]
|
// z: [n, z_channels, h, w]
|
||||||
// return: [n, out_channels, h*8, w*8]
|
// return: [n, out_channels, h*8, w*8]
|
||||||
|
|
||||||
auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
|
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
|
||||||
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
|
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
|
||||||
h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
|
h = ggml_scale(ctx->ggml_ctx, h, 3.0f);
|
||||||
|
|
||||||
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
||||||
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
||||||
@ -400,11 +400,10 @@ public:
|
|||||||
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
||||||
|
|
||||||
// Clamp()
|
// Clamp()
|
||||||
auto h = ggml_ext_scale(ctx->ggml_ctx,
|
auto h = ggml_scale_inplace(ctx->ggml_ctx,
|
||||||
ggml_tanh_inplace(ctx->ggml_ctx,
|
ggml_tanh_inplace(ctx->ggml_ctx,
|
||||||
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||||
3.0f,
|
3.0f);
|
||||||
true);
|
|
||||||
|
|
||||||
h = first_conv->forward(ctx, h);
|
h = first_conv->forward(ctx, h);
|
||||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
|||||||
4
unet.hpp
4
unet.hpp
@ -529,7 +529,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
|
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength);
|
||||||
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
|
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
|
||||||
}
|
}
|
||||||
int control_offset = static_cast<int>(controls.size() - 2);
|
int control_offset = static_cast<int>(controls.size() - 2);
|
||||||
@ -542,7 +542,7 @@ public:
|
|||||||
hs.pop_back();
|
hs.pop_back();
|
||||||
|
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
|
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength);
|
||||||
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
|
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
|
||||||
control_offset--;
|
control_offset--;
|
||||||
}
|
}
|
||||||
|
|||||||
6
vae.hpp
6
vae.hpp
@ -141,7 +141,7 @@ public:
|
|||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||||
}
|
}
|
||||||
|
|
||||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
|
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
|
||||||
|
|
||||||
if (use_linear) {
|
if (use_linear) {
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
@ -253,8 +253,8 @@ public:
|
|||||||
|
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
x = ggml_add(ctx->ggml_ctx,
|
x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
ggml_scale(ctx->ggml_ctx, x, alpha),
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||||
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
|||||||
25
wan.hpp
25
wan.hpp
@ -572,8 +572,8 @@ namespace WAN {
|
|||||||
auto v = qkv_vec[2];
|
auto v = qkv_vec[2];
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
@ -1393,7 +1393,7 @@ namespace WAN {
|
|||||||
k = norm_k->forward(ctx, k);
|
k = norm_k->forward(ctx, k);
|
||||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1442,8 +1442,11 @@ namespace WAN {
|
|||||||
int64_t dim = x->ne[0];
|
int64_t dim = x->ne[0];
|
||||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||||
|
|
||||||
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
|
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||||
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
|
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||||
|
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
|
||||||
|
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||||
|
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
auto q = q_proj->forward(ctx, x);
|
auto q = q_proj->forward(ctx, x);
|
||||||
q = norm_q->forward(ctx, q);
|
q = norm_q->forward(ctx, q);
|
||||||
@ -1455,8 +1458,8 @@ namespace WAN {
|
|||||||
k_img = norm_k_img->forward(ctx, k_img);
|
k_img = norm_k_img->forward(ctx, k_img);
|
||||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
|
||||||
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
||||||
|
|
||||||
@ -1573,7 +1576,7 @@ namespace WAN {
|
|||||||
y = modulate_add(ctx->ggml_ctx, y, es[3]);
|
y = modulate_add(ctx->ggml_ctx, y, es[3]);
|
||||||
|
|
||||||
y = ffn_0->forward(ctx, y);
|
y = ffn_0->forward(ctx, y);
|
||||||
y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
|
y = ggml_gelu_inplace(ctx->ggml_ctx, y);
|
||||||
y = ffn_2->forward(ctx, y);
|
y = ffn_2->forward(ctx, y);
|
||||||
|
|
||||||
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
|
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
|
||||||
@ -1720,7 +1723,7 @@ namespace WAN {
|
|||||||
|
|
||||||
auto x = proj_0->forward(ctx, image_embeds);
|
auto x = proj_0->forward(ctx, image_embeds);
|
||||||
x = proj_1->forward(ctx, x);
|
x = proj_1->forward(ctx, x);
|
||||||
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = proj_3->forward(ctx, x);
|
x = proj_3->forward(ctx, x);
|
||||||
x = proj_4->forward(ctx, x);
|
x = proj_4->forward(ctx, x);
|
||||||
|
|
||||||
@ -1907,7 +1910,7 @@ namespace WAN {
|
|||||||
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
||||||
|
|
||||||
context = text_embedding_0->forward(ctx, context);
|
context = text_embedding_0->forward(ctx, context);
|
||||||
context = ggml_ext_gelu(ctx->ggml_ctx, context);
|
context = ggml_gelu(ctx->ggml_ctx, context);
|
||||||
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
int64_t context_img_len = 0;
|
int64_t context_img_len = 0;
|
||||||
@ -1946,7 +1949,7 @@ namespace WAN {
|
|||||||
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
|
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
|
||||||
auto c_skip = result.first;
|
auto c_skip = result.first;
|
||||||
c = result.second;
|
c = result.second;
|
||||||
c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
|
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength);
|
||||||
x = ggml_add(ctx->ggml_ctx, x, c_skip);
|
x = ggml_add(ctx->ggml_ctx, x, c_skip);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
40
z_image.hpp
40
z_image.hpp
@ -54,37 +54,15 @@ namespace ZImage {
|
|||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
||||||
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
||||||
|
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
|
||||||
|
|
||||||
auto q = ggml_view_4d(ctx->ggml_ctx,
|
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim]
|
||||||
qkv,
|
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim]
|
||||||
qkv->ne[0],
|
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim]
|
||||||
num_heads,
|
|
||||||
qkv->ne[2],
|
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim]
|
||||||
qkv->ne[3],
|
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
||||||
qkv->nb[1],
|
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
||||||
qkv->nb[2],
|
|
||||||
qkv->nb[3],
|
|
||||||
0); // [N, n_token, num_heads, head_dim]
|
|
||||||
auto k = ggml_view_4d(ctx->ggml_ctx,
|
|
||||||
qkv,
|
|
||||||
qkv->ne[0],
|
|
||||||
num_kv_heads,
|
|
||||||
qkv->ne[2],
|
|
||||||
qkv->ne[3],
|
|
||||||
qkv->nb[1],
|
|
||||||
qkv->nb[2],
|
|
||||||
qkv->nb[3],
|
|
||||||
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
|
||||||
auto v = ggml_view_4d(ctx->ggml_ctx,
|
|
||||||
qkv,
|
|
||||||
qkv->ne[0],
|
|
||||||
num_kv_heads,
|
|
||||||
qkv->ne[2],
|
|
||||||
qkv->ne[3],
|
|
||||||
qkv->nb[1],
|
|
||||||
qkv->nb[2],
|
|
||||||
qkv->nb[3],
|
|
||||||
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
|
||||||
|
|
||||||
if (qk_norm) {
|
if (qk_norm) {
|
||||||
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
||||||
@ -517,7 +495,7 @@ namespace ZImage {
|
|||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
||||||
|
|
||||||
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
|
out = ggml_scale(ctx->ggml_ctx, out, -1.f);
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user