refactor: unify the naming style of ggml extension functions (#921)

This commit is contained in:
leejet 2025-10-28 23:26:48 +08:00 committed by GitHub
parent 77eb95f8e4
commit dd75fc081c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 600 additions and 601 deletions

View File

@ -641,10 +641,10 @@ public:
// concat(patch_embedding, class_embedding) + position_embedding // concat(patch_embedding, class_embedding) + position_embedding
struct ggml_tensor* patch_embedding; struct ggml_tensor* patch_embedding;
int64_t N = pixel_values->ne[3]; int64_t N = pixel_values->ne[3];
patch_embedding = ggml_nn_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size] patch_embedding = ggml_ext_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches] patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim] patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1] patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N); struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N);
class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim] class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim]
@ -736,7 +736,7 @@ public:
auto text_projection = params["text_projection"]; auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != nullptr) { if (text_projection != nullptr) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, nullptr); pooled = ggml_ext_linear(ctx, pooled, text_projection, nullptr);
} else { } else {
LOG_DEBUG("identity projection"); LOG_DEBUG("identity projection");
} }
@ -836,7 +836,7 @@ public:
if (transpose_weight) { if (transpose_weight) {
w = ggml_cont(ctx, ggml_transpose(ctx, w)); w = ggml_cont(ctx, ggml_transpose(ctx, w));
} }
return ggml_nn_linear(ctx, x, w, nullptr); return ggml_ext_linear(ctx, x, w, nullptr);
} }
}; };

View File

@ -205,8 +205,8 @@ public:
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_in = x; auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] x = ggml_ext_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] auto gate = ggml_ext_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
gate = ggml_gelu_inplace(ctx, gate); gate = ggml_gelu_inplace(ctx, gate);
@ -325,7 +325,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim] x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x; return x;
@ -492,7 +492,7 @@ protected:
float get_alpha() { float get_alpha() {
// image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,] // image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,]
// so learned_with_images is same as learned // so learned_with_images is same as learned
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]); float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha); return sigmoid(alpha);
} }

View File

@ -462,7 +462,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
clip_skip, clip_skip,
&chunk_hidden_states2, work_ctx); &chunk_hidden_states2, work_ctx);
// concat // concat
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0); chunk_hidden_states = ggml_ext_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
if (chunk_idx == 0) { if (chunk_idx == 0) {
text_model2->compute(n_threads, text_model2->compute(n_threads,
@ -484,18 +484,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states); ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states);
{ {
float original_mean = ggml_tensor_mean(chunk_hidden_states); float original_mean = ggml_ext_tensor_mean(chunk_hidden_states);
for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) { for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) {
for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) { for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) {
for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) { for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) {
float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2); float value = ggml_ext_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(result, value, i0, i1, i2); ggml_ext_tensor_set_f32(result, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(result); float new_mean = ggml_ext_tensor_mean(result);
ggml_tensor_scale(result, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(result, (original_mean / new_mean));
} }
if (zero_out_masked) { if (zero_out_masked) {
float* vec = (float*)result->data; float* vec = (float*)result->data;
@ -874,18 +874,18 @@ struct SD3CLIPEmbedder : public Conditioner {
work_ctx); work_ctx);
{ {
auto tensor = chunk_hidden_states_l; auto tensor = chunk_hidden_states_l;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
if (chunk_idx == 0) { if (chunk_idx == 0) {
@ -932,18 +932,18 @@ struct SD3CLIPEmbedder : public Conditioner {
{ {
auto tensor = chunk_hidden_states_g; auto tensor = chunk_hidden_states_g;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
if (chunk_idx == 0) { if (chunk_idx == 0) {
@ -984,18 +984,18 @@ struct SD3CLIPEmbedder : public Conditioner {
work_ctx); work_ctx);
{ {
auto tensor = chunk_hidden_states_t5; auto tensor = chunk_hidden_states_t5;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
} else { } else {
chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len); chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
@ -1013,19 +1013,19 @@ struct SD3CLIPEmbedder : public Conditioner {
for (int i0 = 0; i0 < chunk_hidden_states_lg_pad->ne[0]; i0++) { for (int i0 = 0; i0 < chunk_hidden_states_lg_pad->ne[0]; i0++) {
float value = 0.f; float value = 0.f;
if (i0 < chunk_hidden_states_l->ne[0]) { if (i0 < chunk_hidden_states_l->ne[0]) {
value = ggml_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2); value = ggml_ext_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2);
} else if (i0 < chunk_hidden_states_l->ne[0] + chunk_hidden_states_g->ne[0]) { } else if (i0 < chunk_hidden_states_l->ne[0] + chunk_hidden_states_g->ne[0]) {
value = ggml_tensor_get_f32(chunk_hidden_states_g, i0 - chunk_hidden_states_l->ne[0], i1, i2); value = ggml_ext_tensor_get_f32(chunk_hidden_states_g, i0 - chunk_hidden_states_l->ne[0], i1, i2);
} }
ggml_tensor_set_f32(chunk_hidden_states_lg_pad, value, i0, i1, i2); ggml_ext_tensor_set_f32(chunk_hidden_states_lg_pad, value, i0, i1, i2);
} }
} }
} }
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_lg_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096] chunk_hidden_states = ggml_ext_tensor_concat(work_ctx, chunk_hidden_states_lg_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
if (chunk_idx == 0) { if (chunk_idx == 0) {
pooled = ggml_tensor_concat(work_ctx, pooled_l, pooled_g, 0); // [768 + 1280] pooled = ggml_ext_tensor_concat(work_ctx, pooled_l, pooled_g, 0); // [768 + 1280]
} }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
@ -1269,18 +1269,18 @@ struct FluxCLIPEmbedder : public Conditioner {
work_ctx); work_ctx);
{ {
auto tensor = chunk_hidden_states; auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
} else { } else {
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len); chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
@ -1483,18 +1483,18 @@ struct T5CLIPEmbedder : public Conditioner {
work_ctx); work_ctx);
{ {
auto tensor = chunk_hidden_states; auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1]; value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
@ -1505,7 +1505,7 @@ struct T5CLIPEmbedder : public Conditioner {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
if (chunk_mask[i1] < 0.f) { if (chunk_mask[i1] < 0.f) {
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, 0.f, i0, i1, i2);
} }
} }
} }
@ -1664,7 +1664,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
image.data = nullptr; image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image, image_tensor, false); sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = nullptr; resized_image.data = nullptr;
@ -1709,18 +1709,18 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
work_ctx); work_ctx);
{ {
auto tensor = hidden_states; auto tensor = hidden_states;
float original_mean = ggml_tensor_mean(tensor); float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2); float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1]; value *= weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
} }
} }
} }
float new_mean = ggml_tensor_mean(tensor); float new_mean = ggml_ext_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean)); ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
} }
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
@ -1731,9 +1731,9 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
hidden_states->ne[1] - prompt_template_encode_start_idx, hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[2]); hidden_states->ne[2]);
ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
}); });
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();

View File

@ -230,7 +230,7 @@ public:
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]); auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);
auto t_emb = ggml_nn_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels] auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb); auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb); emb = ggml_silu_inplace(ctx, emb);

View File

@ -401,8 +401,8 @@ struct CompVisDenoiser : public Denoiser {
// this function will modify noise/latent // this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override { ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_tensor_scale(noise, sigma); ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_tensor_add(latent, noise); ggml_ext_tensor_add_inplace(latent, noise);
return latent; return latent;
} }
@ -496,14 +496,14 @@ struct DiscreteFlowDenoiser : public Denoiser {
// this function will modify noise/latent // this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override { ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_tensor_scale(noise, sigma); ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_tensor_scale(latent, 1.0f - sigma); ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_tensor_add(latent, noise); ggml_ext_tensor_add_inplace(latent, noise);
return latent; return latent;
} }
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override { ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma)); ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent; return latent;
} }
}; };
@ -555,14 +555,14 @@ struct FluxFlowDenoiser : public Denoiser {
// this function will modify noise/latent // this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override { ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_tensor_scale(noise, sigma); ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_tensor_scale(latent, 1.0f - sigma); ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_tensor_add(latent, noise); ggml_ext_tensor_add_inplace(latent, noise);
return latent; return latent;
} }
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override { ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma)); ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent; return latent;
} }
}; };
@ -620,7 +620,7 @@ static void sample_k_diffusion(sample_method_t method,
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
// x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up // x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
// noise = load_tensor_from_file(work_ctx, "./rand" + std::to_string(i+1) + ".bin"); // noise = load_tensor_from_file(work_ctx, "./rand" + std::to_string(i+1) + ".bin");
{ {
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
@ -820,7 +820,7 @@ static void sample_k_diffusion(sample_method_t method,
// Noise addition // Noise addition
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
{ {
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data; float* vec_noise = (float*)noise->data;
@ -1085,7 +1085,7 @@ static void sample_k_diffusion(sample_method_t method,
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
// x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) // x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
// noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin"); // noise = load_tensor_from_file(res_ctx, "./rand" + std::to_string(i+1) + ".bin");
{ {
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
@ -1276,7 +1276,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
if (eta > 0) { if (eta > 0) {
ggml_tensor_set_f32_randn(variance_noise, rng); ggml_ext_im_set_randn_f32(variance_noise, rng);
float* vec_variance_noise = float* vec_variance_noise =
(float*)variance_noise->data; (float*)variance_noise->data;
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
@ -1444,7 +1444,7 @@ static void sample_k_diffusion(sample_method_t method,
if (eta > 0 && i != steps - 1) { if (eta > 0 && i != steps - 1) {
// In this case, x is still pred_noised_sample, // In this case, x is still pred_noised_sample,
// continue in-place // continue in-place
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data; float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) { for (int j = 0; j < ggml_nelements(x); j++) {

View File

@ -596,16 +596,16 @@ namespace Flux {
int64_t hidden_size_x = x->ne[0]; int64_t hidden_size_x = x->ne[0];
auto mlp_params = param_generator->forward(ctx, s); auto mlp_params = param_generator->forward(ctx, s);
auto fc_params = ggml_chunk(ctx, mlp_params, 3, 0); auto fc_params = ggml_ext_chunk(ctx, mlp_params, 3, 0);
auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size); auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
fc1_gate = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] fc1_gate = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f); fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f);
fc1_value = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] fc1_value = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f); fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f);
fc2 = ggml_cont(ctx, ggml_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio] fc2 = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
fc2 = ggml_l2_norm(ctx, fc2, 1e-12f); fc2 = ggml_l2_norm(ctx, fc2, 1e-12f);
auto res_x = x; auto res_x = x;
@ -658,9 +658,9 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]); auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
x = norm->forward(ctx, x); x = norm->forward(ctx, x);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
x = conv->forward(ctx, x); x = conv->forward(ctx, x);
return x; return x;
@ -851,13 +851,13 @@ namespace Flux {
if (params.is_chroma) { if (params.is_chroma) {
int64_t mod_index_length = 344; int64_t mod_index_length = 344;
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]); auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); auto distill_timestep = ggml_ext_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); auto distill_guidance = ggml_ext_timestep_embedding(ctx, guidance, 16, 10000, 1000.f);
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead // ggml_arange tot working on a lot of backends, precomputing it on CPU instead
GGML_ASSERT(mod_index_arange != nullptr); GGML_ASSERT(mod_index_arange != nullptr);
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] auto modulation_index = ggml_ext_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
// Batch broadcast (will it ever be useful) // Batch broadcast (will it ever be useful)
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
@ -876,12 +876,12 @@ namespace Flux {
} else { } else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]); auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]); auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) { if (params.guidance_embed) {
GGML_ASSERT(guidance != nullptr); GGML_ASSERT(guidance != nullptr);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]); auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different // bf16 and fp16 result is different
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); auto g_in = ggml_ext_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
} }
@ -959,7 +959,7 @@ namespace Flux {
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size] img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size] img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
img = ggml_cont(ctx, ggml_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size] img = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size] auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
@ -973,8 +973,8 @@ namespace Flux {
nerf_pixels, nerf_pixels,
nerf_pixels->ne[0] / C, nerf_pixels->ne[0] / C,
C, C,
nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size] nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
nerf_pixels = ggml_cont(ctx, ggml_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C] nerf_pixels = ggml_cont(ctx, ggml_ext_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size] auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size] auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
@ -985,7 +985,7 @@ namespace Flux {
img_dct = block->forward(ctx, img_dct, nerf_hidden); img_dct = block->forward(ctx, img_dct, nerf_hidden);
} }
img_dct = ggml_cont(ctx, ggml_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] img_dct = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]

View File

