feat: add seamless texture generation support (#914)

* global bool

* reworked circular to global flag

* cleaner implementation of tiling support in sd cpp

* cleaned rope

* working simplified but still need wraps

* Further clean of rope

* resolve flux conflict

* switch to pad op circular only

* Set ggml to most recent

* Revert ggml temp

* Update ggml to most recent

* Revert unneded flux change

* move circular flag to the GGMLRunnerContext

* Pass through circular param in all places where conv is called

* fix of constant and minor cleanup

* Added back --circular option

* Conv2d circular in vae and various models

* Fix temporal padding for qwen image and other vaes

* Z Image circular tiling

* x and y axis seamless only

* First attempt at chroma seamless x and y

* refactor into pure x and y, almost there

* Fix crash on chroma

* Refactor into cleaner variable choices

* Removed redundant set_circular_enabled

* Sync ggml

* simplify circular parameter

* format code

* no need to perform circular pad on the clip

* simplify circular_axes setting

* unify function naming

* remove unnecessary member variables

* simplify rope

---------

Co-authored-by: Phylliida <phylliidadev@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Phylliida Dev 2025-12-21 03:06:47 -07:00 committed by GitHub
parent 88ec9d30b1
commit 50ff966445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 375 additions and 79 deletions

View File

@ -28,7 +28,7 @@ public:
if (vae_downsample) { if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]); auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
x = conv->forward(ctx, x); x = conv->forward(ctx, x);
} else { } else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]); auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);

View File

@ -366,18 +366,18 @@ struct KLOptimalScheduler : SigmaScheduler {
for (uint32_t i = 0; i < n; ++i) { for (uint32_t i = 0; i < n; ++i) {
// t goes from 0.0 to 1.0 // t goes from 0.0 to 1.0
float t = static_cast<float>(i) / static_cast<float>(n-1); float t = static_cast<float>(i) / static_cast<float>(n - 1);
// Interpolate in the angle domain // Interpolate in the angle domain
float angle = t * alpha_min + (1.0f - t) * alpha_max; float angle = t * alpha_min + (1.0f - t) * alpha_max;
// Convert back to sigma // Convert back to sigma
sigmas.push_back(std::tan(angle)); sigmas.push_back(std::tan(angle));
} }
// Append the final zero to sigma // Append the final zero to sigma
sigmas.push_back(0.0f); sigmas.push_back(0.0f);
return sigmas; return sigmas;
} }
}; };

View File

