Compare commits

..

No commits in common. "master" and "master-477-639091f" have entirely different histories.

33 changed files with 343 additions and 856 deletions

View File

@ -15,9 +15,6 @@ API and command-line option may change frequently.***
## 🔥Important News
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 870 KiB

View File

@ -479,9 +479,9 @@ public:
x = fc1->forward(ctx, x);
if (use_gelu) {
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
} else {
x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
}
x = fc2->forward(ctx, x);
return x;
@ -510,7 +510,7 @@ public:
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
// x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
@ -542,8 +542,8 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* mask = nullptr,
int clip_skip = -1) {
int clip_skip = -1,
bool mask = true) {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip);
@ -741,7 +741,6 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings,
struct ggml_tensor* mask = nullptr,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
@ -751,7 +750,7 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
@ -815,10 +814,9 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, nullptr, clip_skip);
x = encoder->forward(ctx, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[3] == 1);
@ -907,8 +905,6 @@ public:
struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model;
std::vector<float> attention_mask_vec;
CLIPTextModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
@ -942,7 +938,6 @@ struct CLIPTextModelRunner : public GGMLRunner {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings,
struct ggml_tensor* mask,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
@ -953,7 +948,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
}
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@ -980,23 +975,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}
int n_tokens = static_cast<int>(input_ids->ne[0]);
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states);

View File

@ -200,7 +200,7 @@ public:
gate = ggml_cont(ctx->ggml_ctx, gate);
gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
@ -220,7 +220,7 @@ public:
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
return x;
}
};
@ -317,7 +317,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
@ -536,8 +536,8 @@ public:
// image_only_indicator is always tensor([0.])
float alpha = get_alpha();
auto x = ggml_add(ctx->ggml_ctx,
ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x;
}
};

View File