@ -81,12 +81,12 @@ __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const cha
static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128"); static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128");
// n-mode trensor-matrix product // n-mode tensor-matrix product
// example: 2-mode product // example: 2-mode product
// A: [ne03, k, ne01, ne00] // A: [ne03, k, ne01, ne00]
// B: k rows, m columns => [k, m] // B: k rows, m columns => [k, m]
// result is [ne03, m, ne01, ne00] // result is [ne03, m, ne01, ne00]
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) { __STATIC_INLINE__ struct ggml_tensor* ggml_ext_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
// reshape A // reshape A
// swap 0th and nth axis // swap 0th and nth axis
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0)); a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
@ -105,7 +105,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx,
return result; return result;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = nullptr) { __STATIC_INLINE__ struct ggml_tensor* ggml_ext_merge_lora(ggml_context* ctx,
ggml_tensor* lora_down,
ggml_tensor* lora_up,
ggml_tensor* lora_mid = nullptr) {
struct ggml_tensor* updown; struct ggml_tensor* updown;
// flat lora tensors to multiply it // flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
@ -127,7 +130,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct
// lora_down has shape (Rank, In, 1, 1) // lora_down has shape (Rank, In, 1, 1)
// lora_up has shape (Rank, Out, 1, 1) // lora_up has shape (Rank, Out, 1, 1)
// conv layer shape is (3, 3, Out, In) // conv layer shape is (3, 3, Out, In)
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2); updown = ggml_ext_mul_n_mode(ctx, ggml_ext_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
updown = ggml_cont(ctx, updown); updown = ggml_cont(ctx, updown);
} }
return updown; return updown;
@ -135,7 +138,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct
// Kronecker product // Kronecker product
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10] // [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) { __STATIC_INLINE__ struct ggml_tensor* ggml_ext_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
return ggml_mul(ctx, return ggml_mul(ctx,
ggml_interpolate(ctx, ggml_interpolate(ctx,
a, a,
@ -147,7 +150,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g
b); b);
} }
__STATIC_INLINE__ void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr<RNG> rng) { __STATIC_INLINE__ void ggml_ext_im_set_randn_f32(struct ggml_tensor* tensor, std::shared_ptr<RNG> rng) {
uint32_t n = (uint32_t)ggml_nelements(tensor); uint32_t n = (uint32_t)ggml_nelements(tensor);
std::vector<float> random_numbers = rng->randn(n); std::vector<float> random_numbers = rng->randn(n);
for (uint32_t i = 0; i < n; i++) { for (uint32_t i = 0; i < n; i++) {
@ -155,38 +158,34 @@ __STATIC_INLINE__ void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std
} }
} }
// set tensor[i, j, k, l] __STATIC_INLINE__ void ggml_ext_tensor_set_f32(struct ggml_tensor* tensor, float value, int i0, int i1 = 0, int i2 = 0, int i3 = 0) {
// set tensor[l]
// set tensor[k, l]
// set tensor[j, k, l]
__STATIC_INLINE__ void ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = 0, int j = 0, int i = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(float)); GGML_ASSERT(tensor->nb[0] == sizeof(float));
*(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]) = value; *(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]) = value;
} }
__STATIC_INLINE__ float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { __STATIC_INLINE__ float ggml_ext_tensor_get_f32(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) {
if (tensor->buffer != nullptr) { if (tensor->buffer != nullptr) {
float value; float value;
ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(float)); ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(float));
return value; return value;
} }
GGML_ASSERT(tensor->nb[0] == sizeof(float)); GGML_ASSERT(tensor->nb[0] == sizeof(float));
return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); return *(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ int ggml_tensor_get_i32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { __STATIC_INLINE__ int ggml_ext_tensor_get_i32(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) {
if (tensor->buffer != nullptr) { if (tensor->buffer != nullptr) {
float value; float value;
ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(int)); ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(int));
return value; return value;
} }
GGML_ASSERT(tensor->nb[0] == sizeof(int)); GGML_ASSERT(tensor->nb[0] == sizeof(int));
return *(int*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); return *(int*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { __STATIC_INLINE__ ggml_fp16_t ggml_ext_tensor_get_f16(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
return *(ggml_fp16_t*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); return *(ggml_fp16_t*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int ic, bool scale = true) { __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int ic, bool scale = true) {
@ -212,28 +211,28 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
return; return;
} }
int range = 3; int range = 3;
for (int i = 0; i < tensor->ne[3]; i++) { for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
if (i >= range && i + range < tensor->ne[3]) { if (i3 >= range && i3 + range < tensor->ne[3]) {
continue; continue;
} }
for (int j = 0; j < tensor->ne[2]; j++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
if (j >= range && j + range < tensor->ne[2]) { if (i2 >= range && i2 + range < tensor->ne[2]) {
continue; continue;
} }
for (int k = 0; k < tensor->ne[1]; k++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
if (k >= range && k + range < tensor->ne[1]) { if (i1 >= range && i1 + range < tensor->ne[1]) {
continue; continue;
} }
for (int l = 0; l < tensor->ne[0]; l++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
if (l >= range && l + range < tensor->ne[0]) { if (i0 >= range && i0 + range < tensor->ne[0]) {
continue; continue;
} }
if (tensor->type == GGML_TYPE_F32) { if (tensor->type == GGML_TYPE_F32) {
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_tensor_get_f32(tensor, l, k, j, i)); printf(" [%d, %d, %d, %d] = %f\n", i3, i2, i1, i0, ggml_ext_tensor_get_f32(tensor, i0, i1, i2, i3));
} else if (tensor->type == GGML_TYPE_F16) { } else if (tensor->type == GGML_TYPE_F16) {
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_fp16_to_fp32(ggml_tensor_get_f16(tensor, l, k, j, i))); printf(" [%d, %d, %d, %d] = %f\n", i3, i2, i1, i0, ggml_fp16_to_fp32(ggml_ext_tensor_get_f16(tensor, i0, i1, i2, i3)));
} else if (tensor->type == GGML_TYPE_I32) { } else if (tensor->type == GGML_TYPE_I32) {
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_i32(tensor, l, k, j, i)); printf(" [%d, %d, %d, %d] = %i3\n", i3, i2, i1, i0, ggml_ext_tensor_get_i32(tensor, i0, i1, i2, i3));
} }
fflush(stdout); fflush(stdout);
} }
@ -242,7 +241,7 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
} }
} }
__STATIC_INLINE__ void ggml_tensor_iter( __STATIC_INLINE__ void ggml_ext_tensor_iter(
ggml_tensor* tensor, ggml_tensor* tensor,
const std::function<void(ggml_tensor*, int64_t, int64_t, int64_t, int64_t)>& fn) { const std::function<void(ggml_tensor*, int64_t, int64_t, int64_t, int64_t)>& fn) {
int64_t n0 = tensor->ne[0]; int64_t n0 = tensor->ne[0];
@ -261,7 +260,7 @@ __STATIC_INLINE__ void ggml_tensor_iter(
} }
} }
__STATIC_INLINE__ void ggml_tensor_iter( __STATIC_INLINE__ void ggml_ext_tensor_iter(
ggml_tensor* tensor, ggml_tensor* tensor,
const std::function<void(ggml_tensor*, int64_t)>& fn) { const std::function<void(ggml_tensor*, int64_t)>& fn) {
int64_t n0 = tensor->ne[0]; int64_t n0 = tensor->ne[0];
@ -274,14 +273,14 @@ __STATIC_INLINE__ void ggml_tensor_iter(
} }
} }
__STATIC_INLINE__ void ggml_tensor_diff( __STATIC_INLINE__ void ggml_ext_tensor_diff(
ggml_tensor* a, ggml_tensor* a,
ggml_tensor* b, ggml_tensor* b,
float gap = 0.1f) { float gap = 0.1f) {
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
ggml_tensor_iter(a, [&](ggml_tensor* a, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(a, [&](ggml_tensor* a, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3); float a_value = ggml_ext_tensor_get_f32(a, i0, i1, i2, i3);
float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3); float b_value = ggml_ext_tensor_get_f32(b, i0, i1, i2, i3);
if (abs(a_value - b_value) > gap) { if (abs(a_value - b_value) > gap) {
LOG_WARN("[%ld, %ld, %ld, %ld] %f %f", i3, i2, i1, i0, a_value, b_value); LOG_WARN("[%ld, %ld, %ld, %ld] %f %f", i3, i2, i1, i0, a_value, b_value);
} }
@ -375,7 +374,7 @@ __STATIC_INLINE__ float sigmoid(float x) {
// SPECIAL OPERATIONS WITH TENSORS // SPECIAL OPERATIONS WITH TENSORS
__STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, uint8_t* image_data = nullptr) { __STATIC_INLINE__ uint8_t* ggml_tensor_to_sd_image(struct ggml_tensor* input, uint8_t* image_data = nullptr) {
int64_t width = input->ne[0]; int64_t width = input->ne[0];
int64_t height = input->ne[1]; int64_t height = input->ne[1];
int64_t channels = input->ne[2]; int64_t channels = input->ne[2];
@ -386,7 +385,7 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, uint8_t
for (int iy = 0; iy < height; iy++) { for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) { for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) { for (int k = 0; k < channels; k++) {
float value = ggml_tensor_get_f32(input, ix, iy, k); float value = ggml_ext_tensor_get_f32(input, ix, iy, k);
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f); *(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f);
} }
} }
@ -394,7 +393,7 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, uint8_t
return image_data; return image_data;
} }
__STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx, bool video = false) { __STATIC_INLINE__ uint8_t* ggml_tensor_to_sd_image(struct ggml_tensor* input, int idx, bool video = false) {
int64_t width = input->ne[0]; int64_t width = input->ne[0];
int64_t height = input->ne[1]; int64_t height = input->ne[1];
int64_t channels; int64_t channels;
@ -410,9 +409,9 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx
for (int ic = 0; ic < channels; ic++) { for (int ic = 0; ic < channels; ic++) {
float value; float value;
if (video) { if (video) {
value = ggml_tensor_get_f32(input, iw, ih, idx, ic); value = ggml_ext_tensor_get_f32(input, iw, ih, idx, ic);
} else { } else {
value = ggml_tensor_get_f32(input, iw, ih, ic, idx); value = ggml_ext_tensor_get_f32(input, iw, ih, ic, idx);
} }
*(image_data + ih * width * channels + iw * channels + ic) = (uint8_t)(value * 255.0f); *(image_data + ih * width * channels + iw * channels + ic) = (uint8_t)(value * 255.0f);
} }
@ -421,24 +420,24 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx
return image_data; return image_data;
} }
__STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image, __STATIC_INLINE__ void sd_image_to_ggml_tensor(sd_image_t image,
ggml_tensor* tensor, ggml_tensor* tensor,
bool scale = true) { bool scale = true) {
GGML_ASSERT(image.width == tensor->ne[0]); GGML_ASSERT(image.width == tensor->ne[0]);
GGML_ASSERT(image.height == tensor->ne[1]); GGML_ASSERT(image.height == tensor->ne[1]);
GGML_ASSERT(image.channel == tensor->ne[2]); GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(1 == tensor->ne[3]); GGML_ASSERT(1 == tensor->ne[3]);
GGML_ASSERT(tensor->type == GGML_TYPE_F32); GGML_ASSERT(tensor->type == GGML_TYPE_F32);
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(image, i0, i1, i2, scale); float value = sd_image_get_f32(image, i0, i1, i2, scale);
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2, i3);
}); });
} }
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, __STATIC_INLINE__ void ggml_ext_tensor_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask, struct ggml_tensor* mask,
struct ggml_tensor* output, struct ggml_tensor* output,
float masked_value = 0.5f) { float masked_value = 0.5f) {
int64_t width = output->ne[0]; int64_t width = output->ne[0];
int64_t height = output->ne[1]; int64_t height = output->ne[1];
int64_t channels = output->ne[2]; int64_t channels = output->ne[2];
@ -449,36 +448,36 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
for (int iy = 0; iy < height; iy++) { for (int iy = 0; iy < height; iy++) {
int mx = (int)(ix * rescale_mx); int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my); int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my); float m = ggml_ext_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, mx, my); ggml_ext_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) { for (int k = 0; k < channels; k++) {
float value = ggml_tensor_get_f32(image_data, ix, iy, k); float value = ggml_ext_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value; value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k); ggml_ext_tensor_set_f32(output, value, ix, iy, k);
} }
} }
} }
} }
__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image, __STATIC_INLINE__ void sd_image_f32_to_ggml_tensor(sd_image_f32_t image,
ggml_tensor* tensor, ggml_tensor* tensor,
bool scale = true) { bool scale = true) {
GGML_ASSERT(image.width == tensor->ne[0]); GGML_ASSERT(image.width == tensor->ne[0]);
GGML_ASSERT(image.height == tensor->ne[1]); GGML_ASSERT(image.height == tensor->ne[1]);
GGML_ASSERT(image.channel == tensor->ne[2]); GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(1 == tensor->ne[3]); GGML_ASSERT(1 == tensor->ne[3]);
GGML_ASSERT(tensor->type == GGML_TYPE_F32); GGML_ASSERT(tensor->type == GGML_TYPE_F32);
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(image, i0, i1, i2, scale); float value = sd_image_get_f32(image, i0, i1, i2, scale);
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2, i3);
}); });
} }
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, __STATIC_INLINE__ void ggml_ext_tensor_split_2d(struct ggml_tensor* input,
struct ggml_tensor* output, struct ggml_tensor* output,
int x, int x,
int y) { int y) {
int64_t width = output->ne[0]; int64_t width = output->ne[0];
int64_t height = output->ne[1]; int64_t height = output->ne[1];
int64_t channels = output->ne[2]; int64_t channels = output->ne[2];
@ -488,8 +487,8 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
for (int ix = 0; ix < width; ix++) { for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) { for (int k = 0; k < channels; k++) {
for (int l = 0; l < ne3; l++) { for (int l = 0; l < ne3; l++) {
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k, l); float value = ggml_ext_tensor_get_f32(input, ix + x, iy + y, k, l);
ggml_tensor_set_f32(output, value, ix, iy, k, l); ggml_ext_tensor_set_f32(output, value, ix, iy, k, l);
} }
} }
} }
@ -497,19 +496,19 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
} }
// unclamped -> expects x in the range [0-1] // unclamped -> expects x in the range [0-1]
__STATIC_INLINE__ float ggml_smootherstep_f32(const float x) { __STATIC_INLINE__ float smootherstep_f32(const float x) {
GGML_ASSERT(x >= 0.f && x <= 1.f); GGML_ASSERT(x >= 0.f && x <= 1.f);
return x * x * x * (x * (6.0f * x - 15.0f) + 10.0f); return x * x * x * (x * (6.0f * x - 15.0f) + 10.0f);
} }
__STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
struct ggml_tensor* output, struct ggml_tensor* output,
int x, int x,
int y, int y,
int overlap_x, int overlap_x,
int overlap_y, int overlap_y,
int x_skip = 0, int x_skip = 0,
int y_skip = 0) { int y_skip = 0) {
int64_t width = input->ne[0]; int64_t width = input->ne[0];
int64_t height = input->ne[1]; int64_t height = input->ne[1];
int64_t channels = input->ne[2]; int64_t channels = input->ne[2];
@ -523,9 +522,9 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
for (int ix = x_skip; ix < width; ix++) { for (int ix = x_skip; ix < width; ix++) {
for (int k = 0; k < channels; k++) { for (int k = 0; k < channels; k++) {
for (int l = 0; l < ne3; l++) { for (int l = 0; l < ne3; l++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k, l); float new_value = ggml_ext_tensor_get_f32(input, ix, iy, k, l);
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k, l); float old_value = ggml_ext_tensor_get_f32(output, x + ix, y + iy, k, l);
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1; const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1; const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
@ -535,12 +534,12 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f); const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f); const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
ggml_tensor_set_f32( ggml_ext_tensor_set_f32(
output, output,
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f), old_value + new_value * smootherstep_f32(y_f) * smootherstep_f32(x_f),
x + ix, y + iy, k, l); x + ix, y + iy, k, l);
} else { } else {
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k, l); ggml_ext_tensor_set_f32(output, new_value, x + ix, y + iy, k, l);
} }
} }
} }
@ -548,7 +547,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
} }
} }
__STATIC_INLINE__ float ggml_tensor_mean(struct ggml_tensor* src) { __STATIC_INLINE__ float ggml_ext_tensor_mean(struct ggml_tensor* src) {
float mean = 0.0f; float mean = 0.0f;
int64_t nelements = ggml_nelements(src); int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data; float* data = (float*)src->data;
@ -559,7 +558,7 @@ __STATIC_INLINE__ float ggml_tensor_mean(struct ggml_tensor* src) {
} }
// a = a+b // a = a+b
__STATIC_INLINE__ void ggml_tensor_add(struct ggml_tensor* a, struct ggml_tensor* b) { __STATIC_INLINE__ void ggml_ext_tensor_add_inplace(struct ggml_tensor* a, struct ggml_tensor* b) {
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
int64_t nelements = ggml_nelements(a); int64_t nelements = ggml_nelements(a);
float* vec_a = (float*)a->data; float* vec_a = (float*)a->data;
@ -569,7 +568,7 @@ __STATIC_INLINE__ void ggml_tensor_add(struct ggml_tensor* a, struct ggml_tensor
} }
} }
__STATIC_INLINE__ void ggml_tensor_scale(struct ggml_tensor* src, float scale) { __STATIC_INLINE__ void ggml_ext_tensor_scale_inplace(struct ggml_tensor* src, float scale) {
int64_t nelements = ggml_nelements(src); int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data; float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) { for (int i = 0; i < nelements; i++) {
@ -577,7 +576,7 @@ __STATIC_INLINE__ void ggml_tensor_scale(struct ggml_tensor* src, float scale) {
} }
} }
__STATIC_INLINE__ void ggml_tensor_clamp(struct ggml_tensor* src, float min, float max) { __STATIC_INLINE__ void ggml_ext_tensor_clamp_inplace(struct ggml_tensor* src, float min, float max) {
int64_t nelements = ggml_nelements(src); int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data; float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) { for (int i = 0; i < nelements; i++) {
@ -586,10 +585,10 @@ __STATIC_INLINE__ void ggml_tensor_clamp(struct ggml_tensor* src, float min, flo
} }
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context* ctx,
struct ggml_tensor* a, struct ggml_tensor* a,
struct ggml_tensor* b, struct ggml_tensor* b,
int dim) { int dim) {
int64_t ne[GGML_MAX_DIMS]; int64_t ne[GGML_MAX_DIMS];
for (int d = 0; d < GGML_MAX_DIMS; ++d) { for (int d = 0; d < GGML_MAX_DIMS; ++d) {
if (d == dim) { if (d == dim) {
@ -609,12 +608,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ct
for (int i1 = 0; i1 < result->ne[1]; i1++) { for (int i1 = 0; i1 < result->ne[1]; i1++) {
for (int i0 = 0; i0 < result->ne[0]; i0++) { for (int i0 = 0; i0 < result->ne[0]; i0++) {
if (i0 < a->ne[0] && i1 < a->ne[1] && i2 < a->ne[2] && i3 < a->ne[3]) { if (i0 < a->ne[0] && i1 < a->ne[1] && i2 < a->ne[2] && i3 < a->ne[3]) {
v = ggml_tensor_get_f32(a, i0, i1, i2, i3); v = ggml_ext_tensor_get_f32(a, i0, i1, i2, i3);
} else { } else {
v = ggml_tensor_get_f32(b, i0 - o[0], i1 - o[1], i2 - o[2], i3 - o[3]); v = ggml_ext_tensor_get_f32(b, i0 - o[0], i1 - o[1], i2 - o[2], i3 - o[3]);
} }
ggml_tensor_set_f32(result, v, i0, i1, i2, i3); ggml_ext_tensor_set_f32(result, v, i0, i1, i2, i3);
} }
} }
} }
@ -642,8 +641,8 @@ __STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
} }
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_cont(struct ggml_context* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
if (ggml_is_contiguous(x)) { if (ggml_is_contiguous(x)) {
return x; return x;
} }
@ -651,12 +650,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx,
} }
// torch like permute // torch like permute
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_torch_permute(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int axis0, int axis0,
int axis1, int axis1,
int axis2, int axis2,
int axis3) { int axis3) {
int torch_axes[4] = {axis0, axis1, axis2, axis3}; int torch_axes[4] = {axis0, axis1, axis2, axis3};
int ggml_axes[4] = {0}; int ggml_axes[4] = {0};
@ -675,11 +674,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ct
return ggml_permute(ctx, x, ggml_axes[0], ggml_axes[1], ggml_axes[2], ggml_axes[3]); return ggml_permute(ctx, x, ggml_axes[0], ggml_axes[1], ggml_axes[2], ggml_axes[3]);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int64_t dim, int64_t dim,
int64_t start, int64_t start,
int64_t end) { int64_t end) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
if (x->ne[dim] == 1) { if (x->ne[dim] == 1) {
return x; return x;
@ -704,7 +703,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
inv_perm[perm[i]] = i; inv_perm[perm[i]] = i;
if (dim != 3) { if (dim != 3) {
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -714,7 +713,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start); x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
if (dim != 3) { if (dim != 3) {
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -722,10 +721,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
} }
// example: [N, 3*C, H, W] => ([N, C, H, W], [N, C, H, W], [N, C, H, W]) // example: [N, 3*C, H, W] => ([N, C, H, W], [N, C, H, W], [N, C, H, W])
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_context* ctx, __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int num, int num,
int64_t dim) { int64_t dim) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
GGML_ASSERT(x->ne[dim] % num == 0); GGML_ASSERT(x->ne[dim] % num == 0);
@ -739,7 +738,7 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
inv_perm[perm[i]] = i; inv_perm[perm[i]] = i;
if (dim != 3) { if (dim != 3) {
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -752,7 +751,7 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size); x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size);
if (dim != 3) { if (dim != 3) {
chunk = ggml_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
chunk = ggml_cont(ctx, chunk); chunk = ggml_cont(ctx, chunk);
} }
chunks.push_back(chunk); chunks.push_back(chunk);
@ -913,9 +912,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
int overlap_y_out = decode ? tile_overlap_y * scale : tile_overlap_y; int overlap_y_out = decode ? tile_overlap_y * scale : tile_overlap_y;
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
ggml_split_tensor_2d(input, input_tile, x_in, y_in); ggml_ext_tensor_split_2d(input, input_tile, x_in, y_in);
on_processing(input_tile, output_tile, false); on_processing(input_tile, output_tile, false);
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy); ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
int64_t t2 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f; last_time = (t2 - t1) / 1000.0f;
@ -939,18 +938,18 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input,
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing); sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context* ctx,
struct ggml_tensor* a) { struct ggml_tensor* a) {
const float eps = 1e-6f; // default eps parameter const float eps = 1e-6f; // default eps parameter
return ggml_group_norm(ctx, a, 32, eps); return ggml_group_norm(ctx, a, 32, eps);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
bool force_prec_f32 = false, bool force_prec_f32 = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_scale(ctx, x, scale);
} }
@ -980,18 +979,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
// x: [N, IC, IH, IW] // x: [N, IC, IH, IW]
// b: [OC,] // b: [OC,]
// result: [N, OC, OH, OW] // result: [N, OC, OH, OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
int s0 = 1, int s0 = 1,
int s1 = 1, int s1 = 1,
int p0 = 0, int p0 = 0,
int p1 = 0, int p1 = 0,
int d0 = 1, int d0 = 1,
int d1 = 1, int d1 = 1,
bool direct = false, bool direct = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_scale(ctx, x, scale);
} }
@ -1014,20 +1013,20 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
// x: [N, IC, IH, IW] // x: [N, IC, IH, IW]
// b: [OC,] // b: [OC,]
// result: [N*OC, OD, OH, OW] // result: [N*OC, OD, OH, OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_3d(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
int64_t IC, int64_t IC,
int s0 = 1, int s0 = 1,
int s1 = 1, int s1 = 1,
int s2 = 1, int s2 = 1,
int p0 = 0, int p0 = 0,
int p1 = 0, int p1 = 0,
int p2 = 0, int p2 = 0,
int d0 = 1, int d0 = 1,
int d1 = 1, int d1 = 1,
int d2 = 1) { int d2 = 1) {
int64_t OC = w->ne[3] / IC; int64_t OC = w->ne[3] / IC;
int64_t N = x->ne[3] / IC; int64_t N = x->ne[3] / IC;
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
@ -1043,13 +1042,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
// x: [N, IC, ID, IH*IW] // x: [N, IC, ID, IH*IW]
// b: [OC,] // b: [OC,]
// result: [N, OC, OD, OH*OW] // result: [N, OC, OD, OH*OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_3d_nx1x1(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
int s2 = 1, int s2 = 1,
int p2 = 1, int p2 = 1,
int d2 = 1) { int d2 = 1) {
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [N, OC, T, OH * OW] x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [N, OC, T, OH * OW]
if (b != nullptr) { if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1082,8 +1081,8 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_c
int64_t N = qkv->ne[3]; int64_t N = qkv->ne[3];
int64_t nb1 = qkv->nb[1]; int64_t nb1 = qkv->nb[1];
int64_t nb2 = qkv->nb[2]; int64_t nb2 = qkv->nb[2];
qkv = ggml_reshape_4d(ctx, qkv, W * H, C, 3, N); // [N, 3, C, H*W] qkv = ggml_reshape_4d(ctx, qkv, W * H, C, 3, N); // [N, 3, C, H*W]
qkv = ggml_cont(ctx, ggml_torch_permute(ctx, qkv, 0, 1, 3, 2)); // [3, N, C, H*W] qkv = ggml_cont(ctx, ggml_ext_torch_permute(ctx, qkv, 0, 1, 3, 2)); // [3, N, C, H*W]
int64_t offset = qkv->nb[2] * qkv->ne[2]; int64_t offset = qkv->nb[2] * qkv->ne[2];
auto q = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 0); // [N, C, H, W] auto q = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 0); // [N, C, H, W]
@ -1092,43 +1091,43 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_c
return {q, k, v}; return {q, k, v};
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_full(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
float value, float value,
int64_t ne0, int64_t ne0,
int64_t ne1, int64_t ne1,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
auto t = ggml_scale(ctx, one, value); // [1,] auto t = ggml_scale(ctx, one, value); // [1,]
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
return t; return t;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_zeros(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros(struct ggml_context* ctx,
int64_t ne0, int64_t ne0,
int64_t ne1, int64_t ne1,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
return ggml_full(ctx, 0.f, ne0, ne1, ne2, ne3); return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ones(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
int64_t ne0, int64_t ne0,
int64_t ne1, int64_t ne1,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
return ggml_full(ctx, 1.f, ne0, ne1, ne2, ne3); return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3);
} }
// q: [N * n_head, n_token, d_head] // q: [N * n_head, n_token, d_head]
// k: [N * n_head, n_k, d_head] // k: [N * n_head, n_k, d_head]
// v: [N * n_head, d_head, n_k] // v: [N * n_head, d_head, n_k]
// return: [N * n_head, n_token, d_head] // return: [N * n_head, n_token, d_head]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
struct ggml_tensor* q, struct ggml_tensor* q,
struct ggml_tensor* k, struct ggml_tensor* k,
struct ggml_tensor* v, struct ggml_tensor* v,
bool mask = false) { bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) #if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else #else
@ -1149,17 +1148,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] // v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
// mask: [N, L_q, L_k] // mask: [N, L_q, L_k]
// return: [N, L_q, C] // return: [N, L_q, C]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context* ctx,
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* q, struct ggml_tensor* q,
struct ggml_tensor* k, struct ggml_tensor* k,
struct ggml_tensor* v, struct ggml_tensor* v,
int64_t n_head, int64_t n_head,
struct ggml_tensor* mask = nullptr, struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false, bool diag_mask_inf = false,
bool skip_reshape = false, bool skip_reshape = false,
bool flash_attn = false, // avoid overflow bool flash_attn = false, // avoid overflow
float kv_scale = 1.0f) { float kv_scale = 1.0f) {
int64_t L_q; int64_t L_q;
int64_t L_k; int64_t L_k;
int64_t C; int64_t C;
@ -1174,13 +1173,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
d_head = C / n_head; d_head = C / n_head;
n_kv_head = k->ne[0] / d_head; n_kv_head = k->ne[0] / d_head;
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] q = ggml_ext_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
k = ggml_reshape_4d(ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head] k = ggml_reshape_4d(ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_kv_head, L_k, d_head] k = ggml_ext_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_kv_head, L_k, d_head]
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head] k = ggml_reshape_3d(ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head]
v = ggml_reshape_4d(ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head] v = ggml_reshape_4d(ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
} else { } else {
@ -1206,7 +1205,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
} }
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); v_in = ggml_ext_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N); v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
if (kv_pad != 0) { if (kv_pad != 0) {
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
@ -1220,8 +1219,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
mask_in = ggml_transpose(ctx, mask_in); mask_in = ggml_transpose(ctx, mask_in);
} else { } else {
if (kv_pad > 0) { if (kv_pad > 0) {
mask_in = ggml_zeros(ctx, L_k, L_q, 1, 1); mask_in = ggml_ext_zeros(ctx, L_k, L_q, 1, 1);
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); auto pad_tensor = ggml_ext_full(ctx, -INFINITY, kv_pad, L_q, 1, 1);
mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0); mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0);
} }
} }
@ -1271,8 +1270,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
// if (flash_attn) { // if (flash_attn) {
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// } // }
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k] v = ggml_ext_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale); kq = ggml_scale_inplace(ctx, kq, scale);
@ -1290,17 +1289,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
} }
kqv = ggml_nn_cont(ctx, kqv); kqv = ggml_ext_cont(ctx, kqv);
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C] kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
return kqv; return kqv;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_layer_norm(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
float eps = EPS) { float eps = EPS) {
x = ggml_norm(ctx, x, eps); x = ggml_norm(ctx, x, eps);
if (w != nullptr) { if (w != nullptr) {
x = ggml_mul_inplace(ctx, x, w); x = ggml_mul_inplace(ctx, x, w);
@ -1311,11 +1310,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct
return x; return x;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
int num_groups = 32) { int num_groups = 32) {
if (ggml_n_dims(x) >= 3 && w != nullptr && b != nullptr) { if (ggml_n_dims(x) >= 3 && w != nullptr && b != nullptr) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], 1); w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], 1);
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1331,7 +1330,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
return x; return x;
} }
__STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { __STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) {
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) #if defined(SD_USE_CUDA) || defined(SD_USE_SYCL)
if (!ggml_backend_is_cpu(backend)) { if (!ggml_backend_is_cpu(backend)) {
ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_tensor_get_async(backend, tensor, data, offset, size);
@ -1344,7 +1343,7 @@ __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend,
#endif #endif
} }
__STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) { __STATIC_INLINE__ float ggml_ext_backend_tensor_get_f32(ggml_tensor* tensor) {
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_I32); GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_I32);
float value; float value;
if (tensor->type == GGML_TYPE_F32) { if (tensor->type == GGML_TYPE_F32) {
@ -1439,7 +1438,7 @@ __STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context
return embedding; return embedding;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
struct ggml_context* ctx, struct ggml_context* ctx,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
int dim, int dim,
@ -1857,7 +1856,7 @@ public:
*output = ggml_dup_tensor(output_ctx, result); *output = ggml_dup_tensor(output_ctx, result);
} }
if (*output != nullptr) { if (*output != nullptr) {
ggml_backend_tensor_get_and_sync(runtime_backend, result, (*output)->data, 0, ggml_nbytes(*output)); ggml_ext_backend_tensor_get_and_sync(runtime_backend, result, (*output)->data, 0, ggml_nbytes(*output));
} }
} }
@ -2007,7 +2006,7 @@ public:
if (bias) { if (bias) {
b = params["bias"]; b = params["bias"];
} }
return ggml_nn_linear(ctx, x, w, b, force_prec_f32, scale); return ggml_ext_linear(ctx, x, w, b, force_prec_f32, scale);
} }
}; };
@ -2111,18 +2110,18 @@ public:
if (bias) { if (bias) {
b = params["bias"]; b = params["bias"];
} }
return ggml_nn_conv_2d(ctx, return ggml_ext_conv_2d(ctx,
x, x,
w, w,
b, b,
stride.second, stride.second,
stride.first, stride.first,
padding.second, padding.second,
padding.first, padding.first,
dilation.second, dilation.second,
dilation.first, dilation.first,
direct, direct,
scale); scale);
} }
}; };
@ -2169,7 +2168,7 @@ public:
if (bias) { if (bias) {
b = params["bias"]; b = params["bias"];
} }
return ggml_nn_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation); return ggml_ext_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation);
} }
}; };
@ -2218,10 +2217,10 @@ public:
if (bias) { if (bias) {
b = params["bias"]; b = params["bias"];
} }
return ggml_nn_conv_3d(ctx, x, w, b, in_channels, return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding), std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
} }
}; };
@ -2263,7 +2262,7 @@ public:
b = params["bias"]; b = params["bias"];
} }
} }
return ggml_nn_layer_norm(ctx, x, w, b, eps); return ggml_ext_layer_norm(ctx, x, w, b, eps);
} }
}; };
@ -2300,7 +2299,7 @@ public:
w = params["weight"]; w = params["weight"];
b = params["bias"]; b = params["bias"];
} }
return ggml_nn_group_norm(ctx, x, w, b, num_groups); return ggml_ext_group_norm(ctx, x, w, b, num_groups);
} }
}; };
@ -2378,7 +2377,7 @@ public:
struct ggml_tensor* k = k_proj->forward(ctx, x); struct ggml_tensor* k = k_proj->forward(ctx, x);
struct ggml_tensor* v = v_proj->forward(ctx, x); struct ggml_tensor* v = v_proj->forward(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x; return x;

View File

@ -372,15 +372,15 @@ struct LoraModel : public GGMLRunner {
continue; continue;
} }
struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); struct ggml_tensor* updown_1 = ggml_ext_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); struct ggml_tensor* updown_2 = ggml_ext_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid);
updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2);
// calc_scale // calc_scale
// TODO: .dora_scale? // TODO: .dora_scale?
int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) { if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank; scale_value = alpha / rank;
} }
} else if (lora_tensors.find(full_key + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(full_key + ".lokr_w1_a") != lora_tensors.end()) { } else if (lora_tensors.find(full_key + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(full_key + ".lokr_w1_a") != lora_tensors.end()) {
@ -418,7 +418,7 @@ struct LoraModel : public GGMLRunner {
int64_t rank = down->ne[ggml_n_dims(down) - 1]; int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) { if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank; scale_value = alpha / rank;
} }
} }
@ -426,7 +426,7 @@ struct LoraModel : public GGMLRunner {
up = to_f32(compute_ctx, lora_tensors[up_name]); up = to_f32(compute_ctx, lora_tensors[up_name]);
applied_lora_tensors.insert(up_name); applied_lora_tensors.insert(up_name);
} }
lokr_w1 = ggml_merge_lora(compute_ctx, down, up); lokr_w1 = ggml_ext_merge_lora(compute_ctx, down, up);
} }
if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) { if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) {
lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]); lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]);
@ -442,7 +442,7 @@ struct LoraModel : public GGMLRunner {
int64_t rank = down->ne[ggml_n_dims(down) - 1]; int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) { if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank; scale_value = alpha / rank;
} }
} }
@ -450,13 +450,13 @@ struct LoraModel : public GGMLRunner {
up = to_f32(compute_ctx, lora_tensors[up_name]); up = to_f32(compute_ctx, lora_tensors[up_name]);
applied_lora_tensors.insert(up_name); applied_lora_tensors.insert(up_name);
} }
lokr_w2 = ggml_merge_lora(compute_ctx, down, up); lokr_w2 = ggml_ext_merge_lora(compute_ctx, down, up);
} }
// Technically it might be unused, but I believe it's the expected behavior // Technically it might be unused, but I believe it's the expected behavior
applied_lora_tensors.insert(alpha_name); applied_lora_tensors.insert(alpha_name);
updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2); updown = ggml_ext_kronecker(compute_ctx, lokr_w1, lokr_w2);
} else { } else {
// LoRA mode // LoRA mode
@ -535,30 +535,30 @@ struct LoraModel : public GGMLRunner {
float lora_v_scale = 1; float lora_v_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); lora_q_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name); applied_lora_tensors.insert(split_q_scale_name);
} }
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); lora_k_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name); applied_lora_tensors.insert(split_k_scale_name);
} }
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); lora_v_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name); applied_lora_tensors.insert(split_v_scale_name);
} }
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); float lora_q_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name); applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank; lora_q_scale = lora_q_alpha / q_rank;
} }
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); float lora_k_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name); applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank; lora_k_scale = lora_k_alpha / k_rank;
} }
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); float lora_v_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name); applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank; lora_v_scale = lora_v_alpha / v_rank;
} }
@ -688,39 +688,39 @@ struct LoraModel : public GGMLRunner {
float lora_m_scale = 1; float lora_m_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); lora_q_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name); applied_lora_tensors.insert(split_q_scale_name);
} }
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); lora_k_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name); applied_lora_tensors.insert(split_k_scale_name);
} }
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); lora_v_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name); applied_lora_tensors.insert(split_v_scale_name);
} }
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); lora_m_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
applied_lora_tensors.insert(split_m_scale_name); applied_lora_tensors.insert(split_m_scale_name);
} }
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); float lora_q_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name); applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank; lora_q_scale = lora_q_alpha / q_rank;
} }
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); float lora_k_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name); applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank; lora_k_scale = lora_k_alpha / k_rank;
} }
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); float lora_v_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name); applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank; lora_v_scale = lora_v_alpha / v_rank;
} }
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); float lora_m_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
applied_lora_tensors.insert(split_m_alpha_name); applied_lora_tensors.insert(split_m_alpha_name);
lora_m_scale = lora_m_alpha / m_rank; lora_m_scale = lora_m_alpha / m_rank;
} }
@ -816,16 +816,16 @@ struct LoraModel : public GGMLRunner {
// TODO: .dora_scale? // TODO: .dora_scale?
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
if (lora_tensors.find(scale_name) != lora_tensors.end()) { if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); scale_value = ggml_ext_backend_tensor_get_f32(lora_tensors[scale_name]);
applied_lora_tensors.insert(scale_name); applied_lora_tensors.insert(scale_name);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank; scale_value = alpha / rank;
// LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value); // LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value);
applied_lora_tensors.insert(alpha_name); applied_lora_tensors.insert(alpha_name);
} }
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); updown = ggml_ext_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
} }
scale_value *= multiplier; scale_value *= multiplier;
ggml_tensor* original_tensor = model_tensor; ggml_tensor* original_tensor = model_tensor;