@ -37,8 +37,9 @@ struct DiffusionModel {
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){}; virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0; virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0; virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
}; };
struct UNetModel : public DiffusionModel { struct UNetModel : public DiffusionModel {
@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled); unet.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
unet.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled); mmdit.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
mmdit.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled); flux.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
flux.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled); wan.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
wan.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled); qwen_image.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
qwen_image.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled); z_image.set_flash_attention_enabled(enabled);
} }
void set_circular_axes(bool circular_x, bool circular_y) override {
z_image.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads, bool compute(int n_threads,
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,

View File

@ -449,6 +449,10 @@ struct SDContextParams {
bool diffusion_conv_direct = false; bool diffusion_conv_direct = false;
bool vae_conv_direct = false; bool vae_conv_direct = false;
bool circular = false;
bool circular_x = false;
bool circular_y = false;
bool chroma_use_dit_mask = true; bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false; bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1; int chroma_t5_mask_pad = 1;
@ -605,6 +609,18 @@ struct SDContextParams {
"--vae-conv-direct", "--vae-conv-direct",
"use ggml_conv2d_direct in the vae model", "use ggml_conv2d_direct in the vae model",
true, &vae_conv_direct}, true, &vae_conv_direct},
{"",
"--circular",
"enable circular padding for convolutions",
true, &circular},
{"",
"--circularx",
"enable circular RoPE wrapping on x-axis (width) only",
true, &circular_x},
{"",
"--circulary",
"enable circular RoPE wrapping on y-axis (height) only",
true, &circular_y},
{"", {"",
"--chroma-disable-dit-mask", "--chroma-disable-dit-mask",
"disable dit mask for chroma", "disable dit mask for chroma",
@ -868,6 +884,9 @@ struct SDContextParams {
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
<< " circular: " << (circular ? "true" : "false") << ",\n"
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
@ -928,6 +947,8 @@ struct SDContextParams {
taesd_preview, taesd_preview,
diffusion_conv_direct, diffusion_conv_direct,
vae_conv_direct, vae_conv_direct,
circular || circular_x,
circular || circular_y,
force_sdxl_vae_conv_scale, force_sdxl_vae_conv_scale,
chroma_use_dit_mask, chroma_use_dit_mask,
chroma_use_t5_mask, chroma_use_t5_mask,

View File

@ -860,14 +860,14 @@ namespace Flux {
} }
} }
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x; return x;
} }
@ -893,11 +893,11 @@ namespace Flux {
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx, struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
x = pad_to_patch_size(ctx, x); x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x); x = patchify(ctx->ggml_ctx, x);
return x; return x;
} }
@ -1076,7 +1076,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = pad_to_patch_size(ctx->ggml_ctx, x); auto img = pad_to_patch_size(ctx, x);
auto orig_img = img; auto orig_img = img;
if (params.chroma_radiance_params.use_patch_size_32) { if (params.chroma_radiance_params.use_patch_size_32) {
@ -1150,7 +1150,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = process_img(ctx->ggml_ctx, x); auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1]; uint64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) { if (params.version == VERSION_FLUX_FILL) {
@ -1158,8 +1158,8 @@ namespace Flux {
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = process_img(ctx->ggml_ctx, masked); masked = process_img(ctx, masked);
mask = process_img(ctx->ggml_ctx, mask); mask = process_img(ctx, mask);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) { } else if (params.version == VERSION_FLEX_2) {
@ -1168,21 +1168,21 @@ namespace Flux {
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
masked = process_img(ctx->ggml_ctx, masked); masked = process_img(ctx, masked);
mask = process_img(ctx->ggml_ctx, mask); mask = process_img(ctx, mask);
control = process_img(ctx->ggml_ctx, control); control = process_img(ctx, control);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) { } else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr); GGML_ASSERT(c_concat != nullptr);
auto control = process_img(ctx->ggml_ctx, c_concat); auto control = process_img(ctx, c_concat);
img = ggml_concat(ctx->ggml_ctx, img, control, 0); img = ggml_concat(ctx->ggml_ctx, img, control, 0);
} }
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref); ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
@ -1472,6 +1472,8 @@ namespace Flux {
increase_ref_index, increase_ref_index,
flux_params.ref_index_scale, flux_params.ref_index_scale,
flux_params.theta, flux_params.theta,
circular_y_enabled,
circular_x_enabled,
flux_params.axes_dim); flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);

2
ggml

@ -1 +1 @@
Subproject commit f5425c0ee5e582a7d64411f06139870bff3e52e0 Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff

View File

