mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 19:03:35 +00:00
Flux: simplify when patch_size is 1
This commit is contained in:
parent
203d0539fe
commit
37c5e3eca4
11
flux.hpp
11
flux.hpp
@ -891,6 +891,11 @@ namespace Flux {
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
if (params.patch_size == 1) {
|
||||
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
|
||||
return x;
|
||||
}
|
||||
int64_t p = params.patch_size;
|
||||
int64_t h = H / params.patch_size;
|
||||
int64_t w = W / params.patch_size;
|
||||
@ -925,6 +930,12 @@ namespace Flux {
|
||||
int64_t W = w * params.patch_size;
|
||||
int64_t p = params.patch_size;
|
||||
|
||||
if (params.patch_size == 1) {
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
|
||||
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
|
||||
return x;
|
||||
}
|
||||
|
||||
GGML_ASSERT(C * p * p == x->ne[0]);
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user