View File

@ -113,7 +113,7 @@ public:
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]); auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]); auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] auto t_freq = ggml_ext_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
auto t_emb = mlp_0->forward(ctx, t_freq); auto t_emb = mlp_0->forward(ctx, t_freq);
t_emb = ggml_silu_inplace(ctx, t_emb); t_emb = ggml_silu_inplace(ctx, t_emb);
@ -210,8 +210,8 @@ public:
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* x) { struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x); auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
}; };
@ -441,8 +441,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] auto attn2_out = ggml_ext_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = post_attention_x(ctx, x = post_attention_x(ctx,
attn_out, attn_out,
attn2_out, attn2_out,
@ -458,7 +458,7 @@ public:
auto qkv = qkv_intermediates.first; auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second; auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
@ -504,8 +504,8 @@ block_mixing(struct ggml_context* ctx,
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size] auto attn = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx, auto context_attn = ggml_view_3d(ctx,
attn, attn,
attn->ne[0], attn->ne[0],
@ -538,7 +538,7 @@ block_mixing(struct ggml_context* ctx,
} }
if (x_block->self_attn) { if (x_block->self_attn) {
auto attn2 = ggml_nn_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] auto attn2 = ggml_ext_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx, x = x_block->post_attention_x(ctx,
x_attn, x_attn,

View File

@ -29,7 +29,7 @@ public:
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layernorm"]); auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layernorm"]);
struct ggml_tensor* r = x; struct ggml_tensor* r = x;
// x = ggml_nn_layer_norm(ctx, x, ln_w, ln_b); // x = ggml_ext_layer_norm(ctx, x, ln_w, ln_b);
x = layer_norm->forward(ctx, x); x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);

View File

@ -28,7 +28,7 @@ void gaussian_kernel(struct ggml_tensor* kernel) {
for (int x = 0; x < kernel->ne[1]; x++) { for (int x = 0; x < kernel->ne[1]; x++) {
float gy = -ks_mid + x; float gy = -ks_mid + x;
float k_ = expf(-((gx * gx + gy * gy) / (2.0f * powf(sigma, 2.0f)))) * normal; float k_ = expf(-((gx * gx + gy * gy) / (2.0f * powf(sigma, 2.0f)))) * normal;
ggml_tensor_set_f32(kernel, k_, x, y); ggml_ext_tensor_set_f32(kernel, k_, x, y);
} }
} }
} }
@ -36,11 +36,11 @@ void gaussian_kernel(struct ggml_tensor* kernel) {
void grayscale(struct ggml_tensor* rgb_img, struct ggml_tensor* grayscale) { void grayscale(struct ggml_tensor* rgb_img, struct ggml_tensor* grayscale) {
for (int iy = 0; iy < rgb_img->ne[1]; iy++) { for (int iy = 0; iy < rgb_img->ne[1]; iy++) {
for (int ix = 0; ix < rgb_img->ne[0]; ix++) { for (int ix = 0; ix < rgb_img->ne[0]; ix++) {
float r = ggml_tensor_get_f32(rgb_img, ix, iy); float r = ggml_ext_tensor_get_f32(rgb_img, ix, iy);
float g = ggml_tensor_get_f32(rgb_img, ix, iy, 1); float g = ggml_ext_tensor_get_f32(rgb_img, ix, iy, 1);
float b = ggml_tensor_get_f32(rgb_img, ix, iy, 2); float b = ggml_ext_tensor_get_f32(rgb_img, ix, iy, 2);
float gray = 0.2989f * r + 0.5870f * g + 0.1140f * b; float gray = 0.2989f * r + 0.5870f * g + 0.1140f * b;
ggml_tensor_set_f32(grayscale, gray, ix, iy); ggml_ext_tensor_set_f32(grayscale, gray, ix, iy);
} }
} }
} }
@ -81,37 +81,37 @@ void normalize_tensor(struct ggml_tensor* g) {
void non_max_supression(struct ggml_tensor* result, struct ggml_tensor* G, struct ggml_tensor* D) { void non_max_supression(struct ggml_tensor* result, struct ggml_tensor* G, struct ggml_tensor* D) {
for (int iy = 1; iy < result->ne[1] - 1; iy++) { for (int iy = 1; iy < result->ne[1] - 1; iy++) {
for (int ix = 1; ix < result->ne[0] - 1; ix++) { for (int ix = 1; ix < result->ne[0] - 1; ix++) {
float angle = ggml_tensor_get_f32(D, ix, iy) * 180.0f / M_PI_; float angle = ggml_ext_tensor_get_f32(D, ix, iy) * 180.0f / M_PI_;
angle = angle < 0.0f ? angle += 180.0f : angle; angle = angle < 0.0f ? angle += 180.0f : angle;
float q = 1.0f; float q = 1.0f;
float r = 1.0f; float r = 1.0f;
// angle 0 // angle 0
if ((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180)) { if ((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180)) {
q = ggml_tensor_get_f32(G, ix, iy + 1); q = ggml_ext_tensor_get_f32(G, ix, iy + 1);
r = ggml_tensor_get_f32(G, ix, iy - 1); r = ggml_ext_tensor_get_f32(G, ix, iy - 1);
} }
// angle 45 // angle 45
else if (22.5f >= angle && angle < 67.5f) { else if (22.5f >= angle && angle < 67.5f) {
q = ggml_tensor_get_f32(G, ix + 1, iy - 1); q = ggml_ext_tensor_get_f32(G, ix + 1, iy - 1);
r = ggml_tensor_get_f32(G, ix - 1, iy + 1); r = ggml_ext_tensor_get_f32(G, ix - 1, iy + 1);
} }
// angle 90 // angle 90
else if (67.5f >= angle && angle < 112.5) { else if (67.5f >= angle && angle < 112.5) {
q = ggml_tensor_get_f32(G, ix + 1, iy); q = ggml_ext_tensor_get_f32(G, ix + 1, iy);
r = ggml_tensor_get_f32(G, ix - 1, iy); r = ggml_ext_tensor_get_f32(G, ix - 1, iy);
} }
// angle 135 // angle 135
else if (112.5 >= angle && angle < 157.5f) { else if (112.5 >= angle && angle < 157.5f) {
q = ggml_tensor_get_f32(G, ix - 1, iy - 1); q = ggml_ext_tensor_get_f32(G, ix - 1, iy - 1);
r = ggml_tensor_get_f32(G, ix + 1, iy + 1); r = ggml_ext_tensor_get_f32(G, ix + 1, iy + 1);
} }
float cur = ggml_tensor_get_f32(G, ix, iy); float cur = ggml_ext_tensor_get_f32(G, ix, iy);
if ((cur >= q) && (cur >= r)) { if ((cur >= q) && (cur >= r)) {
ggml_tensor_set_f32(result, cur, ix, iy); ggml_ext_tensor_set_f32(result, cur, ix, iy);
} else { } else {
ggml_tensor_set_f32(result, 0.0f, ix, iy); ggml_ext_tensor_set_f32(result, 0.0f, ix, iy);
} }
} }
} }
@ -138,9 +138,9 @@ void threshold_hystersis(struct ggml_tensor* img, float high_threshold, float lo
for (int iy = 0; iy < img->ne[1]; iy++) { for (int iy = 0; iy < img->ne[1]; iy++) {
for (int ix = 0; ix < img->ne[0]; ix++) { for (int ix = 0; ix < img->ne[0]; ix++) {
if (ix >= 3 && ix <= img->ne[0] - 3 && iy >= 3 && iy <= img->ne[1] - 3) { if (ix >= 3 && ix <= img->ne[0] - 3 && iy >= 3 && iy <= img->ne[1] - 3) {
ggml_tensor_set_f32(img, ggml_tensor_get_f32(img, ix, iy), ix, iy); ggml_ext_tensor_set_f32(img, ggml_ext_tensor_get_f32(img, ix, iy), ix, iy);
} else { } else {
ggml_tensor_set_f32(img, 0.0f, ix, iy); ggml_ext_tensor_set_f32(img, 0.0f, ix, iy);
} }
} }
} }
@ -148,14 +148,14 @@ void threshold_hystersis(struct ggml_tensor* img, float high_threshold, float lo
// hysteresis // hysteresis
for (int iy = 1; iy < img->ne[1] - 1; iy++) { for (int iy = 1; iy < img->ne[1] - 1; iy++) {
for (int ix = 1; ix < img->ne[0] - 1; ix++) { for (int ix = 1; ix < img->ne[0] - 1; ix++) {
float imd_v = ggml_tensor_get_f32(img, ix, iy); float imd_v = ggml_ext_tensor_get_f32(img, ix, iy);
if (imd_v == weak) { if (imd_v == weak) {
if (ggml_tensor_get_f32(img, ix + 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix + 1, iy) == strong || if (ggml_ext_tensor_get_f32(img, ix + 1, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix + 1, iy) == strong ||
ggml_tensor_get_f32(img, ix, iy - 1) == strong || ggml_tensor_get_f32(img, ix, iy + 1) == strong || ggml_ext_tensor_get_f32(img, ix, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix, iy + 1) == strong ||
ggml_tensor_get_f32(img, ix - 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix - 1, iy) == strong) { ggml_ext_tensor_get_f32(img, ix - 1, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix - 1, iy) == strong) {
ggml_tensor_set_f32(img, strong, ix, iy); ggml_ext_tensor_set_f32(img, strong, ix, iy);
} else { } else {
ggml_tensor_set_f32(img, 0.0f, ix, iy); ggml_ext_tensor_set_f32(img, 0.0f, ix, iy);
} }
} }
} }
@ -198,7 +198,7 @@ bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold,
struct ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray); struct ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray);
struct ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray); struct ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray);
struct ggml_tensor* tetha = ggml_dup_tensor(work_ctx, image_gray); struct ggml_tensor* tetha = ggml_dup_tensor(work_ctx, image_gray);
sd_image_to_tensor(img, image); sd_image_to_ggml_tensor(img, image);
grayscale(image, image_gray); grayscale(image, image_gray);
convolve(image_gray, image_gray, gkernel, 2); convolve(image_gray, image_gray, gkernel, 2);
convolve(image_gray, iX, sf_kx, 1); convolve(image_gray, iX, sf_kx, 1);
@ -211,14 +211,14 @@ bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold,
// to RGB channels // to RGB channels
for (int iy = 0; iy < img.height; iy++) { for (int iy = 0; iy < img.height; iy++) {
for (int ix = 0; ix < img.width; ix++) { for (int ix = 0; ix < img.width; ix++) {
float gray = ggml_tensor_get_f32(image_gray, ix, iy); float gray = ggml_ext_tensor_get_f32(image_gray, ix, iy);
gray = inverse ? 1.0f - gray : gray; gray = inverse ? 1.0f - gray : gray;
ggml_tensor_set_f32(image, gray, ix, iy); ggml_ext_tensor_set_f32(image, gray, ix, iy);
ggml_tensor_set_f32(image, gray, ix, iy, 1); ggml_ext_tensor_set_f32(image, gray, ix, iy, 1);
ggml_tensor_set_f32(image, gray, ix, iy, 2); ggml_ext_tensor_set_f32(image, gray, ix, iy, 2);
} }
} }
sd_tensor_to_image(image, img.data); ggml_tensor_to_sd_image(image, img.data);
ggml_free(work_ctx); ggml_free(work_ctx);
return true; return true;
} }

View File

@ -56,7 +56,7 @@ namespace Qwen {
// return: [N, embedding_dim] // return: [N, embedding_dim]
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]); auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1.f); auto timesteps_proj = ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj); auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
return timesteps_emb; return timesteps_emb;
} }
@ -246,11 +246,11 @@ namespace Qwen {
auto img_mod_params = ggml_silu(ctx, t_emb); auto img_mod_params = ggml_silu(ctx, t_emb);
img_mod_params = img_mod_1->forward(ctx, img_mod_params); img_mod_params = img_mod_1->forward(ctx, img_mod_params);
auto img_mod_param_vec = ggml_chunk(ctx, img_mod_params, 6, 0); auto img_mod_param_vec = ggml_ext_chunk(ctx, img_mod_params, 6, 0);
auto txt_mod_params = ggml_silu(ctx, t_emb); auto txt_mod_params = ggml_silu(ctx, t_emb);
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params); txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
auto txt_mod_param_vec = ggml_chunk(ctx, txt_mod_params, 6, 0); auto txt_mod_param_vec = ggml_ext_chunk(ctx, txt_mod_params, 6, 0);
auto img_normed = img_norm1->forward(ctx, img); auto img_normed = img_norm1->forward(ctx, img);
auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]); auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
@ -305,7 +305,7 @@ namespace Qwen {
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]); auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto emb = linear->forward(ctx, ggml_silu(ctx, c)); auto emb = linear->forward(ctx, ggml_silu(ctx, c));
auto mods = ggml_chunk(ctx, emb, 2, 0); auto mods = ggml_ext_chunk(ctx, emb, 2, 0);
auto scale = mods[0]; auto scale = mods[0];
auto shift = mods[1]; auto shift = mods[1];
@ -496,8 +496,8 @@ namespace Qwen {
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice // slice
out = ggml_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_slice(ctx, out, 0, 0, W); // [N, C, H, W] out = ggml_ext_slice(ctx, out, 0, 0, W); // [N, C, H, W]
return out; return out;
} }