@ -5,6 +5,7 @@
#include <inttypes.h> #include <inttypes.h>
#include <stdarg.h> #include <stdarg.h>
#include <algorithm> #include <algorithm>
#include <atomic>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <functional> #include <functional>
@ -993,6 +994,48 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
return x; return x;
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad_ext(struct ggml_context* ctx,
struct ggml_tensor* x,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3,
bool circular_x = false,
bool circular_y = false) {
if (circular_x && circular_y) {
return ggml_pad_ext_circular(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
}
if (circular_x && (lp0 != 0 || rp0 != 0)) {
x = ggml_pad_ext_circular(ctx, x, lp0, rp0, 0, 0, 0, 0, 0, 0);
lp0 = rp0 = 0;
}
if (circular_y && (lp1 != 0 || rp1 != 0)) {
x = ggml_pad_ext_circular(ctx, x, 0, 0, lp1, rp1, 0, 0, 0, 0);
lp1 = rp1 = 0;
}
if (lp0 != 0 || rp0 != 0 || lp1 != 0 || rp1 != 0 || lp2 != 0 || rp2 != 0 || lp3 != 0 || rp3 != 0) {
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad(struct ggml_context* ctx,
struct ggml_tensor* x,
int p0,
int p1,
int p2 = 0,
int p3 = 0,
bool circular_x = false,
bool circular_y = false) {
return ggml_ext_pad_ext(ctx, x, p0, p0, p1, p1, p2, p2, p3, p3, circular_x, circular_y);
}
// w: [OCIC, KH, KW] // w: [OCIC, KH, KW]
// x: [N, IC, IH, IW] // x: [N, IC, IH, IW]
// b: [OC,] // b: [OC,]
@ -1001,20 +1044,29 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
struct ggml_tensor* b, struct ggml_tensor* b,
int s0 = 1, int s0 = 1,
int s1 = 1, int s1 = 1,
int p0 = 0, int p0 = 0,
int p1 = 0, int p1 = 0,
int d0 = 1, int d0 = 1,
int d1 = 1, int d1 = 1,
bool direct = false, bool direct = false,
float scale = 1.f) { bool circular_x = false,
bool circular_y = false,
float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_scale(ctx, x, scale);
} }
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
} }
if ((p0 != 0 || p1 != 0) && (circular_x || circular_y)) {
x = ggml_ext_pad(ctx, x, p0, p1, 0, 0, circular_x, circular_y);
p0 = 0;
p1 = 0;
}
if (direct) { if (direct) {
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1); x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
} else { } else {
@ -1521,14 +1573,16 @@ struct WeightAdapter {
float scale = 1.f; float scale = 1.f;
} linear; } linear;
struct { struct {
int s0 = 1; int s0 = 1;
int s1 = 1; int s1 = 1;
int p0 = 0; int p0 = 0;
int p1 = 0; int p1 = 0;
int d0 = 1; int d0 = 1;
int d1 = 1; int d1 = 1;
bool direct = false; bool direct = false;
float scale = 1.f; bool circular_x = false;
bool circular_y = false;
float scale = 1.f;
} conv2d; } conv2d;
}; };
virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0; virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
@ -1546,6 +1600,8 @@ struct GGMLRunnerContext {
ggml_context* ggml_ctx = nullptr; ggml_context* ggml_ctx = nullptr;
bool flash_attn_enabled = false; bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false; bool conv2d_direct_enabled = false;
bool circular_x_enabled = false;
bool circular_y_enabled = false;
std::shared_ptr<WeightAdapter> weight_adapter = nullptr; std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
}; };
@ -1582,6 +1638,8 @@ protected:
bool flash_attn_enabled = false; bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false; bool conv2d_direct_enabled = false;
bool circular_x_enabled = false;
bool circular_y_enabled = false;
void alloc_params_ctx() { void alloc_params_ctx() {
struct ggml_init_params params; struct ggml_init_params params;
@ -1859,6 +1917,8 @@ public:
runner_ctx.backend = runtime_backend; runner_ctx.backend = runtime_backend;
runner_ctx.flash_attn_enabled = flash_attn_enabled; runner_ctx.flash_attn_enabled = flash_attn_enabled;
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
runner_ctx.circular_x_enabled = circular_x_enabled;
runner_ctx.circular_y_enabled = circular_y_enabled;
runner_ctx.weight_adapter = weight_adapter; runner_ctx.weight_adapter = weight_adapter;
return runner_ctx; return runner_ctx;
} }
@ -2003,6 +2063,11 @@ public:
conv2d_direct_enabled = enabled; conv2d_direct_enabled = enabled;
} }
void set_circular_axes(bool circular_x, bool circular_y) {
circular_x_enabled = circular_x;
circular_y_enabled = circular_y;
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
weight_adapter = adapter; weight_adapter = adapter;
} }
@ -2266,15 +2331,17 @@ public:
} }
if (ctx->weight_adapter) { if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params; WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
forward_params.conv2d.s0 = stride.second; forward_params.conv2d.s0 = stride.second;
forward_params.conv2d.s1 = stride.first; forward_params.conv2d.s1 = stride.first;
forward_params.conv2d.p0 = padding.second; forward_params.conv2d.p0 = padding.second;
forward_params.conv2d.p1 = padding.first; forward_params.conv2d.p1 = padding.first;
forward_params.conv2d.d0 = dilation.second; forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first; forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled; forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.scale = scale; forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
} }
return ggml_ext_conv_2d(ctx->ggml_ctx, return ggml_ext_conv_2d(ctx->ggml_ctx,
@ -2288,6 +2355,8 @@ public:
dilation.second, dilation.second,
dilation.first, dilation.first,
ctx->conv2d_direct_enabled, ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale); scale);
} }
}; };