@ -34,7 +34,6 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
@ -116,13 +115,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) {
@ -791,18 +783,6 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1211,15 +1191,6 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1469,12 +1440,6 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
@ -1685,10 +1650,6 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) {
llm->set_weight_adapter(adapter);

View File

@ -1,8 +1,6 @@
#ifndef __DENOISER_HPP__
#define __DENOISER_HPP__
#include <cmath>
#include "ggml_extend.hpp"
#include "gits_noise.inl"
@ -353,95 +351,6 @@ struct SmoothStepScheduler : SigmaScheduler {
}
};
struct BongTangentScheduler : SigmaScheduler {
static constexpr float kPi = 3.14159265358979323846f;
static std::vector<float> get_bong_tangent_sigmas(int steps, float slope, float pivot, float start, float end) {
std::vector<float> sigmas;
if (steps <= 0) {
return sigmas;
}
float smax = ((2.0f / kPi) * atanf(-slope * (0.0f - pivot)) + 1.0f) * 0.5f;
float smin = ((2.0f / kPi) * atanf(-slope * ((float)(steps - 1) - pivot)) + 1.0f) * 0.5f;
float srange = smax - smin;
float sscale = start - end;
sigmas.reserve(steps);
if (fabsf(srange) < 1e-8f) {
if (steps == 1) {
sigmas.push_back(start);
return sigmas;
}
for (int i = 0; i < steps; ++i) {
float t = (float)i / (float)(steps - 1);
sigmas.push_back(start + (end - start) * t);
}
return sigmas;
}
float inv_srange = 1.0f / srange;
for (int x = 0; x < steps; ++x) {
float v = ((2.0f / kPi) * atanf(-slope * ((float)x - pivot)) + 1.0f) * 0.5f;
float sigma = ((v - smin) * inv_srange) * sscale + end;
sigmas.push_back(sigma);
}
return sigmas;
}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> result;
if (n == 0) {
return result;
}
float start = sigma_max;
float end = sigma_min;
float middle = sigma_min + (sigma_max - sigma_min) * 0.5f;
float pivot_1 = 0.6f;
float pivot_2 = 0.6f;
float slope_1 = 0.2f;
float slope_2 = 0.2f;
int steps = static_cast<int>(n) + 2;
int midpoint = static_cast<int>(((float)steps * pivot_1 + (float)steps * pivot_2) * 0.5f);
int pivot_1_i = static_cast<int>((float)steps * pivot_1);
int pivot_2_i = static_cast<int>((float)steps * pivot_2);
float slope_scale = (float)steps / 40.0f;
slope_1 = slope_1 / slope_scale;
slope_2 = slope_2 / slope_scale;
int stage_2_len = steps - midpoint;
int stage_1_len = steps - stage_2_len;
std::vector<float> sigmas_1 = get_bong_tangent_sigmas(stage_1_len, slope_1, (float)pivot_1_i, start, middle);
std::vector<float> sigmas_2 = get_bong_tangent_sigmas(stage_2_len, slope_2, (float)(pivot_2_i - stage_1_len), middle, end);
if (!sigmas_1.empty()) {
sigmas_1.pop_back();
}
result.reserve(n + 1);
result.insert(result.end(), sigmas_1.begin(), sigmas_1.end());
result.insert(result.end(), sigmas_2.begin(), sigmas_2.end());
if (result.size() < n + 1) {
while (result.size() < n + 1) {
result.push_back(end);
}
} else if (result.size() > n + 1) {
result.resize(n + 1);
}
result[n] = 0.0f;
return result;
}
};
struct KLOptimalScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas;
@ -522,10 +431,6 @@ struct Denoiser {
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
break;
case BONG_TANGENT_SCHEDULER:
LOG_INFO("get_sigmas with bong_tangent scheduler");
scheduler = std::make_shared<BongTangentScheduler>();
break;
case KL_OPTIMAL_SCHEDULER:
LOG_INFO("get_sigmas with KL Optimal scheduler");
scheduler = std::make_shared<KLOptimalScheduler>();
@ -1729,216 +1634,6 @@ static bool sample_k_diffusion(sample_method_t method,
}
}
} break;
case RES_MULTISTEP_SAMPLE_METHOD: // Res Multistep sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
bool have_old_sigma = false;
float old_sigma_down = 0.0f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto sigma_fn = [](float t) -> float { return expf(-t); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
if (sigma_down == 0.0f || !have_old_sigma) {
float dt = sigma_down - sigma_from;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d = (vec_x[j] - vec_denoised[j]) / sigma_from;
vec_x[j] = vec_x[j] + d * dt;
}
} else {
float t = t_fn(sigma_from);
float t_old = t_fn(old_sigma_down);
float t_next = t_fn(sigma_down);
float t_prev = t_fn(sigmas[i - 1]);
float h = t_next - t;
float c2 = (t_prev - t_old) / h;
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b1 = phi1_val - phi2_val / c2;
float b2 = phi2_val / c2;
if (!std::isfinite(b1)) {
b1 = 0.0f;
}
if (!std::isfinite(b2)) {
b2 = 0.0f;
}
float sigma_h = sigma_fn(h);
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
float* vec_old_denoised = (float*)old_denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = sigma_h * vec_x[j] + h * (b1 * vec_denoised[j] + b2 * vec_old_denoised[j]);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
float* vec_old_denoised = (float*)old_denoised->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_old_denoised[j] = vec_denoised[j];
}
old_sigma_down = sigma_down;
have_old_sigma = true;
}
} break;
case RES_2S_SAMPLE_METHOD: // Res 2s sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x0 = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
ggml_tensor* denoised = model(x, sigma_from, -(i + 1));
if (denoised == nullptr) {
return false;
}
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
float* vec_x = (float*)x->data;
float* vec_x0 = (float*)x0->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x0[j] = vec_x[j];
}
if (sigma_down == 0.0f || sigma_from == 0.0f) {
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_denoised[j];
}
} else {
float t = t_fn(sigma_from);
float t_next = t_fn(sigma_down);
float h = t_next - t;
float a21 = c2 * phi1_fn(-h * c2);
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b2 = phi2_val / c2;
float b1 = phi1_val - b2;
float sigma_c2 = expf(-(t + h * c2));
float* vec_denoised = (float*)denoised->data;
float* vec_x2 = (float*)x2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
vec_x2[j] = vec_x0[j] + h * a21 * eps1;
}
ggml_tensor* denoised2 = model(x2, sigma_c2, i + 1);
if (denoised2 == nullptr) {
return false;
}
float* vec_denoised2 = (float*)denoised2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
float eps2 = vec_denoised2[j] - vec_x0[j];
vec_x[j] = vec_x0[j] + h * (b1 * eps1 + b2 * eps2);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
}
} break;
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