View File

@ -423,11 +423,11 @@ namespace Qwen {
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]); auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]); auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
auto x0 = ggml_slice(ctx, x, 2, 0, 1); auto x0 = ggml_ext_slice(ctx, x, 2, 0, 1);
x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels); x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
x0 = proj_0->forward(ctx, x0); x0 = proj_0->forward(ctx, x0);
auto x1 = ggml_slice(ctx, x, 2, 1, 2); auto x1 = ggml_ext_slice(ctx, x, 2, 1, 2);
x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels); x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
x1 = proj_1->forward(ctx, x1); x1 = proj_1->forward(ctx, x1);
@ -688,13 +688,13 @@ namespace Qwen {
q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] q = ggml_cont(ctx, ggml_ext_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_cont(ctx, ggml_ext_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x; return x;
@ -791,7 +791,7 @@ namespace Qwen {
} }
txt_token_end = image_embeds[i].first; txt_token_end = image_embeds[i].first;
auto txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); auto txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
if (input_embed == nullptr) { if (input_embed == nullptr) {
input_embed = txt_embed; input_embed = txt_embed;
} else { } else {
@ -805,7 +805,7 @@ namespace Qwen {
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1]; txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
txt_token_end = raw_x->ne[1]; txt_token_end = raw_x->ne[1];
auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); auto final_txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1); input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]); GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
@ -1042,16 +1042,16 @@ namespace Qwen {
int64_t pw = params.vision.patch_size; int64_t pw = params.vision.patch_size;
image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw] image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw] image = ggml_cont(ctx, ggml_ext_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw]
image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw] image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw] image = ggml_cont(ctx, ggml_ext_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw]
image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw] image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw] image = ggml_cont(ctx, ggml_ext_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw]
image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw] image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw]
image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw] image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw] image = ggml_cont(ctx, ggml_ext_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw]
image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw] image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw] image = ggml_cont(ctx, ggml_ext_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw]
image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw] image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
return image; return image;
} }
@ -1319,7 +1319,7 @@ namespace Qwen {
print_ggml_tensor(out, false, "out"); print_ggml_tensor(out, false, "out");
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin"); // auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
// ggml_tensor_diff(ref_out, out, 0.01f); // ggml_ext_tensor_diff(ref_out, out, 0.01f);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0); LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else { } else {

View File

@ -360,8 +360,8 @@ namespace Rope {
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
} else { } else {
x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2] x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2]
} }
int64_t offset = x->nb[2] * x->ne[2]; int64_t offset = x->nb[2] * x->ne[2];
@ -402,7 +402,7 @@ namespace Rope {
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] auto x = ggml_ext_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x; return x;
} }
}; // namespace Rope }; // namespace Rope

