feat: add seamless texture generation support (#914)

* global bool

* reworked circular to global flag

* cleaner implementation of tiling support in sd cpp

* cleaned rope

* working simplified but still need wraps

* Further clean of rope

* resolve flux conflict

* switch to pad op circular only

* Set ggml to most recent

* Revert ggml temp

* Update ggml to most recent

* Revert unneded flux change

* move circular flag to the GGMLRunnerContext

* Pass through circular param in all places where conv is called

* fix of constant and minor cleanup

* Added back --circular option

* Conv2d circular in vae and various models

* Fix temporal padding for qwen image and other vaes

* Z Image circular tiling

* x and y axis seamless only

* First attempt at chroma seamless x and y

* refactor into pure x and y, almost there

* Fix crash on chroma

* Refactor into cleaner variable choices

* Removed redundant set_circular_enabled

* Sync ggml

* simplify circular parameter

* format code

* no need to perform circular pad on the clip

* simplify circular_axes setting

* unify function naming

* remove unnecessary member variables

* simplify rope

---------

Co-authored-by: Phylliida <phylliidadev@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Phylliida Dev 2025-12-21 03:06:47 -07:00 committed by GitHub
parent 88ec9d30b1
commit 50ff966445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 375 additions and 79 deletions

View File

@ -28,7 +28,7 @@ public:
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);

View File

