Compare commits

...

8 Commits

Author SHA1 Message Date
leejet
dcf91f9e0f chore: change SD_CUBLAS/SD_USE_CUBLAS to SD_CUDA/SD_USE_CUDA 2024-12-28 13:27:51 +08:00
stduhpf
348a54e34a
feat: use pretty-progress for tensor loading (#516) 2024-12-28 13:14:52 +08:00
stduhpf
d50473dc49
feat: support 16 channel tae (taesd/taef1) (#527) 2024-12-28 13:13:48 +08:00
piallai
b5cc1422da
fix: fix typo for skip layers parameters (#492) 2024-12-28 13:12:08 +08:00
R0CKSTAR
5cc74d1f09
feat: support Moore Threads GPU (#529)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2024-12-28 13:08:36 +08:00
stduhpf
0d9d6659a7
fix: fix metal build (#513) 2024-12-28 13:06:17 +08:00
stduhpf
8f4ab9add3
feat: support Inpaint models (#511) 2024-12-28 13:04:49 +08:00
stduhpf
cc92a6a1b3
feat: support more LoRA models (#520) 2024-12-28 12:56:44 +08:00
19 changed files with 1026 additions and 177 deletions

View File

@ -163,7 +163,7 @@ jobs:
- build: "avx512" - build: "avx512"
defines: "-DGGML_AVX512=ON -DSD_BUILD_SHARED_LIBS=ON" defines: "-DGGML_AVX512=ON -DSD_BUILD_SHARED_LIBS=ON"
- build: "cuda12" - build: "cuda12"
defines: "-DSD_CUBLAS=ON -DSD_BUILD_SHARED_LIBS=ON" defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON"
# - build: "rocm5.5" # - build: "rocm5.5"
# defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON' # defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON'
- build: 'vulkan' - build: 'vulkan'

View File

@ -24,19 +24,20 @@ endif()
# general # general
#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE})
option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_CUDA "sd: cuda backend" OFF)
option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF) option(SD_METAL "sd: metal backend" OFF)
option(SD_VULKAN "sd: vulkan backend" OFF) option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF) option(SD_SYCL "sd: sycl backend" OFF)
option(SD_MUSA "sd: musa backend" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF) option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON) #option(SD_BUILD_SERVER "sd: build server example" ON)
if(SD_CUBLAS) if(SD_CUDA)
message("-- Use CUBLAS as backend stable-diffusion") message("-- Use CUDA as backend stable-diffusion")
set(GGML_CUDA ON) set(GGML_CUDA ON)
add_definitions(-DSD_USE_CUBLAS) add_definitions(-DSD_USE_CUDA)
endif() endif()
if(SD_METAL) if(SD_METAL)
@ -54,15 +55,24 @@ endif ()
if (SD_HIPBLAS) if (SD_HIPBLAS)
message("-- Use HIPBLAS as backend stable-diffusion") message("-- Use HIPBLAS as backend stable-diffusion")
set(GGML_HIPBLAS ON) set(GGML_HIPBLAS ON)
add_definitions(-DSD_USE_CUBLAS) add_definitions(-DSD_USE_CUDA)
if(SD_FAST_SOFTMAX) if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON) set(GGML_CUDA_FAST_SOFTMAX ON)
endif() endif()
endif () endif ()
if(SD_MUSA)
message("-- Use MUSA as backend stable-diffusion")
set(GGML_MUSA ON)
add_definitions(-DSD_USE_CUBLAS)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif()
set(SD_LIB stable-diffusion) set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES file(GLOB SD_LIB_SOURCES
"*.h" "*.h"
"*.cpp" "*.cpp"
"*.hpp" "*.hpp"

19
Dockerfile.musa Normal file
View File

@ -0,0 +1,19 @@
ARG MUSA_VERSION=rc3.1.0
FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build
RUN apt-get update && apt-get install -y cmake
WORKDIR /sd.cpp
COPY . .
RUN mkdir build && cd build && \
cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \
cmake --build . --config Release
FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime
COPY --from=build /sd.cpp/build/bin/sd /sd
ENTRYPOINT [ "/sd" ]

View File

@ -113,12 +113,12 @@ cmake .. -DGGML_OPENBLAS=ON
cmake --build . --config Release cmake --build . --config Release
``` ```
##### Using CUBLAS ##### Using CUDA
This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). Recommended to have at least 4 GB of VRAM. This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). Recommended to have at least 4 GB of VRAM.
``` ```
cmake .. -DSD_CUBLAS=ON cmake .. -DSD_CUDA=ON
cmake --build . --config Release cmake --build . --config Release
``` ```
@ -132,6 +132,14 @@ cmake .. -G "Ninja" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_H
cmake --build . --config Release cmake --build . --config Release
``` ```
##### Using MUSA
This provides BLAS acceleration using the MUSA cores of your Moore Threads GPU. Make sure to have the MUSA toolkit installed.
```bash
cmake .. -DCMAKE_C_COMPILER=/usr/local/musa/bin/clang -DCMAKE_CXX_COMPILER=/usr/local/musa/bin/clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release
cmake --build . --config Release
```
##### Using Metal ##### Using Metal
@ -232,6 +240,10 @@ arguments:
-p, --prompt [PROMPT] the prompt to render -p, --prompt [PROMPT] the prompt to render
-n, --negative-prompt PROMPT the negative prompt (default: "") -n, --negative-prompt PROMPT the negative prompt (default: "")
--cfg-scale SCALE unconditional guidance scale: (default: 7.0) --cfg-scale SCALE unconditional guidance scale: (default: 7.0)
--skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])
--skip-layer-start START SLG enabling point: (default: 0.01)
--skip-layer-end END SLG disabling point: (default: 0.2)
SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])
--strength STRENGTH strength for noising/unnoising (default: 0.75) --strength STRENGTH strength for noising/unnoising (default: 0.75)
--style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%)
--control-strength STRENGTH strength to apply Control Net (default: 0.9) --control-strength STRENGTH strength to apply Control Net (default: 0.9)

