mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
refactor: reuse DiT's patchify/unpatchify functions (#1304)
This commit is contained in:
parent
cec4aedcfd
commit
e64baa3611
102
src/anima.hpp
102
src/anima.hpp
@ -6,81 +6,13 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
|
|
||||||
namespace Anima {
|
namespace Anima {
|
||||||
constexpr int ANIMA_GRAPH_SIZE = 65536;
|
constexpr int ANIMA_GRAPH_SIZE = 65536;
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* patchify_2d(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t patch_size) {
|
|
||||||
// x: [W*r, H*q, T, C]
|
|
||||||
// return: [W, H, T, C*q*r]
|
|
||||||
if (patch_size == 1) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
GGML_ASSERT(x->ne[2] == 1);
|
|
||||||
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t T = x->ne[2];
|
|
||||||
int64_t C = x->ne[3];
|
|
||||||
int64_t p = patch_size;
|
|
||||||
int64_t h = H / p;
|
|
||||||
int64_t w = W / p;
|
|
||||||
|
|
||||||
GGML_ASSERT(T == 1);
|
|
||||||
GGML_ASSERT(h * p == H && w * p == W);
|
|
||||||
|
|
||||||
// Reuse Flux patchify layout on a [W, H, C, N] view.
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, T); // [W, H, C, N]
|
|
||||||
|
|
||||||
// Flux patchify: [N, C, H, W] -> [N, h*w, C*p*p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * T); // [p, w, p, h*C*N]
|
|
||||||
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p, p, w, h*C*N]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, T); // [p*p, h*w, C, N]
|
|
||||||
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p*p, C, h*w, N]
|
|
||||||
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, T); // [C*p*p, h*w, N]
|
|
||||||
|
|
||||||
// Return [w, h, T, C*p*p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p * C, w, h, T); // [C*p*p, w, h, N]
|
|
||||||
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [w, h, N, C*p*p]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* unpatchify_2d(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t patch_size) {
|
|
||||||
// x: [W, H, T, C*q*r]
|
|
||||||
// return: [W*r, H*q, T, C]
|
|
||||||
if (patch_size == 1) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
GGML_ASSERT(x->ne[2] == 1);
|
|
||||||
|
|
||||||
int64_t w = x->ne[0];
|
|
||||||
int64_t h = x->ne[1];
|
|
||||||
int64_t T = x->ne[2];
|
|
||||||
int64_t p = patch_size;
|
|
||||||
int64_t nm = p * p;
|
|
||||||
int64_t Cp = x->ne[3];
|
|
||||||
int64_t C = Cp / nm;
|
|
||||||
int64_t W = w * p;
|
|
||||||
int64_t H = h * p;
|
|
||||||
|
|
||||||
GGML_ASSERT(T == 1);
|
|
||||||
GGML_ASSERT(C * nm == Cp);
|
|
||||||
|
|
||||||
// [w, h, 1, C*p*p] -> [W, H, 1, C]
|
|
||||||
x = ggml_reshape_4d(ctx, x, w, h * C, p, p); // [w, h*C, p2, p1]
|
|
||||||
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 3, 1)); // [p2, w, p1, h*C]
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, T, C); // [W, H, 1, C]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* gate) {
|
struct ggml_tensor* gate) {
|
||||||
@ -491,7 +423,7 @@ namespace Anima {
|
|||||||
int64_t text_embed_dim = 1024;
|
int64_t text_embed_dim = 1024;
|
||||||
int64_t num_heads = 16;
|
int64_t num_heads = 16;
|
||||||
int64_t head_dim = 128;
|
int64_t head_dim = 128;
|
||||||
int64_t patch_size = 2;
|
int patch_size = 2;
|
||||||
int64_t num_layers = 28;
|
int64_t num_layers = 28;
|
||||||
std::vector<int> axes_dim = {44, 42, 42};
|
std::vector<int> axes_dim = {44, 42, 42};
|
||||||
int theta = 10000;
|
int theta = 10000;
|
||||||
@ -533,24 +465,10 @@ namespace Anima {
|
|||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); // [N*C, T, H, W] style
|
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]);
|
||||||
|
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W]
|
||||||
|
|
||||||
int64_t pad_h = (patch_size - H % patch_size) % patch_size;
|
x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw]
|
||||||
int64_t pad_w = (patch_size - W % patch_size) % patch_size;
|
|
||||||
if (pad_h > 0 || pad_w > 0) {
|
|
||||||
x = ggml_ext_pad(ctx->ggml_ctx, x, static_cast<int>(pad_w), static_cast<int>(pad_h), 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], x->ne[2], 1);
|
|
||||||
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 3); // concat mask channel
|
|
||||||
|
|
||||||
x = patchify_2d(ctx->ggml_ctx, x, patch_size); // [C*4, T, H/2, W/2]
|
|
||||||
|
|
||||||
int64_t w_len = x->ne[0];
|
|
||||||
int64_t h_len = x->ne[1];
|
|
||||||
int64_t t_len = x->ne[2];
|
|
||||||
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3], 1);
|
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, n_token, C]
|
|
||||||
|
|
||||||
x = x_embedder->forward(ctx, x);
|
x = x_embedder->forward(ctx, x);
|
||||||
|
|
||||||
@ -586,15 +504,9 @@ namespace Anima {
|
|||||||
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
|
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, n_token, C*4]
|
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [n_token, C*4, N]
|
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W]
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w_len, h_len, t_len, x->ne[1]); // [C*4, T, H/2, W/2]
|
|
||||||
x = unpatchify_2d(ctx->ggml_ctx, x, patch_size); // [C, T, H, W]
|
|
||||||
|
|
||||||
x = ggml_ext_slice(ctx->ggml_ctx, x, 1, 0, H); // [C, T, H, W + pad]
|
|
||||||
x = ggml_ext_slice(ctx->ggml_ctx, x, 0, 0, W); // [C, T, H, W]
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[3], x->ne[2]); // [N, C, H, W]
|
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#ifndef __COMMON_HPP__
|
#ifndef __COMMON_BLOCK_HPP__
|
||||||
#define __COMMON_HPP__
|
#define __COMMON_BLOCK_HPP__
|
||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
@ -590,4 +590,4 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // __COMMON_HPP__
|
#endif // __COMMON_BLOCK_HPP__
|
||||||
108
src/common_dit.hpp
Normal file
108
src/common_dit.hpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#ifndef __COMMON_DIT_HPP__
|
||||||
|
#define __COMMON_DIT_HPP__
|
||||||
|
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
|
namespace DiT {
|
||||||
|
ggml_tensor* patchify(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int pw,
|
||||||
|
int ph,
|
||||||
|
bool patch_last = true) {
|
||||||
|
// x: [N, C, H, W]
|
||||||
|
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t h = H / ph;
|
||||||
|
int64_t w = W / pw;
|
||||||
|
|
||||||
|
GGML_ASSERT(h * ph == H && w * pw == W);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw]
|
||||||
|
if (patch_last) {
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw]
|
||||||
|
x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw]
|
||||||
|
} else {
|
||||||
|
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw]
|
||||||
|
x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C]
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* unpatchify(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int64_t h,
|
||||||
|
int64_t w,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
bool patch_last = true) {
|
||||||
|
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
|
||||||
|
// return: [N, C, H, W]
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t C = x->ne[0] / ph / pw;
|
||||||
|
int64_t H = h * ph;
|
||||||
|
int64_t W = w * pw;
|
||||||
|
|
||||||
|
GGML_ASSERT(C * ph * pw == x->ne[0]);
|
||||||
|
|
||||||
|
if (patch_last) {
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw]
|
||||||
|
} else {
|
||||||
|
x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw]
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw]
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int ph,
|
||||||
|
int pw) {
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
|
int pad_h = (ph - H % ph) % ph;
|
||||||
|
int pad_w = (pw - W % pw) % pw;
|
||||||
|
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
bool patch_last = true) {
|
||||||
|
x = pad_to_patch_size(ctx, x, ph, pw);
|
||||||
|
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int64_t H,
|
||||||
|
int64_t W,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
bool patch_last = true) {
|
||||||
|
int pad_h = (ph - H % ph) % ph;
|
||||||
|
int pad_w = (pw - W % pw) % pw;
|
||||||
|
int64_t h = ((H + pad_h) / ph);
|
||||||
|
int64_t w = ((W + pad_w) / pw);
|
||||||
|
x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w]
|
||||||
|
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
} // namespace DiT
|
||||||
|
|
||||||
|
#endif // __COMMON_DIT_HPP__
|
||||||
@ -1,8 +1,7 @@
|
|||||||
#ifndef __CONTROL_HPP__
|
#ifndef __CONTROL_HPP__
|
||||||
#define __CONTROL_HPP__
|
#define __CONTROL_HPP__
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
|
||||||
#define CONTROL_NET_GRAPH_SIZE 1536
|
#define CONTROL_NET_GRAPH_SIZE 1536
|
||||||
|
|||||||
91
src/flux.hpp
91
src/flux.hpp
@ -4,7 +4,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "common_dit.hpp"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
|
|
||||||
@ -846,70 +846,6 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
|
|
||||||
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
|
||||||
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
|
||||||
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
// x: [N, C, H, W]
|
|
||||||
// return: [N, h*w, C * patch_size * patch_size]
|
|
||||||
int64_t N = x->ne[3];
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t p = params.patch_size;
|
|
||||||
int64_t h = H / params.patch_size;
|
|
||||||
int64_t w = W / params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(h * p == H && w * p == W);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
|
|
||||||
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
x = pad_to_patch_size(ctx, x);
|
|
||||||
x = patchify(ctx->ggml_ctx, x);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t h,
|
|
||||||
int64_t w) {
|
|
||||||
// x: [N, h*w, C*patch_size*patch_size]
|
|
||||||
// return: [N, C, H, W]
|
|
||||||
int64_t N = x->ne[2];
|
|
||||||
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
|
|
||||||
int64_t H = h * params.patch_size;
|
|
||||||
int64_t W = w * params.patch_size;
|
|
||||||
int64_t p = params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(C * p * p == x->ne[0]);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* img,
|
struct ggml_tensor* img,
|
||||||
struct ggml_tensor* txt,
|
struct ggml_tensor* txt,
|
||||||
@ -1060,7 +996,7 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = pad_to_patch_size(ctx, x);
|
auto img = DiT::pad_to_patch_size(ctx, x, params.patch_size, params.patch_size);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
@ -1082,7 +1018,7 @@ namespace Flux {
|
|||||||
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
|
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
|
||||||
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
|
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
|
||||||
|
|
||||||
auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
|
auto nerf_pixels = DiT::patchify(ctx->ggml_ctx, orig_img, patch_size, patch_size); // [N, num_patches, C * patch_size * patch_size]
|
||||||
int64_t num_patches = nerf_pixels->ne[1];
|
int64_t num_patches = nerf_pixels->ne[1];
|
||||||
nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx,
|
nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx,
|
||||||
nerf_pixels,
|
nerf_pixels,
|
||||||
@ -1102,7 +1038,7 @@ namespace Flux {
|
|||||||
|
|
||||||
img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
|
img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
|
||||||
img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
|
img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
|
||||||
img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
|
img_dct = DiT::unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size, patch_size); // [N, nerf_hidden_size, H, W]
|
||||||
|
|
||||||
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
||||||
|
|
||||||
@ -1134,7 +1070,7 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size);
|
||||||
int64_t img_tokens = img->ne[1];
|
int64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (params.version == VERSION_FLUX_FILL) {
|
if (params.version == VERSION_FLUX_FILL) {
|
||||||
@ -1142,8 +1078,8 @@ namespace Flux {
|
|||||||
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
||||||
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
|
|
||||||
masked = process_img(ctx, masked);
|
masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size);
|
||||||
mask = process_img(ctx, mask);
|
mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size);
|
||||||
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
|
||||||
} else if (params.version == VERSION_FLEX_2) {
|
} else if (params.version == VERSION_FLEX_2) {
|
||||||
@ -1152,21 +1088,21 @@ namespace Flux {
|
|||||||
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
||||||
|
|
||||||
masked = process_img(ctx, masked);
|
masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size);
|
||||||
mask = process_img(ctx, mask);
|
mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size);
|
||||||
control = process_img(ctx, control);
|
control = DiT::pad_and_patchify(ctx, control, patch_size, patch_size);
|
||||||
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
|
||||||
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
|
|
||||||
auto control = process_img(ctx, c_concat);
|
auto control = DiT::pad_and_patchify(ctx, c_concat, patch_size, patch_size);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
|
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx, ref);
|
ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1178,8 +1114,7 @@ namespace Flux {
|
|||||||
out = ggml_cont(ctx->ggml_ctx, out);
|
out = ggml_cont(ctx->ggml_ctx, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size); // [N, C, H, W]
|
||||||
out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
#ifndef __LTXV_HPP__
|
#ifndef __LTXV_HPP__
|
||||||
#define __LTXV_HPP__
|
#define __LTXV_HPP__
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
|
|
||||||
namespace LTXV {
|
namespace LTXV {
|
||||||
|
|
||||||
|
|||||||
@ -745,28 +745,6 @@ public:
|
|||||||
return spatial_pos_embed;
|
return spatial_pos_embed;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t h,
|
|
||||||
int64_t w) {
|
|
||||||
// x: [N, H*W, patch_size * patch_size * C]
|
|
||||||
// return: [N, C, H, W]
|
|
||||||
int64_t n = x->ne[2];
|
|
||||||
int64_t c = out_channels;
|
|
||||||
int64_t p = patch_size;
|
|
||||||
h = (h + 1) / p;
|
|
||||||
w = (w + 1) / p;
|
|
||||||
|
|
||||||
GGML_ASSERT(h * w == x->ne[1]);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); // [N, H*W, P*P, C]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, H*W, P*P]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); // [N*C*H, W, P, P]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*H, P, W, P]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); // [N, C, H*P, W*P]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c_mod,
|
struct ggml_tensor* c_mod,
|
||||||
@ -811,11 +789,11 @@ public:
|
|||||||
auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]);
|
auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]);
|
||||||
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
|
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
|
||||||
|
|
||||||
int64_t w = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t h = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
|
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
|
||||||
auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size]
|
auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, H, W); // [1, H*W, hidden_size]
|
||||||
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
|
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
|
||||||
|
|
||||||
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
|
auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
|
||||||
@ -834,7 +812,7 @@ public:
|
|||||||
|
|
||||||
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W]
|
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, /*patch_last*/ false); // [N, C, H, W]
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,9 +3,8 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
|
|
||||||
namespace Qwen {
|
namespace Qwen {
|
||||||
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
|
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
|
||||||
@ -390,69 +389,6 @@ namespace Qwen {
|
|||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
|
|
||||||
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
|
||||||
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
|
||||||
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
// x: [N, C, H, W]
|
|
||||||
// return: [N, h*w, C * patch_size * patch_size]
|
|
||||||
int64_t N = x->ne[3];
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t p = params.patch_size;
|
|
||||||
int64_t h = H / params.patch_size;
|
|
||||||
int64_t w = W / params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(h * p == H && w * p == W);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
|
|
||||||
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
x = pad_to_patch_size(ctx, x);
|
|
||||||
x = patchify(ctx->ggml_ctx, x);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t h,
|
|
||||||
int64_t w) {
|
|
||||||
// x: [N, h*w, C*patch_size*patch_size]
|
|
||||||
// return: [N, C, H, W]
|
|
||||||
int64_t N = x->ne[2];
|
|
||||||
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
|
|
||||||
int64_t H = h * params.patch_size;
|
|
||||||
int64_t W = w * params.patch_size;
|
|
||||||
int64_t p = params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(C * p * p == x->ne[0]);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
@ -512,19 +448,16 @@ namespace Qwen {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
auto img = DiT::pad_and_patchify(ctx, x, params.patch_size, params.patch_size);
|
||||||
int64_t img_tokens = img->ne[1];
|
int64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx, ref);
|
ref = DiT::pad_and_patchify(ctx, ref, params.patch_size, params.patch_size);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
|
||||||
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
|
||||||
|
|
||||||
auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C]
|
auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C]
|
||||||
|
|
||||||
if (out->ne[1] > img_tokens) {
|
if (out->ne[1] > img_tokens) {
|
||||||
@ -533,11 +466,7 @@ namespace Qwen {
|
|||||||
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||||
}
|
}
|
||||||
|
|
||||||
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, params.patch_size, params.patch_size); // [N, C, H, W]
|
||||||
|
|
||||||
// slice
|
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
#ifndef __UNET_HPP__
|
#ifndef __UNET_HPP__
|
||||||
#define __UNET_HPP__
|
#define __UNET_HPP__
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
|
||||||
/*==================================================== UnetModel =====================================================*/
|
/*==================================================== UnetModel =====================================================*/
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
#ifndef __VAE_HPP__
|
#ifndef __VAE_HPP__
|
||||||
#define __VAE_HPP__
|
#define __VAE_HPP__
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
|
|
||||||
/*================================================== AutoEncoderKL ===================================================*/
|
/*================================================== AutoEncoderKL ===================================================*/
|
||||||
|
|
||||||
|
|||||||
@ -5,9 +5,8 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common_block.hpp"
|
||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "ggml_extend.hpp"
|
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
#include "vae.hpp"
|
#include "vae.hpp"
|
||||||
|
|
||||||
|
|||||||
@ -346,69 +346,6 @@ namespace ZImage {
|
|||||||
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
|
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
|
|
||||||
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
|
|
||||||
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
|
|
||||||
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
// x: [N, C, H, W]
|
|
||||||
// return: [N, h*w, patch_size*patch_size*C]
|
|
||||||
int64_t N = x->ne[3];
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
int64_t p = z_image_params.patch_size;
|
|
||||||
int64_t h = H / z_image_params.patch_size;
|
|
||||||
int64_t w = W / z_image_params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(h * p == H && w * p == W);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, p*p]
|
|
||||||
x = ggml_reshape_3d(ctx, x, C * p * p, w * h, N); // [N, h*w, p*p*C]
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) {
|
|
||||||
x = pad_to_patch_size(ctx, x);
|
|
||||||
x = patchify(ctx->ggml_ctx, x);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
int64_t h,
|
|
||||||
int64_t w) {
|
|
||||||
// x: [N, h*w, patch_size*patch_size*C]
|
|
||||||
// return: [N, C, H, W]
|
|
||||||
int64_t N = x->ne[2];
|
|
||||||
int64_t C = x->ne[0] / z_image_params.patch_size / z_image_params.patch_size;
|
|
||||||
int64_t H = h * z_image_params.patch_size;
|
|
||||||
int64_t W = w * z_image_params.patch_size;
|
|
||||||
int64_t p = z_image_params.patch_size;
|
|
||||||
|
|
||||||
GGML_ASSERT(C * p * p == x->ne[0]);
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, C, p * p, w * h, N); // [N, h*w, p*p, C]
|
|
||||||
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, h*w, p*p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward_core(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward_core(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
@ -495,27 +432,22 @@ namespace ZImage {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
int patch_size = z_image_params.patch_size;
|
||||||
|
|
||||||
|
auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size, false);
|
||||||
uint64_t n_img_token = img->ne[1];
|
uint64_t n_img_token = img->ne[1];
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx, ref);
|
ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size, false);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
|
||||||
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
|
||||||
|
|
||||||
auto out = forward_core(ctx, img, timestep, context, pe);
|
auto out = forward_core(ctx, img, timestep, context, pe);
|
||||||
|
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C]
|
||||||
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size, false); // [N, C, H, W]
|
||||||
|
|
||||||
// slice
|
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
|
||||||
|
|
||||||
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
|
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user