View File

@ -1010,7 +1010,7 @@ public:
image.data = nullptr; image.data = nullptr;
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image, pixel_values, false); sd_image_f32_to_ggml_tensor(resized_image, pixel_values, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = nullptr; resized_image.data = nullptr;
@ -1047,18 +1047,18 @@ public:
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height); sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
free(image.data); free(image.data);
image.data = nullptr; image.data = nullptr;
sd_image_f32_to_tensor(resized_image, init_img, false); sd_image_f32_to_ggml_tensor(resized_image, init_img, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = nullptr; resized_image.data = nullptr;
} else { } else {
sd_image_to_tensor(init_image, init_img); sd_image_to_ggml_tensor(init_image, init_img);
} }
if (augmentation_level > 0.f) { if (augmentation_level > 0.f) {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_img); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_img);
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
// encode_pixels += torch.randn_like(pixels) * augmentation_level // encode_pixels += torch.randn_like(pixels) * augmentation_level
ggml_tensor_scale(noise, augmentation_level); ggml_ext_tensor_scale_inplace(noise, augmentation_level);
ggml_tensor_add(init_img, noise); ggml_ext_tensor_add_inplace(init_img, noise);
} }
ggml_tensor* moments = vae_encode(work_ctx, init_img); ggml_tensor* moments = vae_encode(work_ctx, init_img);
c_concat = get_first_stage_encoding(work_ctx, moments); c_concat = get_first_stage_encoding(work_ctx, moments);
@ -1086,7 +1086,7 @@ public:
auto new_timesteps = std::vector<float>(init_latent->ne[2], timesteps[0]); auto new_timesteps = std::vector<float>(init_latent->ne[2], timesteps[0]);
if (denoise_mask != nullptr) { if (denoise_mask != nullptr) {
float value = ggml_tensor_get_f32(denoise_mask, 0, 0, 0, 0); float value = ggml_ext_tensor_get_f32(denoise_mask, 0, 0, 0, 0);
if (value == 0.f) { if (value == 0.f) {
new_timesteps[0] = 0.f; new_timesteps[0] = 0.f;
} }
@ -1103,10 +1103,10 @@ public:
for (int64_t i1 = 0; i1 < a->ne[1]; i1++) { for (int64_t i1 = 0; i1 < a->ne[1]; i1++) {
for (int64_t i2 = 0; i2 < a->ne[2]; i2++) { for (int64_t i2 = 0; i2 < a->ne[2]; i2++) {
for (int64_t i3 = 0; i3 < a->ne[3]; i3++) { for (int64_t i3 = 0; i3 < a->ne[3]; i3++) {
float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3); float a_value = ggml_ext_tensor_get_f32(a, i0, i1, i2, i3);
float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3); float b_value = ggml_ext_tensor_get_f32(b, i0, i1, i2, i3);
float mask_value = ggml_tensor_get_f32(mask, i0 % mask->ne[0], i1 % mask->ne[1], i2 % mask->ne[2], i3 % mask->ne[3]); float mask_value = ggml_ext_tensor_get_f32(mask, i0 % mask->ne[0], i1 % mask->ne[1], i2 % mask->ne[2], i3 % mask->ne[3]);
ggml_tensor_set_f32(a, a_value * mask_value + b_value * (1 - mask_value), i0, i1, i2, i3); ggml_ext_tensor_set_f32(a, a_value * mask_value + b_value * (1 - mask_value), i0, i1, i2, i3);
} }
} }
} }
@ -1218,7 +1218,7 @@ public:
copy_ggml_tensor(noised_input, input); copy_ggml_tensor(noised_input, input);
// noised_input = noised_input * c_in // noised_input = noised_input * c_in
ggml_tensor_scale(noised_input, c_in); ggml_ext_tensor_scale_inplace(noised_input, c_in);
if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) { if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) {
apply_mask(noised_input, init_latent, denoise_mask); apply_mask(noised_input, init_latent, denoise_mask);
@ -1446,9 +1446,9 @@ public:
for (int j = 0; j < latent->ne[2]; j++) { for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) { for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) { for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_tensor_get_f32(latent, l, k, j, i); float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
value = (value - mean) * scale_factor / std_; value = (value - mean) * scale_factor / std_;
ggml_tensor_set_f32(latent, value, l, k, j, i); ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
} }
} }
} }
@ -1456,10 +1456,10 @@ public:
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
// pass // pass
} else { } else {
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3); float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
value = (value - shift_factor) * scale_factor; value = (value - shift_factor) * scale_factor;
ggml_tensor_set_f32(latent, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
}); });
} }
} }
@ -1492,9 +1492,9 @@ public:
for (int j = 0; j < latent->ne[2]; j++) { for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) { for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) { for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_tensor_get_f32(latent, l, k, j, i); float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
value = value * std_ / scale_factor + mean; value = value * std_ / scale_factor + mean;
ggml_tensor_set_f32(latent, value, l, k, j, i); ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
} }
} }
} }
@ -1502,10 +1502,10 @@ public:
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
// pass // pass
} else { } else {
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3); float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
value = (value / scale_factor) + shift_factor; value = (value / scale_factor) + shift_factor;
ggml_tensor_set_f32(latent, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
}); });
} }
} }
@ -1606,7 +1606,7 @@ public:
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample // ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]); ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
ggml_tensor_set_f32_randn(noise, rng); ggml_ext_im_set_randn_f32(noise, rng);
{ {
float mean = 0; float mean = 0;
float logvar = 0; float logvar = 0;
@ -1616,13 +1616,13 @@ public:
for (int j = 0; j < latent->ne[2]; j++) { for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) { for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) { for (int l = 0; l < latent->ne[0]; l++) {
mean = ggml_tensor_get_f32(moments, l, k, j, i); mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
logvar = ggml_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i); logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
logvar = std::max(-30.0f, std::min(logvar, 20.0f)); logvar = std::max(-30.0f, std::min(logvar, 20.0f));
std_ = std::exp(0.5f * logvar); std_ = std::exp(0.5f * logvar);
value = mean + std_ * ggml_tensor_get_f32(noise, l, k, j, i); value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
// printf("%d %d %d %d -> %f\n", i, j, k, l, value); // printf("%d %d %d %d -> %f\n", i, j, k, l, value);
ggml_tensor_set_f32(latent, value, l, k, j, i); ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
} }
} }
} }
@ -1725,7 +1725,7 @@ public:
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_tensor_clamp(result, 0.0f, 1.0f); ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
return result; return result;
} }
}; };
@ -2201,9 +2201,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
processed_id_images.push_back(processed_id_image); processed_id_images.push_back(processed_id_image);
} }
ggml_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2); float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2);
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(init_img, value, i0, i1, i2, i3);
}); });
for (auto& image : processed_id_images) { for (auto& image : processed_id_images) {
@ -2276,7 +2276,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
struct ggml_tensor* image_hint = nullptr; struct ggml_tensor* image_hint = nullptr;
if (control_image.data != nullptr) { if (control_image.data != nullptr) {
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(control_image, image_hint); sd_image_to_ggml_tensor(control_image, image_hint);
} }
// Sample // Sample
@ -2289,7 +2289,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
struct ggml_tensor* control_latent = nullptr; struct ggml_tensor* control_latent = nullptr;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) { if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
ggml_tensor_scale(control_latent, control_strength); ggml_ext_tensor_scale_inplace(control_latent, control_strength);
} }
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {
@ -2306,20 +2306,20 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_FLUX_FILL) { if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
// TODO: this might be wrong // TODO: this might be wrong
for (int64_t c = 0; c < init_latent->ne[2]; c++) { for (int64_t c = 0; c < init_latent->ne[2]; c++) {
ggml_tensor_set_f32(empty_latent, 0, x, y, c); ggml_ext_tensor_set_f32(empty_latent, 0, x, y, c);
} }
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
ggml_tensor_set_f32(empty_latent, 1, x, y, c); ggml_ext_tensor_set_f32(empty_latent, 1, x, y, c);
} }
} else if (sd_ctx->sd->version == VERSION_FLEX_2) { } else if (sd_ctx->sd->version == VERSION_FLEX_2) {
for (int64_t c = 0; c < empty_latent->ne[2]; c++) { for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
// 0x16,1x1,0x16 // 0x16,1x1,0x16
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c); ggml_ext_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
} }
} else { } else {
ggml_tensor_set_f32(empty_latent, 1, x, y, 0); ggml_ext_tensor_set_f32(empty_latent, 1, x, y, 0);
for (int64_t c = 1; c < empty_latent->ne[2]; c++) { for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
ggml_tensor_set_f32(empty_latent, 0, x, y, c); ggml_ext_tensor_set_f32(empty_latent, 0, x, y, c);
} }
} }
} }
@ -2336,12 +2336,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
if (no_inpaint) { if (no_inpaint) {
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) { for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
// 0x16,1x1,0x16 // 0x16,1x1,0x16
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c); ggml_ext_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
} }
} }
for (int64_t c = 0; c < control_latent->ne[2]; c++) { for (int64_t c = 0; c < control_latent->ne[2]; c++) {
float v = ggml_tensor_get_f32(control_latent, x, y, c); float v = ggml_ext_tensor_get_f32(control_latent, x, y, c);
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c); ggml_ext_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
} }
} }
} }
@ -2383,7 +2383,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = init_latent; struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
int start_merge_step = -1; int start_merge_step = -1;
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
@ -2454,7 +2454,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
result_images[i].width = width; result_images[i].width = width;
result_images[i].height = height; result_images[i].height = height;
result_images[i].channel = 3; result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(decoded_images[i]); result_images[i].data = ggml_tensor_to_sd_image(decoded_images[i]);
} }
ggml_free(work_ctx); ggml_free(work_ctx);
@ -2529,8 +2529,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img); sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img);
sd_image_to_tensor(sd_img_gen_params->init_image, init_img); sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img);
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
@ -2546,12 +2546,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_ctx->sd->version != VERSION_FLEX_2) { if (sd_ctx->sd->version != VERSION_FLEX_2) {
// most inpaint models mask before vae // most inpaint models mask before vae
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img); ggml_ext_tensor_apply_mask(init_img, mask_img, masked_img);
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
} else { } else {
// mask after vae // mask after vae
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
sd_apply_mask(init_latent, mask_img, masked_latent, 0.); ggml_ext_tensor_apply_mask(init_latent, mask_img, masked_latent, 0.);
} }
concat_latent = ggml_new_tensor_4d(work_ctx, concat_latent = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
@ -2565,30 +2565,30 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
int my = iy * vae_scale_factor; int my = iy * vae_scale_factor;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) { if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_latent->ne[2]; k++) { for (int k = 0; k < masked_latent->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); float v = ggml_ext_tensor_get_f32(masked_latent, ix, iy, k);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k); ggml_ext_tensor_set_f32(concat_latent, v, ix, iy, k);
} }
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
for (int x = 0; x < vae_scale_factor; x++) { for (int x = 0; x < vae_scale_factor; x++) {
for (int y = 0; y < vae_scale_factor; y++) { for (int y = 0; y < vae_scale_factor; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); float m = ggml_ext_tensor_get_f32(mask_img, mx + x, my + y);
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*vae_scale_factor+y or x+vae_scale_factor*y?) // TODO: check if the way the mask is flattened is correct (is it supposed to be x*vae_scale_factor+y or x+vae_scale_factor*y?)
// python code was using "b (h vae_scale_factor) (w vae_scale_factor) -> b (vae_scale_factor vae_scale_factor) h w" // python code was using "b (h vae_scale_factor) (w vae_scale_factor) -> b (vae_scale_factor vae_scale_factor) h w"
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * vae_scale_factor + y); ggml_ext_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * vae_scale_factor + y);
} }
} }
} else if (sd_ctx->sd->version == VERSION_FLEX_2) { } else if (sd_ctx->sd->version == VERSION_FLEX_2) {
float m = ggml_tensor_get_f32(mask_img, mx, my); float m = ggml_ext_tensor_get_f32(mask_img, mx, my);
// masked image // masked image
for (int k = 0; k < masked_latent->ne[2]; k++) { for (int k = 0; k < masked_latent->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); float v = ggml_ext_tensor_get_f32(masked_latent, ix, iy, k);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k); ggml_ext_tensor_set_f32(concat_latent, v, ix, iy, k);
} }
// downsampled mask // downsampled mask
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); ggml_ext_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
// control (todo: support this) // control (todo: support this)
for (int k = 0; k < masked_latent->ne[2]; k++) { for (int k = 0; k < masked_latent->ne[2]; k++) {
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); ggml_ext_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
} }
} }
} }
@ -2602,8 +2602,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
for (int iy = 0; iy < denoise_mask->ne[1]; iy++) { for (int iy = 0; iy < denoise_mask->ne[1]; iy++) {
int mx = ix * vae_scale_factor; int mx = ix * vae_scale_factor;
int my = iy * vae_scale_factor; int my = iy * vae_scale_factor;
float m = ggml_tensor_get_f32(mask_img, mx, my); float m = ggml_ext_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(denoise_mask, m, ix, iy); ggml_ext_tensor_set_f32(denoise_mask, m, ix, iy);
} }
} }
} }
@ -2665,7 +2665,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
resized_image.height, resized_image.height,
3, 3,
1); 1);
sd_image_f32_to_tensor(resized_image, img); sd_image_f32_to_ggml_tensor(resized_image, img);
free(resized_image.data); free(resized_image.data);
resized_image.data = nullptr; resized_image.data = nullptr;
} else { } else {
@ -2675,7 +2675,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ref_images[i]->height, ref_images[i]->height,
3, 3,
1); 1);
sd_image_to_tensor(*ref_images[i], img); sd_image_to_ggml_tensor(*ref_images[i], img);
} }
// print_ggml_tensor(img, false, "img"); // print_ggml_tensor(img, false, "img");
@ -2818,7 +2818,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} else { } else {
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true); end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true);
} }
clip_vision_output = ggml_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1); clip_vision_output = ggml_ext_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1);
} }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
@ -2827,7 +2827,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
ggml_tensor_iter(image, [&](ggml_tensor* image, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(image, [&](ggml_tensor* image, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.5f; float value = 0.5f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3); value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
@ -2836,7 +2836,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3); value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3);
value /= 255.f; value /= 255.f;
} }
ggml_tensor_set_f32(image, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(image, value, i0, i1, i2, i3);
}); });
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor] concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
@ -2850,23 +2850,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
concat_latent->ne[1], concat_latent->ne[1],
concat_latent->ne[2], concat_latent->ne[2],
4); // [b*4, t, w/vae_scale_factor, h/vae_scale_factor] 4); // [b*4, t, w/vae_scale_factor, h/vae_scale_factor]
ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.0f; float value = 0.0f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = 1.0f; value = 1.0f;
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 3) { } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 3) {
value = 1.0f; value = 1.0f;
} }
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
}); });
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/vae_scale_factor, w/vae_scale_factor] concat_latent = ggml_ext_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/vae_scale_factor, w/vae_scale_factor]
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) { } else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
LOG_INFO("IMG2VID"); LOG_INFO("IMG2VID");
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(sd_vid_gen_params->init_image, init_img); sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img);
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3); init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16] auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
@ -2877,11 +2877,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_ctx->sd->process_latent_out(init_latent); sd_ctx->sd->process_latent_out(init_latent);
ggml_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(t, i0, i1, i2, i3); float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(init_latent, value, i0, i1, i2, i3);
if (i3 == 0) { if (i3 == 0) {
ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3); ggml_ext_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3);
} }
}); });
@ -2896,36 +2896,36 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor* ref_image_latent = nullptr; ggml_tensor* ref_image_latent = nullptr;
if (sd_vid_gen_params->init_image.data) { if (sd_vid_gen_params->init_image.data) {
ggml_tensor* ref_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* ref_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(sd_vid_gen_params->init_image, ref_img); sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, ref_img);
ref_img = ggml_reshape_4d(work_ctx, ref_img, width, height, 1, 3); ref_img = ggml_reshape_4d(work_ctx, ref_img, width, height, 1, 3);
ref_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, ref_img); // [b*c, 1, h/16, w/16] ref_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, ref_img); // [b*c, 1, h/16, w/16]
auto zero_latent = ggml_dup_tensor(work_ctx, ref_image_latent); auto zero_latent = ggml_dup_tensor(work_ctx, ref_image_latent);
ggml_set_f32(zero_latent, 0.f); ggml_set_f32(zero_latent, 0.f);
ref_image_latent = ggml_tensor_concat(work_ctx, ref_image_latent, zero_latent, 3); // [b*2*c, 1, h/16, w/16] ref_image_latent = ggml_ext_tensor_concat(work_ctx, ref_image_latent, zero_latent, 3); // [b*2*c, 1, h/16, w/16]
} }
ggml_tensor* control_video = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); ggml_tensor* control_video = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
ggml_tensor_iter(control_video, [&](ggml_tensor* control_video, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(control_video, [&](ggml_tensor* control_video, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.5f; float value = 0.5f;
if (i2 < sd_vid_gen_params->control_frames_size) { if (i2 < sd_vid_gen_params->control_frames_size) {
value = sd_image_get_f32(sd_vid_gen_params->control_frames[i2], i0, i1, i3); value = sd_image_get_f32(sd_vid_gen_params->control_frames[i2], i0, i1, i3);
} }
ggml_tensor_set_f32(control_video, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(control_video, value, i0, i1, i2, i3);
}); });
ggml_tensor* mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 1); ggml_tensor* mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 1);
ggml_set_f32(mask, 1.0f); ggml_set_f32(mask, 1.0f);
ggml_tensor* inactive = ggml_dup_tensor(work_ctx, control_video); ggml_tensor* inactive = ggml_dup_tensor(work_ctx, control_video);
ggml_tensor* reactive = ggml_dup_tensor(work_ctx, control_video); ggml_tensor* reactive = ggml_dup_tensor(work_ctx, control_video);
ggml_tensor_iter(control_video, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(control_video, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float control_video_value = ggml_tensor_get_f32(t, i0, i1, i2, i3) - 0.5f; float control_video_value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3) - 0.5f;
float mask_value = ggml_tensor_get_f32(mask, i0, i1, i2, 0); float mask_value = ggml_ext_tensor_get_f32(mask, i0, i1, i2, 0);
float inactive_value = (control_video_value * (1.f - mask_value)) + 0.5f; float inactive_value = (control_video_value * (1.f - mask_value)) + 0.5f;
float reactive_value = (control_video_value * mask_value) + 0.5f; float reactive_value = (control_video_value * mask_value) + 0.5f;
ggml_tensor_set_f32(inactive, inactive_value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(inactive, inactive_value, i0, i1, i2, i3);
ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3);
}); });
inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor] inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
@ -2938,16 +2938,16 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ref_image_num = 1; ref_image_num = 1;
} }
vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor] vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor]
ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value; float value;
if (i3 < 32) { if (i3 < 32) {
if (ref_image_latent && i2 == 0) { if (ref_image_latent && i2 == 0) {
value = ggml_tensor_get_f32(ref_image_latent, i0, i1, 0, i3); value = ggml_ext_tensor_get_f32(ref_image_latent, i0, i1, 0, i3);
} else { } else {
if (i3 < 16) { if (i3 < 16) {
value = ggml_tensor_get_f32(inactive, i0, i1, i2 - ref_image_num, i3); value = ggml_ext_tensor_get_f32(inactive, i0, i1, i2 - ref_image_num, i3);
} else { } else {
value = ggml_tensor_get_f32(reactive, i0, i1, i2 - ref_image_num, i3 - 16); value = ggml_ext_tensor_get_f32(reactive, i0, i1, i2 - ref_image_num, i3 - 16);
} }
} }
} else { // mask } else { // mask
@ -2957,10 +2957,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t vae_stride = vae_scale_factor; int64_t vae_stride = vae_scale_factor;
int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride; int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride;
int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride; int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride;
value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0); value = ggml_ext_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0);
} }
} }
ggml_tensor_set_f32(vace_context, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(vace_context, value, i0, i1, i2, i3);
}); });
int64_t t2 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@ -3006,7 +3006,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
struct ggml_tensor* final_latent; struct ggml_tensor* final_latent;
struct ggml_tensor* x_t = init_latent; struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
// High Noise Sample // High Noise Sample
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
@ -3088,9 +3088,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
final_latent->ne[1], final_latent->ne[1],
final_latent->ne[2] - ref_image_num, final_latent->ne[2] - ref_image_num,
final_latent->ne[3]); final_latent->ne[3]);
ggml_tensor_iter(trim_latent, [&](ggml_tensor* trim_latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(trim_latent, [&](ggml_tensor* trim_latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(final_latent, i0, i1, i2 + ref_image_num, i3); float value = ggml_ext_tensor_get_f32(final_latent, i0, i1, i2 + ref_image_num, i3);
ggml_tensor_set_f32(trim_latent, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(trim_latent, value, i0, i1, i2, i3);
}); });
final_latent = trim_latent; final_latent = trim_latent;
} }
@ -3115,7 +3115,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
result_images[i].width = vid->ne[0]; result_images[i].width = vid->ne[0];
result_images[i].height = vid->ne[1]; result_images[i].height = vid->ne[1];
result_images[i].channel = 3; result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(vid, i, true); result_images[i].data = ggml_tensor_to_sd_image(vid, i, true);
} }
ggml_free(work_ctx); ggml_free(work_ctx);