View File

@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
void set_flash_attn_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}

View File

@ -7,9 +7,6 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
- Download Z-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
- Download Z-Image
- safetensors: https://huggingface.co/Comfy-Org/z_image/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/unsloth/Z-Image-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Qwen3 4b
@ -18,22 +15,12 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
## Examples
### Z-Image-Turbo
```
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
```
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />
### Z-Image-Base
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\z_image_bf16.safetensors --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
```
<img width="256" alt="z-image example" src="../assets/z_image/base_bf16.png" />
## Comparison of Different Quantization Types
| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|

View File

@ -51,7 +51,7 @@ public:
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat);
x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5;
}
};
@ -76,7 +76,7 @@ public:
out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out);
out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
return out;
}
};

View File

@ -52,8 +52,7 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-fa use flash attention in the diffusion model
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions
@ -108,14 +107,14 @@ Generation Options:
medium
--skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
--strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
@ -124,12 +123,12 @@ Generation Options:
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm, bong_tangent], default: discrete
kl_optimal, lcm], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])

View File

@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
}
bool vae_decode_only = true;
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames;
@ -556,79 +556,57 @@ int main(int argc, const char* argv[]) {
control_frames.clear();
};
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
int expected_height = 0;
if (resize_image && gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
};
if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
int width = 0;
int height = 0;
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
release_all_resources();
return 1;
}
}
if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
return 1;
}
}
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) {
int width = 0;
int height = 0;
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back(ref_image);
}
}
if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(&mask_image,
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
int c = 0;
int width = 0;
int height = 0;
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
if (mask_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1;
}
} else {
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
if (mask_image.data == nullptr) {
LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1;
}
mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
memset(mask_image.data, 255, gen_params.width * gen_params.height);
}
if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(&control_image,
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
int width = 0;
int height = 0;
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (control_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1;
@ -643,11 +621,29 @@ int main(int argc, const char* argv[]) {
}
}
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back({(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
}
}
if (!gen_params.control_video_path.empty()) {
if (!load_images_from_dir(gen_params.control_video_path,
control_frames,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.width,
gen_params.height,
gen_params.video_frames,
cli_params.verbose)) {
release_all_resources();
@ -721,8 +717,8 @@ int main(int argc, const char* argv[]) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.width,
gen_params.height,
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
@ -752,8 +748,8 @@ int main(int argc, const char* argv[]) {
end_image,
control_frames.data(),
(int)control_frames.size(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.width,
gen_params.height,
gen_params.sample_params,
gen_params.high_noise_sample_params,
gen_params.moe_boundary,

View File

@ -445,7 +445,7 @@ struct SDContextParams {
std::string photo_maker_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir = ".";
std::string lora_model_dir;
std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec;
@ -457,7 +457,6 @@ struct SDContextParams {
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
@ -616,13 +615,9 @@ struct SDContextParams {
"--vae-on-cpu",
"keep vae in cpu (for low vram)",
true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"",
"--diffusion-fa",
"use flash attention in the diffusion model only",
"use flash attention in the diffusion model",
true, &diffusion_flash_attn},
{"",
"--diffusion-conv-direct",
@ -909,7 +904,6 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@ -974,7 +968,6 @@ struct SDContextParams {
clip_on_cpu,
control_net_cpu,
vae_on_cpu,
flash_attn,
diffusion_flash_attn,
taesd_preview,
diffusion_conv_direct,
@ -1031,8 +1024,8 @@ struct SDGenerationParams {
std::string prompt_with_lora; // for metadata record only
std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified
int width = -1;
int height = -1;
int width = 512;
int height = 512;
int batch_count = 1;
std::string init_image_path;
std::string end_image_path;
@ -1485,17 +1478,17 @@ struct SDGenerationParams {
on_seed_arg},
{"",
"--sampling-method",
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] "
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] "
"(default: euler for Flux/SD3/Wan, euler_a otherwise)",
on_sample_method_arg},
{"",
"--high-noise-sampling-method",
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]"
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
" default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg},
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
on_scheduler_arg},
{"",
"--sigmas",
@ -1712,24 +1705,17 @@ struct SDGenerationParams {
}
}
bool width_and_height_are_set() const {
return width > 0 && height > 0;
}
void set_width_and_height_if_unset(int w, int h) {
if (!width_and_height_are_set()) {
LOG_INFO("set width x height to %d x %d", w, h);
width = w;
height = h;
}
}
int get_resolved_width() const { return (width > 0) ? width : 512; }
int get_resolved_height() const { return (height > 0) ? height : 512; }
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt;
if (width <= 0) {
LOG_ERROR("error: the width must be greater than 0\n");
return false;
}
if (height <= 0) {
LOG_ERROR("error: the height must be greater than 0\n");
return false;
}
if (sample_params.sample_steps <= 0) {
LOG_ERROR("error: the sample_steps must be greater than 0\n");
@ -2097,22 +2083,6 @@ uint8_t* load_image_from_file(const char* image_path,
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
}
bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int width;
int height;
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
if (image->data == nullptr) {
return false;
}
image->width = width;
image->height = height;
return true;
}
uint8_t* load_image_from_memory(const char* image_bytes,
int len,
int& width,

View File

@ -44,8 +44,7 @@ Context Options:
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--mmap whether to memory-map model
--fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-fa use flash attention in the diffusion model
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions
@ -100,14 +99,14 @@ Default Generation Options:
medium
--skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
--strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
@ -116,12 +115,12 @@ Default Generation Options:
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm, bong_tangent], default: discrete
kl_optimal, lcm], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])

View File

@ -785,11 +785,7 @@ int main(int argc, const char** argv) {
{"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"res 2s", RES_2S_SAMPLE_METHOD},
{"k_res_2s", RES_2S_SAMPLE_METHOD}};
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}};
auto it = hardcoded.find(name);
if (it != hardcoded.end()) return it->second;
return SAMPLE_METHOD_COUNT;

View File

@ -103,7 +103,7 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
@ -153,7 +153,7 @@ namespace Flux {
if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else {
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
}
x = mlp_2->forward(ctx, x);
return x;
@ -377,22 +377,25 @@ namespace Flux {
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
txt->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
0); // [N, n_txt_token, hidden_size]
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
img->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
@ -489,29 +492,43 @@ namespace Flux {
}
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
auto qkv = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads;
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else {
mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
}
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -564,9 +581,12 @@ namespace Flux {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
shift = m_vec[0]; // [N, hidden_size]
scale = m_vec[1]; // [N, hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
}
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
@ -728,7 +748,7 @@ namespace Flux {
int nerf_depth = 4;
int nerf_max_freqs = 8;
bool use_x0 = false;
bool fake_patch_size_x2 = false;
bool use_patch_size_32 = false;
};
struct FluxParams {
@ -766,10 +786,7 @@ namespace Flux {
Flux(FluxParams params)
: params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
if (params.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {params.patch_size / 2, params.patch_size / 2};
}
std::pair<int, int> kernel_size = {16, 16};
std::pair<int, int> stride = kernel_size;
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
@ -1014,14 +1031,16 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
}
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx,
txt_img,
txt_img->ne[0],
txt_img->ne[1],
img->ne[1],
txt_img->ne[2],
txt_img->nb[1],
txt_img->nb[2],
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -1063,7 +1082,7 @@ namespace Flux {
auto img = pad_to_patch_size(ctx, x);
auto orig_img = img;
if (params.chroma_radiance_params.fake_patch_size_x2) {
if (params.chroma_radiance_params.use_patch_size_32) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@ -1174,8 +1193,9 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, out);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@ -1284,7 +1304,6 @@ namespace Flux {
flux_params.use_mlp_silu_act = true;
}
int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
@ -1297,13 +1316,10 @@ namespace Flux {
flux_params.chroma_radiance_params.use_x0 = true;
}
if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32");
LOG_DEBUG("using patch size 32 prediction");
flux_params.chroma_radiance_params.use_patch_size_32 = true;
flux_params.patch_size = 32;
}
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = pair.second.ne[0];
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
@ -1335,11 +1351,6 @@ namespace Flux {
head_dim = pair.second.ne[0];
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
}
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);

2
ggml

@ -1 +1 @@
Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d
Subproject commit 8891ab6fc742ac1198736d3da3b73c730e42af84

View File

@ -687,8 +687,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
struct ggml_tensor* x,
int dim,
int64_t start,
int64_t end,
bool cont = true) {
int64_t end) {
GGML_ASSERT(dim >= 0 && dim < 4);
if (x->ne[dim] == 1) {
return x;
@ -703,15 +702,27 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
GGML_ASSERT(end > start && end <= x->ne[dim]);
int64_t slice_size = end - start;
int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
slice_ne[dim] = slice_size;
int perm[4] = {0, 1, 2, 3};
for (int i = dim; i < 3; ++i)
perm[i] = perm[i + 1];
perm[3] = dim;
x = ggml_view_4d(ctx, x,
slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
int inv_perm[4];
for (int i = 0; i < 4; ++i)
inv_perm[perm[i]] = i;
if (cont) {
if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x);
}
x = ggml_view_4d(
ctx, x,
x->ne[0], x->ne[1], x->ne[2], end - start,
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x);
}
@ -949,49 +960,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
return ggml_group_norm(ctx, a, 32, eps);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
struct ggml_tensor* x,
float factor,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_scale_inplace(ctx, x, factor);
} else {
x = ggml_scale(ctx, x, factor);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_inplace(ctx, x);
} else {
x = ggml_gelu(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_quick_inplace(ctx, x);
} else {
x = ggml_gelu_quick(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
@ -999,7 +967,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
bool force_prec_f32 = false,
float scale = 1.f) {
if (scale != 1.f) {
x = ggml_ext_scale(ctx, x, scale);
x = ggml_scale(ctx, x, scale);
}
if (x->ne[2] * x->ne[3] > 1024) {
// workaround: avoid ggml cuda error
@ -1018,7 +986,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
}
}
if (scale != 1.f) {
x = ggml_ext_scale(ctx, x, 1.f / scale);
x = ggml_scale(ctx, x, 1.f / scale);
}
if (b != nullptr) {
x = ggml_add_inplace(ctx, x, b);
@ -1087,7 +1055,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
bool circular_y = false,
float scale = 1.f) {
if (scale != 1.f) {
x = ggml_ext_scale(ctx, x, scale);
x = ggml_scale(ctx, x, scale);
}
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
@ -1105,7 +1073,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
}
if (scale != 1.f) {
x = ggml_ext_scale(ctx, x, 1.f / scale);
x = ggml_scale(ctx, x, 1.f / scale);
}
if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1203,7 +1171,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
int64_t ne2,
int64_t ne3) {
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
auto t = ggml_ext_scale(ctx, one, value); // [1,]
auto t = ggml_scale(ctx, one, value); // [1,]
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
return t;
}
@ -1257,6 +1225,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
struct ggml_tensor* v,
int64_t n_head,
struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false,
bool skip_reshape = false,
bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow
@ -1302,7 +1271,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
}
if (kv_scale != 1.0f) {
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
k_in = ggml_scale(ctx, k_in, kv_scale);
}
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
@ -1312,7 +1281,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
}
if (kv_scale != 1.0f) {
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
v_in = ggml_scale(ctx, v_in, kv_scale);
}
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
@ -1344,7 +1313,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
if (kv_scale != 1.0f) {
out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
out = ggml_scale(ctx, out, 1.0f / kv_scale);
}
return out;
};
@ -1384,6 +1353,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
if (mask) {
kq = ggml_add_inplace(ctx, kq, mask);
}
if (diag_mask_inf) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
@ -1551,7 +1523,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
int dim,
int max_period = 10000,
float time_factor = 1.0f) {
timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
timesteps = ggml_scale(ctx, timesteps, time_factor);
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
}
@ -2600,7 +2572,7 @@ public:
// x: [N, n_token, embed_dim]
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* mask = nullptr) {
bool mask = false) {
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
ggml_tensor* q;
@ -2623,7 +2595,7 @@ public:
v = v_proj->forward(ctx, x);
}
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;

View File

@ -638,7 +638,7 @@ namespace LLM {
x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x);
x = ggml_gelu(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
@ -881,7 +881,7 @@ namespace LLM {
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;

View File

@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier;
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
if (updown == nullptr) {
updown = curr_updown;
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
float scale_value = 1.0f;
scale_value *= multiplier;
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
if (updown == nullptr) {
updown = curr_updown;
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
if (updown == nullptr) {
updown = curr_updown;
} else {
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier;
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
if (updown == nullptr) {
updown = curr_updown;
@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.scale);
}
auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value);
if (out_diff == nullptr) {
out_diff = curr_out_diff;

View File

@ -33,7 +33,7 @@ public:
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = fc2->forward(ctx, x);
return x;
}
@ -211,7 +211,7 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
@ -284,19 +284,23 @@ public:
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int n_mods = 9;
int64_t n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
auto shift_msa = m_vec[0]; // [N, hidden_size]
auto scale_msa = m_vec[1]; // [N, hidden_size]
auto gate_msa = m_vec[2]; // [N, hidden_size]
auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto shift_msa2 = m_vec[6]; // [N, hidden_size]
auto scale_msa2 = m_vec[7]; // [N, hidden_size]
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x);
@ -318,20 +322,22 @@ public:
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int n_mods = 6;
int64_t n_mods = 6;
if (pre_only) {
n_mods = 2;
}
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
auto shift_msa = m_vec[0]; // [N, hidden_size]
auto scale_msa = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
if (!pre_only) {
auto gate_msa = m_vec[2]; // [N, hidden_size]
auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
@ -433,8 +439,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
@ -450,7 +456,7 @@ public:
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
@ -494,24 +500,26 @@ block_mixing(GGMLRunnerContext* ctx,
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
}
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
context->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
0); // [N, n_context, hidden_size]
0); // [n_context, N, hidden_size]
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
x->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
if (!context_block->pre_only) {
context = context_block->post_attention(ctx,
@ -526,7 +534,7 @@ block_mixing(GGMLRunnerContext* ctx,
}
if (x_block->self_attn) {
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx,
x_attn,
@ -597,9 +605,12 @@ public:
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
auto shift = m_vec[0]; // [N, hidden_size]
auto scale = m_vec[1]; // [N, hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);

View File

@ -376,11 +376,7 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix);
} else {
if (file_exists(file_path)) {
LOG_WARN("unknown format %s", file_path.c_str());
} else {
LOG_WARN("file %s not found", file_path.c_str());
}
return false;
}
}

View File

@ -842,7 +842,6 @@ std::string convert_sep_to_dot(std::string name) {
"conv_in",
"conv_out",
"lora_down",
"lora_mid",
"lora_up",
"diff_b",
"hada_w1_a",
@ -998,13 +997,10 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
if (is_lora) {
std::map<std::string, std::string> lora_suffix_map = {
{".lora_down.weight", ".weight.lora_down"},
{".lora_mid.weight", ".weight.lora_mid"},
{".lora_up.weight", ".weight.lora_up"},
{".lora.down.weight", ".weight.lora_down"},
{".lora.mid.weight", ".weight.lora_mid"},
{".lora.up.weight", ".weight.lora_up"},
{"_lora.down.weight", ".weight.lora_down"},
{"_lora.mid.weight", ".weight.lora_mid"},
{"_lora.up.weight", ".weight.lora_up"},
{".lora_A.weight", ".weight.lora_down"},
{".lora_B.weight", ".weight.lora_up"},

View File

@ -33,7 +33,7 @@ public:
x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue)
@ -129,8 +129,8 @@ public:
k = reshape_tensor(ctx->ggml_ctx, k, heads);
v = reshape_tensor(ctx->ggml_ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale);
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale);
// auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch

View File

@ -163,24 +163,25 @@ namespace Qwen {
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
txt->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
0); // [N, n_txt_token, n_head*d_head]
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
img->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img_attn_out = to_out_0->forward(ctx, img_attn_out);
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
@ -212,7 +213,7 @@ namespace Qwen {
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim,
attention_head_dim,

View File

@ -642,7 +642,7 @@ namespace Rope {
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope

View File

@ -67,8 +67,6 @@ const char* sampling_methods_str[] = {
"LCM",
"DDIM \"trailing\"",
"TCD",
"Res Multistep",
"Res 2s",
};
/*================================================== Helper Functions ================================================*/
@ -445,7 +443,7 @@ public:
}
}
if (is_chroma) {
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
@ -571,6 +569,14 @@ public:
}
}
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@ -618,7 +624,7 @@ public:
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
if (sd_version_is_sdxl(version) &&
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN(
@ -717,28 +723,6 @@ public:
pmid_model->get_param_tensors(tensors, "pmid");
}
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -2759,8 +2743,6 @@ const char* sample_method_to_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"res_multistep",
"res_2s",
};
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2790,7 +2772,6 @@ const char* scheduler_to_str[] = {
"smoothstep",
"kl_optimal",
"lcm",
"bong_tangent",
};
const char* sd_scheduler_name(enum scheduler_t scheduler) {
@ -2956,7 +2937,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
@ -2988,7 +2968,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),

View File

@ -48,8 +48,6 @@ enum sample_method_t {
LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD,
RES_MULTISTEP_SAMPLE_METHOD,
RES_2S_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};
@ -64,7 +62,6 @@ enum scheduler_t {
SMOOTHSTEP_SCHEDULER,
KL_OPTIMAL_SCHEDULER,
LCM_SCHEDULER,
BONG_TANGENT_SCHEDULER,
SCHEDULER_COUNT
};
@ -189,7 +186,6 @@ typedef struct {
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;

4
t5.hpp
View File

@ -515,7 +515,7 @@ public:
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x));
auto hidden_linear = wi_1->forward(ctx, x);
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
x = wo->forward(ctx, x);
@ -608,7 +608,7 @@ public:
}
}
k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)));
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]

65
tae.hpp
View File

@ -17,43 +17,22 @@ class TAEBlock : public UnaryBlock {
protected:
int n_in;
int n_out;
bool use_midblock_gn;
public:
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
TAEBlock(int n_in, int n_out)
: n_in(n_in), n_out(n_out) {
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
if (n_in != n_out) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
}
if (use_midblock_gn) {
int n_gn = n_in * 4;
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
// pool.2 is ReLU, handled in forward
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, n_in, h, w]
// return: [n, n_out, h, w]
if (use_midblock_gn) {
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
auto p = pool_0->forward(ctx, x);
p = pool_1->forward(ctx, p);
p = ggml_relu_inplace(ctx->ggml_ctx, p);
p = pool_3->forward(ctx, p);
x = ggml_add(ctx->ggml_ctx, x, p);
}
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
@ -83,7 +62,7 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3;
public:
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
TinyEncoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
@ -101,7 +80,7 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
@ -128,7 +107,7 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3;
public:
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
TinyDecoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0;
@ -136,7 +115,7 @@ public:
index++; // nn.ReLU()
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
@ -161,9 +140,9 @@ public:
// z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8]
auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
h = ggml_scale(ctx->ggml_ctx, h, 3.0f);
for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) {
@ -400,11 +379,10 @@ public:
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
// Clamp()
auto h = ggml_ext_scale(ctx->ggml_ctx,
auto h = ggml_scale_inplace(ctx->ggml_ctx,
ggml_tanh_inplace(ctx->ggml_ctx,
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f,
true);
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f);
h = first_conv->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
@ -492,44 +470,29 @@ public:
class TAESD : public GGMLBlock {
protected:
bool decode_only;
bool taef2 = false;
public:
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) {
int z_channels = 4;
bool use_midblock_gn = false;
taef2 = sd_version_is_flux2(version);
if (sd_version_is_dit(version)) {
z_channels = 16;
}
if (taef2) {
z_channels = 32;
use_midblock_gn = true;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
}
}
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
if (taef2) {
z = unpatchify(ctx->ggml_ctx, z, 2);
}
return decoder->forward(ctx, z);
}
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
auto z = encoder->forward(ctx, x);
if (taef2) {
z = patchify(ctx->ggml_ctx, z, 2);
}
return z;
return encoder->forward(ctx, x);
}
};

View File

@ -529,7 +529,7 @@ public:
}
}
if (controls.size() > 0) {
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
}
int control_offset = static_cast<int>(controls.size() - 2);
@ -542,7 +542,7 @@ public:
hs.pop_back();
if (controls.size() > 0) {
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--;
}

View File

@ -141,7 +141,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
@ -253,8 +253,8 @@ public:
float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx,
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
ggml_scale(ctx->ggml_ctx, x, alpha),
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w

23
wan.hpp
View File

@ -573,7 +573,7 @@ namespace WAN {
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
@ -1393,7 +1393,7 @@ namespace WAN {
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
@ -1442,8 +1442,11 @@ namespace WAN {
int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len;
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
@ -1455,8 +1458,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_add(ctx->ggml_ctx, x, img_x);
@ -1573,7 +1576,7 @@ namespace WAN {
y = modulate_add(ctx->ggml_ctx, y, es[3]);
y = ffn_0->forward(ctx, y);
y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
y = ggml_gelu_inplace(ctx->ggml_ctx, y);
y = ffn_2->forward(ctx, y);
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
@ -1720,7 +1723,7 @@ namespace WAN {
auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x);
@ -1907,7 +1910,7 @@ namespace WAN {
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context);
context = ggml_ext_gelu(ctx->ggml_ctx, context);
context = ggml_gelu(ctx->ggml_ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0;
@ -1946,7 +1949,7 @@ namespace WAN {
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first;
c = result.second;
c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength);
x = ggml_add(ctx->ggml_ctx, x, c_skip);
}
}

View File

@ -54,37 +54,15 @@ namespace ZImage {
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
0); // [N, n_token, num_heads, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim]
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim]
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
@ -517,7 +495,7 @@ namespace ZImage {
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
out = ggml_scale(ctx->ggml_ctx, out, -1.f);
return out;
}