From 347710f68f6c6c8e243496957f056a4b9f271d24 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 13 Nov 2025 21:48:44 +0800 Subject: [PATCH] feat: support applying LoRA at runtime (#969) --- clip.hpp | 2 +- common.hpp | 26 +-- conditioner.hpp | 41 ++++ control.hpp | 2 +- diffusion_model.hpp | 25 ++- docs/lora.md | 41 +--- esrgan.hpp | 2 +- examples/cli/README.md | 6 + examples/cli/main.cpp | 27 ++- flux.hpp | 2 +- ggml_extend.hpp | 135 +++++++++++- lora.hpp | 489 +++++++++++++++++++++++++++++++---------- mmdit.hpp | 2 +- name_conversion.cpp | 56 +++-- name_conversion.h | 4 + qwen_image.hpp | 2 +- qwenvl.hpp | 2 +- stable-diffusion.cpp | 244 ++++++++++++++++++-- stable-diffusion.h | 10 + unet.hpp | 4 +- wan.hpp | 6 +- 21 files changed, 901 insertions(+), 227 deletions(-) diff --git a/clip.hpp b/clip.hpp index eb37638..e2a892c 100644 --- a/clip.hpp +++ b/clip.hpp @@ -936,7 +936,7 @@ struct CLIPTextModelRunner : public GGMLRunner { size_t max_token_idx = 0, bool return_pooled = false, int clip_skip = -1) { - struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + struct ggml_cgraph* gf = new_graph_custom(2048); input_ids = to_backend(input_ids); diff --git a/common.hpp b/common.hpp index c68ddaf..dd8281f 100644 --- a/common.hpp +++ b/common.hpp @@ -182,31 +182,21 @@ protected: int64_t dim_in; int64_t dim_out; - void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32); - enum ggml_type bias_wtype = GGML_TYPE_F32; - params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); - params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); - } - public: GEGLU(int64_t dim_in, int64_t dim_out) - : dim_in(dim_in), dim_out(dim_out) {} + : dim_in(dim_in), dim_out(dim_out) { + blocks["proj"] = std::shared_ptr(new Linear(dim_in, dim_out * 2)); + } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [ne3, ne2, ne1, dim_in] // return: [ne3, ne2, ne1, dim_out] - struct ggml_tensor* w = params["proj.weight"]; - struct ggml_tensor* b = params["proj.bias"]; + auto proj = std::dynamic_pointer_cast(blocks["proj"]); - auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in] - auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in] - auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ] - auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] - - auto x_in = x; - x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] - auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] + x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] + auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); + x = x_vec[0]; // [ne3, ne2, ne1, dim_out] + auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); diff --git a/conditioner.hpp b/conditioner.hpp index 93e0c28..27d367a 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -34,6 +34,7 @@ struct Conditioner { virtual void free_params_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; + virtual void set_weight_adapter(const std::shared_ptr& adapter) {} virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, int n_threads, const ConditionerParams& conditioner_params) { @@ -108,6 +109,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return buffer_size; } + void set_weight_adapter(const std::shared_ptr& adapter) override { + text_model->set_weight_adapter(adapter); + if (sd_version_is_sdxl(version)) { + text_model2->set_weight_adapter(adapter); + } + } + bool load_embedding(std::string embd_name, std::string embd_path, std::vector& bpe_tokens) { // the order matters ModelLoader model_loader; @@ -764,6 +772,18 @@ struct SD3CLIPEmbedder : public Conditioner { return buffer_size; } + void set_weight_adapter(const std::shared_ptr& adapter) override { + if (clip_l) { + clip_l->set_weight_adapter(adapter); + } + if (clip_g) { + clip_g->set_weight_adapter(adapter); + } + if (t5) { + t5->set_weight_adapter(adapter); + } + } + std::vector, std::vector>> tokenize(std::string text, size_t max_length = 0, bool padding = false) { @@ -1160,6 +1180,15 @@ struct FluxCLIPEmbedder : public Conditioner { return buffer_size; } + void set_weight_adapter(const std::shared_ptr& adapter) { + if (clip_l) { + clip_l->set_weight_adapter(adapter); + } + if (t5) { + t5->set_weight_adapter(adapter); + } + } + std::vector, std::vector>> tokenize(std::string text, size_t max_length = 0, bool padding = false) { @@ -1400,6 +1429,12 @@ struct T5CLIPEmbedder : public Conditioner { return buffer_size; } + void set_weight_adapter(const std::shared_ptr& adapter) override { + if (t5) { + t5->set_weight_adapter(adapter); + } + } + std::tuple, std::vector, std::vector> tokenize(std::string text, size_t max_length = 0, bool padding = false) { @@ -1589,6 +1624,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { return buffer_size; } + void set_weight_adapter(const std::shared_ptr& adapter) override { + if (qwenvl) { + qwenvl->set_weight_adapter(adapter); + } + } + std::tuple, std::vector> tokenize(std::string text, size_t max_length = 0, size_t system_prompt_length = 0, diff --git a/control.hpp b/control.hpp index b34140e..d86f64c 100644 --- a/control.hpp +++ b/control.hpp @@ -380,7 +380,7 @@ struct ControlNet : public GGMLRunner { struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* y = nullptr) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, CONTROL_NET_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(CONTROL_NET_GRAPH_SIZE); x = to_backend(x); if (guided_hint_cached) { diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 3070498..0a3914e 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -35,8 +35,9 @@ struct DiffusionModel { virtual void free_compute_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; - virtual int64_t get_adm_in_channels() = 0; - virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_weight_adapter(const std::shared_ptr& adapter){}; + virtual int64_t get_adm_in_channels() = 0; + virtual void set_flash_attn_enabled(bool enabled) = 0; }; struct UNetModel : public DiffusionModel { @@ -73,6 +74,10 @@ struct UNetModel : public DiffusionModel { return unet.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + unet.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return unet.unet.adm_in_channels; } @@ -130,6 +135,10 @@ struct MMDiTModel : public DiffusionModel { return mmdit.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + mmdit.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return 768 + 1280; } @@ -188,6 +197,10 @@ struct FluxModel : public DiffusionModel { return flux.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + flux.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return 768; } @@ -251,6 +264,10 @@ struct WanModel : public DiffusionModel { return wan.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + wan.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return 768; } @@ -313,6 +330,10 @@ struct QwenImageModel : public DiffusionModel { return qwen_image.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + qwen_image.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return 768; } diff --git a/docs/lora.md b/docs/lora.md index 9885ae5..fe4fbc0 100644 --- a/docs/lora.md +++ b/docs/lora.md @@ -12,38 +12,15 @@ Here's a simple example: `../models/marblesh.safetensors` or `../models/marblesh.ckpt` will be applied to the model -# Support matrix +# Lora Apply Mode -> ℹ️ CUDA `get_rows` support is defined here: -> [ggml-org/ggml/src/ggml-cuda/getrows.cu#L156](https://github.com/ggml-org/ggml/blob/7dee1d6a1e7611f238d09be96738388da97c88ed/src/ggml-cuda/getrows.cu#L156) -> Currently only the basic types + Q4/Q5/Q8 are implemented. K-quants are **not** supported. +There are two ways to apply LoRA: **immediately** and **at_runtime**. You can specify it using the `--lora-apply-mode` parameter. -NOTE: The other backends may have different support. +By default, the mode is selected automatically: + +* If the model weights contain any quantized parameters, the **at_runtime** mode is used; +* Otherwise, the **immediately** mode is used. + +The **immediately** mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage. +In contrast, the **at_runtime** mode provides better compatibility and higher precision, but inference may be slower and memory usage may be higher in some cases. -| Quant / Type | CUDA | Vulkan | -|--------------|------|--------| -| F32 | ✔️ | ✔️ | -| F16 | ✔️ | ✔️ | -| BF16 | ✔️ | ✔️ | -| I32 | ✔️ | ❌ | -| Q4_0 | ✔️ | ✔️ | -| Q4_1 | ✔️ | ✔️ | -| Q5_0 | ✔️ | ✔️ | -| Q5_1 | ✔️ | ✔️ | -| Q8_0 | ✔️ | ✔️ | -| Q2_K | ❌ | ❌ | -| Q3_K | ❌ | ❌ | -| Q4_K | ❌ | ❌ | -| Q5_K | ❌ | ❌ | -| Q6_K | ❌ | ❌ | -| Q8_K | ❌ | ❌ | -| IQ1_S | ❌ | ✔️ | -| IQ1_M | ❌ | ✔️ | -| IQ2_XXS | ❌ | ✔️ | -| IQ2_XS | ❌ | ✔️ | -| IQ2_S | ❌ | ✔️ | -| IQ3_XXS | ❌ | ✔️ | -| IQ3_S | ❌ | ✔️ | -| IQ4_XS | ❌ | ✔️ | -| IQ4_NL | ❌ | ✔️ | -| MXFP4 | ❌ | ✔️ | diff --git a/esrgan.hpp b/esrgan.hpp index adce623..fb09544 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -344,7 +344,7 @@ struct ESRGAN : public GGMLRunner { if (!rrdb_net) return nullptr; constexpr int kGraphNodes = 1 << 16; // 65k - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false); + struct ggml_cgraph* gf = new_graph_custom(kGraphNodes); x = to_backend(x); auto runner_ctx = get_context(); diff --git a/examples/cli/README.md b/examples/cli/README.md index 00e0942..84df1a1 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -99,6 +99,12 @@ Options: --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow] + --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights + contain any quantized parameters, the at_runtime mode will be used; otherwise, + immediately will be used.The immediately mode may have precision and + compatibility issues with quantized parameters, but it usually offers faster inference + speed and, in some cases, lower memory usageThe at_runtime mode, on the other + hand, is exactly the opposite. --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete --skip-layers layers to skip for SLG steps (default: [7,8,9]) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 619c428..a2df094 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -137,7 +137,8 @@ struct SDParams { int chroma_t5_mask_pad = 1; float flow_shift = INFINITY; - prediction_t prediction = DEFAULT_PRED; + prediction_t prediction = DEFAULT_PRED; + lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; bool force_sdxl_vae_conv_scale = false; @@ -209,6 +210,7 @@ void print_params(SDParams params) { printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str)); printf(" moe_boundary: %.3f\n", params.moe_boundary); printf(" prediction: %s\n", sd_prediction_name(params.prediction)); + printf(" lora_apply_mode: %s\n", sd_lora_apply_mode_name(params.lora_apply_mode)); printf(" flow_shift: %.2f\n", params.flow_shift); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); @@ -926,6 +928,20 @@ void parse_args(int argc, const char** argv, SDParams& params) { return 1; }; + auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + params.lora_apply_mode = str_to_lora_apply_mode(arg); + if (params.lora_apply_mode == LORA_APPLY_MODE_COUNT) { + fprintf(stderr, "error: invalid lora apply model %s\n", + arg); + return -1; + } + return 1; + }; + auto on_sample_method_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; @@ -1123,6 +1139,14 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--prediction", "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow]", on_prediction_arg}, + {"", + "--lora-apply-mode", + "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. " + "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used." + "The immediately mode may have precision and compatibility issues with quantized parameters, " + "but it usually offers faster inference speed and, in some cases, lower memory usage" + "The at_runtime mode, on the other hand, is exactly the opposite.", + on_lora_apply_mode_arg}, {"", "--scheduler", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete", @@ -1738,6 +1762,7 @@ int main(int argc, const char* argv[]) { params.wtype, params.rng_type, params.prediction, + params.lora_apply_mode, params.offload_params_to_cpu, params.clip_on_cpu, params.control_net_cpu, diff --git a/flux.hpp b/flux.hpp index 8a255aa..2f85cf8 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1243,7 +1243,7 @@ namespace Flux { bool increase_ref_index = false, std::vector skip_layers = {}) { GGML_ASSERT(x->ne[3] == 1); - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE); struct ggml_tensor* mod_index_arange = nullptr; struct ggml_tensor* dct = nullptr; // for chroma radiance diff --git a/ggml_extend.hpp b/ggml_extend.hpp index eaf5016..aa16645 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -959,12 +959,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, int64_t ne3 = x->ne[3]; x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]); x = ggml_mul_mat(ctx, w, x); - x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3); + if (force_prec_f32) { + ggml_mul_mat_set_prec(x, GGML_PREC_F32); + } + x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3); } else { x = ggml_mul_mat(ctx, w, x); - } - if (force_prec_f32) { - ggml_mul_mat_set_prec(x, GGML_PREC_F32); + if (force_prec_f32) { + ggml_mul_mat_set_prec(x, GGML_PREC_F32); + } } if (scale != 1.f) { x = ggml_scale(ctx, x, 1.f / scale); @@ -1119,6 +1122,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx, return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3); } +__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) { + auto out = ggml_reshape_2d(ctx, a, 1, ggml_nelements(a)); + ggml_tensor* one = ggml_ext_ones(ctx, 1, 1, 1, 1); // [1,] + if (ggml_is_transposed(out)) { + out = ggml_mul_mat(ctx, one, out); + } else { + out = ggml_mul_mat(ctx, out, one); + } + out = ggml_reshape(ctx, out, a); + return out; +} + // q: [N * n_head, n_token, d_head] // k: [N * n_head, n_k, d_head] // v: [N * n_head, d_head, n_k] @@ -1460,11 +1475,43 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { #define MAX_PARAMS_TENSOR_NUM 32768 #define MAX_GRAPH_SIZE 327680 +struct WeightAdapter { + struct ForwardParams { + enum class op_type_t { + OP_LINEAR, + OP_CONV2D, + } op_type; + struct { + bool force_prec_f32 = false; + float scale = 1.f; + } linear; + struct { + int s0 = 1; + int s1 = 1; + int p0 = 0; + int p1 = 0; + int d0 = 1; + int d1 = 1; + bool direct = false; + float scale = 1.f; + } conv2d; + }; + virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0; + virtual ggml_tensor* forward_with_lora(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* w, + ggml_tensor* b, + const std::string& prefix, + ForwardParams forward_params) = 0; + virtual size_t get_extra_graph_size() = 0; +}; + struct GGMLRunnerContext { - ggml_backend_t backend = nullptr; - ggml_context* ggml_ctx = nullptr; - bool flash_attn_enabled = false; - bool conv2d_direct_enabled = false; + ggml_backend_t backend = nullptr; + ggml_context* ggml_ctx = nullptr; + bool flash_attn_enabled = false; + bool conv2d_direct_enabled = false; + std::shared_ptr weight_adapter = nullptr; }; struct GGMLRunner { @@ -1486,6 +1533,8 @@ protected: struct ggml_context* compute_ctx = nullptr; struct ggml_gallocr* compute_allocr = nullptr; + std::shared_ptr weight_adapter = nullptr; + std::vector one_vec = {1.f}; ggml_tensor* one_tensor = nullptr; @@ -1565,6 +1614,13 @@ protected: ggml_build_forward_expand(gf, one_tensor); } + struct ggml_cgraph* new_graph_custom(size_t graph_size) { + if (weight_adapter) { + graph_size += weight_adapter->get_extra_graph_size(); + } + return ggml_new_graph_custom(compute_ctx, graph_size, false); + } + struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) { prepare_build_in_tensor_before(); struct ggml_cgraph* gf = get_graph(); @@ -1760,6 +1816,7 @@ public: runner_ctx.backend = runtime_backend; runner_ctx.flash_attn_enabled = flash_attn_enabled; runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; + runner_ctx.weight_adapter = weight_adapter; return runner_ctx; } @@ -1891,6 +1948,10 @@ public: void set_conv2d_direct_enabled(bool enabled) { conv2d_direct_enabled = enabled; } + + void set_weight_adapter(const std::shared_ptr& adapter) { + weight_adapter = adapter; + } }; class GGMLBlock { @@ -2006,8 +2067,10 @@ protected: bool force_f32; bool force_prec_f32; float scale; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; @@ -2039,6 +2102,13 @@ public: if (bias) { b = params["bias"]; } + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); } }; @@ -2098,8 +2168,10 @@ protected: std::pair dilation; bool bias; float scale = 1.f; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { + this->prefix = prefix; enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { @@ -2138,6 +2210,19 @@ public: if (bias) { b = params["bias"]; } + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + forward_params.conv2d.s0 = stride.second; + forward_params.conv2d.s1 = stride.first; + forward_params.conv2d.p0 = padding.second; + forward_params.conv2d.p1 = padding.first; + forward_params.conv2d.d0 = dilation.second; + forward_params.conv2d.d1 = dilation.first; + forward_params.conv2d.direct = ctx->conv2d_direct_enabled; + forward_params.conv2d.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, @@ -2209,8 +2294,10 @@ protected: std::tuple padding; std::tuple dilation; bool bias; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { + this->prefix = prefix; enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, @@ -2242,8 +2329,17 @@ public: struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = nullptr; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + if (w->type != GGML_TYPE_F16) { + w = ggml_cast(ctx->ggml_ctx, w, GGML_TYPE_F16); + } + } if (bias) { b = params["bias"]; + if (ctx->weight_adapter) { + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + } } return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), @@ -2258,8 +2354,10 @@ protected: float eps; bool elementwise_affine; bool bias; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; if (elementwise_affine) { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); @@ -2286,8 +2384,14 @@ public: if (elementwise_affine) { w = params["weight"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } if (bias) { b = params["bias"]; + if (ctx->weight_adapter) { + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + } } } return ggml_ext_layer_norm(ctx->ggml_ctx, x, w, b, eps); @@ -2300,8 +2404,10 @@ protected: int64_t num_channels; float eps; bool affine; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; if (affine) { enum ggml_type wtype = GGML_TYPE_F32; enum ggml_type bias_wtype = GGML_TYPE_F32; @@ -2326,6 +2432,10 @@ public: if (affine) { w = params["weight"]; b = params["bias"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + } } return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups); } @@ -2341,8 +2451,10 @@ class RMSNorm : public UnaryBlock { protected: int64_t hidden_size; float eps; + std::string prefix; void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + this->prefix = prefix; enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -2355,8 +2467,11 @@ public: struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; - x = ggml_rms_norm(ctx->ggml_ctx, x, eps); - x = ggml_mul_inplace(ctx->ggml_ctx, x, w); + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + x = ggml_mul_inplace(ctx->ggml_ctx, x, w); return x; } }; diff --git a/lora.hpp b/lora.hpp index 6da9d83..daabd4a 100644 --- a/lora.hpp +++ b/lora.hpp @@ -7,22 +7,25 @@ #define LORA_GRAPH_BASE_SIZE 10240 struct LoraModel : public GGMLRunner { + std::string lora_id; float multiplier = 1.0f; - std::map lora_tensors; + std::unordered_map lora_tensors; std::map original_tensor_to_final_tensor; + std::set applied_lora_tensors; std::string file_path; ModelLoader model_loader; - bool load_failed = false; - bool applied = false; - bool tensor_preprocessed = false; - std::vector zero_index_vec = {0}; - ggml_tensor* zero_index = nullptr; + bool load_failed = false; + bool applied = false; + bool tensor_preprocessed = false; - LoraModel(ggml_backend_t backend, + typedef std::function filter_t; + + LoraModel(const std::string& lora_id, + ggml_backend_t backend, const std::string& file_path = "", std::string prefix = "", SDVersion version = VERSION_COUNT) - : file_path(file_path), GGMLRunner(backend, false) { + : lora_id(lora_id), file_path(file_path), GGMLRunner(backend, false) { prefix = "lora." + prefix; if (!model_loader.init_from_file_and_convert_name(file_path, prefix, version)) { load_failed = true; @@ -33,7 +36,7 @@ struct LoraModel : public GGMLRunner { return "lora"; } - bool load_from_file(bool filter_tensor, int n_threads) { + bool load_from_file(int n_threads, filter_t filter = nullptr) { LOG_INFO("loading LoRA from '%s'", file_path.c_str()); if (load_failed) { @@ -48,7 +51,7 @@ struct LoraModel : public GGMLRunner { if (dry_run) { const std::string& name = tensor_storage.name; - if (filter_tensor && !contains(name, "lora.model")) { + if (filter && !filter(name)) { return true; } @@ -68,6 +71,10 @@ struct LoraModel : public GGMLRunner { model_loader.load_tensors(on_new_tensor_cb, n_threads); + if (tensors_to_create.empty()) { + return true; + } + for (const auto& pair : tensors_to_create) { const auto& name = pair.first; const auto& ts = pair.second; @@ -87,14 +94,6 @@ struct LoraModel : public GGMLRunner { return true; } - ggml_tensor* to_f32(ggml_context* ctx, ggml_tensor* a) { - auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a)); - out = ggml_get_rows(ctx, out, zero_index); - out = ggml_reshape(ctx, out, a); - // auto out = ggml_cast(ctx, a, GGML_TYPE_F32); - return out; - } - void preprocess_lora_tensors(const std::map& model_tensors) { if (tensor_preprocessed) { return; @@ -102,7 +101,7 @@ struct LoraModel : public GGMLRunner { tensor_preprocessed = true; // I really hate these hardcoded processes. if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) { - std::map new_lora_tensors; + std::unordered_map new_lora_tensors; for (auto& [old_name, tensor] : lora_tensors) { std::string new_name = old_name; @@ -130,7 +129,7 @@ struct LoraModel : public GGMLRunner { } } - ggml_tensor* get_lora_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -153,17 +152,17 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(lora_up_name); if (iter != lora_tensors.end()) { - lora_up = to_f32(compute_ctx, iter->second); + lora_up = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(lora_mid_name); if (iter != lora_tensors.end()) { - lora_mid = to_f32(compute_ctx, iter->second); + lora_mid = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(lora_down_name); if (iter != lora_tensors.end()) { - lora_down = to_f32(compute_ctx, iter->second); + lora_down = ggml_ext_cast_f32(ctx, iter->second); } if (lora_up == nullptr || lora_down == nullptr) { @@ -195,32 +194,61 @@ struct LoraModel : public GGMLRunner { } scale_value *= multiplier; - auto curr_updown = ggml_ext_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); - curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); + auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid); + curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); if (updown == nullptr) { updown = curr_updown; } else { - updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1); } index++; } - - // diff - if (updown == nullptr) { - std::string lora_diff_name = "lora." + model_tensor_name + ".diff"; - - if (lora_tensors.find(lora_diff_name) != lora_tensors.end()) { - updown = to_f32(compute_ctx, lora_tensors[lora_diff_name]); - applied_lora_tensors.insert(lora_diff_name); - } - } - return updown; } - ggml_tensor* get_loha_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { + ggml_tensor* updown = nullptr; + int index = 0; + while (true) { + std::string key; + if (index == 0) { + key = model_tensor_name; + } else { + key = model_tensor_name + "." + std::to_string(index); + } + + std::string diff_name = "lora." + key + ".diff"; + + ggml_tensor* curr_updown = nullptr; + + auto iter = lora_tensors.find(diff_name); + if (iter != lora_tensors.end()) { + curr_updown = ggml_ext_cast_f32(ctx, iter->second); + } else { + break; + } + + applied_lora_tensors.insert(diff_name); + + float scale_value = 1.0f; + scale_value *= multiplier; + + curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); + + if (updown == nullptr) { + updown = curr_updown; + } else { + updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + } + + index++; + } + return updown; + } + + ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -248,34 +276,34 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(hada_1_down_name); if (iter != lora_tensors.end()) { - hada_1_down = to_f32(compute_ctx, iter->second); + hada_1_down = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(hada_1_up_name); if (iter != lora_tensors.end()) { - hada_1_up = to_f32(compute_ctx, iter->second); + hada_1_up = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(hada_1_mid_name); if (iter != lora_tensors.end()) { - hada_1_mid = to_f32(compute_ctx, iter->second); - hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); + hada_1_mid = ggml_ext_cast_f32(ctx, iter->second); + hada_1_up = ggml_cont(ctx, ggml_transpose(ctx, hada_1_up)); } iter = lora_tensors.find(hada_2_down_name); if (iter != lora_tensors.end()) { - hada_2_down = to_f32(compute_ctx, iter->second); + hada_2_down = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(hada_2_up_name); if (iter != lora_tensors.end()) { - hada_2_up = to_f32(compute_ctx, iter->second); + hada_2_up = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(hada_2_mid_name); if (iter != lora_tensors.end()) { - hada_2_mid = to_f32(compute_ctx, iter->second); - hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); + hada_2_mid = ggml_ext_cast_f32(ctx, iter->second); + hada_2_up = ggml_cont(ctx, ggml_transpose(ctx, hada_2_up)); } if (hada_1_up == nullptr || hada_1_down == nullptr || hada_2_up == nullptr || hada_2_down == nullptr) { @@ -309,21 +337,21 @@ struct LoraModel : public GGMLRunner { } scale_value *= multiplier; - struct ggml_tensor* updown_1 = ggml_ext_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); - struct ggml_tensor* updown_2 = ggml_ext_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); - auto curr_updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); - curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); + struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid); + struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid); + auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2); + curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); if (updown == nullptr) { updown = curr_updown; } else { - updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1); } index++; } return updown; } - ggml_tensor* get_lokr_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -350,24 +378,24 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(lokr_w1_name); if (iter != lora_tensors.end()) { - lokr_w1 = to_f32(compute_ctx, iter->second); + lokr_w1 = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(lokr_w2_name); if (iter != lora_tensors.end()) { - lokr_w2 = to_f32(compute_ctx, iter->second); + lokr_w2 = ggml_ext_cast_f32(ctx, iter->second); } int64_t rank = 1; if (lokr_w1 == nullptr) { iter = lora_tensors.find(lokr_w1_a_name); if (iter != lora_tensors.end()) { - lokr_w1_a = to_f32(compute_ctx, iter->second); + lokr_w1_a = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(lokr_w1_b_name); if (iter != lora_tensors.end()) { - lokr_w1_b = to_f32(compute_ctx, iter->second); + lokr_w1_b = ggml_ext_cast_f32(ctx, iter->second); } if (lokr_w1_a == nullptr || lokr_w1_b == nullptr) { @@ -376,18 +404,18 @@ struct LoraModel : public GGMLRunner { rank = lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1]; - lokr_w1 = ggml_ext_merge_lora(compute_ctx, lokr_w1_b, lokr_w1_a); + lokr_w1 = ggml_ext_merge_lora(ctx, lokr_w1_b, lokr_w1_a); } if (lokr_w2 == nullptr) { iter = lora_tensors.find(lokr_w2_a_name); if (iter != lora_tensors.end()) { - lokr_w2_a = to_f32(compute_ctx, iter->second); + lokr_w2_a = ggml_ext_cast_f32(ctx, iter->second); } iter = lora_tensors.find(lokr_w2_b_name); if (iter != lora_tensors.end()) { - lokr_w2_b = to_f32(compute_ctx, iter->second); + lokr_w2_b = ggml_ext_cast_f32(ctx, iter->second); } if (lokr_w2_a == nullptr || lokr_w2_b == nullptr) { @@ -396,7 +424,7 @@ struct LoraModel : public GGMLRunner { rank = lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1]; - lokr_w2 = ggml_ext_merge_lora(compute_ctx, lokr_w2_b, lokr_w2_a); + lokr_w2 = ggml_ext_merge_lora(ctx, lokr_w2_b, lokr_w2_a); } if (!lokr_w1_a) { @@ -427,49 +455,208 @@ struct LoraModel : public GGMLRunner { scale_value *= multiplier; - auto curr_updown = ggml_ext_kronecker(compute_ctx, lokr_w1, lokr_w2); - curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); + auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2); + curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); if (updown == nullptr) { updown = curr_updown; } else { - updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1); } index++; } return updown; } + ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) { + // lora + ggml_tensor* diff = nullptr; + if (with_lora) { + diff = get_lora_weight_diff(model_tensor_name, ctx); + } + // diff + if (diff == nullptr) { + diff = get_raw_weight_diff(model_tensor_name, ctx); + } + // loha + if (diff == nullptr) { + diff = get_loha_weight_diff(model_tensor_name, ctx); + } + // lokr + if (diff == nullptr) { + diff = get_lokr_weight_diff(model_tensor_name, ctx); + } + if (diff != nullptr) { + if (ggml_nelements(diff) < ggml_nelements(model_tensor)) { + if (ggml_n_dims(diff) == 2 && ggml_n_dims(model_tensor) == 2 && diff->ne[0] == model_tensor->ne[0]) { + LOG_WARN("pad for %s", model_tensor_name.c_str()); + auto pad_tensor = ggml_ext_zeros(ctx, diff->ne[0], model_tensor->ne[1] - diff->ne[1], 1, 1); + diff = ggml_concat(ctx, diff, pad_tensor, 1); + } + } + + GGML_ASSERT(ggml_nelements(diff) == ggml_nelements(model_tensor)); + diff = ggml_reshape(ctx, diff, model_tensor); + } + return diff; + } + + ggml_tensor* get_out_diff(ggml_context* ctx, + ggml_tensor* x, + WeightAdapter::ForwardParams forward_params, + const std::string& model_tensor_name) { + ggml_tensor* out_diff = nullptr; + int index = 0; + while (true) { + std::string key; + if (index == 0) { + key = model_tensor_name; + } else { + key = model_tensor_name + "." + std::to_string(index); + } + + std::string lora_down_name = "lora." + key + ".lora_down"; + std::string lora_up_name = "lora." + key + ".lora_up"; + std::string lora_mid_name = "lora." + key + ".lora_mid"; + std::string scale_name = "lora." + key + ".scale"; + std::string alpha_name = "lora." + key + ".alpha"; + + ggml_tensor* lora_up = nullptr; + ggml_tensor* lora_mid = nullptr; + ggml_tensor* lora_down = nullptr; + + bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + + auto iter = lora_tensors.find(lora_up_name); + if (iter != lora_tensors.end()) { + lora_up = iter->second; + if (is_conv2d && lora_up->type != GGML_TYPE_F16) { + lora_up = ggml_cast(ctx, lora_up, GGML_TYPE_F16); + } + } + + iter = lora_tensors.find(lora_mid_name); + if (iter != lora_tensors.end()) { + lora_mid = iter->second; + if (is_conv2d && lora_mid->type != GGML_TYPE_F16) { + lora_mid = ggml_cast(ctx, lora_mid, GGML_TYPE_F16); + } + } + + iter = lora_tensors.find(lora_down_name); + if (iter != lora_tensors.end()) { + lora_down = iter->second; + if (is_conv2d && lora_down->type != GGML_TYPE_F16) { + lora_down = ggml_cast(ctx, lora_down, GGML_TYPE_F16); + } + } + + if (lora_up == nullptr || lora_down == nullptr) { + break; + } + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + + if (lora_mid) { + applied_lora_tensors.insert(lora_mid_name); + } + + float scale_value = 1.0f; + + int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; + iter = lora_tensors.find(scale_name); + if (iter != lora_tensors.end()) { + scale_value = ggml_ext_backend_tensor_get_f32(iter->second); + applied_lora_tensors.insert(scale_name); + } else { + iter = lora_tensors.find(alpha_name); + if (iter != lora_tensors.end()) { + float alpha = ggml_ext_backend_tensor_get_f32(iter->second); + scale_value = alpha / rank; + // LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value); + applied_lora_tensors.insert(alpha_name); + } + } + scale_value *= multiplier; + + ggml_tensor* lx; + if (!is_conv2d) { + lx = ggml_ext_linear(ctx, x, lora_down, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale); + if (lora_mid) { + lx = ggml_ext_linear(ctx, lx, lora_mid, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale); + } + lx = ggml_ext_linear(ctx, lx, lora_up, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale); + } else { // OP_CONV2D + lx = ggml_ext_conv_2d(ctx, + x, + lora_down, + nullptr, + forward_params.conv2d.s0, + forward_params.conv2d.s1, + forward_params.conv2d.p0, + forward_params.conv2d.p1, + forward_params.conv2d.d0, + forward_params.conv2d.d1, + forward_params.conv2d.direct, + forward_params.conv2d.scale); + if (lora_mid) { + lx = ggml_ext_conv_2d(ctx, + lx, + lora_mid, + nullptr, + 1, + 1, + 0, + 0, + 1, + 1, + forward_params.conv2d.direct, + forward_params.conv2d.scale); + } + lx = ggml_ext_conv_2d(ctx, + lx, + lora_up, + nullptr, + 1, + 1, + 0, + 0, + 1, + 1, + forward_params.conv2d.direct, + forward_params.conv2d.scale); + } + + auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value); + + if (out_diff == nullptr) { + out_diff = curr_out_diff; + } else { + out_diff = ggml_concat(ctx, out_diff, curr_out_diff, ggml_n_dims(out_diff) - 1); + } + + index++; + } + return out_diff; + } + struct ggml_cgraph* build_lora_graph(const std::map& model_tensors, SDVersion version) { size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); - zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); - set_backend_tensor_data(zero_index, zero_index_vec.data()); - ggml_build_forward_expand(gf, zero_index); - preprocess_lora_tensors(model_tensors); original_tensor_to_final_tensor.clear(); + applied_lora_tensors.clear(); - std::set applied_lora_tensors; for (auto it : model_tensors) { std::string model_tensor_name = it.first; ggml_tensor* model_tensor = it.second; // lora - ggml_tensor* updown = get_lora_diff(model_tensor_name, applied_lora_tensors); - // loha - if (updown == nullptr) { - updown = get_loha_diff(model_tensor_name, applied_lora_tensors); - } - - // lokr - if (updown == nullptr) { - updown = get_lokr_diff(model_tensor_name, applied_lora_tensors); - } - - if (updown == nullptr) { + ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor); + if (diff == nullptr) { continue; } @@ -479,53 +666,19 @@ struct LoraModel : public GGMLRunner { set_backend_tensor_data(model_tensor, original_tensor->data); } - if (ggml_nelements(updown) < ggml_nelements(model_tensor)) { - if (ggml_n_dims(updown) == 2 && ggml_n_dims(model_tensor) == 2 && updown->ne[0] == model_tensor->ne[0]) { - LOG_WARN("pad for %s", model_tensor_name.c_str()); - auto pad_tensor = ggml_ext_zeros(compute_ctx, updown->ne[0], model_tensor->ne[1] - updown->ne[1], 1, 1); - updown = ggml_concat(compute_ctx, updown, pad_tensor, 1); - } - } - - GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(model_tensor)); - updown = ggml_reshape(compute_ctx, updown, model_tensor); ggml_tensor* final_tensor; if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) { - final_tensor = to_f32(compute_ctx, model_tensor); - final_tensor = ggml_add_inplace(compute_ctx, final_tensor, updown); + final_tensor = ggml_ext_cast_f32(compute_ctx, model_tensor); + final_tensor = ggml_add_inplace(compute_ctx, final_tensor, diff); final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor); } else { - final_tensor = ggml_add_inplace(compute_ctx, model_tensor, updown); + final_tensor = ggml_add_inplace(compute_ctx, model_tensor, diff); } ggml_build_forward_expand(gf, final_tensor); if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { original_tensor_to_final_tensor[original_tensor] = final_tensor; } } - size_t total_lora_tensors_count = 0; - size_t applied_lora_tensors_count = 0; - - for (auto& kv : lora_tensors) { - total_lora_tensors_count++; - if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { - LOG_WARN("unused lora tensor |%s|", kv.first.c_str()); - print_ggml_tensor(kv.second, true); - // exit(0); - } else { - applied_lora_tensors_count++; - } - } - /* Don't worry if this message shows up twice in the logs per LoRA, - * this function is called once to calculate the required buffer size - * and then again to actually generate a graph to be used */ - if (applied_lora_tensors_count != total_lora_tensors_count) { - LOG_WARN("Only (%lu / %lu) LoRA tensors will be applied", - applied_lora_tensors_count, total_lora_tensors_count); - } else { - LOG_DEBUG("(%lu / %lu) LoRA tensors will be applied", - applied_lora_tensors_count, total_lora_tensors_count); - } - return gf; } @@ -534,6 +687,7 @@ struct LoraModel : public GGMLRunner { return build_lora_graph(model_tensors, version); }; GGMLRunner::compute(get_graph, n_threads, false); + stat(); for (auto item : original_tensor_to_final_tensor) { ggml_tensor* original_tensor = item.first; ggml_tensor* final_tensor = item.second; @@ -543,6 +697,107 @@ struct LoraModel : public GGMLRunner { original_tensor_to_final_tensor.clear(); GGMLRunner::free_compute_buffer(); } + + void stat(bool at_runntime = false) { + size_t total_lora_tensors_count = 0; + size_t applied_lora_tensors_count = 0; + + for (auto& kv : lora_tensors) { + total_lora_tensors_count++; + if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { + if (!at_runntime) { + LOG_WARN("unused lora tensor |%s|", kv.first.c_str()); + print_ggml_tensor(kv.second, true); + } + } else { + applied_lora_tensors_count++; + } + } + /* Don't worry if this message shows up twice in the logs per LoRA, + * this function is called once to calculate the required buffer size + * and then again to actually generate a graph to be used */ + if (!at_runntime && applied_lora_tensors_count != total_lora_tensors_count) { + LOG_WARN("Only (%lu / %lu) LoRA tensors have been applied, lora_file_path = %s", + applied_lora_tensors_count, total_lora_tensors_count, file_path.c_str()); + } else { + LOG_INFO("(%lu / %lu) LoRA tensors have been applied, lora_file_path = %s", + applied_lora_tensors_count, total_lora_tensors_count, file_path.c_str()); + } + } +}; + +struct MultiLoraAdapter : public WeightAdapter { +protected: + std::vector> lora_models; + +public: + explicit MultiLoraAdapter(const std::vector>& lora_models) + : lora_models(lora_models) { + } + + ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) { + for (auto& lora_model : lora_models) { + ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora); + if (diff == nullptr) { + continue; + } + + if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { + weight = ggml_ext_cast_f32(ctx, weight); + } + weight = ggml_add(ctx, weight, diff); + } + return weight; + } + + ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override { + return patch_weight(ctx, weight, weight_name, true); + } + + ggml_tensor* forward_with_lora(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* w, + ggml_tensor* b, + const std::string& prefix, + WeightAdapter::ForwardParams forward_params) override { + w = patch_weight(ctx, w, prefix + "weight", false); + if (b) { + b = patch_weight(ctx, b, prefix + "bias", false); + } + ggml_tensor* out; + if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) { + out = ggml_ext_linear(ctx, x, w, b, forward_params.linear.force_prec_f32, forward_params.linear.scale); + } else { // OP_CONV2D + out = ggml_ext_conv_2d(ctx, + x, + w, + b, + forward_params.conv2d.s0, + forward_params.conv2d.s1, + forward_params.conv2d.p0, + forward_params.conv2d.p1, + forward_params.conv2d.d0, + forward_params.conv2d.d1, + forward_params.conv2d.direct, + forward_params.conv2d.scale); + } + for (auto& lora_model : lora_models) { + ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight"); + if (out_diff == nullptr) { + continue; + } + out = ggml_add_inplace(ctx, out, out_diff); + } + return out; + } + + size_t get_extra_graph_size() override { + size_t lora_tensor_num = 0; + for (auto& lora_model : lora_models) { + lora_tensor_num += lora_model->lora_tensors.size(); + } + return LORA_GRAPH_BASE_SIZE + lora_tensor_num * 10; + } }; #endif // __LORA_HPP__ diff --git a/mmdit.hpp b/mmdit.hpp index 3ca01d9..c243e03 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -870,7 +870,7 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_tensor* context, struct ggml_tensor* y, std::vector skip_layers = std::vector()) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(MMDIT_GRAPH_SIZE); x = to_backend(x); context = to_backend(context); diff --git a/name_conversion.cpp b/name_conversion.cpp index ea2702a..c50baa5 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -855,6 +855,49 @@ std::string convert_sep_to_dot(std::string name) { return name; } +std::vector cond_stage_model_prefix_vec = { + "cond_stage_model.1.", + "cond_stage_model.", + "conditioner.embedders.", + "text_encoders.", +}; + +std::vector diffuison_model_prefix_vec = { + "model.diffusion_model.", +}; + +std::vector first_stage_model_prefix_vec = { + "first_stage_model.", + "vae.", +}; + +bool is_cond_stage_model_name(const std::string& name) { + for (const auto& prefix : cond_stage_model_prefix_vec) { + if (starts_with(name, prefix) || starts_with(name, "lora." + prefix)) { + return true; + } + } + return false; +} + +bool is_diffusion_model_name(const std::string& name) { + for (const auto& prefix : diffuison_model_prefix_vec) { + if (starts_with(name, prefix) || starts_with(name, "lora." + prefix)) { + return true; + } + } + return false; +} + +bool is_first_stage_model_name(const std::string& name) { + for (const auto& prefix : first_stage_model_prefix_vec) { + if (starts_with(name, prefix) || starts_with(name, "lora." + prefix)) { + return true; + } + } + return false; +} + std::string convert_tensor_name(std::string name, SDVersion version) { bool is_lora = false; bool is_lycoris_underline = false; @@ -956,9 +999,6 @@ std::string convert_tensor_name(std::string name, SDVersion version) { // diffusion model { - std::vector diffuison_model_prefix_vec = { - "model.diffusion_model.", - }; for (const auto& prefix : diffuison_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_diffusion_model_name(name.substr(prefix.size()), prefix, version); @@ -970,12 +1010,6 @@ std::string convert_tensor_name(std::string name, SDVersion version) { // cond_stage_model { - std::vector cond_stage_model_prefix_vec = { - "cond_stage_model.1.", - "cond_stage_model.", - "conditioner.embedders.", - "text_encoders.", - }; for (const auto& prefix : cond_stage_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_cond_stage_model_name(name.substr(prefix.size()), prefix); @@ -987,10 +1021,6 @@ std::string convert_tensor_name(std::string name, SDVersion version) { // first_stage_model { - std::vector first_stage_model_prefix_vec = { - "first_stage_model.", - "vae.", - }; for (const auto& prefix : first_stage_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); diff --git a/name_conversion.h b/name_conversion.h index eb3d1a9..3fefcf7 100644 --- a/name_conversion.h +++ b/name_conversion.h @@ -5,6 +5,10 @@ #include "model.h" +bool is_cond_stage_model_name(const std::string& name); +bool is_diffusion_model_name(const std::string& name); +bool is_first_stage_model_name(const std::string& name); + std::string convert_tensor_name(std::string name, SDVersion version); #endif // __NAME_CONVERSTION_H__ \ No newline at end of file diff --git a/qwen_image.hpp b/qwen_image.hpp index 87d2fb9..94ada47 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -543,7 +543,7 @@ namespace Qwen { std::vector ref_latents = {}, bool increase_ref_index = false) { GGML_ASSERT(x->ne[3] == 1); - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE); x = to_backend(x); context = to_backend(context); diff --git a/qwenvl.hpp b/qwenvl.hpp index 0a914f6..9bc2684 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -1049,7 +1049,7 @@ namespace Qwen { } struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(QWENVL_GRAPH_SIZE); GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4cea83a..3e71ec0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -17,6 +17,7 @@ #include "vae.hpp" #include "latent-preview.h" +#include "name_conversion.h" const char* model_version_to_str[] = { "SD 1.x", @@ -108,10 +109,14 @@ public: std::shared_ptr high_noise_diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; - std::shared_ptr control_net = nullptr; + std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; std::shared_ptr pmid_id_embeds; + std::vector> cond_stage_lora_models; + std::vector> diffusion_lora_models; + std::vector> first_stage_lora_models; + bool apply_lora_immediately = false; std::string taesd_path; bool use_tiny_autoencoder = false; @@ -329,6 +334,25 @@ public: LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); + if (sd_ctx_params->lora_apply_mode == LORA_APPLY_AUTO) { + bool have_quantized_weight = false; + for (const auto& [type, _] : wtype_stat) { + if (ggml_is_quantized(type)) { + have_quantized_weight = true; + break; + } + } + if (have_quantized_weight) { + apply_lora_immediately = false; + } else { + apply_lora_immediately = true; + } + } else if (sd_ctx_params->lora_apply_mode == LORA_APPLY_IMMEDIATELY) { + apply_lora_immediately = true; + } else { + apply_lora_immediately = false; + } + if (sd_version_is_sdxl(version)) { scale_factor = 0.13025f; } else if (sd_version_is_sd3(version)) { @@ -571,8 +595,14 @@ public: version); } if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { - pmid_lora = std::make_shared(backend, sd_ctx_params->photo_maker_path, "", version); - if (!pmid_lora->load_from_file(true, n_threads)) { + pmid_lora = std::make_shared("pmid", backend, sd_ctx_params->photo_maker_path, "", version); + auto lora_tensor_filter = [&](const std::string& tensor_name) { + if (starts_with(tensor_name, "lora.model")) { + return true; + } + return false; + }; + if (!pmid_lora->load_from_file(n_threads, lora_tensor_filter)) { LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); return false; } @@ -907,8 +937,11 @@ public: return result < -1; } - void apply_lora(std::string lora_name, float multiplier) { - int64_t t0 = ggml_time_ms(); + std::shared_ptr load_lora_model_from_file(const std::string& lora_id, + float multiplier, + ggml_backend_t backend, + LoraModel::filter_t lora_tensor_filter = nullptr) { + std::string lora_name = lora_id; std::string high_noise_tag = "|high_noise|"; bool is_high_noise = false; if (starts_with(lora_name, high_noise_tag)) { @@ -925,25 +958,19 @@ public: file_path = ckpt_file_path; } else { LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); - return; + return nullptr; } - LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "", version); - if (!lora.load_from_file(false, n_threads)) { + auto lora = std::make_shared(lora_id, backend, file_path, is_high_noise ? "model.high_noise_" : "", version); + if (!lora->load_from_file(n_threads, lora_tensor_filter)) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); - return; + return nullptr; } - lora.multiplier = multiplier; - // TODO: send version? - lora.apply(tensors, version, n_threads); - lora.free_params_buffer(); - - int64_t t1 = ggml_time_ms(); - - LOG_INFO("lora '%s' applied, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000); + lora->multiplier = multiplier; + return lora; } - void apply_loras(const std::unordered_map& lora_state) { + void apply_loras_immediately(const std::unordered_map& lora_state) { std::unordered_map lora_state_diff; for (auto& kv : lora_state) { const std::string& lora_name = kv.first; @@ -964,12 +991,149 @@ public: } for (auto& kv : lora_state_diff) { - apply_lora(kv.first, kv.second); + int64_t t0 = ggml_time_ms(); + + auto lora = load_lora_model_from_file(kv.first, kv.second, backend); + lora->apply(tensors, version, n_threads); + lora->free_params_buffer(); + + int64_t t1 = ggml_time_ms(); + + LOG_INFO("lora '%s' applied, taking %.2fs", kv.first.c_str(), (t1 - t0) * 1.0f / 1000); } curr_lora_state = lora_state; } + void apply_loras_at_runtime(const std::unordered_map& lora_state) { + cond_stage_lora_models.clear(); + diffusion_lora_models.clear(); + first_stage_lora_models.clear(); + if (cond_stage_model) { + std::vector> lora_models; + auto lora_state_diff = lora_state; + for (auto& lora_model : cond_stage_lora_models) { + auto iter = lora_state_diff.find(lora_model->lora_id); + + if (iter != lora_state_diff.end()) { + lora_model->multiplier = iter->second; + lora_models.push_back(lora_model); + lora_state_diff.erase(iter); + } + } + cond_stage_lora_models = lora_models; + auto lora_tensor_filter = [&](const std::string& tensor_name) { + if (is_cond_stage_model_name(tensor_name)) { + return true; + } + return false; + }; + for (auto& kv : lora_state_diff) { + const std::string& lora_id = kv.first; + float multiplier = kv.second; + + auto lora = load_lora_model_from_file(lora_id, multiplier, clip_backend, lora_tensor_filter); + if (lora && !lora->lora_tensors.empty()) { + lora->preprocess_lora_tensors(tensors); + cond_stage_lora_models.push_back(lora); + } + } + auto multi_lora_adapter = std::make_shared(cond_stage_lora_models); + cond_stage_model->set_weight_adapter(multi_lora_adapter); + } + if (diffusion_model) { + std::vector> lora_models; + auto lora_state_diff = lora_state; + for (auto& lora_model : diffusion_lora_models) { + auto iter = lora_state_diff.find(lora_model->lora_id); + + if (iter != lora_state_diff.end()) { + lora_model->multiplier = iter->second; + lora_models.push_back(lora_model); + lora_state_diff.erase(iter); + } + } + diffusion_lora_models = lora_models; + auto lora_tensor_filter = [&](const std::string& tensor_name) { + if (is_diffusion_model_name(tensor_name)) { + return true; + } + return false; + }; + for (auto& kv : lora_state_diff) { + const std::string& lora_name = kv.first; + float multiplier = kv.second; + + auto lora = load_lora_model_from_file(lora_name, multiplier, backend, lora_tensor_filter); + if (lora && !lora->lora_tensors.empty()) { + lora->preprocess_lora_tensors(tensors); + diffusion_lora_models.push_back(lora); + } + } + auto multi_lora_adapter = std::make_shared(diffusion_lora_models); + diffusion_model->set_weight_adapter(multi_lora_adapter); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_weight_adapter(multi_lora_adapter); + } + } + + if (first_stage_model) { + std::vector> lora_models; + auto lora_state_diff = lora_state; + for (auto& lora_model : first_stage_lora_models) { + auto iter = lora_state_diff.find(lora_model->lora_id); + + if (iter != lora_state_diff.end()) { + lora_model->multiplier = iter->second; + lora_models.push_back(lora_model); + lora_state_diff.erase(iter); + } + } + first_stage_lora_models = lora_models; + auto lora_tensor_filter = [&](const std::string& tensor_name) { + if (is_first_stage_model_name(tensor_name)) { + return true; + } + return false; + }; + for (auto& kv : lora_state_diff) { + const std::string& lora_name = kv.first; + float multiplier = kv.second; + + auto lora = load_lora_model_from_file(lora_name, multiplier, vae_backend, lora_tensor_filter); + if (lora && !lora->lora_tensors.empty()) { + lora->preprocess_lora_tensors(tensors); + first_stage_lora_models.push_back(lora); + } + } + auto multi_lora_adapter = std::make_shared(first_stage_lora_models); + first_stage_model->set_weight_adapter(multi_lora_adapter); + } + } + + void lora_stat() { + if (!cond_stage_lora_models.empty()) { + LOG_INFO("cond_stage_lora_models:"); + for (auto& lora_model : cond_stage_lora_models) { + lora_model->stat(); + } + } + + if (!diffusion_lora_models.empty()) { + LOG_INFO("diffusion_lora_models:"); + for (auto& lora_model : diffusion_lora_models) { + lora_model->stat(); + } + } + + if (!first_stage_lora_models.empty()) { + LOG_INFO("first_stage_lora_models:"); + for (auto& lora_model : first_stage_lora_models) { + lora_model->stat(); + } + } + } + std::string apply_loras_from_prompt(const std::string& prompt) { auto result_pair = extract_and_remove_lora(prompt); std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier @@ -978,10 +1142,18 @@ public: LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second); } int64_t t0 = ggml_time_ms(); - apply_loras(lora_f2m); + if (apply_lora_immediately) { + LOG_INFO("apply lora immediately"); + apply_loras_immediately(lora_f2m); + } else { + LOG_INFO("apply at runtime"); + apply_loras_at_runtime(lora_f2m); + } int64_t t1 = ggml_time_ms(); - LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str()); + if (!lora_f2m.empty()) { + LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str()); + } return result_pair.second; } @@ -2081,6 +2253,28 @@ enum preview_t str_to_preview(const char* str) { return PREVIEW_COUNT; } +const char* lora_apply_mode_to_str[] = { + "auto", + "immediately", + "at_runtime", +}; + +const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode) { + if (mode < LORA_APPLY_MODE_COUNT) { + return lora_apply_mode_to_str[mode]; + } + return NONE_STR; +} + +enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) { + for (int i = 0; i < LORA_APPLY_MODE_COUNT; i++) { + if (!strcmp(str, lora_apply_mode_to_str[i])) { + return (enum lora_apply_mode_t)i; + } + } + return LORA_APPLY_MODE_COUNT; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; @@ -2089,6 +2283,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->wtype = SD_TYPE_COUNT; sd_ctx_params->rng_type = CUDA_RNG; sd_ctx_params->prediction = DEFAULT_PRED; + sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false; @@ -2674,6 +2869,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { sd_ctx->sd->first_stage_model->free_params_buffer(); } + + sd_ctx->sd->lora_stat(); + sd_image_t* result_images = (sd_image_t*)calloc(batch_count, sizeof(sd_image_t)); if (result_images == nullptr) { ggml_free(work_ctx); @@ -3343,6 +3541,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->first_stage_model->free_params_buffer(); } + sd_ctx->sd->lora_stat(); + sd_image_t* result_images = (sd_image_t*)calloc(vid->ne[2], sizeof(sd_image_t)); if (result_images == nullptr) { ggml_free(work_ctx); diff --git a/stable-diffusion.h b/stable-diffusion.h index c5db361..5cb2394 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -134,6 +134,13 @@ enum preview_t { PREVIEW_COUNT }; +enum lora_apply_mode_t { + LORA_APPLY_AUTO, + LORA_APPLY_IMMEDIATELY, + LORA_APPLY_AT_RUNTIME, + LORA_APPLY_MODE_COUNT, +}; + typedef struct { bool enabled; int tile_size_x; @@ -165,6 +172,7 @@ typedef struct { enum sd_type_t wtype; enum rng_type_t rng_type; enum prediction_t prediction; + enum lora_apply_mode_t lora_apply_mode; bool offload_params_to_cpu; bool keep_clip_on_cpu; bool keep_control_net_on_cpu; @@ -283,6 +291,8 @@ SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API enum prediction_t str_to_prediction(const char* str); SD_API const char* sd_preview_name(enum preview_t preview); SD_API enum preview_t str_to_preview(const char* str); +SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); +SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/unet.hpp b/unet.hpp index 8f0adf3..de05f46 100644 --- a/unet.hpp +++ b/unet.hpp @@ -7,7 +7,7 @@ /*==================================================== UnetModel =====================================================*/ -#define UNET_GRAPH_SIZE 10240 +#define UNET_GRAPH_SIZE 102400 class SpatialVideoTransformer : public SpatialTransformer { protected: @@ -612,7 +612,7 @@ struct UNetModelRunner : public GGMLRunner { int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, UNET_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(UNET_GRAPH_SIZE); if (num_video_frames == -1) { num_video_frames = x->ne[3]; diff --git a/wan.hpp b/wan.hpp index 91a2e92..41882e7 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1133,7 +1133,7 @@ namespace WAN { } struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 10240 * z->ne[2], false); + struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]); z = to_backend(z); @@ -1147,7 +1147,7 @@ namespace WAN { } struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false); + struct ggml_cgraph* gf = new_graph_custom(20480); ae.clear_cache(); @@ -2142,7 +2142,7 @@ namespace WAN { struct ggml_tensor* time_dim_concat = nullptr, struct ggml_tensor* vace_context = nullptr, float vace_strength = 1.f) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); + struct ggml_cgraph* gf = new_graph_custom(WAN_GRAPH_SIZE); x = to_backend(x); timesteps = to_backend(timesteps);