mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: introduce GGMLRunnerContext (#928)
* introduce GGMLRunnerContext * add Flash Attention enable control through GGMLRunnerContext * add conv2d_direct enable control through GGMLRunnerContext
This commit is contained in:
parent
c42826b77c
commit
6103d86e2c
87
clip.hpp
87
clip.hpp
@ -451,16 +451,16 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
||||||
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
if (use_gelu) {
|
if (use_gelu) {
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_gelu_quick_inplace(ctx, x);
|
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
|
||||||
}
|
}
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
@ -488,15 +488,15 @@ public:
|
|||||||
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
||||||
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
||||||
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
|
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
|
x = ggml_add(ctx->ggml_ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask));
|
||||||
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
|
x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -517,8 +517,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int clip_skip = -1,
|
int clip_skip = -1,
|
||||||
bool mask = true) {
|
bool mask = true) {
|
||||||
@ -536,7 +535,7 @@ public:
|
|||||||
}
|
}
|
||||||
std::string name = "layers." + std::to_string(i);
|
std::string name = "layers." + std::to_string(i);
|
||||||
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
|
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
|
||||||
x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
|
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
|
||||||
// LOG_DEBUG("layer %d", i);
|
// LOG_DEBUG("layer %d", i);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
@ -578,7 +577,7 @@ public:
|
|||||||
return params["token_embedding.weight"];
|
return params["token_embedding.weight"];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* custom_embed_weight) {
|
struct ggml_tensor* custom_embed_weight) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
@ -586,12 +585,12 @@ public:
|
|||||||
auto position_embed_weight = params["position_embedding.weight"];
|
auto position_embed_weight = params["position_embedding.weight"];
|
||||||
|
|
||||||
GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]);
|
GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]);
|
||||||
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
|
input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
|
||||||
auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids);
|
auto token_embedding = ggml_get_rows(ctx->ggml_ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids);
|
||||||
token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
|
token_embedding = ggml_reshape_3d(ctx->ggml_ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]);
|
||||||
|
|
||||||
// token_embedding + position_embedding
|
// token_embedding + position_embedding
|
||||||
auto x = ggml_add(ctx,
|
auto x = ggml_add(ctx->ggml_ctx,
|
||||||
token_embedding,
|
token_embedding,
|
||||||
position_embed_weight); // [N, n_token, embed_dim]
|
position_embed_weight); // [N, n_token, embed_dim]
|
||||||
return x;
|
return x;
|
||||||
@ -629,7 +628,7 @@ public:
|
|||||||
num_positions = num_patches + 1;
|
num_positions = num_patches + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values) {
|
||||||
// pixel_values: [N, num_channels, image_size, image_size]
|
// pixel_values: [N, num_channels, image_size, image_size]
|
||||||
// return: [N, num_positions, embed_dim]
|
// return: [N, num_positions, embed_dim]
|
||||||
GGML_ASSERT(pixel_values->ne[0] == image_size && pixel_values->ne[1] == image_size && pixel_values->ne[2] == num_channels);
|
GGML_ASSERT(pixel_values->ne[0] == image_size && pixel_values->ne[1] == image_size && pixel_values->ne[2] == num_channels);
|
||||||
@ -641,18 +640,18 @@ public:
|
|||||||
// concat(patch_embedding, class_embedding) + position_embedding
|
// concat(patch_embedding, class_embedding) + position_embedding
|
||||||
struct ggml_tensor* patch_embedding;
|
struct ggml_tensor* patch_embedding;
|
||||||
int64_t N = pixel_values->ne[3];
|
int64_t N = pixel_values->ne[3];
|
||||||
patch_embedding = ggml_ext_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
|
patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
|
||||||
patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
|
patch_embedding = ggml_reshape_3d(ctx->ggml_ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
|
||||||
patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
|
patch_embedding = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
|
||||||
patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
|
patch_embedding = ggml_reshape_4d(ctx->ggml_ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
|
||||||
|
|
||||||
struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N);
|
struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, embed_dim, N);
|
||||||
class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim]
|
class_embedding = ggml_repeat(ctx->ggml_ctx, class_embed_weight, class_embedding); // [N, embed_dim]
|
||||||
class_embedding = ggml_reshape_4d(ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1]
|
class_embedding = ggml_reshape_4d(ctx->ggml_ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1]
|
||||||
|
|
||||||
struct ggml_tensor* x = ggml_concat(ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1]
|
struct ggml_tensor* x = ggml_concat(ctx->ggml_ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1]
|
||||||
x = ggml_reshape_3d(ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim]
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim]
|
||||||
x = ggml_add(ctx, x, position_embed_weight);
|
x = ggml_add(ctx->ggml_ctx, x, position_embed_weight);
|
||||||
return x; // [N, num_positions, embed_dim]
|
return x; // [N, num_positions, embed_dim]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -714,8 +713,7 @@ public:
|
|||||||
return embeddings->get_token_embed_weight();
|
return embeddings->get_token_embed_weight();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* tkn_embeddings,
|
struct ggml_tensor* tkn_embeddings,
|
||||||
size_t max_token_idx = 0,
|
size_t max_token_idx = 0,
|
||||||
@ -727,16 +725,16 @@ public:
|
|||||||
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
||||||
|
|
||||||
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
||||||
x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
|
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
|
||||||
if (return_pooled || with_final_ln) {
|
if (return_pooled || with_final_ln) {
|
||||||
x = final_layer_norm->forward(ctx, x);
|
x = final_layer_norm->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
auto text_projection = params["text_projection"];
|
auto text_projection = params["text_projection"];
|
||||||
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
|
ggml_tensor* pooled = ggml_view_1d(ctx->ggml_ctx, x, hidden_size, x->nb[1] * max_token_idx);
|
||||||
if (text_projection != nullptr) {
|
if (text_projection != nullptr) {
|
||||||
pooled = ggml_ext_linear(ctx, pooled, text_projection, nullptr);
|
pooled = ggml_ext_linear(ctx->ggml_ctx, pooled, text_projection, nullptr);
|
||||||
} else {
|
} else {
|
||||||
LOG_DEBUG("identity projection");
|
LOG_DEBUG("identity projection");
|
||||||
}
|
}
|
||||||
@ -779,8 +777,7 @@ public:
|
|||||||
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* pixel_values,
|
struct ggml_tensor* pixel_values,
|
||||||
bool return_pooled = true,
|
bool return_pooled = true,
|
||||||
int clip_skip = -1) {
|
int clip_skip = -1) {
|
||||||
@ -792,14 +789,14 @@ public:
|
|||||||
|
|
||||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||||
x = pre_layernorm->forward(ctx, x);
|
x = pre_layernorm->forward(ctx, x);
|
||||||
x = encoder->forward(ctx, backend, x, clip_skip, false);
|
x = encoder->forward(ctx, x, clip_skip, false);
|
||||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||||
auto last_hidden_state = x;
|
auto last_hidden_state = x;
|
||||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
GGML_ASSERT(x->ne[3] == 1);
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
|
ggml_tensor* pooled = ggml_cont(ctx->ggml_ctx, ggml_view_2d(ctx->ggml_ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
|
||||||
return pooled; // [N, hidden_size]
|
return pooled; // [N, hidden_size]
|
||||||
} else {
|
} else {
|
||||||
// return x; // [N, n_token, hidden_size]
|
// return x; // [N, n_token, hidden_size]
|
||||||
@ -831,12 +828,12 @@ public:
|
|||||||
out_features(out_features),
|
out_features(out_features),
|
||||||
transpose_weight(transpose_weight) {}
|
transpose_weight(transpose_weight) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
if (transpose_weight) {
|
if (transpose_weight) {
|
||||||
w = ggml_cont(ctx, ggml_transpose(ctx, w));
|
w = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, w));
|
||||||
}
|
}
|
||||||
return ggml_ext_linear(ctx, x, w, nullptr);
|
return ggml_ext_linear(ctx->ggml_ctx, x, w, nullptr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -860,8 +857,7 @@ public:
|
|||||||
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
|
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* pixel_values,
|
struct ggml_tensor* pixel_values,
|
||||||
bool return_pooled = true,
|
bool return_pooled = true,
|
||||||
int clip_skip = -1) {
|
int clip_skip = -1) {
|
||||||
@ -870,7 +866,7 @@ public:
|
|||||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||||
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
|
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
|
||||||
|
|
||||||
auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
|
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
|
||||||
|
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
||||||
@ -902,8 +898,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
model.get_param_tensors(tensors, prefix);
|
model.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* embeddings,
|
struct ggml_tensor* embeddings,
|
||||||
size_t max_token_idx = 0,
|
size_t max_token_idx = 0,
|
||||||
@ -913,10 +908,10 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
size_t n_token = input_ids->ne[0];
|
size_t n_token = input_ids->ne[0];
|
||||||
if (input_ids->ne[0] > model.n_token) {
|
if (input_ids->ne[0] > model.n_token) {
|
||||||
GGML_ASSERT(input_ids->ne[0] % model.n_token == 0);
|
GGML_ASSERT(input_ids->ne[0] % model.n_token == 0);
|
||||||
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||||
@ -943,7 +938,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
|||||||
126
common.hpp
126
common.hpp
@ -23,12 +23,12 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, channels, h, w]
|
// x: [N, channels, h, w]
|
||||||
if (vae_downsample) {
|
if (vae_downsample) {
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
||||||
|
|
||||||
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
||||||
x = conv->forward(ctx, x);
|
x = conv->forward(ctx, x);
|
||||||
} else {
|
} else {
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
|
||||||
@ -52,12 +52,12 @@ public:
|
|||||||
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, channels, h, w]
|
// x: [N, channels, h, w]
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
||||||
|
|
||||||
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
|
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
|
||||||
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
|
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -121,7 +121,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) {
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) {
|
||||||
// For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml
|
// For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml
|
||||||
// [N, c, t, h, w] => [N, c, t, h * w]
|
// [N, c, t, h, w] => [N, c, t, h * w]
|
||||||
// x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w]
|
// x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w]
|
||||||
@ -137,32 +137,32 @@ public:
|
|||||||
|
|
||||||
// in_layers
|
// in_layers
|
||||||
auto h = in_layers_0->forward(ctx, x);
|
auto h = in_layers_0->forward(ctx, x);
|
||||||
h = ggml_silu_inplace(ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
||||||
|
|
||||||
// emb_layers
|
// emb_layers
|
||||||
if (!skip_t_emb) {
|
if (!skip_t_emb) {
|
||||||
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
|
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
|
||||||
|
|
||||||
auto emb_out = ggml_silu(ctx, emb);
|
auto emb_out = ggml_silu(ctx->ggml_ctx, emb);
|
||||||
emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels]
|
emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels]
|
||||||
|
|
||||||
if (dims == 2) {
|
if (dims == 2) {
|
||||||
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
|
emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
|
||||||
} else {
|
} else {
|
||||||
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
|
emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
|
||||||
if (exchange_temb_dims) {
|
if (exchange_temb_dims) {
|
||||||
// emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
// emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||||
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
|
emb_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
h = ggml_add(ctx->ggml_ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
||||||
}
|
}
|
||||||
|
|
||||||
// out_layers
|
// out_layers
|
||||||
h = out_layers_0->forward(ctx, h);
|
h = out_layers_0->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
// dropout, skip for inference
|
// dropout, skip for inference
|
||||||
h = out_layers_3->forward(ctx, h);
|
h = out_layers_3->forward(ctx, h);
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ public:
|
|||||||
x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
||||||
}
|
}
|
||||||
|
|
||||||
h = ggml_add(ctx, h, x);
|
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||||
return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -193,24 +193,24 @@ public:
|
|||||||
GEGLU(int64_t dim_in, int64_t dim_out)
|
GEGLU(int64_t dim_in, int64_t dim_out)
|
||||||
: dim_in(dim_in), dim_out(dim_out) {}
|
: dim_in(dim_in), dim_out(dim_out) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [ne3, ne2, ne1, dim_in]
|
// x: [ne3, ne2, ne1, dim_in]
|
||||||
// return: [ne3, ne2, ne1, dim_out]
|
// return: [ne3, ne2, ne1, dim_out]
|
||||||
struct ggml_tensor* w = params["proj.weight"];
|
struct ggml_tensor* w = params["proj.weight"];
|
||||||
struct ggml_tensor* b = params["proj.bias"];
|
struct ggml_tensor* b = params["proj.bias"];
|
||||||
|
|
||||||
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
|
auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
|
||||||
auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
|
auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
|
||||||
auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
|
auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
|
||||||
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
|
auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
|
||||||
|
|
||||||
auto x_in = x;
|
auto x_in = x;
|
||||||
x = ggml_ext_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
|
x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
|
||||||
auto gate = ggml_ext_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
|
auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
gate = ggml_gelu_inplace(ctx, gate);
|
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -222,13 +222,13 @@ public:
|
|||||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out, bias));
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out, bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [ne3, ne2, ne1, dim_in]
|
// x: [ne3, ne2, ne1, dim_in]
|
||||||
// return: [ne3, ne2, ne1, dim_out]
|
// return: [ne3, ne2, ne1, dim_out]
|
||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -262,7 +262,7 @@ public:
|
|||||||
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
|
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [ne3, ne2, ne1, dim]
|
// x: [ne3, ne2, ne1, dim]
|
||||||
// return: [ne3, ne2, ne1, dim_out]
|
// return: [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
@ -281,19 +281,16 @@ protected:
|
|||||||
int64_t context_dim;
|
int64_t context_dim;
|
||||||
int64_t n_head;
|
int64_t n_head;
|
||||||
int64_t d_head;
|
int64_t d_head;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CrossAttention(int64_t query_dim,
|
CrossAttention(int64_t query_dim,
|
||||||
int64_t context_dim,
|
int64_t context_dim,
|
||||||
int64_t n_head,
|
int64_t n_head,
|
||||||
int64_t d_head,
|
int64_t d_head)
|
||||||
bool flash_attn = false)
|
|
||||||
: n_head(n_head),
|
: n_head(n_head),
|
||||||
d_head(d_head),
|
d_head(d_head),
|
||||||
query_dim(query_dim),
|
query_dim(query_dim),
|
||||||
context_dim(context_dim),
|
context_dim(context_dim) {
|
||||||
flash_attn(flash_attn) {
|
|
||||||
int64_t inner_dim = d_head * n_head;
|
int64_t inner_dim = d_head * n_head;
|
||||||
|
|
||||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
|
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
|
||||||
@ -304,8 +301,7 @@ public:
|
|||||||
// to_out_1 is nn.Dropout(), skip for inference
|
// to_out_1 is nn.Dropout(), skip for inference
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context) {
|
||||||
// x: [N, n_token, query_dim]
|
// x: [N, n_token, query_dim]
|
||||||
@ -325,7 +321,7 @@ public:
|
|||||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||||
|
|
||||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||||
return x;
|
return x;
|
||||||
@ -343,16 +339,15 @@ public:
|
|||||||
int64_t n_head,
|
int64_t n_head,
|
||||||
int64_t d_head,
|
int64_t d_head,
|
||||||
int64_t context_dim,
|
int64_t context_dim,
|
||||||
bool ff_in = false,
|
bool ff_in = false)
|
||||||
bool flash_attn = false)
|
|
||||||
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
|
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
|
||||||
// disable_self_attn is always False
|
// disable_self_attn is always False
|
||||||
// disable_temporal_crossattention is always False
|
// disable_temporal_crossattention is always False
|
||||||
// switch_temporal_ca_to_sa is always False
|
// switch_temporal_ca_to_sa is always False
|
||||||
// inner_dim is always None or equal to dim
|
// inner_dim is always None or equal to dim
|
||||||
// gated_ff is always True
|
// gated_ff is always True
|
||||||
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
|
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
|
||||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
|
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
|
||||||
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
|
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||||
@ -364,8 +359,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context) {
|
||||||
// x: [N, n_token, query_dim]
|
// x: [N, n_token, query_dim]
|
||||||
@ -387,21 +381,21 @@ public:
|
|||||||
x = norm_in->forward(ctx, x);
|
x = norm_in->forward(ctx, x);
|
||||||
x = ff_in->forward(ctx, x);
|
x = ff_in->forward(ctx, x);
|
||||||
// self.is_res is always True
|
// self.is_res is always True
|
||||||
x = ggml_add(ctx, x, x_skip);
|
x = ggml_add(ctx->ggml_ctx, x, x_skip);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto r = x;
|
auto r = x;
|
||||||
x = norm1->forward(ctx, x);
|
x = norm1->forward(ctx, x);
|
||||||
x = attn1->forward(ctx, backend, x, x); // self-attention
|
x = attn1->forward(ctx, x, x); // self-attention
|
||||||
x = ggml_add(ctx, x, r);
|
x = ggml_add(ctx->ggml_ctx, x, r);
|
||||||
r = x;
|
r = x;
|
||||||
x = norm2->forward(ctx, x);
|
x = norm2->forward(ctx, x);
|
||||||
x = attn2->forward(ctx, backend, x, context); // cross-attention
|
x = attn2->forward(ctx, x, context); // cross-attention
|
||||||
x = ggml_add(ctx, x, r);
|
x = ggml_add(ctx->ggml_ctx, x, r);
|
||||||
r = x;
|
r = x;
|
||||||
x = norm3->forward(ctx, x);
|
x = norm3->forward(ctx, x);
|
||||||
x = ff->forward(ctx, x);
|
x = ff->forward(ctx, x);
|
||||||
x = ggml_add(ctx, x, r);
|
x = ggml_add(ctx->ggml_ctx, x, r);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -420,8 +414,7 @@ public:
|
|||||||
int64_t n_head,
|
int64_t n_head,
|
||||||
int64_t d_head,
|
int64_t d_head,
|
||||||
int64_t depth,
|
int64_t depth,
|
||||||
int64_t context_dim,
|
int64_t context_dim)
|
||||||
bool flash_attn = false)
|
|
||||||
: in_channels(in_channels),
|
: in_channels(in_channels),
|
||||||
n_head(n_head),
|
n_head(n_head),
|
||||||
d_head(d_head),
|
d_head(d_head),
|
||||||
@ -435,14 +428,13 @@ public:
|
|||||||
|
|
||||||
for (int i = 0; i < depth; i++) {
|
for (int i = 0; i < depth; i++) {
|
||||||
std::string name = "transformer_blocks." + std::to_string(i);
|
std::string name = "transformer_blocks." + std::to_string(i);
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
|
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context) {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
@ -460,23 +452,23 @@ public:
|
|||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
||||||
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
||||||
|
|
||||||
for (int i = 0; i < depth; i++) {
|
for (int i = 0; i < depth; i++) {
|
||||||
std::string name = "transformer_blocks." + std::to_string(i);
|
std::string name = "transformer_blocks." + std::to_string(i);
|
||||||
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
|
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
|
||||||
|
|
||||||
x = transformer_block->forward(ctx, backend, x, context);
|
x = transformer_block->forward(ctx, x, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
||||||
|
|
||||||
// proj_out
|
// proj_out
|
||||||
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
||||||
|
|
||||||
x = ggml_add(ctx, x, x_in);
|
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -503,14 +495,14 @@ public:
|
|||||||
// since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern
|
// since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x_spatial,
|
struct ggml_tensor* x_spatial,
|
||||||
struct ggml_tensor* x_temporal) {
|
struct ggml_tensor* x_temporal) {
|
||||||
// image_only_indicator is always tensor([0.])
|
// image_only_indicator is always tensor([0.])
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
auto x = ggml_add(ctx,
|
auto x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_scale(ctx, x_spatial, alpha),
|
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
|
||||||
ggml_scale(ctx, x_temporal, 1.0f - alpha));
|
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -528,7 +520,7 @@ public:
|
|||||||
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
|
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* emb,
|
struct ggml_tensor* emb,
|
||||||
int num_video_frames) {
|
int num_video_frames) {
|
||||||
@ -546,18 +538,18 @@ public:
|
|||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||||
auto x_mix = x;
|
auto x_mix = x;
|
||||||
|
|
||||||
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
|
emb = ggml_reshape_4d(ctx->ggml_ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
|
||||||
|
|
||||||
x = time_stack->forward(ctx, x, emb); // b t c (h w)
|
x = time_stack->forward(ctx, x, emb); // b t c (h w)
|
||||||
|
|
||||||
x = time_mixer->forward(ctx, x_mix, x); // b t c (h w)
|
x = time_mixer->forward(ctx, x_mix, x); // b t c (h w)
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -641,7 +641,9 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
|||||||
|
|
||||||
pixel_values = to_backend(pixel_values);
|
pixel_values = to_backend(pixel_values);
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* hidden_states = vision_model.forward(&runner_ctx, pixel_values, return_pooled, clip_skip);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
|||||||
52
control.hpp
52
control.hpp
@ -165,7 +165,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* resblock_forward(std::string name,
|
struct ggml_tensor* resblock_forward(std::string name,
|
||||||
struct ggml_context* ctx,
|
GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* emb) {
|
struct ggml_tensor* emb) {
|
||||||
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
|
||||||
@ -173,15 +173,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* attention_layer_forward(std::string name,
|
struct ggml_tensor* attention_layer_forward(std::string name,
|
||||||
struct ggml_context* ctx,
|
GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context) {
|
||||||
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
|
||||||
return block->forward(ctx, backend, x, context);
|
return block->forward(ctx, x, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
|
struct ggml_tensor* input_hint_block_forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* hint,
|
struct ggml_tensor* hint,
|
||||||
struct ggml_tensor* emb,
|
struct ggml_tensor* emb,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context) {
|
||||||
@ -193,14 +192,13 @@ public:
|
|||||||
|
|
||||||
h = block->forward(ctx, h);
|
h = block->forward(ctx, h);
|
||||||
} else {
|
} else {
|
||||||
h = ggml_silu_inplace(ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::vector<struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* hint,
|
struct ggml_tensor* hint,
|
||||||
struct ggml_tensor* guided_hint,
|
struct ggml_tensor* guided_hint,
|
||||||
@ -213,13 +211,13 @@ public:
|
|||||||
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
||||||
if (context != nullptr) {
|
if (context != nullptr) {
|
||||||
if (context->ne[2] != x->ne[3]) {
|
if (context->ne[2] != x->ne[3]) {
|
||||||
context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
|
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (y != nullptr) {
|
if (y != nullptr) {
|
||||||
if (y->ne[1] != x->ne[3]) {
|
if (y->ne[1] != x->ne[3]) {
|
||||||
y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
|
y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,10 +228,10 @@ public:
|
|||||||
|
|
||||||
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);
|
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);
|
||||||
|
|
||||||
auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
|
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels]
|
||||||
|
|
||||||
auto emb = time_embed_0->forward(ctx, t_emb);
|
auto emb = time_embed_0->forward(ctx, t_emb);
|
||||||
emb = ggml_silu_inplace(ctx, emb);
|
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
|
||||||
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
|
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
|
||||||
|
|
||||||
// SDXL/SVD
|
// SDXL/SVD
|
||||||
@ -242,10 +240,10 @@ public:
|
|||||||
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
|
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
|
||||||
|
|
||||||
auto label_emb = label_embed_0->forward(ctx, y);
|
auto label_emb = label_embed_0->forward(ctx, y);
|
||||||
label_emb = ggml_silu_inplace(ctx, label_emb);
|
label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb);
|
||||||
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
|
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
|
||||||
|
|
||||||
emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim]
|
emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim]
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<struct ggml_tensor*> outs;
|
std::vector<struct ggml_tensor*> outs;
|
||||||
@ -259,7 +257,7 @@ public:
|
|||||||
|
|
||||||
// input block 0
|
// input block 0
|
||||||
auto h = input_blocks_0_0->forward(ctx, x);
|
auto h = input_blocks_0_0->forward(ctx, x);
|
||||||
h = ggml_add(ctx, h, guided_hint);
|
h = ggml_add(ctx->ggml_ctx, h, guided_hint);
|
||||||
outs.push_back(zero_convs_0->forward(ctx, h));
|
outs.push_back(zero_convs_0->forward(ctx, h));
|
||||||
|
|
||||||
// input block 1-11
|
// input block 1-11
|
||||||
@ -274,7 +272,7 @@ public:
|
|||||||
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
|
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
|
||||||
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
||||||
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
||||||
h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
|
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
|
||||||
}
|
}
|
||||||
|
|
||||||
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
|
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
|
||||||
@ -298,9 +296,9 @@ public:
|
|||||||
// [N, 4*model_channels, h/8, w/8]
|
// [N, 4*model_channels, h/8, w/8]
|
||||||
|
|
||||||
// middle_block
|
// middle_block
|
||||||
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
|
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
|
||||||
|
|
||||||
// out
|
// out
|
||||||
outs.push_back(middle_block_out->forward(ctx, h));
|
outs.push_back(middle_block_out->forward(ctx, h));
|
||||||
@ -326,17 +324,6 @@ struct ControlNet : public GGMLRunner {
|
|||||||
control_net.init(params_ctx, tensor_types, "");
|
control_net.init(params_ctx, tensor_types, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
control_net.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~ControlNet() override {
|
~ControlNet() override {
|
||||||
free_control_ctx();
|
free_control_ctx();
|
||||||
}
|
}
|
||||||
@ -404,8 +391,9 @@ struct ControlNet : public GGMLRunner {
|
|||||||
y = to_backend(y);
|
y = to_backend(y);
|
||||||
timesteps = to_backend(timesteps);
|
timesteps = to_backend(timesteps);
|
||||||
|
|
||||||
auto outs = control_net.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
|
||||||
|
auto outs = control_net.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
hint,
|
hint,
|
||||||
guided_hint_cached ? guided_hint : nullptr,
|
guided_hint_cached ? guided_hint : nullptr,
|
||||||
|
|||||||
@ -36,6 +36,7 @@ struct DiffusionModel {
|
|||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||||
virtual size_t get_params_buffer_size() = 0;
|
virtual size_t get_params_buffer_size() = 0;
|
||||||
virtual int64_t get_adm_in_channels() = 0;
|
virtual int64_t get_adm_in_channels() = 0;
|
||||||
|
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct UNetModel : public DiffusionModel {
|
struct UNetModel : public DiffusionModel {
|
||||||
@ -44,9 +45,8 @@ struct UNetModel : public DiffusionModel {
|
|||||||
UNetModel(ggml_backend_t backend,
|
UNetModel(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1)
|
||||||
bool flash_attn = false)
|
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
|
||||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -77,6 +77,10 @@ struct UNetModel : public DiffusionModel {
|
|||||||
return unet.unet.adm_in_channels;
|
return unet.unet.adm_in_channels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attn_enabled(bool enabled) {
|
||||||
|
unet.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void compute(int n_threads,
|
void compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -98,9 +102,8 @@ struct MMDiTModel : public DiffusionModel {
|
|||||||
|
|
||||||
MMDiTModel(ggml_backend_t backend,
|
MMDiTModel(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
bool flash_attn = false,
|
|
||||||
const String2GGMLType& tensor_types = {})
|
const String2GGMLType& tensor_types = {})
|
||||||
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
|
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -131,6 +134,10 @@ struct MMDiTModel : public DiffusionModel {
|
|||||||
return 768 + 1280;
|
return 768 + 1280;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attn_enabled(bool enabled) {
|
||||||
|
mmdit.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void compute(int n_threads,
|
void compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -153,9 +160,8 @@ struct FluxModel : public DiffusionModel {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
SDVersion version = VERSION_FLUX,
|
SDVersion version = VERSION_FLUX,
|
||||||
bool flash_attn = false,
|
|
||||||
bool use_mask = false)
|
bool use_mask = false)
|
||||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
|
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -186,6 +192,10 @@ struct FluxModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attn_enabled(bool enabled) {
|
||||||
|
flux.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void compute(int n_threads,
|
void compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -213,9 +223,8 @@ struct WanModel : public DiffusionModel {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "model.diffusion_model",
|
const std::string prefix = "model.diffusion_model",
|
||||||
SDVersion version = VERSION_WAN2,
|
SDVersion version = VERSION_WAN2)
|
||||||
bool flash_attn = false)
|
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -246,6 +255,10 @@ struct WanModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attn_enabled(bool enabled) {
|
||||||
|
wan.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void compute(int n_threads,
|
void compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -272,9 +285,8 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "model.diffusion_model",
|
const std::string prefix = "model.diffusion_model",
|
||||||
SDVersion version = VERSION_QWEN_IMAGE,
|
SDVersion version = VERSION_QWEN_IMAGE)
|
||||||
bool flash_attn = false)
|
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -305,6 +317,10 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attn_enabled(bool enabled) {
|
||||||
|
qwen_image.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void compute(int n_threads,
|
void compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
|
|||||||
49
esrgan.hpp
49
esrgan.hpp
@ -27,11 +27,11 @@ public:
|
|||||||
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
return ggml_leaky_relu(ctx, x, 0.2f, true);
|
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [n, num_feat, h, w]
|
// x: [n, num_feat, h, w]
|
||||||
// return: [n, num_feat, h, w]
|
// return: [n, num_feat, h, w]
|
||||||
|
|
||||||
@ -42,16 +42,16 @@ public:
|
|||||||
auto conv5 = std::dynamic_pointer_cast<Conv2d>(blocks["conv5"]);
|
auto conv5 = std::dynamic_pointer_cast<Conv2d>(blocks["conv5"]);
|
||||||
|
|
||||||
auto x1 = lrelu(ctx, conv1->forward(ctx, x));
|
auto x1 = lrelu(ctx, conv1->forward(ctx, x));
|
||||||
auto x_cat = ggml_concat(ctx, x, x1, 2);
|
auto x_cat = ggml_concat(ctx->ggml_ctx, x, x1, 2);
|
||||||
auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat));
|
auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat));
|
||||||
x_cat = ggml_concat(ctx, x_cat, x2, 2);
|
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x2, 2);
|
||||||
auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat));
|
auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat));
|
||||||
x_cat = ggml_concat(ctx, x_cat, x3, 2);
|
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x3, 2);
|
||||||
auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat));
|
auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat));
|
||||||
x_cat = ggml_concat(ctx, x_cat, x4, 2);
|
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
|
||||||
auto x5 = conv5->forward(ctx, x_cat);
|
auto x5 = conv5->forward(ctx, x_cat);
|
||||||
|
|
||||||
x5 = ggml_add(ctx, ggml_scale(ctx, x5, 0.2f), x);
|
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
|
||||||
return x5;
|
return x5;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -64,7 +64,7 @@ public:
|
|||||||
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
|
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [n, num_feat, h, w]
|
// x: [n, num_feat, h, w]
|
||||||
// return: [n, num_feat, h, w]
|
// return: [n, num_feat, h, w]
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ public:
|
|||||||
out = rdb2->forward(ctx, out);
|
out = rdb2->forward(ctx, out);
|
||||||
out = rdb3->forward(ctx, out);
|
out = rdb3->forward(ctx, out);
|
||||||
|
|
||||||
out = ggml_add(ctx, ggml_scale(ctx, out, 0.2f), x);
|
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -112,11 +112,11 @@ public:
|
|||||||
int get_scale() { return scale; }
|
int get_scale() { return scale; }
|
||||||
int get_num_block() { return num_block; }
|
int get_num_block() { return num_block; }
|
||||||
|
|
||||||
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
return ggml_leaky_relu(ctx, x, 0.2f, true);
|
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [n, num_in_ch, h, w]
|
// x: [n, num_in_ch, h, w]
|
||||||
// return: [n, num_out_ch, h*scale, w*scale]
|
// return: [n, num_out_ch, h*scale, w*scale]
|
||||||
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
|
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
|
||||||
@ -133,14 +133,14 @@ public:
|
|||||||
body_feat = block->forward(ctx, body_feat);
|
body_feat = block->forward(ctx, body_feat);
|
||||||
}
|
}
|
||||||
body_feat = conv_body->forward(ctx, body_feat);
|
body_feat = conv_body->forward(ctx, body_feat);
|
||||||
feat = ggml_add(ctx, feat, body_feat);
|
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
|
||||||
// upsample
|
// upsample
|
||||||
if (scale >= 2) {
|
if (scale >= 2) {
|
||||||
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
||||||
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||||
if (scale == 4) {
|
if (scale == 4) {
|
||||||
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
||||||
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// for all scales
|
// for all scales
|
||||||
@ -161,19 +161,6 @@ struct ESRGAN : public GGMLRunner {
|
|||||||
// rrdb_net will be created in load_from_file
|
// rrdb_net will be created in load_from_file
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
if (!rrdb_net)
|
|
||||||
return;
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
rrdb_net->get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
return "esrgan";
|
return "esrgan";
|
||||||
}
|
}
|
||||||
@ -359,7 +346,9 @@ struct ESRGAN : public GGMLRunner {
|
|||||||
constexpr int kGraphNodes = 1 << 16; // 65k
|
constexpr int kGraphNodes = 1 << 16; // 65k
|
||||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
|
||||||
x = to_backend(x);
|
x = to_backend(x);
|
||||||
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
|
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
struct ggml_tensor* out = rrdb_net->forward(&runner_ctx, x);
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|||||||
374
flux.hpp
374
flux.hpp
@ -19,14 +19,14 @@ namespace Flux {
|
|||||||
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
|
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [..., in_dim]
|
// x: [..., in_dim]
|
||||||
// return: [..., hidden_dim]
|
// return: [..., hidden_dim]
|
||||||
auto in_layer = std::dynamic_pointer_cast<Linear>(blocks["in_layer"]);
|
auto in_layer = std::dynamic_pointer_cast<Linear>(blocks["in_layer"]);
|
||||||
auto out_layer = std::dynamic_pointer_cast<Linear>(blocks["out_layer"]);
|
auto out_layer = std::dynamic_pointer_cast<Linear>(blocks["out_layer"]);
|
||||||
|
|
||||||
x = in_layer->forward(ctx, x);
|
x = in_layer->forward(ctx, x);
|
||||||
x = ggml_silu_inplace(ctx, x);
|
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
||||||
x = out_layer->forward(ctx, x);
|
x = out_layer->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -48,10 +48,10 @@ namespace Flux {
|
|||||||
: hidden_size(hidden_size),
|
: hidden_size(hidden_size),
|
||||||
eps(eps) {}
|
eps(eps) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
struct ggml_tensor* w = params["scale"];
|
struct ggml_tensor* w = params["scale"];
|
||||||
x = ggml_rms_norm(ctx, x, eps);
|
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
|
||||||
x = ggml_mul(ctx, x, w);
|
x = ggml_mul(ctx->ggml_ctx, x, w);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -63,7 +63,7 @@ namespace Flux {
|
|||||||
blocks["key_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim));
|
blocks["key_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* query_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [..., dim]
|
// x: [..., dim]
|
||||||
// return: [..., dim]
|
// return: [..., dim]
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["query_norm"]);
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["query_norm"]);
|
||||||
@ -72,7 +72,7 @@ namespace Flux {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* key_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [..., dim]
|
// x: [..., dim]
|
||||||
// return: [..., dim]
|
// return: [..., dim]
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["key_norm"]);
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["key_norm"]);
|
||||||
@ -85,13 +85,11 @@ namespace Flux {
|
|||||||
struct SelfAttention : public GGMLBlock {
|
struct SelfAttention : public GGMLBlock {
|
||||||
public:
|
public:
|
||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SelfAttention(int64_t dim,
|
SelfAttention(int64_t dim,
|
||||||
int64_t num_heads = 8,
|
int64_t num_heads = 8,
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false)
|
||||||
bool flash_attn = false)
|
|
||||||
: num_heads(num_heads) {
|
: num_heads(num_heads) {
|
||||||
int64_t head_dim = dim / num_heads;
|
int64_t head_dim = dim / num_heads;
|
||||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||||
@ -99,39 +97,38 @@ namespace Flux {
|
|||||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
||||||
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x);
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
auto qkv_vec = split_qkv(ctx, qkv);
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||||
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
||||||
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
||||||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]);
|
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]);
|
||||||
q = norm->query_norm(ctx, q);
|
q = norm->query_norm(ctx, q);
|
||||||
k = norm->key_norm(ctx, k);
|
k = norm->key_norm(ctx, k);
|
||||||
return {q, k, v};
|
return {q, k, v};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x); // [N, n_token, dim]
|
x = proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask) {
|
struct ggml_tensor* mask) {
|
||||||
// x: [N, n_token, dim]
|
// x: [N, n_token, dim]
|
||||||
// pe: [n_token, d_head/2, 2, 2]
|
// pe: [n_token, d_head/2, 2, 2]
|
||||||
// return [N, n_token, dim]
|
// return [N, n_token, dim]
|
||||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||||
x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
|
x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask); // [N, n_token, dim]
|
||||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -144,11 +141,11 @@ namespace Flux {
|
|||||||
ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr)
|
ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr)
|
||||||
: shift(shift), scale(scale), gate(gate) {}
|
: shift(shift), scale(scale), gate(gate) {}
|
||||||
|
|
||||||
ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) {
|
ModulationOut(GGMLRunnerContext* ctx, ggml_tensor* vec, int64_t offset) {
|
||||||
int64_t stride = vec->nb[1] * vec->ne[1];
|
int64_t stride = vec->nb[1] * vec->ne[1];
|
||||||
shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
||||||
scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
||||||
gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
|
gate = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -164,16 +161,16 @@ namespace Flux {
|
|||||||
blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier));
|
blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ModulationOut> forward(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
std::vector<ModulationOut> forward(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||||
// x: [N, dim]
|
// x: [N, dim]
|
||||||
// return: [ModulationOut, ModulationOut]
|
// return: [ModulationOut, ModulationOut]
|
||||||
auto lin = std::dynamic_pointer_cast<Linear>(blocks["lin"]);
|
auto lin = std::dynamic_pointer_cast<Linear>(blocks["lin"]);
|
||||||
|
|
||||||
auto out = ggml_silu(ctx, vec);
|
auto out = ggml_silu(ctx->ggml_ctx, vec);
|
||||||
out = lin->forward(ctx, out); // [N, multiplier*dim]
|
out = lin->forward(ctx, out); // [N, multiplier*dim]
|
||||||
|
|
||||||
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
|
auto m = ggml_reshape_3d(ctx->ggml_ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
|
||||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
|
||||||
|
|
||||||
ModulationOut m_0 = ModulationOut(ctx, m, 0);
|
ModulationOut m_0 = ModulationOut(ctx, m, 0);
|
||||||
if (is_double) {
|
if (is_double) {
|
||||||
@ -199,7 +196,6 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct DoubleStreamBlock : public GGMLBlock {
|
struct DoubleStreamBlock : public GGMLBlock {
|
||||||
bool flash_attn;
|
|
||||||
bool prune_mod;
|
bool prune_mod;
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
|
|
||||||
@ -207,17 +203,16 @@ namespace Flux {
|
|||||||
DoubleStreamBlock(int64_t hidden_size,
|
DoubleStreamBlock(int64_t hidden_size,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
float mlp_ratio,
|
float mlp_ratio,
|
||||||
int idx = 0,
|
int idx = 0,
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool flash_attn = false,
|
bool prune_mod = false)
|
||||||
bool prune_mod = false)
|
: idx(idx), prune_mod(prune_mod) {
|
||||||
: idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
|
|
||||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||||
if (!prune_mod) {
|
if (!prune_mod) {
|
||||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||||
}
|
}
|
||||||
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||||
|
|
||||||
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||||
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||||
@ -228,7 +223,7 @@ namespace Flux {
|
|||||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||||
}
|
}
|
||||||
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||||
|
|
||||||
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||||
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||||
@ -236,7 +231,7 @@ namespace Flux {
|
|||||||
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
|
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ModulationOut> get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||||
// TODO: not hardcoded?
|
// TODO: not hardcoded?
|
||||||
const int single_blocks_count = 38;
|
const int single_blocks_count = 38;
|
||||||
const int double_blocks_count = 19;
|
const int double_blocks_count = 19;
|
||||||
@ -245,7 +240,7 @@ namespace Flux {
|
|||||||
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ModulationOut> get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
std::vector<ModulationOut> get_distil_txt_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||||
// TODO: not hardcoded?
|
// TODO: not hardcoded?
|
||||||
const int single_blocks_count = 38;
|
const int single_blocks_count = 38;
|
||||||
const int double_blocks_count = 19;
|
const int double_blocks_count = 19;
|
||||||
@ -254,8 +249,7 @@ namespace Flux {
|
|||||||
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* img,
|
struct ggml_tensor* img,
|
||||||
struct ggml_tensor* txt,
|
struct ggml_tensor* txt,
|
||||||
struct ggml_tensor* vec,
|
struct ggml_tensor* vec,
|
||||||
@ -300,7 +294,7 @@ namespace Flux {
|
|||||||
|
|
||||||
// prepare image for attention
|
// prepare image for attention
|
||||||
auto img_modulated = img_norm1->forward(ctx, img);
|
auto img_modulated = img_norm1->forward(ctx, img);
|
||||||
img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale);
|
img_modulated = Flux::modulate(ctx->ggml_ctx, img_modulated, img_mod1.shift, img_mod1.scale);
|
||||||
auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head]
|
auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head]
|
||||||
auto img_q = img_qkv[0];
|
auto img_q = img_qkv[0];
|
||||||
auto img_k = img_qkv[1];
|
auto img_k = img_qkv[1];
|
||||||
@ -308,55 +302,55 @@ namespace Flux {
|
|||||||
|
|
||||||
// prepare txt for attention
|
// prepare txt for attention
|
||||||
auto txt_modulated = txt_norm1->forward(ctx, txt);
|
auto txt_modulated = txt_norm1->forward(ctx, txt);
|
||||||
txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale);
|
txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale);
|
||||||
auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head]
|
auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head]
|
||||||
auto txt_q = txt_qkv[0];
|
auto txt_q = txt_qkv[0];
|
||||||
auto txt_k = txt_qkv[1];
|
auto txt_k = txt_qkv[1];
|
||||||
auto txt_v = txt_qkv[2];
|
auto txt_v = txt_qkv[2];
|
||||||
|
|
||||||
// run actual attention
|
// run actual attention
|
||||||
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_txt_token, N, hidden_size]
|
0); // [n_txt_token, N, hidden_size]
|
||||||
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
||||||
auto img_attn_out = ggml_view_3d(ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
// calculate the img bloks
|
// calculate the img bloks
|
||||||
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
||||||
|
|
||||||
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
|
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
|
||||||
img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out);
|
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
|
||||||
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
|
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
|
||||||
|
|
||||||
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
|
||||||
|
|
||||||
// calculate the txt bloks
|
// calculate the txt bloks
|
||||||
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
|
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
|
||||||
|
|
||||||
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
|
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
|
||||||
txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out);
|
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
|
||||||
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
|
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
|
||||||
|
|
||||||
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate));
|
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
|
||||||
|
|
||||||
return {img, txt};
|
return {img, txt};
|
||||||
}
|
}
|
||||||
@ -367,7 +361,6 @@ namespace Flux {
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
int64_t hidden_size;
|
int64_t hidden_size;
|
||||||
int64_t mlp_hidden_dim;
|
int64_t mlp_hidden_dim;
|
||||||
bool flash_attn;
|
|
||||||
bool prune_mod;
|
bool prune_mod;
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
|
|
||||||
@ -377,9 +370,8 @@ namespace Flux {
|
|||||||
float mlp_ratio = 4.0f,
|
float mlp_ratio = 4.0f,
|
||||||
int idx = 0,
|
int idx = 0,
|
||||||
float qk_scale = 0.f,
|
float qk_scale = 0.f,
|
||||||
bool flash_attn = false,
|
|
||||||
bool prune_mod = false)
|
bool prune_mod = false)
|
||||||
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
|
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod) {
|
||||||
int64_t head_dim = hidden_size / num_heads;
|
int64_t head_dim = hidden_size / num_heads;
|
||||||
float scale = qk_scale;
|
float scale = qk_scale;
|
||||||
if (scale <= 0.f) {
|
if (scale <= 0.f) {
|
||||||
@ -397,13 +389,12 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||||
int64_t offset = 3 * idx;
|
int64_t offset = 3 * idx;
|
||||||
return ModulationOut(ctx, vec, offset);
|
return ModulationOut(ctx, vec, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* vec,
|
struct ggml_tensor* vec,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
@ -424,42 +415,42 @@ namespace Flux {
|
|||||||
|
|
||||||
mod = modulation->forward(ctx, vec)[0];
|
mod = modulation->forward(ctx, vec)[0];
|
||||||
}
|
}
|
||||||
auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
||||||
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
|
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
|
||||||
qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
|
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
|
||||||
|
|
||||||
auto qkv = ggml_view_3d(ctx,
|
auto qkv = ggml_view_3d(ctx->ggml_ctx,
|
||||||
qkv_mlp,
|
qkv_mlp,
|
||||||
qkv_mlp->ne[0],
|
qkv_mlp->ne[0],
|
||||||
qkv_mlp->ne[1],
|
qkv_mlp->ne[1],
|
||||||
hidden_size * 3,
|
hidden_size * 3,
|
||||||
qkv_mlp->nb[1],
|
qkv_mlp->nb[1],
|
||||||
qkv_mlp->nb[2],
|
qkv_mlp->nb[2],
|
||||||
0); // [hidden_size * 3 , N, n_token]
|
0); // [hidden_size * 3 , N, n_token]
|
||||||
qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
|
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
|
||||||
auto mlp = ggml_view_3d(ctx,
|
auto mlp = ggml_view_3d(ctx->ggml_ctx,
|
||||||
qkv_mlp,
|
qkv_mlp,
|
||||||
qkv_mlp->ne[0],
|
qkv_mlp->ne[0],
|
||||||
qkv_mlp->ne[1],
|
qkv_mlp->ne[1],
|
||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
qkv_mlp->nb[1],
|
qkv_mlp->nb[1],
|
||||||
qkv_mlp->nb[2],
|
qkv_mlp->nb[2],
|
||||||
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token]
|
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token]
|
||||||
mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim]
|
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim]
|
||||||
|
|
||||||
auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size]
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
|
||||||
int64_t head_dim = hidden_size / num_heads;
|
int64_t head_dim = hidden_size / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
q = norm->query_norm(ctx, q);
|
q = norm->query_norm(ctx, q);
|
||||||
k = norm->key_norm(ctx, k);
|
k = norm->key_norm(ctx, k);
|
||||||
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, ggml_gelu_inplace(ctx->ggml_ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate));
|
output = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, output, mod.gate));
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -480,16 +471,16 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||||
int64_t offset = vec->ne[2] - 2;
|
int64_t offset = vec->ne[2] - 2;
|
||||||
int64_t stride = vec->nb[1] * vec->ne[1];
|
int64_t stride = vec->nb[1] * vec->ne[1];
|
||||||
auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
auto shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
|
||||||
auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
auto scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
|
||||||
// No gate
|
// No gate
|
||||||
return {shift, scale, nullptr};
|
return {shift, scale, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -505,16 +496,16 @@ namespace Flux {
|
|||||||
} else {
|
} else {
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale);
|
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
x = linear->forward(ctx, x);
|
x = linear->forward(ctx, x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -533,7 +524,7 @@ namespace Flux {
|
|||||||
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(inner_size, hidden_size, true));
|
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(inner_size, hidden_size, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks["in_proj"]);
|
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks["in_proj"]);
|
||||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
|
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
|
||||||
|
|
||||||
@ -541,7 +532,7 @@ namespace Flux {
|
|||||||
for (int i = 0; i < n_layers; i++) {
|
for (int i = 0; i < n_layers; i++) {
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norms." + std::to_string(i)]);
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norms." + std::to_string(i)]);
|
||||||
auto embed = std::dynamic_pointer_cast<MLPEmbedder>(blocks["layers." + std::to_string(i)]);
|
auto embed = std::dynamic_pointer_cast<MLPEmbedder>(blocks["layers." + std::to_string(i)]);
|
||||||
x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
|
x = ggml_add_inplace(ctx->ggml_ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
|
||||||
}
|
}
|
||||||
x = out_proj->forward(ctx, x);
|
x = out_proj->forward(ctx, x);
|
||||||
|
|
||||||
@ -556,7 +547,7 @@ namespace Flux {
|
|||||||
blocks["embedder.0"] = std::make_shared<Linear>(in_channels + max_freqs * max_freqs, hidden_size_input);
|
blocks["embedder.0"] = std::make_shared<Linear>(in_channels + max_freqs * max_freqs, hidden_size_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* dct) {
|
struct ggml_tensor* dct) {
|
||||||
// x: (B, P^2, C)
|
// x: (B, P^2, C)
|
||||||
@ -564,8 +555,8 @@ namespace Flux {
|
|||||||
// return: (B, P^2, hidden_size_input)
|
// return: (B, P^2, hidden_size_input)
|
||||||
auto embedder = std::dynamic_pointer_cast<Linear>(blocks["embedder.0"]);
|
auto embedder = std::dynamic_pointer_cast<Linear>(blocks["embedder.0"]);
|
||||||
|
|
||||||
dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
|
dct = ggml_repeat_4d(ctx->ggml_ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
|
||||||
x = ggml_concat(ctx, x, dct, 0);
|
x = ggml_concat(ctx->ggml_ctx, x, dct, 0);
|
||||||
x = embedder->forward(ctx, x);
|
x = embedder->forward(ctx, x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -583,7 +574,7 @@ namespace Flux {
|
|||||||
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size_x);
|
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* s) {
|
struct ggml_tensor* s) {
|
||||||
// x: (batch_size, n_token, hidden_size_x)
|
// x: (batch_size, n_token, hidden_size_x)
|
||||||
@ -596,31 +587,31 @@ namespace Flux {
|
|||||||
int64_t hidden_size_x = x->ne[0];
|
int64_t hidden_size_x = x->ne[0];
|
||||||
|
|
||||||
auto mlp_params = param_generator->forward(ctx, s);
|
auto mlp_params = param_generator->forward(ctx, s);
|
||||||
auto fc_params = ggml_ext_chunk(ctx, mlp_params, 3, 0);
|
auto fc_params = ggml_ext_chunk(ctx->ggml_ctx, mlp_params, 3, 0);
|
||||||
auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
auto fc1_gate = ggml_reshape_3d(ctx->ggml_ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
||||||
auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
auto fc1_value = ggml_reshape_3d(ctx->ggml_ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
||||||
auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
|
auto fc2 = ggml_reshape_3d(ctx->ggml_ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
|
||||||
|
|
||||||
fc1_gate = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
fc1_gate = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
||||||
fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f);
|
fc1_gate = ggml_l2_norm(ctx->ggml_ctx, fc1_gate, 1e-12f);
|
||||||
fc1_value = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
fc1_value = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
||||||
fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f);
|
fc1_value = ggml_l2_norm(ctx->ggml_ctx, fc1_value, 1e-12f);
|
||||||
fc2 = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
|
fc2 = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
|
||||||
fc2 = ggml_l2_norm(ctx, fc2, 1e-12f);
|
fc2 = ggml_l2_norm(ctx->ggml_ctx, fc2, 1e-12f);
|
||||||
|
|
||||||
auto res_x = x;
|
auto res_x = x;
|
||||||
x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x]
|
x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x]
|
||||||
|
|
||||||
auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
auto x1 = ggml_mul_mat(ctx->ggml_ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
x1 = ggml_silu_inplace(ctx, x1);
|
x1 = ggml_silu_inplace(ctx->ggml_ctx, x1);
|
||||||
|
|
||||||
auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
auto x2 = ggml_mul_mat(ctx->ggml_ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
|
|
||||||
x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
x = ggml_mul_inplace(ctx->ggml_ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
|
|
||||||
x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
|
x = ggml_mul_mat(ctx->ggml_ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
|
||||||
|
|
||||||
x = ggml_add_inplace(ctx, x, res_x);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, res_x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -633,7 +624,7 @@ namespace Flux {
|
|||||||
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels);
|
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
@ -652,15 +643,15 @@ namespace Flux {
|
|||||||
blocks["conv"] = std::make_shared<Conv2d>(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1});
|
blocks["conv"] = std::make_shared<Conv2d>(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
// x: [N, C, H, W]
|
// x: [N, C, H, W]
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
|
||||||
x = conv->forward(ctx, x);
|
x = conv->forward(ctx, x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -692,7 +683,6 @@ namespace Flux {
|
|||||||
int theta = 10000;
|
int theta = 10000;
|
||||||
bool qkv_bias = true;
|
bool qkv_bias = true;
|
||||||
bool guidance_embed = true;
|
bool guidance_embed = true;
|
||||||
bool flash_attn = true;
|
|
||||||
int64_t in_dim = 64;
|
int64_t in_dim = 64;
|
||||||
ChromaRadianceParams chroma_radiance_params;
|
ChromaRadianceParams chroma_radiance_params;
|
||||||
};
|
};
|
||||||
@ -731,7 +721,6 @@ namespace Flux {
|
|||||||
params.mlp_ratio,
|
params.mlp_ratio,
|
||||||
i,
|
i,
|
||||||
params.qkv_bias,
|
params.qkv_bias,
|
||||||
params.flash_attn,
|
|
||||||
params.is_chroma);
|
params.is_chroma);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -741,7 +730,6 @@ namespace Flux {
|
|||||||
params.mlp_ratio,
|
params.mlp_ratio,
|
||||||
i,
|
i,
|
||||||
0.f,
|
0.f,
|
||||||
params.flash_attn,
|
|
||||||
params.is_chroma);
|
params.is_chroma);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -828,8 +816,7 @@ namespace Flux {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
|
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* img,
|
struct ggml_tensor* img,
|
||||||
struct ggml_tensor* txt,
|
struct ggml_tensor* txt,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
@ -851,41 +838,41 @@ namespace Flux {
|
|||||||
if (params.is_chroma) {
|
if (params.is_chroma) {
|
||||||
int64_t mod_index_length = 344;
|
int64_t mod_index_length = 344;
|
||||||
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
|
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
|
||||||
auto distill_timestep = ggml_ext_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
|
auto distill_timestep = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 16, 10000, 1000.f);
|
||||||
auto distill_guidance = ggml_ext_timestep_embedding(ctx, guidance, 16, 10000, 1000.f);
|
auto distill_guidance = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 16, 10000, 1000.f);
|
||||||
|
|
||||||
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
|
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
|
||||||
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
|
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
|
||||||
GGML_ASSERT(mod_index_arange != nullptr);
|
GGML_ASSERT(mod_index_arange != nullptr);
|
||||||
auto modulation_index = ggml_ext_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
|
auto modulation_index = ggml_ext_timestep_embedding(ctx->ggml_ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
|
||||||
|
|
||||||
// Batch broadcast (will it ever be useful)
|
// Batch broadcast (will it ever be useful)
|
||||||
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
|
modulation_index = ggml_repeat(ctx->ggml_ctx, modulation_index, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
|
||||||
|
|
||||||
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
|
auto timestep_guidance = ggml_concat(ctx->ggml_ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
|
||||||
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
|
timestep_guidance = ggml_repeat(ctx->ggml_ctx, timestep_guidance, modulation_index); // [N, 344, 32]
|
||||||
|
|
||||||
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
|
vec = ggml_concat(ctx->ggml_ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
|
||||||
// Permute for consistency with non-distilled modulation implementation
|
// Permute for consistency with non-distilled modulation implementation
|
||||||
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
|
vec = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
|
||||||
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
|
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
|
||||||
|
|
||||||
if (y != nullptr) {
|
if (y != nullptr) {
|
||||||
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
|
txt_img_mask = ggml_pad(ctx->ggml_ctx, y, img->ne[1], 0, 0, 0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
|
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
|
||||||
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
|
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
|
||||||
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
|
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f));
|
||||||
if (params.guidance_embed) {
|
if (params.guidance_embed) {
|
||||||
GGML_ASSERT(guidance != nullptr);
|
GGML_ASSERT(guidance != nullptr);
|
||||||
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
|
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
|
||||||
// bf16 and fp16 result is different
|
// bf16 and fp16 result is different
|
||||||
auto g_in = ggml_ext_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
|
auto g_in = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 256, 10000, 1000.f);
|
||||||
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
|
vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in));
|
||||||
}
|
}
|
||||||
|
|
||||||
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
|
vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y));
|
||||||
}
|
}
|
||||||
|
|
||||||
txt = txt_in->forward(ctx, txt);
|
txt = txt_in->forward(ctx, txt);
|
||||||
@ -897,31 +884,31 @@ namespace Flux {
|
|||||||
|
|
||||||
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
auto img_txt = block->forward(ctx, backend, img, txt, vec, pe, txt_img_mask);
|
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask);
|
||||||
img = img_txt.first; // [N, n_img_token, hidden_size]
|
img = img_txt.first; // [N, n_img_token, hidden_size]
|
||||||
txt = img_txt.second; // [N, n_txt_token, hidden_size]
|
txt = img_txt.second; // [N, n_txt_token, hidden_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
|
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
|
||||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||||
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
|
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
txt_img = block->forward(ctx, backend, txt_img, vec, pe, txt_img_mask);
|
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
img = ggml_view_3d(ctx,
|
img = ggml_view_3d(ctx->ggml_ctx,
|
||||||
txt_img,
|
txt_img,
|
||||||
txt_img->ne[0],
|
txt_img->ne[0],
|
||||||
txt_img->ne[1],
|
txt_img->ne[1],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
txt_img->nb[1],
|
txt_img->nb[1],
|
||||||
txt_img->nb[2],
|
txt_img->nb[2],
|
||||||
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
if (final_layer) {
|
if (final_layer) {
|
||||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||||
@ -930,8 +917,7 @@ namespace Flux {
|
|||||||
return img;
|
return img;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx,
|
struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -952,32 +938,32 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = pad_to_patch_size(ctx, x);
|
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
|
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
|
||||||
|
|
||||||
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
|
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
|
||||||
img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
|
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
|
||||||
img = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
|
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
|
||||||
|
|
||||||
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
|
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
// nerf decode
|
// nerf decode
|
||||||
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
|
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
|
||||||
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
|
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
|
||||||
|
|
||||||
auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
|
auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
|
||||||
int64_t num_patches = nerf_pixels->ne[1];
|
int64_t num_patches = nerf_pixels->ne[1];
|
||||||
nerf_pixels = ggml_reshape_3d(ctx,
|
nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx,
|
||||||
nerf_pixels,
|
nerf_pixels,
|
||||||
nerf_pixels->ne[0] / C,
|
nerf_pixels->ne[0] / C,
|
||||||
C,
|
C,
|
||||||
nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
|
nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
|
||||||
nerf_pixels = ggml_cont(ctx, ggml_ext_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
|
nerf_pixels = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
|
||||||
|
|
||||||
auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
|
auto nerf_hidden = ggml_reshape_2d(ctx->ggml_ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
|
||||||
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
|
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
|
||||||
|
|
||||||
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
|
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<NerfGLUBlock>(blocks["nerf_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<NerfGLUBlock>(blocks["nerf_blocks." + std::to_string(i)]);
|
||||||
@ -985,17 +971,16 @@ namespace Flux {
|
|||||||
img_dct = block->forward(ctx, img_dct, nerf_hidden);
|
img_dct = block->forward(ctx, img_dct, nerf_hidden);
|
||||||
}
|
}
|
||||||
|
|
||||||
img_dct = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
|
img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
|
||||||
img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
|
img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
|
||||||
img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
|
img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
|
||||||
|
|
||||||
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx,
|
struct ggml_tensor* forward_flux_chroma(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -1016,58 +1001,57 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
auto img = process_img(ctx->ggml_ctx, x);
|
||||||
uint64_t img_tokens = img->ne[1];
|
uint64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (params.version == VERSION_FLUX_FILL) {
|
if (params.version == VERSION_FLUX_FILL) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
||||||
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
|
|
||||||
masked = process_img(ctx, masked);
|
masked = process_img(ctx->ggml_ctx, masked);
|
||||||
mask = process_img(ctx, mask);
|
mask = process_img(ctx->ggml_ctx, mask);
|
||||||
|
|
||||||
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
|
||||||
} else if (params.version == VERSION_FLEX_2) {
|
} else if (params.version == VERSION_FLEX_2) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
||||||
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
||||||
|
|
||||||
masked = process_img(ctx, masked);
|
masked = process_img(ctx->ggml_ctx, masked);
|
||||||
mask = process_img(ctx, mask);
|
mask = process_img(ctx->ggml_ctx, mask);
|
||||||
control = process_img(ctx, control);
|
control = process_img(ctx->ggml_ctx, control);
|
||||||
|
|
||||||
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
|
||||||
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
|
|
||||||
auto control = process_img(ctx, c_concat);
|
auto control = process_img(ctx->ggml_ctx, c_concat);
|
||||||
img = ggml_concat(ctx, img, control, 0);
|
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx, ref);
|
ref = process_img(ctx->ggml_ctx, ref);
|
||||||
img = ggml_concat(ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
||||||
|
|
||||||
if (out->ne[1] > img_tokens) {
|
if (out->ne[1] > img_tokens) {
|
||||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||||
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
||||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||||
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
|
out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -1091,7 +1075,6 @@ namespace Flux {
|
|||||||
|
|
||||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
return forward_chroma_radiance(ctx,
|
return forward_chroma_radiance(ctx,
|
||||||
backend,
|
|
||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
context,
|
context,
|
||||||
@ -1105,7 +1088,6 @@ namespace Flux {
|
|||||||
skip_layers);
|
skip_layers);
|
||||||
} else {
|
} else {
|
||||||
return forward_flux_chroma(ctx,
|
return forward_flux_chroma(ctx,
|
||||||
backend,
|
|
||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
context,
|
context,
|
||||||
@ -1136,11 +1118,9 @@ namespace Flux {
|
|||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
SDVersion version = VERSION_FLUX,
|
SDVersion version = VERSION_FLUX,
|
||||||
bool flash_attn = false,
|
|
||||||
bool use_mask = false)
|
bool use_mask = false)
|
||||||
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
|
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
|
||||||
flux_params.version = version;
|
flux_params.version = version;
|
||||||
flux_params.flash_attn = flash_attn;
|
|
||||||
flux_params.guidance_embed = false;
|
flux_params.guidance_embed = false;
|
||||||
flux_params.depth = 0;
|
flux_params.depth = 0;
|
||||||
flux_params.depth_single_blocks = 0;
|
flux_params.depth_single_blocks = 0;
|
||||||
@ -1323,8 +1303,9 @@ namespace Flux {
|
|||||||
set_backend_tensor_data(dct, dct_vec.data());
|
set_backend_tensor_data(dct, dct_vec.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* out = flux.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
|
||||||
|
struct ggml_tensor* out = flux.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
@ -1435,8 +1416,7 @@ namespace Flux {
|
|||||||
tensor_types,
|
tensor_types,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
VERSION_CHROMA_RADIANCE,
|
VERSION_CHROMA_RADIANCE,
|
||||||
false,
|
false);
|
||||||
true);
|
|
||||||
|
|
||||||
flux->alloc_params_buffer();
|
flux->alloc_params_buffer();
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
|||||||
103
ggml_extend.hpp
103
ggml_extend.hpp
@ -1157,8 +1157,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
bool diag_mask_inf = false,
|
bool diag_mask_inf = false,
|
||||||
bool skip_reshape = false,
|
bool skip_reshape = false,
|
||||||
bool flash_attn = false, // avoid overflow
|
bool flash_attn = false,
|
||||||
float kv_scale = 1.0f) {
|
float kv_scale = 1.0f) { // avoid overflow
|
||||||
int64_t L_q;
|
int64_t L_q;
|
||||||
int64_t L_k;
|
int64_t L_k;
|
||||||
int64_t C;
|
int64_t C;
|
||||||
@ -1462,6 +1462,13 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
|||||||
|
|
||||||
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
||||||
|
|
||||||
|
struct GGMLRunnerContext {
|
||||||
|
ggml_backend_t backend = nullptr;
|
||||||
|
ggml_context* ggml_ctx = nullptr;
|
||||||
|
bool flash_attn_enabled = false;
|
||||||
|
bool conv2d_direct_enabled = false;
|
||||||
|
};
|
||||||
|
|
||||||
struct GGMLRunner {
|
struct GGMLRunner {
|
||||||
protected:
|
protected:
|
||||||
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
|
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
|
||||||
@ -1488,6 +1495,9 @@ protected:
|
|||||||
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
|
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
|
||||||
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
||||||
|
|
||||||
|
bool flash_attn_enabled = false;
|
||||||
|
bool conv2d_direct_enabled = false;
|
||||||
|
|
||||||
void alloc_params_ctx() {
|
void alloc_params_ctx() {
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
|
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
|
||||||
@ -1744,6 +1754,15 @@ public:
|
|||||||
free_cache_ctx_and_buffer();
|
free_cache_ctx_and_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual GGMLRunnerContext get_context() {
|
||||||
|
GGMLRunnerContext runner_ctx;
|
||||||
|
runner_ctx.ggml_ctx = compute_ctx;
|
||||||
|
runner_ctx.backend = runtime_backend;
|
||||||
|
runner_ctx.flash_attn_enabled = flash_attn_enabled;
|
||||||
|
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
|
||||||
|
return runner_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
void reset_compute_ctx() {
|
void reset_compute_ctx() {
|
||||||
free_compute_ctx();
|
free_compute_ctx();
|
||||||
alloc_compute_ctx();
|
alloc_compute_ctx();
|
||||||
@ -1864,6 +1883,14 @@ public:
|
|||||||
free_compute_buffer();
|
free_compute_buffer();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
|
flash_attn_enabled = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_conv2d_direct_enabled(bool enabled) {
|
||||||
|
conv2d_direct_enabled = enabled;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class GGMLBlock {
|
class GGMLBlock {
|
||||||
@ -1955,12 +1982,12 @@ public:
|
|||||||
|
|
||||||
class UnaryBlock : public GGMLBlock {
|
class UnaryBlock : public GGMLBlock {
|
||||||
public:
|
public:
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Identity : public UnaryBlock {
|
class Identity : public UnaryBlock {
|
||||||
public:
|
public:
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1974,7 +2001,7 @@ protected:
|
|||||||
bool force_prec_f32;
|
bool force_prec_f32;
|
||||||
float scale;
|
float scale;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||||
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
|
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
|
||||||
wtype = GGML_TYPE_F32;
|
wtype = GGML_TYPE_F32;
|
||||||
@ -2000,13 +2027,13 @@ public:
|
|||||||
force_prec_f32(force_prec_f32),
|
force_prec_f32(force_prec_f32),
|
||||||
scale(scale) {}
|
scale(scale) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_ext_linear(ctx, x, w, b, force_prec_f32, scale);
|
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2022,7 +2049,7 @@ class Embedding : public UnaryBlock {
|
|||||||
protected:
|
protected:
|
||||||
int64_t embedding_dim;
|
int64_t embedding_dim;
|
||||||
int64_t num_embeddings;
|
int64_t num_embeddings;
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||||
if (!support_get_rows(wtype)) {
|
if (!support_get_rows(wtype)) {
|
||||||
wtype = GGML_TYPE_F32;
|
wtype = GGML_TYPE_F32;
|
||||||
@ -2036,7 +2063,7 @@ public:
|
|||||||
num_embeddings(num_embeddings) {
|
num_embeddings(num_embeddings) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids) {
|
struct ggml_tensor* input_ids) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto weight = params["weight"];
|
auto weight = params["weight"];
|
||||||
@ -2044,11 +2071,11 @@ public:
|
|||||||
// There are issues with ggml batch inference, so we are expanding it here first.
|
// There are issues with ggml batch inference, so we are expanding it here first.
|
||||||
// TODO: fix ggml batch inference
|
// TODO: fix ggml batch inference
|
||||||
int64_t n = input_ids->ne[1];
|
int64_t n = input_ids->ne[1];
|
||||||
input_ids = ggml_reshape_1d(ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]);
|
input_ids = ggml_reshape_1d(ctx->ggml_ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]);
|
||||||
|
|
||||||
input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
|
input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
|
||||||
auto embedding = ggml_get_rows(ctx, weight, input_ids);
|
auto embedding = ggml_get_rows(ctx->ggml_ctx, weight, input_ids);
|
||||||
embedding = ggml_reshape_3d(ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n);
|
embedding = ggml_reshape_3d(ctx->ggml_ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n);
|
||||||
|
|
||||||
// [N, n_token, embedding_dim]
|
// [N, n_token, embedding_dim]
|
||||||
return embedding;
|
return embedding;
|
||||||
@ -2064,10 +2091,9 @@ protected:
|
|||||||
std::pair<int, int> padding;
|
std::pair<int, int> padding;
|
||||||
std::pair<int, int> dilation;
|
std::pair<int, int> dilation;
|
||||||
bool bias;
|
bool bias;
|
||||||
bool direct = false;
|
|
||||||
float scale = 1.f;
|
float scale = 1.f;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||||
enum ggml_type wtype = GGML_TYPE_F16;
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
|
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -2092,10 +2118,6 @@ public:
|
|||||||
dilation(dilation),
|
dilation(dilation),
|
||||||
bias(bias) {}
|
bias(bias) {}
|
||||||
|
|
||||||
void enable_direct() {
|
|
||||||
direct = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_scale(float scale_value) {
|
void set_scale(float scale_value) {
|
||||||
scale = scale_value;
|
scale = scale_value;
|
||||||
}
|
}
|
||||||
@ -2104,13 +2126,13 @@ public:
|
|||||||
return "Conv2d";
|
return "Conv2d";
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_ext_conv_2d(ctx,
|
return ggml_ext_conv_2d(ctx->ggml_ctx,
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
b,
|
b,
|
||||||
@ -2120,7 +2142,7 @@ public:
|
|||||||
padding.first,
|
padding.first,
|
||||||
dilation.second,
|
dilation.second,
|
||||||
dilation.first,
|
dilation.first,
|
||||||
direct,
|
ctx->conv2d_direct_enabled,
|
||||||
scale);
|
scale);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -2135,7 +2157,7 @@ protected:
|
|||||||
int64_t dilation;
|
int64_t dilation;
|
||||||
bool bias;
|
bool bias;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||||
enum ggml_type wtype = GGML_TYPE_F16;
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
|
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
|
||||||
if (bias) {
|
if (bias) {
|
||||||
@ -2162,13 +2184,13 @@ public:
|
|||||||
|
|
||||||
// x: [N, IC, ID, IH*IW]
|
// x: [N, IC, ID, IH*IW]
|
||||||
// result: [N, OC, OD, OH*OW]
|
// result: [N, OC, OD, OH*OW]
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_ext_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation);
|
return ggml_ext_conv_3d_nx1x1(ctx->ggml_ctx, x, w, b, stride, padding, dilation);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2182,7 +2204,7 @@ protected:
|
|||||||
std::tuple<int, int, int> dilation;
|
std::tuple<int, int, int> dilation;
|
||||||
bool bias;
|
bool bias;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||||
enum ggml_type wtype = GGML_TYPE_F16;
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
params["weight"] = ggml_new_tensor_4d(ctx,
|
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||||
wtype,
|
wtype,
|
||||||
@ -2211,13 +2233,13 @@ public:
|
|||||||
dilation(dilation),
|
dilation(dilation),
|
||||||
bias(bias) {}
|
bias(bias) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
|
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
||||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
||||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||||
@ -2231,7 +2253,7 @@ protected:
|
|||||||
bool elementwise_affine;
|
bool elementwise_affine;
|
||||||
bool bias;
|
bool bias;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||||
if (elementwise_affine) {
|
if (elementwise_affine) {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
|
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
|
||||||
@ -2252,7 +2274,7 @@ public:
|
|||||||
elementwise_affine(elementwise_affine),
|
elementwise_affine(elementwise_affine),
|
||||||
bias(bias) {}
|
bias(bias) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = nullptr;
|
struct ggml_tensor* w = nullptr;
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
|
|
||||||
@ -2262,7 +2284,7 @@ public:
|
|||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ggml_ext_layer_norm(ctx, x, w, b, eps);
|
return ggml_ext_layer_norm(ctx->ggml_ctx, x, w, b, eps);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2273,7 +2295,7 @@ protected:
|
|||||||
float eps;
|
float eps;
|
||||||
bool affine;
|
bool affine;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||||
if (affine) {
|
if (affine) {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
enum ggml_type bias_wtype = GGML_TYPE_F32;
|
enum ggml_type bias_wtype = GGML_TYPE_F32;
|
||||||
@ -2292,14 +2314,14 @@ public:
|
|||||||
eps(eps),
|
eps(eps),
|
||||||
affine(affine) {}
|
affine(affine) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = nullptr;
|
struct ggml_tensor* w = nullptr;
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (affine) {
|
if (affine) {
|
||||||
w = params["weight"];
|
w = params["weight"];
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_ext_group_norm(ctx, x, w, b, num_groups);
|
return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2314,7 +2336,7 @@ protected:
|
|||||||
int64_t hidden_size;
|
int64_t hidden_size;
|
||||||
float eps;
|
float eps;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||||
}
|
}
|
||||||
@ -2325,10 +2347,10 @@ public:
|
|||||||
: hidden_size(hidden_size),
|
: hidden_size(hidden_size),
|
||||||
eps(eps) {}
|
eps(eps) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
x = ggml_rms_norm(ctx, x, eps);
|
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
|
||||||
x = ggml_mul_inplace(ctx, x, w);
|
x = ggml_mul_inplace(ctx->ggml_ctx, x, w);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -2364,8 +2386,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// x: [N, n_token, embed_dim]
|
// x: [N, n_token, embed_dim]
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
bool mask = false) {
|
bool mask = false) {
|
||||||
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
|
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
|
||||||
@ -2377,7 +2398,7 @@ public:
|
|||||||
struct ggml_tensor* k = k_proj->forward(ctx, x);
|
struct ggml_tensor* k = k_proj->forward(ctx, x);
|
||||||
struct ggml_tensor* v = v_proj->forward(ctx, x);
|
struct ggml_tensor* v = v_proj->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
||||||
return x;
|
return x;
|
||||||
|
|||||||
2
ltxv.hpp
2
ltxv.hpp
@ -27,7 +27,7 @@ namespace LTXV {
|
|||||||
bias));
|
bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
bool causal = true) {
|
bool causal = true) {
|
||||||
// x: [N*IC, ID, IH, IW]
|
// x: [N*IC, ID, IH, IW]
|
||||||
|
|||||||
243
mmdit.hpp
243
mmdit.hpp
@ -27,13 +27,13 @@ public:
|
|||||||
blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_features, out_features, bias));
|
blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_features, out_features, bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, n_token, in_features]
|
// x: [N, n_token, in_features]
|
||||||
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
||||||
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -72,7 +72,7 @@ public:
|
|||||||
bias));
|
bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, C, H, W]
|
// x: [N, C, H, W]
|
||||||
// return: [N, H*W, embed_dim]
|
// return: [N, H*W, embed_dim]
|
||||||
auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]);
|
||||||
@ -82,13 +82,13 @@ public:
|
|||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
|
x = ggml_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
|
||||||
}
|
}
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
|
|
||||||
if (flatten) {
|
if (flatten) {
|
||||||
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3));
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3));
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -107,16 +107,16 @@ public:
|
|||||||
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
|
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) {
|
||||||
// t: [N, ]
|
// t: [N, ]
|
||||||
// return: [N, hidden_size]
|
// return: [N, hidden_size]
|
||||||
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
||||||
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
||||||
|
|
||||||
auto t_freq = ggml_ext_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
|
auto t_freq = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
|
||||||
|
|
||||||
auto t_emb = mlp_0->forward(ctx, t_freq);
|
auto t_emb = mlp_0->forward(ctx, t_freq);
|
||||||
t_emb = ggml_silu_inplace(ctx, t_emb);
|
t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb);
|
||||||
t_emb = mlp_2->forward(ctx, t_emb);
|
t_emb = mlp_2->forward(ctx, t_emb);
|
||||||
return t_emb;
|
return t_emb;
|
||||||
}
|
}
|
||||||
@ -131,14 +131,14 @@ public:
|
|||||||
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
|
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, input_dim]
|
// x: [N, input_dim]
|
||||||
// return: [N, hidden_size]
|
// return: [N, hidden_size]
|
||||||
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
||||||
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
||||||
|
|
||||||
x = mlp_0->forward(ctx, x);
|
x = mlp_0->forward(ctx, x);
|
||||||
x = ggml_silu_inplace(ctx, x);
|
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -149,16 +149,14 @@ public:
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
bool pre_only;
|
bool pre_only;
|
||||||
std::string qk_norm;
|
std::string qk_norm;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SelfAttention(int64_t dim,
|
SelfAttention(int64_t dim,
|
||||||
int64_t num_heads = 8,
|
int64_t num_heads = 8,
|
||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false,
|
bool pre_only = false)
|
||||||
bool flash_attn = false)
|
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
|
||||||
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
|
|
||||||
int64_t d_head = dim / num_heads;
|
int64_t d_head = dim / num_heads;
|
||||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
@ -173,15 +171,15 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x);
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
auto qkv_vec = split_qkv(ctx, qkv);
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||||
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto v = qkv_vec[2]; // [N, n_token, n_head*d_head]
|
auto v = qkv_vec[2]; // [N, n_token, n_head*d_head]
|
||||||
|
|
||||||
if (qk_norm == "rms" || qk_norm == "ln") {
|
if (qk_norm == "rms" || qk_norm == "ln") {
|
||||||
auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]);
|
auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]);
|
||||||
@ -190,13 +188,13 @@ public:
|
|||||||
k = ln_k->forward(ctx, k);
|
k = ln_k->forward(ctx, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
|
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head]
|
||||||
k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head]
|
||||||
|
|
||||||
return {q, k, v};
|
return {q, k, v};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
GGML_ASSERT(!pre_only);
|
GGML_ASSERT(!pre_only);
|
||||||
|
|
||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
@ -206,12 +204,11 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// x: [N, n_token, dim]
|
// x: [N, n_token, dim]
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto qkv = pre_attention(ctx, x);
|
auto qkv = pre_attention(ctx, x);
|
||||||
x = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -236,7 +233,6 @@ public:
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
bool pre_only;
|
bool pre_only;
|
||||||
bool self_attn;
|
bool self_attn;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DismantledBlock(int64_t hidden_size,
|
DismantledBlock(int64_t hidden_size,
|
||||||
@ -245,17 +241,16 @@ public:
|
|||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false,
|
bool pre_only = false,
|
||||||
bool self_attn = false,
|
bool self_attn = false)
|
||||||
bool flash_attn = false)
|
|
||||||
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
|
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
|
||||||
// rmsnorm is always Flase
|
// rmsnorm is always Flase
|
||||||
// scale_mod_only is always Flase
|
// scale_mod_only is always Flase
|
||||||
// swiglu is always Flase
|
// swiglu is always Flase
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||||
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
|
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
|
||||||
|
|
||||||
if (self_attn) {
|
if (self_attn) {
|
||||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
|
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
@ -274,9 +269,9 @@ public:
|
|||||||
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
|
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
|
std::tuple<std::vector<ggml_tensor*>, std::vector<ggml_tensor*>, std::vector<ggml_tensor*>> pre_attention_x(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
GGML_ASSERT(self_attn);
|
GGML_ASSERT(self_attn);
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
// c: [N, hidden_size]
|
// c: [N, hidden_size]
|
||||||
@ -286,35 +281,35 @@ public:
|
|||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
int64_t n_mods = 9;
|
int64_t n_mods = 9;
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
||||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
||||||
|
|
||||||
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
||||||
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
||||||
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
||||||
|
|
||||||
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
|
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
|
||||||
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
|
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
|
||||||
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
|
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
|
||||||
|
|
||||||
auto x_norm = norm1->forward(ctx, x);
|
auto x_norm = norm1->forward(ctx, x);
|
||||||
|
|
||||||
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
|
auto attn_in = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa);
|
||||||
auto qkv = attn->pre_attention(ctx, attn_in);
|
auto qkv = attn->pre_attention(ctx, attn_in);
|
||||||
|
|
||||||
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
|
auto attn2_in = modulate(ctx->ggml_ctx, x_norm, shift_msa2, scale_msa2);
|
||||||
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
|
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
|
||||||
|
|
||||||
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
|
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
|
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -327,33 +322,33 @@ public:
|
|||||||
if (pre_only) {
|
if (pre_only) {
|
||||||
n_mods = 2;
|
n_mods = 2;
|
||||||
}
|
}
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
||||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
||||||
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
||||||
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
||||||
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
||||||
|
|
||||||
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
||||||
|
|
||||||
auto qkv = attn->pre_attention(ctx, attn_in);
|
auto qkv = attn->pre_attention(ctx, attn_in);
|
||||||
|
|
||||||
return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}};
|
return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}};
|
||||||
} else {
|
} else {
|
||||||
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
||||||
auto qkv = attn->pre_attention(ctx, attn_in);
|
auto qkv = attn->pre_attention(ctx, attn_in);
|
||||||
|
|
||||||
return {qkv, {nullptr, nullptr, nullptr, nullptr, nullptr}};
|
return {qkv, {nullptr, nullptr, nullptr, nullptr, nullptr}};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
|
struct ggml_tensor* post_attention_x(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* attn_out,
|
struct ggml_tensor* attn_out,
|
||||||
struct ggml_tensor* attn2_out,
|
struct ggml_tensor* attn2_out,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
@ -376,22 +371,22 @@ public:
|
|||||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
||||||
|
|
||||||
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
|
gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
|
||||||
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
|
gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
|
||||||
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
|
gate_msa2 = ggml_reshape_3d(ctx->ggml_ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
|
||||||
|
|
||||||
attn_out = attn->post_attention(ctx, attn_out);
|
attn_out = attn->post_attention(ctx, attn_out);
|
||||||
attn2_out = attn2->post_attention(ctx, attn2_out);
|
attn2_out = attn2->post_attention(ctx, attn2_out);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn2_out, gate_msa2));
|
||||||
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
|
auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp));
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* post_attention(struct ggml_context* ctx,
|
struct ggml_tensor* post_attention(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* attn_out,
|
struct ggml_tensor* attn_out,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* gate_msa,
|
struct ggml_tensor* gate_msa,
|
||||||
@ -411,20 +406,19 @@ public:
|
|||||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
||||||
|
|
||||||
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
|
gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
|
||||||
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
|
gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
|
||||||
|
|
||||||
attn_out = attn->post_attention(ctx, attn_out);
|
attn_out = attn->post_attention(ctx, attn_out);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
|
||||||
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
|
auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp));
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -441,8 +435,8 @@ public:
|
|||||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||||
auto intermediates = std::get<2>(qkv_intermediates);
|
auto intermediates = std::get<2>(qkv_intermediates);
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
auto attn2_out = ggml_ext_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention_x(ctx,
|
x = post_attention_x(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
attn2_out,
|
attn2_out,
|
||||||
@ -458,7 +452,7 @@ public:
|
|||||||
auto qkv = qkv_intermediates.first;
|
auto qkv = qkv_intermediates.first;
|
||||||
auto intermediates = qkv_intermediates.second;
|
auto intermediates = qkv_intermediates.second;
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx,
|
x = post_attention(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
intermediates[0],
|
intermediates[0],
|
||||||
@ -472,9 +466,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
|
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
|
||||||
block_mixing(struct ggml_context* ctx,
|
block_mixing(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
bool flash_attn,
|
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c,
|
struct ggml_tensor* c,
|
||||||
@ -501,29 +493,29 @@ block_mixing(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
std::vector<struct ggml_tensor*> qkv;
|
std::vector<struct ggml_tensor*> qkv;
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
|
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attn = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
|
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
||||||
auto context_attn = ggml_view_3d(ctx,
|
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
context->ne[1],
|
context->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_context, N, hidden_size]
|
0); // [n_context, N, hidden_size]
|
||||||
context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
|
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
|
||||||
auto x_attn = ggml_view_3d(ctx,
|
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
x->ne[1],
|
x->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
|
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
|
||||||
x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
|
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
if (!context_block->pre_only) {
|
if (!context_block->pre_only) {
|
||||||
context = context_block->post_attention(ctx,
|
context = context_block->post_attention(ctx,
|
||||||
@ -538,7 +530,7 @@ block_mixing(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (x_block->self_attn) {
|
if (x_block->self_attn) {
|
||||||
auto attn2 = ggml_ext_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
|
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = x_block->post_attention_x(ctx,
|
x = x_block->post_attention_x(ctx,
|
||||||
x_attn,
|
x_attn,
|
||||||
@ -563,8 +555,6 @@ block_mixing(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct JointBlock : public GGMLBlock {
|
struct JointBlock : public GGMLBlock {
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
JointBlock(int64_t hidden_size,
|
JointBlock(int64_t hidden_size,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
@ -572,22 +562,19 @@ public:
|
|||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false,
|
bool pre_only = false,
|
||||||
bool self_attn_x = false,
|
bool self_attn_x = false) {
|
||||||
bool flash_attn = false)
|
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false));
|
||||||
: flash_attn(flash_attn) {
|
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
|
||||||
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
|
|
||||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
|
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
|
||||||
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
|
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
|
||||||
|
|
||||||
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
|
return block_mixing(ctx, context, x, c, context_block, x_block);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -603,7 +590,7 @@ public:
|
|||||||
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
|
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -613,15 +600,15 @@ public:
|
|||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
||||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
int64_t offset = m->nb[1] * m->ne[1];
|
||||||
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||||
auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||||
|
|
||||||
x = modulate(ctx, norm_final->forward(ctx, x), shift, scale);
|
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
x = linear->forward(ctx, x);
|
x = linear->forward(ctx, x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -645,7 +632,6 @@ protected:
|
|||||||
int64_t context_embedder_out_dim = 1536;
|
int64_t context_embedder_out_dim = 1536;
|
||||||
int64_t hidden_size;
|
int64_t hidden_size;
|
||||||
std::string qk_norm;
|
std::string qk_norm;
|
||||||
bool flash_attn = false;
|
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
@ -653,8 +639,7 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
|
MMDiT(const String2GGMLType& tensor_types = {}) {
|
||||||
: flash_attn(flash_attn) {
|
|
||||||
// input_size is always None
|
// input_size is always None
|
||||||
// learn_sigma is always False
|
// learn_sigma is always False
|
||||||
// register_length is alwalys 0
|
// register_length is alwalys 0
|
||||||
@ -722,8 +707,7 @@ public:
|
|||||||
qk_norm,
|
qk_norm,
|
||||||
true,
|
true,
|
||||||
i == depth - 1,
|
i == depth - 1,
|
||||||
i <= d_self,
|
i <= d_self));
|
||||||
flash_attn));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
|
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
|
||||||
@ -791,8 +775,7 @@ public:
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
|
struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c_mod,
|
struct ggml_tensor* c_mod,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -811,7 +794,7 @@ public:
|
|||||||
|
|
||||||
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
auto context_x = block->forward(ctx, backend, context, x, c_mod);
|
auto context_x = block->forward(ctx, context, x, c_mod);
|
||||||
context = context_x.first;
|
context = context_x.first;
|
||||||
x = context_x.second;
|
x = context_x.second;
|
||||||
}
|
}
|
||||||
@ -821,8 +804,7 @@ public:
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* t,
|
struct ggml_tensor* t,
|
||||||
struct ggml_tensor* y = nullptr,
|
struct ggml_tensor* y = nullptr,
|
||||||
@ -840,16 +822,16 @@ public:
|
|||||||
int64_t w = x->ne[0];
|
int64_t w = x->ne[0];
|
||||||
int64_t h = x->ne[1];
|
int64_t h = x->ne[1];
|
||||||
|
|
||||||
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
|
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
|
||||||
auto pos_embed = cropped_pos_embed(ctx, h, w); // [1, H*W, hidden_size]
|
auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size]
|
||||||
x = ggml_add(ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
|
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
|
||||||
|
|
||||||
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
|
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
|
||||||
if (y != nullptr && adm_in_channels != -1) {
|
if (y != nullptr && adm_in_channels != -1) {
|
||||||
auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]);
|
auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]);
|
||||||
|
|
||||||
y = y_embedder->forward(ctx, y); // [N, hidden_size]
|
y = y_embedder->forward(ctx, y); // [N, hidden_size]
|
||||||
c = ggml_add(ctx, c, y);
|
c = ggml_add(ctx->ggml_ctx, c, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (context != nullptr) {
|
if (context != nullptr) {
|
||||||
@ -858,9 +840,9 @@ public:
|
|||||||
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
|
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
|
||||||
}
|
}
|
||||||
|
|
||||||
x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
x = unpatchify(ctx, x, h, w); // [N, C, H, W]
|
x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W]
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -870,10 +852,9 @@ struct MMDiTRunner : public GGMLRunner {
|
|||||||
|
|
||||||
MMDiTRunner(ggml_backend_t backend,
|
MMDiTRunner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
bool flash_attn,
|
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "")
|
const std::string prefix = "")
|
||||||
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
|
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
|
||||||
mmdit.init(params_ctx, tensor_types, prefix);
|
mmdit.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -897,8 +878,8 @@ struct MMDiTRunner : public GGMLRunner {
|
|||||||
y = to_backend(y);
|
y = to_backend(y);
|
||||||
timesteps = to_backend(timesteps);
|
timesteps = to_backend(timesteps);
|
||||||
|
|
||||||
struct ggml_tensor* out = mmdit.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
struct ggml_tensor* out = mmdit.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
y,
|
y,
|
||||||
@ -972,7 +953,7 @@ struct MMDiTRunner : public GGMLRunner {
|
|||||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
ggml_type model_data_type = GGML_TYPE_F16;
|
ggml_type model_data_type = GGML_TYPE_F16;
|
||||||
std::shared_ptr<MMDiTRunner> mmdit = std::make_shared<MMDiTRunner>(backend, false, false);
|
std::shared_ptr<MMDiTRunner> mmdit = std::make_shared<MMDiTRunner>(backend, false);
|
||||||
{
|
{
|
||||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||||
|
|
||||||
|
|||||||
124
pmid.hpp
124
pmid.hpp
@ -21,7 +21,7 @@ public:
|
|||||||
blocks["layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
|
blocks["layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, channels, h, w]
|
// x: [N, channels, h, w]
|
||||||
|
|
||||||
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
|
||||||
@ -33,11 +33,11 @@ public:
|
|||||||
x = layer_norm->forward(ctx, x);
|
x = layer_norm->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
||||||
if (use_residue)
|
if (use_residue)
|
||||||
x = ggml_add(ctx, x, r);
|
x = ggml_add(ctx->ggml_ctx, x, r);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -54,7 +54,7 @@ public:
|
|||||||
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
|
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
|
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
|
||||||
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
|
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
|
||||||
@ -100,7 +100,7 @@ public:
|
|||||||
ggml_cont(ctx, tli)};
|
ggml_cont(ctx, tli)};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* latents) {
|
struct ggml_tensor* latents) {
|
||||||
// x (torch.Tensor): image features
|
// x (torch.Tensor): image features
|
||||||
@ -118,33 +118,33 @@ public:
|
|||||||
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||||
auto q = to_q->forward(ctx, latents);
|
auto q = to_q->forward(ctx, latents);
|
||||||
|
|
||||||
auto kv_input = ggml_concat(ctx, x, latents, 1);
|
auto kv_input = ggml_concat(ctx->ggml_ctx, x, latents, 1);
|
||||||
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
|
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
|
||||||
auto kv = to_kv->forward(ctx, kv_input);
|
auto kv = to_kv->forward(ctx, kv_input);
|
||||||
auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
|
auto k = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
|
||||||
auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
|
auto v = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
|
||||||
k = ggml_cont(ctx, k);
|
k = ggml_cont(ctx->ggml_ctx, k);
|
||||||
v = ggml_cont(ctx, v);
|
v = ggml_cont(ctx->ggml_ctx, v);
|
||||||
q = reshape_tensor(ctx, q, heads);
|
q = reshape_tensor(ctx->ggml_ctx, q, heads);
|
||||||
k = reshape_tensor(ctx, k, heads);
|
k = reshape_tensor(ctx->ggml_ctx, k, heads);
|
||||||
v = reshape_tensor(ctx, v, heads);
|
v = reshape_tensor(ctx->ggml_ctx, v, heads);
|
||||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||||
k = ggml_scale_inplace(ctx, k, scale);
|
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale);
|
||||||
q = ggml_scale_inplace(ctx, q, scale);
|
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale);
|
||||||
// auto weight = ggml_mul_mat(ctx, q, k);
|
// auto weight = ggml_mul_mat(ctx, q, k);
|
||||||
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
|
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||||
|
|
||||||
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
|
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
|
||||||
// in this case, dimension along which Softmax will be computed is the last dim
|
// in this case, dimension along which Softmax will be computed is the last dim
|
||||||
// in torch and the first dim in GGML, consistent with the convention that pytorch's
|
// in torch and the first dim in GGML, consistent with the convention that pytorch's
|
||||||
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
|
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
|
||||||
// weight = ggml_soft_max(ctx, weight);
|
// weight = ggml_soft_max(ctx, weight);
|
||||||
weight = ggml_soft_max_inplace(ctx, weight);
|
weight = ggml_soft_max_inplace(ctx->ggml_ctx, weight);
|
||||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
v = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, v));
|
||||||
// auto out = ggml_mul_mat(ctx, weight, v);
|
// auto out = ggml_mul_mat(ctx, weight, v);
|
||||||
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
|
auto out = ggml_mul_mat(ctx->ggml_ctx, v, weight); // NOTE order of mul is opposite to pytorch
|
||||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3));
|
||||||
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
|
out = ggml_reshape_3d(ctx->ggml_ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
|
||||||
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
|
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
|
||||||
out = to_out->forward(ctx, out);
|
out = to_out->forward(ctx, out);
|
||||||
return out;
|
return out;
|
||||||
@ -176,7 +176,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* latents,
|
struct ggml_tensor* latents,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
// x: [N, channels, h, w]
|
// x: [N, channels, h, w]
|
||||||
@ -191,9 +191,9 @@ public:
|
|||||||
name = "layers." + std::to_string(i) + ".1";
|
name = "layers." + std::to_string(i) + ".1";
|
||||||
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
|
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
|
||||||
auto t = attn->forward(ctx, x, latents);
|
auto t = attn->forward(ctx, x, latents);
|
||||||
latents = ggml_add(ctx, t, latents);
|
latents = ggml_add(ctx->ggml_ctx, t, latents);
|
||||||
t = ff->forward(ctx, latents);
|
t = ff->forward(ctx, latents);
|
||||||
latents = ggml_add(ctx, t, latents);
|
latents = ggml_add(ctx->ggml_ctx, t, latents);
|
||||||
}
|
}
|
||||||
latents = proj_out->forward(ctx, latents);
|
latents = proj_out->forward(ctx, latents);
|
||||||
latents = norm_out->forward(ctx, latents);
|
latents = norm_out->forward(ctx, latents);
|
||||||
@ -225,7 +225,7 @@ public:
|
|||||||
4));
|
4));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* last_hidden_state) {
|
struct ggml_tensor* last_hidden_state) {
|
||||||
// x: [N, channels, h, w]
|
// x: [N, channels, h, w]
|
||||||
@ -235,11 +235,11 @@ public:
|
|||||||
|
|
||||||
x = token_proj->forward(ctx, x);
|
x = token_proj->forward(ctx, x);
|
||||||
int64_t nel = ggml_nelements(x);
|
int64_t nel = ggml_nelements(x);
|
||||||
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
|
||||||
x = token_norm->forward(ctx, x);
|
x = token_norm->forward(ctx, x);
|
||||||
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
|
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
|
||||||
if (use_residul)
|
if (use_residul)
|
||||||
out = ggml_add(ctx, x, out);
|
out = ggml_add(ctx->ggml_ctx, x, out);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -256,24 +256,24 @@ public:
|
|||||||
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(embed_dim));
|
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(embed_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* fuse_fn(struct ggml_context* ctx,
|
struct ggml_tensor* fuse_fn(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* prompt_embeds,
|
struct ggml_tensor* prompt_embeds,
|
||||||
struct ggml_tensor* id_embeds) {
|
struct ggml_tensor* id_embeds) {
|
||||||
auto mlp1 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp1"]);
|
auto mlp1 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp1"]);
|
||||||
auto mlp2 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp2"]);
|
auto mlp2 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp2"]);
|
||||||
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm"]);
|
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm"]);
|
||||||
|
|
||||||
auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0);
|
auto stacked_id_embeds = ggml_concat(ctx->ggml_ctx, prompt_embeds, id_embeds, 0);
|
||||||
|
|
||||||
stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds);
|
stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds);
|
||||||
stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds);
|
stacked_id_embeds = ggml_add(ctx->ggml_ctx, stacked_id_embeds, prompt_embeds);
|
||||||
stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds);
|
stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds);
|
||||||
stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds);
|
stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds);
|
||||||
|
|
||||||
return stacked_id_embeds;
|
return stacked_id_embeds;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* prompt_embeds,
|
struct ggml_tensor* prompt_embeds,
|
||||||
struct ggml_tensor* id_embeds,
|
struct ggml_tensor* id_embeds,
|
||||||
struct ggml_tensor* class_tokens_mask,
|
struct ggml_tensor* class_tokens_mask,
|
||||||
@ -286,25 +286,25 @@ public:
|
|||||||
// # slice out the image token embeddings
|
// # slice out the image token embeddings
|
||||||
ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos");
|
ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos");
|
||||||
ggml_set_name(prompt_embeds, "prompt_embeds");
|
ggml_set_name(prompt_embeds, "prompt_embeds");
|
||||||
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
|
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx->ggml_ctx, prompt_embeds, class_tokens_mask_pos);
|
||||||
ggml_set_name(image_token_embeds, "image_token_embeds");
|
ggml_set_name(image_token_embeds, "image_token_embeds");
|
||||||
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
|
valid_id_embeds = ggml_reshape_2d(ctx->ggml_ctx, valid_id_embeds, valid_id_embeds->ne[0],
|
||||||
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
|
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
|
||||||
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
|
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
|
||||||
|
|
||||||
if (left && right) {
|
if (left && right) {
|
||||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
|
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1);
|
||||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
|
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1);
|
||||||
} else if (left) {
|
} else if (left) {
|
||||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
|
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1);
|
||||||
} else if (right) {
|
} else if (right) {
|
||||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
|
stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask));
|
class_tokens_mask = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, class_tokens_mask));
|
||||||
class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds);
|
class_tokens_mask = ggml_repeat(ctx->ggml_ctx, class_tokens_mask, prompt_embeds);
|
||||||
prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask);
|
prompt_embeds = ggml_mul(ctx->ggml_ctx, prompt_embeds, class_tokens_mask);
|
||||||
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds);
|
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx->ggml_ctx, prompt_embeds, stacked_id_embeds);
|
||||||
ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds");
|
ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds");
|
||||||
return updated_prompt_embeds;
|
return updated_prompt_embeds;
|
||||||
}
|
}
|
||||||
@ -317,8 +317,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
|
|||||||
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
|
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* id_pixel_values,
|
struct ggml_tensor* id_pixel_values,
|
||||||
struct ggml_tensor* prompt_embeds,
|
struct ggml_tensor* prompt_embeds,
|
||||||
struct ggml_tensor* class_tokens_mask,
|
struct ggml_tensor* class_tokens_mask,
|
||||||
@ -331,15 +330,15 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
|
|||||||
auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]);
|
auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]);
|
||||||
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
|
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
|
||||||
|
|
||||||
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, backend, id_pixel_values); // [N, hidden_size]
|
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
|
||||||
struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)]
|
struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)]
|
||||||
struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280]
|
struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280]
|
||||||
|
|
||||||
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
|
id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 2, 0, 1, 3));
|
||||||
id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3));
|
id_embeds_2 = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds_2, 2, 0, 1, 3));
|
||||||
|
|
||||||
id_embeds = ggml_concat(ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
|
id_embeds = ggml_concat(ctx->ggml_ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
|
||||||
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 1, 2, 0, 3));
|
id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 1, 2, 0, 3));
|
||||||
|
|
||||||
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
@ -366,8 +365,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
|||||||
num_tokens));
|
num_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* id_pixel_values,
|
struct ggml_tensor* id_pixel_values,
|
||||||
struct ggml_tensor* prompt_embeds,
|
struct ggml_tensor* prompt_embeds,
|
||||||
struct ggml_tensor* class_tokens_mask,
|
struct ggml_tensor* class_tokens_mask,
|
||||||
@ -381,7 +379,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
|
|||||||
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
|
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
|
||||||
|
|
||||||
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
|
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
|
||||||
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, backend, id_pixel_values, false); // [N, hidden_size]
|
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
|
||||||
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
|
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
|
||||||
|
|
||||||
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
||||||
@ -458,7 +456,7 @@ public:
|
|||||||
zeros_right.clear();
|
zeros_right.clear();
|
||||||
zeros_right_16.clear();
|
zeros_right_16.clear();
|
||||||
|
|
||||||
ggml_context* ctx0 = compute_ctx;
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
|
||||||
@ -466,7 +464,7 @@ public:
|
|||||||
int64_t seq_length = prompt_embeds->ne[1];
|
int64_t seq_length = prompt_embeds->ne[1];
|
||||||
ggml_type type = GGML_TYPE_F32;
|
ggml_type type = GGML_TYPE_F32;
|
||||||
|
|
||||||
struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(ctx0, type, class_tokens_mask.size());
|
struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(runner_ctx.ggml_ctx, type, class_tokens_mask.size());
|
||||||
|
|
||||||
struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
|
struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
|
||||||
struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
|
struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
|
||||||
@ -488,16 +486,16 @@ public:
|
|||||||
}
|
}
|
||||||
// printf("\n");
|
// printf("\n");
|
||||||
if (ctmpos[0] > 0) {
|
if (ctmpos[0] > 0) {
|
||||||
// left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]);
|
// left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, 1, ctmpos[0]);
|
||||||
left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1);
|
left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, ctmpos[0], 1);
|
||||||
}
|
}
|
||||||
if (ctmpos[ctmpos.size() - 1] < seq_length - 1) {
|
if (ctmpos[ctmpos.size() - 1] < seq_length - 1) {
|
||||||
// right = ggml_new_tensor_3d(ctx0, type,
|
// right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type,
|
||||||
// hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
|
// hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
|
||||||
right = ggml_new_tensor_3d(ctx0, type,
|
right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type,
|
||||||
hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1);
|
hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1);
|
||||||
}
|
}
|
||||||
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size());
|
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(runner_ctx.ggml_ctx, GGML_TYPE_I32, ctmpos.size());
|
||||||
|
|
||||||
{
|
{
|
||||||
if (type == GGML_TYPE_F16)
|
if (type == GGML_TYPE_F16)
|
||||||
@ -530,16 +528,14 @@ public:
|
|||||||
}
|
}
|
||||||
struct ggml_tensor* updated_prompt_embeds = nullptr;
|
struct ggml_tensor* updated_prompt_embeds = nullptr;
|
||||||
if (pm_version == PM_VERSION_1)
|
if (pm_version == PM_VERSION_1)
|
||||||
updated_prompt_embeds = id_encoder.forward(ctx0,
|
updated_prompt_embeds = id_encoder.forward(&runner_ctx,
|
||||||
runtime_backend,
|
|
||||||
id_pixel_values_d,
|
id_pixel_values_d,
|
||||||
prompt_embeds_d,
|
prompt_embeds_d,
|
||||||
class_tokens_mask_d,
|
class_tokens_mask_d,
|
||||||
class_tokens_mask_pos,
|
class_tokens_mask_pos,
|
||||||
left, right);
|
left, right);
|
||||||
else if (pm_version == PM_VERSION_2)
|
else if (pm_version == PM_VERSION_2)
|
||||||
updated_prompt_embeds = id_encoder2.forward(ctx0,
|
updated_prompt_embeds = id_encoder2.forward(&runner_ctx,
|
||||||
runtime_backend,
|
|
||||||
id_pixel_values_d,
|
id_pixel_values_d,
|
||||||
prompt_embeds_d,
|
prompt_embeds_d,
|
||||||
class_tokens_mask_d,
|
class_tokens_mask_d,
|
||||||
|
|||||||
140
qwen_image.hpp
140
qwen_image.hpp
@ -27,18 +27,18 @@ namespace Qwen {
|
|||||||
blocks["linear_2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, out_dim, sample_proj_bias));
|
blocks["linear_2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, out_dim, sample_proj_bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* sample,
|
struct ggml_tensor* sample,
|
||||||
struct ggml_tensor* condition = nullptr) {
|
struct ggml_tensor* condition = nullptr) {
|
||||||
if (condition != nullptr) {
|
if (condition != nullptr) {
|
||||||
auto cond_proj = std::dynamic_pointer_cast<Linear>(blocks["cond_proj"]);
|
auto cond_proj = std::dynamic_pointer_cast<Linear>(blocks["cond_proj"]);
|
||||||
sample = ggml_add(ctx, sample, cond_proj->forward(ctx, condition));
|
sample = ggml_add(ctx->ggml_ctx, sample, cond_proj->forward(ctx, condition));
|
||||||
}
|
}
|
||||||
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
|
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
|
||||||
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
|
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
|
||||||
|
|
||||||
sample = linear_1->forward(ctx, sample);
|
sample = linear_1->forward(ctx, sample);
|
||||||
sample = ggml_silu_inplace(ctx, sample);
|
sample = ggml_silu_inplace(ctx->ggml_ctx, sample);
|
||||||
sample = linear_2->forward(ctx, sample);
|
sample = linear_2->forward(ctx, sample);
|
||||||
return sample;
|
return sample;
|
||||||
}
|
}
|
||||||
@ -50,13 +50,13 @@ namespace Qwen {
|
|||||||
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* timesteps) {
|
struct ggml_tensor* timesteps) {
|
||||||
// timesteps: [N,]
|
// timesteps: [N,]
|
||||||
// return: [N, embedding_dim]
|
// return: [N, embedding_dim]
|
||||||
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
||||||
|
|
||||||
auto timesteps_proj = ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
|
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
|
||||||
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
||||||
return timesteps_emb;
|
return timesteps_emb;
|
||||||
}
|
}
|
||||||
@ -65,7 +65,6 @@ namespace Qwen {
|
|||||||
struct QwenImageAttention : public GGMLBlock {
|
struct QwenImageAttention : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t dim_head;
|
int64_t dim_head;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
QwenImageAttention(int64_t query_dim,
|
QwenImageAttention(int64_t query_dim,
|
||||||
@ -75,9 +74,8 @@ namespace Qwen {
|
|||||||
int64_t out_context_dim = 0,
|
int64_t out_context_dim = 0,
|
||||||
bool bias = true,
|
bool bias = true,
|
||||||
bool out_bias = true,
|
bool out_bias = true,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
: dim_head(dim_head) {
|
||||||
: dim_head(dim_head), flash_attn(flash_attn) {
|
|
||||||
int64_t inner_dim = out_dim > 0 ? out_dim : dim_head * num_heads;
|
int64_t inner_dim = out_dim > 0 ? out_dim : dim_head * num_heads;
|
||||||
out_dim = out_dim > 0 ? out_dim : query_dim;
|
out_dim = out_dim > 0 ? out_dim : query_dim;
|
||||||
out_context_dim = out_context_dim > 0 ? out_context_dim : query_dim;
|
out_context_dim = out_context_dim > 0 ? out_context_dim : query_dim;
|
||||||
@ -105,8 +103,7 @@ namespace Qwen {
|
|||||||
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
|
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* img,
|
struct ggml_tensor* img,
|
||||||
struct ggml_tensor* txt,
|
struct ggml_tensor* txt,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
@ -138,49 +135,49 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto img_q = to_q->forward(ctx, img);
|
auto img_q = to_q->forward(ctx, img);
|
||||||
int64_t num_heads = img_q->ne[0] / dim_head;
|
int64_t num_heads = img_q->ne[0] / dim_head;
|
||||||
img_q = ggml_reshape_4d(ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
img_q = ggml_reshape_4d(ctx->ggml_ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
auto img_k = to_k->forward(ctx, img);
|
auto img_k = to_k->forward(ctx, img);
|
||||||
img_k = ggml_reshape_4d(ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
img_k = ggml_reshape_4d(ctx->ggml_ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
auto img_v = to_v->forward(ctx, img);
|
auto img_v = to_v->forward(ctx, img);
|
||||||
img_v = ggml_reshape_4d(ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
img_v = ggml_reshape_4d(ctx->ggml_ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
|
|
||||||
img_q = norm_q->forward(ctx, img_q);
|
img_q = norm_q->forward(ctx, img_q);
|
||||||
img_k = norm_k->forward(ctx, img_k);
|
img_k = norm_k->forward(ctx, img_k);
|
||||||
|
|
||||||
auto txt_q = add_q_proj->forward(ctx, txt);
|
auto txt_q = add_q_proj->forward(ctx, txt);
|
||||||
txt_q = ggml_reshape_4d(ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
txt_q = ggml_reshape_4d(ctx->ggml_ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
auto txt_k = add_k_proj->forward(ctx, txt);
|
auto txt_k = add_k_proj->forward(ctx, txt);
|
||||||
txt_k = ggml_reshape_4d(ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
txt_k = ggml_reshape_4d(ctx->ggml_ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
auto txt_v = add_v_proj->forward(ctx, txt);
|
auto txt_v = add_v_proj->forward(ctx, txt);
|
||||||
txt_v = ggml_reshape_4d(ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
txt_v = ggml_reshape_4d(ctx->ggml_ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
|
|
||||||
txt_q = norm_added_q->forward(ctx, txt_q);
|
txt_q = norm_added_q->forward(ctx, txt_q);
|
||||||
txt_k = norm_added_k->forward(ctx, txt_k);
|
txt_k = norm_added_k->forward(ctx, txt_k);
|
||||||
|
|
||||||
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_txt_token, N, hidden_size]
|
0); // [n_txt_token, N, hidden_size]
|
||||||
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
||||||
auto img_attn_out = ggml_view_3d(ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
attn->ne[1],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
||||||
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
||||||
@ -194,8 +191,7 @@ namespace Qwen {
|
|||||||
QwenImageTransformerBlock(int64_t dim,
|
QwenImageTransformerBlock(int64_t dim,
|
||||||
int64_t num_attention_heads,
|
int64_t num_attention_heads,
|
||||||
int64_t attention_head_dim,
|
int64_t attention_head_dim,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6) {
|
||||||
bool flash_attn = false) {
|
|
||||||
// img_mod.0 is nn.SiLU()
|
// img_mod.0 is nn.SiLU()
|
||||||
blocks["img_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
|
blocks["img_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
|
||||||
|
|
||||||
@ -217,12 +213,10 @@ namespace Qwen {
|
|||||||
0, // out_context-dim
|
0, // out_context-dim
|
||||||
true, // bias
|
true, // bias
|
||||||
true, // out_bias
|
true, // out_bias
|
||||||
eps,
|
eps));
|
||||||
flash_attn));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
|
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* img,
|
struct ggml_tensor* img,
|
||||||
struct ggml_tensor* txt,
|
struct ggml_tensor* txt,
|
||||||
struct ggml_tensor* t_emb,
|
struct ggml_tensor* t_emb,
|
||||||
@ -244,40 +238,40 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto attn = std::dynamic_pointer_cast<QwenImageAttention>(blocks["attn"]);
|
auto attn = std::dynamic_pointer_cast<QwenImageAttention>(blocks["attn"]);
|
||||||
|
|
||||||
auto img_mod_params = ggml_silu(ctx, t_emb);
|
auto img_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
|
||||||
img_mod_params = img_mod_1->forward(ctx, img_mod_params);
|
img_mod_params = img_mod_1->forward(ctx, img_mod_params);
|
||||||
auto img_mod_param_vec = ggml_ext_chunk(ctx, img_mod_params, 6, 0);
|
auto img_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, img_mod_params, 6, 0);
|
||||||
|
|
||||||
auto txt_mod_params = ggml_silu(ctx, t_emb);
|
auto txt_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
|
||||||
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
|
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
|
||||||
auto txt_mod_param_vec = ggml_ext_chunk(ctx, txt_mod_params, 6, 0);
|
auto txt_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, txt_mod_params, 6, 0);
|
||||||
|
|
||||||
auto img_normed = img_norm1->forward(ctx, img);
|
auto img_normed = img_norm1->forward(ctx, img);
|
||||||
auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
|
auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
|
||||||
auto img_gate1 = img_mod_param_vec[2];
|
auto img_gate1 = img_mod_param_vec[2];
|
||||||
|
|
||||||
auto txt_normed = txt_norm1->forward(ctx, txt);
|
auto txt_normed = txt_norm1->forward(ctx, txt);
|
||||||
auto txt_modulated = Flux::modulate(ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]);
|
auto txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]);
|
||||||
auto txt_gate1 = txt_mod_param_vec[2];
|
auto txt_gate1 = txt_mod_param_vec[2];
|
||||||
|
|
||||||
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe);
|
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, img_modulated, txt_modulated, pe);
|
||||||
|
|
||||||
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn_output, img_gate1));
|
||||||
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1));
|
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn_output, txt_gate1));
|
||||||
|
|
||||||
auto img_normed2 = img_norm2->forward(ctx, img);
|
auto img_normed2 = img_norm2->forward(ctx, img);
|
||||||
auto img_modulated2 = Flux::modulate(ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
|
auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
|
||||||
auto img_gate2 = img_mod_param_vec[5];
|
auto img_gate2 = img_mod_param_vec[5];
|
||||||
|
|
||||||
auto txt_normed2 = txt_norm2->forward(ctx, txt);
|
auto txt_normed2 = txt_norm2->forward(ctx, txt);
|
||||||
auto txt_modulated2 = Flux::modulate(ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]);
|
auto txt_modulated2 = Flux::modulate(ctx->ggml_ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]);
|
||||||
auto txt_gate2 = txt_mod_param_vec[5];
|
auto txt_gate2 = txt_mod_param_vec[5];
|
||||||
|
|
||||||
auto img_mlp_out = img_mlp->forward(ctx, img_modulated2);
|
auto img_mlp_out = img_mlp->forward(ctx, img_modulated2);
|
||||||
auto txt_mlp_out = txt_mlp->forward(ctx, txt_modulated2);
|
auto txt_mlp_out = txt_mlp->forward(ctx, txt_modulated2);
|
||||||
|
|
||||||
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_gate2));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_gate2));
|
||||||
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_gate2));
|
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_gate2));
|
||||||
|
|
||||||
return {img, txt};
|
return {img, txt};
|
||||||
}
|
}
|
||||||
@ -294,7 +288,7 @@ namespace Qwen {
|
|||||||
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(conditioning_embedding_dim, embedding_dim * 2, bias));
|
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(conditioning_embedding_dim, embedding_dim * 2, bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -304,13 +298,13 @@ namespace Qwen {
|
|||||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
|
|
||||||
auto emb = linear->forward(ctx, ggml_silu(ctx, c));
|
auto emb = linear->forward(ctx, ggml_silu(ctx->ggml_ctx, c));
|
||||||
auto mods = ggml_ext_chunk(ctx, emb, 2, 0);
|
auto mods = ggml_ext_chunk(ctx->ggml_ctx, emb, 2, 0);
|
||||||
auto scale = mods[0];
|
auto scale = mods[0];
|
||||||
auto shift = mods[1];
|
auto shift = mods[1];
|
||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
x = Flux::modulate(ctx, x, shift, scale);
|
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -327,7 +321,6 @@ namespace Qwen {
|
|||||||
float theta = 10000;
|
float theta = 10000;
|
||||||
std::vector<int> axes_dim = {16, 56, 56};
|
std::vector<int> axes_dim = {16, 56, 56};
|
||||||
int64_t axes_dim_sum = 128;
|
int64_t axes_dim_sum = 128;
|
||||||
bool flash_attn = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class QwenImageModel : public GGMLBlock {
|
class QwenImageModel : public GGMLBlock {
|
||||||
@ -349,8 +342,7 @@ namespace Qwen {
|
|||||||
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
|
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
|
||||||
params.num_attention_heads,
|
params.num_attention_heads,
|
||||||
params.attention_head_dim,
|
params.attention_head_dim,
|
||||||
1e-6f,
|
1e-6f));
|
||||||
params.flash_attn));
|
|
||||||
blocks["transformer_blocks." + std::to_string(i)] = block;
|
blocks["transformer_blocks." + std::to_string(i)] = block;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,8 +413,7 @@ namespace Qwen {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
|
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -442,7 +433,7 @@ namespace Qwen {
|
|||||||
for (int i = 0; i < params.num_layers; i++) {
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
auto result = block->forward(ctx, backend, img, txt, t_emb, pe);
|
auto result = block->forward(ctx, img, txt, t_emb, pe);
|
||||||
img = result.first;
|
img = result.first;
|
||||||
txt = result.second;
|
txt = result.second;
|
||||||
}
|
}
|
||||||
@ -453,8 +444,7 @@ namespace Qwen {
|
|||||||
return img;
|
return img;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -472,32 +462,32 @@ namespace Qwen {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
auto img = process_img(ctx->ggml_ctx, x);
|
||||||
uint64_t img_tokens = img->ne[1];
|
uint64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx, ref);
|
ref = process_img(ctx->ggml_ctx, ref);
|
||||||
img = ggml_concat(ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
||||||
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
||||||
|
|
||||||
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||||
|
|
||||||
if (out->ne[1] > img_tokens) {
|
if (out->ne[1] > img_tokens) {
|
||||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||||
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
||||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
// slice
|
// slice
|
||||||
out = ggml_ext_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
||||||
out = ggml_ext_slice(ctx, out, 0, 0, W); // [N, C, H, W]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -514,10 +504,8 @@ namespace Qwen {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
SDVersion version = VERSION_QWEN_IMAGE,
|
SDVersion version = VERSION_QWEN_IMAGE)
|
||||||
bool flash_attn = false)
|
|
||||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||||
qwen_image_params.flash_attn = flash_attn;
|
|
||||||
qwen_image_params.num_layers = 0;
|
qwen_image_params.num_layers = 0;
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
@ -582,8 +570,9 @@ namespace Qwen {
|
|||||||
// pe->data = nullptr;
|
// pe->data = nullptr;
|
||||||
set_backend_tensor_data(pe, pe_vec.data());
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
struct ggml_tensor* out = qwen_image.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
|
||||||
|
struct ggml_tensor* out = qwen_image.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
@ -672,8 +661,7 @@ namespace Qwen {
|
|||||||
false,
|
false,
|
||||||
tensor_types,
|
tensor_types,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
VERSION_QWEN_IMAGE,
|
VERSION_QWEN_IMAGE);
|
||||||
true);
|
|
||||||
|
|
||||||
qwen_image->alloc_params_buffer();
|
qwen_image->alloc_params_buffer();
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
|||||||
140
qwenvl.hpp
140
qwenvl.hpp
@ -349,15 +349,15 @@ namespace Qwen {
|
|||||||
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
|
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
|
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
|
||||||
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
|
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
|
||||||
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
||||||
|
|
||||||
auto h = gate_proj->forward(ctx, x);
|
auto h = gate_proj->forward(ctx, x);
|
||||||
h = ggml_silu_inplace(ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
h = ggml_mul_inplace(ctx, h, up_proj->forward(ctx, x));
|
h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x));
|
||||||
h = down_proj->forward(ctx, h);
|
h = down_proj->forward(ctx, h);
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
@ -409,10 +409,10 @@ namespace Qwen {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
|
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
|
||||||
// return: [N*grid_t*grid_h*grid_w, embed_dim]
|
// return: [N*grid_t*grid_h*grid_w, embed_dim]
|
||||||
x = ggml_reshape_4d(ctx,
|
x = ggml_reshape_4d(ctx->ggml_ctx,
|
||||||
x,
|
x,
|
||||||
patch_size,
|
patch_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
@ -423,22 +423,22 @@ namespace Qwen {
|
|||||||
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
|
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
|
||||||
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
|
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
|
||||||
|
|
||||||
auto x0 = ggml_ext_slice(ctx, x, 2, 0, 1);
|
auto x0 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
|
||||||
x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
|
x0 = ggml_reshape_4d(ctx->ggml_ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
|
||||||
x0 = proj_0->forward(ctx, x0);
|
x0 = proj_0->forward(ctx, x0);
|
||||||
|
|
||||||
auto x1 = ggml_ext_slice(ctx, x, 2, 1, 2);
|
auto x1 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, 2);
|
||||||
x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
|
x1 = ggml_reshape_4d(ctx->ggml_ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
|
||||||
x1 = proj_1->forward(ctx, x1);
|
x1 = proj_1->forward(ctx, x1);
|
||||||
|
|
||||||
x = ggml_add(ctx, x0, x1);
|
x = ggml_add(ctx->ggml_ctx, x0, x1);
|
||||||
} else {
|
} else {
|
||||||
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
|
x = ggml_reshape_2d(ctx->ggml_ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -458,15 +458,15 @@ namespace Qwen {
|
|||||||
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
|
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
|
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
|
||||||
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
||||||
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
||||||
|
|
||||||
x = ln_q->forward(ctx, x);
|
x = ln_q->forward(ctx, x);
|
||||||
x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
||||||
x = mlp_0->forward(ctx, x);
|
x = mlp_0->forward(ctx, x);
|
||||||
x = ggml_gelu(ctx, x);
|
x = ggml_gelu(ctx->ggml_ctx, x);
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -495,8 +495,7 @@ namespace Qwen {
|
|||||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask = nullptr) {
|
struct ggml_tensor* mask = nullptr) {
|
||||||
@ -519,14 +518,14 @@ namespace Qwen {
|
|||||||
} else {
|
} else {
|
||||||
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
||||||
auto qkv = qkv_proj->forward(ctx, x);
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
qkv_vec = split_qkv(ctx, qkv);
|
qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
|
||||||
x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size]
|
x = Rope::attention(ctx, q, k, v, pe, mask, 1.f, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
|
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
return x;
|
return x;
|
||||||
@ -546,8 +545,7 @@ namespace Qwen {
|
|||||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask = nullptr) {
|
struct ggml_tensor* mask = nullptr) {
|
||||||
@ -559,13 +557,13 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = norm1->forward(ctx, x);
|
x = norm1->forward(ctx, x);
|
||||||
x = attn->forward(ctx, backend, x, pe, mask);
|
x = attn->forward(ctx, x, pe, mask);
|
||||||
x = ggml_add_inplace(ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
x = norm2->forward(ctx, x);
|
x = norm2->forward(ctx, x);
|
||||||
x = mlp->forward(ctx, x);
|
x = mlp->forward(ctx, x);
|
||||||
x = ggml_add_inplace(ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -607,8 +605,7 @@ namespace Qwen {
|
|||||||
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
|
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* pixel_values,
|
struct ggml_tensor* pixel_values,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* window_index,
|
struct ggml_tensor* window_index,
|
||||||
@ -623,9 +620,9 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto x = patch_embed->forward(ctx, pixel_values);
|
auto x = patch_embed->forward(ctx, pixel_values);
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
|
||||||
x = ggml_get_rows(ctx, x, window_index);
|
x = ggml_get_rows(ctx->ggml_ctx, x, window_index);
|
||||||
x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
|
||||||
|
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
|
||||||
@ -634,12 +631,12 @@ namespace Qwen {
|
|||||||
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
|
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
|
||||||
mask = nullptr;
|
mask = nullptr;
|
||||||
}
|
}
|
||||||
x = block->forward(ctx, backend, x, pe, mask);
|
x = block->forward(ctx, x, pe, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = merger->forward(ctx, x);
|
x = merger->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_get_rows(ctx, x, window_inverse_index);
|
x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -664,8 +661,7 @@ namespace Qwen {
|
|||||||
blocks["o_proj"] = std::shared_ptr<GGMLBlock>(new Linear(num_heads * head_dim, hidden_size, false));
|
blocks["o_proj"] = std::shared_ptr<GGMLBlock>(new Linear(num_heads * head_dim, hidden_size, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -680,21 +676,21 @@ namespace Qwen {
|
|||||||
auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
||||||
auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
||||||
|
|
||||||
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim]
|
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim]
|
||||||
k = ggml_reshape_4d(ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
||||||
v = ggml_reshape_4d(ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
|
||||||
int sections[4] = {16, 24, 24, 0};
|
int sections[4] = {16, 24, 24, 0};
|
||||||
q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
k = ggml_rope_multi(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
|
|
||||||
q = ggml_cont(ctx, ggml_ext_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
|
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
|
||||||
q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
|
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
|
||||||
|
|
||||||
k = ggml_cont(ctx, ggml_ext_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||||
k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
return x;
|
return x;
|
||||||
@ -714,8 +710,7 @@ namespace Qwen {
|
|||||||
blocks["post_attention_layernorm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
blocks["post_attention_layernorm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
@ -726,13 +721,13 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = input_layernorm->forward(ctx, x);
|
x = input_layernorm->forward(ctx, x);
|
||||||
x = self_attn->forward(ctx, backend, x, input_pos);
|
x = self_attn->forward(ctx, x, input_pos);
|
||||||
x = ggml_add_inplace(ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
x = post_attention_layernorm->forward(ctx, x);
|
x = post_attention_layernorm->forward(ctx, x);
|
||||||
x = mlp->forward(ctx, x);
|
x = mlp->forward(ctx, x);
|
||||||
x = ggml_add_inplace(ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -761,8 +756,7 @@ namespace Qwen {
|
|||||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
@ -777,7 +771,7 @@ namespace Qwen {
|
|||||||
if (image_embeds.size() > 0) {
|
if (image_embeds.size() > 0) {
|
||||||
GGML_ASSERT(x->ne[2] == 1); // N == 1
|
GGML_ASSERT(x->ne[2] == 1); // N == 1
|
||||||
|
|
||||||
auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type);
|
auto raw_x = ggml_cast(ctx->ggml_ctx, x, image_embeds[0].second->type);
|
||||||
int64_t txt_token_start = 0;
|
int64_t txt_token_start = 0;
|
||||||
int64_t txt_token_end = 0;
|
int64_t txt_token_end = 0;
|
||||||
|
|
||||||
@ -791,23 +785,23 @@ namespace Qwen {
|
|||||||
}
|
}
|
||||||
txt_token_end = image_embeds[i].first;
|
txt_token_end = image_embeds[i].first;
|
||||||
|
|
||||||
auto txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
|
auto txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end);
|
||||||
if (input_embed == nullptr) {
|
if (input_embed == nullptr) {
|
||||||
input_embed = txt_embed;
|
input_embed = txt_embed;
|
||||||
} else {
|
} else {
|
||||||
input_embed = ggml_concat(ctx, input_embed, txt_embed, 1);
|
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, txt_embed, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto image_embed = image_embeds[i].second;
|
auto image_embed = image_embeds[i].second;
|
||||||
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
|
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, image_embed, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
|
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
|
||||||
txt_token_end = raw_x->ne[1];
|
txt_token_end = raw_x->ne[1];
|
||||||
|
|
||||||
auto final_txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
|
auto final_txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end);
|
||||||
|
|
||||||
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
|
input_embed = ggml_concat(ctx->ggml_ctx, input_embed, final_txt_embed, 1);
|
||||||
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
|
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
|
||||||
|
|
||||||
x = input_embed;
|
x = input_embed;
|
||||||
@ -816,7 +810,7 @@ namespace Qwen {
|
|||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
x = block->forward(ctx, backend, x, input_pos);
|
x = block->forward(ctx, x, input_pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
@ -880,20 +874,18 @@ namespace Qwen {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
|
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
|
||||||
|
|
||||||
auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds);
|
auto x = model->forward(ctx, input_ids, input_pos, image_embeds);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
|
struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* pixel_values,
|
struct ggml_tensor* pixel_values,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* window_index,
|
struct ggml_tensor* window_index,
|
||||||
@ -901,7 +893,7 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* window_mask) {
|
struct ggml_tensor* window_mask) {
|
||||||
GGML_ASSERT(enable_vision);
|
GGML_ASSERT(enable_vision);
|
||||||
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
|
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
|
||||||
return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask);
|
return vision_model->forward(ctx, pixel_values, pe, window_index, window_inverse_index, window_mask);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -959,23 +951,21 @@ namespace Qwen {
|
|||||||
model.get_param_tensors(tensors, prefix);
|
model.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
|
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
|
struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* pixel_values,
|
struct ggml_tensor* pixel_values,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
struct ggml_tensor* window_index,
|
struct ggml_tensor* window_index,
|
||||||
struct ggml_tensor* window_inverse_index,
|
struct ggml_tensor* window_inverse_index,
|
||||||
struct ggml_tensor* window_mask) {
|
struct ggml_tensor* window_mask) {
|
||||||
auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
|
auto hidden_states = model.vision_forward(ctx, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1002,7 +992,9 @@ namespace Qwen {
|
|||||||
n_tokens * 4);
|
n_tokens * 4);
|
||||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
@ -1167,8 +1159,8 @@ namespace Qwen {
|
|||||||
// pe->data = nullptr;
|
// pe->data = nullptr;
|
||||||
set_backend_tensor_data(pe, pe_vec.data());
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = vision_forward(compute_ctx,
|
auto runnter_ctx = get_context();
|
||||||
runtime_backend,
|
struct ggml_tensor* hidden_states = vision_forward(&runnter_ctx,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
pe,
|
pe,
|
||||||
window_index,
|
window_index,
|
||||||
|
|||||||
10
rope.hpp
10
rope.hpp
@ -386,23 +386,21 @@ namespace Rope {
|
|||||||
return x_out;
|
return x_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* attention(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* q,
|
struct ggml_tensor* q,
|
||||||
struct ggml_tensor* k,
|
struct ggml_tensor* k,
|
||||||
struct ggml_tensor* v,
|
struct ggml_tensor* v,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask,
|
struct ggml_tensor* mask,
|
||||||
bool flash_attn,
|
|
||||||
float kv_scale = 1.0f,
|
float kv_scale = 1.0f,
|
||||||
bool rope_interleaved = true) {
|
bool rope_interleaved = true) {
|
||||||
// q,k,v: [N, L, n_head, d_head]
|
// q,k,v: [N, L, n_head, d_head]
|
||||||
// pe: [L, d_head/2, 2, 2]
|
// pe: [L, d_head/2, 2, 2]
|
||||||
// return: [N, L, n_head*d_head]
|
// return: [N, L, n_head*d_head]
|
||||||
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
|
|
||||||
auto x = ggml_ext_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
|
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
}; // namespace Rope
|
}; // namespace Rope
|
||||||
|
|||||||
@ -341,16 +341,12 @@ public:
|
|||||||
LOG_INFO("CLIP: Using CPU backend");
|
LOG_INFO("CLIP: Using CPU backend");
|
||||||
clip_backend = ggml_backend_cpu_init();
|
clip_backend = ggml_backend_cpu_init();
|
||||||
}
|
}
|
||||||
if (sd_ctx_params->diffusion_flash_attn) {
|
|
||||||
LOG_INFO("Using flash attention in the diffusion model");
|
|
||||||
}
|
|
||||||
if (sd_version_is_sd3(version)) {
|
if (sd_version_is_sd3(version)) {
|
||||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
|
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types);
|
||||||
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
sd_ctx_params->diffusion_flash_attn,
|
|
||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types);
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
bool is_chroma = false;
|
bool is_chroma = false;
|
||||||
@ -384,7 +380,6 @@ public:
|
|||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
version,
|
version,
|
||||||
sd_ctx_params->diffusion_flash_attn,
|
|
||||||
sd_ctx_params->chroma_use_dit_mask);
|
sd_ctx_params->chroma_use_dit_mask);
|
||||||
} else if (sd_version_is_wan(version)) {
|
} else if (sd_version_is_wan(version)) {
|
||||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||||
@ -397,15 +392,13 @@ public:
|
|||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
version,
|
version);
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
|
||||||
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
|
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
|
||||||
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
|
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
"model.high_noise_diffusion_model",
|
"model.high_noise_diffusion_model",
|
||||||
version,
|
version);
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
|
||||||
}
|
}
|
||||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||||
@ -428,8 +421,7 @@ public:
|
|||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
version,
|
version);
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
|
||||||
} else { // SD1.x SD2.x SDXL
|
} else { // SD1.x SD2.x SDXL
|
||||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||||
@ -448,14 +440,18 @@ public:
|
|||||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
version,
|
version);
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
|
||||||
if (sd_ctx_params->diffusion_conv_direct) {
|
if (sd_ctx_params->diffusion_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the diffusion model");
|
LOG_INFO("Using Conv2d direct in the diffusion model");
|
||||||
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
|
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_ctx_params->diffusion_flash_attn) {
|
||||||
|
LOG_INFO("Using flash attention in the diffusion model");
|
||||||
|
diffusion_model->set_flash_attn_enabled(true);
|
||||||
|
}
|
||||||
|
|
||||||
cond_stage_model->alloc_params_buffer();
|
cond_stage_model->alloc_params_buffer();
|
||||||
cond_stage_model->get_param_tensors(tensors);
|
cond_stage_model->get_param_tensors(tensors);
|
||||||
|
|
||||||
@ -500,7 +496,7 @@ public:
|
|||||||
version);
|
version);
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
LOG_INFO("Using Conv2d direct in the vae model");
|
||||||
first_stage_model->enable_conv2d_direct();
|
first_stage_model->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
if (version == VERSION_SDXL &&
|
if (version == VERSION_SDXL &&
|
||||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||||
@ -522,7 +518,7 @@ public:
|
|||||||
version);
|
version);
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the tae model");
|
LOG_INFO("Using Conv2d direct in the tae model");
|
||||||
tae_first_stage->enable_conv2d_direct();
|
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
||||||
@ -541,7 +537,7 @@ public:
|
|||||||
version);
|
version);
|
||||||
if (sd_ctx_params->diffusion_conv_direct) {
|
if (sd_ctx_params->diffusion_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the control net");
|
LOG_INFO("Using Conv2d direct in the control net");
|
||||||
control_net->enable_conv2d_direct();
|
control_net->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
67
t5.hpp
67
t5.hpp
@ -472,10 +472,10 @@ public:
|
|||||||
: hidden_size(hidden_size),
|
: hidden_size(hidden_size),
|
||||||
eps(eps) {}
|
eps(eps) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
x = ggml_rms_norm(ctx, x, eps);
|
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
|
||||||
x = ggml_mul(ctx, x, w);
|
x = ggml_mul(ctx->ggml_ctx, x, w);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -487,13 +487,13 @@ public:
|
|||||||
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false));
|
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, n_token, model_dim]
|
// x: [N, n_token, model_dim]
|
||||||
auto wi = std::dynamic_pointer_cast<Linear>(blocks["wi"]);
|
auto wi = std::dynamic_pointer_cast<Linear>(blocks["wi"]);
|
||||||
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
||||||
|
|
||||||
x = wi->forward(ctx, x);
|
x = wi->forward(ctx, x);
|
||||||
x = ggml_relu_inplace(ctx, x);
|
x = ggml_relu_inplace(ctx->ggml_ctx, x);
|
||||||
x = wo->forward(ctx, x);
|
x = wo->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -509,15 +509,15 @@ public:
|
|||||||
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false, false, false, scale));
|
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false, false, false, scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, n_token, model_dim]
|
// x: [N, n_token, model_dim]
|
||||||
auto wi_0 = std::dynamic_pointer_cast<Linear>(blocks["wi_0"]);
|
auto wi_0 = std::dynamic_pointer_cast<Linear>(blocks["wi_0"]);
|
||||||
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
||||||
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
||||||
|
|
||||||
auto hidden_gelu = ggml_gelu_inplace(ctx, wi_0->forward(ctx, x));
|
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x));
|
||||||
auto hidden_linear = wi_1->forward(ctx, x);
|
auto hidden_linear = wi_1->forward(ctx, x);
|
||||||
x = ggml_mul_inplace(ctx, hidden_gelu, hidden_linear);
|
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
|
||||||
x = wo->forward(ctx, x);
|
x = wo->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -530,14 +530,14 @@ public:
|
|||||||
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, n_token, model_dim]
|
// x: [N, n_token, model_dim]
|
||||||
auto DenseReluDense = std::dynamic_pointer_cast<T5DenseGatedActDense>(blocks["DenseReluDense"]);
|
auto DenseReluDense = std::dynamic_pointer_cast<T5DenseGatedActDense>(blocks["DenseReluDense"]);
|
||||||
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
|
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
|
||||||
|
|
||||||
auto forwarded_states = layer_norm->forward(ctx, x);
|
auto forwarded_states = layer_norm->forward(ctx, x);
|
||||||
forwarded_states = DenseReluDense->forward(ctx, forwarded_states);
|
forwarded_states = DenseReluDense->forward(ctx, forwarded_states);
|
||||||
x = ggml_add_inplace(ctx, forwarded_states, x);
|
x = ggml_add_inplace(ctx->ggml_ctx, forwarded_states, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -569,18 +569,17 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* compute_bias(struct ggml_context* ctx,
|
struct ggml_tensor* compute_bias(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* relative_position_bucket) {
|
struct ggml_tensor* relative_position_bucket) {
|
||||||
auto relative_attention_bias = std::dynamic_pointer_cast<Embedding>(blocks["relative_attention_bias"]);
|
auto relative_attention_bias = std::dynamic_pointer_cast<Embedding>(blocks["relative_attention_bias"]);
|
||||||
|
|
||||||
auto values = relative_attention_bias->forward(ctx, relative_position_bucket); // shape (query_length, key_length, num_heads)
|
auto values = relative_attention_bias->forward(ctx, relative_position_bucket); // shape (query_length, key_length, num_heads)
|
||||||
values = ggml_cont(ctx, ggml_permute(ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length)
|
values = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length)
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
// x: [N, n_token, model_dim]
|
// x: [N, n_token, model_dim]
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* past_bias = nullptr,
|
struct ggml_tensor* past_bias = nullptr,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
@ -602,16 +601,16 @@ public:
|
|||||||
}
|
}
|
||||||
if (past_bias != nullptr) {
|
if (past_bias != nullptr) {
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
mask = ggml_repeat(ctx, mask, past_bias);
|
mask = ggml_repeat(ctx->ggml_ctx, mask, past_bias);
|
||||||
mask = ggml_add(ctx, mask, past_bias);
|
mask = ggml_add(ctx->ggml_ctx, mask, past_bias);
|
||||||
} else {
|
} else {
|
||||||
mask = past_bias;
|
mask = past_bias;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
k = ggml_scale_inplace(ctx, k, sqrt(d_head));
|
k = ggml_scale_inplace(ctx->ggml_ctx, k, sqrt(d_head));
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, model_dim]
|
x = out_proj->forward(ctx, x); // [N, n_token, model_dim]
|
||||||
return {x, past_bias};
|
return {x, past_bias};
|
||||||
@ -629,8 +628,7 @@ public:
|
|||||||
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* past_bias = nullptr,
|
struct ggml_tensor* past_bias = nullptr,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
@ -640,11 +638,11 @@ public:
|
|||||||
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
|
auto layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["layer_norm"]);
|
||||||
|
|
||||||
auto normed_hidden_state = layer_norm->forward(ctx, x);
|
auto normed_hidden_state = layer_norm->forward(ctx, x);
|
||||||
auto ret = SelfAttention->forward(ctx, backend, normed_hidden_state, past_bias, mask, relative_position_bucket);
|
auto ret = SelfAttention->forward(ctx, normed_hidden_state, past_bias, mask, relative_position_bucket);
|
||||||
auto output = ret.first;
|
auto output = ret.first;
|
||||||
past_bias = ret.second;
|
past_bias = ret.second;
|
||||||
|
|
||||||
x = ggml_add_inplace(ctx, output, x);
|
x = ggml_add_inplace(ctx->ggml_ctx, output, x);
|
||||||
return {x, past_bias};
|
return {x, past_bias};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -656,8 +654,7 @@ public:
|
|||||||
blocks["layer.1"] = std::shared_ptr<GGMLBlock>(new T5LayerFF(model_dim, ff_dim));
|
blocks["layer.1"] = std::shared_ptr<GGMLBlock>(new T5LayerFF(model_dim, ff_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* past_bias = nullptr,
|
struct ggml_tensor* past_bias = nullptr,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
@ -666,7 +663,7 @@ public:
|
|||||||
auto layer_0 = std::dynamic_pointer_cast<T5LayerSelfAttention>(blocks["layer.0"]);
|
auto layer_0 = std::dynamic_pointer_cast<T5LayerSelfAttention>(blocks["layer.0"]);
|
||||||
auto layer_1 = std::dynamic_pointer_cast<T5LayerFF>(blocks["layer.1"]);
|
auto layer_1 = std::dynamic_pointer_cast<T5LayerFF>(blocks["layer.1"]);
|
||||||
|
|
||||||
auto ret = layer_0->forward(ctx, backend, x, past_bias, mask, relative_position_bucket);
|
auto ret = layer_0->forward(ctx, x, past_bias, mask, relative_position_bucket);
|
||||||
x = ret.first;
|
x = ret.first;
|
||||||
past_bias = ret.second;
|
past_bias = ret.second;
|
||||||
x = layer_1->forward(ctx, x);
|
x = layer_1->forward(ctx, x);
|
||||||
@ -692,8 +689,7 @@ public:
|
|||||||
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* past_bias = nullptr,
|
struct ggml_tensor* past_bias = nullptr,
|
||||||
struct ggml_tensor* attention_mask = nullptr,
|
struct ggml_tensor* attention_mask = nullptr,
|
||||||
@ -702,7 +698,7 @@ public:
|
|||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]);
|
||||||
|
|
||||||
auto ret = block->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
|
auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
|
||||||
x = ret.first;
|
x = ret.first;
|
||||||
past_bias = ret.second;
|
past_bias = ret.second;
|
||||||
}
|
}
|
||||||
@ -740,8 +736,7 @@ public:
|
|||||||
params.model_dim));
|
params.model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* past_bias = nullptr,
|
struct ggml_tensor* past_bias = nullptr,
|
||||||
struct ggml_tensor* attention_mask = nullptr,
|
struct ggml_tensor* attention_mask = nullptr,
|
||||||
@ -752,7 +747,7 @@ public:
|
|||||||
auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]);
|
||||||
|
|
||||||
auto x = shared->forward(ctx, input_ids);
|
auto x = shared->forward(ctx, input_ids);
|
||||||
x = encoder->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket);
|
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -784,15 +779,14 @@ struct T5Runner : public GGMLRunner {
|
|||||||
model.get_param_tensors(tensors, prefix);
|
model.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* relative_position_bucket,
|
struct ggml_tensor* relative_position_bucket,
|
||||||
struct ggml_tensor* attention_mask = nullptr) {
|
struct ggml_tensor* attention_mask = nullptr) {
|
||||||
size_t N = input_ids->ne[1];
|
size_t N = input_ids->ne[1];
|
||||||
size_t n_token = input_ids->ne[0];
|
size_t n_token = input_ids->ne[0];
|
||||||
|
|
||||||
auto hidden_states = model.forward(ctx, backend, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
|
auto hidden_states = model.forward(ctx, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -818,7 +812,8 @@ struct T5Runner : public GGMLRunner {
|
|||||||
input_ids->ne[0]);
|
input_ids->ne[0]);
|
||||||
set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data());
|
set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data());
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, relative_position_bucket, attention_mask);
|
auto runner_ctx = get_context();
|
||||||
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, relative_position_bucket, attention_mask);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
|||||||
42
tae.hpp
42
tae.hpp
@ -29,7 +29,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [n, n_in, h, w]
|
// x: [n, n_in, h, w]
|
||||||
// return: [n, n_out, h, w]
|
// return: [n, n_out, h, w]
|
||||||
|
|
||||||
@ -38,9 +38,9 @@ public:
|
|||||||
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||||
|
|
||||||
auto h = conv_0->forward(ctx, x);
|
auto h = conv_0->forward(ctx, x);
|
||||||
h = ggml_relu_inplace(ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
h = conv_2->forward(ctx, h);
|
h = conv_2->forward(ctx, h);
|
||||||
h = ggml_relu_inplace(ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
h = conv_4->forward(ctx, h);
|
h = conv_4->forward(ctx, h);
|
||||||
|
|
||||||
if (n_in != n_out) {
|
if (n_in != n_out) {
|
||||||
@ -49,8 +49,8 @@ public:
|
|||||||
x = skip->forward(ctx, x);
|
x = skip->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
h = ggml_add(ctx, h, x);
|
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||||
h = ggml_relu_inplace(ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -86,7 +86,7 @@ public:
|
|||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [n, in_channels, h, w]
|
// x: [n, in_channels, h, w]
|
||||||
// return: [n, z_channels, h/8, w/8]
|
// return: [n, z_channels, h/8, w/8]
|
||||||
|
|
||||||
@ -136,20 +136,20 @@ public:
|
|||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
|
||||||
// z: [n, z_channels, h, w]
|
// z: [n, z_channels, h, w]
|
||||||
// return: [n, out_channels, h*8, w*8]
|
// return: [n, out_channels, h*8, w*8]
|
||||||
|
|
||||||
auto h = ggml_scale(ctx, z, 1.0f / 3.0f);
|
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
|
||||||
h = ggml_tanh_inplace(ctx, h);
|
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
|
||||||
h = ggml_scale(ctx, h, 3.0f);
|
h = ggml_scale(ctx->ggml_ctx, h, 3.0f);
|
||||||
|
|
||||||
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
||||||
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
||||||
if (i == 1) {
|
if (i == 1) {
|
||||||
h = ggml_relu_inplace(ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
} else {
|
} else {
|
||||||
h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST);
|
h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -180,12 +180,12 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
||||||
return decoder->forward(ctx, z);
|
return decoder->forward(ctx, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
||||||
return encoder->forward(ctx, x);
|
return encoder->forward(ctx, x);
|
||||||
}
|
}
|
||||||
@ -207,17 +207,6 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
taesd.init(params_ctx, tensor_types, prefix);
|
taesd.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
taesd.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
return "taesd";
|
return "taesd";
|
||||||
}
|
}
|
||||||
@ -252,7 +241,8 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
struct ggml_tensor* out = decode_graph ? taesd.decode(compute_ctx, z) : taesd.encode(compute_ctx, z);
|
auto runner_ctx = get_context();
|
||||||
|
struct ggml_tensor* out = decode_graph ? taesd.decode(&runner_ctx, z) : taesd.encode(&runner_ctx, z);
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|||||||
121
unet.hpp
121
unet.hpp
@ -60,8 +60,7 @@ public:
|
|||||||
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
|
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
int timesteps) {
|
int timesteps) {
|
||||||
@ -92,7 +91,7 @@ public:
|
|||||||
auto time_context = context; // [b*t, n_context, context_dim]
|
auto time_context = context; // [b*t, n_context, context_dim]
|
||||||
auto spatial_context = context;
|
auto spatial_context = context;
|
||||||
// time_context_first_timestep = time_context[::timesteps]
|
// time_context_first_timestep = time_context[::timesteps]
|
||||||
auto time_context_first_timestep = ggml_view_3d(ctx,
|
auto time_context_first_timestep = ggml_view_3d(ctx->ggml_ctx,
|
||||||
time_context,
|
time_context,
|
||||||
time_context->ne[0],
|
time_context->ne[0],
|
||||||
time_context->ne[1],
|
time_context->ne[1],
|
||||||
@ -100,26 +99,26 @@ public:
|
|||||||
time_context->nb[1],
|
time_context->nb[1],
|
||||||
time_context->nb[2],
|
time_context->nb[2],
|
||||||
0); // [b, n_context, context_dim]
|
0); // [b, n_context, context_dim]
|
||||||
time_context = ggml_new_tensor_3d(ctx, GGML_TYPE_F32,
|
time_context = ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32,
|
||||||
time_context_first_timestep->ne[0],
|
time_context_first_timestep->ne[0],
|
||||||
time_context_first_timestep->ne[1],
|
time_context_first_timestep->ne[1],
|
||||||
time_context_first_timestep->ne[2] * h * w);
|
time_context_first_timestep->ne[2] * h * w);
|
||||||
time_context = ggml_repeat(ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim]
|
time_context = ggml_repeat(ctx->ggml_ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim]
|
||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
||||||
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
||||||
|
|
||||||
auto num_frames = ggml_arange(ctx, 0, timesteps, 1);
|
auto num_frames = ggml_arange(ctx->ggml_ctx, 0, timesteps, 1);
|
||||||
// since b is 1, no need to do repeat
|
// since b is 1, no need to do repeat
|
||||||
auto t_emb = ggml_ext_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
|
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
|
||||||
|
|
||||||
auto emb = time_pos_embed_0->forward(ctx, t_emb);
|
auto emb = time_pos_embed_0->forward(ctx, t_emb);
|
||||||
emb = ggml_silu_inplace(ctx, emb);
|
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
|
||||||
emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels]
|
emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels]
|
||||||
emb = ggml_reshape_3d(ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels]
|
emb = ggml_reshape_3d(ctx->ggml_ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels]
|
||||||
|
|
||||||
for (int i = 0; i < depth; i++) {
|
for (int i = 0; i < depth; i++) {
|
||||||
std::string transformer_name = "transformer_blocks." + std::to_string(i);
|
std::string transformer_name = "transformer_blocks." + std::to_string(i);
|
||||||
@ -128,11 +127,11 @@ public:
|
|||||||
auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]);
|
auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]);
|
||||||
auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]);
|
auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]);
|
||||||
|
|
||||||
x = block->forward(ctx, backend, x, spatial_context); // [N, h * w, inner_dim]
|
x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim]
|
||||||
|
|
||||||
// in_channels == inner_dim
|
// in_channels == inner_dim
|
||||||
auto x_mix = x;
|
auto x_mix = x;
|
||||||
x_mix = ggml_add(ctx, x_mix, emb); // [N, h * w, inner_dim]
|
x_mix = ggml_add(ctx->ggml_ctx, x_mix, emb); // [N, h * w, inner_dim]
|
||||||
|
|
||||||
int64_t N = x_mix->ne[2];
|
int64_t N = x_mix->ne[2];
|
||||||
int64_t T = timesteps;
|
int64_t T = timesteps;
|
||||||
@ -140,26 +139,26 @@ public:
|
|||||||
int64_t S = x_mix->ne[1];
|
int64_t S = x_mix->ne[1];
|
||||||
int64_t C = x_mix->ne[0];
|
int64_t C = x_mix->ne[0];
|
||||||
|
|
||||||
x_mix = ggml_reshape_4d(ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c
|
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c
|
||||||
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
|
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
|
||||||
x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
|
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
|
||||||
|
|
||||||
x_mix = mix_block->forward(ctx, backend, x_mix, time_context); // [B * h * w, T, inner_dim]
|
x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim]
|
||||||
|
|
||||||
x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
|
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
|
||||||
x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
|
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
|
||||||
x_mix = ggml_reshape_3d(ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c
|
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c
|
||||||
|
|
||||||
x = time_mixer->forward(ctx, x, x_mix); // [N, h * w, inner_dim]
|
x = time_mixer->forward(ctx, x, x_mix); // [N, h * w, inner_dim]
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
||||||
|
|
||||||
// proj_out
|
// proj_out
|
||||||
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
||||||
|
|
||||||
x = ggml_add(ctx, x, x_in);
|
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -184,7 +183,7 @@ public:
|
|||||||
int model_channels = 320;
|
int model_channels = 320;
|
||||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||||
|
|
||||||
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}, bool flash_attn = false)
|
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {})
|
||||||
: version(version) {
|
: version(version) {
|
||||||
if (sd_version_is_sd2(version)) {
|
if (sd_version_is_sd2(version)) {
|
||||||
context_dim = 1024;
|
context_dim = 1024;
|
||||||
@ -252,7 +251,7 @@ public:
|
|||||||
if (version == VERSION_SVD) {
|
if (version == VERSION_SVD) {
|
||||||
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
|
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||||
} else {
|
} else {
|
||||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn);
|
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -377,7 +376,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* resblock_forward(std::string name,
|
struct ggml_tensor* resblock_forward(std::string name,
|
||||||
struct ggml_context* ctx,
|
GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* emb,
|
struct ggml_tensor* emb,
|
||||||
int num_video_frames) {
|
int num_video_frames) {
|
||||||
@ -393,24 +392,22 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* attention_layer_forward(std::string name,
|
struct ggml_tensor* attention_layer_forward(std::string name,
|
||||||
struct ggml_context* ctx,
|
GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
int timesteps) {
|
int timesteps) {
|
||||||
if (version == VERSION_SVD) {
|
if (version == VERSION_SVD) {
|
||||||
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
|
||||||
|
|
||||||
return block->forward(ctx, backend, x, context, timesteps);
|
return block->forward(ctx, x, context, timesteps);
|
||||||
} else {
|
} else {
|
||||||
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
|
||||||
|
|
||||||
return block->forward(ctx, backend, x, context);
|
return block->forward(ctx, x, context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -427,20 +424,20 @@ public:
|
|||||||
// return: [N, out_channels, h, w]
|
// return: [N, out_channels, h, w]
|
||||||
if (context != nullptr) {
|
if (context != nullptr) {
|
||||||
if (context->ne[2] != x->ne[3]) {
|
if (context->ne[2] != x->ne[3]) {
|
||||||
context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
|
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (c_concat != nullptr) {
|
if (c_concat != nullptr) {
|
||||||
if (c_concat->ne[3] != x->ne[3]) {
|
if (c_concat->ne[3] != x->ne[3]) {
|
||||||
c_concat = ggml_repeat(ctx, c_concat, x);
|
c_concat = ggml_repeat(ctx->ggml_ctx, c_concat, x);
|
||||||
}
|
}
|
||||||
x = ggml_concat(ctx, x, c_concat, 2);
|
x = ggml_concat(ctx->ggml_ctx, x, c_concat, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (y != nullptr) {
|
if (y != nullptr) {
|
||||||
if (y->ne[1] != x->ne[3]) {
|
if (y->ne[1] != x->ne[3]) {
|
||||||
y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
|
y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -451,10 +448,10 @@ public:
|
|||||||
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
|
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
|
||||||
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
|
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
|
||||||
|
|
||||||
auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
|
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels]
|
||||||
|
|
||||||
auto emb = time_embed_0->forward(ctx, t_emb);
|
auto emb = time_embed_0->forward(ctx, t_emb);
|
||||||
emb = ggml_silu_inplace(ctx, emb);
|
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
|
||||||
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
|
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
|
||||||
|
|
||||||
// SDXL/SVD
|
// SDXL/SVD
|
||||||
@ -463,10 +460,10 @@ public:
|
|||||||
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
|
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
|
||||||
|
|
||||||
auto label_emb = label_embed_0->forward(ctx, y);
|
auto label_emb = label_embed_0->forward(ctx, y);
|
||||||
label_emb = ggml_silu_inplace(ctx, label_emb);
|
label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb);
|
||||||
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
|
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
|
||||||
|
|
||||||
emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim]
|
emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim]
|
||||||
}
|
}
|
||||||
|
|
||||||
// input_blocks
|
// input_blocks
|
||||||
@ -489,7 +486,7 @@ public:
|
|||||||
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
|
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
|
||||||
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
||||||
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
||||||
h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); // [N, mult*model_channels, h, w]
|
h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w]
|
||||||
}
|
}
|
||||||
hs.push_back(h);
|
hs.push_back(h);
|
||||||
}
|
}
|
||||||
@ -513,13 +510,13 @@ public:
|
|||||||
if (version != VERSION_SD1_TINY_UNET) {
|
if (version != VERSION_SD1_TINY_UNET) {
|
||||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
if (version != VERSION_SDXL_SSD1B) {
|
if (version != VERSION_SDXL_SSD1B) {
|
||||||
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
|
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength);
|
||||||
h = ggml_add(ctx, h, cs); // middle control
|
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
|
||||||
}
|
}
|
||||||
int control_offset = controls.size() - 2;
|
int control_offset = controls.size() - 2;
|
||||||
|
|
||||||
@ -531,12 +528,12 @@ public:
|
|||||||
hs.pop_back();
|
hs.pop_back();
|
||||||
|
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_scale_inplace(ctx, controls[control_offset], control_strength);
|
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength);
|
||||||
h_skip = ggml_add(ctx, h_skip, cs); // control net condition
|
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
|
||||||
control_offset--;
|
control_offset--;
|
||||||
}
|
}
|
||||||
|
|
||||||
h = ggml_concat(ctx, h, h_skip, 2);
|
h = ggml_concat(ctx->ggml_ctx, h, h_skip, 2);
|
||||||
|
|
||||||
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
|
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
|
||||||
|
|
||||||
@ -546,7 +543,7 @@ public:
|
|||||||
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
|
||||||
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
|
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
|
||||||
|
|
||||||
h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames);
|
h = attention_layer_forward(name, ctx, h, context, num_video_frames);
|
||||||
|
|
||||||
up_sample_idx++;
|
up_sample_idx++;
|
||||||
}
|
}
|
||||||
@ -572,7 +569,7 @@ public:
|
|||||||
|
|
||||||
// out
|
// out
|
||||||
h = out_0->forward(ctx, h);
|
h = out_0->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
h = out_2->forward(ctx, h);
|
h = out_2->forward(ctx, h);
|
||||||
ggml_set_name(h, "bench-end");
|
ggml_set_name(h, "bench-end");
|
||||||
return h; // [N, out_channels, h, w]
|
return h; // [N, out_channels, h, w]
|
||||||
@ -586,24 +583,11 @@ struct UNetModelRunner : public GGMLRunner {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types,
|
const String2GGMLType& tensor_types,
|
||||||
const std::string prefix,
|
const std::string prefix,
|
||||||
SDVersion version = VERSION_SD1,
|
SDVersion version = VERSION_SD1)
|
||||||
bool flash_attn = false)
|
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) {
|
||||||
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types, flash_attn) {
|
|
||||||
unet.init(params_ctx, tensor_types, prefix);
|
unet.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
unet.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
LOG_DEBUG("block %s", block->get_desc().c_str());
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
return "unet";
|
return "unet";
|
||||||
}
|
}
|
||||||
@ -636,8 +620,9 @@ struct UNetModelRunner : public GGMLRunner {
|
|||||||
controls[i] = to_backend(controls[i]);
|
controls[i] = to_backend(controls[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* out = unet.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
|
||||||
|
struct ggml_tensor* out = unet.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
|
|||||||
@ -53,7 +53,7 @@ struct UpscalerGGML {
|
|||||||
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
||||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
|
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
|
||||||
if (direct) {
|
if (direct) {
|
||||||
esrgan_upscaler->enable_conv2d_direct();
|
esrgan_upscaler->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
|
if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
96
vae.hpp
96
vae.hpp
@ -30,7 +30,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
// t_emb is always None
|
// t_emb is always None
|
||||||
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
||||||
@ -40,12 +40,12 @@ public:
|
|||||||
|
|
||||||
auto h = x;
|
auto h = x;
|
||||||
h = norm1->forward(ctx, h);
|
h = norm1->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h); // swish
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||||
h = conv1->forward(ctx, h);
|
h = conv1->forward(ctx, h);
|
||||||
// return h;
|
// return h;
|
||||||
|
|
||||||
h = norm2->forward(ctx, h);
|
h = norm2->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h); // swish
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||||
// dropout, skip for inference
|
// dropout, skip for inference
|
||||||
h = conv2->forward(ctx, h);
|
h = conv2->forward(ctx, h);
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ public:
|
|||||||
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
||||||
}
|
}
|
||||||
|
|
||||||
h = ggml_add(ctx, h, x);
|
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||||
return h; // [N, out_channels, h, w]
|
return h; // [N, out_channels, h, w]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -76,7 +76,7 @@ public:
|
|||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||||
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
|
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
|
||||||
@ -91,25 +91,25 @@ public:
|
|||||||
const int64_t h = h_->ne[1];
|
const int64_t h = h_->ne[1];
|
||||||
const int64_t w = h_->ne[0];
|
const int64_t w = h_->ne[0];
|
||||||
|
|
||||||
auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
||||||
|
|
||||||
h_ = ggml_ext_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
|
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||||
|
|
||||||
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||||
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||||
|
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
|
|
||||||
h_ = ggml_add(ctx, h_, x);
|
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
||||||
return h_;
|
return h_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -133,7 +133,7 @@ public:
|
|||||||
kernel_padding));
|
kernel_padding));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) override {
|
struct ggml_tensor* x) override {
|
||||||
// timesteps always None
|
// timesteps always None
|
||||||
// skip_video always False
|
// skip_video always False
|
||||||
@ -152,12 +152,12 @@ public:
|
|||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||||
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
return x; // [B*T, OC, OH, OW]
|
return x; // [B*T, OC, OH, OW]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ public:
|
|||||||
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
||||||
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
||||||
// t_emb is always None
|
// t_emb is always None
|
||||||
@ -199,19 +199,19 @@ public:
|
|||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||||
auto x_mix = x;
|
auto x_mix = x;
|
||||||
|
|
||||||
x = time_stack->forward(ctx, x); // b t c (h w)
|
x = time_stack->forward(ctx, x); // b t c (h w)
|
||||||
|
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
x = ggml_add(ctx,
|
x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_scale(ctx, x, alpha),
|
ggml_scale(ctx->ggml_ctx, x, alpha),
|
||||||
ggml_scale(ctx, x_mix, 1.0f - alpha));
|
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -271,7 +271,7 @@ public:
|
|||||||
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
|
|
||||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||||
@ -307,8 +307,8 @@ public:
|
|||||||
|
|
||||||
// end
|
// end
|
||||||
h = norm_out->forward(ctx, h);
|
h = norm_out->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h); // nonlinearity/swish
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||||
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -388,7 +388,7 @@ public:
|
|||||||
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
// z: [N, z_channels, h, w]
|
// z: [N, z_channels, h, w]
|
||||||
// alpha is always 0
|
// alpha is always 0
|
||||||
// merge_strategy is always learned
|
// merge_strategy is always learned
|
||||||
@ -429,8 +429,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
h = norm_out->forward(ctx, h);
|
h = norm_out->forward(ctx, h);
|
||||||
h = ggml_silu_inplace(ctx, h); // nonlinearity/swish
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||||
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -493,7 +493,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
// z: [N, z_channels, h, w]
|
// z: [N, z_channels, h, w]
|
||||||
if (use_quant) {
|
if (use_quant) {
|
||||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||||
@ -507,7 +507,7 @@ public:
|
|||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
||||||
|
|
||||||
@ -529,7 +529,6 @@ struct VAE : public GGMLRunner {
|
|||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
struct ggml_context* output_ctx) = 0;
|
struct ggml_context* output_ctx) = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
virtual void enable_conv2d_direct(){};
|
|
||||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -572,17 +571,6 @@ struct AutoEncoderKL : public VAE {
|
|||||||
ae.init(params_ctx, tensor_types, prefix);
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() override {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
ae.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_conv2d_scale(float scale) override {
|
void set_conv2d_scale(float scale) override {
|
||||||
std::vector<GGMLBlock*> blocks;
|
std::vector<GGMLBlock*> blocks;
|
||||||
ae.get_all_blocks(blocks);
|
ae.get_all_blocks(blocks);
|
||||||
@ -607,7 +595,9 @@ struct AutoEncoderKL : public VAE {
|
|||||||
|
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
|
|
||||||
struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
|
|||||||
424
wan.hpp
424
wan.hpp
@ -54,7 +54,7 @@ namespace WAN {
|
|||||||
dilation(std::move(dilation)),
|
dilation(std::move(dilation)),
|
||||||
bias(bias) {}
|
bias(bias) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) {
|
||||||
// x: [N*IC, ID, IH, IW]
|
// x: [N*IC, ID, IH, IW]
|
||||||
// result: x: [N*OC, ID, IH, IW]
|
// result: x: [N*OC, ID, IH, IW]
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
@ -71,12 +71,12 @@ namespace WAN {
|
|||||||
int rp2 = 0;
|
int rp2 = 0;
|
||||||
|
|
||||||
if (cache_x != nullptr && lp2 > 0) {
|
if (cache_x != nullptr && lp2 > 0) {
|
||||||
x = ggml_concat(ctx, cache_x, x, 2);
|
x = ggml_concat(ctx->ggml_ctx, cache_x, x, 2);
|
||||||
lp2 -= (int)cache_x->ne[2];
|
lp2 -= (int)cache_x->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
|
x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
|
||||||
return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
|
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
||||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
0, 0, 0,
|
0, 0, 0,
|
||||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||||
@ -96,15 +96,15 @@ namespace WAN {
|
|||||||
RMS_norm(int64_t dim)
|
RMS_norm(int64_t dim)
|
||||||
: dim(dim) {}
|
: dim(dim) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [N*IC, ID, IH, IW], IC == dim
|
// x: [N*IC, ID, IH, IW], IC == dim
|
||||||
// assert N == 1
|
// assert N == 1
|
||||||
|
|
||||||
struct ggml_tensor* w = params["gamma"];
|
struct ggml_tensor* w = params["gamma"];
|
||||||
auto h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||||
h = ggml_rms_norm(ctx, h, 1e-12);
|
h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12);
|
||||||
h = ggml_mul(ctx, h, w);
|
h = ggml_mul(ctx->ggml_ctx, h, w);
|
||||||
h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, h, 1, 2, 3, 0));
|
h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0));
|
||||||
|
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -165,16 +165,16 @@ namespace WAN {
|
|||||||
} else {
|
} else {
|
||||||
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||||
|
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
|
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
|
||||||
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
|
cache_x = ggml_pad_ext(ctx->ggml_ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
|
||||||
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
||||||
}
|
}
|
||||||
if (chunk_idx == 1) {
|
if (chunk_idx == 1) {
|
||||||
@ -183,9 +183,9 @@ namespace WAN {
|
|||||||
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
||||||
}
|
}
|
||||||
feat_cache[idx] = cache_x;
|
feat_cache[idx] = cache_x;
|
||||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -194,18 +194,18 @@ namespace WAN {
|
|||||||
if (mode != "none") {
|
if (mode != "none") {
|
||||||
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
if (mode == "upsample2d") {
|
if (mode == "upsample2d") {
|
||||||
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
} else if (mode == "upsample3d") {
|
} else if (mode == "upsample3d") {
|
||||||
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
} else if (mode == "downsample2d") {
|
} else if (mode == "downsample2d") {
|
||||||
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
||||||
} else if (mode == "downsample3d") {
|
} else if (mode == "downsample3d") {
|
||||||
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
||||||
}
|
}
|
||||||
x = resample_1->forward(ctx, x);
|
x = resample_1->forward(ctx, x);
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode == "downsample3d") {
|
if (mode == "downsample3d") {
|
||||||
@ -217,9 +217,9 @@ namespace WAN {
|
|||||||
} else {
|
} else {
|
||||||
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||||
|
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -1, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -1, x->ne[2]);
|
||||||
x = ggml_concat(ctx,
|
x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
x,
|
x,
|
||||||
2);
|
2);
|
||||||
x = time_conv->forward(ctx, x);
|
x = time_conv->forward(ctx, x);
|
||||||
@ -249,7 +249,7 @@ namespace WAN {
|
|||||||
GGML_ASSERT(in_channels * factor % out_channels == 0);
|
GGML_ASSERT(in_channels * factor % out_channels == 0);
|
||||||
group_size = in_channels * factor / out_channels;
|
group_size = in_channels * factor / out_channels;
|
||||||
}
|
}
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t B = 1) {
|
int64_t B = 1) {
|
||||||
// x: [B*IC, T, H, W]
|
// x: [B*IC, T, H, W]
|
||||||
@ -262,20 +262,20 @@ namespace WAN {
|
|||||||
|
|
||||||
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
|
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
|
||||||
|
|
||||||
x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
|
x = ggml_pad_ext(ctx->ggml_ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
|
||||||
T = x->ne[2];
|
T = x->ne[2];
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
|
||||||
x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
|
||||||
x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
|
||||||
x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
|
||||||
x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
|
x = ggml_mean(ctx->ggml_ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
|
||||||
x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -296,7 +296,7 @@ namespace WAN {
|
|||||||
GGML_ASSERT(out_channels * factor % in_channels == 0);
|
GGML_ASSERT(out_channels * factor % in_channels == 0);
|
||||||
repeats = out_channels * factor / in_channels;
|
repeats = out_channels * factor / in_channels;
|
||||||
}
|
}
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
bool first_chunk = false,
|
bool first_chunk = false,
|
||||||
int64_t B = 1) {
|
int64_t B = 1) {
|
||||||
@ -310,21 +310,21 @@ namespace WAN {
|
|||||||
|
|
||||||
auto x_ = x;
|
auto x_ = x;
|
||||||
for (int64_t i = 1; i < repeats; i++) {
|
for (int64_t i = 1; i < repeats; i++) {
|
||||||
x = ggml_concat(ctx, x, x_, 2);
|
x = ggml_concat(ctx->ggml_ctx, x, x_, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
C = out_channels;
|
C = out_channels;
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
|
||||||
x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
|
||||||
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
|
||||||
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
|
||||||
|
|
||||||
if (first_chunk) {
|
if (first_chunk) {
|
||||||
x = ggml_ext_slice(ctx, x, 2, factor_t - 1, x->ne[2]);
|
x = ggml_ext_slice(ctx->ggml_ctx, x, 2, factor_t - 1, x->ne[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -351,7 +351,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -374,11 +374,11 @@ namespace WAN {
|
|||||||
|
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
@ -388,13 +388,13 @@ namespace WAN {
|
|||||||
feat_idx += 1;
|
feat_idx += 1;
|
||||||
}
|
}
|
||||||
} else if (i == 1 || i == 4) {
|
} else if (i == 1 || i == 4) {
|
||||||
x = ggml_silu(ctx, x);
|
x = ggml_silu(ctx->ggml_ctx, x);
|
||||||
} else { // i == 5
|
} else { // i == 5
|
||||||
// nn.Dropout(), ignore
|
// nn.Dropout(), ignore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_add(ctx, x, h);
|
x = ggml_add(ctx->ggml_ctx, x, h);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -425,7 +425,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -453,7 +453,7 @@ namespace WAN {
|
|||||||
|
|
||||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, shortcut);
|
x = ggml_add(ctx->ggml_ctx, x, shortcut);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -487,7 +487,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -513,7 +513,7 @@ namespace WAN {
|
|||||||
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
||||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
|
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, shortcut);
|
x = ggml_add(ctx->ggml_ctx, x, shortcut);
|
||||||
}
|
}
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -532,7 +532,7 @@ namespace WAN {
|
|||||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {1, 1}));
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b) {
|
int64_t b) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
@ -545,7 +545,7 @@ namespace WAN {
|
|||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
|
|
||||||
const int64_t n = x->ne[3];
|
const int64_t n = x->ne[3];
|
||||||
const int64_t c = x->ne[2];
|
const int64_t c = x->ne[2];
|
||||||
@ -553,31 +553,31 @@ namespace WAN {
|
|||||||
const int64_t w = x->ne[0];
|
const int64_t w = x->ne[0];
|
||||||
|
|
||||||
auto qkv = to_qkv->forward(ctx, x);
|
auto qkv = to_qkv->forward(ctx, x);
|
||||||
auto qkv_vec = split_image_qkv(ctx, qkv);
|
auto qkv_vec = split_image_qkv(ctx->ggml_ctx, qkv);
|
||||||
|
|
||||||
auto q = qkv_vec[0];
|
auto q = qkv_vec[0];
|
||||||
q = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
q = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
auto k = qkv_vec[1];
|
auto k = qkv_vec[1];
|
||||||
k = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
k = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
auto v = qkv_vec[2];
|
auto v = qkv_vec[2];
|
||||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
x = ggml_ext_attention(ctx, q, k, v, false); // [t, h * w, c]
|
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c]
|
||||||
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||||
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
|
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
|
|
||||||
x = ggml_add(ctx, x, identity);
|
x = ggml_add(ctx->ggml_ctx, x, identity);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -655,7 +655,7 @@ namespace WAN {
|
|||||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -673,11 +673,11 @@ namespace WAN {
|
|||||||
// conv1
|
// conv1
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
@ -722,14 +722,14 @@ namespace WAN {
|
|||||||
|
|
||||||
// head
|
// head
|
||||||
x = head_0->forward(ctx, x);
|
x = head_0->forward(ctx, x);
|
||||||
x = ggml_silu(ctx, x);
|
x = ggml_silu(ctx->ggml_ctx, x);
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
@ -826,7 +826,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
@ -844,11 +844,11 @@ namespace WAN {
|
|||||||
// conv1
|
// conv1
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
@ -893,14 +893,14 @@ namespace WAN {
|
|||||||
|
|
||||||
// head
|
// head
|
||||||
x = head_0->forward(ctx, x);
|
x = head_0->forward(ctx, x);
|
||||||
x = ggml_silu(ctx, x);
|
x = ggml_silu(ctx->ggml_ctx, x);
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx->ggml_ctx,
|
||||||
ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
@ -1015,7 +1015,7 @@ namespace WAN {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* encode(struct ggml_context* ctx,
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
@ -1025,7 +1025,7 @@ namespace WAN {
|
|||||||
clear_cache();
|
clear_cache();
|
||||||
|
|
||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
x = patchify(ctx, x, 2, b);
|
x = patchify(ctx->ggml_ctx, x, 2, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
|
||||||
@ -1037,21 +1037,21 @@ namespace WAN {
|
|||||||
for (int i = 0; i < iter_; i++) {
|
for (int i = 0; i < iter_; i++) {
|
||||||
_enc_conv_idx = 0;
|
_enc_conv_idx = 0;
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
auto in = ggml_ext_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
||||||
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||||
} else {
|
} else {
|
||||||
auto in = ggml_ext_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
||||||
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||||
out = ggml_concat(ctx, out, out_, 2);
|
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out = conv1->forward(ctx, out);
|
out = conv1->forward(ctx, out);
|
||||||
auto mu = ggml_ext_chunk(ctx, out, 2, 3)[0];
|
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
|
||||||
clear_cache();
|
clear_cache();
|
||||||
return mu;
|
return mu;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode(struct ggml_context* ctx,
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
// z: [b*c, t, h, w]
|
// z: [b*c, t, h, w]
|
||||||
@ -1068,22 +1068,22 @@ namespace WAN {
|
|||||||
for (int64_t i = 0; i < iter_; i++) {
|
for (int64_t i = 0; i < iter_; i++) {
|
||||||
_conv_idx = 0;
|
_conv_idx = 0;
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
} else {
|
} else {
|
||||||
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
out = ggml_concat(ctx, out, out_, 2);
|
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
out = unpatchify(ctx, out, 2, b);
|
out = unpatchify(ctx->ggml_ctx, out, 2, b);
|
||||||
}
|
}
|
||||||
clear_cache();
|
clear_cache();
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode_partial(struct ggml_context* ctx,
|
struct ggml_tensor* decode_partial(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
int64_t i,
|
int64_t i,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
@ -1094,11 +1094,11 @@ namespace WAN {
|
|||||||
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
||||||
|
|
||||||
auto x = conv2->forward(ctx, z);
|
auto x = conv2->forward(ctx, z);
|
||||||
auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
_conv_idx = 0;
|
_conv_idx = 0;
|
||||||
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
out = unpatchify(ctx, out, 2, b);
|
out = unpatchify(ctx->ggml_ctx, out, 2, b);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1131,7 +1131,9 @@ namespace WAN {
|
|||||||
|
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
|
|
||||||
struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
@ -1150,7 +1152,9 @@ namespace WAN {
|
|||||||
|
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
|
|
||||||
struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z);
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z);
|
||||||
|
|
||||||
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
||||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
||||||
@ -1283,15 +1287,13 @@ namespace WAN {
|
|||||||
public:
|
public:
|
||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
int64_t head_dim;
|
int64_t head_dim;
|
||||||
bool flash_attn;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
WanSelfAttention(int64_t dim,
|
WanSelfAttention(int64_t dim,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
: num_heads(num_heads) {
|
||||||
: num_heads(num_heads), flash_attn(flash_attn) {
|
|
||||||
head_dim = dim / num_heads;
|
head_dim = dim / num_heads;
|
||||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
@ -1307,8 +1309,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask = nullptr) {
|
struct ggml_tensor* mask = nullptr) {
|
||||||
@ -1331,11 +1332,11 @@ namespace WAN {
|
|||||||
k = norm_k->forward(ctx, k);
|
k = norm_k->forward(ctx, k);
|
||||||
auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head]
|
auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head]
|
||||||
|
|
||||||
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
|
|
||||||
x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
|
x = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, dim]
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1346,12 +1347,10 @@ namespace WAN {
|
|||||||
public:
|
public:
|
||||||
WanCrossAttention(int64_t dim,
|
WanCrossAttention(int64_t dim,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
: WanSelfAttention(dim, num_heads, qk_norm, eps) {}
|
||||||
: WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
int64_t context_img_len) = 0;
|
int64_t context_img_len) = 0;
|
||||||
@ -1361,12 +1360,10 @@ namespace WAN {
|
|||||||
public:
|
public:
|
||||||
WanT2VCrossAttention(int64_t dim,
|
WanT2VCrossAttention(int64_t dim,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
: WanCrossAttention(dim, num_heads, qk_norm, eps) {}
|
||||||
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
int64_t context_img_len) override {
|
int64_t context_img_len) override {
|
||||||
@ -1390,7 +1387,7 @@ namespace WAN {
|
|||||||
k = norm_k->forward(ctx, k);
|
k = norm_k->forward(ctx, k);
|
||||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1401,10 +1398,9 @@ namespace WAN {
|
|||||||
public:
|
public:
|
||||||
WanI2VCrossAttention(int64_t dim,
|
WanI2VCrossAttention(int64_t dim,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
: WanCrossAttention(dim, num_heads, qk_norm, eps) {
|
||||||
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {
|
|
||||||
blocks["k_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["k_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
blocks["v_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["v_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
|
||||||
@ -1415,8 +1411,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
int64_t context_img_len) override {
|
int64_t context_img_len) override {
|
||||||
@ -1441,11 +1436,11 @@ namespace WAN {
|
|||||||
int64_t dim = x->ne[0];
|
int64_t dim = x->ne[0];
|
||||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||||
|
|
||||||
context = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||||
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||||
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
|
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
|
||||||
context_img = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||||
context_txt = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
auto q = q_proj->forward(ctx, x);
|
auto q = q_proj->forward(ctx, x);
|
||||||
q = norm_q->forward(ctx, q);
|
q = norm_q->forward(ctx, q);
|
||||||
@ -1457,10 +1452,10 @@ namespace WAN {
|
|||||||
k_img = norm_k_img->forward(ctx, k_img);
|
k_img = norm_k_img->forward(ctx, k_img);
|
||||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
|
||||||
auto img_x = ggml_ext_attention_ext(ctx, backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = ggml_add(ctx, x, img_x);
|
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1511,20 +1506,19 @@ namespace WAN {
|
|||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
bool cross_attn_norm = false,
|
bool cross_attn_norm = false,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6)
|
||||||
bool flash_attn = false)
|
|
||||||
: dim(dim) {
|
: dim(dim) {
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
|
||||||
if (cross_attn_norm) {
|
if (cross_attn_norm) {
|
||||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
||||||
} else {
|
} else {
|
||||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||||
}
|
}
|
||||||
if (t2v_cross_attn) {
|
if (t2v_cross_attn) {
|
||||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||||
} else {
|
} else {
|
||||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
@ -1534,8 +1528,7 @@ namespace WAN {
|
|||||||
blocks["ffn.2"] = std::shared_ptr<GGMLBlock>(new Linear(ffn_dim, dim));
|
blocks["ffn.2"] = std::shared_ptr<GGMLBlock>(new Linear(ffn_dim, dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* e,
|
struct ggml_tensor* e,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
@ -1547,8 +1540,8 @@ namespace WAN {
|
|||||||
// return [N, n_token, dim]
|
// return [N, n_token, dim]
|
||||||
|
|
||||||
auto modulation = params["modulation"];
|
auto modulation = params["modulation"];
|
||||||
e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
|
e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
|
||||||
auto es = ggml_ext_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
|
auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
|
||||||
|
|
||||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||||
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
|
||||||
@ -1560,27 +1553,27 @@ namespace WAN {
|
|||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
auto y = norm1->forward(ctx, x);
|
auto y = norm1->forward(ctx, x);
|
||||||
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1]));
|
y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[1]));
|
||||||
y = modulate_add(ctx, y, es[0]);
|
y = modulate_add(ctx->ggml_ctx, y, es[0]);
|
||||||
y = self_attn->forward(ctx, backend, y, pe);
|
y = self_attn->forward(ctx, y, pe);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2]));
|
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[2]));
|
||||||
|
|
||||||
// cross-attention
|
// cross-attention
|
||||||
x = ggml_add(ctx,
|
x = ggml_add(ctx->ggml_ctx,
|
||||||
x,
|
x,
|
||||||
cross_attn->forward(ctx, backend, norm3->forward(ctx, x), context, context_img_len));
|
cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len));
|
||||||
|
|
||||||
// ffn
|
// ffn
|
||||||
y = norm2->forward(ctx, x);
|
y = norm2->forward(ctx, x);
|
||||||
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4]));
|
y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[4]));
|
||||||
y = modulate_add(ctx, y, es[3]);
|
y = modulate_add(ctx->ggml_ctx, y, es[3]);
|
||||||
|
|
||||||
y = ffn_0->forward(ctx, y);
|
y = ffn_0->forward(ctx, y);
|
||||||
y = ggml_gelu_inplace(ctx, y);
|
y = ggml_gelu_inplace(ctx->ggml_ctx, y);
|
||||||
y = ffn_2->forward(ctx, y);
|
y = ffn_2->forward(ctx, y);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5]));
|
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1602,17 +1595,15 @@ namespace WAN {
|
|||||||
bool qk_norm = true,
|
bool qk_norm = true,
|
||||||
bool cross_attn_norm = false,
|
bool cross_attn_norm = false,
|
||||||
float eps = 1e-6,
|
float eps = 1e-6,
|
||||||
int block_id = 0,
|
int block_id = 0)
|
||||||
bool flash_attn = false)
|
: WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps), block_id(block_id) {
|
||||||
: WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps, flash_attn), block_id(block_id) {
|
|
||||||
if (block_id == 0) {
|
if (block_id == 0) {
|
||||||
blocks["before_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["before_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
}
|
}
|
||||||
blocks["after_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
blocks["after_proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* c,
|
struct ggml_tensor* c,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* e,
|
struct ggml_tensor* e,
|
||||||
@ -1627,12 +1618,12 @@ namespace WAN {
|
|||||||
auto before_proj = std::dynamic_pointer_cast<Linear>(blocks["before_proj"]);
|
auto before_proj = std::dynamic_pointer_cast<Linear>(blocks["before_proj"]);
|
||||||
|
|
||||||
c = before_proj->forward(ctx, c);
|
c = before_proj->forward(ctx, c);
|
||||||
c = ggml_add(ctx, c, x);
|
c = ggml_add(ctx->ggml_ctx, c, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto after_proj = std::dynamic_pointer_cast<Linear>(blocks["after_proj"]);
|
auto after_proj = std::dynamic_pointer_cast<Linear>(blocks["after_proj"]);
|
||||||
|
|
||||||
c = WanAttentionBlock::forward(ctx, backend, c, e, pe, context, context_img_len);
|
c = WanAttentionBlock::forward(ctx, c, e, pe, context, context_img_len);
|
||||||
auto c_skip = after_proj->forward(ctx, c);
|
auto c_skip = after_proj->forward(ctx, c);
|
||||||
|
|
||||||
return {c_skip, c};
|
return {c_skip, c};
|
||||||
@ -1660,7 +1651,7 @@ namespace WAN {
|
|||||||
blocks["head"] = std::shared_ptr<GGMLBlock>(new Linear(dim, out_dim));
|
blocks["head"] = std::shared_ptr<GGMLBlock>(new Linear(dim, out_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* e) {
|
struct ggml_tensor* e) {
|
||||||
// x: [N, n_token, dim]
|
// x: [N, n_token, dim]
|
||||||
@ -1668,18 +1659,18 @@ namespace WAN {
|
|||||||
// return [N, n_token, out_dim]
|
// return [N, n_token, out_dim]
|
||||||
|
|
||||||
auto modulation = params["modulation"];
|
auto modulation = params["modulation"];
|
||||||
e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
|
e = ggml_reshape_4d(ctx->ggml_ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
|
||||||
e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
|
e = ggml_repeat_4d(ctx->ggml_ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
|
||||||
|
|
||||||
e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
|
e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
|
||||||
auto es = ggml_ext_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
|
auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
|
||||||
|
|
||||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||||
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
|
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
|
||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1]));
|
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, x, es[1]));
|
||||||
x = modulate_add(ctx, x, es[0]);
|
x = modulate_add(ctx->ggml_ctx, x, es[0]);
|
||||||
x = head->forward(ctx, x);
|
x = head->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1708,15 +1699,15 @@ namespace WAN {
|
|||||||
blocks["proj.4"] = std::shared_ptr<GGMLBlock>(new LayerNorm(out_dim));
|
blocks["proj.4"] = std::shared_ptr<GGMLBlock>(new LayerNorm(out_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* image_embeds) {
|
struct ggml_tensor* image_embeds) {
|
||||||
if (flf_pos_embed_token_number > 0) {
|
if (flf_pos_embed_token_number > 0) {
|
||||||
auto emb_pos = params["emb_pos"];
|
auto emb_pos = params["emb_pos"];
|
||||||
|
|
||||||
auto a = ggml_ext_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]);
|
auto a = ggml_ext_slice(ctx->ggml_ctx, image_embeds, 1, 0, emb_pos->ne[1]);
|
||||||
auto b = ggml_ext_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]);
|
auto b = ggml_ext_slice(ctx->ggml_ctx, emb_pos, 1, 0, image_embeds->ne[1]);
|
||||||
|
|
||||||
image_embeds = ggml_add(ctx, a, b);
|
image_embeds = ggml_add(ctx->ggml_ctx, a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto proj_0 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.0"]);
|
auto proj_0 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.0"]);
|
||||||
@ -1726,7 +1717,7 @@ namespace WAN {
|
|||||||
|
|
||||||
auto x = proj_0->forward(ctx, image_embeds);
|
auto x = proj_0->forward(ctx, image_embeds);
|
||||||
x = proj_1->forward(ctx, x);
|
x = proj_1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx, x);
|
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||||
x = proj_3->forward(ctx, x);
|
x = proj_3->forward(ctx, x);
|
||||||
x = proj_4->forward(ctx, x);
|
x = proj_4->forward(ctx, x);
|
||||||
|
|
||||||
@ -1757,7 +1748,6 @@ namespace WAN {
|
|||||||
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
||||||
std::vector<int> axes_dim = {44, 42, 42};
|
std::vector<int> axes_dim = {44, 42, 42};
|
||||||
int64_t axes_dim_sum = 128;
|
int64_t axes_dim_sum = 128;
|
||||||
bool flash_attn = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Wan : public GGMLBlock {
|
class Wan : public GGMLBlock {
|
||||||
@ -1792,8 +1782,7 @@ namespace WAN {
|
|||||||
params.num_heads,
|
params.num_heads,
|
||||||
params.qk_norm,
|
params.qk_norm,
|
||||||
params.cross_attn_norm,
|
params.cross_attn_norm,
|
||||||
params.eps,
|
params.eps));
|
||||||
params.flash_attn));
|
|
||||||
blocks["blocks." + std::to_string(i)] = block;
|
blocks["blocks." + std::to_string(i)] = block;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1815,8 +1804,7 @@ namespace WAN {
|
|||||||
params.qk_norm,
|
params.qk_norm,
|
||||||
params.cross_attn_norm,
|
params.cross_attn_norm,
|
||||||
params.eps,
|
params.eps,
|
||||||
i,
|
i));
|
||||||
params.flash_attn));
|
|
||||||
blocks["vace_blocks." + std::to_string(i)] = block;
|
blocks["vace_blocks." + std::to_string(i)] = block;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1872,8 +1860,7 @@ namespace WAN {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
|
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -1902,31 +1889,31 @@ namespace WAN {
|
|||||||
auto head = std::dynamic_pointer_cast<Head>(blocks["head"]);
|
auto head = std::dynamic_pointer_cast<Head>(blocks["head"]);
|
||||||
|
|
||||||
// patch_embedding
|
// patch_embedding
|
||||||
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
||||||
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||||
|
|
||||||
// time_embedding
|
// time_embedding
|
||||||
auto e = ggml_ext_timestep_embedding(ctx, timestep, params.freq_dim);
|
auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, params.freq_dim);
|
||||||
e = time_embedding_0->forward(ctx, e);
|
e = time_embedding_0->forward(ctx, e);
|
||||||
e = ggml_silu_inplace(ctx, e);
|
e = ggml_silu_inplace(ctx->ggml_ctx, e);
|
||||||
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
|
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
|
||||||
|
|
||||||
// time_projection
|
// time_projection
|
||||||
auto e0 = ggml_silu(ctx, e);
|
auto e0 = ggml_silu(ctx->ggml_ctx, e);
|
||||||
e0 = time_projection_1->forward(ctx, e0);
|
e0 = time_projection_1->forward(ctx, e0);
|
||||||
e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
||||||
|
|
||||||
context = text_embedding_0->forward(ctx, context);
|
context = text_embedding_0->forward(ctx, context);
|
||||||
context = ggml_gelu(ctx, context);
|
context = ggml_gelu(ctx->ggml_ctx, context);
|
||||||
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
int64_t context_img_len = 0;
|
int64_t context_img_len = 0;
|
||||||
if (clip_fea != nullptr) {
|
if (clip_fea != nullptr) {
|
||||||
if (params.model_type == "i2v") {
|
if (params.model_type == "i2v") {
|
||||||
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
|
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
|
||||||
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
|
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
|
||||||
context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
|
context = ggml_concat(ctx->ggml_ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
|
||||||
}
|
}
|
||||||
context_img_len = clip_fea->ne[1]; // 257
|
context_img_len = clip_fea->ne[1]; // 257
|
||||||
}
|
}
|
||||||
@ -1936,9 +1923,9 @@ namespace WAN {
|
|||||||
if (params.vace_layers > 0) {
|
if (params.vace_layers > 0) {
|
||||||
auto vace_patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["vace_patch_embedding"]);
|
auto vace_patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["vace_patch_embedding"]);
|
||||||
|
|
||||||
c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len]
|
c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len]
|
||||||
c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
c = ggml_reshape_3d(ctx->ggml_ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
||||||
c = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
c = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||||
}
|
}
|
||||||
|
|
||||||
auto x_orig = x;
|
auto x_orig = x;
|
||||||
@ -1946,7 +1933,7 @@ namespace WAN {
|
|||||||
for (int i = 0; i < params.num_layers; i++) {
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
x = block->forward(ctx, backend, x, e0, pe, context, context_img_len);
|
x = block->forward(ctx, x, e0, pe, context, context_img_len);
|
||||||
|
|
||||||
auto iter = params.vace_layers_mapping.find(i);
|
auto iter = params.vace_layers_mapping.find(i);
|
||||||
if (iter != params.vace_layers_mapping.end()) {
|
if (iter != params.vace_layers_mapping.end()) {
|
||||||
@ -1954,11 +1941,11 @@ namespace WAN {
|
|||||||
|
|
||||||
auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]);
|
auto vace_block = std::dynamic_pointer_cast<VaceWanAttentionBlock>(blocks["vace_blocks." + std::to_string(n)]);
|
||||||
|
|
||||||
auto result = vace_block->forward(ctx, backend, c, x_orig, e0, pe, context, context_img_len);
|
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
|
||||||
auto c_skip = result.first;
|
auto c_skip = result.first;
|
||||||
c = result.second;
|
c = result.second;
|
||||||
c_skip = ggml_scale(ctx, c_skip, vace_strength);
|
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength);
|
||||||
x = ggml_add(ctx, x, c_skip);
|
x = ggml_add(ctx->ggml_ctx, x, c_skip);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1967,8 +1954,7 @@ namespace WAN {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -1993,27 +1979,27 @@ namespace WAN {
|
|||||||
int64_t T = x->ne[2];
|
int64_t T = x->ne[2];
|
||||||
int64_t C = x->ne[3];
|
int64_t C = x->ne[3];
|
||||||
|
|
||||||
x = pad_to_patch_size(ctx, x);
|
x = pad_to_patch_size(ctx->ggml_ctx, x);
|
||||||
|
|
||||||
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
|
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
|
||||||
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
|
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
|
||||||
|
|
||||||
if (time_dim_concat != nullptr) {
|
if (time_dim_concat != nullptr) {
|
||||||
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
|
time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat);
|
||||||
x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
||||||
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
|
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
|
||||||
|
|
||||||
out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
out = unpatchify(ctx->ggml_ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
||||||
|
|
||||||
// slice
|
// slice
|
||||||
|
|
||||||
out = ggml_ext_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
|
||||||
out = ggml_ext_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
|
||||||
out = ggml_ext_slice(ctx, out, 0, 0, W); // [N*C, T, H, W]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N*C, T, H, W]
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -2031,10 +2017,8 @@ namespace WAN {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
SDVersion version = VERSION_WAN2,
|
SDVersion version = VERSION_WAN2)
|
||||||
bool flash_attn = false)
|
|
||||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||||
wan_params.flash_attn = flash_attn;
|
|
||||||
wan_params.num_layers = 0;
|
wan_params.num_layers = 0;
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
@ -2183,8 +2167,9 @@ namespace WAN {
|
|||||||
x = ggml_concat(compute_ctx, x, c_concat, 3);
|
x = ggml_concat(compute_ctx, x, c_concat, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* out = wan.forward(compute_ctx,
|
auto runner_ctx = get_context();
|
||||||
runtime_backend,
|
|
||||||
|
struct ggml_tensor* out = wan.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
@ -2281,8 +2266,7 @@ namespace WAN {
|
|||||||
false,
|
false,
|
||||||
tensor_types,
|
tensor_types,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V);
|
||||||
true);
|
|
||||||
|
|
||||||
wan->alloc_params_buffer();
|
wan->alloc_params_buffer();
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user