diff --git a/flux.hpp b/flux.hpp index c500302..fc30987 100644 --- a/flux.hpp +++ b/flux.hpp @@ -891,6 +891,11 @@ namespace Flux { int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; + if (params.patch_size == 1) { + x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W] + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C] + return x; + } int64_t p = params.patch_size; int64_t h = H / params.patch_size; int64_t w = W / params.patch_size; @@ -925,6 +930,12 @@ namespace Flux { int64_t W = w * params.patch_size; int64_t p = params.patch_size; + if (params.patch_size == 1) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W] + return x; + } + GGML_ASSERT(C * p * p == x->ne[0]); x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]