View File

@ -61,18 +61,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1, SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1, PMVersion pv = PM_VERSION_1,
int clip_skip = -1) int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) { : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
if (clip_skip <= 0) { if (clip_skip <= 0) {
clip_skip = 1; clip_skip = 1;
if (version == VERSION_SD2 || version == VERSION_SDXL) { if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
clip_skip = 2; clip_skip = 2;
} }
} }
if (version == VERSION_SD1) { if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip); text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
} else if (version == VERSION_SD2) { } else if (sd_version_is_sd2(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip); text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
} else if (version == VERSION_SDXL) { } else if (sd_version_is_sdxl(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
} }
@ -80,35 +80,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
void set_clip_skip(int clip_skip) { void set_clip_skip(int clip_skip) {
text_model->set_clip_skip(clip_skip); text_model->set_clip_skip(clip_skip);
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
text_model2->set_clip_skip(clip_skip); text_model2->set_clip_skip(clip_skip);
} }
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
} }
} }
void alloc_params_buffer() { void alloc_params_buffer() {
text_model->alloc_params_buffer(); text_model->alloc_params_buffer();
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
text_model2->alloc_params_buffer(); text_model2->alloc_params_buffer();
} }
} }
void free_params_buffer() { void free_params_buffer() {
text_model->free_params_buffer(); text_model->free_params_buffer();
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
text_model2->free_params_buffer(); text_model2->free_params_buffer();
} }
} }
size_t get_params_buffer_size() { size_t get_params_buffer_size() {
size_t buffer_size = text_model->get_params_buffer_size(); size_t buffer_size = text_model->get_params_buffer_size();
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
buffer_size += text_model2->get_params_buffer_size(); buffer_size += text_model2->get_params_buffer_size();
} }
return buffer_size; return buffer_size;
@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
struct ggml_tensor* input_ids2 = NULL; struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0; size_t max_token_idx = 0;
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
if (it != chunk_tokens.end()) { if (it != chunk_tokens.end()) {
std::fill(std::next(it), chunk_tokens.end(), 0); std::fill(std::next(it), chunk_tokens.end(), 0);
@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
false, false,
&chunk_hidden_states1, &chunk_hidden_states1,
work_ctx); work_ctx);
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
text_model2->compute(n_threads, text_model2->compute(n_threads,
input_ids2, input_ids2,
0, 0,
@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
ggml_tensor* vec = NULL; ggml_tensor* vec = NULL;
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
int out_dim = 256; int out_dim = 256;
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
// [0:1280] // [0:1280]

View File

@ -34,11 +34,11 @@ public:
ControlNetBlock(SDVersion version = VERSION_SD1) ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) { : version(version) {
if (version == VERSION_SD2) { if (sd_version_is_sd2(version)) {
context_dim = 1024; context_dim = 1024;
num_head_channels = 64; num_head_channels = 64;
num_heads = -1; num_heads = -1;
} else if (version == VERSION_SDXL) { } else if (sd_version_is_sdxl(version)) {
context_dim = 2048; context_dim = 2048;
attention_resolutions = {4, 2}; attention_resolutions = {4, 2};
channel_mult = {1, 2, 4}; channel_mult = {1, 2, 4};
@ -58,7 +58,7 @@ public:
// time_embed_1 is nn.SiLU() // time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim)); blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_SDXL || version == VERSION_SVD) { if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim)); blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU() // label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim)); blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

View File

@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
FluxModel(ggml_backend_t backend, FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types, std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false) SDVersion version = VERSION_FLUX,
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) { bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
} }
void alloc_params_buffer() { void alloc_params_buffer() {
@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL, struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL, struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) { std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers); return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
} }
}; };

View File

@ -85,6 +85,7 @@ struct SDParams {
std::string lora_model_dir; std::string lora_model_dir;
std::string output_path = "output.png"; std::string output_path = "output.png";
std::string input_path; std::string input_path;
std::string mask_path;
std::string control_image_path; std::string control_image_path;
std::string prompt; std::string prompt;
@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false"); printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
@ -207,9 +209,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n"); printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n"); printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break; break;
} }
params.input_path = argv[i]; params.input_path = argv[i];
} else if (arg == "--mask") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.mask_path = argv[i];
} else if (arg == "--control-image") { } else if (arg == "--control-image") {
if (++i >= argc) { if (++i >= argc) {
invalid_arg = true; invalid_arg = true;
@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
bool vae_decode_only = true; bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL; uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL; uint8_t* control_image_buffer = NULL;
uint8_t* mask_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) { if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false; vae_decode_only = false;
@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
} }
} }
if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
std::vector<uint8_t> arr(params.width * params.height, 255);
mask_image_buffer = arr.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};
sd_image_t* results; sd_image_t* results;
if (params.mode == TXT2IMG) { if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx, results = txt2img(sd_ctx,
@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
} else { } else {
results = img2img(sd_ctx, results = img2img(sd_ctx,
input_image, input_image,
mask_image,
params.prompt.c_str(), params.prompt.c_str(),
params.negative_prompt.c_str(), params.negative_prompt.c_str(),
params.clip_skip, params.clip_skip,

View File

@ -490,6 +490,7 @@ namespace Flux {
struct FluxParams { struct FluxParams {
int64_t in_channels = 64; int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768; int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096; int64_t context_in_dim = 4096;
int64_t hidden_size = 3072; int64_t hidden_size = 3072;
@ -642,8 +643,7 @@ namespace Flux {
Flux() {} Flux() {}
Flux(FluxParams params) Flux(FluxParams params)
: params(params) { : params(params) {
int64_t out_channels = params.in_channels; int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true)); blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size)); blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
@ -669,7 +669,7 @@ namespace Flux {
params.flash_attn)); params.flash_attn));
} }
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels)); blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
} }
struct ggml_tensor* patchify(struct ggml_context* ctx, struct ggml_tensor* patchify(struct ggml_context* ctx,
@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
struct ggml_tensor* pe, struct ggml_tensor* pe,
@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps // timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D) // context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels // y: (N, adm_in_channels) tensor of class labels
// guidance: (N,) // guidance: (N,)
// pe: (L, d_head/2, 2, 2) // pe: (L, d_head/2, 2, 2)
@ -806,6 +808,7 @@ namespace Flux {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2; int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
}
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@ -834,12 +850,16 @@ namespace Flux {
FluxRunner(ggml_backend_t backend, FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "", const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false) bool flash_attn = false)
: GGMLRunner(backend) { : GGMLRunner(backend) {
flux_params.flash_attn = flash_attn; flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false; flux_params.guidance_embed = false;
flux_params.depth = 0; flux_params.depth = 0;
flux_params.depth_single_blocks = 0; flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) { for (auto pair : tensor_types) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos) if (tensor_name.find("model.diffusion_model.") == std::string::npos)
@ -886,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) { std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
x = to_backend(x); x = to_backend(x);
context = to_backend(context); context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y); y = to_backend(y);
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) { if (flux_params.guidance_embed) {
@ -913,6 +937,7 @@ namespace Flux {
x, x,
timesteps, timesteps,
context, context,
c_concat,
y, y,
guidance, guidance,
pe, pe,
@ -927,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL, struct ggml_tensor** output = NULL,
@ -938,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels] // y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ] // guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers); return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
}; };
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -978,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx); compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

