Flux: simplify when patch_size is 1

This commit is contained in:
Stéphane du Hamel 2025-12-06 16:05:58 +01:00
parent 203d0539fe
commit 37c5e3eca4

View File

@ -891,6 +891,11 @@ namespace Flux {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t W = x->ne[0]; int64_t W = x->ne[0];
if (params.patch_size == 1) {
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
return x;
}
int64_t p = params.patch_size; int64_t p = params.patch_size;
int64_t h = H / params.patch_size; int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size; int64_t w = W / params.patch_size;
@ -925,6 +930,12 @@ namespace Flux {
int64_t W = w * params.patch_size; int64_t W = w * params.patch_size;
int64_t p = params.patch_size; int64_t p = params.patch_size;
if (params.patch_size == 1) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
return x;
}
GGML_ASSERT(C * p * p == x->ne[0]); GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]