@ -366,7 +366,7 @@ struct KLOptimalScheduler : SigmaScheduler {
for (uint32_t i = 0; i < n; ++i) {
// t goes from 0.0 to 1.0
float t = static_cast<float>(i) / static_cast<float>(n-1);
float t = static_cast<float>(i) / static_cast<float>(n - 1);
// Interpolate in the angle domain
float angle = t * alpha_min + (1.0f - t) * alpha_max;

View File

@ -39,6 +39,7 @@ struct DiffusionModel {
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
struct UNetModel : public DiffusionModel {
@ -87,6 +88,10 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
unet.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -148,6 +153,10 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
mmdit.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -210,6 +219,10 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
flux.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -277,6 +290,10 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
wan.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -343,6 +360,10 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
qwen_image.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@ -406,6 +427,10 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
z_image.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,

View File

@ -449,6 +449,10 @@ struct SDContextParams {
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool circular = false;
bool circular_x = false;
bool circular_y = false;
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
@ -605,6 +609,18 @@ struct SDContextParams {
"--vae-conv-direct",
"use ggml_conv2d_direct in the vae model",
true, &vae_conv_direct},
{"",
"--circular",
"enable circular padding for convolutions",
true, &circular},
{"",
"--circularx",
"enable circular RoPE wrapping on x-axis (width) only",
true, &circular_x},
{"",
"--circulary",
"enable circular RoPE wrapping on y-axis (height) only",
true, &circular_y},
{"",
"--chroma-disable-dit-mask",
"disable dit mask for chroma",
@ -868,6 +884,9 @@ struct SDContextParams {
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
<< " circular: " << (circular ? "true" : "false") << ",\n"
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
@ -928,6 +947,8 @@ struct SDContextParams {
taesd_preview,
diffusion_conv_direct,
vae_conv_direct,
circular || circular_x,
circular || circular_y,
force_sdxl_vae_conv_scale,
chroma_use_dit_mask,
chroma_use_t5_mask,

View File

@ -860,14 +860,14 @@ namespace Flux {
}
}
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
@ -893,11 +893,11 @@ namespace Flux {
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
@ -1076,7 +1076,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
auto img = pad_to_patch_size(ctx, x);
auto orig_img = img;
if (params.chroma_radiance_params.use_patch_size_32) {
@ -1150,7 +1150,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = process_img(ctx->ggml_ctx, x);
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) {
@ -1158,8 +1158,8 @@ namespace Flux {
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
@ -1168,21 +1168,21 @@ namespace Flux {
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
masked = process_img(ctx->ggml_ctx, masked);
mask = process_img(ctx->ggml_ctx, mask);
control = process_img(ctx->ggml_ctx, control);
masked = process_img(ctx, masked);
mask = process_img(ctx, mask);
control = process_img(ctx, control);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr);
auto control = process_img(ctx->ggml_ctx, c_concat);
auto control = process_img(ctx, c_concat);
img = ggml_concat(ctx->ggml_ctx, img, control, 0);
}
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref);
ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
@ -1472,6 +1472,8 @@ namespace Flux {
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
circular_y_enabled,
circular_x_enabled,
flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);

2
ggml

@ -1 +1 @@
Subproject commit f5425c0ee5e582a7d64411f06139870bff3e52e0
Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff

View File

@ -5,6 +5,7 @@
#include <inttypes.h>
#include <stdarg.h>
#include <algorithm>
#include <atomic>
#include <cstring>
#include <fstream>
#include <functional>
@ -993,6 +994,48 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad_ext(struct ggml_context* ctx,
struct ggml_tensor* x,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3,
bool circular_x = false,
bool circular_y = false) {
if (circular_x && circular_y) {
return ggml_pad_ext_circular(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
}
if (circular_x && (lp0 != 0 || rp0 != 0)) {
x = ggml_pad_ext_circular(ctx, x, lp0, rp0, 0, 0, 0, 0, 0, 0);
lp0 = rp0 = 0;
}
if (circular_y && (lp1 != 0 || rp1 != 0)) {
x = ggml_pad_ext_circular(ctx, x, 0, 0, lp1, rp1, 0, 0, 0, 0);
lp1 = rp1 = 0;
}
if (lp0 != 0 || rp0 != 0 || lp1 != 0 || rp1 != 0 || lp2 != 0 || rp2 != 0 || lp3 != 0 || rp3 != 0) {
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_pad(struct ggml_context* ctx,
struct ggml_tensor* x,
int p0,
int p1,
int p2 = 0,
int p3 = 0,
bool circular_x = false,
bool circular_y = false) {
return ggml_ext_pad_ext(ctx, x, p0, p0, p1, p1, p2, p2, p3, p3, circular_x, circular_y);
}
// w: [OCIC, KH, KW]
// x: [N, IC, IH, IW]
// b: [OC,]
@ -1008,6 +1051,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
int d0 = 1,
int d1 = 1,
bool direct = false,
bool circular_x = false,
bool circular_y = false,
float scale = 1.f) {
if (scale != 1.f) {
x = ggml_scale(ctx, x, scale);
@ -1015,6 +1060,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
}
if ((p0 != 0 || p1 != 0) && (circular_x || circular_y)) {
x = ggml_ext_pad(ctx, x, p0, p1, 0, 0, circular_x, circular_y);
p0 = 0;
p1 = 0;
}
if (direct) {
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
} else {
@ -1528,6 +1580,8 @@ struct WeightAdapter {
int d0 = 1;
int d1 = 1;
bool direct = false;
bool circular_x = false;
bool circular_y = false;
float scale = 1.f;
} conv2d;
};
@ -1546,6 +1600,8 @@ struct GGMLRunnerContext {
ggml_context* ggml_ctx = nullptr;
bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false;
bool circular_x_enabled = false;
bool circular_y_enabled = false;
std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
};
@ -1582,6 +1638,8 @@ protected:
bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false;
bool circular_x_enabled = false;
bool circular_y_enabled = false;
void alloc_params_ctx() {
struct ggml_init_params params;
@ -1859,6 +1917,8 @@ public:
runner_ctx.backend = runtime_backend;
runner_ctx.flash_attn_enabled = flash_attn_enabled;
runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled;
runner_ctx.circular_x_enabled = circular_x_enabled;
runner_ctx.circular_y_enabled = circular_y_enabled;
runner_ctx.weight_adapter = weight_adapter;
return runner_ctx;
}
@ -2003,6 +2063,11 @@ public:
conv2d_direct_enabled = enabled;
}
void set_circular_axes(bool circular_x, bool circular_y) {
circular_x_enabled = circular_x;
circular_y_enabled = circular_y;
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
weight_adapter = adapter;
}
@ -2274,6 +2339,8 @@ public:
forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
@ -2288,6 +2355,8 @@ public:
dilation.second,
dilation.first,
ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale);
}
};

View File

@ -599,6 +599,8 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.d0,
forward_params.conv2d.d1,
forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale);
if (lora_mid) {
lx = ggml_ext_conv_2d(ctx,
@ -612,6 +614,8 @@ struct LoraModel : public GGMLRunner {
1,
1,
forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale);
}
lx = ggml_ext_conv_2d(ctx,
@ -625,6 +629,8 @@ struct LoraModel : public GGMLRunner {
1,
1,
forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale);
}
@ -779,6 +785,8 @@ public:
forward_params.conv2d.d0,
forward_params.conv2d.d1,
forward_params.conv2d.direct,
forward_params.conv2d.circular_x,
forward_params.conv2d.circular_y,
forward_params.conv2d.scale);
}
for (auto& lora_model : lora_models) {

View File

@ -354,14 +354,14 @@ namespace Qwen {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
}
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
@ -387,10 +387,10 @@ namespace Qwen {
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
@ -466,12 +466,12 @@ namespace Qwen {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
auto img = process_img(ctx->ggml_ctx, x);
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref);
ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
@ -565,6 +565,8 @@ namespace Qwen {
ref_latents,
increase_ref_index,
qwen_image_params.theta,
circular_y_enabled,
circular_x_enabled,
qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);

158
rope.hpp
View File

@ -1,6 +1,8 @@
#ifndef __ROPE_HPP__
#define __ROPE_HPP__
#include <algorithm>
#include <cmath>
#include <vector>
#include "ggml_extend.hpp"
@ -39,7 +41,10 @@ namespace Rope {
return flat_vec;
}
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
int dim,
int theta,
const std::vector<int>& axis_wrap_dims = {}) {
assert(dim % 2 == 0);
int half_dim = dim / 2;
@ -47,14 +52,31 @@ namespace Rope {
std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0 / std::pow(theta, scale[i]);
omega[i] = 1.0f / std::pow(theta, scale[i]);
}
int pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
out[i][j] = pos[i] * omega[j];
float angle = pos[i] * omega[j];
if (!axis_wrap_dims.empty()) {
size_t wrap_size = axis_wrap_dims.size();
// mod batch size since we only store this for one item in the batch
size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0;
int wrap_dim = axis_wrap_dims[wrap_idx];
if (wrap_dim > 0) {
constexpr float TWO_PI = 6.28318530717958647692f;
float cycles = omega[j] * wrap_dim / TWO_PI;
// closest periodic harmonic, necessary to ensure things neatly tile
// without this round, things don't tile at the boundaries and you end up
// with the model knowing what is "center"
float rounded = std::round(cycles);
angle = pos[i] * TWO_PI * rounded / wrap_dim;
}
}
out[i][j] = angle;
}
}
@ -146,7 +168,8 @@ namespace Rope {
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs,
int theta,
const std::vector<int>& axes_dim) {
const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) {
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size();
@ -161,7 +184,12 @@ namespace Rope {
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0;
for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
std::vector<int> axis_wrap_dims;
if (!wrap_dims.empty() && i < (int)wrap_dims.size()) {
axis_wrap_dims = wrap_dims[i];
}
std::vector<std::vector<float>> rope_emb =
rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) {
@ -251,6 +279,8 @@ namespace Rope {
bool increase_ref_index,
float ref_index_scale,
int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h,
w,
@ -262,7 +292,47 @@ namespace Rope {
ref_latents,
increase_ref_index,
ref_index_scale);
return embed_nd(ids, bs, theta, axes_dim);
std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len; // text first
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
cursor += img_tokens;
// reference latents
for (ggml_tensor* ref : ref_latents) {
if (ref == nullptr) {
continue;
}
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
int ref_h_l = (ref_h + (patch_size / 2)) / patch_size;
int ref_w_l = (ref_w + (patch_size / 2)) / patch_size;
size_t ref_tokens = static_cast<size_t>(ref_h_l) * static_cast<size_t>(ref_w_l);
for (size_t token_i = 0; token_i < ref_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = ref_h_l;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = ref_w_l;
}
}
cursor += ref_tokens;
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
@ -301,9 +371,57 @@ namespace Rope {
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
std::vector<std::vector<int>> wrap_dims;
// This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size;
if (h_len > 0 && w_len > 0) {
const size_t total_tokens = ids.size();
// Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic.
wrap_dims.assign(axes_dim.size(), std::vector<int>(total_tokens / bs, 0));
size_t cursor = context_len; // ignore text tokens
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
cursor += img_tokens;
// For each reference image, store wrap sizes as well
for (ggml_tensor* ref : ref_latents) {
if (ref == nullptr) {
continue;
}
int ref_h = static_cast<int>(ref->ne[1]);
int ref_w = static_cast<int>(ref->ne[0]);
int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size;
int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size;
int ref_h_len = (ref_h + ref_pad_h) / patch_size;
int ref_w_len = (ref_w + ref_pad_w) / patch_size;
size_t ref_n_tokens = static_cast<size_t>(ref_h_len) * static_cast<size_t>(ref_w_len);
for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = ref_h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = ref_w_len;
}
}
cursor += ref_n_tokens;
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
@ -440,9 +558,33 @@ namespace Rope {
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding)
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][cursor + token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][cursor + token_i] = w_len;
}
}
}
}
return embed_nd(ids, bs, theta, axes_dim, wrap_dims);
}
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,

View File

@ -405,6 +405,10 @@ public:
vae_decode_only = false;
}
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
LOG_INFO("Using circular padding for convolutions");
}
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
{
@ -705,6 +709,20 @@ public:
}
pmid_model->get_param_tensors(tensors, "pmid");
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (control_net) {
control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (first_stage_model) {
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (tae_first_stage) {
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
}
struct ggml_init_params params;
@ -2559,6 +2577,8 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->keep_control_net_on_cpu = false;
sd_ctx_params->keep_vae_on_cpu = false;
sd_ctx_params->diffusion_flash_attn = false;
sd_ctx_params->circular_x = false;
sd_ctx_params->circular_y = false;
sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1;
@ -2598,6 +2618,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
"diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
"chroma_use_dit_mask: %s\n"
"chroma_use_t5_mask: %s\n"
"chroma_t5_mask_pad: %d\n",
@ -2627,6 +2649,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
sd_ctx_params->chroma_t5_mask_pad);

View File

@ -189,6 +189,8 @@ typedef struct {
bool tae_preview_only;
bool diffusion_conv_direct;
bool vae_conv_direct;
bool circular_x;
bool circular_y;
bool force_sdxl_vae_conv_scale;
bool chroma_use_dit_mask;
bool chroma_use_t5_mask;

15
wan.hpp
View File

@ -75,7 +75,7 @@ namespace WAN {
lp2 -= (int)cache_x->ne[2];
}
x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
x = ggml_ext_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
0, 0, 0,
@ -206,9 +206,9 @@ namespace WAN {
} else if (mode == "upsample3d") {
x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "downsample2d") {
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
} else if (mode == "downsample3d") {
x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0);
x = ggml_ext_pad(ctx->ggml_ctx, x, 1, 1, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
}
x = resample_1->forward(ctx, x);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
@ -1826,7 +1826,7 @@ namespace WAN {
}
}
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
@ -1835,8 +1835,7 @@ namespace WAN {
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
@ -1986,14 +1985,14 @@ namespace WAN {
int64_t T = x->ne[2];
int64_t C = x->ne[3];
x = pad_to_patch_size(ctx->ggml_ctx, x);
x = pad_to_patch_size(ctx, x);
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
if (time_dim_concat != nullptr) {
time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat);
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
}

View File

@ -324,14 +324,14 @@ namespace ZImage {
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
}
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
@ -357,10 +357,10 @@ namespace ZImage {
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
@ -473,12 +473,12 @@ namespace ZImage {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
auto img = process_img(ctx->ggml_ctx, x);
auto img = process_img(ctx, x);
uint64_t n_img_token = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx->ggml_ctx, ref);
ref = process_img(ctx, ref);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
}
}
@ -552,6 +552,8 @@ namespace ZImage {
ref_latents,
increase_ref_index,
z_image_params.theta,
circular_y_enabled,
circular_x_enabled,
z_image_params.axes_dim);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);