mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
optimize ggml_ext_chunk
This commit is contained in:
parent
11ab095230
commit
0835e5c227
@ -194,10 +194,12 @@ public:
|
||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||
|
||||
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
|
||||
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
|
||||
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false);
|
||||
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
|
||||
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
|
||||
|
||||
gate = ggml_cont(ctx->ggml_ctx, gate);
|
||||
|
||||
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
||||
|
||||
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
||||
|
||||
@ -732,7 +732,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
||||
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int num,
|
||||
int64_t dim) {
|
||||
int64_t dim,
|
||||
bool cont = true) {
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
GGML_ASSERT(x->ne[dim] % num == 0);
|
||||
|
||||
@ -747,7 +748,9 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||
x = ggml_cont(ctx, x);
|
||||
if (cont) {
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> chunks;
|
||||
@ -760,7 +763,9 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
|
||||
|
||||
if (dim != 3) {
|
||||
chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||
chunk = ggml_cont(ctx, chunk);
|
||||
if (cont) {
|
||||
chunk = ggml_cont(ctx, chunk);
|
||||
}
|
||||
}
|
||||
chunks.push_back(chunk);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user