2
t5.hpp
View File

@ -611,7 +611,7 @@ public:
k = ggml_scale_inplace(ctx, k, sqrt(d_head)); k = ggml_scale_inplace(ctx, k, sqrt(d_head));
x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
x = out_proj->forward(ctx, x); // [N, n_token, model_dim] x = out_proj->forward(ctx, x); // [N, n_token, model_dim]
return {x, past_bias}; return {x, past_bias};

View File

@ -114,7 +114,7 @@ public:
auto num_frames = ggml_arange(ctx, 0, timesteps, 1); auto num_frames = ggml_arange(ctx, 0, timesteps, 1);
// since b is 1, no need to do repeat // since b is 1, no need to do repeat
auto t_emb = ggml_nn_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels] auto t_emb = ggml_ext_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
auto emb = time_pos_embed_0->forward(ctx, t_emb); auto emb = time_pos_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb); emb = ggml_silu_inplace(ctx, emb);
@ -451,7 +451,7 @@ public:
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]); auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]); auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
auto t_emb = ggml_nn_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels] auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb); auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb); emb = ggml_silu_inplace(ctx, emb);

View File

@ -82,7 +82,7 @@ struct UpscalerGGML {
} }
// LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); // LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
ggml_tensor* input_image_tensor = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, input_image.width, input_image.height, 3, 1); ggml_tensor* input_image_tensor = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, input_image.width, input_image.height, 3, 1);
sd_image_to_tensor(input_image, input_image_tensor); sd_image_to_ggml_tensor(input_image, input_image_tensor);
ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1); ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
@ -91,8 +91,8 @@ struct UpscalerGGML {
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling); sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling);
esrgan_upscaler->free_compute_buffer(); esrgan_upscaler->free_compute_buffer();
ggml_tensor_clamp(upscaled, 0.f, 1.f); ggml_ext_tensor_clamp_inplace(upscaled, 0.f, 1.f);
uint8_t* upscaled_data = sd_tensor_to_image(upscaled); uint8_t* upscaled_data = ggml_tensor_to_sd_image(upscaled);
ggml_free(upscale_ctx); ggml_free(upscale_ctx);
int64_t t3 = ggml_time_ms(); int64_t t3 = ggml_time_ms();
LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f); LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f);

