mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
fix ggml_nn_attention_ext mask
This commit is contained in:
parent
48d4c1cd0b
commit
aa5566f005
@ -942,6 +942,34 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_c
|
|||||||
return {q, k, v};
|
return {q, k, v};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_full(struct ggml_context* ctx,
|
||||||
|
float value,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
int64_t ne3) {
|
||||||
|
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
||||||
|
auto t = ggml_scale(ctx, one, value); // [1,]
|
||||||
|
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_zeros(struct ggml_context* ctx,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
int64_t ne3) {
|
||||||
|
return ggml_full(ctx, 0.f, ne0, ne1, ne2, ne3);
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ones(struct ggml_context* ctx,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
int64_t ne3) {
|
||||||
|
return ggml_full(ctx, 1.f, ne0, ne1, ne2, ne3);
|
||||||
|
}
|
||||||
|
|
||||||
// q: [N * n_head, n_token, d_head]
|
// q: [N * n_head, n_token, d_head]
|
||||||
// k: [N * n_head, n_k, d_head]
|
// k: [N * n_head, n_k, d_head]
|
||||||
// v: [N * n_head, d_head, n_k]
|
// v: [N * n_head, d_head, n_k]
|
||||||
@ -969,6 +997,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
|
|||||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
||||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
||||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
||||||
|
// mask: [N, L_q, L_k]
|
||||||
// return: [N, L_q, C]
|
// return: [N, L_q, C]
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* q,
|
struct ggml_tensor* q,
|
||||||
@ -1019,7 +1048,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
|
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
|
||||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1046,14 +1074,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
|
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
mask = ggml_transpose(ctx, mask);
|
mask = ggml_transpose(ctx, mask);
|
||||||
|
} else {
|
||||||
if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) {
|
if (kv_pad > 0) {
|
||||||
LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]);
|
mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k]
|
||||||
LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD));
|
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad]
|
||||||
mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0);
|
mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mask pad
|
||||||
|
if (mask != nullptr) {
|
||||||
|
int mask_pad = 0;
|
||||||
|
if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) {
|
||||||
|
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1];
|
||||||
|
}
|
||||||
|
if (mask_pad > 0) {
|
||||||
|
mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad]
|
||||||
|
}
|
||||||
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
|
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
|
||||||
|
// LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad);
|
||||||
}
|
}
|
||||||
|
|
||||||
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
|
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
|
||||||
@ -1271,6 +1310,9 @@ protected:
|
|||||||
struct ggml_context* compute_ctx = NULL;
|
struct ggml_context* compute_ctx = NULL;
|
||||||
struct ggml_gallocr* compute_allocr = NULL;
|
struct ggml_gallocr* compute_allocr = NULL;
|
||||||
|
|
||||||
|
std::vector<float> one_vec = {1.f};
|
||||||
|
ggml_tensor* one_tensor = NULL;
|
||||||
|
|
||||||
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
||||||
|
|
||||||
void alloc_params_ctx() {
|
void alloc_params_ctx() {
|
||||||
@ -1315,12 +1357,29 @@ protected:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void prepare_build_in_tensor_before() {
|
||||||
|
one_tensor = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1);
|
||||||
|
ggml_set_name(one_tensor, "ggml_runner_build_in_tensor:one");
|
||||||
|
set_backend_tensor_data(one_tensor, one_vec.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void prepare_build_in_tensor_after(struct ggml_cgraph* gf) {
|
||||||
|
ggml_build_forward_expand(gf, one_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
|
||||||
|
prepare_build_in_tensor_before();
|
||||||
|
struct ggml_cgraph* gf = get_graph();
|
||||||
|
prepare_build_in_tensor_after(gf);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
|
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
|
||||||
if (compute_allocr != NULL) {
|
if (compute_allocr != NULL) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
reset_compute_ctx();
|
reset_compute_ctx();
|
||||||
struct ggml_cgraph* gf = get_graph();
|
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||||
backend_tensor_data_map.clear();
|
backend_tensor_data_map.clear();
|
||||||
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
|
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
|
||||||
|
|
||||||
@ -1531,7 +1590,7 @@ public:
|
|||||||
}
|
}
|
||||||
alloc_compute_buffer(get_graph);
|
alloc_compute_buffer(get_graph);
|
||||||
reset_compute_ctx();
|
reset_compute_ctx();
|
||||||
struct ggml_cgraph* gf = get_graph();
|
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||||
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
||||||
cpy_data_to_backend_tensor();
|
cpy_data_to_backend_tensor();
|
||||||
if (ggml_backend_is_cpu(runtime_backend)) {
|
if (ggml_backend_is_cpu(runtime_backend)) {
|
||||||
|
|||||||
2
vae.hpp
2
vae.hpp
@ -529,7 +529,7 @@ struct VAE : public GGMLRunner {
|
|||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
struct ggml_context* output_ctx) = 0;
|
struct ggml_context* output_ctx) = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
virtual void enable_conv2d_direct() {};
|
virtual void enable_conv2d_direct(){};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AutoEncoderKL : public VAE {
|
struct AutoEncoderKL : public VAE {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user