View File

@ -27,7 +27,7 @@
#include "model.h" #include "model.h"
#ifdef SD_USE_CUBLAS #ifdef SD_USE_CUDA
#include "ggml-cuda.h" #include "ggml-cuda.h"
#endif #endif
@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
} }
} }
__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
float value = *(image_data + iy * width * channels + ix);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy);
}
}
}
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
for (int k = 0; k < channels; k++) {
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
}
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, __STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output, struct ggml_tensor* output,
int idx, int idx,
@ -672,7 +708,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
struct ggml_tensor* k, struct ggml_tensor* k,
struct ggml_tensor* v, struct ggml_tensor* v,
bool mask = false) { bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) #if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else #else
float d_head = (float)q->ne[0]; float d_head = (float)q->ne[0];
@ -828,7 +864,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
} }
__STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) {
#if defined(SD_USE_CUBLAS) || defined(SD_USE_SYCL) #if defined(SD_USE_CUDA) || defined(SD_USE_SYCL)
if (!ggml_backend_is_cpu(backend)) { if (!ggml_backend_is_cpu(backend)) {
ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_tensor_get_async(backend, tensor, data, offset, size);
ggml_backend_synchronize(backend); ggml_backend_synchronize(backend);
@ -1138,13 +1174,7 @@ public:
ggml_backend_cpu_set_n_threads(backend, n_threads); ggml_backend_cpu_set_n_threads(backend, n_threads);
} }
#ifdef SD_USE_METAL
if (ggml_backend_is_metal(backend)) {
ggml_backend_metal_set_n_cb(backend, n_threads);
}
#endif
ggml_backend_graph_compute(backend, gf); ggml_backend_graph_compute(backend, gf);
#ifdef GGML_PERF #ifdef GGML_PERF
ggml_graph_print(gf); ggml_graph_print(gf);
#endif #endif

632
lora.hpp
View File

@ -6,6 +6,90 @@
#define LORA_GRAPH_SIZE 10240 #define LORA_GRAPH_SIZE 10240
struct LoraModel : public GGMLRunner { struct LoraModel : public GGMLRunner {
enum lora_t {
REGULAR = 0,
DIFFUSERS = 1,
DIFFUSERS_2 = 2,
DIFFUSERS_3 = 3,
TRANSFORMERS = 4,
LORA_TYPE_COUNT
};
const std::string lora_ups[LORA_TYPE_COUNT] = {
".lora_up",
"_lora.up",
".lora_B",
".lora.up",
".lora_linear_layer.up",
};
const std::string lora_downs[LORA_TYPE_COUNT] = {
".lora_down",
"_lora.down",
".lora_A",
".lora.down",
".lora_linear_layer.down",
};
const std::string lora_pre[LORA_TYPE_COUNT] = {
"lora.",
"",
"",
"",
"",
};
const std::map<std::string, std::string> alt_names = {
// mmdit
{"final_layer.adaLN_modulation.1", "norm_out.linear"},
{"pos_embed", "pos_embed.proj"},
{"final_layer.linear", "proj_out"},
{"y_embedder.mlp.0", "time_text_embed.text_embedder.linear_1"},
{"y_embedder.mlp.2", "time_text_embed.text_embedder.linear_2"},
{"t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1"},
{"t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2"},
{"x_block.mlp.fc1", "ff.net.0.proj"},
{"x_block.mlp.fc2", "ff.net.2"},
{"context_block.mlp.fc1", "ff_context.net.0.proj"},
{"context_block.mlp.fc2", "ff_context.net.2"},
{"x_block.adaLN_modulation.1", "norm1.linear"},
{"context_block.adaLN_modulation.1", "norm1_context.linear"},
{"context_block.attn.proj", "attn.to_add_out"},
{"x_block.attn.proj", "attn.to_out.0"},
{"x_block.attn2.proj", "attn2.to_out.0"},
// flux
// singlestream
{"linear2", "proj_out"},
{"modulation.lin", "norm.linear"},
// doublestream
{"txt_attn.proj", "attn.to_add_out"},
{"img_attn.proj", "attn.to_out.0"},
{"txt_mlp.0", "ff_context.net.0.proj"},
{"txt_mlp.2", "ff_context.net.2"},
{"img_mlp.0", "ff.net.0.proj"},
{"img_mlp.2", "ff.net.2"},
{"txt_mod.lin", "norm1_context.linear"},
{"img_mod.lin", "norm1.linear"},
};
const std::map<std::string, std::string> qkv_prefixes = {
// mmdit
{"context_block.attn.qkv", "attn.add_"}, // suffix "_proj"
{"x_block.attn.qkv", "attn.to_"},
{"x_block.attn2.qkv", "attn2.to_"},
// flux
// doublestream
{"txt_attn.qkv", "attn.add_"}, // suffix "_proj"
{"img_attn.qkv", "attn.to_"},
};
const std::map<std::string, std::string> qkvm_prefixes = {
// flux
// singlestream
{"linear1", ""},
};
const std::string* type_fingerprints = lora_ups;
float multiplier = 1.0f; float multiplier = 1.0f;
std::map<std::string, struct ggml_tensor*> lora_tensors; std::map<std::string, struct ggml_tensor*> lora_tensors;
std::string file_path; std::string file_path;
@ -14,6 +98,7 @@ struct LoraModel : public GGMLRunner {
bool applied = false; bool applied = false;
std::vector<int> zero_index_vec = {0}; std::vector<int> zero_index_vec = {0};
ggml_tensor* zero_index = NULL; ggml_tensor* zero_index = NULL;
enum lora_t type = REGULAR;
LoraModel(ggml_backend_t backend, LoraModel(ggml_backend_t backend,
const std::string& file_path = "", const std::string& file_path = "",
@ -44,6 +129,13 @@ struct LoraModel : public GGMLRunner {
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
return true; return true;
} }
// LOG_INFO("%s", name.c_str());
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
if (name.find(type_fingerprints[i]) != std::string::npos) {
type = (lora_t)i;
break;
}
}
if (dry_run) { if (dry_run) {
struct ggml_tensor* real = ggml_new_tensor(params_ctx, struct ggml_tensor* real = ggml_new_tensor(params_ctx,
@ -61,10 +153,12 @@ struct LoraModel : public GGMLRunner {
model_loader.load_tensors(on_new_tensor_cb, backend); model_loader.load_tensors(on_new_tensor_cb, backend);
alloc_params_buffer(); alloc_params_buffer();
// exit(0);
dry_run = false; dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, backend); model_loader.load_tensors(on_new_tensor_cb, backend);
LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());
LOG_DEBUG("finished loaded lora"); LOG_DEBUG("finished loaded lora");
return true; return true;
} }
@ -76,7 +170,66 @@ struct LoraModel : public GGMLRunner {
return out; return out;
} }
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) { std::vector<std::string> to_lora_keys(std::string blk_name, SDVersion version) {
std::vector<std::string> keys;
// if (!sd_version_is_sd3(version) || blk_name != "model.diffusion_model.pos_embed") {
size_t k_pos = blk_name.find(".weight");
if (k_pos == std::string::npos) {
return keys;
}
blk_name = blk_name.substr(0, k_pos);
// }
keys.push_back(blk_name);
keys.push_back("lora." + blk_name);
if (sd_version_is_dit(version)) {
if (blk_name.find("model.diffusion_model") != std::string::npos) {
blk_name.replace(blk_name.find("model.diffusion_model"), sizeof("model.diffusion_model") - 1, "transformer");
}
if (blk_name.find(".single_blocks") != std::string::npos) {
blk_name.replace(blk_name.find(".single_blocks"), sizeof(".single_blocks") - 1, ".single_transformer_blocks");
}
if (blk_name.find(".double_blocks") != std::string::npos) {
blk_name.replace(blk_name.find(".double_blocks"), sizeof(".double_blocks") - 1, ".transformer_blocks");
}
if (blk_name.find(".joint_blocks") != std::string::npos) {
blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks");
}
for (const auto& item : alt_names) {
size_t match = blk_name.find(item.first);
if (match != std::string::npos) {
blk_name = blk_name.substr(0, match) + item.second;
}
}
for (const auto& prefix : qkv_prefixes) {
size_t match = blk_name.find(prefix.first);
if (match != std::string::npos) {
std::string split_blk = "SPLIT|" + blk_name.substr(0, match) + prefix.second;
keys.push_back(split_blk);
}
}
for (const auto& prefix : qkvm_prefixes) {
size_t match = blk_name.find(prefix.first);
if (match != std::string::npos) {
std::string split_blk = "SPLIT_L|" + blk_name.substr(0, match) + prefix.second;
keys.push_back(split_blk);
}
}
}
keys.push_back(blk_name);
std::vector<std::string> ret;
for (std::string& key : keys) {
ret.push_back(key);
replace_all_chars(key, '.', '_');
ret.push_back(key);
}
return ret;
}
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors, SDVersion version) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
@ -88,91 +241,416 @@ struct LoraModel : public GGMLRunner {
std::string k_tensor = it.first; std::string k_tensor = it.first;
struct ggml_tensor* weight = model_tensors[it.first]; struct ggml_tensor* weight = model_tensors[it.first];
size_t k_pos = k_tensor.find(".weight"); std::vector<std::string> keys = to_lora_keys(k_tensor, version);
if (k_pos == std::string::npos) { if (keys.size() == 0)
continue; continue;
}
k_tensor = k_tensor.substr(0, k_pos);
replace_all_chars(k_tensor, '.', '_');
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") {
// fix for some sdxl lora, like lcm-lora-xl
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
lora_up_name = "lora." + k_tensor + ".lora_up.weight";
}
}
std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
std::string alpha_name = "lora." + k_tensor + ".alpha";
std::string scale_name = "lora." + k_tensor + ".scale";
ggml_tensor* lora_up = NULL; ggml_tensor* lora_up = NULL;
ggml_tensor* lora_down = NULL; ggml_tensor* lora_down = NULL;
for (auto& key : keys) {
std::string alpha_name = "";
std::string scale_name = "";
std::string split_q_scale_name = "";
std::string lora_down_name = "";
std::string lora_up_name = "";
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { if (starts_with(key, "SPLIT|")) {
lora_up = lora_tensors[lora_up_name]; key = key.substr(sizeof("SPLIT|") - 1);
// TODO: Handle alphas
std::string suffix = "";
auto split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
suffix = "_proj";
split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
}
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = lora_pre[type] + key + "k" + suffix + lora_downs[type] + ".weight";
auto split_v_d_name = lora_pre[type] + key + "v" + suffix + lora_downs[type] + ".weight";
auto split_q_u_name = lora_pre[type] + key + "q" + suffix + lora_ups[type] + ".weight";
auto split_k_u_name = lora_pre[type] + key + "k" + suffix + lora_ups[type] + ".weight";
auto split_v_u_name = lora_pre[type] + key + "v" + suffix + lora_ups[type] + ".weight";
auto split_q_scale_name = lora_pre[type] + key + "q" + suffix + ".scale";
auto split_k_scale_name = lora_pre[type] + key + "k" + suffix + ".scale";
auto split_v_scale_name = lora_pre[type] + key + "v" + suffix + ".scale";
auto split_q_alpha_name = lora_pre[type] + key + "q" + suffix + ".alpha";
auto split_k_alpha_name = lora_pre[type] + key + "k" + suffix + ".alpha";
auto split_v_alpha_name = lora_pre[type] + key + "v" + suffix + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 |
// |0 ,k_up,0 |
// |0 ,0 ,v_up|
// (q_down,k_down,v_down) . (q ,k ,v)
// up_concat will be [9216, R*3, 1, 1]
// down_concat will be [R*3, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_scale(compute_ctx, z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
}
}
if (starts_with(key, "SPLIT_L|")) {
key = key.substr(sizeof("SPLIT_L|") - 1);
auto split_q_d_name = lora_pre[type] + key + "attn.to_q" + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = lora_pre[type] + key + "attn.to_k" + lora_downs[type] + ".weight";
auto split_v_d_name = lora_pre[type] + key + "attn.to_v" + lora_downs[type] + ".weight";
auto split_q_u_name = lora_pre[type] + key + "attn.to_q" + lora_ups[type] + ".weight";
auto split_k_u_name = lora_pre[type] + key + "attn.to_k" + lora_ups[type] + ".weight";
auto split_v_u_name = lora_pre[type] + key + "attn.to_v" + lora_ups[type] + ".weight";
auto split_m_d_name = lora_pre[type] + key + "proj_mlp" + lora_downs[type] + ".weight";
auto split_m_u_name = lora_pre[type] + key + "proj_mlp" + lora_ups[type] + ".weight";
auto split_q_scale_name = lora_pre[type] + key + "attn.to_q" + ".scale";
auto split_k_scale_name = lora_pre[type] + key + "attn.to_k" + ".scale";
auto split_v_scale_name = lora_pre[type] + key + "attn.to_v" + ".scale";
auto split_m_scale_name = lora_pre[type] + key + "proj_mlp" + ".scale";
auto split_q_alpha_name = lora_pre[type] + key + "attn.to_q" + ".alpha";
auto split_k_alpha_name = lora_pre[type] + key + "attn.to_k" + ".alpha";
auto split_v_alpha_name = lora_pre[type] + key + "attn.to_v" + ".alpha";
auto split_m_alpha_name = lora_pre[type] + key + "proj_mlp" + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
ggml_tensor* lora_m_down = NULL;
ggml_tensor* lora_m_up = NULL;
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
}
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
}
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float m_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
float lora_m_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
applied_lora_tensors.insert(split_m_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
applied_lora_tensors.insert(split_m_alpha_name);
lora_m_scale = lora_m_alpha / m_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 ,0 |
// |0 ,k_up,0 ,0 |
// |0 ,0 ,v_up,0 |
// |0 ,0 ,0 ,m_up|
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
// up_concat will be [21504, R*4, 1, 1]
// down_concat will be [R*4, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
ggml_scale(compute_ctx, z, 0);
ggml_scale(compute_ctx, mlp_z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_m_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
applied_lora_tensors.insert(split_m_d_name);
}
}
if (lora_up == NULL || lora_down == NULL) {
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
// fix for some sdxl lora, like lcm-lora-xl
key = "model_diffusion_model_output_blocks_2_1_conv";
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
}
}
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
alpha_name = lora_pre[type] + key + ".alpha";
scale_name = lora_pre[type] + key + ".scale";
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
lora_up = lora_tensors[lora_up_name];
}
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
lora_down = lora_tensors[lora_down_name];
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);
}
if (lora_up == NULL || lora_down == NULL) {
continue;
}
// calc_scale
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
float scale_value = 1.0f;
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
}
scale_value *= multiplier;
// flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
int64_t lora_down_rows = lora_down->ne[ggml_n_dims(lora_down) - 1];
lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
// ggml_mul_mat requires tensor b transposed
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
updown = ggml_reshape(compute_ctx, updown, weight);
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
ggml_tensor* final_weight;
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
// final_weight = ggml_cpy(compute_ctx, weight, final_weight);
final_weight = to_f32(compute_ctx, weight);
final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
final_weight = ggml_cpy(compute_ctx, final_weight, weight);
} else {
final_weight = ggml_add_inplace(compute_ctx, weight, updown);
}
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
ggml_build_forward_expand(gf, final_weight);
break;
} }
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
lora_down = lora_tensors[lora_down_name];
}
if (lora_up == NULL || lora_down == NULL) {
continue;
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);
// calc_cale
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
float scale_value = 1.0f;
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
}
scale_value *= multiplier;
// flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
int64_t lora_down_rows = lora_down->ne[ggml_n_dims(lora_down) - 1];
lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
// ggml_mul_mat requires tensor b transposed
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
updown = ggml_reshape(compute_ctx, updown, weight);
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
ggml_tensor* final_weight;
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
// final_weight = ggml_cpy(compute_ctx, weight, final_weight);
final_weight = to_f32(compute_ctx, weight);
final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
final_weight = ggml_cpy(compute_ctx, final_weight, weight);
} else {
final_weight = ggml_add_inplace(compute_ctx, weight, updown);
}
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
ggml_build_forward_expand(gf, final_weight);
} }
size_t total_lora_tensors_count = 0; size_t total_lora_tensors_count = 0;
size_t applied_lora_tensors_count = 0; size_t applied_lora_tensors_count = 0;
for (auto& kv : lora_tensors) { for (auto& kv : lora_tensors) {
total_lora_tensors_count++; total_lora_tensors_count++;
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
LOG_WARN("unused lora tensor %s", kv.first.c_str()); LOG_WARN("unused lora tensor |%s|", kv.first.c_str());
print_ggml_tensor(kv.second, true);
// exit(0);
} else { } else {
applied_lora_tensors_count++; applied_lora_tensors_count++;
} }
@ -191,9 +669,9 @@ struct LoraModel : public GGMLRunner {
return gf; return gf;
} }
void apply(std::map<std::string, struct ggml_tensor*> model_tensors, int n_threads) { void apply(std::map<std::string, struct ggml_tensor*> model_tensors, SDVersion version, int n_threads) {
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_lora_graph(model_tensors); return build_lora_graph(model_tensors, version);
}; };
GGMLRunner::compute(get_graph, n_threads, true); GGMLRunner::compute(get_graph, n_threads, true);
} }