View File

@ -599,6 +599,8 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.d0, forward_params.conv2d.d0,
forward_params.conv2d.d1, forward_params.conv2d.d1,
forward_params.conv2d.direct, forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale); forward_params.conv2d.scale);
if (lora_mid) { if (lora_mid) {
lx = ggml_ext_conv_2d(ctx, lx = ggml_ext_conv_2d(ctx,
@ -612,6 +614,8 @@ struct LoraModel : public GGMLRunner {
1, 1,
1, 1,
forward_params.conv2d.direct, forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
lx = ggml_ext_conv_2d(ctx, lx = ggml_ext_conv_2d(ctx,
@ -625,6 +629,8 @@ struct LoraModel : public GGMLRunner {
1, 1,
1, 1,
forward_params.conv2d.direct, forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
@ -779,6 +785,8 @@ public:
forward_params.conv2d.d0, forward_params.conv2d.d0,
forward_params.conv2d.d1, forward_params.conv2d.d1,
forward_params.conv2d.direct, forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
for (auto& lora_model : lora_models) { for (auto& lora_model : lora_models) {

View File

@ -983,4 +983,4 @@ struct MMDiTRunner : public GGMLRunner {
} }
}; };
#endif #endif

View File

@ -354,14 +354,14 @@ namespace Qwen {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
} }
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x; return x;
} }
@ -387,10 +387,10 @@ namespace Qwen {
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx, struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x); x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x); x = patchify(ctx->ggml_ctx, x);
return x; return x;
} }
@ -466,12 +466,12 @@ namespace Qwen {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto img = process_img(ctx->ggml_ctx, x); auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1]; uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref); ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
@ -565,6 +565,8 @@ namespace Qwen {
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
qwen_image_params.theta, qwen_image_params.theta,
circular_y_enabled,
circular_x_enabled,
qwen_image_params.axes_dim); qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
@ -684,4 +686,4 @@ namespace Qwen {
} // namespace name } // namespace name
#endif // __QWEN_IMAGE_HPP__ #endif // __QWEN_IMAGE_HPP__

164
rope.hpp
View File

