diff --git a/common.hpp b/common.hpp index bfdcc00..b18ee51 100644 --- a/common.hpp +++ b/common.hpp @@ -367,7 +367,7 @@ protected: int64_t n_head; int64_t d_head; int64_t depth = 1; // 1 - int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x + int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2 public: SpatialTransformer(int64_t in_channels, diff --git a/conditioner.hpp b/conditioner.hpp index e01be2b..0e8f5a3 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -43,7 +43,7 @@ struct Conditioner { // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; CLIPTokenizer tokenizer; ggml_type wtype; std::shared_ptr text_model; @@ -58,20 +58,20 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, ggml_type wtype, const std::string& embd_dir, - SDVersion version = VERSION_1_x, + SDVersion version = VERSION_SD1, int clip_skip = -1) - : version(version), tokenizer(version == VERSION_2_x ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { + : version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { if (clip_skip <= 0) { clip_skip = 1; - if (version == VERSION_2_x || version == VERSION_XL) { + if (version == VERSION_SD2 || version == VERSION_SDXL) { clip_skip = 2; } } - if (version == VERSION_1_x) { + if (version == VERSION_SD1) { text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip); - } else if (version == VERSION_2_x) { + } else if (version == VERSION_SD2) { text_model = std::make_shared(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip); - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); text_model2 = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); } @@ -79,35 +79,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { void set_clip_skip(int clip_skip) { text_model->set_clip_skip(clip_skip); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->set_clip_skip(clip_skip); } } void get_param_tensors(std::map& tensors) { text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); } } void alloc_params_buffer() { text_model->alloc_params_buffer(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->alloc_params_buffer(); } } void free_params_buffer() { text_model->free_params_buffer(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->free_params_buffer(); } } size_t get_params_buffer_size() { size_t buffer_size = text_model->get_params_buffer_size(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { buffer_size += text_model2->get_params_buffer_size(); } return buffer_size; @@ -398,7 +398,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; size_t max_token_idx = 0; - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); if (it != chunk_tokens.end()) { std::fill(std::next(it), chunk_tokens.end(), 0); @@ -423,7 +423,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { false, &chunk_hidden_states1, work_ctx); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->compute(n_threads, input_ids2, 0, @@ -482,7 +482,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_tensor* vec = NULL; - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { int out_dim = 256; vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); // [0:1280] @@ -978,4 +978,230 @@ struct SD3CLIPEmbedder : public Conditioner { } }; + +struct FluxCLIPEmbedder : public Conditioner { + ggml_type wtype; + CLIPTokenizer clip_l_tokenizer; + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr clip_l; + std::shared_ptr t5; + + FluxCLIPEmbedder(ggml_backend_t backend, + ggml_type wtype, + int clip_skip = -1) + : wtype(wtype) { + if (clip_skip <= 0) { + clip_skip = 2; + } + clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true); + t5 = std::make_shared(backend, wtype); + } + + void set_clip_skip(int clip_skip) { + clip_l->set_clip_skip(clip_skip); + } + + void get_param_tensors(std::map& tensors) { + clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model"); + t5->get_param_tensors(tensors, "text_encoders.t5xxl"); + } + + void alloc_params_buffer() { + clip_l->alloc_params_buffer(); + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + clip_l->free_params_buffer(); + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = clip_l->get_params_buffer_size(); + buffer_size += t5->get_params_buffer_size(); + return buffer_size; + } + + std::vector, std::vector>> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector clip_l_tokens; + std::vector clip_l_weights; + std::vector t5_tokens; + std::vector t5_weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb); + clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight); + + curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + + // for (int i = 0; i < clip_l_tokens.size(); i++) { + // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; + // } + // std::cout << std::endl; + + // for (int i = 0; i < t5_tokens.size(); i++) { + // std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", "; + // } + // std::cout << std::endl; + + return {{clip_l_tokens, clip_l_weights}, {t5_tokens, t5_weights}}; + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::vector, std::vector>> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + set_clip_skip(clip_skip); + auto& clip_l_tokens = token_and_weights[0].first; + auto& clip_l_weights = token_and_weights[0].second; + auto& t5_tokens = token_and_weights[1].first; + auto& t5_weights = token_and_weights[1].second; + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + std::vector hidden_states_vec; + + size_t chunk_len = 256; + size_t chunk_count = t5_tokens.size() / chunk_len; + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // clip_l + if (chunk_idx == 0) { + size_t chunk_len_l = 77; + std::vector chunk_tokens(clip_l_tokens.begin(), + clip_l_tokens.begin() + chunk_len_l); + std::vector chunk_weights(clip_l_weights.begin(), + clip_l_weights.begin() + chunk_len_l); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + size_t max_token_idx = 0; + + // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + // clip_l->compute(n_threads, + // input_ids, + // 0, + // NULL, + // max_token_idx, + // true, + // &pooled, + // work_ctx); + + // clip_l.transformer.text_model.text_projection no in file, ignore + // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection + pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); + ggml_set_f32(pooled, 0.f); + } + + // t5 + { + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + + t5->compute(n_threads, + input_ids, + &chunk_hidden_states, + work_ctx); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + return SDCondition(hidden_states, pooled, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, 256, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + #endif \ No newline at end of file diff --git a/control.hpp b/control.hpp index 3375e73..41f31ac 100644 --- a/control.hpp +++ b/control.hpp @@ -14,7 +14,7 @@ */ class ControlNetBlock : public GGMLBlock { protected: - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; int out_channels = 4; @@ -26,19 +26,19 @@ protected: int time_embed_dim = 1280; // model_channels*4 int num_heads = 8; int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL + int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL public: int model_channels = 320; - int adm_in_channels = 2816; // only for VERSION_XL + int adm_in_channels = 2816; // only for VERSION_SDXL - ControlNetBlock(SDVersion version = VERSION_1_x) + ControlNetBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -58,7 +58,7 @@ public: // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_XL || version == VERSION_SVD) { + if (version == VERSION_SDXL || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -307,7 +307,7 @@ public: }; struct ControlNet : public GGMLRunner { - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; ControlNetBlock control_net; ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory @@ -318,7 +318,7 @@ struct ControlNet : public GGMLRunner { ControlNet(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : GGMLRunner(backend, wtype), control_net(version) { control_net.init(params_ctx, wtype); } diff --git a/denoiser.hpp b/denoiser.hpp index 26f4c85..85e4a0b 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -8,6 +8,7 @@ // Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py #define TIMESTEPS 1000 +#define FLUX_TIMESTEPS 1000 struct SigmaSchedule { int version = 0; @@ -144,13 +145,13 @@ struct AYSSchedule : SigmaSchedule { std::vector results(n + 1); switch (version) { - case VERSION_2_x: /* fallthrough */ + case VERSION_SD2: /* fallthrough */ LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_1_x: + case VERSION_SD1: LOG_INFO("AYS using SD1.5 noise levels"); inputs = noise_levels[0]; break; - case VERSION_XL: + case VERSION_SDXL: LOG_INFO("AYS using SDXL noise levels"); inputs = noise_levels[1]; break; @@ -350,6 +351,66 @@ struct DiscreteFlowDenoiser : public Denoiser { } }; + +float flux_time_shift(float mu, float sigma, float t) { + return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma)); +} + +struct FluxFlowDenoiser : public Denoiser { + float sigmas[TIMESTEPS]; + float shift = 1.15f; + + float sigma_data = 1.0f; + + FluxFlowDenoiser(float shift = 1.15f) { + set_parameters(shift); + } + + void set_parameters(float shift = 1.15f) { + this->shift = shift; + for (int i = 1; i < TIMESTEPS + 1; i++) { + sigmas[i - 1] = t_to_sigma(i/TIMESTEPS * TIMESTEPS); + } + } + + float sigma_min() { + return sigmas[0]; + } + + float sigma_max() { + return sigmas[TIMESTEPS - 1]; + } + + float sigma_to_t(float sigma) { + return sigma; + } + + float t_to_sigma(float t) { + t = t + 1; + return flux_time_shift(shift, 1.0f, t / TIMESTEPS); + } + + std::vector get_scalings(float sigma) { + float c_skip = 1.0f; + float c_out = -sigma; + float c_in = 1.0f; + return {c_skip, c_out, c_in}; + } + + // this function will modify noise/latent + ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) { + ggml_tensor_scale(noise, sigma); + ggml_tensor_scale(latent, 1.0f - sigma); + ggml_tensor_add(latent, noise); + return latent; + } + + ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) { + ggml_tensor_scale(latent, 1.0f / (1.0f - sigma)); + return latent; + } +}; + typedef std::function denoise_cb_t; // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t diff --git a/diffusion_model.hpp b/diffusion_model.hpp index fb28494..5c214e1 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -3,6 +3,7 @@ #include "mmdit.hpp" #include "unet.hpp" +#include "flux.hpp" struct DiffusionModel { virtual void compute(int n_threads, @@ -11,6 +12,7 @@ struct DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -29,7 +31,7 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : unet(backend, wtype, version) { } @@ -63,6 +65,7 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -77,7 +80,7 @@ struct MMDiTModel : public DiffusionModel { MMDiTModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_3_2B) + SDVersion version = VERSION_SD3_2B) : mmdit(backend, wtype, version) { } @@ -111,6 +114,7 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -120,4 +124,54 @@ struct MMDiTModel : public DiffusionModel { } }; + +struct FluxModel : public DiffusionModel { + Flux::FluxRunner flux; + + FluxModel(ggml_backend_t backend, + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) + : flux(backend, wtype, version) { + } + + void alloc_params_buffer() { + flux.alloc_params_buffer(); + } + + void free_params_buffer() { + flux.free_params_buffer(); + } + + void free_compute_buffer() { + flux.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) { + flux.get_param_tensors(tensors, "model.diffusion_model"); + } + + size_t get_params_buffer_size() { + return flux.get_params_buffer_size(); + } + + int64_t get_adm_in_channels() { + return 768; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx); + } +}; + #endif \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 6675095..bb773da 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -7,9 +7,8 @@ #include // #include "preprocessing.hpp" -#include "mmdit.hpp" +#include "flux.hpp" #include "stable-diffusion.h" -#include "t5.hpp" #define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_STATIC @@ -68,6 +67,9 @@ struct SDParams { SDMode mode = TXT2IMG; std::string model_path; + std::string clip_l_path; + std::string t5xxl_path; + std::string diffusion_model_path; std::string vae_path; std::string taesd_path; std::string esrgan_path; @@ -85,6 +87,7 @@ struct SDParams { std::string negative_prompt; float min_cfg = 1.0f; float cfg_scale = 7.0f; + float guidance = 3.5f; float style_ratio = 20.f; int clip_skip = -1; // <= 0 represents unspecified int width = 512; @@ -120,6 +123,9 @@ void print_params(SDParams params) { printf(" mode: %s\n", modes_str[params.mode]); printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); + printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); + printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); + printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); @@ -140,6 +146,7 @@ void print_params(SDParams params) { printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" guidance: %.2f\n", params.guidance); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); @@ -240,6 +247,24 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.model_path = argv[i]; + } else if (arg == "--clip_l") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_l_path = argv[i]; + } else if (arg == "--t5xxl") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.t5xxl_path = argv[i]; + } else if (arg == "--diffusion-model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.diffusion_model_path = argv[i]; } else if (arg == "--vae") { if (++i >= argc) { invalid_arg = true; @@ -359,6 +384,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.cfg_scale = std::stof(argv[i]); + } else if (arg == "--guidance") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.guidance = std::stof(argv[i]); } else if (arg == "--strength") { if (++i >= argc) { invalid_arg = true; @@ -501,8 +532,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } - if (params.model_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: model_path\n"); + if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { + fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); print_usage(argc, argv); exit(1); } @@ -570,6 +601,7 @@ std::string get_image_params(SDParams params, int64_t seed) { } parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; + parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", "; @@ -717,6 +749,9 @@ int main(int argc, const char* argv[]) { } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), + params.clip_l_path.c_str(), + params.t5xxl_path.c_str(), + params.diffusion_model_path.c_str(), params.vae_path.c_str(), params.taesd_path.c_str(), params.controlnet_path.c_str(), @@ -770,6 +805,7 @@ int main(int argc, const char* argv[]) { params.negative_prompt.c_str(), params.clip_skip, params.cfg_scale, + params.guidance, params.width, params.height, params.sample_method, @@ -830,6 +866,7 @@ int main(int argc, const char* argv[]) { params.negative_prompt.c_str(), params.clip_skip, params.cfg_scale, + params.guidance, params.width, params.height, params.sample_method, diff --git a/flux.hpp b/flux.hpp new file mode 100644 index 0000000..f8c69de --- /dev/null +++ b/flux.hpp @@ -0,0 +1,963 @@ +#ifndef __FLUX_HPP__ +#define __FLUX_HPP__ + +#include + +#include "ggml_extend.hpp" +#include "model.h" + +#define FLUX_GRAPH_SIZE 10240 + +namespace Flux { + +struct MLPEmbedder : public UnaryBlock { +public: + MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { + blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); + blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., in_dim] + // return: [..., hidden_dim] + auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); + auto out_layer = std::dynamic_pointer_cast(blocks["out_layer"]); + + x = in_layer->forward(ctx, x); + x = ggml_silu_inplace(ctx, x); + x = out_layer->forward(ctx, x); + return x; + } +}; + +class RMSNorm : public UnaryBlock { +protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + +public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["scale"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } +}; + + +struct QKNorm : public GGMLBlock { +public: + QKNorm(int64_t dim) { + blocks["query_norm"] = std::shared_ptr(new RMSNorm(dim)); + blocks["key_norm"] = std::shared_ptr(new RMSNorm(dim)); + } + + struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["query_norm"]); + + x = norm->forward(ctx, x); + return x; + } + + struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["key_norm"]); + + x = norm->forward(ctx, x); + return x; + } +}; + +__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe) { + // x: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] + x = ggml_reshape_4d(ctx, x, 2, d_head/2, L, n_head * N); // [N * n_head, L, d_head/2, 2] + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] + x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + + auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head*N); // [N*n_head, L, d_head] + return x_out; +} + +__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + struct ggml_tensor* pe) { + // q,k,v: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + // return: [N, L, n_head*d_head] + q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] + k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] + + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] + return x; +} + +struct SelfAttention : public GGMLBlock { +public: + int64_t num_heads; + +public: + SelfAttention(int64_t dim, + int64_t num_heads = 8, + bool qkv_bias = false) + : num_heads(num_heads) { + int64_t head_dim = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); + } + + std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + + + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + return {q, k, v}; + } + + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + x = proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + // x: [N, n_token, dim] + // pe: [n_token, d_head/2, 2, 2] + // return [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] + return x; + } +}; + + +struct ModulationOut { + ggml_tensor* shift = NULL; + ggml_tensor* scale = NULL; + ggml_tensor* gate = NULL; + + ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) + : shift(shift), scale(scale), gate(gate) {} +}; + +struct Modulation : public GGMLBlock { +public: + bool is_double; + int multiplier; +public: + Modulation(int64_t dim, bool is_double): is_double(is_double) { + multiplier = is_double? 6 : 3; + blocks["lin"] = std::shared_ptr(new Linear(dim, dim * multiplier)); + } + + std::vector forward(struct ggml_context* ctx, struct ggml_tensor* vec) { + // x: [N, dim] + // return: [ModulationOut, ModulationOut] + auto lin = std::dynamic_pointer_cast(blocks["lin"]); + + auto out = ggml_silu(ctx, vec); + out = lin->forward(ctx, out); // [N, multiplier*dim] + + auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] + auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] + auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] + + if (is_double) { + auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] + auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] + auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + } + + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + } +}; + +__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* shift, + struct ggml_tensor* scale) { + // x: [N, L, C] + // scale: [N, C] + // shift: [N, C] + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); + return x; +} + +struct DoubleStreamBlock : public GGMLBlock { +public: + DoubleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio, + bool qkv_bias = false) { + int64_t mlp_hidden_dim = hidden_size * mlp_ratio; + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + + blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); + + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + + blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); + } + + std::pair forward(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // img: [N, n_img_token, hidden_size] + // txt: [N, n_txt_token, hidden_size] + // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] + // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) + + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); + auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); + + auto img_norm2 = std::dynamic_pointer_cast(blocks["img_norm2"]); + auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); + auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); + + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); + auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); + + auto txt_norm2 = std::dynamic_pointer_cast(blocks["txt_norm2"]); + auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); + auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); + + + auto img_mods = img_mod->forward(ctx, vec); + ModulationOut img_mod1 = img_mods[0]; + ModulationOut img_mod2 = img_mods[1]; + auto txt_mods = txt_mod->forward(ctx, vec); + ModulationOut txt_mod1 = txt_mods[0]; + ModulationOut txt_mod2 = txt_mods[1]; + + // prepare image for attention + auto img_modulated = img_norm1->forward(ctx, img); + img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale); + auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head] + auto img_q = img_qkv[0]; + auto img_k = img_qkv[1]; + auto img_v = img_qkv[2]; + + // prepare txt for attention + auto txt_modulated = txt_norm1->forward(ctx, txt); + txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); + auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head] + auto txt_q = txt_qkv[0]; + auto txt_k = txt_qkv[1]; + auto txt_v = txt_qkv[2]; + + // run actual attention + auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + + auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto txt_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + txt->ne[1], + attn->nb[1], + attn->nb[2], + 0); // [n_txt_token, N, hidden_size] + txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + auto img_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + img->ne[1], + attn->nb[1], + attn->nb[2], + attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + + // calculate the img bloks + img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); + + auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); + img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out); + img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); + + img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate)); + + // calculate the txt bloks + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); + + auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); + txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out); + txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); + + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate)); + + return {img, txt}; + } +}; + + +struct SingleStreamBlock : public GGMLBlock { +public: + int64_t num_heads; + int64_t hidden_size; + int64_t mlp_hidden_dim; +public: + SingleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio = 4.0f, + float qk_scale = 0.f) : + hidden_size(hidden_size), num_heads(num_heads) { + int64_t head_dim = hidden_size / num_heads; + float scale = qk_scale; + if (scale <= 0.f) { + scale = 1 / sqrt((float)head_dim); + } + mlp_hidden_dim = hidden_size * mlp_ratio; + + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)); + blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + // mlp_act is nn.GELU(approximate="tanh") + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // x: [N, n_token, hidden_size] + // pe: [n_token, d_head/2, 2, 2] + // return: [N, n_token, hidden_size] + + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + auto mods = modulation->forward(ctx, vec); + ModulationOut mod = mods[0]; + + auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] + qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] + + auto qkv = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + hidden_size * 3, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + 0); // [hidden_size * 3 , N, n_token] + qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] + auto mlp = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + mlp_hidden_dim, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] + mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] + + auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size] + int64_t head_dim = hidden_size / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + + auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] + auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] + + output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate)); + return output; + } +}; + + +struct LastLayer : public GGMLBlock { +public: + LastLayer(int64_t hidden_size, + int64_t patch_size, + int64_t out_channels) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + // return: [N, n_token, patch_size * patch_size * out_channels] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + + x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); + x = linear->forward(ctx, x); + + return x; + } +}; + +struct FluxParams { + int64_t in_channels = 64; + int64_t vec_in_dim=768; + int64_t context_in_dim = 4096; + int64_t hidden_size = 3072; + float mlp_ratio = 4.0f; + int64_t num_heads = 24; + int64_t depth = 19; + int64_t depth_single_blocks = 38; + std::vector axes_dim = {16, 56, 56}; + int64_t axes_dim_sum = 128; + int theta = 10000; + bool qkv_bias = true; + bool guidance_embed = true; +}; + + +struct Flux : public GGMLBlock { +public: + std::vector linspace(float start, float end, int num) { + std::vector result(num); + float step = (end - start) / (num - 1); + for (int i = 0; i < num; ++i) { + result[i] = start + i * step; + } + return result; + } + + std::vector> transpose(const std::vector>& mat) { + int rows = mat.size(); + int cols = mat[0].size(); + std::vector> transposed(cols, std::vector(rows)); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + transposed[j][i] = mat[i][j]; + } + } + return transposed; + } + + std::vector flatten(const std::vector>& vec) { + std::vector flat_vec; + for (const auto& sub_vec : vec) { + flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); + } + return flat_vec; + } + + std::vector> rope(const std::vector& pos, int dim, int theta) { + assert(dim % 2 == 0); + int half_dim = dim / 2; + + std::vector scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim); + + std::vector omega(half_dim); + for (int i = 0; i < half_dim; ++i) { + omega[i] = 1.0 / std::pow(theta, scale[i]); + } + + int pos_size = pos.size(); + std::vector> out(pos_size, std::vector(half_dim)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + out[i][j] = pos[i] * omega[j]; + } + } + + std::vector> result(pos_size, std::vector(half_dim * 4)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + result[i][4 * j] = std::cos(out[i][j]); + result[i][4 * j + 1] = -std::sin(out[i][j]); + result[i][4 * j + 2] = std::sin(out[i][j]); + result[i][4 * j + 3] = std::cos(out[i][j]); + } + } + + return result; + } + + // Generate IDs for image patches and text + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; + + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); + + std::vector row_ids = linspace(0, h_len - 1, h_len); + std::vector col_ids = linspace(0, w_len - 1, w_len); + + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][1] = row_ids[i]; + img_ids[i * w_len + j][2] = col_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < img_ids.size(); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + + std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); + std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < context_len; ++j) { + ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + } + for (int j = 0; j < img_ids.size(); ++j) { + ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + } + } + + return ids; + } + + // Generate positional embeddings + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); + std::vector> trans_ids = transpose(ids); + size_t pos_len = ids.size(); + int num_axes = axes_dim.size(); + for (int i = 0; i < pos_len; i++) { + // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; + } + + + int emb_dim = 0; + for (int d : axes_dim) emb_dim += d / 2; + + std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); + int offset = 0; + for (int i = 0; i < num_axes; ++i) { + std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + for (int b = 0; b < bs; ++b) { + for (int j = 0; j < pos_len; ++j) { + for (int k = 0; k < rope_emb[0].size(); ++k) { + emb[b * pos_len + j][offset + k] = rope_emb[j][k]; + } + } + } + offset += rope_emb[0].size(); + } + + return flatten(emb); + } +public: + FluxParams params; + Flux() {} + Flux(FluxParams params) : params(params) { + int64_t out_channels = params.in_channels; + int64_t pe_dim = params.hidden_size / params.num_heads; + + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size)); + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } + blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size)); + + for (int i = 0; i < params.depth; i++) { + blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio, + params.qkv_bias)); + } + + for (int i = 0; i < params.depth_single_blocks; i++) { + blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio)); + } + + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); + } + + struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size) { + // x: [N, C, H, W] + // return: [N, h*w, C * patch_size * patch_size] + int64_t N = x->ne[3]; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t p = patch_size; + int64_t h = H / patch_size; + int64_t w = W / patch_size; + + GGML_ASSERT(h * p == H && w * p == W); + + x = ggml_reshape_4d(ctx, x, p, w, p, h*C*N); // [N*C*h, p, w, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] + x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] + x = ggml_reshape_3d(ctx, x, p*p*C, w*h, N); // [N, h*w, C*p*p] + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t h, + int64_t w, + int64_t patch_size) { + // x: [N, h*w, C*patch_size*patch_size] + // return: [N, C, H, W] + int64_t N = x->ne[2]; + int64_t C = x->ne[0] / patch_size / patch_size; + int64_t H = h * patch_size; + int64_t W = w * patch_size; + int64_t p = patch_size; + + GGML_ASSERT(C * p * p == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] + x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] + + return x; + } + + struct ggml_tensor* forward_orig(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* timesteps, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); + auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); + + img = img_in->forward(ctx, img); + auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } + + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); + txt = txt_in->forward(ctx, txt); + + for (int i = 0; i < params.depth; i++) { + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); + + auto img_txt = block->forward(ctx, img, txt, vec, pe); + img = img_txt.first; // [N, n_img_token, hidden_size] + txt = img_txt.second; // [N, n_txt_token, hidden_size] + } + + auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + for (int i = 0; i < params.depth_single_blocks; i++) { + auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); + + txt_img = block->forward(ctx, txt_img, vec, pe); + } + + txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + img = ggml_view_3d(ctx, + txt_img, + txt_img->ne[0], + txt_img->ne[1], + img->ne[1], + txt_img->nb[1], + txt_img->nb[2], + txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + + img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + + return img; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // timestep: (N,) tensor of diffusion timesteps + // context: (N, L, D) + // y: (N, adm_in_channels) tensor of class labels + // guidance: (N,) + // pe: (L, d_head/2, 2, 2) + // return: (N, C, H, W) + + GGML_ASSERT(x->ne[3] == 1); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + + // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) + out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] + + return out; + } +}; + + +struct FluxRunner : public GGMLRunner { +public: + FluxParams flux_params; + Flux flux; + std::vector pe_vec; // for cache + + FluxRunner(ggml_backend_t backend, + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) + : GGMLRunner(backend, wtype) { + if (version == VERSION_FLUX_SCHNELL) { + flux_params.guidance_embed = false; + } + flux = Flux(flux_params); + flux.init(params_ctx, wtype); + } + + std::string get_desc() { + return "flux"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + flux.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance) { + GGML_ASSERT(x->ne[3] == 1); + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + + x = to_backend(x); + context = to_backend(context); + y = to_backend(y); + timesteps = to_backend(timesteps); + guidance = to_backend(guidance); + + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum/2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); + + + struct ggml_tensor* out = flux.forward(compute_ctx, + x, + timesteps, + context, + y, + guidance, + pe); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // context: [N, max_position, hidden_size] + // y: [N, adm_in_channels] or [1, adm_in_channels] + // guidance: [N, ] + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, y, guidance); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + // cpu f16: + // cuda f16: nan + // cuda q8_0: pass + auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); + ggml_set_f32(x, 0.01f); + // print_ggml_tensor(x); + + std::vector timesteps_vec(1, 999.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + + std::vector guidance_vec(1, 3.5f); + auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); + + auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); + ggml_set_f32(context, 0.01f); + // print_ggml_tensor(context); + + auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); + ggml_set_f32(y, 0.01f); + // print_ggml_tensor(y); + + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("flux test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + ggml_backend_t backend = ggml_backend_cuda_init(0); + // ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); + { + LOG_INFO("loading from '%s'", file_path.c_str()); + + flux->alloc_params_buffer(); + std::map tensors; + flux->get_param_tensors(tensors, "model.diffusion_model"); + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + bool success = model_loader.load_tensors(tensors, backend); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("flux model loaded"); + } + flux->test(); + } +}; + +} // namespace Flux + +#endif // __FLUX_HPP__ \ No newline at end of file diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 14ad37c..3ad9906 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -627,6 +627,20 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1(struct ggml_context* return x; // [N, OC, T, OH * OW] } +// qkv: [N, L, 3*C] +// return: ([N, L, C], [N, L, C], [N, L, C]) +__STATIC_INLINE__ std::vector split_qkv(struct ggml_context* ctx, + struct ggml_tensor* qkv) { + qkv = ggml_reshape_4d(ctx, qkv, qkv->ne[0] / 3, 3, qkv->ne[1], qkv->ne[2]); // [N, L, 3, C] + qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 0, 3, 1, 2)); // [3, N, L, C] + + int64_t offset = qkv->nb[2] * qkv->ne[2]; + auto q = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 0); // [N, L, C] + auto k = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 1); // [N, L, C] + auto v = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 2); // [N, L, C] + return {q, k, v}; +} + // q: [N * n_head, n_token, d_head] // k: [N * n_head, n_k, d_head] // v: [N * n_head, d_head, n_k] @@ -653,9 +667,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx return kqv; } -// q: [N, L_q, C] -// k: [N, L_k, C] -// v: [N, L_k, C] +// q: [N, L_q, C] or [N*n_head, L_q, d_head] +// k: [N, L_k, C] or [N*n_head, L_k, d_head] +// v: [N, L_k, C] or [N, L_k, n_head, d_head] // return: [N, L_q, C] __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, struct ggml_tensor* q, @@ -663,38 +677,61 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* struct ggml_tensor* v, int64_t n_head, struct ggml_tensor* mask = NULL, - bool diag_mask_inf = false) { - int64_t L_q = q->ne[1]; - int64_t L_k = k->ne[1]; - int64_t C = q->ne[0]; - int64_t N = q->ne[2]; + bool diag_mask_inf = false, + bool skip_reshape = false) { + int64_t L_q; + int64_t L_k; + int64_t C ; + int64_t N ; + int64_t d_head; + if (!skip_reshape) { + L_q = q->ne[1]; + L_k = k->ne[1]; + C = q->ne[0]; + N = q->ne[2]; + d_head = C / n_head; + q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] + q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + + k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + + v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + } else { + L_q = q->ne[1]; + L_k = k->ne[1]; + d_head = v->ne[0]; + N = v->ne[3]; + C = d_head * n_head; + } - int64_t d_head = C / n_head; float scale = (1.0f / sqrt((float)d_head)); - q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] - q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + bool use_flash_attn = false; + ggml_tensor* kqv = NULL; + if (use_flash_attn) { + v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + LOG_DEBUG("k->ne[1] == %d", k->ne[1]); + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0); + } else { + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] + v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] - k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] + kq = ggml_scale_inplace(ctx, kq, scale); + if (mask) { + kq = ggml_add(ctx, kq, mask); + } + if (diag_mask_inf) { + kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + } + kq = ggml_soft_max_inplace(ctx, kq); - v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] - v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] - - auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] - kq = ggml_scale_inplace(ctx, kq, scale); - if (mask) { - kq = ggml_add(ctx, kq, mask); + kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] } - if (diag_mask_inf) { - kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - } - kq = ggml_soft_max_inplace(ctx, kq); - - auto kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head] @@ -846,7 +883,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( struct ggml_context* ctx, struct ggml_tensor* timesteps, int dim, - int max_period = 10000) { + int max_period = 10000, + float time_factor = 1.0f) { + timesteps = ggml_scale(ctx, timesteps, time_factor); return ggml_timestep_embedding(ctx, timesteps, dim, max_period); } diff --git a/mmdit.hpp b/mmdit.hpp index 7d7b22d..0a4d831 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -142,20 +142,6 @@ public: } }; -__STATIC_INLINE__ std::vector split_qkv(struct ggml_context* ctx, - struct ggml_tensor* qkv) { - // qkv: [N, L, 3*C] - // return: ([N, L, C], [N, L, C], [N, L, C]) - qkv = ggml_reshape_4d(ctx, qkv, qkv->ne[0] / 3, 3, qkv->ne[1], qkv->ne[2]); // [N, L, 3, C] - qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 0, 3, 1, 2)); // [3, N, L, C] - - int64_t offset = qkv->nb[2] * qkv->ne[2]; - auto q = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 0); // [N, L, C] - auto k = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 1); // [N, L, C] - auto v = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 2); // [N, L, C] - return {q, k, v}; -} - class SelfAttention : public GGMLBlock { public: int64_t num_heads; @@ -469,7 +455,7 @@ public: struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - SDVersion version = VERSION_3_2B; + SDVersion version = VERSION_SD3_2B; int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; @@ -487,7 +473,7 @@ protected: } public: - MMDiT(SDVersion version = VERSION_3_2B) + MMDiT(SDVersion version = VERSION_SD3_2B) : version(version) { // input_size is always None // learn_sigma is always False @@ -501,7 +487,7 @@ public: // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { input_size = -1; patch_size = 2; in_channels = 16; @@ -669,7 +655,7 @@ struct MMDiTRunner : public GGMLRunner { MMDiTRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_3_2B) + SDVersion version = VERSION_SD3_2B) : GGMLRunner(backend, wtype), mmdit(version) { mmdit.init(params_ctx, wtype); } diff --git a/model.cpp b/model.cpp index 7ab2287..c372b91 100644 --- a/model.cpp +++ b/model.cpp @@ -1291,15 +1291,22 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; + bool is_flux = false; for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { + return VERSION_FLUX_DEV; + } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + } if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) { - return VERSION_3_2B; + return VERSION_SD3_2B; } if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { - return VERSION_XL; + return VERSION_SDXL; } if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { - return VERSION_XL; + return VERSION_SDXL; } if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { return VERSION_SVD; @@ -1315,10 +1322,13 @@ SDVersion ModelLoader::get_sd_version() { // break; } } + if (is_flux) { + return VERSION_FLUX_SCHNELL; + } if (token_embedding_weight.ne[0] == 768) { - return VERSION_1_x; + return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { - return VERSION_2_x; + return VERSION_SD2; } return VERSION_COUNT; } @@ -1330,8 +1340,68 @@ ggml_type ModelLoader::get_sd_wtype() { } if (tensor_storage.name.find(".weight") != std::string::npos && - (tensor_storage.name.find("time_embed") != std::string::npos) || - tensor_storage.name.find("context_embedder") != std::string::npos) { + (tensor_storage.name.find("time_embed") != std::string::npos || + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + +ggml_type ModelLoader::get_conditioner_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if ((tensor_storage.name.find("text_encoders") == std::string::npos && + tensor_storage.name.find("cond_stage_model") == std::string::npos && + tensor_storage.name.find("te.text_model.") == std::string::npos && + tensor_storage.name.find("conditioner") == std::string::npos)) { + continue; + } + + if (tensor_storage.name.find(".weight") != std::string::npos) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + + +ggml_type ModelLoader::get_diffusion_model_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) { + continue; + } + + if (tensor_storage.name.find(".weight") != std::string::npos && + (tensor_storage.name.find("time_embed") != std::string::npos || + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + +ggml_type ModelLoader::get_vae_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (tensor_storage.name.find("vae.") == std::string::npos && + tensor_storage.name.find("first_stage_model") == std::string::npos) { + continue; + } + + if (tensor_storage.name.find(".weight")) { return tensor_storage.type; } } diff --git a/model.h b/model.h index 5bfce30..2f08669 100644 --- a/model.h +++ b/model.h @@ -18,11 +18,13 @@ #define SD_MAX_DIMS 5 enum SDVersion { - VERSION_1_x, - VERSION_2_x, - VERSION_XL, + VERSION_SD1, + VERSION_SD2, + VERSION_SDXL, VERSION_SVD, - VERSION_3_2B, + VERSION_SD3_2B, + VERSION_FLUX_DEV, + VERSION_FLUX_SCHNELL, VERSION_COUNT, }; @@ -144,6 +146,9 @@ public: bool init_from_file(const std::string& file_path, const std::string& prefix = ""); SDVersion get_sd_version(); ggml_type get_sd_wtype(); + ggml_type get_conditioner_wtype(); + ggml_type get_diffusion_model_wtype(); + ggml_type get_vae_wtype(); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); bool load_tensors(std::map& tensors, ggml_backend_t backend, diff --git a/pmid.hpp b/pmid.hpp index d1d8c31..381050f 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -161,7 +161,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { struct PhotoMakerIDEncoder : public GGMLRunner { public: - SDVersion version = VERSION_XL; + SDVersion version = VERSION_SDXL; PhotoMakerIDEncoderBlock id_encoder; float style_strength; @@ -175,7 +175,7 @@ public: std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_XL, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f) : GGMLRunner(backend, wtype), version(version), style_strength(sty) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 34bf8f5..619da29 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -25,11 +25,13 @@ // #include "stb_image_write.h" const char* model_version_to_str[] = { - "1.x", - "2.x", - "XL", + "SD 1.x", + "SD 2.x", + "SDXL", "SVD", - "3 2B"}; + "SD3 2B", + "Flux Dev", + "Flux Schnell"}; const char* sampling_methods_str[] = { "Euler A", @@ -67,7 +69,11 @@ public: ggml_backend_t clip_backend = NULL; ggml_backend_t control_net_backend = NULL; ggml_backend_t vae_backend = NULL; - ggml_type model_data_type = GGML_TYPE_COUNT; + ggml_type model_wtype = GGML_TYPE_COUNT; + ggml_type conditioner_wtype = GGML_TYPE_COUNT; + ggml_type diffusion_model_wtype = GGML_TYPE_COUNT; + ggml_type vae_wtype = GGML_TYPE_COUNT; + SDVersion version; bool vae_decode_only = false; @@ -131,6 +137,9 @@ public: } bool load_from_file(const std::string& model_path, + const std::string& clip_l_path, + const std::string& t5xxl_path, + const std::string& diffusion_model_path, const std::string& vae_path, const std::string control_net_path, const std::string embeddings_path, @@ -164,14 +173,36 @@ public: LOG_INFO("Flash Attention enabled"); #endif #endif - LOG_INFO("loading model from '%s'", model_path.c_str()); ModelLoader model_loader; vae_tiling = vae_tiling_; - if (!model_loader.init_from_file(model_path)) { - LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); - return false; + if (model_path.size() > 0) { + LOG_INFO("loading model from '%s'", model_path.c_str()); + if (!model_loader.init_from_file(model_path)) { + LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + } + } + + if (clip_l_path.size() > 0) { + LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); + if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) { + LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); + } + } + + if (t5xxl_path.size() > 0) { + LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); + if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) { + LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); + } + } + + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } } if (vae_path.size() > 0) { @@ -187,16 +218,45 @@ public: return false; } - LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); + LOG_INFO("Version: %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { - model_data_type = model_loader.get_sd_wtype(); + model_wtype = model_loader.get_sd_wtype(); + if (model_wtype == GGML_TYPE_COUNT) { + model_wtype = GGML_TYPE_F32; + LOG_WARN("can not get mode wtype frome weight, use f32"); + } + conditioner_wtype = model_loader.get_conditioner_wtype(); + if (conditioner_wtype == GGML_TYPE_COUNT) { + conditioner_wtype = wtype; + } + diffusion_model_wtype = model_loader.get_diffusion_model_wtype(); + if (diffusion_model_wtype == GGML_TYPE_COUNT) { + diffusion_model_wtype = wtype; + } + vae_wtype = model_loader.get_vae_wtype(); + + if (vae_wtype == GGML_TYPE_COUNT) { + vae_wtype = wtype; + } } else { - model_data_type = wtype; + model_wtype = wtype; + conditioner_wtype = wtype; + diffusion_model_wtype = wtype; + vae_wtype = wtype; } - LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type)); + + if (version == VERSION_SDXL) { + vae_wtype = GGML_TYPE_F32; + } + + LOG_INFO("Weight type: %s", ggml_type_name(model_wtype)); + LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype)); + LOG_INFO("Diffsuion model weight type: %s", ggml_type_name(diffusion_model_wtype)); + LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype)); + LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { scale_factor = 0.13025f; if (vae_path.size() == 0 && taesd_path.size() == 0) { LOG_WARN( @@ -205,26 +265,33 @@ public: "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_3_2B) { + } else if (version == VERSION_SD3_2B) { scale_factor = 1.5305f; + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + scale_factor = 0.3611; + // TODO: shift_factor } if (version == VERSION_SVD) { - clip_vision = std::make_shared(backend, model_data_type); + clip_vision = std::make_shared(backend, conditioner_wtype); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); - diffusion_model = std::make_shared(backend, model_data_type, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - first_stage_model = std::make_shared(backend, model_data_type, vae_decode_only, true, version); + first_stage_model = std::make_shared(backend, vae_wtype, vae_decode_only, true, version); LOG_DEBUG("vae_decode_only %d", vae_decode_only); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { clip_backend = backend; - if (!ggml_backend_is_cpu(backend) && version == VERSION_3_2B && model_data_type != GGML_TYPE_F32) { + bool use_t5xxl = false; + if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + use_t5xxl = true; + } + if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { clip_on_cpu = true; LOG_INFO("set clip_on_cpu to true"); } @@ -232,12 +299,15 @@ public: LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_3_2B) { - cond_stage_model = std::make_shared(clip_backend, model_data_type); - diffusion_model = std::make_shared(backend, model_data_type, version); + if (version == VERSION_SD3_2B) { + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else { - cond_stage_model = std::make_shared(clip_backend, model_data_type, embeddings_path, version); - diffusion_model = std::make_shared(backend, model_data_type, version); + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -245,11 +315,6 @@ public: diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - ggml_type vae_type = model_data_type; - if (version == VERSION_XL) { - vae_type = GGML_TYPE_F32; // avoid nan, not work... - } - if (!use_tiny_autoencoder) { if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); @@ -257,11 +322,11 @@ public: } else { vae_backend = backend; } - first_stage_model = std::make_shared(vae_backend, vae_type, vae_decode_only, false, version); + first_stage_model = std::make_shared(vae_backend, vae_wtype, vae_decode_only, false, version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, model_data_type, vae_decode_only); + tae_first_stage = std::make_shared(backend, vae_wtype, vae_decode_only); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -273,12 +338,12 @@ public: } else { controlnet_backend = backend; } - control_net = std::make_shared(controlnet_backend, model_data_type, version); + control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); } - pmid_model = std::make_shared(clip_backend, model_data_type, version); + pmid_model = std::make_shared(clip_backend, model_wtype, version); if (id_embeddings_path.size() > 0) { - pmid_lora = std::make_shared(backend, model_data_type, id_embeddings_path, ""); + pmid_lora = std::make_shared(backend, model_wtype, id_embeddings_path, ""); if (!pmid_lora->load_from_file(true)) { LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str()); return false; @@ -423,7 +488,7 @@ public: // check is_using_v_parameterization_for_sd2 bool is_using_v_parameterization = false; - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { if (is_using_v_parameterization_for_sd2(ctx)) { is_using_v_parameterization = true; } @@ -432,9 +497,16 @@ public: is_using_v_parameterization = true; } - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + LOG_INFO("running in Flux FLOW mode"); + float shift = 1.15f; + if (version == VERSION_FLUX_SCHNELL) { + shift = 1.0f; // TODO: validate + } + denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); @@ -489,7 +561,7 @@ public: ggml_set_f32(timesteps, 999); int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -522,7 +594,7 @@ public: LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); return; } - LoraModel lora(backend, model_data_type, file_path); + LoraModel lora(backend, model_wtype, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); return; @@ -538,7 +610,7 @@ public: } void apply_loras(const std::unordered_map& lora_state) { - if (lora_state.size() > 0 && model_data_type != GGML_TYPE_F16 && model_data_type != GGML_TYPE_F32) { + if (lora_state.size() > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) { LOG_WARN("In quantized models when applying LoRA, the images have poor quality."); } std::unordered_map lora_state_diff; @@ -663,6 +735,7 @@ public: float control_strength, float min_cfg, float cfg_scale, + float guidance, sample_method_t method, const std::vector& sigmas, int start_merge_step, @@ -701,6 +774,8 @@ public: float t = denoiser->sigma_to_t(sigma); std::vector timesteps_vec(x->ne[3], t); // [N, ] auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + std::vector guidance_vec(x->ne[3], guidance); + auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); copy_ggml_tensor(noised_input, input); // noised_input = noised_input * c_in @@ -723,6 +798,7 @@ public: cond.c_crossattn, cond.c_concat, cond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -734,6 +810,7 @@ public: id_cond.c_crossattn, cond.c_concat, id_cond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -753,6 +830,7 @@ public: uncond.c_crossattn, uncond.c_concat, uncond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -838,7 +916,9 @@ public: if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { + C = 32; + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { C = 32; } } @@ -904,6 +984,9 @@ struct sd_ctx_t { }; sd_ctx_t* new_sd_ctx(const char* model_path_c_str, + const char* clip_l_path_c_str, + const char* t5xxl_path_c_str, + const char* diffusion_model_path_c_str, const char* vae_path_c_str, const char* taesd_path_c_str, const char* control_net_path_c_str, @@ -925,6 +1008,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, return NULL; } std::string model_path(model_path_c_str); + std::string clip_l_path(clip_l_path_c_str); + std::string t5xxl_path(t5xxl_path_c_str); + std::string diffusion_model_path(diffusion_model_path_c_str); std::string vae_path(vae_path_c_str); std::string taesd_path(taesd_path_c_str); std::string control_net_path(control_net_path_c_str); @@ -942,6 +1028,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, } if (!sd_ctx->sd->load_from_file(model_path, + clip_l_path, + t5xxl_path_c_str, + diffusion_model_path, vae_path, control_net_path, embd_path, @@ -976,6 +1065,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, std::string negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -1127,7 +1217,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_ctx->sd->version == VERSION_XL && negative_prompt.size() == 0) { + if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1156,7 +1246,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { + C = 16; + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; } int W = width / 8; @@ -1189,6 +1281,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, control_strength, cfg_scale, cfg_scale, + guidance, sample_method, sigmas, start_merge_step, @@ -1247,6 +1340,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* negative_prompt_c_str, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -1265,9 +1359,12 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { params.mem_size *= 3; } + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + params.mem_size *= 4; + } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } @@ -1288,14 +1385,18 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { + C = 16; + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { ggml_set_f32(init_latent, 0.0609f); + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); } @@ -1307,6 +1408,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, negative_prompt_c_str, clip_skip, cfg_scale, + guidance, width, height, sample_method, @@ -1332,6 +1434,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* negative_prompt_c_str, int clip_skip, float cfg_scale, + float guidance, int width, int height, sample_method_t sample_method, @@ -1351,9 +1454,12 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { params.mem_size *= 2; } + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + params.mem_size *= 3; + } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } @@ -1403,6 +1509,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, negative_prompt_c_str, clip_skip, cfg_scale, + guidance, width, height, sample_method, @@ -1510,6 +1617,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, 0.f, min_cfg, cfg_scale, + 0.f, sample_method, sigmas, -1, diff --git a/stable-diffusion.h b/stable-diffusion.h index f78748f..0225b34 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -119,6 +119,9 @@ typedef struct { typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, + const char* clip_l_path, + const char* t5xxl_path, + const char* diffusion_model_path, const char* vae_path, const char* taesd_path, const char* control_net_path_c_str, @@ -143,6 +146,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -161,6 +165,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, diff --git a/unet.hpp b/unet.hpp index 737a2bb..94a8ba4 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,7 +166,7 @@ public: // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; int out_channels = 4; @@ -177,19 +177,19 @@ protected: int time_embed_dim = 1280; // model_channels*4 int num_heads = 8; int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL + int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL public: int model_channels = 320; - int adm_in_channels = 2816; // only for VERSION_XL/SVD + int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_1_x) + UnetModelBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -211,7 +211,7 @@ public: // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_XL || version == VERSION_SVD) { + if (version == VERSION_SDXL || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -533,7 +533,7 @@ struct UNetModelRunner : public GGMLRunner { UNetModelRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : GGMLRunner(backend, wtype), unet(version) { unet.init(params_ctx, wtype); } diff --git a/vae.hpp b/vae.hpp index cb8112d..85319fd 100644 --- a/vae.hpp +++ b/vae.hpp @@ -455,9 +455,9 @@ protected: public: AutoencodingEngine(bool decode_only = true, bool use_video_decoder = false, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { dd_config.z_channels = 16; use_quant = false; } @@ -527,7 +527,7 @@ struct AutoEncoderKL : public GGMLRunner { ggml_type wtype, bool decode_only = false, bool use_video_decoder = false, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) { ae.init(params_ctx, wtype); }