View File

@ -1458,24 +1458,49 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
} }
SDVersion ModelLoader::get_sd_version() { SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight; TensorStorage token_embedding_weight, input_block_weight;
for (auto& tensor_storage : tensor_storages) { bool input_block_checked = false;
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
return VERSION_FLUX;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_SDXL;
}
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
return VERSION_SDXL;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
}
bool has_multiple_encoders = false;
bool is_unet = false;
bool is_xl = false;
bool is_flux = false;
#define found_family (is_xl || is_flux)
for (auto& tensor_storage : tensor_storages) {
if (!found_family) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
if (input_block_checked) {
break;
}
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
is_unet = true;
if(has_multiple_encoders){
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
has_multiple_encoders = true;
if(is_unet){
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
}
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@ -1485,11 +1510,39 @@ SDVersion ModelLoader::get_sd_version() {
token_embedding_weight = tensor_storage; token_embedding_weight = tensor_storage;
// break; // break;
} }
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (found_family) {
break;
}
}
}
bool is_inpaint = input_block_weight.ne[2] == 9;
if (is_xl) {
if (is_inpaint) {
return VERSION_SDXL_INPAINT;
}
return VERSION_SDXL;
}
if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
return VERSION_FLUX_FILL;
}
return VERSION_FLUX;
} }
if (token_embedding_weight.ne[0] == 768) { if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {
return VERSION_SD1_INPAINT;
}
return VERSION_SD1; return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) { } else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {
return VERSION_SD2_INPAINT;
}
return VERSION_SD2; return VERSION_SD2;
} }
return VERSION_COUNT; return VERSION_COUNT;
@ -1695,9 +1748,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} }
return true; return true;
}; };
int tensor_count = 0;
int64_t t1 = ggml_time_ms();
for (auto& tensor_storage : processed_tensor_storages) { for (auto& tensor_storage : processed_tensor_storages) {
if (tensor_storage.file_index != file_index) { if (tensor_storage.file_index != file_index) {
++tensor_count;
continue; continue;
} }
ggml_tensor* dst_tensor = NULL; ggml_tensor* dst_tensor = NULL;
@ -1709,6 +1764,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} }
if (dst_tensor == NULL) { if (dst_tensor == NULL) {
++tensor_count;
continue; continue;
} }
@ -1775,6 +1831,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
} }
} }
int64_t t2 = ggml_time_ms();
pretty_progress(++tensor_count, processed_tensor_storages.size(), (t2 - t1) / 1000.0f);
t1 = t2;
} }
if (zip != NULL) { if (zip != NULL) {

34
model.h
View File

@ -19,16 +19,20 @@
enum SDVersion { enum SDVersion {
VERSION_SD1, VERSION_SD1,
VERSION_SD1_INPAINT,
VERSION_SD2, VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SDXL, VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SVD, VERSION_SVD,
VERSION_SD3, VERSION_SD3,
VERSION_FLUX, VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_COUNT, VERSION_COUNT,
}; };
static inline bool sd_version_is_flux(SDVersion version) { static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX) { if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true; return true;
} }
return false; return false;
@ -41,6 +45,34 @@ static inline bool sd_version_is_sd3(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_sd2(SDVersion version) {
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true;
}
return false;
}
static inline bool sd_version_is_dit(SDVersion version) { static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
return true; return true;

View File

@ -26,11 +26,15 @@
const char* model_version_to_str[] = { const char* model_version_to_str[] = {
"SD 1.x", "SD 1.x",
"SD 1.x Inpaint",
"SD 2.x", "SD 2.x",
"SD 2.x Inpaint",
"SDXL", "SDXL",
"SDXL Inpaint",
"SVD", "SVD",
"SD3.x", "SD3.x",
"Flux"}; "Flux",
"Flux Fill"};
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"Euler A", "Euler A",
@ -155,13 +159,13 @@ public:
bool vae_on_cpu, bool vae_on_cpu,
bool diffusion_flash_attn) { bool diffusion_flash_attn) {
use_tiny_autoencoder = taesd_path.size() > 0; use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUBLAS #ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend"); LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0); backend = ggml_backend_cuda_init(0);
#endif #endif
#ifdef SD_USE_METAL #ifdef SD_USE_METAL
LOG_DEBUG("Using Metal backend"); LOG_DEBUG("Using Metal backend");
ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); ggml_log_set(ggml_log_callback_default, nullptr);
backend = ggml_backend_metal_init(); backend = ggml_backend_metal_init();
#endif #endif
#ifdef SD_USE_VULKAN #ifdef SD_USE_VULKAN
@ -263,7 +267,7 @@ public:
model_loader.set_wtype_override(wtype); model_loader.set_wtype_override(wtype);
} }
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
vae_wtype = GGML_TYPE_F32; vae_wtype = GGML_TYPE_F32;
model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
} }
@ -275,7 +279,7 @@ public:
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
if (version == VERSION_SDXL) { if (sd_version_is_sdxl(version)) {
scale_factor = 0.13025f; scale_factor = 0.13025f;
if (vae_path.size() == 0 && taesd_path.size() == 0) { if (vae_path.size() == 0 && taesd_path.size() == 0) {
LOG_WARN( LOG_WARN(
@ -329,7 +333,7 @@ public:
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types); cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn); diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
} else { } else {
if (id_embeddings_path.find("v2") != std::string::npos) { if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
@ -356,7 +360,7 @@ public:
first_stage_model->alloc_params_buffer(); first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model"); first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else { } else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only); tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version);
} }
// first_stage_model->get_param_tensors(tensors, "first_stage_model."); // first_stage_model->get_param_tensors(tensors, "first_stage_model.");
@ -517,8 +521,8 @@ public:
// check is_using_v_parameterization_for_sd2 // check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false; bool is_using_v_parameterization = false;
if (version == VERSION_SD2) { if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx)) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true; is_using_v_parameterization = true;
} }
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
@ -592,7 +596,7 @@ public:
return true; return true;
} }
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) { bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5); ggml_set_f32(x_t, 0.5);
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1); struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
@ -600,9 +604,13 @@ public:
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1);
ggml_set_f32(timesteps, 999); ggml_set_f32(timesteps, 999);
struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL;
ggml_set_f32(concat, 0);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out);
diffusion_model->free_compute_buffer(); diffusion_model->free_compute_buffer();
double result = 0.f; double result = 0.f;
@ -642,7 +650,8 @@ public:
} }
lora.multiplier = multiplier; lora.multiplier = multiplier;
lora.apply(tensors, n_threads); // TODO: send version?
lora.apply(tensors, version, n_threads);
lora.free_params_buffer(); lora.free_params_buffer();
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
@ -784,7 +793,20 @@ public:
std::vector<int> skip_layers = {}, std::vector<int> skip_layers = {},
float slg_scale = 0, float slg_scale = 0,
float skip_layer_start = 0.01, float skip_layer_start = 0.01,
float skip_layer_end = 0.2) { float skip_layer_end = 0.2,
ggml_tensor* noise_mask = nullptr) {
LOG_DEBUG("Sample");
struct ggml_init_params params;
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
for (int i = 1; i < 4; i++) {
data_size *= init_latent->ne[i];
}
data_size += 1024;
params.mem_size = data_size * 3;
params.mem_buffer = NULL;
params.no_alloc = false;
ggml_context* tmp_ctx = ggml_init(params);
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
// noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // noise = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(noise); // print_ggml_tensor(noise);
@ -943,6 +965,19 @@ public:
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
} }
if (noise_mask != nullptr) {
for (int64_t x = 0; x < denoised->ne[0]; x++) {
for (int64_t y = 0; y < denoised->ne[1]; y++) {
float mask = ggml_tensor_get_f32(noise_mask, x, y);
for (int64_t k = 0; k < denoised->ne[2]; k++) {
float init = ggml_tensor_get_f32(init_latent, x, y, k);
float den = ggml_tensor_get_f32(denoised, x, y, k);
ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k);
}
}
}
}
return denoised; return denoised;
}; };
@ -1166,7 +1201,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
std::vector<int> skip_layers = {}, std::vector<int> skip_layers = {},
float slg_scale = 0, float slg_scale = 0,
float skip_layer_start = 0.01, float skip_layer_start = 0.01,
float skip_layer_end = 0.2) { float skip_layer_end = 0.2,
ggml_tensor* masked_image = NULL) {
if (seed < 0) { if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0. // Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@ -1206,7 +1242,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) { if (!sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms(); t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads); sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
t1 = ggml_time_ms(); t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true; sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
@ -1316,7 +1352,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
SDCondition uncond; SDCondition uncond;
if (cfg_scale != 1.0) { if (cfg_scale != 1.0) {
bool force_zero_embeddings = false; bool force_zero_embeddings = false;
if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
force_zero_embeddings = true; force_zero_embeddings = true;
} }
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
@ -1353,6 +1389,39 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
ggml_tensor* noise_mask = nullptr;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
if (masked_image == NULL) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
// no mask, set the whole image as masked
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
// TODO: this might be wrong
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 1, x, y, c);
}
} else {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
}
}
}
}
cond.c_concat = masked_image;
uncond.c_concat = masked_image;
} else {
noise_mask = masked_image;
}
for (int b = 0; b < batch_count; b++) { for (int b = 0; b < batch_count; b++) {
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = seed + b; int64_t cur_seed = seed + b;
@ -1388,7 +1457,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
skip_layers, skip_layers,
slg_scale, slg_scale,
skip_layer_start, skip_layer_start,
skip_layer_end); skip_layer_end,
noise_mask);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0); // print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
@ -1510,6 +1581,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
ggml_set_f32(init_latent, 0.f); ggml_set_f32(init_latent, 0.f);
} }
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
}
sd_image_t* result_images = generate_image(sd_ctx, sd_image_t* result_images = generate_image(sd_ctx,
work_ctx, work_ctx,
init_latent, init_latent,
@ -1543,6 +1618,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_image_t init_image, sd_image_t init_image,
sd_image_t mask,
const char* prompt_c_str, const char* prompt_c_str,
const char* negative_prompt_c_str, const char* negative_prompt_c_str,
int clip_skip, int clip_skip,
@ -1582,7 +1658,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
} }
params.mem_size += width * height * 3 * sizeof(float) * 2; params.mem_size += width * height * 3 * sizeof(float) * 3;
params.mem_size *= batch_count; params.mem_size *= batch_count;
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = false; params.no_alloc = false;
@ -1603,7 +1679,70 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
sd_mask_to_tensor(mask.data, mask_img);
sd_image_to_tensor(init_image.data, init_img); sd_image_to_tensor(init_image.data, init_img);
ggml_tensor* masked_image;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_image_0 = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
} else {
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
}
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
int mx = ix * 8;
int my = iy * 8;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
}
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
}
}
} else {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
}
}
}
}
} else {
// LOG_WARN("Inpainting with a base model is not great");
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
for (int ix = 0; ix < masked_image->ne[0]; ix++) {
for (int iy = 0; iy < masked_image->ne[1]; iy++) {
int mx = ix * 8;
int my = iy * 8;
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy);
}
}
}
ggml_tensor* init_latent = NULL; ggml_tensor* init_latent = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) { if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
@ -1611,12 +1750,15 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
} else { } else {
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
} }
print_ggml_tensor(init_latent, true); print_ggml_tensor(init_latent, true);
size_t t1 = ggml_time_ms(); size_t t1 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength); size_t t_enc = static_cast<size_t>(sample_steps * strength);
if (t_enc == sample_steps)
t_enc--;
LOG_INFO("target t_enc is %zu steps", t_enc); LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched; std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
@ -1643,7 +1785,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
skip_layers_vec, skip_layers_vec,
slg_scale, slg_scale,
skip_layer_start, skip_layer_start,
skip_layer_end); skip_layer_end,
masked_image);
size_t t2 = ggml_time_ms(); size_t t2 = ggml_time_ms();