@ -1,6 +1,8 @@
#ifndef __ROPE_HPP__ #ifndef __ROPE_HPP__
#define __ROPE_HPP__ #define __ROPE_HPP__
#include <algorithm>
#include <cmath>
#include <vector> #include <vector>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
@ -39,7 +41,10 @@ namespace Rope {
return flat_vec; return flat_vec;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) { __STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
int dim,
int theta,
const std::vector<int>& axis_wrap_dims = {}) {
assert(dim % 2 == 0); assert(dim % 2 == 0);
int half_dim = dim / 2; int half_dim = dim / 2;
@ -47,14 +52,31 @@ namespace Rope {
std::vector<float> omega(half_dim); std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) { for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0 / std::pow(theta, scale[i]); omega[i] = 1.0f / std::pow(theta, scale[i]);
} }
int pos_size = pos.size(); int pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim)); std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) { for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) { for (int j = 0; j < half_dim; ++j) {
out[i][j] = pos[i] * omega[j]; float angle = pos[i] * omega[j];
if (!axis_wrap_dims.empty()) {
size_t wrap_size = axis_wrap_dims.size();
// mod batch size since we only store this for one item in the batch
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
int wrap_dim = axis_wrap_dims[wrap_idx];
if (wrap_dim > 0) {
constexpr float TWO_PI = 6.28318530717958647692f;
float cycles = omega[j] * wrap_dim / TWO_PI;
// closest periodic harmonic, necessary to ensure things neatly tile
// without this round, things don't tile at the boundaries and you end up
// with the model knowing what is "center"
float rounded = std::round(cycles);
angle = pos[i] * TWO_PI * rounded / wrap_dim;
}
}
out[i][j] = angle;
} }
} }
@ -89,9 +111,9 @@ namespace Rope {
int patch_size, int patch_size,
int bs, int bs,
int axes_dim_num, int axes_dim_num,
int index = 0, int index = 0,
int h_offset = 0, int h_offset = 0,
int w_offset = 0, int w_offset = 0,
bool scale_rope = false) { bool scale_rope = false) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
@ -146,7 +168,8 @@ namespace Rope {
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids, __STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs, int bs,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size(); int num_axes = axes_dim.size();
@ -161,7 +184,12 @@ namespace Rope {
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0)); std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0; int offset = 0;
for (int i = 0; i < num_axes; ++i) { for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] std::vector<int> axis_wrap_dims;
if (!wrap_dims.empty() && i < (int)wrap_dims.size()) {
axis_wrap_dims = wrap_dims[i];
}
std::vector<std::vector<float>> rope_emb =
rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) { for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) { for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) { for (int k = 0; k < rope_emb[0].size(); ++k) {
@ -251,6 +279,8 @@ namespace Rope {
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale, float ref_index_scale,
int theta, int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, std::vector<std::vector<float>> ids = gen_flux_ids(h,
w, w,
@ -262,7 +292,47 @@ namespace Rope {
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
ref_index_scale); ref_index_scale);
return embed_nd(ids, bs, theta, axes_dim); std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len; // text first
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
cursor += img_tokens;
// reference latents
for (ggml_tensor* ref : ref_latents) {
if (ref == nullptr) {
continue;
}
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
int ref_h_l = (ref_h + (patch_size / 2)) / patch_size;
int ref_w_l = (ref_w + (patch_size / 2)) / patch_size;
size_t ref_tokens = static_cast<size_t>(ref_h_l) * static_cast<size_t>(ref_w_l);
for (size_t token_i = 0; token_i < ref_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = ref_h_l;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = ref_w_l;
}
}
cursor += ref_tokens;
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
@ -301,9 +371,57 @@ namespace Rope {
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); std::vector<std::vector<int>> wrap_dims;
// This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size;
if (h_len > 0 && w_len > 0) {
const size_t total_tokens = ids.size();
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
wrap_dims.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
size_t cursor = context_len; // ignore text tokens
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
cursor += img_tokens;
// For each reference image, store wrap sizes as well
for (ggml_tensor* ref : ref_latents) {
if (ref == nullptr) {
continue;
}
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size;
int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size;
int ref_h_len = (ref_h + ref_pad_h) / patch_size;
int ref_w_len = (ref_w + ref_pad_w) / patch_size;
size_t ref_n_tokens = static_cast<size_t>(ref_h_len) * static_cast<size_t>(ref_w_len);
for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = ref_h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = ref_w_len;
}
}
cursor += ref_n_tokens;
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
@ -440,9 +558,33 @@ namespace Rope {
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding)
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
} }
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,

View File

@ -405,6 +405,10 @@ public:
vae_decode_only = false; vae_decode_only = false;
} }
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
LOG_INFO("Using circular padding for convolutions");
}
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
{ {
@ -705,6 +709,20 @@ public:
} }
pmid_model->get_param_tensors(tensors, "pmid"); pmid_model->get_param_tensors(tensors, "pmid");
} }
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (control_net) {
control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (first_stage_model) {
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (tae_first_stage) {
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
} }
struct ggml_init_params params; struct ggml_init_params params;
@ -1519,7 +1537,7 @@ public:
} }
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
float cfg_scale = guidance.txt_cfg; float cfg_scale = guidance.txt_cfg;
if (cfg_scale < 1.f) { if (cfg_scale < 1.f) {
if (cfg_scale == 0.f) { if (cfg_scale == 0.f) {
// Diffusers follow the convention from the original paper // Diffusers follow the convention from the original paper
@ -2559,6 +2577,8 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false;
sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false;
sd_ctx_params->diffusion_flash_attn = false; sd_ctx_params->diffusion_flash_attn = false;
sd_ctx_params->circular_x = false;
sd_ctx_params->circular_y = false;
sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1; sd_ctx_params->chroma_t5_mask_pad = 1;
@ -2598,6 +2618,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_control_net_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n" "keep_vae_on_cpu: %s\n"
"diffusion_flash_attn: %s\n" "diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
"chroma_use_dit_mask: %s\n" "chroma_use_dit_mask: %s\n"
"chroma_use_t5_mask: %s\n" "chroma_use_t5_mask: %s\n"
"chroma_t5_mask_pad: %d\n", "chroma_t5_mask_pad: %d\n",
@ -2627,6 +2649,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->diffusion_flash_attn), BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),
BOOL_STR(sd_ctx_params->chroma_use_dit_mask), BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
BOOL_STR(sd_ctx_params->chroma_use_t5_mask), BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
sd_ctx_params->chroma_t5_mask_pad); sd_ctx_params->chroma_t5_mask_pad);

