mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 18:53:36 +00:00
feat: add seamless texture generation support (#914)
* global bool * reworked circular to global flag * cleaner implementation of tiling support in sd cpp * cleaned rope * working simplified but still need wraps * Further clean of rope * resolve flux conflict * switch to pad op circular only * Set ggml to most recent * Revert ggml temp * Update ggml to most recent * Revert unneded flux change * move circular flag to the GGMLRunnerContext * Pass through circular param in all places where conv is called * fix of constant and minor cleanup * Added back --circular option * Conv2d circular in vae and various models * Fix temporal padding for qwen image and other vaes * Z Image circular tiling * x and y axis seamless only * First attempt at chroma seamless x and y * refactor into pure x and y, almost there * Fix crash on chroma * Refactor into cleaner variable choices * Removed redundant set_circular_enabled * Sync ggml * simplify circular parameter * format code * no need to perform circular pad on the clip * simplify circular_axes setting * unify function naming * remove unnecessary member variables * simplify rope --------- Co-authored-by: Phylliida <phylliidadev@gmail.com> Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
88ec9d30b1
commit
50ff966445
@ -28,7 +28,7 @@ public:
|
|||||||
if (vae_downsample) {
|
if (vae_downsample) {
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
||||||
|
|
||||||
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
x = conv->forward(ctx, x);
|
x = conv->forward(ctx, x);
|
||||||
} else {
|
} else {
|
||||||
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
|
||||||
|
|||||||
@ -39,6 +39,7 @@ struct DiffusionModel {
|
|||||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
||||||
virtual int64_t get_adm_in_channels() = 0;
|
virtual int64_t get_adm_in_channels() = 0;
|
||||||
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
||||||
|
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct UNetModel : public DiffusionModel {
|
struct UNetModel : public DiffusionModel {
|
||||||
@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel {
|
|||||||
unet.set_flash_attention_enabled(enabled);
|
unet.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
unet.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel {
|
|||||||
mmdit.set_flash_attention_enabled(enabled);
|
mmdit.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
mmdit.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel {
|
|||||||
flux.set_flash_attention_enabled(enabled);
|
flux.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
flux.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel {
|
|||||||
wan.set_flash_attention_enabled(enabled);
|
wan.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
wan.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
qwen_image.set_flash_attention_enabled(enabled);
|
qwen_image.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
qwen_image.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel {
|
|||||||
z_image.set_flash_attention_enabled(enabled);
|
z_image.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
z_image.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
bool compute(int n_threads,
|
bool compute(int n_threads,
|
||||||
DiffusionParams diffusion_params,
|
DiffusionParams diffusion_params,
|
||||||
struct ggml_tensor** output = nullptr,
|
struct ggml_tensor** output = nullptr,
|
||||||
|
|||||||
@ -449,6 +449,10 @@ struct SDContextParams {
|
|||||||
bool diffusion_conv_direct = false;
|
bool diffusion_conv_direct = false;
|
||||||
bool vae_conv_direct = false;
|
bool vae_conv_direct = false;
|
||||||
|
|
||||||
|
bool circular = false;
|
||||||
|
bool circular_x = false;
|
||||||
|
bool circular_y = false;
|
||||||
|
|
||||||
bool chroma_use_dit_mask = true;
|
bool chroma_use_dit_mask = true;
|
||||||
bool chroma_use_t5_mask = false;
|
bool chroma_use_t5_mask = false;
|
||||||
int chroma_t5_mask_pad = 1;
|
int chroma_t5_mask_pad = 1;
|
||||||
@ -605,6 +609,18 @@ struct SDContextParams {
|
|||||||
"--vae-conv-direct",
|
"--vae-conv-direct",
|
||||||
"use ggml_conv2d_direct in the vae model",
|
"use ggml_conv2d_direct in the vae model",
|
||||||
true, &vae_conv_direct},
|
true, &vae_conv_direct},
|
||||||
|
{"",
|
||||||
|
"--circular",
|
||||||
|
"enable circular padding for convolutions",
|
||||||
|
true, &circular},
|
||||||
|
{"",
|
||||||
|
"--circularx",
|
||||||
|
"enable circular RoPE wrapping on x-axis (width) only",
|
||||||
|
true, &circular_x},
|
||||||
|
{"",
|
||||||
|
"--circulary",
|
||||||
|
"enable circular RoPE wrapping on y-axis (height) only",
|
||||||
|
true, &circular_y},
|
||||||
{"",
|
{"",
|
||||||
"--chroma-disable-dit-mask",
|
"--chroma-disable-dit-mask",
|
||||||
"disable dit mask for chroma",
|
"disable dit mask for chroma",
|
||||||
@ -868,6 +884,9 @@ struct SDContextParams {
|
|||||||
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
|
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
|
||||||
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
|
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
|
||||||
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
|
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
|
||||||
|
<< " circular: " << (circular ? "true" : "false") << ",\n"
|
||||||
|
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
|
||||||
|
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
|
||||||
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
|
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
|
||||||
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
|
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
|
||||||
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
|
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
|
||||||
@ -928,6 +947,8 @@ struct SDContextParams {
|
|||||||
taesd_preview,
|
taesd_preview,
|
||||||
diffusion_conv_direct,
|
diffusion_conv_direct,
|
||||||
vae_conv_direct,
|
vae_conv_direct,
|
||||||
|
circular || circular_x,
|
||||||
|
circular || circular_y,
|
||||||
force_sdxl_vae_conv_scale,
|
force_sdxl_vae_conv_scale,
|
||||||
chroma_use_dit_mask,
|
chroma_use_dit_mask,
|
||||||
chroma_use_t5_mask,
|
chroma_use_t5_mask,
|
||||||
|
|||||||
28
flux.hpp
28
flux.hpp
@ -860,14 +860,14 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
||||||
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -893,11 +893,11 @@ namespace Flux {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
x = pad_to_patch_size(ctx, x);
|
x = pad_to_patch_size(ctx, x);
|
||||||
x = patchify(ctx, x);
|
x = patchify(ctx->ggml_ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1076,7 +1076,7 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
|
auto img = pad_to_patch_size(ctx, x);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
if (params.chroma_radiance_params.use_patch_size_32) {
|
||||||
@ -1150,7 +1150,7 @@ namespace Flux {
|
|||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
auto img = process_img(ctx->ggml_ctx, x);
|
auto img = process_img(ctx, x);
|
||||||
uint64_t img_tokens = img->ne[1];
|
uint64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (params.version == VERSION_FLUX_FILL) {
|
if (params.version == VERSION_FLUX_FILL) {
|
||||||
@ -1158,8 +1158,8 @@ namespace Flux {
|
|||||||
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
|
||||||
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
|
|
||||||
masked = process_img(ctx->ggml_ctx, masked);
|
masked = process_img(ctx, masked);
|
||||||
mask = process_img(ctx->ggml_ctx, mask);
|
mask = process_img(ctx, mask);
|
||||||
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
|
||||||
} else if (params.version == VERSION_FLEX_2) {
|
} else if (params.version == VERSION_FLEX_2) {
|
||||||
@ -1168,21 +1168,21 @@ namespace Flux {
|
|||||||
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
||||||
|
|
||||||
masked = process_img(ctx->ggml_ctx, masked);
|
masked = process_img(ctx, masked);
|
||||||
mask = process_img(ctx->ggml_ctx, mask);
|
mask = process_img(ctx, mask);
|
||||||
control = process_img(ctx->ggml_ctx, control);
|
control = process_img(ctx, control);
|
||||||
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
|
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
|
||||||
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
|
|
||||||
auto control = process_img(ctx->ggml_ctx, c_concat);
|
auto control = process_img(ctx, c_concat);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
|
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx->ggml_ctx, ref);
|
ref = process_img(ctx, ref);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1472,6 +1472,8 @@ namespace Flux {
|
|||||||
increase_ref_index,
|
increase_ref_index,
|
||||||
flux_params.ref_index_scale,
|
flux_params.ref_index_scale,
|
||||||
flux_params.theta,
|
flux_params.theta,
|
||||||
|
circular_y_enabled,
|
||||||
|
circular_x_enabled,
|
||||||
flux_params.axes_dim);
|
flux_params.axes_dim);
|
||||||
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
|
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
|
||||||
// LOG_DEBUG("pos_len %d", pos_len);
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
|
|||||||
2
ggml
2
ggml
@ -1 +1 @@
|
|||||||
Subproject commit f5425c0ee5e582a7d64411f06139870bff3e52e0
|
Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff
|
||||||
@ -5,6 +5,7 @@
|
|||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
#include <stdarg.h>
|
#include <stdarg.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <atomic>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
@ -993,6 +994,48 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad_ext(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int lp0,
|
||||||
|
int rp0,
|
||||||
|
int lp1,
|
||||||
|
int rp1,
|
||||||
|
int lp2,
|
||||||
|
int rp2,
|
||||||
|
int lp3,
|
||||||
|
int rp3,
|
||||||
|
bool circular_x = false,
|
||||||
|
bool circular_y = false) {
|
||||||
|
if (circular_x && circular_y) {
|
||||||
|
return ggml_pad_ext_circular(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (circular_x && (lp0 != 0 || rp0 != 0)) {
|
||||||
|
x = ggml_pad_ext_circular(ctx, x, lp0, rp0, 0, 0, 0, 0, 0, 0);
|
||||||
|
lp0 = rp0 = 0;
|
||||||
|
}
|
||||||
|
if (circular_y && (lp1 != 0 || rp1 != 0)) {
|
||||||
|
x = ggml_pad_ext_circular(ctx, x, 0, 0, lp1, rp1, 0, 0, 0, 0);
|
||||||
|
lp1 = rp1 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lp0 != 0 || rp0 != 0 || lp1 != 0 || rp1 != 0 || lp2 != 0 || rp2 != 0 || lp3 != 0 || rp3 != 0) {
|
||||||
|
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int p2 = 0,
|
||||||
|
int p3 = 0,
|
||||||
|
bool circular_x = false,
|
||||||
|
bool circular_y = false) {
|
||||||
|
return ggml_ext_pad_ext(ctx, x, p0, p0, p1, p1, p2, p2, p3, p3, circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
// w: [OC,IC, KH, KW]
|
// w: [OC,IC, KH, KW]
|
||||||
// x: [N, IC, IH, IW]
|
// x: [N, IC, IH, IW]
|
||||||
// b: [OC,]
|
// b: [OC,]
|
||||||
@ -1008,6 +1051,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
int d0 = 1,
|
int d0 = 1,
|
||||||
int d1 = 1,
|
int d1 = 1,
|
||||||
bool direct = false,
|
bool direct = false,
|
||||||
|
bool circular_x = false,
|
||||||
|
bool circular_y = false,
|
||||||
float scale = 1.f) {
|
float scale = 1.f) {
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, scale);
|
x = ggml_scale(ctx, x, scale);
|
||||||
@ -1015,6 +1060,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
||||||
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((p0 != 0 || p1 != 0) && (circular_x || circular_y)) {
|
||||||
|
x = ggml_ext_pad(ctx, x, p0, p1, 0, 0, circular_x, circular_y);
|
||||||
|
p0 = 0;
|
||||||
|
p1 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (direct) {
|
if (direct) {
|
||||||
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
||||||
} else {
|
} else {
|
||||||
@ -1528,6 +1580,8 @@ struct WeightAdapter {
|
|||||||
int d0 = 1;
|
int d0 = 1;
|
||||||
int d1 = 1;
|
int d1 = 1;
|
||||||
bool direct = false;
|
bool direct = false;
|
||||||
|
bool circular_x = false;
|
||||||
|
bool circular_y = false;
|
||||||
float scale = 1.f;
|
float scale = 1.f;
|
||||||
} conv2d;
|
} conv2d;
|
||||||
};
|
};
|
||||||
@ -1546,6 +1600,8 @@ struct GGMLRunnerContext {
|
|||||||
ggml_context* ggml_ctx = nullptr;
|
ggml_context* ggml_ctx = nullptr;
|
||||||
bool flash_attn_enabled = false;
|
bool flash_attn_enabled = false;
|
||||||
bool conv2d_direct_enabled = false;
|
bool conv2d_direct_enabled = false;
|
||||||
|
bool circular_x_enabled = false;
|
||||||
|
bool circular_y_enabled = false;
|
||||||
std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
|
std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1582,6 +1638,8 @@ protected:
|
|||||||
|
|
||||||
bool flash_attn_enabled = false;
|
bool flash_attn_enabled = false;
|
||||||
bool conv2d_direct_enabled = false;
|
bool conv2d_direct_enabled = false;
|
||||||
|
bool circular_x_enabled = false;
|
||||||
|
bool circular_y_enabled = false;
|
||||||
|
|
||||||
void alloc_params_ctx() {
|
void alloc_params_ctx() {
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
@ -1859,6 +1917,8 @@ public:
|
|||||||
runner_ctx.backend = runtime_backend;
|
runner_ctx.backend = runtime_backend;
|
||||||
runner_ctx.flash_attn_enabled = flash_attn_enabled;
|
runner_ctx.flash_attn_enabled = flash_attn_enabled;
|
||||||
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
|
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
|
||||||
|
runner_ctx.circular_x_enabled = circular_x_enabled;
|
||||||
|
runner_ctx.circular_y_enabled = circular_y_enabled;
|
||||||
runner_ctx.weight_adapter = weight_adapter;
|
runner_ctx.weight_adapter = weight_adapter;
|
||||||
return runner_ctx;
|
return runner_ctx;
|
||||||
}
|
}
|
||||||
@ -2003,6 +2063,11 @@ public:
|
|||||||
conv2d_direct_enabled = enabled;
|
conv2d_direct_enabled = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) {
|
||||||
|
circular_x_enabled = circular_x;
|
||||||
|
circular_y_enabled = circular_y;
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
|
||||||
weight_adapter = adapter;
|
weight_adapter = adapter;
|
||||||
}
|
}
|
||||||
@ -2274,6 +2339,8 @@ public:
|
|||||||
forward_params.conv2d.d0 = dilation.second;
|
forward_params.conv2d.d0 = dilation.second;
|
||||||
forward_params.conv2d.d1 = dilation.first;
|
forward_params.conv2d.d1 = dilation.first;
|
||||||
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||||
|
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
|
||||||
|
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
|
||||||
forward_params.conv2d.scale = scale;
|
forward_params.conv2d.scale = scale;
|
||||||
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
||||||
}
|
}
|
||||||
@ -2288,6 +2355,8 @@ public:
|
|||||||
dilation.second,
|
dilation.second,
|
||||||
dilation.first,
|
dilation.first,
|
||||||
ctx->conv2d_direct_enabled,
|
ctx->conv2d_direct_enabled,
|
||||||
|
ctx->circular_x_enabled,
|
||||||
|
ctx->circular_y_enabled,
|
||||||
scale);
|
scale);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
8
lora.hpp
8
lora.hpp
@ -599,6 +599,8 @@ struct LoraModel : public GGMLRunner {
|
|||||||
forward_params.conv2d.d0,
|
forward_params.conv2d.d0,
|
||||||
forward_params.conv2d.d1,
|
forward_params.conv2d.d1,
|
||||||
forward_params.conv2d.direct,
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.circular_x,
|
||||||
|
forward_params.conv2d.circular_y,
|
||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
if (lora_mid) {
|
if (lora_mid) {
|
||||||
lx = ggml_ext_conv_2d(ctx,
|
lx = ggml_ext_conv_2d(ctx,
|
||||||
@ -612,6 +614,8 @@ struct LoraModel : public GGMLRunner {
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
forward_params.conv2d.direct,
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.circular_x,
|
||||||
|
forward_params.conv2d.circular_y,
|
||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
}
|
}
|
||||||
lx = ggml_ext_conv_2d(ctx,
|
lx = ggml_ext_conv_2d(ctx,
|
||||||
@ -625,6 +629,8 @@ struct LoraModel : public GGMLRunner {
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
forward_params.conv2d.direct,
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.circular_x,
|
||||||
|
forward_params.conv2d.circular_y,
|
||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -779,6 +785,8 @@ public:
|
|||||||
forward_params.conv2d.d0,
|
forward_params.conv2d.d0,
|
||||||
forward_params.conv2d.d1,
|
forward_params.conv2d.d1,
|
||||||
forward_params.conv2d.direct,
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.circular_x,
|
||||||
|
forward_params.conv2d.circular_y,
|
||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
}
|
}
|
||||||
for (auto& lora_model : lora_models) {
|
for (auto& lora_model : lora_models) {
|
||||||
|
|||||||
@ -354,14 +354,14 @@ namespace Qwen {
|
|||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
||||||
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,10 +387,10 @@ namespace Qwen {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
x = pad_to_patch_size(ctx, x);
|
x = pad_to_patch_size(ctx, x);
|
||||||
x = patchify(ctx, x);
|
x = patchify(ctx->ggml_ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -466,12 +466,12 @@ namespace Qwen {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx->ggml_ctx, x);
|
auto img = process_img(ctx, x);
|
||||||
uint64_t img_tokens = img->ne[1];
|
uint64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx->ggml_ctx, ref);
|
ref = process_img(ctx, ref);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -565,6 +565,8 @@ namespace Qwen {
|
|||||||
ref_latents,
|
ref_latents,
|
||||||
increase_ref_index,
|
increase_ref_index,
|
||||||
qwen_image_params.theta,
|
qwen_image_params.theta,
|
||||||
|
circular_y_enabled,
|
||||||
|
circular_x_enabled,
|
||||||
qwen_image_params.axes_dim);
|
qwen_image_params.axes_dim);
|
||||||
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
||||||
// LOG_DEBUG("pos_len %d", pos_len);
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
|
|||||||
158
rope.hpp
158
rope.hpp
@ -1,6 +1,8 @@
|
|||||||
#ifndef __ROPE_HPP__
|
#ifndef __ROPE_HPP__
|
||||||
#define __ROPE_HPP__
|
#define __ROPE_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
@ -39,7 +41,10 @@ namespace Rope {
|
|||||||
return flat_vec;
|
return flat_vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
|
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
|
||||||
|
int dim,
|
||||||
|
int theta,
|
||||||
|
const std::vector<int>& axis_wrap_dims = {}) {
|
||||||
assert(dim % 2 == 0);
|
assert(dim % 2 == 0);
|
||||||
int half_dim = dim / 2;
|
int half_dim = dim / 2;
|
||||||
|
|
||||||
@ -47,14 +52,31 @@ namespace Rope {
|
|||||||
|
|
||||||
std::vector<float> omega(half_dim);
|
std::vector<float> omega(half_dim);
|
||||||
for (int i = 0; i < half_dim; ++i) {
|
for (int i = 0; i < half_dim; ++i) {
|
||||||
omega[i] = 1.0 / std::pow(theta, scale[i]);
|
omega[i] = 1.0f / std::pow(theta, scale[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int pos_size = pos.size();
|
int pos_size = pos.size();
|
||||||
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
||||||
for (int i = 0; i < pos_size; ++i) {
|
for (int i = 0; i < pos_size; ++i) {
|
||||||
for (int j = 0; j < half_dim; ++j) {
|
for (int j = 0; j < half_dim; ++j) {
|
||||||
out[i][j] = pos[i] * omega[j];
|
float angle = pos[i] * omega[j];
|
||||||
|
if (!axis_wrap_dims.empty()) {
|
||||||
|
size_t wrap_size = axis_wrap_dims.size();
|
||||||
|
// mod batch size since we only store this for one item in the batch
|
||||||
|
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
|
||||||
|
int wrap_dim = axis_wrap_dims[wrap_idx];
|
||||||
|
if (wrap_dim > 0) {
|
||||||
|
constexpr float TWO_PI = 6.28318530717958647692f;
|
||||||
|
float cycles = omega[j] * wrap_dim / TWO_PI;
|
||||||
|
// closest periodic harmonic, necessary to ensure things neatly tile
|
||||||
|
// without this round, things don't tile at the boundaries and you end up
|
||||||
|
// with the model knowing what is "center"
|
||||||
|
float rounded = std::round(cycles);
|
||||||
|
angle = pos[i] * TWO_PI * rounded / wrap_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out[i][j] = angle;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +168,8 @@ namespace Rope {
|
|||||||
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
|
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
|
||||||
int bs,
|
int bs,
|
||||||
int theta,
|
int theta,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim,
|
||||||
|
const std::vector<std::vector<int>>& wrap_dims = {}) {
|
||||||
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
||||||
size_t pos_len = ids.size() / bs;
|
size_t pos_len = ids.size() / bs;
|
||||||
int num_axes = axes_dim.size();
|
int num_axes = axes_dim.size();
|
||||||
@ -161,7 +184,12 @@ namespace Rope {
|
|||||||
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
for (int i = 0; i < num_axes; ++i) {
|
for (int i = 0; i < num_axes; ++i) {
|
||||||
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
std::vector<int> axis_wrap_dims;
|
||||||
|
if (!wrap_dims.empty() && i < (int)wrap_dims.size()) {
|
||||||
|
axis_wrap_dims = wrap_dims[i];
|
||||||
|
}
|
||||||
|
std::vector<std::vector<float>> rope_emb =
|
||||||
|
rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
||||||
for (int b = 0; b < bs; ++b) {
|
for (int b = 0; b < bs; ++b) {
|
||||||
for (int j = 0; j < pos_len; ++j) {
|
for (int j = 0; j < pos_len; ++j) {
|
||||||
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
||||||
@ -251,6 +279,8 @@ namespace Rope {
|
|||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
float ref_index_scale,
|
float ref_index_scale,
|
||||||
int theta,
|
int theta,
|
||||||
|
bool circular_h,
|
||||||
|
bool circular_w,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim) {
|
||||||
std::vector<std::vector<float>> ids = gen_flux_ids(h,
|
std::vector<std::vector<float>> ids = gen_flux_ids(h,
|
||||||
w,
|
w,
|
||||||
@ -262,7 +292,47 @@ namespace Rope {
|
|||||||
ref_latents,
|
ref_latents,
|
||||||
increase_ref_index,
|
increase_ref_index,
|
||||||
ref_index_scale);
|
ref_index_scale);
|
||||||
return embed_nd(ids, bs, theta, axes_dim);
|
std::vector<std::vector<int>> wrap_dims;
|
||||||
|
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
|
||||||
|
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||||
|
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||||
|
if (h_len > 0 && w_len > 0) {
|
||||||
|
size_t pos_len = ids.size() / bs;
|
||||||
|
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
|
||||||
|
size_t cursor = context_len; // text first
|
||||||
|
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
|
||||||
|
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
|
||||||
|
if (circular_h) {
|
||||||
|
wrap_dims[1][cursor + token_i] = h_len;
|
||||||
|
}
|
||||||
|
if (circular_w) {
|
||||||
|
wrap_dims[2][cursor + token_i] = w_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor += img_tokens;
|
||||||
|
// reference latents
|
||||||
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
|
if (ref == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int ref_h = static_cast<int>(ref->ne[1]);
|
||||||
|
int ref_w = static_cast<int>(ref->ne[0]);
|
||||||
|
int ref_h_l = (ref_h + (patch_size / 2)) / patch_size;
|
||||||
|
int ref_w_l = (ref_w + (patch_size / 2)) / patch_size;
|
||||||
|
size_t ref_tokens = static_cast<size_t>(ref_h_l) * static_cast<size_t>(ref_w_l);
|
||||||
|
for (size_t token_i = 0; token_i < ref_tokens; ++token_i) {
|
||||||
|
if (circular_h) {
|
||||||
|
wrap_dims[1][cursor + token_i] = ref_h_l;
|
||||||
|
}
|
||||||
|
if (circular_w) {
|
||||||
|
wrap_dims[2][cursor + token_i] = ref_w_l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor += ref_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
|
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
|
||||||
@ -301,9 +371,57 @@ namespace Rope {
|
|||||||
const std::vector<ggml_tensor*>& ref_latents,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
int theta,
|
int theta,
|
||||||
|
bool circular_h,
|
||||||
|
bool circular_w,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim) {
|
||||||
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
|
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
|
||||||
return embed_nd(ids, bs, theta, axes_dim);
|
std::vector<std::vector<int>> wrap_dims;
|
||||||
|
// This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles
|
||||||
|
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
|
||||||
|
int pad_h = (patch_size - (h % patch_size)) % patch_size;
|
||||||
|
int pad_w = (patch_size - (w % patch_size)) % patch_size;
|
||||||
|
int h_len = (h + pad_h) / patch_size;
|
||||||
|
int w_len = (w + pad_w) / patch_size;
|
||||||
|
if (h_len > 0 && w_len > 0) {
|
||||||
|
const size_t total_tokens = ids.size();
|
||||||
|
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
|
||||||
|
wrap_dims.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
|
||||||
|
size_t cursor = context_len; // ignore text tokens
|
||||||
|
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
|
||||||
|
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
|
||||||
|
if (circular_h) {
|
||||||
|
wrap_dims[1][cursor + token_i] = h_len;
|
||||||
|
}
|
||||||
|
if (circular_w) {
|
||||||
|
wrap_dims[2][cursor + token_i] = w_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor += img_tokens;
|
||||||
|
// For each reference image, store wrap sizes as well
|
||||||
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
|
if (ref == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int ref_h = static_cast<int>(ref->ne[1]);
|
||||||
|
int ref_w = static_cast<int>(ref->ne[0]);
|
||||||
|
int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size;
|
||||||
|
int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size;
|
||||||
|
int ref_h_len = (ref_h + ref_pad_h) / patch_size;
|
||||||
|
int ref_w_len = (ref_w + ref_pad_w) / patch_size;
|
||||||
|
size_t ref_n_tokens = static_cast<size_t>(ref_h_len) * static_cast<size_t>(ref_w_len);
|
||||||
|
for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) {
|
||||||
|
if (circular_h) {
|
||||||
|
wrap_dims[1][cursor + token_i] = ref_h_len;
|
||||||
|
}
|
||||||
|
if (circular_w) {
|
||||||
|
wrap_dims[2][cursor + token_i] = ref_w_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cursor += ref_n_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
|
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
|
||||||
@ -440,9 +558,33 @@ namespace Rope {
|
|||||||
const std::vector<ggml_tensor*>& ref_latents,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
int theta,
|
int theta,
|
||||||
|
bool circular_h,
|
||||||
|
bool circular_w,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim) {
|
||||||
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
|
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
|
||||||
return embed_nd(ids, bs, theta, axes_dim);
|
std::vector<std::vector<int>> wrap_dims;
|
||||||
|
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
|
||||||
|
int pad_h = (patch_size - (h % patch_size)) % patch_size;
|
||||||
|
int pad_w = (patch_size - (w % patch_size)) % patch_size;
|
||||||
|
int h_len = (h + pad_h) / patch_size;
|
||||||
|
int w_len = (w + pad_w) / patch_size;
|
||||||
|
if (h_len > 0 && w_len > 0) {
|
||||||
|
size_t pos_len = ids.size() / bs;
|
||||||
|
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
|
||||||
|
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding)
|
||||||
|
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
|
||||||
|
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
|
||||||
|
if (circular_h) {
|
||||||
|
wrap_dims[1][cursor + token_i] = h_len;
|
||||||
|
}
|
||||||
|
if (circular_w) {
|
||||||
|
wrap_dims[2][cursor + token_i] = w_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
|
||||||
|
|||||||
@ -405,6 +405,10 @@ public:
|
|||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
||||||
|
LOG_INFO("Using circular padding for convolutions");
|
||||||
|
}
|
||||||
|
|
||||||
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
|
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -705,6 +709,20 @@ public:
|
|||||||
}
|
}
|
||||||
pmid_model->get_param_tensors(tensors, "pmid");
|
pmid_model->get_param_tensors(tensors, "pmid");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
|
if (high_noise_diffusion_model) {
|
||||||
|
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
|
}
|
||||||
|
if (control_net) {
|
||||||
|
control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
|
}
|
||||||
|
if (first_stage_model) {
|
||||||
|
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
|
}
|
||||||
|
if (tae_first_stage) {
|
||||||
|
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
@ -2559,6 +2577,8 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
|||||||
sd_ctx_params->keep_control_net_on_cpu = false;
|
sd_ctx_params->keep_control_net_on_cpu = false;
|
||||||
sd_ctx_params->keep_vae_on_cpu = false;
|
sd_ctx_params->keep_vae_on_cpu = false;
|
||||||
sd_ctx_params->diffusion_flash_attn = false;
|
sd_ctx_params->diffusion_flash_attn = false;
|
||||||
|
sd_ctx_params->circular_x = false;
|
||||||
|
sd_ctx_params->circular_y = false;
|
||||||
sd_ctx_params->chroma_use_dit_mask = true;
|
sd_ctx_params->chroma_use_dit_mask = true;
|
||||||
sd_ctx_params->chroma_use_t5_mask = false;
|
sd_ctx_params->chroma_use_t5_mask = false;
|
||||||
sd_ctx_params->chroma_t5_mask_pad = 1;
|
sd_ctx_params->chroma_t5_mask_pad = 1;
|
||||||
@ -2598,6 +2618,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"keep_control_net_on_cpu: %s\n"
|
"keep_control_net_on_cpu: %s\n"
|
||||||
"keep_vae_on_cpu: %s\n"
|
"keep_vae_on_cpu: %s\n"
|
||||||
"diffusion_flash_attn: %s\n"
|
"diffusion_flash_attn: %s\n"
|
||||||
|
"circular_x: %s\n"
|
||||||
|
"circular_y: %s\n"
|
||||||
"chroma_use_dit_mask: %s\n"
|
"chroma_use_dit_mask: %s\n"
|
||||||
"chroma_use_t5_mask: %s\n"
|
"chroma_use_t5_mask: %s\n"
|
||||||
"chroma_t5_mask_pad: %d\n",
|
"chroma_t5_mask_pad: %d\n",
|
||||||
@ -2627,6 +2649,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
||||||
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
|
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
|
||||||
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
|
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
|
||||||
|
BOOL_STR(sd_ctx_params->circular_x),
|
||||||
|
BOOL_STR(sd_ctx_params->circular_y),
|
||||||
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
|
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
|
||||||
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
|
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
|
||||||
sd_ctx_params->chroma_t5_mask_pad);
|
sd_ctx_params->chroma_t5_mask_pad);
|
||||||
|
|||||||
@ -189,6 +189,8 @@ typedef struct {
|
|||||||
bool tae_preview_only;
|
bool tae_preview_only;
|
||||||
bool diffusion_conv_direct;
|
bool diffusion_conv_direct;
|
||||||
bool vae_conv_direct;
|
bool vae_conv_direct;
|
||||||
|
bool circular_x;
|
||||||
|
bool circular_y;
|
||||||
bool force_sdxl_vae_conv_scale;
|
bool force_sdxl_vae_conv_scale;
|
||||||
bool chroma_use_dit_mask;
|
bool chroma_use_dit_mask;
|
||||||
bool chroma_use_t5_mask;
|
bool chroma_use_t5_mask;
|
||||||
|
|||||||
15
wan.hpp
15
wan.hpp
@ -75,7 +75,7 @@ namespace WAN {
|
|||||||
lp2 -= (int)cache_x->ne[2];
|
lp2 -= (int)cache_x->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
|
x = ggml_ext_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
||||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
0, 0, 0,
|
0, 0, 0,
|
||||||
@ -206,9 +206,9 @@ namespace WAN {
|
|||||||
} else if (mode == "upsample3d") {
|
} else if (mode == "upsample3d") {
|
||||||
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
} else if (mode == "downsample2d") {
|
} else if (mode == "downsample2d") {
|
||||||
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
} else if (mode == "downsample3d") {
|
} else if (mode == "downsample3d") {
|
||||||
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
|
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
}
|
}
|
||||||
x = resample_1->forward(ctx, x);
|
x = resample_1->forward(ctx, x);
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
@ -1826,7 +1826,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
@ -1835,8 +1835,7 @@ namespace WAN {
|
|||||||
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
|
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
|
||||||
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
|
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
|
||||||
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
|
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
|
ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1986,14 +1985,14 @@ namespace WAN {
|
|||||||
int64_t T = x->ne[2];
|
int64_t T = x->ne[2];
|
||||||
int64_t C = x->ne[3];
|
int64_t C = x->ne[3];
|
||||||
|
|
||||||
x = pad_to_patch_size(ctx->ggml_ctx, x);
|
x = pad_to_patch_size(ctx, x);
|
||||||
|
|
||||||
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
|
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
|
||||||
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
|
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
|
||||||
|
|
||||||
if (time_dim_concat != nullptr) {
|
if (time_dim_concat != nullptr) {
|
||||||
time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat);
|
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
|
||||||
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
||||||
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
}
|
}
|
||||||
|
|||||||
14
z_image.hpp
14
z_image.hpp
@ -324,14 +324,14 @@ namespace ZImage {
|
|||||||
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
|
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
|
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
|
||||||
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
|
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,10 +357,10 @@ namespace ZImage {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
x = pad_to_patch_size(ctx, x);
|
x = pad_to_patch_size(ctx, x);
|
||||||
x = patchify(ctx, x);
|
x = patchify(ctx->ggml_ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,12 +473,12 @@ namespace ZImage {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx->ggml_ctx, x);
|
auto img = process_img(ctx, x);
|
||||||
uint64_t n_img_token = img->ne[1];
|
uint64_t n_img_token = img->ne[1];
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
ref = process_img(ctx->ggml_ctx, ref);
|
ref = process_img(ctx, ref);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -552,6 +552,8 @@ namespace ZImage {
|
|||||||
ref_latents,
|
ref_latents,
|
||||||
increase_ref_index,
|
increase_ref_index,
|
||||||
z_image_params.theta,
|
z_image_params.theta,
|
||||||
|
circular_y_enabled,
|
||||||
|
circular_x_enabled,
|
||||||
z_image_params.axes_dim);
|
z_image_params.axes_dim);
|
||||||
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
|
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
|
||||||
// LOG_DEBUG("pos_len %d", pos_len);
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user