Compare commits
30 Commits
master-462
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f957fa3d2a | ||
|
|
c252e03c6b | ||
|
|
e63daba33d | ||
|
|
3959109281 | ||
|
|
e411520407 | ||
|
|
43e829f219 | ||
|
|
7837232631 | ||
|
|
4ccce027b2 | ||
|
|
fa61ea744d | ||
|
|
5e4579c11d | ||
|
|
329571131d | ||
|
|
a48b4a3ade | ||
|
|
b87fe13afd | ||
|
|
e50e1f253d | ||
|
|
c6206fb351 | ||
|
|
639091fbe9 | ||
|
|
9293016c9d | ||
|
|
2efd19978d | ||
|
|
61659ef299 | ||
|
|
9565c7f6bd | ||
|
|
fbce16e02d | ||
|
|
7010bb4dff | ||
|
|
48d3161a8d | ||
|
|
271b594e74 | ||
|
|
885e62ea82 | ||
|
|
0e52afc651 | ||
|
|
27b5f17401 | ||
|
|
dfe6d6c664 | ||
|
|
9be0b91927 | ||
|
|
e7e83ed4d1 |
2
.github/workflows/build.yml
vendored
@ -207,7 +207,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
push: true
|
push: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
file: Dockerfile.${{ matrix.variant }}
|
file: Dockerfile.${{ matrix.variant }}
|
||||||
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.BRANCH_NAME }}-${{ matrix.variant }}
|
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.BRANCH_NAME }}-${{ matrix.variant }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
|||||||
13
README.md
@ -15,6 +15,9 @@ API and command-line option may change frequently.***
|
|||||||
|
|
||||||
## 🔥Important News
|
## 🔥Important News
|
||||||
|
|
||||||
|
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
|
||||||
|
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
|
||||||
|
|
||||||
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
|
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
|
||||||
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
|
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
|
||||||
|
|
||||||
@ -43,8 +46,8 @@ API and command-line option may change frequently.***
|
|||||||
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||||
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
|
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
|
||||||
- [FLUX.2-dev](./docs/flux2.md)
|
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
||||||
- [Qwen Image](./docs/qwen_image.md)
|
- [Qwen Image](./docs/qwen_image.md)
|
||||||
@ -70,7 +73,7 @@ API and command-line option may change frequently.***
|
|||||||
- SYCL
|
- SYCL
|
||||||
- Supported weight formats
|
- Supported weight formats
|
||||||
- Pytorch checkpoint (`.ckpt` or `.pth`)
|
- Pytorch checkpoint (`.ckpt` or `.pth`)
|
||||||
- Safetensors (`./safetensors`)
|
- Safetensors (`.safetensors`)
|
||||||
- GGUF (`.gguf`)
|
- GGUF (`.gguf`)
|
||||||
- Supported platforms
|
- Supported platforms
|
||||||
- Linux
|
- Linux
|
||||||
@ -127,8 +130,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
|
|||||||
|
|
||||||
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
|
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
|
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
|
||||||
- [FLUX.2-dev](./docs/flux2.md)
|
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
|
||||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
- [🔥Qwen Image](./docs/qwen_image.md)
|
- [🔥Qwen Image](./docs/qwen_image.md)
|
||||||
|
|||||||
BIN
assets/flux2/flux2-klein-4b-edit.png
Normal file
|
After Width: | Height: | Size: 510 KiB |
BIN
assets/flux2/flux2-klein-4b.png
Normal file
|
After Width: | Height: | Size: 455 KiB |
BIN
assets/flux2/flux2-klein-9b-edit.png
Normal file
|
After Width: | Height: | Size: 511 KiB |
BIN
assets/flux2/flux2-klein-9b.png
Normal file
|
After Width: | Height: | Size: 491 KiB |
BIN
assets/flux2/flux2-klein-base-4b.png
Normal file
|
After Width: | Height: | Size: 464 KiB |
BIN
assets/flux2/flux2-klein-base-9b.png
Normal file
|
After Width: | Height: | Size: 552 KiB |
BIN
assets/z_image/base_bf16.png
Normal file
|
After Width: | Height: | Size: 870 KiB |
47
clip.hpp
@ -479,9 +479,9 @@ public:
|
|||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
if (use_gelu) {
|
if (use_gelu) {
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
|
||||||
}
|
}
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
@ -510,7 +510,7 @@ public:
|
|||||||
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
||||||
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
||||||
@ -542,8 +542,8 @@ public:
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int clip_skip = -1,
|
struct ggml_tensor* mask = nullptr,
|
||||||
bool mask = true) {
|
int clip_skip = -1) {
|
||||||
// x: [N, n_token, d_model]
|
// x: [N, n_token, d_model]
|
||||||
int layer_idx = n_layer - 1;
|
int layer_idx = n_layer - 1;
|
||||||
// LOG_DEBUG("clip_skip %d", clip_skip);
|
// LOG_DEBUG("clip_skip %d", clip_skip);
|
||||||
@ -741,16 +741,17 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* tkn_embeddings,
|
struct ggml_tensor* tkn_embeddings,
|
||||||
size_t max_token_idx = 0,
|
struct ggml_tensor* mask = nullptr,
|
||||||
bool return_pooled = false,
|
size_t max_token_idx = 0,
|
||||||
int clip_skip = -1) {
|
bool return_pooled = false,
|
||||||
|
int clip_skip = -1) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
|
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
|
||||||
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
|
||||||
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
||||||
|
|
||||||
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
||||||
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
|
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
|
||||||
if (return_pooled || with_final_ln) {
|
if (return_pooled || with_final_ln) {
|
||||||
x = final_layer_norm->forward(ctx, x);
|
x = final_layer_norm->forward(ctx, x);
|
||||||
}
|
}
|
||||||
@ -814,10 +815,11 @@ public:
|
|||||||
|
|
||||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||||
x = pre_layernorm->forward(ctx, x);
|
x = pre_layernorm->forward(ctx, x);
|
||||||
x = encoder->forward(ctx, x, clip_skip, false);
|
x = encoder->forward(ctx, x, nullptr, clip_skip);
|
||||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
|
||||||
auto last_hidden_state = x;
|
auto last_hidden_state = x;
|
||||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
|
||||||
|
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
GGML_ASSERT(x->ne[3] == 1);
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
if (return_pooled) {
|
if (return_pooled) {
|
||||||
@ -905,6 +907,8 @@ public:
|
|||||||
struct CLIPTextModelRunner : public GGMLRunner {
|
struct CLIPTextModelRunner : public GGMLRunner {
|
||||||
CLIPTextModel model;
|
CLIPTextModel model;
|
||||||
|
|
||||||
|
std::vector<float> attention_mask_vec;
|
||||||
|
|
||||||
CLIPTextModelRunner(ggml_backend_t backend,
|
CLIPTextModelRunner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2TensorStorage& tensor_storage_map,
|
const String2TensorStorage& tensor_storage_map,
|
||||||
@ -938,6 +942,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* embeddings,
|
struct ggml_tensor* embeddings,
|
||||||
|
struct ggml_tensor* mask,
|
||||||
size_t max_token_idx = 0,
|
size_t max_token_idx = 0,
|
||||||
bool return_pooled = false,
|
bool return_pooled = false,
|
||||||
int clip_skip = -1) {
|
int clip_skip = -1) {
|
||||||
@ -948,7 +953,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||||
@ -975,9 +980,23 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
|||||||
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int n_tokens = static_cast<int>(input_ids->ne[0]);
|
||||||
|
attention_mask_vec.resize(n_tokens * n_tokens);
|
||||||
|
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||||
|
for (int i1 = 0; i1 < n_tokens; i1++) {
|
||||||
|
float value = 0.f;
|
||||||
|
if (i0 > i1) {
|
||||||
|
value = -INFINITY;
|
||||||
|
}
|
||||||
|
attention_mask_vec[i1 * n_tokens + i0] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
|
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
|||||||
10
common.hpp
@ -200,7 +200,7 @@ public:
|
|||||||
|
|
||||||
gate = ggml_cont(ctx->ggml_ctx, gate);
|
gate = ggml_cont(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
|
||||||
|
|
||||||
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ public:
|
|||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -317,7 +317,7 @@ public:
|
|||||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||||
|
|
||||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||||
return x;
|
return x;
|
||||||
@ -536,8 +536,8 @@ public:
|
|||||||
// image_only_indicator is always tensor([0.])
|
// image_only_indicator is always tensor([0.])
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
auto x = ggml_add(ctx->ggml_ctx,
|
auto x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
|
ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
|
||||||
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
|
ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -34,6 +34,7 @@ struct Conditioner {
|
|||||||
virtual void free_params_buffer() = 0;
|
virtual void free_params_buffer() = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||||
virtual size_t get_params_buffer_size() = 0;
|
virtual size_t get_params_buffer_size() = 0;
|
||||||
|
virtual void set_flash_attention_enabled(bool enabled) = 0;
|
||||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
|
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
|
||||||
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
return buffer_size;
|
return buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
text_model->set_flash_attention_enabled(enabled);
|
||||||
|
if (sd_version_is_sdxl(version)) {
|
||||||
|
text_model2->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
text_model->set_weight_adapter(adapter);
|
text_model->set_weight_adapter(adapter);
|
||||||
if (sd_version_is_sdxl(version)) {
|
if (sd_version_is_sdxl(version)) {
|
||||||
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
return buffer_size;
|
return buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
if (clip_l) {
|
||||||
|
clip_l->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
if (clip_g) {
|
||||||
|
clip_g->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
if (t5) {
|
||||||
|
t5->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
if (clip_l) {
|
if (clip_l) {
|
||||||
clip_l->set_weight_adapter(adapter);
|
clip_l->set_weight_adapter(adapter);
|
||||||
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
return buffer_size;
|
return buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
if (clip_l) {
|
||||||
|
clip_l->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
if (t5) {
|
||||||
|
t5->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
|
||||||
if (clip_l) {
|
if (clip_l) {
|
||||||
clip_l->set_weight_adapter(adapter);
|
clip_l->set_weight_adapter(adapter);
|
||||||
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
return buffer_size;
|
return buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
if (t5) {
|
||||||
|
t5->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
if (t5) {
|
if (t5) {
|
||||||
t5->set_weight_adapter(adapter);
|
t5->set_weight_adapter(adapter);
|
||||||
@ -1614,9 +1649,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
bool enable_vision = false)
|
bool enable_vision = false)
|
||||||
: version(version) {
|
: version(version) {
|
||||||
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
||||||
if (sd_version_is_flux2(version)) {
|
if (version == VERSION_FLUX2) {
|
||||||
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
||||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
|
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
}
|
}
|
||||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||||
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
return buffer_size;
|
return buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
llm->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
if (llm) {
|
if (llm) {
|
||||||
llm->set_weight_adapter(adapter);
|
llm->set_weight_adapter(adapter);
|
||||||
@ -1708,6 +1747,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
int prompt_template_encode_start_idx = 34;
|
int prompt_template_encode_start_idx = 34;
|
||||||
int max_length = 0;
|
int max_length = 0;
|
||||||
std::set<int> out_layers;
|
std::set<int> out_layers;
|
||||||
|
std::vector<int> tokens;
|
||||||
|
std::vector<float> weights;
|
||||||
|
std::vector<float> mask;
|
||||||
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||||
LOG_INFO("QwenImageEditPlusPipeline");
|
LOG_INFO("QwenImageEditPlusPipeline");
|
||||||
prompt_template_encode_start_idx = 64;
|
prompt_template_encode_start_idx = 64;
|
||||||
@ -1771,7 +1813,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
out_layers = {10, 20, 30};
|
||||||
|
|
||||||
@ -1793,17 +1835,28 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2_KLEIN) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
max_length = 512;
|
||||||
|
out_layers = {9, 18, 27};
|
||||||
|
|
||||||
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
prompt = "<|im_start|>user\n";
|
||||||
|
|
||||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
prompt += conditioner_params.text;
|
prompt += conditioner_params.text;
|
||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "[/INST]";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||||
|
|
||||||
|
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
|
||||||
|
tokens = std::get<0>(tokens_and_weights);
|
||||||
|
weights = std::get<1>(tokens_and_weights);
|
||||||
|
|
||||||
|
mask.insert(mask.end(), tokens.size(), 1.f);
|
||||||
|
if (tokens.size() < max_length) {
|
||||||
|
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
|
||||||
|
tokenizer->pad_tokens(tokens, weights, max_length, true);
|
||||||
|
}
|
||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
prompt_template_encode_start_idx = 28;
|
prompt_template_encode_start_idx = 28;
|
||||||
max_length = prompt_template_encode_start_idx + 256;
|
max_length = prompt_template_encode_start_idx + 256;
|
||||||
@ -1827,17 +1880,34 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
if (tokens.empty()) {
|
||||||
auto& tokens = std::get<0>(tokens_and_weights);
|
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
||||||
auto& weights = std::get<1>(tokens_and_weights);
|
tokens = std::get<0>(tokens_and_weights);
|
||||||
|
weights = std::get<1>(tokens_and_weights);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
|
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
|
||||||
|
|
||||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||||
|
|
||||||
|
ggml_tensor* attention_mask = nullptr;
|
||||||
|
if (!mask.empty()) {
|
||||||
|
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
|
||||||
|
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = 0.f;
|
||||||
|
if (mask[i0] == 0.f) {
|
||||||
|
value = -INFINITY;
|
||||||
|
} else if (i0 > i1) {
|
||||||
|
value = -INFINITY;
|
||||||
|
}
|
||||||
|
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
llm->compute(n_threads,
|
llm->compute(n_threads,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
image_embeds,
|
image_embeds,
|
||||||
out_layers,
|
out_layers,
|
||||||
&hidden_states,
|
&hidden_states,
|
||||||
@ -1861,7 +1931,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||||
|
|
||||||
int64_t min_length = 0;
|
int64_t min_length = 0;
|
||||||
if (sd_version_is_flux2(version)) {
|
if (version == VERSION_FLUX2) {
|
||||||
min_length = 512;
|
min_length = 512;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
305
denoiser.hpp
@ -1,6 +1,8 @@
|
|||||||
#ifndef __DENOISER_HPP__
|
#ifndef __DENOISER_HPP__
|
||||||
#define __DENOISER_HPP__
|
#define __DENOISER_HPP__
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "gits_noise.inl"
|
#include "gits_noise.inl"
|
||||||
|
|
||||||
@ -351,6 +353,95 @@ struct SmoothStepScheduler : SigmaScheduler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct BongTangentScheduler : SigmaScheduler {
|
||||||
|
static constexpr float kPi = 3.14159265358979323846f;
|
||||||
|
|
||||||
|
static std::vector<float> get_bong_tangent_sigmas(int steps, float slope, float pivot, float start, float end) {
|
||||||
|
std::vector<float> sigmas;
|
||||||
|
if (steps <= 0) {
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
|
||||||
|
float smax = ((2.0f / kPi) * atanf(-slope * (0.0f - pivot)) + 1.0f) * 0.5f;
|
||||||
|
float smin = ((2.0f / kPi) * atanf(-slope * ((float)(steps - 1) - pivot)) + 1.0f) * 0.5f;
|
||||||
|
float srange = smax - smin;
|
||||||
|
float sscale = start - end;
|
||||||
|
|
||||||
|
sigmas.reserve(steps);
|
||||||
|
|
||||||
|
if (fabsf(srange) < 1e-8f) {
|
||||||
|
if (steps == 1) {
|
||||||
|
sigmas.push_back(start);
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < steps; ++i) {
|
||||||
|
float t = (float)i / (float)(steps - 1);
|
||||||
|
sigmas.push_back(start + (end - start) * t);
|
||||||
|
}
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_srange = 1.0f / srange;
|
||||||
|
for (int x = 0; x < steps; ++x) {
|
||||||
|
float v = ((2.0f / kPi) * atanf(-slope * ((float)x - pivot)) + 1.0f) * 0.5f;
|
||||||
|
float sigma = ((v - smin) * inv_srange) * sscale + end;
|
||||||
|
sigmas.push_back(sigma);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t /*t_to_sigma*/) override {
|
||||||
|
std::vector<float> result;
|
||||||
|
if (n == 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
float start = sigma_max;
|
||||||
|
float end = sigma_min;
|
||||||
|
float middle = sigma_min + (sigma_max - sigma_min) * 0.5f;
|
||||||
|
|
||||||
|
float pivot_1 = 0.6f;
|
||||||
|
float pivot_2 = 0.6f;
|
||||||
|
float slope_1 = 0.2f;
|
||||||
|
float slope_2 = 0.2f;
|
||||||
|
|
||||||
|
int steps = static_cast<int>(n) + 2;
|
||||||
|
int midpoint = static_cast<int>(((float)steps * pivot_1 + (float)steps * pivot_2) * 0.5f);
|
||||||
|
int pivot_1_i = static_cast<int>((float)steps * pivot_1);
|
||||||
|
int pivot_2_i = static_cast<int>((float)steps * pivot_2);
|
||||||
|
|
||||||
|
float slope_scale = (float)steps / 40.0f;
|
||||||
|
slope_1 = slope_1 / slope_scale;
|
||||||
|
slope_2 = slope_2 / slope_scale;
|
||||||
|
|
||||||
|
int stage_2_len = steps - midpoint;
|
||||||
|
int stage_1_len = steps - stage_2_len;
|
||||||
|
|
||||||
|
std::vector<float> sigmas_1 = get_bong_tangent_sigmas(stage_1_len, slope_1, (float)pivot_1_i, start, middle);
|
||||||
|
std::vector<float> sigmas_2 = get_bong_tangent_sigmas(stage_2_len, slope_2, (float)(pivot_2_i - stage_1_len), middle, end);
|
||||||
|
|
||||||
|
if (!sigmas_1.empty()) {
|
||||||
|
sigmas_1.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
result.reserve(n + 1);
|
||||||
|
result.insert(result.end(), sigmas_1.begin(), sigmas_1.end());
|
||||||
|
result.insert(result.end(), sigmas_2.begin(), sigmas_2.end());
|
||||||
|
|
||||||
|
if (result.size() < n + 1) {
|
||||||
|
while (result.size() < n + 1) {
|
||||||
|
result.push_back(end);
|
||||||
|
}
|
||||||
|
} else if (result.size() > n + 1) {
|
||||||
|
result.resize(n + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
result[n] = 0.0f;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct KLOptimalScheduler : SigmaScheduler {
|
struct KLOptimalScheduler : SigmaScheduler {
|
||||||
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
|
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
|
||||||
std::vector<float> sigmas;
|
std::vector<float> sigmas;
|
||||||
@ -431,6 +522,10 @@ struct Denoiser {
|
|||||||
LOG_INFO("get_sigmas with SmoothStep scheduler");
|
LOG_INFO("get_sigmas with SmoothStep scheduler");
|
||||||
scheduler = std::make_shared<SmoothStepScheduler>();
|
scheduler = std::make_shared<SmoothStepScheduler>();
|
||||||
break;
|
break;
|
||||||
|
case BONG_TANGENT_SCHEDULER:
|
||||||
|
LOG_INFO("get_sigmas with bong_tangent scheduler");
|
||||||
|
scheduler = std::make_shared<BongTangentScheduler>();
|
||||||
|
break;
|
||||||
case KL_OPTIMAL_SCHEDULER:
|
case KL_OPTIMAL_SCHEDULER:
|
||||||
LOG_INFO("get_sigmas with KL Optimal scheduler");
|
LOG_INFO("get_sigmas with KL Optimal scheduler");
|
||||||
scheduler = std::make_shared<KLOptimalScheduler>();
|
scheduler = std::make_shared<KLOptimalScheduler>();
|
||||||
@ -1634,6 +1729,216 @@ static bool sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case RES_MULTISTEP_SAMPLE_METHOD: // Res Multistep sampler
|
||||||
|
{
|
||||||
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
bool have_old_sigma = false;
|
||||||
|
float old_sigma_down = 0.0f;
|
||||||
|
|
||||||
|
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
|
||||||
|
auto sigma_fn = [](float t) -> float { return expf(-t); };
|
||||||
|
auto phi1_fn = [](float t) -> float {
|
||||||
|
if (fabsf(t) < 1e-6f) {
|
||||||
|
return 1.0f + t * 0.5f + (t * t) / 6.0f;
|
||||||
|
}
|
||||||
|
return (expf(t) - 1.0f) / t;
|
||||||
|
};
|
||||||
|
auto phi2_fn = [&](float t) -> float {
|
||||||
|
if (fabsf(t) < 1e-6f) {
|
||||||
|
return 0.5f + t / 6.0f + (t * t) / 24.0f;
|
||||||
|
}
|
||||||
|
float phi1_val = phi1_fn(t);
|
||||||
|
return (phi1_val - 1.0f) / t;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
|
||||||
|
if (denoised == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sigma_from = sigmas[i];
|
||||||
|
float sigma_to = sigmas[i + 1];
|
||||||
|
float sigma_up = 0.0f;
|
||||||
|
float sigma_down = sigma_to;
|
||||||
|
|
||||||
|
if (eta > 0.0f) {
|
||||||
|
float sigma_from_sq = sigma_from * sigma_from;
|
||||||
|
float sigma_to_sq = sigma_to * sigma_to;
|
||||||
|
if (sigma_from_sq > 0.0f) {
|
||||||
|
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
|
||||||
|
if (term > 0.0f) {
|
||||||
|
sigma_up = eta * std::sqrt(term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sigma_up = std::min(sigma_up, sigma_to);
|
||||||
|
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
|
||||||
|
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sigma_down == 0.0f || !have_old_sigma) {
|
||||||
|
float dt = sigma_down - sigma_from;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
float d = (vec_x[j] - vec_denoised[j]) / sigma_from;
|
||||||
|
vec_x[j] = vec_x[j] + d * dt;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float t = t_fn(sigma_from);
|
||||||
|
float t_old = t_fn(old_sigma_down);
|
||||||
|
float t_next = t_fn(sigma_down);
|
||||||
|
float t_prev = t_fn(sigmas[i - 1]);
|
||||||
|
float h = t_next - t;
|
||||||
|
float c2 = (t_prev - t_old) / h;
|
||||||
|
|
||||||
|
float phi1_val = phi1_fn(-h);
|
||||||
|
float phi2_val = phi2_fn(-h);
|
||||||
|
float b1 = phi1_val - phi2_val / c2;
|
||||||
|
float b2 = phi2_val / c2;
|
||||||
|
|
||||||
|
if (!std::isfinite(b1)) {
|
||||||
|
b1 = 0.0f;
|
||||||
|
}
|
||||||
|
if (!std::isfinite(b2)) {
|
||||||
|
b2 = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sigma_h = sigma_fn(h);
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
float* vec_old_denoised = (float*)old_denoised->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] = sigma_h * vec_x[j] + h * (b1 * vec_denoised[j] + b2 * vec_old_denoised[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
|
||||||
|
ggml_ext_im_set_randn_f32(noise, rng);
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_noise = (float*)noise->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float* vec_old_denoised = (float*)old_denoised->data;
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_old_denoised[j] = vec_denoised[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
old_sigma_down = sigma_down;
|
||||||
|
have_old_sigma = true;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case RES_2S_SAMPLE_METHOD: // Res 2s sampler
|
||||||
|
{
|
||||||
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* x0 = ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
const float c2 = 0.5f;
|
||||||
|
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
|
||||||
|
auto phi1_fn = [](float t) -> float {
|
||||||
|
if (fabsf(t) < 1e-6f) {
|
||||||
|
return 1.0f + t * 0.5f + (t * t) / 6.0f;
|
||||||
|
}
|
||||||
|
return (expf(t) - 1.0f) / t;
|
||||||
|
};
|
||||||
|
auto phi2_fn = [&](float t) -> float {
|
||||||
|
if (fabsf(t) < 1e-6f) {
|
||||||
|
return 0.5f + t / 6.0f + (t * t) / 24.0f;
|
||||||
|
}
|
||||||
|
float phi1_val = phi1_fn(t);
|
||||||
|
return (phi1_val - 1.0f) / t;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
float sigma_from = sigmas[i];
|
||||||
|
float sigma_to = sigmas[i + 1];
|
||||||
|
|
||||||
|
ggml_tensor* denoised = model(x, sigma_from, -(i + 1));
|
||||||
|
if (denoised == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sigma_up = 0.0f;
|
||||||
|
float sigma_down = sigma_to;
|
||||||
|
if (eta > 0.0f) {
|
||||||
|
float sigma_from_sq = sigma_from * sigma_from;
|
||||||
|
float sigma_to_sq = sigma_to * sigma_to;
|
||||||
|
if (sigma_from_sq > 0.0f) {
|
||||||
|
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
|
||||||
|
if (term > 0.0f) {
|
||||||
|
sigma_up = eta * std::sqrt(term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sigma_up = std::min(sigma_up, sigma_to);
|
||||||
|
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
|
||||||
|
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_x0 = (float*)x0->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x0[j] = vec_x[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sigma_down == 0.0f || sigma_from == 0.0f) {
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] = vec_denoised[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float t = t_fn(sigma_from);
|
||||||
|
float t_next = t_fn(sigma_down);
|
||||||
|
float h = t_next - t;
|
||||||
|
|
||||||
|
float a21 = c2 * phi1_fn(-h * c2);
|
||||||
|
float phi1_val = phi1_fn(-h);
|
||||||
|
float phi2_val = phi2_fn(-h);
|
||||||
|
float b2 = phi2_val / c2;
|
||||||
|
float b1 = phi1_val - b2;
|
||||||
|
|
||||||
|
float sigma_c2 = expf(-(t + h * c2));
|
||||||
|
|
||||||
|
float* vec_denoised = (float*)denoised->data;
|
||||||
|
float* vec_x2 = (float*)x2->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
float eps1 = vec_denoised[j] - vec_x0[j];
|
||||||
|
vec_x2[j] = vec_x0[j] + h * a21 * eps1;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* denoised2 = model(x2, sigma_c2, i + 1);
|
||||||
|
if (denoised2 == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
float* vec_denoised2 = (float*)denoised2->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
float eps1 = vec_denoised[j] - vec_x0[j];
|
||||||
|
float eps2 = vec_denoised2[j] - vec_x0[j];
|
||||||
|
vec_x[j] = vec_x0[j] + h * (b1 * eps1 + b2 * eps2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
|
||||||
|
ggml_ext_im_set_randn_f32(noise, rng);
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_noise = (float*)noise->data;
|
||||||
|
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
||||||
|
|||||||
@ -38,7 +38,7 @@ struct DiffusionModel {
|
|||||||
virtual size_t get_params_buffer_size() = 0;
|
virtual size_t get_params_buffer_size() = 0;
|
||||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
||||||
virtual int64_t get_adm_in_channels() = 0;
|
virtual int64_t get_adm_in_channels() = 0;
|
||||||
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
virtual void set_flash_attention_enabled(bool enabled) = 0;
|
||||||
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
|
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
|
|||||||
return unet.unet.adm_in_channels;
|
return unet.unet.adm_in_channels;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
unet.set_flash_attention_enabled(enabled);
|
unet.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
|
|||||||
return 768 + 1280;
|
return 768 + 1280;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
mmdit.set_flash_attention_enabled(enabled);
|
mmdit.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
flux.set_flash_attention_enabled(enabled);
|
flux.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
wan.set_flash_attention_enabled(enabled);
|
wan.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
qwen_image.set_flash_attention_enabled(enabled);
|
qwen_image.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
|
|||||||
return 768;
|
return 768;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flash_attn_enabled(bool enabled) {
|
void set_flash_attention_enabled(bool enabled) {
|
||||||
z_image.set_flash_attention_enabled(enabled);
|
z_image.set_flash_attention_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
# Running distilled models: SSD1B and SDx.x with tiny U-Nets
|
# Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets
|
||||||
|
|
||||||
## Preface
|
## Preface
|
||||||
|
|
||||||
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
||||||
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
|
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
|
||||||
|
|
||||||
## SSD1B
|
## SSD1B
|
||||||
@ -17,7 +17,17 @@ Useful LoRAs are also available:
|
|||||||
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
||||||
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
|
||||||
These files can be used out-of-the-box, unlike the models described in the next section.
|
## Vega
|
||||||
|
|
||||||
|
Segmind's Vega model is available online here:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors
|
||||||
|
|
||||||
|
VegaRT is an example for an LCM-LoRA:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
|
||||||
|
Both files can be used out-of-the-box, unlike the models described in next sections.
|
||||||
|
|
||||||
|
|
||||||
## SD1.x, SD2.x with tiny U-Nets
|
## SD1.x, SD2.x with tiny U-Nets
|
||||||
@ -83,7 +93,7 @@ python convert_diffusers_to_original_stable_diffusion.py \
|
|||||||
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
|
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
|
||||||
|
|
||||||
|
|
||||||
### Another available .ckpt file:
|
##### Another available .ckpt file:
|
||||||
|
|
||||||
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
|
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
|
||||||
|
|
||||||
@ -97,3 +107,31 @@ for key, value in ckpt['state_dict'].items():
|
|||||||
ckpt['state_dict'][key] = value.contiguous()
|
ckpt['state_dict'][key] = value.contiguous()
|
||||||
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
|
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### SDXS-512
|
||||||
|
|
||||||
|
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
|
||||||
|
|
||||||
|
##### 1. Download the diffusers model from Hugging Face using Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
|
||||||
|
pipe.save_pretrained(save_directory="sdxs")
|
||||||
|
```
|
||||||
|
##### 2. Create a safetensors file
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert_diffusers_to_original_stable_diffusion.py \
|
||||||
|
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 3. Run the model as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
|
||||||
|
--cfg-scale 1 --steps 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
## Using ESRGAN to upscale results
|
## Using ESRGAN to upscale results
|
||||||
|
|
||||||
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon.
|
You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity.
|
||||||
|
|
||||||
- Specify the model path using the `--upscale-model PATH` parameter. example:
|
- Specify the model path using the `--upscale-model PATH` parameter. example:
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# How to Use
|
# How to Use
|
||||||
|
|
||||||
## Download weights
|
## Flux.2-dev
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
- Download FLUX.2-dev
|
- Download FLUX.2-dev
|
||||||
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
|
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
|
||||||
@ -9,7 +11,7 @@
|
|||||||
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
|
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
|
||||||
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
|
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
|
||||||
|
|
||||||
## Examples
|
### Examples
|
||||||
|
|
||||||
```
|
```
|
||||||
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
|
||||||
@ -17,5 +19,74 @@
|
|||||||
|
|
||||||
<img alt="flux2 example" src="../assets/flux2/example.png" />
|
<img alt="flux2 example" src="../assets/flux2/example.png" />
|
||||||
|
|
||||||
|
## Flux.2 klein 4B / Flux.2 klein base 4B
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
|
- Download FLUX.2-klein-4B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main
|
||||||
|
- Download FLUX.2-klein-base-4B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||||
|
- Download Qwen3 4b
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
|
||||||
|
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-4b" src="../assets/flux2/flux2-klein-4b.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-4b-edit" src="../assets/flux2/flux2-klein-4b-edit.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-base-4b" src="../assets/flux2/flux2-klein-base-4b.png" />
|
||||||
|
|
||||||
|
## Flux.2 klein 9B / Flux.2 klein base 9B
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
|
- Download FLUX.2-klein-9B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main
|
||||||
|
- Download FLUX.2-klein-base-9B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||||
|
- Download Qwen3 8B
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders
|
||||||
|
- gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-9b" src="../assets/flux2/flux2-klein-9b.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-9b-edit" src="../assets/flux2/flux2-klein-9b-edit.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-base-9b" src="../assets/flux2/flux2-klein-base-9b.png" />
|
||||||
@ -7,6 +7,9 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
|
|||||||
- Download Z-Image-Turbo
|
- Download Z-Image-Turbo
|
||||||
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
|
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
|
||||||
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
|
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
|
||||||
|
- Download Z-Image
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/z_image/tree/main/split_files/diffusion_models
|
||||||
|
- gguf: https://huggingface.co/unsloth/Z-Image-GGUF/tree/main
|
||||||
- Download vae
|
- Download vae
|
||||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
|
||||||
- Download Qwen3 4b
|
- Download Qwen3 4b
|
||||||
@ -15,12 +18,22 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
### Z-Image-Turbo
|
||||||
|
|
||||||
```
|
```
|
||||||
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
|
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
|
||||||
```
|
```
|
||||||
|
|
||||||
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />
|
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />
|
||||||
|
|
||||||
|
### Z-Image-Base
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\z_image_bf16.safetensors --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
|
||||||
|
```
|
||||||
|
|
||||||
|
<img width="256" alt="z-image example" src="../assets/z_image/base_bf16.png" />
|
||||||
|
|
||||||
## Comparison of Different Quantization Types
|
## Comparison of Different Quantization Types
|
||||||
|
|
||||||
| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|
|
| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|
|
||||||
|
|||||||
@ -51,7 +51,7 @@ public:
|
|||||||
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
|
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
|
||||||
auto x5 = conv5->forward(ctx, x_cat);
|
auto x5 = conv5->forward(ctx, x_cat);
|
||||||
|
|
||||||
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
|
x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
|
||||||
return x5;
|
return x5;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -76,7 +76,7 @@ public:
|
|||||||
out = rdb2->forward(ctx, out);
|
out = rdb2->forward(ctx, out);
|
||||||
out = rdb3->forward(ctx, out);
|
out = rdb3->forward(ctx, out);
|
||||||
|
|
||||||
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
|
out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -52,7 +52,8 @@ Context Options:
|
|||||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||||
--clip-on-cpu keep clip in cpu (for low vram)
|
--clip-on-cpu keep clip in cpu (for low vram)
|
||||||
--vae-on-cpu keep vae in cpu (for low vram)
|
--vae-on-cpu keep vae in cpu (for low vram)
|
||||||
--diffusion-fa use flash attention in the diffusion model
|
--fa use flash attention
|
||||||
|
--diffusion-fa use flash attention in the diffusion model only
|
||||||
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
||||||
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
||||||
--circular enable circular padding for convolutions
|
--circular enable circular padding for convolutions
|
||||||
@ -107,14 +108,14 @@ Generation Options:
|
|||||||
medium
|
medium
|
||||||
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
||||||
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
||||||
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
|
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
|
||||||
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
||||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
||||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
||||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
||||||
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
||||||
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
||||||
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
|
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
|
||||||
--strength <float> strength for noising/unnoising (default: 0.75)
|
--strength <float> strength for noising/unnoising (default: 0.75)
|
||||||
--pm-style-strength <float>
|
--pm-style-strength <float>
|
||||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
||||||
@ -123,12 +124,12 @@ Generation Options:
|
|||||||
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
||||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
|
||||||
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
||||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
|
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
||||||
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
|
tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
||||||
kl_optimal, lcm], default: discrete
|
kl_optimal, lcm, bong_tangent], default: discrete
|
||||||
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
||||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||||
|
|||||||
@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
|
|||||||
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
|
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
|
||||||
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
|
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
|
||||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||||
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
|
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
|
||||||
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
|
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
|
||||||
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
|
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
|
||||||
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
|
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
|
||||||
@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool vae_decode_only = true;
|
bool vae_decode_only = true;
|
||||||
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
sd_image_t init_image = {0, 0, 3, nullptr};
|
||||||
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
sd_image_t end_image = {0, 0, 3, nullptr};
|
||||||
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
sd_image_t control_image = {0, 0, 3, nullptr};
|
||||||
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
|
sd_image_t mask_image = {0, 0, 1, nullptr};
|
||||||
std::vector<sd_image_t> ref_images;
|
std::vector<sd_image_t> ref_images;
|
||||||
std::vector<sd_image_t> pmid_images;
|
std::vector<sd_image_t> pmid_images;
|
||||||
std::vector<sd_image_t> control_frames;
|
std::vector<sd_image_t> control_frames;
|
||||||
@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
|
|||||||
control_frames.clear();
|
control_frames.clear();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto load_image_and_update_size = [&](const std::string& path,
|
||||||
|
sd_image_t& image,
|
||||||
|
bool resize_image = true,
|
||||||
|
int expected_channel = 3) -> bool {
|
||||||
|
int expected_width = 0;
|
||||||
|
int expected_height = 0;
|
||||||
|
if (resize_image && gen_params.width_and_height_are_set()) {
|
||||||
|
expected_width = gen_params.width;
|
||||||
|
expected_height = gen_params.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
|
||||||
|
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||||
|
release_all_resources();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
if (gen_params.init_image_path.size() > 0) {
|
if (gen_params.init_image_path.size() > 0) {
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
|
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
|
||||||
int width = 0;
|
|
||||||
int height = 0;
|
|
||||||
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
|
||||||
if (init_image.data == nullptr) {
|
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
|
|
||||||
release_all_resources();
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.end_image_path.size() > 0) {
|
if (gen_params.end_image_path.size() > 0) {
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
|
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
|
||||||
int width = 0;
|
|
||||||
int height = 0;
|
|
||||||
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
|
||||||
if (end_image.data == nullptr) {
|
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
|
|
||||||
release_all_resources();
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gen_params.ref_image_paths.size() > 0) {
|
||||||
|
vae_decode_only = false;
|
||||||
|
for (auto& path : gen_params.ref_image_paths) {
|
||||||
|
sd_image_t ref_image = {0, 0, 3, nullptr};
|
||||||
|
if (!load_image_and_update_size(path, ref_image, false)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
ref_images.push_back(ref_image);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (gen_params.mask_image_path.size() > 0) {
|
if (gen_params.mask_image_path.size() > 0) {
|
||||||
int c = 0;
|
if (!load_sd_image_from_file(&mask_image,
|
||||||
int width = 0;
|
gen_params.mask_image_path.c_str(),
|
||||||
int height = 0;
|
gen_params.get_resolved_width(),
|
||||||
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
|
gen_params.get_resolved_height(),
|
||||||
if (mask_image.data == nullptr) {
|
1)) {
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
|
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||||
memset(mask_image.data, 255, gen_params.width * gen_params.height);
|
|
||||||
if (mask_image.data == nullptr) {
|
if (mask_image.data == nullptr) {
|
||||||
LOG_ERROR("malloc mask image failed");
|
LOG_ERROR("malloc mask image failed");
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
mask_image.width = gen_params.get_resolved_width();
|
||||||
|
mask_image.height = gen_params.get_resolved_height();
|
||||||
|
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.control_image_path.size() > 0) {
|
if (gen_params.control_image_path.size() > 0) {
|
||||||
int width = 0;
|
if (!load_sd_image_from_file(&control_image,
|
||||||
int height = 0;
|
gen_params.control_image_path.c_str(),
|
||||||
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
gen_params.get_resolved_width(),
|
||||||
if (control_image.data == nullptr) {
|
gen_params.get_resolved_height())) {
|
||||||
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.ref_image_paths.size() > 0) {
|
|
||||||
vae_decode_only = false;
|
|
||||||
for (auto& path : gen_params.ref_image_paths) {
|
|
||||||
int width = 0;
|
|
||||||
int height = 0;
|
|
||||||
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
|
|
||||||
if (image_buffer == nullptr) {
|
|
||||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
|
||||||
release_all_resources();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
ref_images.push_back({(uint32_t)width,
|
|
||||||
(uint32_t)height,
|
|
||||||
3,
|
|
||||||
image_buffer});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!gen_params.control_video_path.empty()) {
|
if (!gen_params.control_video_path.empty()) {
|
||||||
if (!load_images_from_dir(gen_params.control_video_path,
|
if (!load_images_from_dir(gen_params.control_video_path,
|
||||||
control_frames,
|
control_frames,
|
||||||
gen_params.width,
|
gen_params.get_resolved_width(),
|
||||||
gen_params.height,
|
gen_params.get_resolved_height(),
|
||||||
gen_params.video_frames,
|
gen_params.video_frames,
|
||||||
cli_params.verbose)) {
|
cli_params.verbose)) {
|
||||||
release_all_resources();
|
release_all_resources();
|
||||||
@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
gen_params.auto_resize_ref_image,
|
gen_params.auto_resize_ref_image,
|
||||||
gen_params.increase_ref_index,
|
gen_params.increase_ref_index,
|
||||||
mask_image,
|
mask_image,
|
||||||
gen_params.width,
|
gen_params.get_resolved_width(),
|
||||||
gen_params.height,
|
gen_params.get_resolved_height(),
|
||||||
gen_params.sample_params,
|
gen_params.sample_params,
|
||||||
gen_params.strength,
|
gen_params.strength,
|
||||||
gen_params.seed,
|
gen_params.seed,
|
||||||
@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
end_image,
|
end_image,
|
||||||
control_frames.data(),
|
control_frames.data(),
|
||||||
(int)control_frames.size(),
|
(int)control_frames.size(),
|
||||||
gen_params.width,
|
gen_params.get_resolved_width(),
|
||||||
gen_params.height,
|
gen_params.get_resolved_height(),
|
||||||
gen_params.sample_params,
|
gen_params.sample_params,
|
||||||
gen_params.high_noise_sample_params,
|
gen_params.high_noise_sample_params,
|
||||||
gen_params.moe_boundary,
|
gen_params.moe_boundary,
|
||||||
@ -757,6 +761,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
gen_params.seed,
|
gen_params.seed,
|
||||||
gen_params.video_frames,
|
gen_params.video_frames,
|
||||||
gen_params.vace_strength,
|
gen_params.vace_strength,
|
||||||
|
ctx_params.vae_tiling_params,
|
||||||
gen_params.cache_params,
|
gen_params.cache_params,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -445,7 +445,7 @@ struct SDContextParams {
|
|||||||
std::string photo_maker_path;
|
std::string photo_maker_path;
|
||||||
sd_type_t wtype = SD_TYPE_COUNT;
|
sd_type_t wtype = SD_TYPE_COUNT;
|
||||||
std::string tensor_type_rules;
|
std::string tensor_type_rules;
|
||||||
std::string lora_model_dir;
|
std::string lora_model_dir = ".";
|
||||||
|
|
||||||
std::map<std::string, std::string> embedding_map;
|
std::map<std::string, std::string> embedding_map;
|
||||||
std::vector<sd_embedding_t> embedding_vec;
|
std::vector<sd_embedding_t> embedding_vec;
|
||||||
@ -457,6 +457,7 @@ struct SDContextParams {
|
|||||||
bool control_net_cpu = false;
|
bool control_net_cpu = false;
|
||||||
bool clip_on_cpu = false;
|
bool clip_on_cpu = false;
|
||||||
bool vae_on_cpu = false;
|
bool vae_on_cpu = false;
|
||||||
|
bool flash_attn = false;
|
||||||
bool diffusion_flash_attn = false;
|
bool diffusion_flash_attn = false;
|
||||||
bool diffusion_conv_direct = false;
|
bool diffusion_conv_direct = false;
|
||||||
bool vae_conv_direct = false;
|
bool vae_conv_direct = false;
|
||||||
@ -615,9 +616,13 @@ struct SDContextParams {
|
|||||||
"--vae-on-cpu",
|
"--vae-on-cpu",
|
||||||
"keep vae in cpu (for low vram)",
|
"keep vae in cpu (for low vram)",
|
||||||
true, &vae_on_cpu},
|
true, &vae_on_cpu},
|
||||||
|
{"",
|
||||||
|
"--fa",
|
||||||
|
"use flash attention",
|
||||||
|
true, &flash_attn},
|
||||||
{"",
|
{"",
|
||||||
"--diffusion-fa",
|
"--diffusion-fa",
|
||||||
"use flash attention in the diffusion model",
|
"use flash attention in the diffusion model only",
|
||||||
true, &diffusion_flash_attn},
|
true, &diffusion_flash_attn},
|
||||||
{"",
|
{"",
|
||||||
"--diffusion-conv-direct",
|
"--diffusion-conv-direct",
|
||||||
@ -904,6 +909,7 @@ struct SDContextParams {
|
|||||||
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
|
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
|
||||||
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
|
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
|
||||||
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
|
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
|
||||||
|
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
|
||||||
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
|
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
|
||||||
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
|
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
|
||||||
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
|
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
|
||||||
@ -968,6 +974,7 @@ struct SDContextParams {
|
|||||||
clip_on_cpu,
|
clip_on_cpu,
|
||||||
control_net_cpu,
|
control_net_cpu,
|
||||||
vae_on_cpu,
|
vae_on_cpu,
|
||||||
|
flash_attn,
|
||||||
diffusion_flash_attn,
|
diffusion_flash_attn,
|
||||||
taesd_preview,
|
taesd_preview,
|
||||||
diffusion_conv_direct,
|
diffusion_conv_direct,
|
||||||
@ -1024,8 +1031,8 @@ struct SDGenerationParams {
|
|||||||
std::string prompt_with_lora; // for metadata record only
|
std::string prompt_with_lora; // for metadata record only
|
||||||
std::string negative_prompt;
|
std::string negative_prompt;
|
||||||
int clip_skip = -1; // <= 0 represents unspecified
|
int clip_skip = -1; // <= 0 represents unspecified
|
||||||
int width = 512;
|
int width = -1;
|
||||||
int height = 512;
|
int height = -1;
|
||||||
int batch_count = 1;
|
int batch_count = 1;
|
||||||
std::string init_image_path;
|
std::string init_image_path;
|
||||||
std::string end_image_path;
|
std::string end_image_path;
|
||||||
@ -1478,17 +1485,17 @@ struct SDGenerationParams {
|
|||||||
on_seed_arg},
|
on_seed_arg},
|
||||||
{"",
|
{"",
|
||||||
"--sampling-method",
|
"--sampling-method",
|
||||||
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] "
|
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] "
|
||||||
"(default: euler for Flux/SD3/Wan, euler_a otherwise)",
|
"(default: euler for Flux/SD3/Wan, euler_a otherwise)",
|
||||||
on_sample_method_arg},
|
on_sample_method_arg},
|
||||||
{"",
|
{"",
|
||||||
"--high-noise-sampling-method",
|
"--high-noise-sampling-method",
|
||||||
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
|
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]"
|
||||||
" default: euler for Flux/SD3/Wan, euler_a otherwise",
|
" default: euler for Flux/SD3/Wan, euler_a otherwise",
|
||||||
on_high_noise_sample_method_arg},
|
on_high_noise_sample_method_arg},
|
||||||
{"",
|
{"",
|
||||||
"--scheduler",
|
"--scheduler",
|
||||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
|
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
|
||||||
on_scheduler_arg},
|
on_scheduler_arg},
|
||||||
{"",
|
{"",
|
||||||
"--sigmas",
|
"--sigmas",
|
||||||
@ -1594,10 +1601,30 @@ struct SDGenerationParams {
|
|||||||
load_if_exists("skip_layers", skip_layers);
|
load_if_exists("skip_layers", skip_layers);
|
||||||
load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
|
load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
|
||||||
|
|
||||||
|
load_if_exists("steps", sample_params.sample_steps);
|
||||||
|
load_if_exists("high_noise_steps", high_noise_sample_params.sample_steps);
|
||||||
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
|
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
|
||||||
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
|
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
|
||||||
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
|
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
|
||||||
|
|
||||||
|
auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) {
|
||||||
|
if (j.contains(key) && j[key].is_string()) {
|
||||||
|
enum sample_method_t tmp = str_to_sample_method(j[key].get<std::string>().c_str());
|
||||||
|
if (tmp != SAMPLE_METHOD_COUNT) {
|
||||||
|
out = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
load_sampler_if_exists("sample_method", sample_params.sample_method);
|
||||||
|
load_sampler_if_exists("high_noise_sample_method", high_noise_sample_params.sample_method);
|
||||||
|
|
||||||
|
if (j.contains("scheduler") && j["scheduler"].is_string()) {
|
||||||
|
enum scheduler_t tmp = str_to_scheduler(j["scheduler"].get<std::string>().c_str());
|
||||||
|
if (tmp != SCHEDULER_COUNT) {
|
||||||
|
sample_params.scheduler = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1685,17 +1712,24 @@ struct SDGenerationParams {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool width_and_height_are_set() const {
|
||||||
|
return width > 0 && height > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_width_and_height_if_unset(int w, int h) {
|
||||||
|
if (!width_and_height_are_set()) {
|
||||||
|
LOG_INFO("set width x height to %d x %d", w, h);
|
||||||
|
width = w;
|
||||||
|
height = h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_resolved_width() const { return (width > 0) ? width : 512; }
|
||||||
|
|
||||||
|
int get_resolved_height() const { return (height > 0) ? height : 512; }
|
||||||
|
|
||||||
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
|
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
|
||||||
prompt_with_lora = prompt;
|
prompt_with_lora = prompt;
|
||||||
if (width <= 0) {
|
|
||||||
LOG_ERROR("error: the width must be greater than 0\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (height <= 0) {
|
|
||||||
LOG_ERROR("error: the height must be greater than 0\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sample_params.sample_steps <= 0) {
|
if (sample_params.sample_steps <= 0) {
|
||||||
LOG_ERROR("error: the sample_steps must be greater than 0\n");
|
LOG_ERROR("error: the sample_steps must be greater than 0\n");
|
||||||
@ -2063,6 +2097,22 @@ uint8_t* load_image_from_file(const char* image_path,
|
|||||||
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_sd_image_from_file(sd_image_t* image,
|
||||||
|
const char* image_path,
|
||||||
|
int expected_width = 0,
|
||||||
|
int expected_height = 0,
|
||||||
|
int expected_channel = 3) {
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
||||||
|
if (image->data == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
image->width = width;
|
||||||
|
image->height = height;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
uint8_t* load_image_from_memory(const char* image_bytes,
|
uint8_t* load_image_from_memory(const char* image_bytes,
|
||||||
int len,
|
int len,
|
||||||
int& width,
|
int& width,
|
||||||
|
|||||||
@ -44,7 +44,8 @@ Context Options:
|
|||||||
--clip-on-cpu keep clip in cpu (for low vram)
|
--clip-on-cpu keep clip in cpu (for low vram)
|
||||||
--vae-on-cpu keep vae in cpu (for low vram)
|
--vae-on-cpu keep vae in cpu (for low vram)
|
||||||
--mmap whether to memory-map model
|
--mmap whether to memory-map model
|
||||||
--diffusion-fa use flash attention in the diffusion model
|
--fa use flash attention
|
||||||
|
--diffusion-fa use flash attention in the diffusion model only
|
||||||
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
||||||
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
||||||
--circular enable circular padding for convolutions
|
--circular enable circular padding for convolutions
|
||||||
@ -99,14 +100,14 @@ Default Generation Options:
|
|||||||
medium
|
medium
|
||||||
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
||||||
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
||||||
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
|
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
|
||||||
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
||||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
||||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
||||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
||||||
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
||||||
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
||||||
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
|
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
|
||||||
--strength <float> strength for noising/unnoising (default: 0.75)
|
--strength <float> strength for noising/unnoising (default: 0.75)
|
||||||
--pm-style-strength <float>
|
--pm-style-strength <float>
|
||||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
||||||
@ -115,12 +116,12 @@ Default Generation Options:
|
|||||||
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
||||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
|
||||||
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
||||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
|
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
||||||
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
|
tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
||||||
kl_optimal, lcm], default: discrete
|
kl_optimal, lcm, bong_tangent], default: discrete
|
||||||
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
||||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||||
|
|||||||
@ -86,21 +86,6 @@ std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string iso_timestamp_now() {
|
|
||||||
using namespace std::chrono;
|
|
||||||
auto now = system_clock::now();
|
|
||||||
std::time_t t = system_clock::to_time_t(now);
|
|
||||||
std::tm tm{};
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
gmtime_s(&tm, &t);
|
|
||||||
#else
|
|
||||||
gmtime_r(&t, &tm);
|
|
||||||
#endif
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SDSvrParams {
|
struct SDSvrParams {
|
||||||
std::string listen_ip = "127.0.0.1";
|
std::string listen_ip = "127.0.0.1";
|
||||||
int listen_port = 1234;
|
int listen_port = 1234;
|
||||||
@ -404,7 +389,7 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
json out;
|
json out;
|
||||||
out["created"] = iso_timestamp_now();
|
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||||
out["data"] = json::array();
|
out["data"] = json::array();
|
||||||
out["output_format"] = output_format;
|
out["output_format"] = output_format;
|
||||||
|
|
||||||
@ -420,6 +405,9 @@ int main(int argc, const char** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gen_params.sample_params.sample_steps > 100)
|
||||||
|
gen_params.sample_params.sample_steps = 100;
|
||||||
|
|
||||||
if (!gen_params.process_and_check(IMG_GEN, "")) {
|
if (!gen_params.process_and_check(IMG_GEN, "")) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(R"({"error":"invalid params"})", "application/json");
|
res.set_content(R"({"error":"invalid params"})", "application/json");
|
||||||
@ -537,7 +525,7 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> mask_bytes;
|
std::vector<uint8_t> mask_bytes;
|
||||||
if (req.form.has_field("mask")) {
|
if (req.form.has_file("mask")) {
|
||||||
auto file = req.form.get_file("mask");
|
auto file = req.form.get_file("mask");
|
||||||
mask_bytes.assign(file.content.begin(), file.content.end());
|
mask_bytes.assign(file.content.begin(), file.content.end());
|
||||||
}
|
}
|
||||||
@ -598,6 +586,9 @@ int main(int argc, const char** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gen_params.sample_params.sample_steps > 100)
|
||||||
|
gen_params.sample_params.sample_steps = 100;
|
||||||
|
|
||||||
if (!gen_params.process_and_check(IMG_GEN, "")) {
|
if (!gen_params.process_and_check(IMG_GEN, "")) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
res.set_content(R"({"error":"invalid params"})", "application/json");
|
res.set_content(R"({"error":"invalid params"})", "application/json");
|
||||||
@ -686,7 +677,7 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
json out;
|
json out;
|
||||||
out["created"] = iso_timestamp_now();
|
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||||
out["data"] = json::array();
|
out["data"] = json::array();
|
||||||
out["output_format"] = output_format;
|
out["output_format"] = output_format;
|
||||||
|
|
||||||
@ -726,6 +717,331 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// sdapi endpoints (AUTOMATIC1111 / Forge)
|
||||||
|
|
||||||
|
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||||
|
try {
|
||||||
|
if (req.body.empty()) {
|
||||||
|
res.status = 400;
|
||||||
|
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
json j = json::parse(req.body);
|
||||||
|
|
||||||
|
std::string prompt = j.value("prompt", "");
|
||||||
|
std::string negative_prompt = j.value("negative_prompt", "");
|
||||||
|
int width = j.value("width", 512);
|
||||||
|
int height = j.value("height", 512);
|
||||||
|
int steps = j.value("steps", -1);
|
||||||
|
float cfg_scale = j.value("cfg_scale", 7.f);
|
||||||
|
int64_t seed = j.value("seed", -1);
|
||||||
|
int batch_size = j.value("batch_size", 1);
|
||||||
|
int clip_skip = j.value("clip_skip", -1);
|
||||||
|
std::string sampler_name = j.value("sampler_name", "");
|
||||||
|
std::string scheduler_name = j.value("scheduler", "");
|
||||||
|
|
||||||
|
auto bad = [&](const std::string& msg) {
|
||||||
|
res.status = 400;
|
||||||
|
res.set_content("{\"error\":\"" + msg + "\"}", "application/json");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (width <= 0 || height <= 0) {
|
||||||
|
return bad("width and height must be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (steps < 1 || steps > 150) {
|
||||||
|
return bad("steps must be in range [1, 150]");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch_size < 1 || batch_size > 8) {
|
||||||
|
return bad("batch_size must be in range [1, 8]");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_scale < 0.f) {
|
||||||
|
return bad("cfg_scale must be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prompt.empty()) {
|
||||||
|
return bad("prompt required");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||||
|
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||||
|
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||||
|
// some applications use a hardcoded sampler list
|
||||||
|
std::transform(name.begin(), name.end(), name.begin(),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
||||||
|
{"euler a", EULER_A_SAMPLE_METHOD},
|
||||||
|
{"k_euler_a", EULER_A_SAMPLE_METHOD},
|
||||||
|
{"euler", EULER_SAMPLE_METHOD},
|
||||||
|
{"k_euler", EULER_SAMPLE_METHOD},
|
||||||
|
{"heun", HEUN_SAMPLE_METHOD},
|
||||||
|
{"k_heun", HEUN_SAMPLE_METHOD},
|
||||||
|
{"dpm2", DPM2_SAMPLE_METHOD},
|
||||||
|
{"k_dpm_2", DPM2_SAMPLE_METHOD},
|
||||||
|
{"lcm", LCM_SAMPLE_METHOD},
|
||||||
|
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
|
||||||
|
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
|
||||||
|
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
|
||||||
|
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
|
||||||
|
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
|
||||||
|
{"res 2s", RES_2S_SAMPLE_METHOD},
|
||||||
|
{"k_res_2s", RES_2S_SAMPLE_METHOD}};
|
||||||
|
auto it = hardcoded.find(name);
|
||||||
|
if (it != hardcoded.end()) return it->second;
|
||||||
|
return SAMPLE_METHOD_COUNT;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum sample_method_t sample_method = get_sample_method(sampler_name);
|
||||||
|
|
||||||
|
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
||||||
|
|
||||||
|
// avoid excessive resource usage
|
||||||
|
|
||||||
|
SDGenerationParams gen_params = default_gen_params;
|
||||||
|
gen_params.prompt = prompt;
|
||||||
|
gen_params.negative_prompt = negative_prompt;
|
||||||
|
gen_params.width = width;
|
||||||
|
gen_params.height = height;
|
||||||
|
gen_params.seed = seed;
|
||||||
|
gen_params.sample_params.sample_steps = steps;
|
||||||
|
gen_params.batch_count = batch_size;
|
||||||
|
|
||||||
|
if (clip_skip > 0) {
|
||||||
|
gen_params.clip_skip = clip_skip;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sample_method != SAMPLE_METHOD_COUNT) {
|
||||||
|
gen_params.sample_params.sample_method = sample_method;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scheduler != SCHEDULER_COUNT) {
|
||||||
|
gen_params.sample_params.scheduler = scheduler;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
|
||||||
|
|
||||||
|
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||||
|
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||||
|
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
|
||||||
|
std::vector<uint8_t> mask_data;
|
||||||
|
std::vector<sd_image_t> pmid_images;
|
||||||
|
std::vector<sd_image_t> ref_images;
|
||||||
|
|
||||||
|
if (img2img) {
|
||||||
|
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
|
||||||
|
// remove data URI prefix if present ("data:image/png;base64,")
|
||||||
|
auto comma_pos = encoded.find(',');
|
||||||
|
if (comma_pos != std::string::npos) {
|
||||||
|
encoded = encoded.substr(comma_pos + 1);
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> img_data = base64_decode(encoded);
|
||||||
|
if (!img_data.empty()) {
|
||||||
|
int img_w = image.width;
|
||||||
|
int img_h = image.height;
|
||||||
|
uint8_t* raw_data = load_image_from_memory(
|
||||||
|
(const char*)img_data.data(), (int)img_data.size(),
|
||||||
|
img_w, img_h,
|
||||||
|
image.width, image.height, image.channel);
|
||||||
|
if (raw_data) {
|
||||||
|
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
|
||||||
|
std::string encoded = j["init_images"][0].get<std::string>();
|
||||||
|
decode_image(init_image, encoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j.contains("mask") && j["mask"].is_string()) {
|
||||||
|
std::string encoded = j["mask"].get<std::string>();
|
||||||
|
decode_image(mask_image, encoded);
|
||||||
|
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
|
||||||
|
if (inpainting_mask_invert && mask_image.data != nullptr) {
|
||||||
|
for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) {
|
||||||
|
mask_image.data[i] = 255 - mask_image.data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mask_data = std::vector<uint8_t>(width * height, 255);
|
||||||
|
mask_image.width = width;
|
||||||
|
mask_image.height = height;
|
||||||
|
mask_image.channel = 1;
|
||||||
|
mask_image.data = mask_data.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j.contains("extra_images") && j["extra_images"].is_array()) {
|
||||||
|
for (auto extra_image : j["extra_images"]) {
|
||||||
|
std::string encoded = extra_image.get<std::string>();
|
||||||
|
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||||
|
if (decode_image(tmp_image, encoded)) {
|
||||||
|
ref_images.push_back(tmp_image);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float denoising_strength = j.value("denoising_strength", -1.f);
|
||||||
|
if (denoising_strength >= 0.f) {
|
||||||
|
denoising_strength = std::min(denoising_strength, 1.0f);
|
||||||
|
gen_params.strength = denoising_strength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_img_gen_params_t img_gen_params = {
|
||||||
|
gen_params.lora_vec.data(),
|
||||||
|
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
||||||
|
gen_params.prompt.c_str(),
|
||||||
|
gen_params.negative_prompt.c_str(),
|
||||||
|
gen_params.clip_skip,
|
||||||
|
init_image,
|
||||||
|
ref_images.data(),
|
||||||
|
(int)ref_images.size(),
|
||||||
|
gen_params.auto_resize_ref_image,
|
||||||
|
gen_params.increase_ref_index,
|
||||||
|
mask_image,
|
||||||
|
gen_params.width,
|
||||||
|
gen_params.height,
|
||||||
|
gen_params.sample_params,
|
||||||
|
gen_params.strength,
|
||||||
|
gen_params.seed,
|
||||||
|
gen_params.batch_count,
|
||||||
|
control_image,
|
||||||
|
gen_params.control_strength,
|
||||||
|
{
|
||||||
|
pmid_images.data(),
|
||||||
|
(int)pmid_images.size(),
|
||||||
|
gen_params.pm_id_embed_path.c_str(),
|
||||||
|
gen_params.pm_style_strength,
|
||||||
|
}, // pm_params
|
||||||
|
ctx_params.vae_tiling_params,
|
||||||
|
gen_params.cache_params,
|
||||||
|
};
|
||||||
|
|
||||||
|
sd_image_t* results = nullptr;
|
||||||
|
int num_results = 0;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
||||||
|
results = generate_image(sd_ctx, &img_gen_params);
|
||||||
|
num_results = gen_params.batch_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
json out;
|
||||||
|
out["images"] = json::array();
|
||||||
|
out["parameters"] = j; // TODO should return changed defaults
|
||||||
|
out["info"] = "";
|
||||||
|
|
||||||
|
for (int i = 0; i < num_results; i++) {
|
||||||
|
if (results[i].data == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto image_bytes = write_image_to_vector(ImageFormat::PNG,
|
||||||
|
results[i].data,
|
||||||
|
results[i].width,
|
||||||
|
results[i].height,
|
||||||
|
results[i].channel);
|
||||||
|
|
||||||
|
if (image_bytes.empty()) {
|
||||||
|
LOG_ERROR("write image to mem failed");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string b64 = base64_encode(image_bytes);
|
||||||
|
out["images"].push_back(b64);
|
||||||
|
}
|
||||||
|
|
||||||
|
res.set_content(out.dump(), "application/json");
|
||||||
|
res.status = 200;
|
||||||
|
|
||||||
|
if (init_image.data) {
|
||||||
|
stbi_image_free(init_image.data);
|
||||||
|
}
|
||||||
|
if (mask_image.data && mask_data.empty()) {
|
||||||
|
stbi_image_free(mask_image.data);
|
||||||
|
}
|
||||||
|
for (auto ref_image : ref_images) {
|
||||||
|
stbi_image_free(ref_image.data);
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
res.status = 500;
|
||||||
|
json err;
|
||||||
|
err["error"] = "server_error";
|
||||||
|
err["message"] = e.what();
|
||||||
|
res.set_content(err.dump(), "application/json");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
|
sdapi_any2img(req, res, false);
|
||||||
|
});
|
||||||
|
|
||||||
|
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||||
|
sdapi_any2img(req, res, true);
|
||||||
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
std::vector<std::string> sampler_names;
|
||||||
|
sampler_names.push_back("default");
|
||||||
|
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
||||||
|
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
|
||||||
|
}
|
||||||
|
json r = json::array();
|
||||||
|
for (auto name : sampler_names) {
|
||||||
|
json entry;
|
||||||
|
entry["name"] = name;
|
||||||
|
entry["aliases"] = json::array({name});
|
||||||
|
entry["options"] = json::object();
|
||||||
|
r.push_back(entry);
|
||||||
|
}
|
||||||
|
res.set_content(r.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
std::vector<std::string> scheduler_names;
|
||||||
|
scheduler_names.push_back("default");
|
||||||
|
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
||||||
|
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
|
||||||
|
}
|
||||||
|
json r = json::array();
|
||||||
|
for (auto name : scheduler_names) {
|
||||||
|
json entry;
|
||||||
|
entry["name"] = name;
|
||||||
|
entry["label"] = name;
|
||||||
|
r.push_back(entry);
|
||||||
|
}
|
||||||
|
res.set_content(r.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
fs::path model_path = ctx_params.model_path;
|
||||||
|
json entry;
|
||||||
|
entry["title"] = model_path.stem();
|
||||||
|
entry["model_name"] = model_path.stem();
|
||||||
|
entry["filename"] = model_path.filename();
|
||||||
|
entry["hash"] = "8888888888";
|
||||||
|
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
|
||||||
|
entry["config"] = nullptr;
|
||||||
|
json r = json::array();
|
||||||
|
r.push_back(entry);
|
||||||
|
res.set_content(r.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
fs::path model_path = ctx_params.model_path;
|
||||||
|
json r;
|
||||||
|
r["samples_format"] = "png";
|
||||||
|
r["sd_model_checkpoint"] = model_path.stem();
|
||||||
|
res.set_content(r.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
||||||
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
||||||
|
|
||||||
|
|||||||
148
flux.hpp
@ -103,7 +103,7 @@ namespace Flux {
|
|||||||
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
|
||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x);
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
|
||||||
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
|
||||||
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
|
||||||
@ -153,7 +153,7 @@ namespace Flux {
|
|||||||
if (use_mlp_silu_act) {
|
if (use_mlp_silu_act) {
|
||||||
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
|
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
}
|
}
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
@ -376,26 +376,23 @@ namespace Flux {
|
|||||||
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
|
||||||
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_txt_token, N, hidden_size]
|
0); // [N, n_txt_token, hidden_size]
|
||||||
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
|
||||||
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
|
||||||
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
|
||||||
|
|
||||||
// calculate the img bloks
|
// calculate the img bloks
|
||||||
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
||||||
@ -492,43 +489,29 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
|
||||||
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
|
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
|
||||||
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
|
|
||||||
|
|
||||||
auto qkv = ggml_view_3d(ctx->ggml_ctx,
|
auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
|
||||||
qkv_mlp,
|
auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
|
||||||
qkv_mlp->ne[0],
|
auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
|
||||||
qkv_mlp->ne[1],
|
|
||||||
hidden_size * 3,
|
|
||||||
qkv_mlp->nb[1],
|
|
||||||
qkv_mlp->nb[2],
|
|
||||||
0); // [hidden_size * 3 , N, n_token]
|
|
||||||
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
|
|
||||||
auto mlp = ggml_view_3d(ctx->ggml_ctx,
|
|
||||||
qkv_mlp,
|
|
||||||
qkv_mlp->ne[0],
|
|
||||||
qkv_mlp->ne[1],
|
|
||||||
mlp_hidden_dim * mlp_mult_factor,
|
|
||||||
qkv_mlp->nb[1],
|
|
||||||
qkv_mlp->nb[2],
|
|
||||||
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
|
|
||||||
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
|
|
||||||
|
|
||||||
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
|
|
||||||
int64_t head_dim = hidden_size / num_heads;
|
int64_t head_dim = hidden_size / num_heads;
|
||||||
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
|
||||||
q = norm->query_norm(ctx, q);
|
|
||||||
k = norm->key_norm(ctx, k);
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
|
||||||
|
|
||||||
|
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
|
||||||
|
|
||||||
|
q = norm->query_norm(ctx, q);
|
||||||
|
k = norm->key_norm(ctx, k);
|
||||||
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
|
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
|
||||||
if (use_yak_mlp) {
|
if (use_yak_mlp) {
|
||||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
|
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
|
||||||
} else if (use_mlp_silu_act) {
|
} else if (use_mlp_silu_act) {
|
||||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
|
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
|
||||||
} else {
|
} else {
|
||||||
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
|
mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
|
||||||
}
|
}
|
||||||
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||||
@ -580,13 +563,10 @@ namespace Flux {
|
|||||||
} else {
|
} else {
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
|
||||||
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
shift = m_vec[0]; // [N, hidden_size]
|
||||||
|
scale = m_vec[1]; // [N, hidden_size]
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
|
||||||
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
|
||||||
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
@ -748,7 +728,7 @@ namespace Flux {
|
|||||||
int nerf_depth = 4;
|
int nerf_depth = 4;
|
||||||
int nerf_max_freqs = 8;
|
int nerf_max_freqs = 8;
|
||||||
bool use_x0 = false;
|
bool use_x0 = false;
|
||||||
bool use_patch_size_32 = false;
|
bool fake_patch_size_x2 = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FluxParams {
|
struct FluxParams {
|
||||||
@ -786,8 +766,11 @@ namespace Flux {
|
|||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
std::pair<int, int> kernel_size = {16, 16};
|
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
|
||||||
std::pair<int, int> stride = kernel_size;
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
|
kernel_size = {params.patch_size / 2, params.patch_size / 2};
|
||||||
|
}
|
||||||
|
std::pair<int, int> stride = kernel_size;
|
||||||
|
|
||||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||||
params.hidden_size,
|
params.hidden_size,
|
||||||
@ -1031,16 +1014,14 @@ namespace Flux {
|
|||||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
||||||
}
|
}
|
||||||
|
|
||||||
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
img = ggml_view_3d(ctx->ggml_ctx,
|
||||||
img = ggml_view_3d(ctx->ggml_ctx,
|
txt_img,
|
||||||
txt_img,
|
txt_img->ne[0],
|
||||||
txt_img->ne[0],
|
img->ne[1],
|
||||||
txt_img->ne[1],
|
txt_img->ne[2],
|
||||||
img->ne[1],
|
txt_img->nb[1],
|
||||||
txt_img->nb[1],
|
txt_img->nb[2],
|
||||||
txt_img->nb[2],
|
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
||||||
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
|
||||||
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
|
||||||
|
|
||||||
if (final_layer) {
|
if (final_layer) {
|
||||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||||
@ -1082,7 +1063,7 @@ namespace Flux {
|
|||||||
auto img = pad_to_patch_size(ctx, x);
|
auto img = pad_to_patch_size(ctx, x);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||||
@ -1193,9 +1174,8 @@ namespace Flux {
|
|||||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
|
||||||
|
|
||||||
if (out->ne[1] > img_tokens) {
|
if (out->ne[1] > img_tokens) {
|
||||||
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
|
||||||
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
out = ggml_cont(ctx->ggml_ctx, out);
|
||||||
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||||
@ -1288,13 +1268,9 @@ namespace Flux {
|
|||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
flux_params.semantic_txt_norm = true;
|
flux_params.semantic_txt_norm = true;
|
||||||
flux_params.use_yak_mlp = true;
|
flux_params.use_yak_mlp = true;
|
||||||
flux_params.context_in_dim = 2048;
|
|
||||||
flux_params.vec_in_dim = 0;
|
flux_params.vec_in_dim = 0;
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (sd_version_is_flux2(version)) {
|
||||||
flux_params.context_in_dim = 15360;
|
|
||||||
flux_params.in_channels = 128;
|
flux_params.in_channels = 128;
|
||||||
flux_params.hidden_size = 6144;
|
|
||||||
flux_params.num_heads = 48;
|
|
||||||
flux_params.patch_size = 1;
|
flux_params.patch_size = 1;
|
||||||
flux_params.out_channels = 128;
|
flux_params.out_channels = 128;
|
||||||
flux_params.mlp_ratio = 3.f;
|
flux_params.mlp_ratio = 3.f;
|
||||||
@ -1307,12 +1283,13 @@ namespace Flux {
|
|||||||
flux_params.ref_index_scale = 10.f;
|
flux_params.ref_index_scale = 10.f;
|
||||||
flux_params.use_mlp_silu_act = true;
|
flux_params.use_mlp_silu_act = true;
|
||||||
}
|
}
|
||||||
|
int64_t head_dim = 0;
|
||||||
|
int64_t actual_radiance_patch_size = -1;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
if (!starts_with(tensor_name, prefix))
|
if (!starts_with(tensor_name, prefix))
|
||||||
continue;
|
continue;
|
||||||
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
|
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
|
||||||
// not schnell
|
|
||||||
flux_params.guidance_embed = true;
|
flux_params.guidance_embed = true;
|
||||||
}
|
}
|
||||||
if (tensor_name.find("__x0__") != std::string::npos) {
|
if (tensor_name.find("__x0__") != std::string::npos) {
|
||||||
@ -1320,9 +1297,12 @@ namespace Flux {
|
|||||||
flux_params.chroma_radiance_params.use_x0 = true;
|
flux_params.chroma_radiance_params.use_x0 = true;
|
||||||
}
|
}
|
||||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||||
LOG_DEBUG("using patch size 32 prediction");
|
LOG_DEBUG("using patch size 32");
|
||||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
flux_params.patch_size = 32;
|
||||||
flux_params.patch_size = 32;
|
}
|
||||||
|
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
|
||||||
|
actual_radiance_patch_size = pair.second.ne[0];
|
||||||
|
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
|
||||||
}
|
}
|
||||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||||
// Chroma
|
// Chroma
|
||||||
@ -1344,13 +1324,35 @@ namespace Flux {
|
|||||||
flux_params.depth_single_blocks = block_depth + 1;
|
flux_params.depth_single_blocks = block_depth + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (ends_with(tensor_name, "txt_in.weight")) {
|
||||||
|
flux_params.context_in_dim = pair.second.ne[0];
|
||||||
|
flux_params.hidden_size = pair.second.ne[1];
|
||||||
|
}
|
||||||
|
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
|
||||||
|
head_dim = pair.second.ne[0];
|
||||||
|
}
|
||||||
|
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
|
||||||
|
head_dim = pair.second.ne[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
|
||||||
|
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
|
||||||
|
LOG_DEBUG("using fake x2 patch size");
|
||||||
|
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
|
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||||
|
|
||||||
|
LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
|
||||||
|
", hidden_size = %" PRId64 ", num_heads = %d",
|
||||||
|
flux_params.depth,
|
||||||
|
flux_params.depth_single_blocks,
|
||||||
|
flux_params.guidance_embed ? "true" : "false",
|
||||||
|
flux_params.context_in_dim,
|
||||||
|
flux_params.hidden_size,
|
||||||
|
flux_params.num_heads);
|
||||||
if (flux_params.is_chroma) {
|
if (flux_params.is_chroma) {
|
||||||
LOG_INFO("Using pruned modulation (Chroma)");
|
LOG_INFO("Using pruned modulation (Chroma)");
|
||||||
} else if (!flux_params.guidance_embed) {
|
|
||||||
LOG_INFO("Flux guidance is disabled (Schnell mode)");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
flux = Flux(flux_params);
|
flux = Flux(flux_params);
|
||||||
|
|||||||
2
ggml
@ -1 +1 @@
|
|||||||
Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff
|
Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d
|
||||||
127
ggml_extend.hpp
@ -687,7 +687,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int dim,
|
int dim,
|
||||||
int64_t start,
|
int64_t start,
|
||||||
int64_t end) {
|
int64_t end,
|
||||||
|
bool cont = true) {
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
if (x->ne[dim] == 1) {
|
if (x->ne[dim] == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -702,27 +703,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
||||||
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
||||||
|
|
||||||
int perm[4] = {0, 1, 2, 3};
|
int64_t slice_size = end - start;
|
||||||
for (int i = dim; i < 3; ++i)
|
int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
|
||||||
perm[i] = perm[i + 1];
|
slice_ne[dim] = slice_size;
|
||||||
perm[3] = dim;
|
|
||||||
|
|
||||||
int inv_perm[4];
|
x = ggml_view_4d(ctx, x,
|
||||||
for (int i = 0; i < 4; ++i)
|
slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
|
||||||
inv_perm[perm[i]] = i;
|
x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
|
||||||
|
|
||||||
if (dim != 3) {
|
if (cont) {
|
||||||
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
x = ggml_view_4d(
|
|
||||||
ctx, x,
|
|
||||||
x->ne[0], x->ne[1], x->ne[2], end - start,
|
|
||||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
|
||||||
|
|
||||||
if (dim != 3) {
|
|
||||||
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
|
||||||
x = ggml_cont(ctx, x);
|
x = ggml_cont(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -960,6 +949,49 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
|
|||||||
return ggml_group_norm(ctx, a, 32, eps);
|
return ggml_group_norm(ctx, a, 32, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
float factor,
|
||||||
|
bool inplace = false) {
|
||||||
|
if (!ggml_is_contiguous(x)) {
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
if (inplace) {
|
||||||
|
x = ggml_scale_inplace(ctx, x, factor);
|
||||||
|
} else {
|
||||||
|
x = ggml_scale(ctx, x, factor);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
bool inplace = false) {
|
||||||
|
if (!ggml_is_contiguous(x)) {
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
if (inplace) {
|
||||||
|
x = ggml_gelu_inplace(ctx, x);
|
||||||
|
} else {
|
||||||
|
x = ggml_gelu(ctx, x);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
bool inplace = false) {
|
||||||
|
if (!ggml_is_contiguous(x)) {
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
if (inplace) {
|
||||||
|
x = ggml_gelu_quick_inplace(ctx, x);
|
||||||
|
} else {
|
||||||
|
x = ggml_gelu_quick(ctx, x);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* w,
|
struct ggml_tensor* w,
|
||||||
@ -967,7 +999,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
bool force_prec_f32 = false,
|
bool force_prec_f32 = false,
|
||||||
float scale = 1.f) {
|
float scale = 1.f) {
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, scale);
|
x = ggml_ext_scale(ctx, x, scale);
|
||||||
}
|
}
|
||||||
if (x->ne[2] * x->ne[3] > 1024) {
|
if (x->ne[2] * x->ne[3] > 1024) {
|
||||||
// workaround: avoid ggml cuda error
|
// workaround: avoid ggml cuda error
|
||||||
@ -986,7 +1018,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, 1.f / scale);
|
x = ggml_ext_scale(ctx, x, 1.f / scale);
|
||||||
}
|
}
|
||||||
if (b != nullptr) {
|
if (b != nullptr) {
|
||||||
x = ggml_add_inplace(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
@ -1055,7 +1087,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
bool circular_y = false,
|
bool circular_y = false,
|
||||||
float scale = 1.f) {
|
float scale = 1.f) {
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, scale);
|
x = ggml_ext_scale(ctx, x, scale);
|
||||||
}
|
}
|
||||||
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
|
||||||
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
|
||||||
@ -1073,7 +1105,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
|
|||||||
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
||||||
}
|
}
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, 1.f / scale);
|
x = ggml_ext_scale(ctx, x, 1.f / scale);
|
||||||
}
|
}
|
||||||
if (b != nullptr) {
|
if (b != nullptr) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||||
@ -1171,7 +1203,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
|
|||||||
int64_t ne2,
|
int64_t ne2,
|
||||||
int64_t ne3) {
|
int64_t ne3) {
|
||||||
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
||||||
auto t = ggml_scale(ctx, one, value); // [1,]
|
auto t = ggml_ext_scale(ctx, one, value); // [1,]
|
||||||
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
@ -1208,35 +1240,11 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor*
|
|||||||
} else {
|
} else {
|
||||||
out = ggml_mul_mat(ctx, out, one);
|
out = ggml_mul_mat(ctx, out, one);
|
||||||
}
|
}
|
||||||
out = ggml_reshape(ctx, out, a);
|
out = ggml_reshape(ctx, out, a);
|
||||||
#endif
|
#endif
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
// q: [N * n_head, n_token, d_head]
|
|
||||||
// k: [N * n_head, n_k, d_head]
|
|
||||||
// v: [N * n_head, d_head, n_k]
|
|
||||||
// return: [N * n_head, n_token, d_head]
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* q,
|
|
||||||
struct ggml_tensor* k,
|
|
||||||
struct ggml_tensor* v,
|
|
||||||
bool mask = false) {
|
|
||||||
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
|
|
||||||
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
|
|
||||||
#else
|
|
||||||
float d_head = (float)q->ne[0];
|
|
||||||
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
|
|
||||||
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
|
|
||||||
if (mask) {
|
|
||||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
|
||||||
}
|
|
||||||
kq = ggml_soft_max_inplace(ctx, kq);
|
|
||||||
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
|
|
||||||
#endif
|
|
||||||
return kqv;
|
|
||||||
}
|
|
||||||
|
|
||||||
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
||||||
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
||||||
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
||||||
@ -1249,7 +1257,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
struct ggml_tensor* v,
|
struct ggml_tensor* v,
|
||||||
int64_t n_head,
|
int64_t n_head,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
bool diag_mask_inf = false,
|
|
||||||
bool skip_reshape = false,
|
bool skip_reshape = false,
|
||||||
bool flash_attn = false,
|
bool flash_attn = false,
|
||||||
float kv_scale = 1.0f) { // avoid overflow
|
float kv_scale = 1.0f) { // avoid overflow
|
||||||
@ -1295,7 +1302,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
k_in = ggml_scale(ctx, k_in, kv_scale);
|
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
|
||||||
}
|
}
|
||||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
@ -1305,7 +1312,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
v_in = ggml_scale(ctx, v_in, kv_scale);
|
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
|
||||||
}
|
}
|
||||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
@ -1337,7 +1344,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
||||||
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
||||||
if (kv_scale != 1.0f) {
|
if (kv_scale != 1.0f) {
|
||||||
out = ggml_scale(ctx, out, 1.0f / kv_scale);
|
out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
};
|
};
|
||||||
@ -1372,13 +1379,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
||||||
|
|
||||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||||
kq = ggml_scale_inplace(ctx, kq, scale);
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
kq = ggml_scale_inplace(ctx, kq, scale);
|
||||||
if (mask) {
|
if (mask) {
|
||||||
kq = ggml_add_inplace(ctx, kq, mask);
|
kq = ggml_add_inplace(ctx, kq, mask);
|
||||||
}
|
}
|
||||||
if (diag_mask_inf) {
|
|
||||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
|
||||||
}
|
|
||||||
kq = ggml_soft_max_inplace(ctx, kq);
|
kq = ggml_soft_max_inplace(ctx, kq);
|
||||||
|
|
||||||
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
||||||
@ -1546,7 +1551,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
|
|||||||
int dim,
|
int dim,
|
||||||
int max_period = 10000,
|
int max_period = 10000,
|
||||||
float time_factor = 1.0f) {
|
float time_factor = 1.0f) {
|
||||||
timesteps = ggml_scale(ctx, timesteps, time_factor);
|
timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
|
||||||
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
|
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2595,7 +2600,7 @@ public:
|
|||||||
// x: [N, n_token, embed_dim]
|
// x: [N, n_token, embed_dim]
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
bool mask = false) {
|
struct ggml_tensor* mask = nullptr) {
|
||||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
||||||
|
|
||||||
ggml_tensor* q;
|
ggml_tensor* q;
|
||||||
@ -2618,7 +2623,7 @@ public:
|
|||||||
v = v_proj->forward(ctx, x);
|
v = v_proj->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
||||||
return x;
|
return x;
|
||||||
|
|||||||
53
llm.hpp
@ -638,7 +638,7 @@ namespace LLM {
|
|||||||
x = ln_q->forward(ctx, x);
|
x = ln_q->forward(ctx, x);
|
||||||
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
|
||||||
x = mlp_0->forward(ctx, x);
|
x = mlp_0->forward(ctx, x);
|
||||||
x = ggml_gelu(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x);
|
||||||
x = mlp_2->forward(ctx, x);
|
x = mlp_2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -837,7 +837,8 @@ namespace LLM {
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
int64_t n_token = x->ne[1];
|
int64_t n_token = x->ne[1];
|
||||||
int64_t N = x->ne[2];
|
int64_t N = x->ne[2];
|
||||||
@ -880,7 +881,7 @@ namespace LLM {
|
|||||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
return x;
|
return x;
|
||||||
@ -898,7 +899,8 @@ namespace LLM {
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
||||||
@ -907,7 +909,7 @@ namespace LLM {
|
|||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = input_layernorm->forward(ctx, x);
|
x = input_layernorm->forward(ctx, x);
|
||||||
x = self_attn->forward(ctx, x, input_pos);
|
x = self_attn->forward(ctx, x, input_pos, attention_mask);
|
||||||
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
@ -936,6 +938,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
@ -990,7 +993,7 @@ namespace LLM {
|
|||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
x = block->forward(ctx, x, input_pos);
|
x = block->forward(ctx, x, input_pos, attention_mask);
|
||||||
if (out_layers.find(i + 1) != out_layers.end()) {
|
if (out_layers.find(i + 1) != out_layers.end()) {
|
||||||
intermediate_outputs.push_back(x);
|
intermediate_outputs.push_back(x);
|
||||||
}
|
}
|
||||||
@ -1036,12 +1039,13 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
||||||
|
|
||||||
auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers);
|
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1063,6 +1067,7 @@ namespace LLM {
|
|||||||
LLM model;
|
LLM model;
|
||||||
|
|
||||||
std::vector<int> input_pos_vec;
|
std::vector<int> input_pos_vec;
|
||||||
|
std::vector<float> attention_mask_vec;
|
||||||
std::vector<float> window_mask_vec;
|
std::vector<float> window_mask_vec;
|
||||||
std::vector<int> window_index_vec;
|
std::vector<int> window_index_vec;
|
||||||
std::vector<int> window_inverse_index_vec;
|
std::vector<int> window_inverse_index_vec;
|
||||||
@ -1157,9 +1162,10 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size]
|
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1174,6 +1180,7 @@ namespace LLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
@ -1205,9 +1212,26 @@ namespace LLM {
|
|||||||
input_pos_vec.size());
|
input_pos_vec.size());
|
||||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
|
if (attention_mask != nullptr) {
|
||||||
|
attention_mask = to_backend(attention_mask);
|
||||||
|
} else {
|
||||||
|
attention_mask_vec.resize(n_tokens * n_tokens);
|
||||||
|
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||||
|
for (int i1 = 0; i1 < n_tokens; i1++) {
|
||||||
|
float value = 0.f;
|
||||||
|
if (i0 > i1) {
|
||||||
|
value = -INFINITY;
|
||||||
|
}
|
||||||
|
attention_mask_vec[i1 * n_tokens + i0] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
|
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
||||||
|
}
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers);
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
@ -1216,12 +1240,13 @@ namespace LLM {
|
|||||||
|
|
||||||
bool compute(const int n_threads,
|
bool compute(const int n_threads,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers,
|
std::set<int> out_layers,
|
||||||
ggml_tensor** output,
|
ggml_tensor** output,
|
||||||
ggml_context* output_ctx = nullptr) {
|
ggml_context* output_ctx = nullptr) {
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_graph(input_ids, image_embeds, out_layers);
|
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
|
||||||
};
|
};
|
||||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
}
|
}
|
||||||
@ -1525,7 +1550,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, image_embeds, {}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1565,7 +1590,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1588,7 +1613,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1611,7 +1636,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|||||||
10
lora.hpp
@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
|
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
|
||||||
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
float scale_value = 1.0f;
|
float scale_value = 1.0f;
|
||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
|
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
|
||||||
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
|
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
|
||||||
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
|
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
|
||||||
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
} else {
|
} else {
|
||||||
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
scale_value *= multiplier;
|
scale_value *= multiplier;
|
||||||
|
|
||||||
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
|
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
|
||||||
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
|
||||||
|
|
||||||
if (updown == nullptr) {
|
if (updown == nullptr) {
|
||||||
updown = curr_updown;
|
updown = curr_updown;
|
||||||
@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
forward_params.conv2d.scale);
|
forward_params.conv2d.scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value);
|
auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
|
||||||
|
|
||||||
if (out_diff == nullptr) {
|
if (out_diff == nullptr) {
|
||||||
out_diff = curr_out_diff;
|
out_diff = curr_out_diff;
|
||||||
|
|||||||
87
mmdit.hpp
@ -33,7 +33,7 @@ public:
|
|||||||
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
|
||||||
|
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -211,8 +211,8 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto qkv = pre_attention(ctx, x);
|
auto qkv = pre_attention(ctx, x);
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -284,23 +284,19 @@ public:
|
|||||||
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
int64_t n_mods = 9;
|
int n_mods = 9;
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
|
||||||
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
auto shift_msa = m_vec[0]; // [N, hidden_size]
|
||||||
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
auto scale_msa = m_vec[1]; // [N, hidden_size]
|
||||||
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
auto gate_msa = m_vec[2]; // [N, hidden_size]
|
||||||
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
auto shift_mlp = m_vec[3]; // [N, hidden_size]
|
||||||
|
auto scale_mlp = m_vec[4]; // [N, hidden_size]
|
||||||
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
auto gate_mlp = m_vec[5]; // [N, hidden_size]
|
||||||
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
auto shift_msa2 = m_vec[6]; // [N, hidden_size]
|
||||||
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
auto scale_msa2 = m_vec[7]; // [N, hidden_size]
|
||||||
|
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
|
||||||
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
|
|
||||||
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
|
|
||||||
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
|
|
||||||
|
|
||||||
auto x_norm = norm1->forward(ctx, x);
|
auto x_norm = norm1->forward(ctx, x);
|
||||||
|
|
||||||
@ -322,22 +318,20 @@ public:
|
|||||||
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
int64_t n_mods = 6;
|
int n_mods = 6;
|
||||||
if (pre_only) {
|
if (pre_only) {
|
||||||
n_mods = 2;
|
n_mods = 2;
|
||||||
}
|
}
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
|
||||||
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
|
||||||
|
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
auto shift_msa = m_vec[0]; // [N, hidden_size]
|
||||||
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
auto scale_msa = m_vec[1]; // [N, hidden_size]
|
||||||
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
auto gate_msa = m_vec[2]; // [N, hidden_size]
|
||||||
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
auto shift_mlp = m_vec[3]; // [N, hidden_size]
|
||||||
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
auto scale_mlp = m_vec[4]; // [N, hidden_size]
|
||||||
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
auto gate_mlp = m_vec[5]; // [N, hidden_size]
|
||||||
|
|
||||||
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
|
||||||
|
|
||||||
@ -439,8 +433,8 @@ public:
|
|||||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||||
auto intermediates = std::get<2>(qkv_intermediates);
|
auto intermediates = std::get<2>(qkv_intermediates);
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention_x(ctx,
|
x = post_attention_x(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
attn2_out,
|
attn2_out,
|
||||||
@ -456,7 +450,7 @@ public:
|
|||||||
auto qkv = qkv_intermediates.first;
|
auto qkv = qkv_intermediates.first;
|
||||||
auto intermediates = qkv_intermediates.second;
|
auto intermediates = qkv_intermediates.second;
|
||||||
|
|
||||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = post_attention(ctx,
|
x = post_attention(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
intermediates[0],
|
intermediates[0],
|
||||||
@ -500,26 +494,24 @@ block_mixing(GGMLRunnerContext* ctx,
|
|||||||
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
||||||
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
|
||||||
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
context->ne[1],
|
context->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_context, N, hidden_size]
|
0); // [N, n_context, hidden_size]
|
||||||
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
|
|
||||||
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
|
auto x_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
x->ne[1],
|
x->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
|
context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
|
||||||
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
|
|
||||||
|
|
||||||
if (!context_block->pre_only) {
|
if (!context_block->pre_only) {
|
||||||
context = context_block->post_attention(ctx,
|
context = context_block->post_attention(ctx,
|
||||||
@ -534,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (x_block->self_attn) {
|
if (x_block->self_attn) {
|
||||||
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = x_block->post_attention_x(ctx,
|
x = x_block->post_attention_x(ctx,
|
||||||
x_attn,
|
x_attn,
|
||||||
@ -604,13 +596,10 @@ public:
|
|||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
|
||||||
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
|
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
|
||||||
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
|
auto shift = m_vec[0]; // [N, hidden_size]
|
||||||
|
auto scale = m_vec[1]; // [N, hidden_size]
|
||||||
int64_t offset = m->nb[1] * m->ne[1];
|
|
||||||
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
|
||||||
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
|
||||||
|
|
||||||
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
|
||||||
x = linear->forward(ctx, x);
|
x = linear->forward(ctx, x);
|
||||||
|
|||||||
36
model.cpp
@ -376,7 +376,11 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
|
|||||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||||
return init_from_ckpt_file(file_path, prefix);
|
return init_from_ckpt_file(file_path, prefix);
|
||||||
} else {
|
} else {
|
||||||
LOG_WARN("unknown format %s", file_path.c_str());
|
if (file_exists(file_path)) {
|
||||||
|
LOG_WARN("unknown format %s", file_path.c_str());
|
||||||
|
} else {
|
||||||
|
LOG_WARN("file %s not found", file_path.c_str());
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1034,10 +1038,14 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
|
|
||||||
bool is_xl = false;
|
bool is_xl = false;
|
||||||
bool is_flux = false;
|
bool is_flux = false;
|
||||||
|
bool is_flux2 = false;
|
||||||
|
bool has_single_block_47 = false;
|
||||||
bool is_wan = false;
|
bool is_wan = false;
|
||||||
int64_t patch_embedding_channels = 0;
|
int64_t patch_embedding_channels = 0;
|
||||||
bool has_img_emb = false;
|
bool has_img_emb = false;
|
||||||
bool has_middle_block_1 = false;
|
bool has_middle_block_1 = false;
|
||||||
|
bool has_output_block_311 = false;
|
||||||
|
bool has_output_block_71 = false;
|
||||||
|
|
||||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||||
if (!(is_xl)) {
|
if (!(is_xl)) {
|
||||||
@ -1054,7 +1062,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_QWEN_IMAGE;
|
return VERSION_QWEN_IMAGE;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
||||||
return VERSION_FLUX2;
|
is_flux2 = true;
|
||||||
|
}
|
||||||
|
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
|
||||||
|
has_single_block_47 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
||||||
return VERSION_OVIS_IMAGE;
|
return VERSION_OVIS_IMAGE;
|
||||||
@ -1094,6 +1105,12 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||||
has_middle_block_1 = true;
|
has_middle_block_1 = true;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||||
|
has_output_block_311 = true;
|
||||||
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||||
|
has_output_block_71 = true;
|
||||||
|
}
|
||||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
||||||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
||||||
@ -1129,12 +1146,15 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SDXL_PIX2PIX;
|
return VERSION_SDXL_PIX2PIX;
|
||||||
}
|
}
|
||||||
if (!has_middle_block_1) {
|
if (!has_middle_block_1) {
|
||||||
|
if (!has_output_block_311) {
|
||||||
|
return VERSION_SDXL_VEGA;
|
||||||
|
}
|
||||||
return VERSION_SDXL_SSD1B;
|
return VERSION_SDXL_SSD1B;
|
||||||
}
|
}
|
||||||
return VERSION_SDXL;
|
return VERSION_SDXL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_flux) {
|
if (is_flux && !is_flux2) {
|
||||||
if (input_block_weight.ne[0] == 384) {
|
if (input_block_weight.ne[0] == 384) {
|
||||||
return VERSION_FLUX_FILL;
|
return VERSION_FLUX_FILL;
|
||||||
}
|
}
|
||||||
@ -1147,6 +1167,13 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_FLUX;
|
return VERSION_FLUX;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_flux2) {
|
||||||
|
if (has_single_block_47) {
|
||||||
|
return VERSION_FLUX2;
|
||||||
|
}
|
||||||
|
return VERSION_FLUX2_KLEIN;
|
||||||
|
}
|
||||||
|
|
||||||
if (token_embedding_weight.ne[0] == 768) {
|
if (token_embedding_weight.ne[0] == 768) {
|
||||||
if (is_inpaint) {
|
if (is_inpaint) {
|
||||||
return VERSION_SD1_INPAINT;
|
return VERSION_SD1_INPAINT;
|
||||||
@ -1155,6 +1182,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SD1_PIX2PIX;
|
return VERSION_SD1_PIX2PIX;
|
||||||
}
|
}
|
||||||
if (!has_middle_block_1) {
|
if (!has_middle_block_1) {
|
||||||
|
if (!has_output_block_71) {
|
||||||
|
return VERSION_SDXS;
|
||||||
|
}
|
||||||
return VERSION_SD1_TINY_UNET;
|
return VERSION_SD1_TINY_UNET;
|
||||||
}
|
}
|
||||||
return VERSION_SD1;
|
return VERSION_SD1;
|
||||||
|
|||||||
9
model.h
@ -28,9 +28,11 @@ enum SDVersion {
|
|||||||
VERSION_SD2,
|
VERSION_SD2,
|
||||||
VERSION_SD2_INPAINT,
|
VERSION_SD2_INPAINT,
|
||||||
VERSION_SD2_TINY_UNET,
|
VERSION_SD2_TINY_UNET,
|
||||||
|
VERSION_SDXS,
|
||||||
VERSION_SDXL,
|
VERSION_SDXL,
|
||||||
VERSION_SDXL_INPAINT,
|
VERSION_SDXL_INPAINT,
|
||||||
VERSION_SDXL_PIX2PIX,
|
VERSION_SDXL_PIX2PIX,
|
||||||
|
VERSION_SDXL_VEGA,
|
||||||
VERSION_SDXL_SSD1B,
|
VERSION_SDXL_SSD1B,
|
||||||
VERSION_SVD,
|
VERSION_SVD,
|
||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
@ -44,13 +46,14 @@ enum SDVersion {
|
|||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V,
|
||||||
VERSION_QWEN_IMAGE,
|
VERSION_QWEN_IMAGE,
|
||||||
VERSION_FLUX2,
|
VERSION_FLUX2,
|
||||||
|
VERSION_FLUX2_KLEIN,
|
||||||
VERSION_Z_IMAGE,
|
VERSION_Z_IMAGE,
|
||||||
VERSION_OVIS_IMAGE,
|
VERSION_OVIS_IMAGE,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool sd_version_is_sd1(SDVersion version) {
|
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
|
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -64,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_sdxl(SDVersion version) {
|
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
|
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -99,7 +102,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_flux2(SDVersion version) {
|
static inline bool sd_version_is_flux2(SDVersion version) {
|
||||||
if (version == VERSION_FLUX2) {
|
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -842,6 +842,7 @@ std::string convert_sep_to_dot(std::string name) {
|
|||||||
"conv_in",
|
"conv_in",
|
||||||
"conv_out",
|
"conv_out",
|
||||||
"lora_down",
|
"lora_down",
|
||||||
|
"lora_mid",
|
||||||
"lora_up",
|
"lora_up",
|
||||||
"diff_b",
|
"diff_b",
|
||||||
"hada_w1_a",
|
"hada_w1_a",
|
||||||
@ -997,10 +998,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
|||||||
if (is_lora) {
|
if (is_lora) {
|
||||||
std::map<std::string, std::string> lora_suffix_map = {
|
std::map<std::string, std::string> lora_suffix_map = {
|
||||||
{".lora_down.weight", ".weight.lora_down"},
|
{".lora_down.weight", ".weight.lora_down"},
|
||||||
|
{".lora_mid.weight", ".weight.lora_mid"},
|
||||||
{".lora_up.weight", ".weight.lora_up"},
|
{".lora_up.weight", ".weight.lora_up"},
|
||||||
{".lora.down.weight", ".weight.lora_down"},
|
{".lora.down.weight", ".weight.lora_down"},
|
||||||
|
{".lora.mid.weight", ".weight.lora_mid"},
|
||||||
{".lora.up.weight", ".weight.lora_up"},
|
{".lora.up.weight", ".weight.lora_up"},
|
||||||
{"_lora.down.weight", ".weight.lora_down"},
|
{"_lora.down.weight", ".weight.lora_down"},
|
||||||
|
{"_lora.mid.weight", ".weight.lora_mid"},
|
||||||
{"_lora.up.weight", ".weight.lora_up"},
|
{"_lora.up.weight", ".weight.lora_up"},
|
||||||
{".lora_A.weight", ".weight.lora_down"},
|
{".lora_A.weight", ".weight.lora_down"},
|
||||||
{".lora_B.weight", ".weight.lora_up"},
|
{".lora_B.weight", ".weight.lora_up"},
|
||||||
|
|||||||
6
pmid.hpp
@ -33,7 +33,7 @@ public:
|
|||||||
x = layer_norm->forward(ctx, x);
|
x = layer_norm->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
|
||||||
x = fc1->forward(ctx, x);
|
x = fc1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
x = fc2->forward(ctx, x);
|
x = fc2->forward(ctx, x);
|
||||||
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
|
||||||
if (use_residue)
|
if (use_residue)
|
||||||
@ -129,8 +129,8 @@ public:
|
|||||||
k = reshape_tensor(ctx->ggml_ctx, k, heads);
|
k = reshape_tensor(ctx->ggml_ctx, k, heads);
|
||||||
v = reshape_tensor(ctx->ggml_ctx, v, heads);
|
v = reshape_tensor(ctx->ggml_ctx, v, heads);
|
||||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||||
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale);
|
k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
|
||||||
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale);
|
q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
|
||||||
// auto weight = ggml_mul_mat(ctx, q, k);
|
// auto weight = ggml_mul_mat(ctx, q, k);
|
||||||
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
|
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||||
|
|
||||||
|
|||||||
@ -162,26 +162,25 @@ namespace Qwen {
|
|||||||
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
|
||||||
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
txt->ne[1],
|
txt->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
0); // [n_txt_token, N, hidden_size]
|
0); // [N, n_txt_token, n_head*d_head]
|
||||||
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
|
||||||
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
attn->ne[1],
|
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
|
attn->ne[2],
|
||||||
attn->nb[1],
|
attn->nb[1],
|
||||||
attn->nb[2],
|
attn->nb[2],
|
||||||
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
|
||||||
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
|
||||||
|
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
|
||||||
|
|
||||||
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
||||||
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
||||||
|
|||||||
2
rope.hpp
@ -642,7 +642,7 @@ namespace Rope {
|
|||||||
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||||
|
|
||||||
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
}; // namespace Rope
|
}; // namespace Rope
|
||||||
|
|||||||
@ -31,9 +31,11 @@ const char* model_version_to_str[] = {
|
|||||||
"SD 2.x",
|
"SD 2.x",
|
||||||
"SD 2.x Inpaint",
|
"SD 2.x Inpaint",
|
||||||
"SD 2.x Tiny UNet",
|
"SD 2.x Tiny UNet",
|
||||||
|
"SDXS",
|
||||||
"SDXL",
|
"SDXL",
|
||||||
"SDXL Inpaint",
|
"SDXL Inpaint",
|
||||||
"SDXL Instruct-Pix2Pix",
|
"SDXL Instruct-Pix2Pix",
|
||||||
|
"SDXL (Vega)",
|
||||||
"SDXL (SSD1B)",
|
"SDXL (SSD1B)",
|
||||||
"SVD",
|
"SVD",
|
||||||
"SD3.x",
|
"SD3.x",
|
||||||
@ -47,6 +49,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Wan 2.2 TI2V",
|
"Wan 2.2 TI2V",
|
||||||
"Qwen Image",
|
"Qwen Image",
|
||||||
"Flux.2",
|
"Flux.2",
|
||||||
|
"Flux.2 klein",
|
||||||
"Z-Image",
|
"Z-Image",
|
||||||
"Ovis Image",
|
"Ovis Image",
|
||||||
};
|
};
|
||||||
@ -64,6 +67,8 @@ const char* sampling_methods_str[] = {
|
|||||||
"LCM",
|
"LCM",
|
||||||
"DDIM \"trailing\"",
|
"DDIM \"trailing\"",
|
||||||
"TCD",
|
"TCD",
|
||||||
|
"Res Multistep",
|
||||||
|
"Res 2s",
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
@ -407,6 +412,11 @@ public:
|
|||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
||||||
|
if (version == VERSION_SDXS) {
|
||||||
|
tae_preview_only = false;
|
||||||
|
}
|
||||||
|
|
||||||
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
||||||
LOG_INFO("Using circular padding for convolutions");
|
LOG_INFO("Using circular padding for convolutions");
|
||||||
}
|
}
|
||||||
@ -435,7 +445,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_chroma) {
|
if (is_chroma) {
|
||||||
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
|
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
|
||||||
LOG_WARN(
|
LOG_WARN(
|
||||||
"!!!It looks like you are using Chroma with flash attention. "
|
"!!!It looks like you are using Chroma with flash attention. "
|
||||||
"This is currently unsupported. "
|
"This is currently unsupported. "
|
||||||
@ -561,14 +571,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sd_ctx_params->diffusion_flash_attn) {
|
|
||||||
LOG_INFO("Using flash attention in the diffusion model");
|
|
||||||
diffusion_model->set_flash_attn_enabled(true);
|
|
||||||
if (high_noise_diffusion_model) {
|
|
||||||
high_noise_diffusion_model->set_flash_attn_enabled(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cond_stage_model->alloc_params_buffer();
|
cond_stage_model->alloc_params_buffer();
|
||||||
cond_stage_model->get_param_tensors(tensors);
|
cond_stage_model->get_param_tensors(tensors);
|
||||||
|
|
||||||
@ -591,7 +593,7 @@ public:
|
|||||||
vae_backend = backend;
|
vae_backend = backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
|
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
@ -616,7 +618,7 @@ public:
|
|||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
LOG_INFO("Using Conv2d direct in the vae model");
|
||||||
first_stage_model->set_conv2d_direct_enabled(true);
|
first_stage_model->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
if (version == VERSION_SDXL &&
|
if (sd_version_is_sdxl(version) &&
|
||||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||||
float vae_conv_2d_scale = 1.f / 32.f;
|
float vae_conv_2d_scale = 1.f / 32.f;
|
||||||
LOG_WARN(
|
LOG_WARN(
|
||||||
@ -629,8 +631,7 @@ public:
|
|||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||||
if (use_tiny_autoencoder) {
|
|
||||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
@ -645,6 +646,10 @@ public:
|
|||||||
"decoder.layers",
|
"decoder.layers",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
version);
|
version);
|
||||||
|
if (version == VERSION_SDXS) {
|
||||||
|
tae_first_stage->alloc_params_buffer();
|
||||||
|
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the tae model");
|
LOG_INFO("Using Conv2d direct in the tae model");
|
||||||
@ -712,6 +717,28 @@ public:
|
|||||||
pmid_model->get_param_tensors(tensors, "pmid");
|
pmid_model->get_param_tensors(tensors, "pmid");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_ctx_params->flash_attn) {
|
||||||
|
LOG_INFO("Using flash attention");
|
||||||
|
cond_stage_model->set_flash_attention_enabled(true);
|
||||||
|
if (clip_vision) {
|
||||||
|
clip_vision->set_flash_attention_enabled(true);
|
||||||
|
}
|
||||||
|
if (first_stage_model) {
|
||||||
|
first_stage_model->set_flash_attention_enabled(true);
|
||||||
|
}
|
||||||
|
if (tae_first_stage) {
|
||||||
|
tae_first_stage->set_flash_attention_enabled(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
|
||||||
|
LOG_INFO("Using flash attention in the diffusion model");
|
||||||
|
diffusion_model->set_flash_attention_enabled(true);
|
||||||
|
if (high_noise_diffusion_model) {
|
||||||
|
high_noise_diffusion_model->set_flash_attention_enabled(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
if (high_noise_diffusion_model) {
|
if (high_noise_diffusion_model) {
|
||||||
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||||
@ -782,14 +809,15 @@ public:
|
|||||||
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
|
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
size_t vae_params_mem_size = 0;
|
size_t vae_params_mem_size = 0;
|
||||||
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
|
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||||
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
if (use_tiny_autoencoder) {
|
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||||
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
|
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
|
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
|
||||||
|
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
size_t control_net_params_mem_size = 0;
|
size_t control_net_params_mem_size = 0;
|
||||||
if (control_net) {
|
if (control_net) {
|
||||||
@ -945,7 +973,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
|
use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2489,10 +2517,15 @@ public:
|
|||||||
ne2 = 1;
|
ne2 = 1;
|
||||||
ne3 = C * x->ne[3];
|
ne3 = C * x->ne[3];
|
||||||
} else {
|
} else {
|
||||||
if (!use_tiny_autoencoder) {
|
int64_t out_channels = C;
|
||||||
C *= 2;
|
bool encode_outputs_mu = use_tiny_autoencoder ||
|
||||||
|
sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_flux2(version) ||
|
||||||
|
version == VERSION_CHROMA_RADIANCE;
|
||||||
|
if (!encode_outputs_mu) {
|
||||||
|
out_channels *= 2;
|
||||||
}
|
}
|
||||||
ne2 = C;
|
ne2 = out_channels;
|
||||||
ne3 = x->ne[3];
|
ne3 = x->ne[3];
|
||||||
}
|
}
|
||||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
||||||
@ -2633,7 +2666,7 @@ public:
|
|||||||
}
|
}
|
||||||
process_latent_out(x);
|
process_latent_out(x);
|
||||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||||
if (vae_tiling_params.enabled && !decode_video) {
|
if (vae_tiling_params.enabled) {
|
||||||
float tile_overlap;
|
float tile_overlap;
|
||||||
int tile_size_x, tile_size_y;
|
int tile_size_x, tile_size_y;
|
||||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
|
||||||
@ -2651,7 +2684,7 @@ public:
|
|||||||
first_stage_model->free_compute_buffer();
|
first_stage_model->free_compute_buffer();
|
||||||
process_vae_output_tensor(result);
|
process_vae_output_tensor(result);
|
||||||
} else {
|
} else {
|
||||||
if (vae_tiling_params.enabled && !decode_video) {
|
if (vae_tiling_params.enabled) {
|
||||||
// split latent in 64x64 tiles and compute in several steps
|
// split latent in 64x64 tiles and compute in several steps
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
tae_first_stage->compute(n_threads, in, true, &out);
|
tae_first_stage->compute(n_threads, in, true, &out);
|
||||||
@ -2726,6 +2759,8 @@ const char* sample_method_to_str[] = {
|
|||||||
"lcm",
|
"lcm",
|
||||||
"ddim_trailing",
|
"ddim_trailing",
|
||||||
"tcd",
|
"tcd",
|
||||||
|
"res_multistep",
|
||||||
|
"res_2s",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
||||||
@ -2755,6 +2790,7 @@ const char* scheduler_to_str[] = {
|
|||||||
"smoothstep",
|
"smoothstep",
|
||||||
"kl_optimal",
|
"kl_optimal",
|
||||||
"lcm",
|
"lcm",
|
||||||
|
"bong_tangent",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
||||||
@ -2920,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"keep_clip_on_cpu: %s\n"
|
"keep_clip_on_cpu: %s\n"
|
||||||
"keep_control_net_on_cpu: %s\n"
|
"keep_control_net_on_cpu: %s\n"
|
||||||
"keep_vae_on_cpu: %s\n"
|
"keep_vae_on_cpu: %s\n"
|
||||||
|
"flash_attn: %s\n"
|
||||||
"diffusion_flash_attn: %s\n"
|
"diffusion_flash_attn: %s\n"
|
||||||
"circular_x: %s\n"
|
"circular_x: %s\n"
|
||||||
"circular_y: %s\n"
|
"circular_y: %s\n"
|
||||||
@ -2951,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
|
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
|
||||||
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
||||||
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
|
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
|
||||||
|
BOOL_STR(sd_ctx_params->flash_attn),
|
||||||
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
|
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
|
||||||
BOOL_STR(sd_ctx_params->circular_x),
|
BOOL_STR(sd_ctx_params->circular_x),
|
||||||
BOOL_STR(sd_ctx_params->circular_y),
|
BOOL_STR(sd_ctx_params->circular_y),
|
||||||
@ -3047,6 +3085,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
"sample_params: %s\n"
|
"sample_params: %s\n"
|
||||||
"strength: %.2f\n"
|
"strength: %.2f\n"
|
||||||
"seed: %" PRId64
|
"seed: %" PRId64
|
||||||
|
"\n"
|
||||||
"batch_count: %d\n"
|
"batch_count: %d\n"
|
||||||
"ref_images_count: %d\n"
|
"ref_images_count: %d\n"
|
||||||
"auto_resize_ref_image: %s\n"
|
"auto_resize_ref_image: %s\n"
|
||||||
@ -3099,6 +3138,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
|||||||
sd_vid_gen_params->video_frames = 6;
|
sd_vid_gen_params->video_frames = 6;
|
||||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||||
sd_vid_gen_params->vace_strength = 1.f;
|
sd_vid_gen_params->vace_strength = 1.f;
|
||||||
|
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||||
sd_cache_params_init(&sd_vid_gen_params->cache);
|
sd_cache_params_init(&sd_vid_gen_params->cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3728,6 +3768,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params;
|
||||||
|
|
||||||
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
||||||
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
||||||
|
|||||||
@ -48,6 +48,8 @@ enum sample_method_t {
|
|||||||
LCM_SAMPLE_METHOD,
|
LCM_SAMPLE_METHOD,
|
||||||
DDIM_TRAILING_SAMPLE_METHOD,
|
DDIM_TRAILING_SAMPLE_METHOD,
|
||||||
TCD_SAMPLE_METHOD,
|
TCD_SAMPLE_METHOD,
|
||||||
|
RES_MULTISTEP_SAMPLE_METHOD,
|
||||||
|
RES_2S_SAMPLE_METHOD,
|
||||||
SAMPLE_METHOD_COUNT
|
SAMPLE_METHOD_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -62,6 +64,7 @@ enum scheduler_t {
|
|||||||
SMOOTHSTEP_SCHEDULER,
|
SMOOTHSTEP_SCHEDULER,
|
||||||
KL_OPTIMAL_SCHEDULER,
|
KL_OPTIMAL_SCHEDULER,
|
||||||
LCM_SCHEDULER,
|
LCM_SCHEDULER,
|
||||||
|
BONG_TANGENT_SCHEDULER,
|
||||||
SCHEDULER_COUNT
|
SCHEDULER_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -186,6 +189,7 @@ typedef struct {
|
|||||||
bool keep_clip_on_cpu;
|
bool keep_clip_on_cpu;
|
||||||
bool keep_control_net_on_cpu;
|
bool keep_control_net_on_cpu;
|
||||||
bool keep_vae_on_cpu;
|
bool keep_vae_on_cpu;
|
||||||
|
bool flash_attn;
|
||||||
bool diffusion_flash_attn;
|
bool diffusion_flash_attn;
|
||||||
bool tae_preview_only;
|
bool tae_preview_only;
|
||||||
bool diffusion_conv_direct;
|
bool diffusion_conv_direct;
|
||||||
@ -319,6 +323,7 @@ typedef struct {
|
|||||||
int64_t seed;
|
int64_t seed;
|
||||||
int video_frames;
|
int video_frames;
|
||||||
float vace_strength;
|
float vace_strength;
|
||||||
|
sd_tiling_params_t vae_tiling_params;
|
||||||
sd_cache_params_t cache;
|
sd_cache_params_t cache;
|
||||||
} sd_vid_gen_params_t;
|
} sd_vid_gen_params_t;
|
||||||
|
|
||||||
|
|||||||
4
t5.hpp
@ -515,7 +515,7 @@ public:
|
|||||||
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
|
||||||
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
||||||
|
|
||||||
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x));
|
auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
|
||||||
auto hidden_linear = wi_1->forward(ctx, x);
|
auto hidden_linear = wi_1->forward(ctx, x);
|
||||||
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
|
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
|
||||||
x = wo->forward(ctx, x);
|
x = wo->forward(ctx, x);
|
||||||
@ -608,7 +608,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)));
|
k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
|
||||||
|
|
||||||
|
|||||||
80
tae.hpp
@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
|
|||||||
protected:
|
protected:
|
||||||
int n_in;
|
int n_in;
|
||||||
int n_out;
|
int n_out;
|
||||||
|
bool use_midblock_gn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAEBlock(int n_in, int n_out)
|
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
|
||||||
: n_in(n_in), n_out(n_out) {
|
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
|
||||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
if (n_in != n_out) {
|
if (n_in != n_out) {
|
||||||
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
|
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
}
|
}
|
||||||
|
if (use_midblock_gn) {
|
||||||
|
int n_gn = n_in * 4;
|
||||||
|
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
|
||||||
|
// pool.2 is ReLU, handled in forward
|
||||||
|
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [n, n_in, h, w]
|
// x: [n, n_in, h, w]
|
||||||
// return: [n, n_out, h, w]
|
// return: [n, n_out, h, w]
|
||||||
|
|
||||||
|
if (use_midblock_gn) {
|
||||||
|
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
|
||||||
|
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
|
||||||
|
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
|
||||||
|
|
||||||
|
auto p = pool_0->forward(ctx, x);
|
||||||
|
p = pool_1->forward(ctx, p);
|
||||||
|
p = ggml_relu_inplace(ctx->ggml_ctx, p);
|
||||||
|
p = pool_3->forward(ctx, p);
|
||||||
|
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, p);
|
||||||
|
}
|
||||||
|
|
||||||
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||||
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
||||||
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||||
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyEncoder(int z_channels = 4)
|
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||||
: z_channels(z_channels) {
|
: z_channels(z_channels) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
@ -80,7 +101,7 @@ public:
|
|||||||
|
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||||
for (int i = 0; i < num_blocks; i++) {
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyDecoder(int z_channels = 4)
|
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||||
: z_channels(z_channels) {
|
: z_channels(z_channels) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
@ -115,7 +136,7 @@ public:
|
|||||||
index++; // nn.ReLU()
|
index++; // nn.ReLU()
|
||||||
|
|
||||||
for (int i = 0; i < num_blocks; i++) {
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
index++; // nn.Upsample()
|
index++; // nn.Upsample()
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
@ -140,9 +161,9 @@ public:
|
|||||||
// z: [n, z_channels, h, w]
|
// z: [n, z_channels, h, w]
|
||||||
// return: [n, out_channels, h*8, w*8]
|
// return: [n, out_channels, h*8, w*8]
|
||||||
|
|
||||||
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
|
auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
|
||||||
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
|
h = ggml_tanh_inplace(ctx->ggml_ctx, h);
|
||||||
h = ggml_scale(ctx->ggml_ctx, h, 3.0f);
|
h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
|
||||||
|
|
||||||
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
for (int i = 0; i < num_blocks * 3 + 10; i++) {
|
||||||
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
if (blocks.find(std::to_string(i)) == blocks.end()) {
|
||||||
@ -379,10 +400,11 @@ public:
|
|||||||
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
||||||
|
|
||||||
// Clamp()
|
// Clamp()
|
||||||
auto h = ggml_scale_inplace(ctx->ggml_ctx,
|
auto h = ggml_ext_scale(ctx->ggml_ctx,
|
||||||
ggml_tanh_inplace(ctx->ggml_ctx,
|
ggml_tanh_inplace(ctx->ggml_ctx,
|
||||||
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||||
3.0f);
|
3.0f,
|
||||||
|
true);
|
||||||
|
|
||||||
h = first_conv->forward(ctx, h);
|
h = first_conv->forward(ctx, h);
|
||||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
@ -470,29 +492,44 @@ public:
|
|||||||
class TAESD : public GGMLBlock {
|
class TAESD : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
bool decode_only;
|
bool decode_only;
|
||||||
|
bool taef2 = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only) {
|
: decode_only(decode_only) {
|
||||||
int z_channels = 4;
|
int z_channels = 4;
|
||||||
|
bool use_midblock_gn = false;
|
||||||
|
taef2 = sd_version_is_flux2(version);
|
||||||
|
|
||||||
if (sd_version_is_dit(version)) {
|
if (sd_version_is_dit(version)) {
|
||||||
z_channels = 16;
|
z_channels = 16;
|
||||||
}
|
}
|
||||||
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
|
if (taef2) {
|
||||||
|
z_channels = 32;
|
||||||
|
use_midblock_gn = true;
|
||||||
|
}
|
||||||
|
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
|
||||||
|
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
|
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
||||||
|
if (taef2) {
|
||||||
|
z = unpatchify(ctx->ggml_ctx, z, 2);
|
||||||
|
}
|
||||||
return decoder->forward(ctx, z);
|
return decoder->forward(ctx, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
||||||
return encoder->forward(ctx, x);
|
auto z = encoder->forward(ctx, x);
|
||||||
|
if (taef2) {
|
||||||
|
z = patchify(ctx->ggml_ctx, z, 2);
|
||||||
|
}
|
||||||
|
return z;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -505,7 +542,8 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
struct ggml_context* output_ctx = nullptr) = 0;
|
struct ggml_context* output_ctx = nullptr) = 0;
|
||||||
|
|
||||||
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
||||||
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||||
@ -555,6 +593,10 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
|||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
taesd.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
@ -624,6 +666,10 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
|||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
taehv.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
|
|||||||
16
unet.hpp
@ -201,6 +201,9 @@ public:
|
|||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
use_linear_projection = true;
|
use_linear_projection = true;
|
||||||
|
if (version == VERSION_SDXL_VEGA) {
|
||||||
|
transformer_depth = {1, 1, 2};
|
||||||
|
}
|
||||||
} else if (version == VERSION_SVD) {
|
} else if (version == VERSION_SVD) {
|
||||||
in_channels = 8;
|
in_channels = 8;
|
||||||
out_channels = 4;
|
out_channels = 4;
|
||||||
@ -215,10 +218,13 @@ public:
|
|||||||
} else if (sd_version_is_unet_edit(version)) {
|
} else if (sd_version_is_unet_edit(version)) {
|
||||||
in_channels = 8;
|
in_channels = 8;
|
||||||
}
|
}
|
||||||
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) {
|
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) {
|
||||||
num_res_blocks = 1;
|
num_res_blocks = 1;
|
||||||
channel_mult = {1, 2, 4};
|
channel_mult = {1, 2, 4};
|
||||||
tiny_unet = true;
|
tiny_unet = true;
|
||||||
|
if (version == VERSION_SDXS) {
|
||||||
|
attention_resolutions = {4, 2}; // here just like SDXL
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dims is always 2
|
// dims is always 2
|
||||||
@ -316,7 +322,7 @@ public:
|
|||||||
}
|
}
|
||||||
if (!tiny_unet) {
|
if (!tiny_unet) {
|
||||||
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||||
if (version != VERSION_SDXL_SSD1B) {
|
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||||
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||||
n_head,
|
n_head,
|
||||||
d_head,
|
d_head,
|
||||||
@ -517,13 +523,13 @@ public:
|
|||||||
// middle_block
|
// middle_block
|
||||||
if (!tiny_unet) {
|
if (!tiny_unet) {
|
||||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
if (version != VERSION_SDXL_SSD1B) {
|
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||||
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength);
|
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
|
||||||
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
|
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
|
||||||
}
|
}
|
||||||
int control_offset = static_cast<int>(controls.size() - 2);
|
int control_offset = static_cast<int>(controls.size() - 2);
|
||||||
@ -536,7 +542,7 @@ public:
|
|||||||
hs.pop_back();
|
hs.pop_back();
|
||||||
|
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength);
|
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
|
||||||
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
|
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
|
||||||
control_offset--;
|
control_offset--;
|
||||||
}
|
}
|
||||||
|
|||||||
13
vae.hpp
@ -127,8 +127,6 @@ public:
|
|||||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
|
||||||
} else {
|
} else {
|
||||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
@ -138,11 +136,12 @@ public:
|
|||||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||||
}
|
}
|
||||||
|
|
||||||
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
|
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
|
||||||
|
|
||||||
if (use_linear) {
|
if (use_linear) {
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
@ -254,8 +253,8 @@ public:
|
|||||||
|
|
||||||
float alpha = get_alpha();
|
float alpha = get_alpha();
|
||||||
x = ggml_add(ctx->ggml_ctx,
|
x = ggml_add(ctx->ggml_ctx,
|
||||||
ggml_scale(ctx->ggml_ctx, x, alpha),
|
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
||||||
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||||
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
|||||||
26
wan.hpp
@ -572,9 +572,8 @@ namespace WAN {
|
|||||||
auto v = qkv_vec[2];
|
auto v = qkv_vec[2];
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c]
|
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||||
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
|
||||||
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
|
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
@ -1394,7 +1393,7 @@ namespace WAN {
|
|||||||
k = norm_k->forward(ctx, k);
|
k = norm_k->forward(ctx, k);
|
||||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1443,11 +1442,8 @@ namespace WAN {
|
|||||||
int64_t dim = x->ne[0];
|
int64_t dim = x->ne[0];
|
||||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||||
|
|
||||||
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
|
||||||
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
|
||||||
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
|
|
||||||
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
|
||||||
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
|
||||||
|
|
||||||
auto q = q_proj->forward(ctx, x);
|
auto q = q_proj->forward(ctx, x);
|
||||||
q = norm_q->forward(ctx, q);
|
q = norm_q->forward(ctx, q);
|
||||||
@ -1459,8 +1455,8 @@ namespace WAN {
|
|||||||
k_img = norm_k_img->forward(ctx, k_img);
|
k_img = norm_k_img->forward(ctx, k_img);
|
||||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
|
||||||
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||||
|
|
||||||
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
||||||
|
|
||||||
@ -1577,7 +1573,7 @@ namespace WAN {
|
|||||||
y = modulate_add(ctx->ggml_ctx, y, es[3]);
|
y = modulate_add(ctx->ggml_ctx, y, es[3]);
|
||||||
|
|
||||||
y = ffn_0->forward(ctx, y);
|
y = ffn_0->forward(ctx, y);
|
||||||
y = ggml_gelu_inplace(ctx->ggml_ctx, y);
|
y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
|
||||||
y = ffn_2->forward(ctx, y);
|
y = ffn_2->forward(ctx, y);
|
||||||
|
|
||||||
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
|
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
|
||||||
@ -1724,7 +1720,7 @@ namespace WAN {
|
|||||||
|
|
||||||
auto x = proj_0->forward(ctx, image_embeds);
|
auto x = proj_0->forward(ctx, image_embeds);
|
||||||
x = proj_1->forward(ctx, x);
|
x = proj_1->forward(ctx, x);
|
||||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
x = proj_3->forward(ctx, x);
|
x = proj_3->forward(ctx, x);
|
||||||
x = proj_4->forward(ctx, x);
|
x = proj_4->forward(ctx, x);
|
||||||
|
|
||||||
@ -1911,7 +1907,7 @@ namespace WAN {
|
|||||||
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
||||||
|
|
||||||
context = text_embedding_0->forward(ctx, context);
|
context = text_embedding_0->forward(ctx, context);
|
||||||
context = ggml_gelu(ctx->ggml_ctx, context);
|
context = ggml_ext_gelu(ctx->ggml_ctx, context);
|
||||||
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
int64_t context_img_len = 0;
|
int64_t context_img_len = 0;
|
||||||
@ -1950,7 +1946,7 @@ namespace WAN {
|
|||||||
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
|
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
|
||||||
auto c_skip = result.first;
|
auto c_skip = result.first;
|
||||||
c = result.second;
|
c = result.second;
|
||||||
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength);
|
c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
|
||||||
x = ggml_add(ctx->ggml_ctx, x, c_skip);
|
x = ggml_add(ctx->ggml_ctx, x, c_skip);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
40
z_image.hpp
@ -54,15 +54,37 @@ namespace ZImage {
|
|||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
||||||
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
||||||
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
|
|
||||||
|
|
||||||
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim]
|
auto q = ggml_view_4d(ctx->ggml_ctx,
|
||||||
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim]
|
qkv,
|
||||||
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim]
|
qkv->ne[0],
|
||||||
|
num_heads,
|
||||||
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim]
|
qkv->ne[2],
|
||||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
qkv->ne[3],
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
0); // [N, n_token, num_heads, head_dim]
|
||||||
|
auto k = ggml_view_4d(ctx->ggml_ctx,
|
||||||
|
qkv,
|
||||||
|
qkv->ne[0],
|
||||||
|
num_kv_heads,
|
||||||
|
qkv->ne[2],
|
||||||
|
qkv->ne[3],
|
||||||
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
auto v = ggml_view_4d(ctx->ggml_ctx,
|
||||||
|
qkv,
|
||||||
|
qkv->ne[0],
|
||||||
|
num_kv_heads,
|
||||||
|
qkv->ne[2],
|
||||||
|
qkv->ne[3],
|
||||||
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
|
||||||
if (qk_norm) {
|
if (qk_norm) {
|
||||||
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
||||||
@ -495,7 +517,7 @@ namespace ZImage {
|
|||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
||||||
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
|
||||||
|
|
||||||
out = ggml_scale(ctx->ggml_ctx, out, -1.f);
|
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||