View File

@ -174,6 +174,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_image_t init_image, sd_image_t init_image,
sd_image_t mask_image,
const char* prompt, const char* prompt,
const char* negative_prompt, const char* negative_prompt,
int clip_skip, int clip_skip,

23
tae.hpp
View File

@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyEncoder() { TinyEncoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0; int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyDecoder(int index = 0) { TinyDecoder(int z_channels = 4)
: z_channels(z_channels) {
int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
index++; // nn.ReLU() index++; // nn.ReLU()
@ -163,12 +167,16 @@ protected:
bool decode_only; bool decode_only;
public: public:
TAESD(bool decode_only = true) TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) { : decode_only(decode_only) {
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder()); int z_channels = 4;
if (sd_version_is_dit(version)) {
z_channels = 16;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
if (!decode_only) { if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder()); blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
} }
} }
@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
TinyAutoEncoder(ggml_backend_t backend, TinyAutoEncoder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types, std::map<std::string, enum ggml_type>& tensor_types,
const std::string prefix, const std::string prefix,
bool decoder_only = true) bool decoder_only = true,
SDVersion version = VERSION_SD1)
: decode_only(decoder_only), : decode_only(decoder_only),
taesd(decode_only), taesd(decode_only, version),
GGMLRunner(backend) { GGMLRunner(backend) {
taesd.init(params_ctx, tensor_types, prefix); taesd.init(params_ctx, tensor_types, prefix);
} }

