mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
Compare commits
11 Commits
ac54e00760
...
b5f4932696
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5f4932696 | ||
|
|
1c168d98a5 | ||
|
|
ea9b647080 | ||
|
|
2b1bc06477 | ||
|
|
b99cbfe4dc | ||
|
|
8c7719fe9a | ||
|
|
8f94efafa3 | ||
|
|
07585448ad | ||
|
|
6ea812256e | ||
|
|
9b1d90bc23 | ||
|
|
65fa646684 |
@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
|
||||
option(SD_METAL "sd: metal backend" OFF)
|
||||
option(SD_VULKAN "sd: vulkan backend" OFF)
|
||||
option(SD_SYCL "sd: sycl backend" OFF)
|
||||
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
|
||||
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
|
||||
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
|
||||
#option(SD_BUILD_SERVER "sd: build server example" ON)
|
||||
@ -61,11 +60,6 @@ if (SD_HIPBLAS)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if(SD_FLASH_ATTN)
|
||||
message("-- Use Flash Attention for memory optimization")
|
||||
add_definitions(-DSD_USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
|
||||
set(SD_LIB stable-diffusion)
|
||||
|
||||
file(GLOB SD_LIB_SOURCES
|
||||
|
||||
24
README.md
24
README.md
@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
|
||||
- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration.
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
|
||||
- No need to convert to `.ggml` or `.gguf` anymore!
|
||||
- Flash Attention for memory usage optimization (only cpu for now)
|
||||
- Flash Attention for memory usage optimization
|
||||
- Original `txt2img` and `img2img` mode
|
||||
- Negative prompt
|
||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||
@ -182,11 +182,21 @@ Example of text2img by using SYCL backend:
|
||||
|
||||
##### Using Flash Attention
|
||||
|
||||
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
|
||||
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
|
||||
eg.:
|
||||
- flux 768x768 ~600mb
|
||||
- SD2 768x768 ~1400mb
|
||||
|
||||
For most backends, it slows things down, but for cuda it generally speeds it up too.
|
||||
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
|
||||
|
||||
Run by adding `--diffusion-fa` to the arguments and watch for:
|
||||
```
|
||||
cmake .. -DSD_FLASH_ATTN=ON
|
||||
cmake --build . --config Release
|
||||
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
|
||||
```
|
||||
and the compute buffer shrink in the debug log:
|
||||
```
|
||||
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
|
||||
```
|
||||
|
||||
### Run
|
||||
@ -240,6 +250,9 @@ arguments:
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
--diffusion-fa use flash attention in the diffusion model (for low vram)
|
||||
Might lower quality, since it implies converting k and v to f16.
|
||||
This might crash if it is not supported by the backend.
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--color Colors the logging tags according to level
|
||||
@ -292,12 +305,15 @@ These projects wrap `stable-diffusion.cpp` for easier use in other languages/fra
|
||||
|
||||
* Golang: [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion)
|
||||
* C#: [DarthAffe/StableDiffusion.NET](https://github.com/DarthAffe/StableDiffusion.NET)
|
||||
* Python: [william-murray1204/stable-diffusion-cpp-python](https://github.com/william-murray1204/stable-diffusion-cpp-python)
|
||||
* Rust: [newfla/diffusion-rs](https://github.com/newfla/diffusion-rs)
|
||||
|
||||
## UIs
|
||||
|
||||
These projects use `stable-diffusion.cpp` as a backend for their image generation.
|
||||
|
||||
- [Jellybox](https://jellybox.com)
|
||||
- [Stable Diffusion GUI](https://github.com/fszontagh/sd.cpp.gui.wx)
|
||||
|
||||
## Contributors
|
||||
|
||||
|
||||
31
clip.hpp
31
clip.hpp
@ -343,6 +343,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::string clean_up_tokenization(std::string& text) {
|
||||
std::regex pattern(R"( ,)");
|
||||
// Replace " ," with ","
|
||||
std::string result = std::regex_replace(text, pattern, ",");
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string decode(const std::vector<int>& tokens) {
|
||||
std::string text = "";
|
||||
for (int t : tokens) {
|
||||
@ -351,8 +358,12 @@ public:
|
||||
std::u32string ts = decoder[t];
|
||||
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
|
||||
std::string s = utf32_to_utf8(ts);
|
||||
if (s.length() >= 4 && ends_with(s, "</w>")) {
|
||||
text += " " + s.replace(s.length() - 4, s.length() - 1, "");
|
||||
if (s.length() >= 4) {
|
||||
if (ends_with(s, "</w>")) {
|
||||
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
|
||||
} else {
|
||||
text += s;
|
||||
}
|
||||
} else {
|
||||
text += " " + s;
|
||||
}
|
||||
@ -364,6 +375,7 @@ public:
|
||||
|
||||
// std::string s((char *)bytes.data());
|
||||
// std::string s = "";
|
||||
text = clean_up_tokenization(text);
|
||||
return trim(text);
|
||||
}
|
||||
|
||||
@ -711,8 +723,12 @@ public:
|
||||
if (return_pooled) {
|
||||
auto text_projection = params["text_projection"];
|
||||
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
|
||||
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled);
|
||||
return pooled;
|
||||
if (text_projection != NULL) {
|
||||
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
||||
} else {
|
||||
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
|
||||
}
|
||||
return pooled; // [hidden_size, 1, 1]
|
||||
}
|
||||
|
||||
return x; // [N, n_token, hidden_size]
|
||||
@ -761,14 +777,17 @@ public:
|
||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||
x = pre_layernorm->forward(ctx, x);
|
||||
x = encoder->forward(ctx, x, -1, false);
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||
auto last_hidden_state = x;
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
if (return_pooled) {
|
||||
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
|
||||
return pooled; // [N, hidden_size]
|
||||
} else {
|
||||
return x; // [N, n_token, hidden_size]
|
||||
// return x; // [N, n_token, hidden_size]
|
||||
return last_hidden_state; // [N, n_token, hidden_size]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
23
common.hpp
23
common.hpp
@ -245,16 +245,19 @@ protected:
|
||||
int64_t context_dim;
|
||||
int64_t n_head;
|
||||
int64_t d_head;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
CrossAttention(int64_t query_dim,
|
||||
int64_t context_dim,
|
||||
int64_t n_head,
|
||||
int64_t d_head)
|
||||
int64_t d_head,
|
||||
bool flash_attn = false)
|
||||
: n_head(n_head),
|
||||
d_head(d_head),
|
||||
query_dim(query_dim),
|
||||
context_dim(context_dim) {
|
||||
context_dim(context_dim),
|
||||
flash_attn(flash_attn) {
|
||||
int64_t inner_dim = d_head * n_head;
|
||||
|
||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
|
||||
@ -283,7 +286,7 @@ public:
|
||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
|
||||
|
||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||
return x;
|
||||
@ -301,15 +304,16 @@ public:
|
||||
int64_t n_head,
|
||||
int64_t d_head,
|
||||
int64_t context_dim,
|
||||
bool ff_in = false)
|
||||
bool ff_in = false,
|
||||
bool flash_attn = false)
|
||||
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
|
||||
// disable_self_attn is always False
|
||||
// disable_temporal_crossattention is always False
|
||||
// switch_temporal_ca_to_sa is always False
|
||||
// inner_dim is always None or equal to dim
|
||||
// gated_ff is always True
|
||||
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
|
||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
|
||||
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
|
||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
|
||||
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
@ -374,7 +378,8 @@ public:
|
||||
int64_t n_head,
|
||||
int64_t d_head,
|
||||
int64_t depth,
|
||||
int64_t context_dim)
|
||||
int64_t context_dim,
|
||||
bool flash_attn = false)
|
||||
: in_channels(in_channels),
|
||||
n_head(n_head),
|
||||
d_head(d_head),
|
||||
@ -388,7 +393,7 @@ public:
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "transformer_blocks." + std::to_string(i);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
|
||||
}
|
||||
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
||||
@ -511,4 +516,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __COMMON_HPP__
|
||||
#endif // __COMMON_HPP__
|
||||
|
||||
@ -43,7 +43,8 @@ struct Conditioner {
|
||||
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
|
||||
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
SDVersion version = VERSION_SD1;
|
||||
SDVersion version = VERSION_SD1;
|
||||
PMVersion pm_version = PM_VERSION_1;
|
||||
CLIPTokenizer tokenizer;
|
||||
ggml_type wtype;
|
||||
std::shared_ptr<CLIPTextModelRunner> text_model;
|
||||
@ -59,8 +60,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
ggml_type wtype,
|
||||
const std::string& embd_dir,
|
||||
SDVersion version = VERSION_SD1,
|
||||
PMVersion pv = PM_VERSION_1,
|
||||
int clip_skip = -1)
|
||||
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
|
||||
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
|
||||
if (clip_skip <= 0) {
|
||||
clip_skip = 1;
|
||||
if (version == VERSION_SD2 || version == VERSION_SDXL) {
|
||||
@ -268,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
std::vector<int> clean_input_ids_tmp;
|
||||
for (uint32_t i = 0; i < class_token_index[0]; i++)
|
||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||
for (uint32_t i = 0; i < num_input_imgs; i++)
|
||||
for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
|
||||
clean_input_ids_tmp.push_back(class_token);
|
||||
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
|
||||
clean_input_ids_tmp.push_back(clean_input_ids[i]);
|
||||
@ -279,13 +281,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
|
||||
weights.insert(weights.end(), clean_input_ids.size(), curr_weight);
|
||||
}
|
||||
tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
|
||||
weights.insert(weights.begin(), 1.0);
|
||||
// BUG!! double couting, pad_tokens will add BOS at the beginning
|
||||
// tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
|
||||
// weights.insert(weights.begin(), 1.0);
|
||||
|
||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||
|
||||
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
|
||||
for (uint32_t i = 0; i < tokens.size(); i++) {
|
||||
if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs)
|
||||
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
|
||||
// hardcode for now
|
||||
class_token_mask.push_back(true);
|
||||
else
|
||||
class_token_mask.push_back(false);
|
||||
@ -798,21 +803,16 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
if (chunk_idx == 0) {
|
||||
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
// clip_l->compute(n_threads,
|
||||
// input_ids,
|
||||
// 0,
|
||||
// NULL,
|
||||
// max_token_idx,
|
||||
// true,
|
||||
// &pooled_l,
|
||||
// work_ctx);
|
||||
|
||||
// clip_l.transformer.text_model.text_projection no in file, ignore
|
||||
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
|
||||
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
|
||||
ggml_set_f32(pooled_l, 0.f);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
clip_l->compute(n_threads,
|
||||
input_ids,
|
||||
0,
|
||||
NULL,
|
||||
max_token_idx,
|
||||
true,
|
||||
&pooled_l,
|
||||
work_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -852,21 +852,16 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
if (chunk_idx == 0) {
|
||||
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
|
||||
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
// clip_g->compute(n_threads,
|
||||
// input_ids,
|
||||
// 0,
|
||||
// NULL,
|
||||
// max_token_idx,
|
||||
// true,
|
||||
// &pooled_g,
|
||||
// work_ctx);
|
||||
// clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
|
||||
|
||||
// TODO: fix pooled_g
|
||||
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
|
||||
ggml_set_f32(pooled_g, 0.f);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
clip_g->compute(n_threads,
|
||||
input_ids,
|
||||
0,
|
||||
NULL,
|
||||
max_token_idx,
|
||||
true,
|
||||
&pooled_g,
|
||||
work_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1104,21 +1099,17 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
|
||||
size_t max_token_idx = 0;
|
||||
|
||||
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
// clip_l->compute(n_threads,
|
||||
// input_ids,
|
||||
// 0,
|
||||
// NULL,
|
||||
// max_token_idx,
|
||||
// true,
|
||||
// &pooled,
|
||||
// work_ctx);
|
||||
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
|
||||
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
|
||||
|
||||
// clip_l.transformer.text_model.text_projection no in file, ignore
|
||||
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
|
||||
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
|
||||
ggml_set_f32(pooled, 0.f);
|
||||
clip_l->compute(n_threads,
|
||||
input_ids,
|
||||
0,
|
||||
NULL,
|
||||
max_token_idx,
|
||||
true,
|
||||
&pooled,
|
||||
work_ctx);
|
||||
}
|
||||
|
||||
// t5
|
||||
|
||||
@ -17,7 +17,8 @@ struct DiffusionModel {
|
||||
std::vector<struct ggml_tensor*> controls = {},
|
||||
float control_strength = 0.f,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) = 0;
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) = 0;
|
||||
virtual void alloc_params_buffer() = 0;
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void free_compute_buffer() = 0;
|
||||
@ -31,8 +32,9 @@ struct UNetModel : public DiffusionModel {
|
||||
|
||||
UNetModel(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: unet(backend, wtype, version) {
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: unet(backend, wtype, version, flash_attn) {
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -70,7 +72,9 @@ struct UNetModel : public DiffusionModel {
|
||||
std::vector<struct ggml_tensor*> controls = {},
|
||||
float control_strength = 0.f,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
(void)skip_layers; // SLG doesn't work with UNet models
|
||||
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
|
||||
}
|
||||
};
|
||||
@ -119,8 +123,9 @@ struct MMDiTModel : public DiffusionModel {
|
||||
std::vector<struct ggml_tensor*> controls = {},
|
||||
float control_strength = 0.f,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
|
||||
}
|
||||
};
|
||||
|
||||
@ -129,8 +134,9 @@ struct FluxModel : public DiffusionModel {
|
||||
|
||||
FluxModel(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_FLUX_DEV)
|
||||
: flux(backend, wtype, version) {
|
||||
SDVersion version = VERSION_FLUX_DEV,
|
||||
bool flash_attn = false)
|
||||
: flux(backend, wtype, version, flash_attn) {
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -168,9 +174,10 @@ struct FluxModel : public DiffusionModel {
|
||||
std::vector<struct ggml_tensor*> controls = {},
|
||||
float control_strength = 0.f,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@ -29,4 +29,26 @@ Example:
|
||||
|
||||
```bash
|
||||
bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png
|
||||
```
|
||||
```
|
||||
|
||||
## PhotoMaker Version 2
|
||||
|
||||
[PhotoMaker Version 2 (PMV2)](https://github.com/TencentARC/PhotoMaker/blob/main/README_pmv2.md) has some key improvements. Unfortunately it has a very heavy dependency which makes running it a bit involved in ```SD.cpp```.
|
||||
|
||||
Running PMV2 is now a two-step process:
|
||||
|
||||
- Run a python script ```face_detect.py``` to obtain **id_embeds** for the given input images
|
||||
```
|
||||
python face_detect.py input_image_dir
|
||||
```
|
||||
An ```id_embeds.safetensors``` file will be generated in ```input_images_dir```
|
||||
|
||||
**Note: this step is only needed to run once; the same ```id_embeds``` can be reused**
|
||||
|
||||
- Run the same command as in version 1 but replacing ```photomaker-v1.safetensors``` with ```photomaker-v2.safetensors```.
|
||||
|
||||
You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2)
|
||||
|
||||
- All the command line parameters from Version 1 remain the same for Version 2
|
||||
|
||||
|
||||
|
||||
@ -116,9 +116,15 @@ struct SDParams {
|
||||
bool normalize_input = false;
|
||||
bool clip_on_cpu = false;
|
||||
bool vae_on_cpu = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool canny_preprocess = false;
|
||||
bool color = false;
|
||||
int upscale_repeats = 1;
|
||||
|
||||
std::vector<int> skip_layers = {7, 8, 9};
|
||||
float slg_scale = 0.;
|
||||
float skip_layer_start = 0.01;
|
||||
float skip_layer_end = 0.2;
|
||||
};
|
||||
|
||||
void print_params(SDParams params) {
|
||||
@ -146,11 +152,13 @@ void print_params(SDParams params) {
|
||||
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
||||
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
||||
printf(" strength(control): %.2f\n", params.control_strength);
|
||||
printf(" prompt: %s\n", params.prompt.c_str());
|
||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||
printf(" min_cfg: %.2f\n", params.min_cfg);
|
||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||
printf(" slg_scale: %.2f\n", params.slg_scale);
|
||||
printf(" guidance: %.2f\n", params.guidance);
|
||||
printf(" clip_skip: %d\n", params.clip_skip);
|
||||
printf(" width: %d\n", params.width);
|
||||
@ -177,7 +185,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -m, --model [MODEL] path to full model\n");
|
||||
printf(" --diffusion-model path to the standalone diffusion model\n");
|
||||
printf(" --clip_l path to the clip-l text encoder\n");
|
||||
printf(" --clip_g path to the clip-l text encoder\n");
|
||||
printf(" --clip_g path to the clip-g text encoder\n");
|
||||
printf(" --t5xxl path to the the t5xxl text encoder\n");
|
||||
printf(" --vae [VAE] path to vae\n");
|
||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||
@ -197,6 +205,12 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" -p, --prompt [PROMPT] the prompt to render\n");
|
||||
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
|
||||
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
|
||||
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
|
||||
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
|
||||
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
|
||||
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n");
|
||||
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n");
|
||||
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
|
||||
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
|
||||
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
|
||||
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
|
||||
@ -215,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
|
||||
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
|
||||
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
|
||||
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
|
||||
printf(" Might lower quality, since it implies converting k and v to f16.\n");
|
||||
printf(" This might crash if it is not supported by the backend.\n");
|
||||
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
||||
printf(" --canny apply canny preprocessor (edge detection)\n");
|
||||
printf(" --color Colors the logging tags according to level\n");
|
||||
@ -465,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
|
||||
} else if (arg == "--vae-on-cpu") {
|
||||
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
|
||||
} else if (arg == "--diffusion-fa") {
|
||||
params.diffusion_flash_attn = true; // can reduce MEM significantly
|
||||
} else if (arg == "--canny") {
|
||||
params.canny_preprocess = true;
|
||||
} else if (arg == "-b" || arg == "--batch-count") {
|
||||
@ -534,6 +553,61 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
params.verbose = true;
|
||||
} else if (arg == "--color") {
|
||||
params.color = true;
|
||||
} else if (arg == "--slg-scale") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.slg_scale = std::stof(argv[i]);
|
||||
} else if (arg == "--skip-layers") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
if (argv[i][0] != '[') {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
std::string layers_str = argv[i];
|
||||
while (layers_str.back() != ']') {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
layers_str += " " + std::string(argv[i]);
|
||||
}
|
||||
layers_str = layers_str.substr(1, layers_str.size() - 2);
|
||||
|
||||
std::regex regex("[, ]+");
|
||||
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
|
||||
std::sregex_token_iterator end;
|
||||
std::vector<std::string> tokens(iter, end);
|
||||
std::vector<int> layers;
|
||||
for (const auto& token : tokens) {
|
||||
try {
|
||||
layers.push_back(std::stoi(token));
|
||||
} catch (const std::invalid_argument& e) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
params.skip_layers = layers;
|
||||
|
||||
if (invalid_arg) {
|
||||
break;
|
||||
}
|
||||
} else if (arg == "--skip-layer-start") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.skip_layer_start = std::stof(argv[i]);
|
||||
} else if (arg == "--skip-layer-end") {
|
||||
if (++i >= argc) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
}
|
||||
params.skip_layer_end = std::stof(argv[i]);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv);
|
||||
@ -624,6 +698,16 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
}
|
||||
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
|
||||
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
|
||||
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
|
||||
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
|
||||
parameter_string += "Skip layers: [";
|
||||
for (const auto& layer : params.skip_layers) {
|
||||
parameter_string += std::to_string(layer) + ", ";
|
||||
}
|
||||
parameter_string += "], ";
|
||||
parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", ";
|
||||
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
|
||||
}
|
||||
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||
@ -791,7 +875,8 @@ int main(int argc, const char* argv[]) {
|
||||
params.schedule,
|
||||
params.clip_on_cpu,
|
||||
params.control_net_cpu,
|
||||
params.vae_on_cpu);
|
||||
params.vae_on_cpu,
|
||||
params.diffusion_flash_attn);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
printf("new_sd_ctx_t failed\n");
|
||||
@ -840,7 +925,11 @@ int main(int argc, const char* argv[]) {
|
||||
params.control_strength,
|
||||
params.style_ratio,
|
||||
params.normalize_input,
|
||||
params.input_id_images_path.c_str());
|
||||
params.input_id_images_path.c_str(),
|
||||
params.skip_layers,
|
||||
params.slg_scale,
|
||||
params.skip_layer_start,
|
||||
params.skip_layer_end);
|
||||
} else {
|
||||
sd_image_t input_image = {(uint32_t)params.width,
|
||||
(uint32_t)params.height,
|
||||
|
||||
88
face_detect.py
Normal file
88
face_detect.py
Normal file
@ -0,0 +1,88 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
# pip install insightface==0.7.3
|
||||
from insightface.app import FaceAnalysis
|
||||
from insightface.data import get_image as ins_get_image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
###
|
||||
# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543
|
||||
###
|
||||
class FaceAnalysis2(FaceAnalysis):
|
||||
# NOTE: allows setting det_size for each detection call.
|
||||
# the model allows it but the wrapping code from insightface
|
||||
# doesn't show it, and people end up loading duplicate models
|
||||
# for different sizes where there is absolutely no need to
|
||||
def get(self, img, max_num=0, det_size=(640, 640)):
|
||||
if det_size is not None:
|
||||
self.det_model.input_size = det_size
|
||||
|
||||
return super().get(img, max_num)
|
||||
|
||||
def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)):
|
||||
# NOTE: try detect faces, if no faces detected, lower det_size until it does
|
||||
detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)]
|
||||
|
||||
for size in detection_sizes:
|
||||
faces = face_analysis.get(img_data, det_size=size)
|
||||
if len(faces) > 0:
|
||||
return faces
|
||||
|
||||
return []
|
||||
|
||||
if __name__ == "__main__":
|
||||
#face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
|
||||
face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition'])
|
||||
face_detector.prepare(ctx_id=0, det_size=(640, 640))
|
||||
#input_folder_name = './scarletthead_woman'
|
||||
input_folder_name = sys.argv[1]
|
||||
image_basename_list = os.listdir(input_folder_name)
|
||||
image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list])
|
||||
|
||||
input_id_images = []
|
||||
for image_path in image_path_list:
|
||||
input_id_images.append(load_image(image_path))
|
||||
|
||||
id_embed_list = []
|
||||
|
||||
for img in input_id_images:
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
faces = analyze_faces(face_detector, img)
|
||||
if len(faces) > 0:
|
||||
id_embed_list.append(torch.from_numpy((faces[0]['embedding'])))
|
||||
|
||||
if len(id_embed_list) == 0:
|
||||
raise ValueError(f"No face detected in input image pool")
|
||||
|
||||
id_embeds = torch.stack(id_embed_list)
|
||||
|
||||
# for r in id_embeds:
|
||||
# print(r)
|
||||
# #torch.save(id_embeds, input_folder_name+'/id_embeds.pt');
|
||||
# weights = dict()
|
||||
# weights["id_embeds"] = id_embeds
|
||||
# save_file(weights, input_folder_name+'/id_embeds.safetensors')
|
||||
|
||||
binary_data = id_embeds.numpy().tobytes()
|
||||
two = 4
|
||||
zero = 0
|
||||
one = 1
|
||||
tensor_name = "id_embeds"
|
||||
# Write binary data to a file
|
||||
with open(input_folder_name+'/id_embeds.bin', "wb") as f:
|
||||
f.write(two.to_bytes(4, byteorder='little'))
|
||||
f.write((len(tensor_name)).to_bytes(4, byteorder='little'))
|
||||
f.write(zero.to_bytes(4, byteorder='little'))
|
||||
f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little'))
|
||||
f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little'))
|
||||
f.write(one.to_bytes(4, byteorder='little'))
|
||||
f.write(one.to_bytes(4, byteorder='little'))
|
||||
f.write(tensor_name.encode('ascii'))
|
||||
f.write(binary_data)
|
||||
|
||||
|
||||
78
flux.hpp
78
flux.hpp
@ -115,25 +115,28 @@ namespace Flux {
|
||||
struct ggml_tensor* q,
|
||||
struct ggml_tensor* k,
|
||||
struct ggml_tensor* v,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
bool flash_attn) {
|
||||
// q,k,v: [N, L, n_head, d_head]
|
||||
// pe: [L, d_head/2, 2, 2]
|
||||
// return: [N, L, n_head*d_head]
|
||||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head]
|
||||
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head]
|
||||
return x;
|
||||
}
|
||||
|
||||
struct SelfAttention : public GGMLBlock {
|
||||
public:
|
||||
int64_t num_heads;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
SelfAttention(int64_t dim,
|
||||
int64_t num_heads = 8,
|
||||
bool qkv_bias = false)
|
||||
bool qkv_bias = false,
|
||||
bool flash_attn = false)
|
||||
: num_heads(num_heads) {
|
||||
int64_t head_dim = dim / num_heads;
|
||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||
@ -167,9 +170,9 @@ namespace Flux {
|
||||
// x: [N, n_token, dim]
|
||||
// pe: [n_token, d_head/2, 2, 2]
|
||||
// return [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
|
||||
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -237,15 +240,19 @@ namespace Flux {
|
||||
}
|
||||
|
||||
struct DoubleStreamBlock : public GGMLBlock {
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
DoubleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio,
|
||||
bool qkv_bias = false) {
|
||||
bool qkv_bias = false,
|
||||
bool flash_attn = false)
|
||||
: flash_attn(flash_attn) {
|
||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||
@ -254,7 +261,7 @@ namespace Flux {
|
||||
|
||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
|
||||
|
||||
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
|
||||
@ -316,7 +323,7 @@ namespace Flux {
|
||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
|
||||
auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
auto txt_attn_out = ggml_view_3d(ctx,
|
||||
attn,
|
||||
@ -364,13 +371,15 @@ namespace Flux {
|
||||
int64_t num_heads;
|
||||
int64_t hidden_size;
|
||||
int64_t mlp_hidden_dim;
|
||||
bool flash_attn;
|
||||
|
||||
public:
|
||||
SingleStreamBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
float mlp_ratio = 4.0f,
|
||||
float qk_scale = 0.f)
|
||||
: hidden_size(hidden_size), num_heads(num_heads) {
|
||||
float qk_scale = 0.f,
|
||||
bool flash_attn = false)
|
||||
: hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) {
|
||||
int64_t head_dim = hidden_size / num_heads;
|
||||
float scale = qk_scale;
|
||||
if (scale <= 0.f) {
|
||||
@ -433,7 +442,7 @@ namespace Flux {
|
||||
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
|
||||
q = norm->query_norm(ctx, q);
|
||||
k = norm->key_norm(ctx, k);
|
||||
auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size]
|
||||
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size]
|
||||
|
||||
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
|
||||
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
|
||||
@ -492,6 +501,7 @@ namespace Flux {
|
||||
int theta = 10000;
|
||||
bool qkv_bias = true;
|
||||
bool guidance_embed = true;
|
||||
bool flash_attn = true;
|
||||
};
|
||||
|
||||
struct Flux : public GGMLBlock {
|
||||
@ -646,13 +656,16 @@ namespace Flux {
|
||||
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio,
|
||||
params.qkv_bias));
|
||||
params.qkv_bias,
|
||||
params.flash_attn));
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.mlp_ratio));
|
||||
params.mlp_ratio,
|
||||
0.f,
|
||||
params.flash_attn));
|
||||
}
|
||||
|
||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
|
||||
@ -711,7 +724,8 @@ namespace Flux {
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
|
||||
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
|
||||
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
|
||||
@ -733,6 +747,10 @@ namespace Flux {
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
|
||||
for (int i = 0; i < params.depth; i++) {
|
||||
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
|
||||
|
||||
auto img_txt = block->forward(ctx, img, txt, vec, pe);
|
||||
@ -742,6 +760,9 @@ namespace Flux {
|
||||
|
||||
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
|
||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
|
||||
continue;
|
||||
}
|
||||
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
|
||||
|
||||
txt_img = block->forward(ctx, txt_img, vec, pe);
|
||||
@ -769,7 +790,8 @@ namespace Flux {
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// Forward pass of DiT.
|
||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
// timestep: (N,) tensor of diffusion timesteps
|
||||
@ -791,7 +813,7 @@ namespace Flux {
|
||||
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
|
||||
|
||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
|
||||
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
|
||||
|
||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
|
||||
@ -808,11 +830,16 @@ namespace Flux {
|
||||
|
||||
FluxRunner(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_FLUX_DEV)
|
||||
SDVersion version = VERSION_FLUX_DEV,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend, wtype) {
|
||||
flux_params.flash_attn = flash_attn;
|
||||
if (version == VERSION_FLUX_SCHNELL) {
|
||||
flux_params.guidance_embed = false;
|
||||
}
|
||||
if (version == VERSION_FLUX_LITE) {
|
||||
flux_params.depth = 8;
|
||||
}
|
||||
flux = Flux(flux_params);
|
||||
flux.init(params_ctx, wtype);
|
||||
}
|
||||
@ -829,7 +856,8 @@ namespace Flux {
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance) {
|
||||
struct ggml_tensor* guidance,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
|
||||
|
||||
@ -856,7 +884,8 @@ namespace Flux {
|
||||
context,
|
||||
y,
|
||||
guidance,
|
||||
pe);
|
||||
pe,
|
||||
skip_layers);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
@ -870,14 +899,15 @@ namespace Flux {
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]
|
||||
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
||||
// guidance: [N, ]
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context, y, guidance);
|
||||
return build_graph(x, timesteps, context, y, guidance, skip_layers);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
@ -958,4 +988,4 @@ namespace Flux {
|
||||
|
||||
} // namespace Flux
|
||||
|
||||
#endif // __FLUX_HPP__
|
||||
#endif // __FLUX_HPP__
|
||||
|
||||
@ -666,32 +666,6 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
// q: [N * n_head, n_token, d_head]
|
||||
// k: [N * n_head, n_k, d_head]
|
||||
// v: [N * n_head, d_head, n_k]
|
||||
// return: [N * n_head, n_token, d_head]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
|
||||
struct ggml_tensor* q,
|
||||
struct ggml_tensor* k,
|
||||
struct ggml_tensor* v,
|
||||
bool mask = false) {
|
||||
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
|
||||
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
|
||||
#else
|
||||
float d_head = (float)q->ne[0];
|
||||
|
||||
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
|
||||
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
|
||||
if (mask) {
|
||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
||||
}
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
|
||||
#endif
|
||||
return kqv;
|
||||
}
|
||||
|
||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
||||
@ -703,7 +677,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
int64_t n_head,
|
||||
struct ggml_tensor* mask = NULL,
|
||||
bool diag_mask_inf = false,
|
||||
bool skip_reshape = false) {
|
||||
bool skip_reshape = false,
|
||||
bool flash_attn = false) {
|
||||
int64_t L_q;
|
||||
int64_t L_k;
|
||||
int64_t C;
|
||||
@ -734,13 +709,42 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
bool use_flash_attn = false;
|
||||
ggml_tensor* kqv = NULL;
|
||||
if (use_flash_attn) {
|
||||
// if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
// }
|
||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
||||
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
|
||||
|
||||
// cuda max d_head seems to be 256, cpu does seem to work with 512
|
||||
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
}
|
||||
|
||||
// TODO(Green-Sky): more pad or disable for funny tensor shapes
|
||||
|
||||
ggml_tensor* kqv = nullptr;
|
||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
||||
if (can_use_flash_attn && flash_attn) {
|
||||
// LOG_DEBUG("using flash attention");
|
||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
LOG_DEBUG("k->ne[1] == %d", k->ne[1]);
|
||||
v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
|
||||
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
||||
} else {
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||
@ -756,10 +760,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
||||
|
||||
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
|
||||
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
||||
}
|
||||
|
||||
kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
|
||||
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
|
||||
kqv = ggml_cont(ctx, kqv);
|
||||
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
||||
|
||||
return kqv;
|
||||
@ -1047,6 +1053,11 @@ public:
|
||||
params_buffer_size / (1024.0 * 1024.0),
|
||||
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
|
||||
num_tensors);
|
||||
// printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n",
|
||||
// get_desc().c_str(),
|
||||
// params_buffer_size / (1024.0 * 1024.0),
|
||||
// ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
|
||||
// num_tensors);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
248
mmdit.hpp
248
mmdit.hpp
@ -252,6 +252,7 @@ struct DismantledBlock : public GGMLBlock {
|
||||
public:
|
||||
int64_t num_heads;
|
||||
bool pre_only;
|
||||
bool self_attn;
|
||||
|
||||
public:
|
||||
DismantledBlock(int64_t hidden_size,
|
||||
@ -259,14 +260,19 @@ public:
|
||||
float mlp_ratio = 4.0,
|
||||
std::string qk_norm = "",
|
||||
bool qkv_bias = false,
|
||||
bool pre_only = false)
|
||||
: num_heads(num_heads), pre_only(pre_only) {
|
||||
bool pre_only = false,
|
||||
bool self_attn = false)
|
||||
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
|
||||
// rmsnorm is always Flase
|
||||
// scale_mod_only is always Flase
|
||||
// swiglu is always Flase
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
|
||||
|
||||
if (self_attn) {
|
||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
|
||||
}
|
||||
|
||||
if (!pre_only) {
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||
int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
|
||||
@ -277,9 +283,52 @@ public:
|
||||
if (pre_only) {
|
||||
n_mods = 2;
|
||||
}
|
||||
if (self_attn) {
|
||||
n_mods = 9;
|
||||
}
|
||||
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
|
||||
}
|
||||
|
||||
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* c) {
|
||||
GGML_ASSERT(self_attn);
|
||||
// x: [N, n_token, hidden_size]
|
||||
// c: [N, hidden_size]
|
||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
||||
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||
|
||||
int64_t n_mods = 9;
|
||||
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
|
||||
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
|
||||
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
|
||||
|
||||
int64_t offset = m->nb[1] * m->ne[1];
|
||||
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
|
||||
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
|
||||
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
|
||||
|
||||
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
|
||||
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
|
||||
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
|
||||
|
||||
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
|
||||
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
|
||||
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
|
||||
|
||||
auto x_norm = norm1->forward(ctx, x);
|
||||
|
||||
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
|
||||
auto qkv = attn->pre_attention(ctx, attn_in);
|
||||
|
||||
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
|
||||
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
|
||||
|
||||
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* c) {
|
||||
@ -319,6 +368,44 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
|
||||
struct ggml_tensor* attn_out,
|
||||
struct ggml_tensor* attn2_out,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* gate_msa,
|
||||
struct ggml_tensor* shift_mlp,
|
||||
struct ggml_tensor* scale_mlp,
|
||||
struct ggml_tensor* gate_mlp,
|
||||
struct ggml_tensor* gate_msa2) {
|
||||
// attn_out: [N, n_token, hidden_size]
|
||||
// x: [N, n_token, hidden_size]
|
||||
// gate_msa: [N, hidden_size]
|
||||
// shift_mlp: [N, hidden_size]
|
||||
// scale_mlp: [N, hidden_size]
|
||||
// gate_mlp: [N, hidden_size]
|
||||
// return: [N, n_token, hidden_size]
|
||||
GGML_ASSERT(!pre_only);
|
||||
|
||||
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
||||
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
||||
|
||||
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
|
||||
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
|
||||
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
|
||||
|
||||
attn_out = attn->post_attention(ctx, attn_out);
|
||||
attn2_out = attn2->post_attention(ctx, attn2_out);
|
||||
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
|
||||
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* post_attention(struct ggml_context* ctx,
|
||||
struct ggml_tensor* attn_out,
|
||||
struct ggml_tensor* x,
|
||||
@ -357,29 +444,52 @@ public:
|
||||
// return: [N, n_token, hidden_size]
|
||||
|
||||
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
|
||||
if (self_attn) {
|
||||
auto qkv_intermediates = pre_attention_x(ctx, x, c);
|
||||
// auto qkv = qkv_intermediates.first;
|
||||
// auto intermediates = qkv_intermediates.second;
|
||||
// no longer a pair, but a tuple
|
||||
auto qkv = std::get<0>(qkv_intermediates);
|
||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||
auto intermediates = std::get<2>(qkv_intermediates);
|
||||
|
||||
auto qkv_intermediates = pre_attention(ctx, x, c);
|
||||
auto qkv = qkv_intermediates.first;
|
||||
auto intermediates = qkv_intermediates.second;
|
||||
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
||||
auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
|
||||
x = post_attention_x(ctx,
|
||||
attn_out,
|
||||
attn2_out,
|
||||
intermediates[0],
|
||||
intermediates[1],
|
||||
intermediates[2],
|
||||
intermediates[3],
|
||||
intermediates[4],
|
||||
intermediates[5]);
|
||||
return x; // [N, n_token, dim]
|
||||
} else {
|
||||
auto qkv_intermediates = pre_attention(ctx, x, c);
|
||||
auto qkv = qkv_intermediates.first;
|
||||
auto intermediates = qkv_intermediates.second;
|
||||
|
||||
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
||||
x = post_attention(ctx,
|
||||
attn_out,
|
||||
intermediates[0],
|
||||
intermediates[1],
|
||||
intermediates[2],
|
||||
intermediates[3],
|
||||
intermediates[4]);
|
||||
return x; // [N, n_token, dim]
|
||||
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
||||
x = post_attention(ctx,
|
||||
attn_out,
|
||||
intermediates[0],
|
||||
intermediates[1],
|
||||
intermediates[2],
|
||||
intermediates[3],
|
||||
intermediates[4]);
|
||||
return x; // [N, n_token, dim]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixing(struct ggml_context* ctx,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* c,
|
||||
std::shared_ptr<DismantledBlock> context_block,
|
||||
std::shared_ptr<DismantledBlock> x_block) {
|
||||
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
|
||||
block_mixing(struct ggml_context* ctx,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* c,
|
||||
std::shared_ptr<DismantledBlock> context_block,
|
||||
std::shared_ptr<DismantledBlock> x_block) {
|
||||
// context: [N, n_context, hidden_size]
|
||||
// x: [N, n_token, hidden_size]
|
||||
// c: [N, hidden_size]
|
||||
@ -387,10 +497,18 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
|
||||
auto context_qkv = context_qkv_intermediates.first;
|
||||
auto context_intermediates = context_qkv_intermediates.second;
|
||||
|
||||
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
|
||||
auto x_qkv = x_qkv_intermediates.first;
|
||||
auto x_intermediates = x_qkv_intermediates.second;
|
||||
std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
|
||||
|
||||
if (x_block->self_attn) {
|
||||
auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c);
|
||||
x_qkv = std::get<0>(x_qkv_intermediates);
|
||||
x_qkv2 = std::get<1>(x_qkv_intermediates);
|
||||
x_intermediates = std::get<2>(x_qkv_intermediates);
|
||||
} else {
|
||||
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
|
||||
x_qkv = x_qkv_intermediates.first;
|
||||
x_intermediates = x_qkv_intermediates.second;
|
||||
}
|
||||
std::vector<struct ggml_tensor*> qkv;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
|
||||
@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
|
||||
context = NULL;
|
||||
}
|
||||
|
||||
x = x_block->post_attention(ctx,
|
||||
x_attn,
|
||||
x_intermediates[0],
|
||||
x_intermediates[1],
|
||||
x_intermediates[2],
|
||||
x_intermediates[3],
|
||||
x_intermediates[4]);
|
||||
if (x_block->self_attn) {
|
||||
auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
|
||||
|
||||
x = x_block->post_attention_x(ctx,
|
||||
x_attn,
|
||||
attn2,
|
||||
x_intermediates[0],
|
||||
x_intermediates[1],
|
||||
x_intermediates[2],
|
||||
x_intermediates[3],
|
||||
x_intermediates[4],
|
||||
x_intermediates[5]);
|
||||
} else {
|
||||
x = x_block->post_attention(ctx,
|
||||
x_attn,
|
||||
x_intermediates[0],
|
||||
x_intermediates[1],
|
||||
x_intermediates[2],
|
||||
x_intermediates[3],
|
||||
x_intermediates[4]);
|
||||
}
|
||||
|
||||
return {context, x};
|
||||
}
|
||||
@ -447,9 +579,10 @@ public:
|
||||
float mlp_ratio = 4.0,
|
||||
std::string qk_norm = "",
|
||||
bool qkv_bias = false,
|
||||
bool pre_only = false) {
|
||||
bool pre_only = false,
|
||||
bool self_attn_x = false) {
|
||||
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
|
||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false));
|
||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
|
||||
}
|
||||
|
||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
||||
@ -507,6 +640,7 @@ protected:
|
||||
int64_t input_size = -1;
|
||||
int64_t patch_size = 2;
|
||||
int64_t in_channels = 16;
|
||||
int64_t d_self = -1; // >=0 for MMdiT-X
|
||||
int64_t depth = 24;
|
||||
float mlp_ratio = 4.0f;
|
||||
int64_t adm_in_channels = 2048;
|
||||
@ -561,6 +695,20 @@ public:
|
||||
context_size = 4096;
|
||||
context_embedder_out_dim = 2432;
|
||||
qk_norm = "rms";
|
||||
} else if (version == VERSION_SD3_5_2B) {
|
||||
input_size = -1;
|
||||
patch_size = 2;
|
||||
in_channels = 16;
|
||||
depth = 24;
|
||||
d_self = 12;
|
||||
mlp_ratio = 4.0f;
|
||||
adm_in_channels = 2048;
|
||||
out_channels = 16;
|
||||
pos_embed_max_size = 384;
|
||||
num_patchs = 147456;
|
||||
context_size = 4096;
|
||||
context_embedder_out_dim = 1536;
|
||||
qk_norm = "rms";
|
||||
}
|
||||
int64_t default_out_channels = in_channels;
|
||||
hidden_size = 64 * depth;
|
||||
@ -581,15 +729,17 @@ public:
|
||||
mlp_ratio,
|
||||
qk_norm,
|
||||
true,
|
||||
i == depth - 1));
|
||||
i == depth - 1,
|
||||
i <= d_self));
|
||||
}
|
||||
|
||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
|
||||
}
|
||||
|
||||
struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx,
|
||||
int64_t h,
|
||||
int64_t w) {
|
||||
struct ggml_tensor*
|
||||
cropped_pos_embed(struct ggml_context* ctx,
|
||||
int64_t h,
|
||||
int64_t w) {
|
||||
auto pos_embed = params["pos_embed"];
|
||||
|
||||
h = (h + 1) / patch_size;
|
||||
@ -651,7 +801,8 @@ public:
|
||||
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* c_mod,
|
||||
struct ggml_tensor* context) {
|
||||
struct ggml_tensor* context,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// x: [N, H*W, hidden_size]
|
||||
// context: [N, n_context, d_context]
|
||||
// c: [N, hidden_size]
|
||||
@ -659,6 +810,11 @@ public:
|
||||
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
// skip iteration if i is in skip_layers
|
||||
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
|
||||
|
||||
auto context_x = block->forward(ctx, context, x, c_mod);
|
||||
@ -674,8 +830,9 @@ public:
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* t,
|
||||
struct ggml_tensor* y = NULL,
|
||||
struct ggml_tensor* context = NULL) {
|
||||
struct ggml_tensor* y = NULL,
|
||||
struct ggml_tensor* context = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// Forward pass of DiT.
|
||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
// t: (N,) tensor of diffusion timesteps
|
||||
@ -706,7 +863,7 @@ public:
|
||||
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
|
||||
}
|
||||
|
||||
x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels)
|
||||
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
||||
|
||||
x = unpatchify(ctx, x, h, w); // [N, C, H, W]
|
||||
|
||||
@ -735,7 +892,8 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* y) {
|
||||
struct ggml_tensor* y,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false);
|
||||
|
||||
x = to_backend(x);
|
||||
@ -747,7 +905,8 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
x,
|
||||
timesteps,
|
||||
y,
|
||||
context);
|
||||
context,
|
||||
skip_layers);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
@ -760,13 +919,14 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size]
|
||||
// y: [N, adm_in_channels] or [1, adm_in_channels]
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context, y);
|
||||
return build_graph(x, timesteps, context, y, skip_layers);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
|
||||
126
model.cpp
126
model.cpp
@ -146,6 +146,33 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
|
||||
{"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight",
|
||||
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"},
|
||||
{"pmid.qformer_perceiver.token_proj.0.bias",
|
||||
"pmid.qformer_perceiver.token_proj.fc1.bias"},
|
||||
{"pmid.qformer_perceiver.token_proj.2.bias",
|
||||
"pmid.qformer_perceiver.token_proj.fc2.bias"},
|
||||
{"pmid.qformer_perceiver.token_proj.0.weight",
|
||||
"pmid.qformer_perceiver.token_proj.fc1.weight"},
|
||||
{"pmid.qformer_perceiver.token_proj.2.weight",
|
||||
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
||||
};
|
||||
|
||||
std::string convert_open_clip_to_hf_clip(const std::string& name) {
|
||||
std::string new_name = name;
|
||||
std::string prefix;
|
||||
@ -212,6 +239,13 @@ std::string convert_vae_decoder_name(const std::string& name) {
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string convert_pmid_v2_name(const std::string& name) {
|
||||
if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) {
|
||||
return pmid_v2_name_map[name];
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
/* If not a SDXL LoRA the unet" prefix will have already been replaced by this
|
||||
* point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */
|
||||
std::string convert_sdxl_lora_name(std::string tensor_name) {
|
||||
@ -443,6 +477,8 @@ std::string convert_tensor_name(std::string name) {
|
||||
new_name = convert_open_clip_to_hf_clip(name);
|
||||
} else if (starts_with(name, "first_stage_model.decoder")) {
|
||||
new_name = convert_vae_decoder_name(name);
|
||||
} else if (starts_with(name, "pmid.qformer_perceiver")) {
|
||||
new_name = convert_pmid_v2_name(name);
|
||||
} else if (starts_with(name, "control_model.")) { // for controlnet pth models
|
||||
size_t pos = name.find('.');
|
||||
if (pos != std::string::npos) {
|
||||
@ -614,6 +650,47 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
|
||||
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
|
||||
}
|
||||
|
||||
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
|
||||
uint8_t sign = (fp8 >> 7) & 0x1;
|
||||
uint8_t exponent = (fp8 >> 2) & 0x1F;
|
||||
uint8_t mantissa = fp8 & 0x3;
|
||||
|
||||
uint16_t fp16_sign = sign << 15;
|
||||
uint16_t fp16_exponent;
|
||||
uint16_t fp16_mantissa;
|
||||
|
||||
if (exponent == 0 && mantissa == 0) { // zero
|
||||
return fp16_sign;
|
||||
}
|
||||
|
||||
if (exponent == 0x1F) { // NAN and INF
|
||||
fp16_exponent = 0x1F;
|
||||
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
|
||||
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
|
||||
}
|
||||
|
||||
if (exponent == 0) { // subnormal numbers
|
||||
fp16_exponent = 0;
|
||||
fp16_mantissa = (mantissa << 8);
|
||||
return fp16_sign | fp16_mantissa;
|
||||
}
|
||||
|
||||
// normal numbers
|
||||
int16_t true_exponent = (int16_t)exponent - 15 + 15;
|
||||
if (true_exponent <= 0) {
|
||||
fp16_exponent = 0;
|
||||
fp16_mantissa = (mantissa << 8);
|
||||
} else if (true_exponent >= 0x1F) {
|
||||
fp16_exponent = 0x1F;
|
||||
fp16_mantissa = 0;
|
||||
} else {
|
||||
fp16_exponent = (uint16_t)true_exponent;
|
||||
fp16_mantissa = mantissa << 8;
|
||||
}
|
||||
|
||||
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
|
||||
}
|
||||
|
||||
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
|
||||
// support inplace op
|
||||
for (int64_t i = n - 1; i >= 0; i--) {
|
||||
@ -627,6 +704,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
|
||||
dst[i] = f8_e4m3_to_f16(src[i]);
|
||||
}
|
||||
}
|
||||
void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
|
||||
// support inplace op
|
||||
for (int64_t i = n - 1; i >= 0; i--) {
|
||||
dst[i] = f8_e5m2_to_f16(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void convert_tensor(void* src,
|
||||
ggml_type src_type,
|
||||
@ -863,6 +946,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
ttype = GGML_TYPE_F32;
|
||||
} else if (dtype == "F8_E4M3") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
}
|
||||
return ttype;
|
||||
}
|
||||
@ -976,6 +1061,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
tensor_storage.is_f8_e4m3 = true;
|
||||
// f8 -> f16
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
tensor_storage.is_f8_e5m2 = true;
|
||||
// f8 -> f16
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else {
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
|
||||
}
|
||||
@ -1308,7 +1397,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
reader.tensor_storage.reverse_ne();
|
||||
reader.tensor_storage.file_index = file_index;
|
||||
// if(strcmp(prefix.c_str(), "scarlett") == 0)
|
||||
// printf(" got tensor %s \n ", reader.tensor_storage.name.c_str());
|
||||
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
||||
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
|
||||
tensor_storages.push_back(reader.tensor_storage);
|
||||
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
||||
@ -1345,7 +1434,8 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
size_t pos = name.find("data.pkl");
|
||||
if (pos != std::string::npos) {
|
||||
std::string dir = name.substr(0, pos);
|
||||
void* pkl_data = NULL;
|
||||
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
|
||||
void* pkl_data = NULL;
|
||||
size_t pkl_size;
|
||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||
|
||||
@ -1364,15 +1454,23 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
bool is_flux = false;
|
||||
bool is_sd3 = false;
|
||||
bool is_flux = false;
|
||||
bool is_schnell = true;
|
||||
bool is_lite = true;
|
||||
bool is_sd3 = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
return VERSION_FLUX_DEV;
|
||||
is_schnell = false;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
|
||||
is_lite = false;
|
||||
}
|
||||
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
|
||||
return VERSION_SD3_5_2B;
|
||||
}
|
||||
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
|
||||
return VERSION_SD3_5_8B;
|
||||
}
|
||||
@ -1400,7 +1498,14 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
}
|
||||
}
|
||||
if (is_flux) {
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
if (is_schnell) {
|
||||
GGML_ASSERT(!is_lite);
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
} else if (is_lite) {
|
||||
return VERSION_FLUX_LITE;
|
||||
} else {
|
||||
return VERSION_FLUX_DEV;
|
||||
}
|
||||
}
|
||||
if (is_sd3) {
|
||||
return VERSION_SD3_2B;
|
||||
@ -1629,6 +1734,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(tensor_storage.nbytes());
|
||||
@ -1640,6 +1748,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
|
||||
@ -1655,6 +1766,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
|
||||
34
model.h
34
model.h
@ -26,14 +26,43 @@ enum SDVersion {
|
||||
VERSION_FLUX_DEV,
|
||||
VERSION_FLUX_SCHNELL,
|
||||
VERSION_SD3_5_8B,
|
||||
VERSION_SD3_5_2B,
|
||||
VERSION_FLUX_LITE,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
static inline bool sd_version_is_flux(SDVersion version) {
|
||||
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_dit(SDVersion version) {
|
||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
enum PMVersion {
|
||||
PM_VERSION_1,
|
||||
PM_VERSION_2,
|
||||
};
|
||||
|
||||
struct TensorStorage {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
bool is_bf16 = false;
|
||||
bool is_f8_e4m3 = false;
|
||||
bool is_f8_e5m2 = false;
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
|
||||
@ -63,7 +92,7 @@ struct TensorStorage {
|
||||
}
|
||||
|
||||
int64_t nbytes_to_read() const {
|
||||
if (is_bf16 || is_f8_e4m3) {
|
||||
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
|
||||
return nbytes() / 2;
|
||||
} else {
|
||||
return nbytes();
|
||||
@ -113,6 +142,8 @@ struct TensorStorage {
|
||||
type_name = "bf16";
|
||||
} else if (is_f8_e4m3) {
|
||||
type_name = "f8_e4m3";
|
||||
} else if (is_f8_e5m2) {
|
||||
type_name = "f8_e5m2";
|
||||
}
|
||||
ss << name << " | " << type_name << " | ";
|
||||
ss << n_dims << " [";
|
||||
@ -157,6 +188,7 @@ public:
|
||||
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
|
||||
ggml_backend_t backend,
|
||||
std::set<std::string> ignore_tensors = {});
|
||||
|
||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
|
||||
601
pmid.hpp
601
pmid.hpp
@ -42,6 +42,370 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
class QFormerPerceiver(nn.Module):
|
||||
def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4):
|
||||
super().__init__()
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.use_residual = use_residual
|
||||
print(cross_attention_dim*num_tokens)
|
||||
self.token_proj = nn.Sequential(
|
||||
nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio),
|
||||
nn.GELU(),
|
||||
nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens),
|
||||
)
|
||||
self.token_norm = nn.LayerNorm(cross_attention_dim)
|
||||
self.perceiver_resampler = FacePerceiverResampler(
|
||||
dim=cross_attention_dim,
|
||||
depth=4,
|
||||
dim_head=128,
|
||||
heads=cross_attention_dim // 128,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=cross_attention_dim,
|
||||
ff_mult=4,
|
||||
)
|
||||
|
||||
def forward(self, x, last_hidden_state):
|
||||
x = self.token_proj(x)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.token_norm(x) # cls token
|
||||
out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens
|
||||
if self.use_residual: # TODO: if use_residual is not true
|
||||
out = x + 1.0 * out
|
||||
return out
|
||||
*/
|
||||
|
||||
struct PMFeedForward : public GGMLBlock {
|
||||
// network hparams
|
||||
int dim;
|
||||
|
||||
public:
|
||||
PMFeedForward(int d, int multi = 4)
|
||||
: dim(d) {
|
||||
int inner_dim = dim * multi;
|
||||
blocks["0"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["1"] = std::shared_ptr<GGMLBlock>(new Mlp(dim, inner_dim, dim, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["0"]);
|
||||
auto ff = std::dynamic_pointer_cast<Mlp>(blocks["1"]);
|
||||
|
||||
x = norm->forward(ctx, x);
|
||||
x = ff->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerceiverAttention : public GGMLBlock {
|
||||
// network hparams
|
||||
float scale; // = dim_head**-0.5
|
||||
int dim_head; // = dim_head
|
||||
int heads; // = heads
|
||||
public:
|
||||
PerceiverAttention(int dim, int dim_h = 64, int h = 8)
|
||||
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) {
|
||||
int inner_dim = dim_head * heads;
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
|
||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim, false));
|
||||
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim * 2, false));
|
||||
blocks["to_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim, false));
|
||||
}
|
||||
|
||||
struct ggml_tensor* reshape_tensor(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int heads) {
|
||||
int64_t ne[4];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
ne[i] = x->ne[i];
|
||||
// print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: ");
|
||||
// printf("heads = %d \n", heads);
|
||||
// x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads,
|
||||
// x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]);
|
||||
// x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2],
|
||||
// x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
// x = ggml_cont(ctx, x);
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
|
||||
// print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: ");
|
||||
// x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads);
|
||||
return x;
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> chunk_half(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
|
||||
auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2);
|
||||
return {ggml_cont(ctx, tlo),
|
||||
ggml_cont(ctx, tli)};
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* latents) {
|
||||
// x (torch.Tensor): image features
|
||||
// shape (b, n1, D)
|
||||
// latent (torch.Tensor): latent features
|
||||
// shape (b, n2, D)
|
||||
int64_t ne[4];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
ne[i] = latents->ne[i];
|
||||
|
||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||
x = norm1->forward(ctx, x);
|
||||
latents = norm2->forward(ctx, latents);
|
||||
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||
auto q = to_q->forward(ctx, latents);
|
||||
|
||||
auto kv_input = ggml_concat(ctx, x, latents, 1);
|
||||
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
|
||||
auto kv = to_kv->forward(ctx, kv_input);
|
||||
auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0);
|
||||
auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2));
|
||||
k = ggml_cont(ctx, k);
|
||||
v = ggml_cont(ctx, v);
|
||||
q = reshape_tensor(ctx, q, heads);
|
||||
k = reshape_tensor(ctx, k, heads);
|
||||
v = reshape_tensor(ctx, v, heads);
|
||||
scale = 1.f / sqrt(sqrt((float)dim_head));
|
||||
k = ggml_scale_inplace(ctx, k, scale);
|
||||
q = ggml_scale_inplace(ctx, q, scale);
|
||||
// auto weight = ggml_mul_mat(ctx, q, k);
|
||||
auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch
|
||||
|
||||
// GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1)
|
||||
// in this case, dimension along which Softmax will be computed is the last dim
|
||||
// in torch and the first dim in GGML, consistent with the convention that pytorch's
|
||||
// last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly).
|
||||
// weight = ggml_soft_max(ctx, weight);
|
||||
weight = ggml_soft_max_inplace(ctx, weight);
|
||||
v = ggml_cont(ctx, ggml_transpose(ctx, v));
|
||||
// auto out = ggml_mul_mat(ctx, weight, v);
|
||||
auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch
|
||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
|
||||
out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1]));
|
||||
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
|
||||
out = to_out->forward(ctx, out);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
struct FacePerceiverResampler : public GGMLBlock {
|
||||
// network hparams
|
||||
int depth;
|
||||
|
||||
public:
|
||||
FacePerceiverResampler(int dim = 768,
|
||||
int d = 4,
|
||||
int dim_head = 64,
|
||||
int heads = 16,
|
||||
int embedding_dim = 1280,
|
||||
int output_dim = 768,
|
||||
int ff_mult = 4)
|
||||
: depth(d) {
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(embedding_dim, dim, true));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(dim, output_dim, true));
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new LayerNorm(output_dim));
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "layers." + std::to_string(i) + ".0";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new PerceiverAttention(dim, dim_head, heads));
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new PMFeedForward(dim, ff_mult));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* latents,
|
||||
struct ggml_tensor* x) {
|
||||
// x: [N, channels, h, w]
|
||||
auto proj_in = std::dynamic_pointer_cast<Linear>(blocks["proj_in"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_out"]);
|
||||
|
||||
x = proj_in->forward(ctx, x);
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "layers." + std::to_string(i) + ".0";
|
||||
auto attn = std::dynamic_pointer_cast<PerceiverAttention>(blocks[name]);
|
||||
name = "layers." + std::to_string(i) + ".1";
|
||||
auto ff = std::dynamic_pointer_cast<PMFeedForward>(blocks[name]);
|
||||
auto t = attn->forward(ctx, x, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
t = ff->forward(ctx, latents);
|
||||
latents = ggml_add(ctx, t, latents);
|
||||
}
|
||||
latents = proj_out->forward(ctx, latents);
|
||||
latents = norm_out->forward(ctx, latents);
|
||||
return latents;
|
||||
}
|
||||
};
|
||||
|
||||
struct QFormerPerceiver : public GGMLBlock {
|
||||
// network hparams
|
||||
int num_tokens;
|
||||
int cross_attention_dim;
|
||||
bool use_residul;
|
||||
|
||||
public:
|
||||
QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4)
|
||||
: cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) {
|
||||
blocks["token_proj"] = std::shared_ptr<GGMLBlock>(new Mlp(id_embeddings_dim,
|
||||
id_embeddings_dim * ratio,
|
||||
cross_attention_dim * num_tokens,
|
||||
true));
|
||||
blocks["token_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(cross_attention_d));
|
||||
blocks["perceiver_resampler"] = std::shared_ptr<GGMLBlock>(new FacePerceiverResampler(
|
||||
cross_attention_dim,
|
||||
4,
|
||||
128,
|
||||
cross_attention_dim / 128,
|
||||
embedding_dim,
|
||||
cross_attention_dim,
|
||||
4));
|
||||
}
|
||||
|
||||
/*
|
||||
def forward(self, x, last_hidden_state):
|
||||
x = self.token_proj(x)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.token_norm(x) # cls token
|
||||
out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens
|
||||
if self.use_residual: # TODO: if use_residual is not true
|
||||
out = x + 1.0 * out
|
||||
return out
|
||||
*/
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* last_hidden_state) {
|
||||
// x: [N, channels, h, w]
|
||||
auto token_proj = std::dynamic_pointer_cast<Mlp>(blocks["token_proj"]);
|
||||
auto token_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["token_norm"]);
|
||||
auto perceiver_resampler = std::dynamic_pointer_cast<FacePerceiverResampler>(blocks["perceiver_resampler"]);
|
||||
|
||||
x = token_proj->forward(ctx, x);
|
||||
int64_t nel = ggml_nelements(x);
|
||||
x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens));
|
||||
x = token_norm->forward(ctx, x);
|
||||
struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state);
|
||||
if (use_residul)
|
||||
out = ggml_add(ctx, x, out);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
class FacePerceiverResampler(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim=768,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
embedding_dim=1280,
|
||||
output_dim=768,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.proj_in = torch.nn.Linear(embedding_dim, dim)
|
||||
self.proj_out = torch.nn.Linear(dim, output_dim)
|
||||
self.norm_out = torch.nn.LayerNorm(output_dim)
|
||||
self.layers = torch.nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, latents, x):
|
||||
x = self.proj_in(x)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
*/
|
||||
|
||||
/*
|
||||
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
*/
|
||||
|
||||
struct FuseModule : public GGMLBlock {
|
||||
// network hparams
|
||||
int embed_dim;
|
||||
@ -61,12 +425,19 @@ public:
|
||||
auto mlp2 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp2"]);
|
||||
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm"]);
|
||||
|
||||
auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3));
|
||||
auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
|
||||
// concat is along dim 2
|
||||
auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2);
|
||||
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3));
|
||||
// print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: ");
|
||||
// print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: ");
|
||||
|
||||
// auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3));
|
||||
// auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
|
||||
// print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: ");
|
||||
// print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: ");
|
||||
// concat is along dim 2
|
||||
// auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2);
|
||||
auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0);
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: ");
|
||||
// stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3));
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: ");
|
||||
// stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds);
|
||||
// stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds);
|
||||
// stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds);
|
||||
@ -77,6 +448,8 @@ public:
|
||||
stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds);
|
||||
stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds);
|
||||
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: ");
|
||||
|
||||
return stacked_id_embeds;
|
||||
}
|
||||
|
||||
@ -98,23 +471,31 @@ public:
|
||||
// print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos");
|
||||
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
|
||||
ggml_set_name(image_token_embeds, "image_token_embeds");
|
||||
valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0],
|
||||
ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]);
|
||||
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
|
||||
|
||||
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
|
||||
// stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds");
|
||||
// print_ggml_tensor(left, true, "AA left");
|
||||
// print_ggml_tensor(right, true, "AA right");
|
||||
if (left && right) {
|
||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2);
|
||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2);
|
||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
|
||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
|
||||
} else if (left) {
|
||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2);
|
||||
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1);
|
||||
} else if (right) {
|
||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2);
|
||||
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1);
|
||||
}
|
||||
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds");
|
||||
// stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
|
||||
// print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds");
|
||||
class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask));
|
||||
class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds);
|
||||
prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask);
|
||||
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds);
|
||||
ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds");
|
||||
// print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: ");
|
||||
return updated_prompt_embeds;
|
||||
}
|
||||
};
|
||||
@ -159,10 +540,77 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
|
||||
}
|
||||
};
|
||||
|
||||
struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection {
|
||||
int cross_attention_dim;
|
||||
int num_tokens;
|
||||
|
||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512)
|
||||
: CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14),
|
||||
cross_attention_dim(2048),
|
||||
num_tokens(2) {
|
||||
blocks["visual_projection_2"] = std::shared_ptr<GGMLBlock>(new Linear(1024, 1280, false));
|
||||
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
|
||||
/*
|
||||
cross_attention_dim = 2048
|
||||
# projection
|
||||
self.num_tokens = 2
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.qformer_perceiver = QFormerPerceiver(
|
||||
id_embeddings_dim,
|
||||
cross_attention_dim,
|
||||
self.num_tokens,
|
||||
)*/
|
||||
blocks["qformer_perceiver"] = std::shared_ptr<GGMLBlock>(new QFormerPerceiver(id_embeddings_dim,
|
||||
cross_attention_dim,
|
||||
num_tokens));
|
||||
}
|
||||
|
||||
/*
|
||||
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
last_hidden_state = self.vision_model(id_pixel_values)[0]
|
||||
id_embeds = id_embeds.view(b * num_inputs, -1)
|
||||
|
||||
id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state)
|
||||
id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1)
|
||||
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
||||
*/
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* id_pixel_values,
|
||||
struct ggml_tensor* prompt_embeds,
|
||||
struct ggml_tensor* class_tokens_mask,
|
||||
struct ggml_tensor* class_tokens_mask_pos,
|
||||
struct ggml_tensor* id_embeds,
|
||||
struct ggml_tensor* left,
|
||||
struct ggml_tensor* right) {
|
||||
// x: [N, channels, h, w]
|
||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
|
||||
auto qformer_perceiver = std::dynamic_pointer_cast<QFormerPerceiver>(blocks["qformer_perceiver"]);
|
||||
|
||||
// struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
|
||||
struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size]
|
||||
id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state);
|
||||
|
||||
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
|
||||
prompt_embeds,
|
||||
id_embeds,
|
||||
class_tokens_mask,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
return updated_prompt_embeds;
|
||||
}
|
||||
};
|
||||
|
||||
struct PhotoMakerIDEncoder : public GGMLRunner {
|
||||
public:
|
||||
SDVersion version = VERSION_SDXL;
|
||||
SDVersion version = VERSION_SDXL;
|
||||
PMVersion pm_version = PM_VERSION_1;
|
||||
PhotoMakerIDEncoderBlock id_encoder;
|
||||
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
|
||||
float style_strength;
|
||||
|
||||
std::vector<float> ctm;
|
||||
@ -175,25 +623,38 @@ public:
|
||||
std::vector<float> zeros_right;
|
||||
|
||||
public:
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f)
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
|
||||
: GGMLRunner(backend, wtype),
|
||||
version(version),
|
||||
pm_version(pm_v),
|
||||
style_strength(sty) {
|
||||
id_encoder.init(params_ctx, wtype);
|
||||
if (pm_version == PM_VERSION_1) {
|
||||
id_encoder.init(params_ctx, wtype);
|
||||
} else if (pm_version == PM_VERSION_2) {
|
||||
id_encoder2.init(params_ctx, wtype);
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return "pmid";
|
||||
}
|
||||
|
||||
PMVersion get_version() const {
|
||||
return pm_version;
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
id_encoder.get_param_tensors(tensors, prefix);
|
||||
if (pm_version == PM_VERSION_1)
|
||||
id_encoder.get_param_tensors(tensors, prefix);
|
||||
else if (pm_version == PM_VERSION_2)
|
||||
id_encoder2.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr,
|
||||
struct ggml_tensor* id_pixel_values,
|
||||
struct ggml_tensor* prompt_embeds,
|
||||
std::vector<bool>& class_tokens_mask) {
|
||||
std::vector<bool>& class_tokens_mask,
|
||||
struct ggml_tensor* id_embeds) {
|
||||
ctm.clear();
|
||||
ctmf16.clear();
|
||||
ctmpos.clear();
|
||||
@ -214,25 +675,32 @@ public:
|
||||
|
||||
struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
|
||||
struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
|
||||
struct ggml_tensor* id_embeds_d = to_backend(id_embeds);
|
||||
|
||||
struct ggml_tensor* left = NULL;
|
||||
struct ggml_tensor* right = NULL;
|
||||
for (int i = 0; i < class_tokens_mask.size(); i++) {
|
||||
if (class_tokens_mask[i]) {
|
||||
// printf(" 1,");
|
||||
ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask
|
||||
ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask
|
||||
ctmpos.push_back(i);
|
||||
} else {
|
||||
// printf(" 0,");
|
||||
ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask
|
||||
ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask
|
||||
}
|
||||
}
|
||||
// printf("\n");
|
||||
if (ctmpos[0] > 0) {
|
||||
left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]);
|
||||
// left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]);
|
||||
left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1);
|
||||
}
|
||||
if (ctmpos[ctmpos.size() - 1] < seq_length - 1) {
|
||||
// right = ggml_new_tensor_3d(ctx0, type,
|
||||
// hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
|
||||
right = ggml_new_tensor_3d(ctx0, type,
|
||||
hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
|
||||
hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1);
|
||||
}
|
||||
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size());
|
||||
|
||||
@ -265,12 +733,23 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
struct ggml_tensor* updated_prompt_embeds = NULL;
|
||||
if (pm_version == PM_VERSION_1)
|
||||
updated_prompt_embeds = id_encoder.forward(ctx0,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
else if (pm_version == PM_VERSION_2)
|
||||
updated_prompt_embeds = id_encoder2.forward(ctx0,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
id_embeds_d,
|
||||
left, right);
|
||||
|
||||
ggml_build_forward_expand(gf, updated_prompt_embeds);
|
||||
|
||||
return gf;
|
||||
@ -279,12 +758,13 @@ public:
|
||||
void compute(const int n_threads,
|
||||
struct ggml_tensor* id_pixel_values,
|
||||
struct ggml_tensor* prompt_embeds,
|
||||
struct ggml_tensor* id_embeds,
|
||||
std::vector<bool>& class_tokens_mask,
|
||||
struct ggml_tensor** updated_prompt_embeds,
|
||||
ggml_context* output_ctx) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
// return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask);
|
||||
return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask);
|
||||
return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds);
|
||||
};
|
||||
|
||||
// GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds);
|
||||
@ -292,4 +772,75 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
std::map<std::string, struct ggml_tensor*> tensors;
|
||||
std::string file_path;
|
||||
ModelLoader* model_loader;
|
||||
bool load_failed = false;
|
||||
bool applied = false;
|
||||
|
||||
PhotoMakerIDEmbed(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
ModelLoader* ml,
|
||||
const std::string& file_path = "",
|
||||
const std::string& prefix = "")
|
||||
: file_path(file_path), GGMLRunner(backend, wtype), model_loader(ml) {
|
||||
if (!model_loader->init_from_file(file_path, prefix)) {
|
||||
load_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return "id_embeds";
|
||||
}
|
||||
|
||||
bool load_from_file(bool filter_tensor = false) {
|
||||
LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str());
|
||||
|
||||
if (load_failed) {
|
||||
LOG_ERROR("init photomaker id embed from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool dry_run = true;
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
const std::string& name = tensor_storage.name;
|
||||
|
||||
if (filter_tensor && !contains(name, "pmid.id_embeds")) {
|
||||
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
||||
return true;
|
||||
}
|
||||
if (dry_run) {
|
||||
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
|
||||
tensor_storage.type,
|
||||
tensor_storage.n_dims,
|
||||
tensor_storage.ne);
|
||||
tensors[name] = real;
|
||||
} else {
|
||||
auto real = tensors[name];
|
||||
*dst_tensor = real;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
model_loader->load_tensors(on_new_tensor_cb, backend);
|
||||
alloc_params_buffer();
|
||||
|
||||
dry_run = false;
|
||||
model_loader->load_tensors(on_new_tensor_cb, backend);
|
||||
|
||||
LOG_DEBUG("finished loading PhotoMaker ID Embeds ");
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ggml_tensor* get() {
|
||||
std::map<std::string, struct ggml_tensor*>::iterator pos;
|
||||
pos = tensors.find("pmid.id_embeds");
|
||||
if (pos != tensors.end())
|
||||
return pos->second;
|
||||
return NULL;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __PMI_HPP__
|
||||
|
||||
@ -32,7 +32,9 @@ const char* model_version_to_str[] = {
|
||||
"SD3 2B",
|
||||
"Flux Dev",
|
||||
"Flux Schnell",
|
||||
"SD3.5 8B"};
|
||||
"SD3.5 8B",
|
||||
"SD3.5 2B",
|
||||
"Flux Lite 8B"};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
@ -93,6 +95,7 @@ public:
|
||||
std::shared_ptr<ControlNet> control_net;
|
||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||
std::shared_ptr<LoraModel> pmid_lora;
|
||||
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
|
||||
|
||||
std::string taesd_path;
|
||||
bool use_tiny_autoencoder = false;
|
||||
@ -153,7 +156,8 @@ public:
|
||||
schedule_t schedule,
|
||||
bool clip_on_cpu,
|
||||
bool control_net_cpu,
|
||||
bool vae_on_cpu) {
|
||||
bool vae_on_cpu,
|
||||
bool diffusion_flash_attn) {
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
#ifdef SD_USE_CUBLAS
|
||||
LOG_DEBUG("Using CUDA backend");
|
||||
@ -182,13 +186,7 @@ public:
|
||||
LOG_DEBUG("Using CPU backend");
|
||||
backend = ggml_backend_cpu_init();
|
||||
}
|
||||
#ifdef SD_USE_FLASH_ATTENTION
|
||||
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
|
||||
LOG_WARN("Flash Attention not supported with GPU Backend");
|
||||
#else
|
||||
LOG_INFO("Flash Attention enabled");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
ModelLoader model_loader;
|
||||
|
||||
vae_tiling = vae_tiling_;
|
||||
@ -288,9 +286,9 @@ public:
|
||||
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
|
||||
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
|
||||
}
|
||||
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
} else if (sd_version_is_sd3(version)) {
|
||||
scale_factor = 1.5305f;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
scale_factor = 0.3611;
|
||||
// TODO: shift_factor
|
||||
}
|
||||
@ -311,7 +309,7 @@ public:
|
||||
} else {
|
||||
clip_backend = backend;
|
||||
bool use_t5xxl = false;
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
use_t5xxl = true;
|
||||
}
|
||||
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
||||
@ -322,15 +320,25 @@ public:
|
||||
LOG_INFO("CLIP: Using CPU backend");
|
||||
clip_backend = ggml_backend_cpu_init();
|
||||
}
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
if (diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
}
|
||||
if (sd_version_is_sd3(version)) {
|
||||
if (diffusion_flash_attn) {
|
||||
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
|
||||
}
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
|
||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
|
||||
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
|
||||
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version);
|
||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
|
||||
}
|
||||
cond_stage_model->alloc_params_buffer();
|
||||
cond_stage_model->get_param_tensors(tensors);
|
||||
@ -364,7 +372,12 @@ public:
|
||||
control_net = std::make_shared<ControlNet>(controlnet_backend, diffusion_model_wtype, version);
|
||||
}
|
||||
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(clip_backend, model_wtype, version);
|
||||
if (id_embeddings_path.find("v2") != std::string::npos) {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, PM_VERSION_2);
|
||||
LOG_INFO("using PhotoMaker Version 2");
|
||||
} else {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
|
||||
}
|
||||
if (id_embeddings_path.size() > 0) {
|
||||
pmid_lora = std::make_shared<LoraModel>(backend, model_wtype, id_embeddings_path, "");
|
||||
if (!pmid_lora->load_from_file(true)) {
|
||||
@ -383,14 +396,8 @@ public:
|
||||
LOG_ERROR(" pmid model params buffer allocation failed");
|
||||
return false;
|
||||
}
|
||||
// LOG_INFO("pmid param memory buffer size = %.2fMB ",
|
||||
// pmid_model->params_buffer_size / 1024.0 / 1024.0);
|
||||
pmid_model->get_param_tensors(tensors, "pmid");
|
||||
}
|
||||
// if(stacked_id){
|
||||
// pmid_model.init_params(GGML_TYPE_F32);
|
||||
// pmid_model.map_by_name(tensors, "pmid.");
|
||||
// }
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
@ -520,10 +527,10 @@ public:
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = 1.15f;
|
||||
if (version == VERSION_FLUX_SCHNELL) {
|
||||
@ -673,10 +680,10 @@ public:
|
||||
ggml_tensor* id_encoder(ggml_context* work_ctx,
|
||||
ggml_tensor* init_img,
|
||||
ggml_tensor* prompts_embeds,
|
||||
ggml_tensor* id_embeds,
|
||||
std::vector<bool>& class_tokens_mask) {
|
||||
ggml_tensor* res = NULL;
|
||||
pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx);
|
||||
|
||||
pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -771,7 +778,11 @@ public:
|
||||
sample_method_t method,
|
||||
const std::vector<float>& sigmas,
|
||||
int start_merge_step,
|
||||
SDCondition id_cond) {
|
||||
SDCondition id_cond,
|
||||
std::vector<int> skip_layers = {},
|
||||
float slg_scale = 2.5,
|
||||
float skip_layer_start = 0.01,
|
||||
float skip_layer_end = 0.2) {
|
||||
size_t steps = sigmas.size() - 1;
|
||||
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
|
||||
// print_ggml_tensor(noise);
|
||||
@ -782,13 +793,24 @@ public:
|
||||
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise);
|
||||
|
||||
bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL;
|
||||
bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0;
|
||||
|
||||
// denoise wrapper
|
||||
struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x);
|
||||
struct ggml_tensor* out_uncond = NULL;
|
||||
struct ggml_tensor* out_skip = NULL;
|
||||
|
||||
if (has_unconditioned) {
|
||||
out_uncond = ggml_dup_tensor(work_ctx, x);
|
||||
}
|
||||
if (has_skiplayer) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
out_skip = ggml_dup_tensor(work_ctx, x);
|
||||
} else {
|
||||
has_skiplayer = false;
|
||||
LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]);
|
||||
}
|
||||
}
|
||||
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
|
||||
|
||||
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
|
||||
@ -869,6 +891,28 @@ public:
|
||||
&out_uncond);
|
||||
negative_data = (float*)out_uncond->data;
|
||||
}
|
||||
|
||||
int step_count = sigmas.size();
|
||||
bool is_skiplayer_step = has_skiplayer && step > (int)(skip_layer_start * step_count) && step < (int)(skip_layer_end * step_count);
|
||||
float* skip_layer_data = NULL;
|
||||
if (is_skiplayer_step) {
|
||||
LOG_DEBUG("Skipping layers at step %d\n", step);
|
||||
// skip layer (same as conditionned)
|
||||
diffusion_model->compute(n_threads,
|
||||
noised_input,
|
||||
timesteps,
|
||||
cond.c_crossattn,
|
||||
cond.c_concat,
|
||||
cond.c_vector,
|
||||
guidance_tensor,
|
||||
-1,
|
||||
controls,
|
||||
control_strength,
|
||||
&out_skip,
|
||||
NULL,
|
||||
skip_layers);
|
||||
skip_layer_data = (float*)out_skip->data;
|
||||
}
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
float* vec_input = (float*)input->data;
|
||||
float* positive_data = (float*)out_cond->data;
|
||||
@ -885,6 +929,9 @@ public:
|
||||
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
|
||||
}
|
||||
}
|
||||
if (is_skiplayer_step) {
|
||||
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale;
|
||||
}
|
||||
// v = latent_result, eps = latent_result
|
||||
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
|
||||
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
|
||||
@ -948,9 +995,9 @@ public:
|
||||
if (use_tiny_autoencoder) {
|
||||
C = 4;
|
||||
} else {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(version)) {
|
||||
C = 32;
|
||||
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
C = 32;
|
||||
}
|
||||
}
|
||||
@ -1035,7 +1082,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
enum schedule_t s,
|
||||
bool keep_clip_on_cpu,
|
||||
bool keep_control_net_cpu,
|
||||
bool keep_vae_on_cpu) {
|
||||
bool keep_vae_on_cpu,
|
||||
bool diffusion_flash_attn) {
|
||||
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
|
||||
if (sd_ctx == NULL) {
|
||||
return NULL;
|
||||
@ -1076,7 +1124,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
|
||||
s,
|
||||
keep_clip_on_cpu,
|
||||
keep_control_net_cpu,
|
||||
keep_vae_on_cpu)) {
|
||||
keep_vae_on_cpu,
|
||||
diffusion_flash_attn)) {
|
||||
delete sd_ctx->sd;
|
||||
sd_ctx->sd = NULL;
|
||||
free(sd_ctx);
|
||||
@ -1111,7 +1160,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
float control_strength,
|
||||
float style_ratio,
|
||||
bool normalize_input,
|
||||
std::string input_id_images_path) {
|
||||
std::string input_id_images_path,
|
||||
std::vector<int> skip_layers = {},
|
||||
float slg_scale = 2.5,
|
||||
float skip_layer_start = 0.01,
|
||||
float skip_layer_end = 0.2) {
|
||||
if (seed < 0) {
|
||||
// Generally, when using the provided command line, the seed is always >0.
|
||||
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
|
||||
@ -1161,11 +1214,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
}
|
||||
// preprocess input id images
|
||||
std::vector<sd_image_t*> input_id_images;
|
||||
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
|
||||
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
|
||||
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
|
||||
for (std::string img_file : img_files) {
|
||||
int c = 0;
|
||||
int width, height;
|
||||
if (ends_with(img_file, "safetensors")) {
|
||||
continue;
|
||||
}
|
||||
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
|
||||
if (input_image_buffer == NULL) {
|
||||
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
|
||||
@ -1203,18 +1260,23 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
else
|
||||
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
|
||||
}
|
||||
t0 = ggml_time_ms();
|
||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||
sd_ctx->sd->n_threads, prompt,
|
||||
clip_skip,
|
||||
width,
|
||||
height,
|
||||
num_input_images,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
||||
id_cond = std::get<0>(cond_tup);
|
||||
class_tokens_mask = std::get<1>(cond_tup); //
|
||||
|
||||
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, class_tokens_mask);
|
||||
t0 = ggml_time_ms();
|
||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||
sd_ctx->sd->n_threads, prompt,
|
||||
clip_skip,
|
||||
width,
|
||||
height,
|
||||
num_input_images,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
||||
id_cond = std::get<0>(cond_tup);
|
||||
class_tokens_mask = std::get<1>(cond_tup); //
|
||||
struct ggml_tensor* id_embeds = NULL;
|
||||
if (pmv2) {
|
||||
// id_embeds = sd_ctx->sd->pmid_id_embeds->get();
|
||||
id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin"));
|
||||
// print_ggml_tensor(id_embeds, true, "id_embeds:");
|
||||
}
|
||||
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
|
||||
t1 = ggml_time_ms();
|
||||
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
@ -1281,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
// Sample
|
||||
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
}
|
||||
int W = width / 8;
|
||||
@ -1320,7 +1382,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
sample_method,
|
||||
sigmas,
|
||||
start_merge_step,
|
||||
id_cond);
|
||||
id_cond,
|
||||
skip_layers,
|
||||
slg_scale,
|
||||
skip_layer_start,
|
||||
skip_layer_end);
|
||||
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
|
||||
// print_ggml_tensor(x_0);
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
@ -1386,7 +1452,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
float control_strength,
|
||||
float style_ratio,
|
||||
bool normalize_input,
|
||||
const char* input_id_images_path_c_str) {
|
||||
const char* input_id_images_path_c_str,
|
||||
std::vector<int> skip_layers,
|
||||
float slg_scale,
|
||||
float skip_layer_start,
|
||||
float skip_layer_end) {
|
||||
LOG_DEBUG("txt2img %dx%d", width, height);
|
||||
if (sd_ctx == NULL) {
|
||||
return NULL;
|
||||
@ -1394,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
params.mem_size *= 3;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
params.mem_size *= 4;
|
||||
}
|
||||
if (sd_ctx->sd->stacked_id) {
|
||||
@ -1420,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
||||
|
||||
int C = 4;
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
}
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
ggml_set_f32(init_latent, 0.0609f);
|
||||
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
ggml_set_f32(init_latent, 0.1159f);
|
||||
} else {
|
||||
ggml_set_f32(init_latent, 0.f);
|
||||
@ -1454,7 +1524,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
control_strength,
|
||||
style_ratio,
|
||||
normalize_input,
|
||||
input_id_images_path_c_str);
|
||||
input_id_images_path_c_str,
|
||||
skip_layers,
|
||||
slg_scale,
|
||||
skip_layer_start,
|
||||
skip_layer_end);
|
||||
|
||||
size_t t1 = ggml_time_ms();
|
||||
|
||||
@ -1481,7 +1555,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
float control_strength,
|
||||
float style_ratio,
|
||||
bool normalize_input,
|
||||
const char* input_id_images_path_c_str) {
|
||||
const char* input_id_images_path_c_str,
|
||||
std::vector<int> skip_layers,
|
||||
float slg_scale,
|
||||
float skip_layer_start,
|
||||
float skip_layer_end) {
|
||||
LOG_DEBUG("img2img %dx%d", width, height);
|
||||
if (sd_ctx == NULL) {
|
||||
return NULL;
|
||||
@ -1489,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
params.mem_size *= 2;
|
||||
}
|
||||
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
params.mem_size *= 3;
|
||||
}
|
||||
if (sd_ctx->sd->stacked_id) {
|
||||
@ -1555,7 +1633,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
control_strength,
|
||||
style_ratio,
|
||||
normalize_input,
|
||||
input_id_images_path_c_str);
|
||||
input_id_images_path_c_str,
|
||||
skip_layers,
|
||||
slg_scale,
|
||||
skip_layer_start,
|
||||
skip_layer_end);
|
||||
|
||||
size_t t2 = ggml_time_ms();
|
||||
|
||||
|
||||
@ -142,7 +142,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
|
||||
enum schedule_t s,
|
||||
bool keep_clip_on_cpu,
|
||||
bool keep_control_net_cpu,
|
||||
bool keep_vae_on_cpu);
|
||||
bool keep_vae_on_cpu,
|
||||
bool diffusion_flash_attn);
|
||||
|
||||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
||||
|
||||
@ -162,7 +163,11 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
||||
float control_strength,
|
||||
float style_strength,
|
||||
bool normalize_input,
|
||||
const char* input_id_images_path);
|
||||
const char* input_id_images_path,
|
||||
std::vector<int> skip_layers = {},
|
||||
float slg_scale = 2.5,
|
||||
float skip_layer_start = 0.01,
|
||||
float skip_layer_end = 0.2);
|
||||
|
||||
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
sd_image_t init_image,
|
||||
@ -182,7 +187,11 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
||||
float control_strength,
|
||||
float style_strength,
|
||||
bool normalize_input,
|
||||
const char* input_id_images_path);
|
||||
const char* input_id_images_path,
|
||||
std::vector<int> skip_layers = {},
|
||||
float slg_scale = 2.5,
|
||||
float skip_layer_start = 0.01,
|
||||
float skip_layer_end = 0.2);
|
||||
|
||||
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
|
||||
sd_image_t init_image,
|
||||
|
||||
11
unet.hpp
11
unet.hpp
@ -183,7 +183,7 @@ public:
|
||||
int model_channels = 320;
|
||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1)
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
|
||||
: version(version) {
|
||||
if (version == VERSION_SD2) {
|
||||
context_dim = 1024;
|
||||
@ -242,7 +242,7 @@ public:
|
||||
if (version == VERSION_SVD) {
|
||||
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
} else {
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn);
|
||||
}
|
||||
};
|
||||
|
||||
@ -533,8 +533,9 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
|
||||
UNetModelRunner(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend, wtype), unet(version) {
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend, wtype), unet(version, flash_attn) {
|
||||
unet.init(params_ctx, wtype);
|
||||
}
|
||||
|
||||
@ -649,4 +650,4 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __UNET_HPP__
|
||||
#endif // __UNET_HPP__
|
||||
|
||||
17
util.cpp
17
util.cpp
@ -276,6 +276,23 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
||||
return p1 + "/" + p2;
|
||||
}
|
||||
|
||||
std::vector<std::string> splitString(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> result;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
result.push_back(str.substr(start, end - start));
|
||||
start = end + 1;
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
// Add the last segment after the last delimiter
|
||||
result.push_back(str.substr(start));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
sd_image_t* preprocess_id_image(sd_image_t* img) {
|
||||
int shortest_edge = 224;
|
||||
int size = shortest_edge;
|
||||
|
||||
2
util.h
2
util.h
@ -45,7 +45,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
|
||||
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
|
||||
|
||||
std::string path_join(const std::string& p1, const std::string& p2);
|
||||
|
||||
std::vector<std::string> splitString(const std::string& str, char delimiter);
|
||||
void pretty_progress(int step, int steps, float time);
|
||||
|
||||
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
||||
|
||||
12
vae.hpp
12
vae.hpp
@ -99,10 +99,12 @@ public:
|
||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||
// h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||
h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false);
|
||||
|
||||
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
@ -457,7 +459,7 @@ public:
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
dd_config.z_channels = 16;
|
||||
use_quant = false;
|
||||
}
|
||||
@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner {
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user