diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 110bbbc..86fefc4 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -942,6 +942,34 @@ __STATIC_INLINE__ std::vector split_image_qkv(struct ggml_c return {q, k, v}; } +__STATIC_INLINE__ struct ggml_tensor* ggml_full(struct ggml_context* ctx, + float value, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); + auto t = ggml_scale(ctx, one, value); // [1,] + t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] + return t; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_zeros(struct ggml_context* ctx, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + return ggml_full(ctx, 0.f, ne0, ne1, ne2, ne3); +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_ones(struct ggml_context* ctx, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + return ggml_full(ctx, 1.f, ne0, ne1, ne2, ne3); +} + // q: [N * n_head, n_token, d_head] // k: [N * n_head, n_k, d_head] // v: [N * n_head, d_head, n_k] @@ -969,6 +997,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx // q: [N, L_q, C] or [N*n_head, L_q, d_head] // k: [N, L_k, C] or [N*n_head, L_k, d_head] // v: [N, L_k, C] or [N, L_k, n_head, d_head] +// mask: [N, L_q, L_k] // return: [N, L_q, C] __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, struct ggml_tensor* q, @@ -1019,7 +1048,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* if (mask != nullptr) { // TODO(Green-Sky): figure out if we can bend t5 to work too - can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; } @@ -1046,14 +1074,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* if (mask != nullptr) { mask = ggml_transpose(ctx, mask); - - if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { - LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); - LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); - mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } else { + if (kv_pad > 0) { + mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k] + auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad] + mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad] } + } + // mask pad + if (mask != nullptr) { + int mask_pad = 0; + if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) { + mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1]; + } + if (mask_pad > 0) { + mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad] + } mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + // LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad); } kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); @@ -1271,6 +1310,9 @@ protected: struct ggml_context* compute_ctx = NULL; struct ggml_gallocr* compute_allocr = NULL; + std::vector one_vec = {1.f}; + ggml_tensor* one_tensor = NULL; + std::map backend_tensor_data_map; void alloc_params_ctx() { @@ -1315,12 +1357,29 @@ protected: } } + void prepare_build_in_tensor_before() { + one_tensor = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1); + ggml_set_name(one_tensor, "ggml_runner_build_in_tensor:one"); + set_backend_tensor_data(one_tensor, one_vec.data()); + } + + void prepare_build_in_tensor_after(struct ggml_cgraph* gf) { + ggml_build_forward_expand(gf, one_tensor); + } + + struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) { + prepare_build_in_tensor_before(); + struct ggml_cgraph* gf = get_graph(); + prepare_build_in_tensor_after(gf); + return gf; + } + bool alloc_compute_buffer(get_graph_cb_t get_graph) { if (compute_allocr != NULL) { return true; } reset_compute_ctx(); - struct ggml_cgraph* gf = get_graph(); + struct ggml_cgraph* gf = get_compute_graph(get_graph); backend_tensor_data_map.clear(); compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend)); @@ -1531,7 +1590,7 @@ public: } alloc_compute_buffer(get_graph); reset_compute_ctx(); - struct ggml_cgraph* gf = get_graph(); + struct ggml_cgraph* gf = get_compute_graph(get_graph); GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf)); cpy_data_to_backend_tensor(); if (ggml_backend_is_cpu(runtime_backend)) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 7307c45..e4d2aa1 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -101,7 +101,7 @@ enum sd_type_t { // SD_TYPE_IQ4_NL_4_4 = 36, // SD_TYPE_IQ4_NL_4_8 = 37, // SD_TYPE_IQ4_NL_8_8 = 38, - SD_TYPE_MXFP4 = 39, // MXFP4 (1 block) + SD_TYPE_MXFP4 = 39, // MXFP4 (1 block) SD_TYPE_COUNT = 40, }; diff --git a/vae.hpp b/vae.hpp index 61dc2e9..408d32d 100644 --- a/vae.hpp +++ b/vae.hpp @@ -529,7 +529,7 @@ struct VAE : public GGMLRunner { struct ggml_tensor** output, struct ggml_context* output_ctx) = 0; virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; - virtual void enable_conv2d_direct() {}; + virtual void enable_conv2d_direct(){}; }; struct AutoEncoderKL : public VAE {