View File

@ -166,6 +166,7 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel // ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock { class UnetModelBlock : public GGMLBlock {
protected: protected:
static std::map<std::string, enum ggml_type> empty_tensor_types;
SDVersion version = VERSION_SD1; SDVersion version = VERSION_SD1;
// network hparams // network hparams
int in_channels = 4; int in_channels = 4;
@ -183,13 +184,13 @@ public:
int model_channels = 320; int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false) UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
: version(version) { : version(version) {
if (version == VERSION_SD2) { if (sd_version_is_sd2(version)) {
context_dim = 1024; context_dim = 1024;
num_head_channels = 64; num_head_channels = 64;
num_heads = -1; num_heads = -1;
} else if (version == VERSION_SDXL) { } else if (sd_version_is_sdxl(version)) {
context_dim = 2048; context_dim = 2048;
attention_resolutions = {4, 2}; attention_resolutions = {4, 2};
channel_mult = {1, 2, 4}; channel_mult = {1, 2, 4};
@ -204,6 +205,10 @@ public:
num_head_channels = 64; num_head_channels = 64;
num_heads = -1; num_heads = -1;
} }
if (sd_version_is_inpaint(version)) {
in_channels = 9;
}
// dims is always 2 // dims is always 2
// use_temporal_attention is always True for SVD // use_temporal_attention is always True for SVD
@ -211,7 +216,7 @@ public:
// time_embed_1 is nn.SiLU() // time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim)); blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (version == VERSION_SDXL || version == VERSION_SVD) { if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim)); blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU() // label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim)); blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner {
const std::string prefix, const std::string prefix,
SDVersion version = VERSION_SD1, SDVersion version = VERSION_SD1,
bool flash_attn = false) bool flash_attn = false)
: GGMLRunner(backend), unet(version, flash_attn) { : GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
unet.init(params_ctx, tensor_types, prefix); unet.init(params_ctx, tensor_types, prefix);
} }
@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner {
context = to_backend(context); context = to_backend(context);
y = to_backend(y); y = to_backend(y);
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
c_concat = to_backend(c_concat);
for (int i = 0; i < controls.size(); i++) { for (int i = 0; i < controls.size(); i++) {
controls[i] = to_backend(controls[i]); controls[i] = to_backend(controls[i]);

View File

@ -15,13 +15,13 @@ struct UpscalerGGML {
} }
bool load_from_file(const std::string& esrgan_path) { bool load_from_file(const std::string& esrgan_path) {
#ifdef SD_USE_CUBLAS #ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend"); LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0); backend = ggml_backend_cuda_init(0);
#endif #endif
#ifdef SD_USE_METAL #ifdef SD_USE_METAL
LOG_DEBUG("Using Metal backend"); LOG_DEBUG("Using Metal backend");
ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); ggml_log_set(ggml_log_callback_default, nullptr);
backend = ggml_backend_metal_init(); backend = ggml_backend_metal_init();
#endif #endif
#ifdef SD_USE_VULKAN #ifdef SD_USE_VULKAN

View File

@ -348,7 +348,7 @@ void pretty_progress(int step, int steps, float time) {
} }
} }
progress += "|"; progress += "|";
printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s", printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s\033[K",
progress.c_str(), step, steps, progress.c_str(), step, steps,
time > 1.0f || time == 0 ? time : (1.0f / time)); time > 1.0f || time == 0 ? time : (1.0f / time));
fflush(stdout); // for linux fflush(stdout); // for linux