10
vae.hpp
View File

@ -102,7 +102,7 @@ public:
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w] v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_ext_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w] h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
@ -169,7 +169,7 @@ protected:
} }
float get_alpha() { float get_alpha() {
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]); float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha); return sigmoid(alpha);
} }
@ -544,9 +544,9 @@ struct FakeVAE : public VAE {
if (*output == nullptr && output_ctx != nullptr) { if (*output == nullptr && output_ctx != nullptr) {
*output = ggml_dup_tensor(output_ctx, z); *output = ggml_dup_tensor(output_ctx, z);
} }
ggml_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(z, i0, i1, i2, i3); float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3);
ggml_tensor_set_f32(*output, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3);
}); });
} }

174
wan.hpp
View File

@ -76,10 +76,10 @@ namespace WAN {
} }
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
return ggml_nn_conv_3d(ctx, x, w, b, in_channels, return ggml_ext_conv_3d(ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
0, 0, 0, 0, 0, 0,
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
} }
}; };
@ -101,10 +101,10 @@ namespace WAN {
// assert N == 1 // assert N == 1
struct ggml_tensor* w = params["gamma"]; struct ggml_tensor* w = params["gamma"];
auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] auto h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx, h, 1e-12); h = ggml_rms_norm(ctx, h, 1e-12);
h = ggml_mul(ctx, h, w); h = ggml_mul(ctx, h, w);
h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, h, 1, 2, 3, 0));
return h; return h;
} }
@ -165,11 +165,11 @@ namespace WAN {
} else { } else {
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]); auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2 if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -183,9 +183,9 @@ namespace WAN {
x = time_conv->forward(ctx, x, feat_cache[idx]); x = time_conv->forward(ctx, x, feat_cache[idx]);
} }
feat_cache[idx] = cache_x; feat_cache[idx] = cache_x;
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
} }
} }
} }
@ -194,7 +194,7 @@ namespace WAN {
if (mode != "none") { if (mode != "none") {
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]); auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
if (mode == "upsample2d") { if (mode == "upsample2d") {
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "upsample3d") { } else if (mode == "upsample3d") {
@ -205,7 +205,7 @@ namespace WAN {
x = ggml_pad(ctx, x, 1, 1, 0, 0); x = ggml_pad(ctx, x, 1, 1, 0, 0);
} }
x = resample_1->forward(ctx, x); x = resample_1->forward(ctx, x);
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
} }
if (mode == "downsample3d") { if (mode == "downsample3d") {
@ -217,9 +217,9 @@ namespace WAN {
} else { } else {
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]); auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
auto cache_x = ggml_slice(ctx, x, 2, -1, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -1, x->ne[2]);
x = ggml_concat(ctx, x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
x, x,
2); 2);
x = time_conv->forward(ctx, x); x = time_conv->forward(ctx, x);
@ -266,15 +266,15 @@ namespace WAN {
T = x->ne[2]; T = x->ne[2];
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W] x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s] x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s] x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1] x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels); x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
return x; return x;
} }
@ -316,15 +316,15 @@ namespace WAN {
C = out_channels; C = out_channels;
x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W] x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s] x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
if (first_chunk) { if (first_chunk) {
x = ggml_slice(ctx, x, 2, factor_t - 1, x->ne[2]); x = ggml_ext_slice(ctx, x, 2, factor_t - 1, x->ne[2]);
} }
return x; return x;
@ -374,11 +374,11 @@ namespace WAN {
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -545,7 +545,7 @@ namespace WAN {
x = norm->forward(ctx, x); x = norm->forward(ctx, x);
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
const int64_t n = x->ne[3]; const int64_t n = x->ne[3];
const int64_t c = x->ne[2]; const int64_t c = x->ne[2];
@ -556,26 +556,26 @@ namespace WAN {
auto qkv_vec = split_image_qkv(ctx, qkv); auto qkv_vec = split_image_qkv(ctx, qkv);
auto q = qkv_vec[0]; auto q = qkv_vec[0];
q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c] q = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c] q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
auto k = qkv_vec[1]; auto k = qkv_vec[1];
k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c] k = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c] k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
auto v = qkv_vec[2]; auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w] v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c] x = ggml_ext_attention(ctx, q, k, v, false); // [t, h * w, c]
// v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] // v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
// x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true); // x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
x = proj->forward(ctx, x); x = proj->forward(ctx, x);
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_add(ctx, x, identity); x = ggml_add(ctx, x, identity);
return x; return x;
@ -673,11 +673,11 @@ namespace WAN {
// conv1 // conv1
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -725,11 +725,11 @@ namespace WAN {
x = ggml_silu(ctx, x); x = ggml_silu(ctx, x);
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -844,11 +844,11 @@ namespace WAN {
// conv1 // conv1
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -896,11 +896,11 @@ namespace WAN {
x = ggml_silu(ctx, x); x = ggml_silu(ctx, x);
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
// cache last frame of last two chunk // cache last frame of last two chunk
cache_x = ggml_concat(ctx, cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x, cache_x,
2); 2);
} }
@ -978,13 +978,13 @@ namespace WAN {
int64_t h = x->ne[1] / q; int64_t h = x->ne[1] / q;
int64_t w = x->ne[0] / r; int64_t w = x->ne[0] / r;
x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r] x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r]
x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r] x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w]
x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w] x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w]
x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w] x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w]
return x; return x;
} }
@ -1005,13 +1005,13 @@ namespace WAN {
int64_t h = x->ne[1]; int64_t h = x->ne[1];
int64_t w = x->ne[0]; int64_t w = x->ne[0];
x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w] x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w]
x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w] x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r]
x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r] x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r]
x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r] x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r]
return x; return x;
} }
@ -1037,16 +1037,16 @@ namespace WAN {
for (int i = 0; i < iter_; i++) { for (int i = 0; i < iter_; i++) {
_enc_conv_idx = 0; _enc_conv_idx = 0;
if (i == 0) { if (i == 0) {
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
} else { } else {
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w] auto in = ggml_ext_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
out = ggml_concat(ctx, out, out_, 2); out = ggml_concat(ctx, out, out_, 2);
} }
} }
out = conv1->forward(ctx, out); out = conv1->forward(ctx, out);
auto mu = ggml_chunk(ctx, out, 2, 3)[0]; auto mu = ggml_ext_chunk(ctx, out, 2, 3)[0];
clear_cache(); clear_cache();
return mu; return mu;
} }
@ -1068,10 +1068,10 @@ namespace WAN {
for (int64_t i = 0; i < iter_; i++) { for (int64_t i = 0; i < iter_; i++) {
_conv_idx = 0; _conv_idx = 0;
if (i == 0) { if (i == 0) {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
} else { } else {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
out = ggml_concat(ctx, out, out_, 2); out = ggml_concat(ctx, out, out_, 2);
} }
@ -1094,7 +1094,7 @@ namespace WAN {
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]); auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
auto x = conv2->forward(ctx, z); auto x = conv2->forward(ctx, z);
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0; _conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
if (wan2_2) { if (wan2_2) {
@ -1197,9 +1197,9 @@ namespace WAN {
for (int64_t i2 = 0; i2 < out->ne[2]; i2++) { for (int64_t i2 = 0; i2 < out->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < out->ne[1]; i1++) { for (int64_t i1 = 0; i1 < out->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < out->ne[0]; i0++) { for (int64_t i0 = 0; i0 < out->ne[0]; i0++) {
float value = ggml_tensor_get_f32(out, i0, i1, i2, i3); float value = ggml_ext_tensor_get_f32(out, i0, i1, i2, i3);
int64_t offset = (i == 0) ? 0 : (1 + (i - 1) * 4); int64_t offset = (i == 0) ? 0 : (1 + (i - 1) * 4);
ggml_tensor_set_f32(*output, value, i0, i1, offset + i2, i3); ggml_ext_tensor_set_f32(*output, value, i0, i1, offset + i2, i3);
} }
} }
} }
@ -1390,7 +1390,7 @@ namespace WAN {
k = norm_k->forward(ctx, k); k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim] auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1441,11 +1441,11 @@ namespace WAN {
int64_t dim = x->ne[0]; int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len; int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] context = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] context_img = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] context_txt = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x); auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q); q = norm_q->forward(ctx, q);
@ -1457,8 +1457,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img); k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_nn_attention_ext(ctx, backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] auto img_x = ggml_ext_attention_ext(ctx, backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim]
x = ggml_add(ctx, x, img_x); x = ggml_add(ctx, x, img_x);
@ -1548,7 +1548,7 @@ namespace WAN {
auto modulation = params["modulation"]; auto modulation = params["modulation"];
e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim] e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim] auto es = ggml_ext_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]); auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]); auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
@ -1672,7 +1672,7 @@ namespace WAN {
e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim] e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim] e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...) auto es = ggml_ext_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]); auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
@ -1713,8 +1713,8 @@ namespace WAN {
if (flf_pos_embed_token_number > 0) { if (flf_pos_embed_token_number > 0) {
auto emb_pos = params["emb_pos"]; auto emb_pos = params["emb_pos"];
auto a = ggml_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]); auto a = ggml_ext_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]);
auto b = ggml_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]); auto b = ggml_ext_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]);
image_embeds = ggml_add(ctx, a, b); image_embeds = ggml_add(ctx, a, b);
} }
@ -1861,13 +1861,13 @@ namespace WAN {
GGML_ASSERT(C * pt * ph * pw == x->ne[0]); GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C] x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw] x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw]
return x; return x;
} }
@ -1904,10 +1904,10 @@ namespace WAN {
// patch_embedding // patch_embedding
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
// time_embedding // time_embedding
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim); auto e = ggml_ext_timestep_embedding(ctx, timestep, params.freq_dim);
e = time_embedding_0->forward(ctx, e); e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx, e); e = ggml_silu_inplace(ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim] e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
@ -1938,7 +1938,7 @@ namespace WAN {
c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len] c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len]
c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
c = ggml_nn_cont(ctx, ggml_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] c = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
} }
auto x_orig = x; auto x_orig = x;
@ -2011,9 +2011,9 @@ namespace WAN {
// slice // slice
out = ggml_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w] out = ggml_ext_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
out = ggml_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w] out = ggml_ext_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
out = ggml_slice(ctx, out, 0, 0, W); // [N*C, T, H, W] out = ggml_ext_slice(ctx, out, 0, 0, W); // [N*C, T, H, W]
return out; return out;
} }