View File

@ -189,6 +189,8 @@ typedef struct {
bool tae_preview_only; bool tae_preview_only;
bool diffusion_conv_direct; bool diffusion_conv_direct;
bool vae_conv_direct; bool vae_conv_direct;
bool circular_x;
bool circular_y;
bool force_sdxl_vae_conv_scale; bool force_sdxl_vae_conv_scale;
bool chroma_use_dit_mask; bool chroma_use_dit_mask;
bool chroma_use_t5_mask; bool chroma_use_t5_mask;

15
wan.hpp
View File

@ -75,7 +75,7 @@ namespace WAN {
lp2 -= (int)cache_x->ne[2]; lp2 -= (int)cache_x->ne[2];
} }
x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); x = ggml_ext_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
0, 0, 0, 0, 0, 0,
@ -206,9 +206,9 @@ namespace WAN {
} else if (mode == "upsample3d") { } else if (mode == "upsample3d") {
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "downsample2d") { } else if (mode == "downsample2d") {
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
} else if (mode == "downsample3d") { } else if (mode == "downsample3d") {
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
} }
x = resample_1->forward(ctx, x); x = resample_1->forward(ctx, x);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w) x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
@ -1826,7 +1826,7 @@ namespace WAN {
} }
} }
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
@ -1835,8 +1835,7 @@ namespace WAN {
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x; return x;
} }
@ -1986,14 +1985,14 @@ namespace WAN {
int64_t T = x->ne[2]; int64_t T = x->ne[2];
int64_t C = x->ne[3]; int64_t C = x->ne[3];
x = pad_to_patch_size(ctx->ggml_ctx, x); x = pad_to_patch_size(ctx, x);
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
if (time_dim_concat != nullptr) { if (time_dim_concat != nullptr) {
time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat); time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
} }

View File

@ -324,14 +324,14 @@ namespace ZImage {
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
} }
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size; int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size; int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x; return x;
} }
@ -357,10 +357,10 @@ namespace ZImage {
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx, struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x); x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x); x = patchify(ctx->ggml_ctx, x);
return x; return x;
} }
@ -473,12 +473,12 @@ namespace ZImage {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto img = process_img(ctx->ggml_ctx, x); auto img = process_img(ctx, x);
uint64_t n_img_token = img->ne[1]; uint64_t n_img_token = img->ne[1];
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref); ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
@ -552,6 +552,8 @@ namespace ZImage {
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
z_image_params.theta, z_image_params.theta,
circular_y_enabled,
circular_x_enabled,
z_image_params.axes_dim); z_image_params.axes_dim);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);