Compare commits

...

9 Commits

Author SHA1 Message Date
yslai
19d876ee30
feat: implement DDIM with the "trailing" timestep spacing and TCD (#568) 2025-02-22 21:34:22 +08:00
lalala
f27f2b2aa2
docs: add missing --mask and --guidance options to print_usage (#572) 2025-02-22 21:32:37 +08:00
piallai
99609761dc
docs: fix typo in readme (#574) 2025-02-22 21:30:28 +08:00
stduhpf
69c73789fe
fix: force binary mask for inpaint models (#589)
Co-authored-by: leejet <leejet714@gmail.com>
2025-02-22 21:29:57 +08:00
Meng, Hengyu
838beb9b5e
chore: add global SYCL compile flags (#597) 2025-02-22 21:23:58 +08:00
stduhpf
f23b803a6b
fix:: unapply current loras properly (#590) 2025-02-22 21:22:22 +08:00
stduhpf
1be2491dcf
feat: partial LyCORIS support (tucker decomposition for LoCon + LoHa + LoKr) (#577) 2025-02-22 21:19:26 +08:00
Matti Pulkkinen
3753223982
fix: make get_files_from_dir works with absolute path (#598)
Co-authored-by: Matti Pulkkinen <pulkkinen@ultimatium.com>
2025-02-22 21:16:50 +08:00
R0CKSTAR
59ca2b0f16
chore: bump MUSA SDK version to rc3.1.1 (#599)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2025-02-22 21:14:26 +08:00
10 changed files with 1040 additions and 396 deletions

View File

@ -96,6 +96,7 @@ endif()
if(SD_SYCL)
message("-- Use SYCL as backend stable-diffusion")
set(GGML_SYCL ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
add_definitions(-DSD_USE_SYCL)
# disable fast-math on host, see:
# https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html

View File

@ -1,4 +1,4 @@
ARG MUSA_VERSION=rc3.1.0
ARG MUSA_VERSION=rc3.1.1
FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build

View File

@ -326,7 +326,7 @@ These projects use `stable-diffusion.cpp` as a backend for their image generatio
- [Jellybox](https://jellybox.com)
- [Stable Diffusion GUI](https://github.com/fszontagh/sd.cpp.gui.wx)
- [Stable Diffusion CLI-GUI] (https://github.com/piallai/stable-diffusion.cpp)
- [Stable Diffusion CLI-GUI](https://github.com/piallai/stable-diffusion.cpp)
## Contributors

View File

@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
ggml_context* work_ctx,
ggml_tensor* x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng) {
std::shared_ptr<RNG> rng,
float eta) {
size_t steps = sigmas.size() - 1;
// sample_euler_ancestral
switch (method) {
@ -1005,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing
{
// See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG]
//
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
// end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler)
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* variance_noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming
//
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
// and p. 16, (16) (<variable name> -> <name in
// paper>):
//
// - pred_noise_t -> epsilon_theta^(t)(x_t)
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
// - std_dev_t -> sigma_t (not the LMS sigma)
// - eta -> eta (set to 0 at the moment)
// - pred_sample_direction -> "direction pointing to
// x_t"
// - pred_prev_sample -> "x_t-1"
int timestep =
roundf(TIMESTEPS -
i * ((float)TIMESTEPS / steps)) - 1;
// 1. get previous step value (=t-1)
int prev_timestep = timestep - TIMESTEPS / steps;
// The sigma here is chosen to cause the
// CompVisDenoiser to produce t = timestep
float sigma = compvis_sigmas[timestep];
if (i == 0) {
// The function add_noise intializes x to
// Diffusers' latents * sigma (as in Diffusers'
// pipeline) or sample * sigma (Diffusers'
// scheduler), where this sigma = init_noise_sigma
// in Diffusers. For DDPM and DDIM however,
// init_noise_sigma = 1. But the k-diffusion
// model() also evaluates F_theta(c_in(sigma) x;
// ...) instead of the bare U-net F_theta, with
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
// T. Karras et al., "Elucidating the Design Space
// of Diffusion-Based Generative Models",
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
// the first call has to be prescaled as x <- x /
// (c_in * sigma) with the k-diffusion pipeline
// and CompVisDenoiser.
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
// For the subsequent steps after the first one,
// at this point x = latents or x = sample, and
// needs to be prescaled with x <- sample / c_in
// to compensate for model() applying the scale
// c_in before the U-net F_theta
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
// Note (also noise_pred in Diffuser's pipeline)
// model_output = model() is the D(x, sigma) as
// defined in Karras et al. (2022), p. 3, Table 1 and
// p. 8 (7), compare also p. 38 (226) therein.
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
// Here model_output is still the k-diffusion denoiser
// output, not the U-net output F_theta(c_in(sigma) x;
// ...) in Karras et al. (2022), whereas Diffusers'
// model_output is F_theta(...). Recover the actual
// model_output, which is also referred to as the
// "Karras ODE derivative" d or d_cur in several
// samplers above.
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
float alpha_prod_t = alphas_cumprod[timestep];
// Note final_alpha_cumprod = alphas_cumprod[0] due to
// trailing timestep spacing
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
float beta_prod_t = 1 - alpha_prod_t;
// 3. compute predicted original sample from predicted
// noise also called "predicted x_0" of formula (12)
// from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
// Note the substitution of latents or sample = x
// * c_in = x / sqrt(sigma^2 + 1)
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// Assuming the "epsilon" prediction type, where below
// pred_epsilon = model_output is inserted, and is not
// defined/copied explicitly.
//
// 5. compute variance: "sigma_t(eta)" -> see formula
// (16)
//
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
// sqrt(1 - alpha_t/alpha_t-1)
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
float variance = (beta_prod_t_prev / beta_prod_t) *
(1 - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
// 6. compute "direction pointing to x_t" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
// 7. compute x_t without "random noise" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_model_output = (float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Two step inner loop without an explicit
// tensor
float pred_sample_direction =
std::sqrt(1 - alpha_prod_t_prev -
std::pow(std_dev_t, 2)) *
vec_model_output[j];
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
vec_pred_original_sample[j] +
pred_sample_direction;
}
}
if (eta > 0) {
ggml_tensor_set_f32_randn(variance_noise, rng);
float* vec_variance_noise =
(float*)variance_noise->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] += std_dev_t * vec_variance_noise[j];
}
}
// See the note above: x = latents or sample here, and
// is not scaled by the c_in. For the final output
// this is correct, but for subsequent iterations, x
// needs to be prescaled again, since k-diffusion's
// model() differes from the bare U-net F_theta by the
// factor c_in.
}
} break;
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation
{
// See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation
// by Semi-Linear Consistency Function with Trajectory
// Mapping", arXiv:2402.19159 [cs.CV]
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
int original_steps = 50;
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// Analytic form for TCD timesteps
int timestep = TIMESTEPS - 1 -
(TIMESTEPS / original_steps) *
(int)floor(i * ((float)original_steps / steps));
// 1. get previous step value
int prev_timestep = i >= steps - 1 ? 0 :
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
(int)floor((i + 1) *
((float)original_steps / steps));
// Here timestep_s is tau_n' in Algorithm 4. The _s
// notation appears to be that from C. Lu,
// "DPM-Solver: A Fast ODE Solver for Diffusion
// Probabilistic Model Sampling in Around 10 Steps",
// arXiv:2206.00927 [cs.LG], but this notation is not
// continued in Algorithm 4, where _n' is used.
int timestep_s =
(int)floor((1 - eta) * prev_timestep);
// Begin k-diffusion specific workaround for
// evaluating F_theta(x; ...) from D(x, sigma), same
// as in DDIM (and see there for detailed comments)
float sigma = compvis_sigmas[timestep];
if (i == 0) {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
//
// When comparing TCD with DDPM/DDIM note that Zheng
// et al. (2024) follows the DPM-Solver notation for
// alpha. One can find the following comment in the
// original DPM-Solver code
// (https://github.com/LuChengTHU/dpm-solver/):
// "**Important**: Please pay special attention for
// the args for `alphas_cumprod`: The `alphas_cumprod`
// is the \hat{alpha_n} arrays in the notations of
// DDPM. [...] Therefore, the notation \hat{alpha_n}
// is different from the notation alpha_t in
// DPM-Solver. In fact, we have alpha_{t_n} =
// \sqrt{\hat{alpha_n}}, [...]"
float alpha_prod_t = alphas_cumprod[timestep];
float beta_prod_t = 1 - alpha_prod_t;
// Note final_alpha_cumprod = alphas_cumprod[0] since
// TCD is always "trailing"
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
// The subscript _s are the only portion in this
// section (2) unique to TCD
float alpha_prod_s = alphas_cumprod[timestep_s];
float beta_prod_s = 1 - alpha_prod_s;
// 3. Compute the predicted noised sample x_s based on
// the model parameterization
//
// This section is also exactly the same as DDIM
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// This consistency function step can be difficult to
// decipher from Algorithm 4, as it is simply stated
// using a consistency function. This step is the
// modified DDIM, i.e. p. 8 (32) in Zheng et
// al. (2024), with eta set to 0 (see the paragraph
// immediately thereafter that states this somewhat
// obliquely).
{
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Substituting x = pred_noised_sample and
// pred_epsilon = model_output
vec_x[j] =
std::sqrt(alpha_prod_s) *
vec_pred_original_sample[j] +
std::sqrt(beta_prod_s) *
vec_model_output[j];
}
}
// 4. Sample and inject noise z ~ N(0, I) for
// MultiStep Inference Noise is not used on the final
// timestep of the timestep schedule. This also means
// that noise is not used for one-step sampling. Eta
// (referred to as "gamma" in the paper) was
// introduced to control the stochasticity in every
// step. When eta = 0, it represents deterministic
// sampling, whereas eta = 1 indicates full stochastic
// sampling.
if (eta > 0 && i != steps - 1) {
// In this case, x is still pred_noised_sample,
// continue in-place
ggml_tensor_set_f32_randn(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Corresponding to (35) in Zheng et
// al. (2024), substituting x =
// pred_noised_sample
vec_x[j] =
std::sqrt(alpha_prod_t_prev /
alpha_prod_s) *
vec_x[j] +
std::sqrt(1 - alpha_prod_t_prev /
alpha_prod_s) *
vec_noise[j];
}
}
}
} break;
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

View File

@ -39,6 +39,8 @@ const char* sample_method_str[] = {
"ipndm",
"ipndm_v",
"lcm",
"ddim_trailing",
"tcd",
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
@ -93,6 +95,7 @@ struct SDParams {
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
@ -162,6 +165,7 @@ void print_params(SDParams params) {
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" eta: %.2f\n", params.eta);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height);
@ -202,13 +206,16 @@ void print_usage(int argc, const char* argv[]) {
printf(" If not specified, the default is the type of the weight file\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
printf(" --control-image [IMAGE] path to image condition, control net\n");
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
@ -219,7 +226,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
@ -438,6 +445,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.guidance = std::stof(argv[i]);
} else if (arg == "--eta") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.eta = std::stof(argv[i]);
} else if (arg == "--strength") {
if (++i >= argc) {
invalid_arg = true;
@ -717,6 +730,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Eta: " + std::to_string(params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
@ -937,6 +951,7 @@ int main(int argc, const char* argv[]) {
params.clip_skip,
params.cfg_scale,
params.guidance,
params.eta,
params.width,
params.height,
params.sample_method,
@ -1004,6 +1019,7 @@ int main(int argc, const char* argv[]) {
params.clip_skip,
params.cfg_scale,
params.guidance,
params.eta,
params.width,
params.height,
params.sample_method,

View File

@ -52,6 +52,71 @@
#define __STATIC_INLINE__ static inline
#endif
// n-mode trensor-matrix product
// example: 2-mode product
// A: [ne03, k, ne01, ne00]
// B: k rows, m columns => [k, m]
// result is [ne03, m, ne01, ne00]
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
// reshape A
// swap 0th and nth axis
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
int ne1 = a->ne[1];
int ne2 = a->ne[2];
int ne3 = a->ne[3];
// make 2D
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
// reshape output (same shape as a after permutation except first dim)
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
// swap back 0th and nth axis
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
return result;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
struct ggml_tensor* updown;
// flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
auto lora_down_n_dims = ggml_n_dims(lora_down);
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
// ggml_mul_mat requires tensor b transposed
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
if (lora_mid == NULL) {
updown = ggml_mul_mat(ctx, lora_up, lora_down);
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
} else {
// undoing tucker decomposition for conv layers.
// lora_mid has shape (3, 3, Rank, Rank)
// lora_down has shape (Rank, In, 1, 1)
// lora_up has shape (Rank, Out, 1, 1)
// conv layer shape is (3, 3, Out, In)
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
updown = ggml_cont(ctx, updown);
}
return updown;
}
// Kronecker product
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
return ggml_mul(ctx,
ggml_upscale_ext(ctx,
a,
a->ne[0] * b->ne[0],
a->ne[1] * b->ne[1],
a->ne[2] * b->ne[2],
a->ne[3] * b->ne[3]),
b);
}
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
(void)level;
(void)user_data;
@ -318,8 +383,10 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
for (int k = 0; k < channels; k++) {
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
@ -987,8 +1054,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
}
/* SDXL with LoRA requires more space */
#define MAX_PARAMS_TENSOR_NUM 15360
#define MAX_GRAPH_SIZE 15360
#define MAX_PARAMS_TENSOR_NUM 32768
#define MAX_GRAPH_SIZE 32768
struct GGMLRunner {
protected:

907
lora.hpp
View File

@ -197,6 +197,10 @@ struct LoraModel : public GGMLRunner {
blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks");
}
if (blk_name.find("text_encoders.clip_l") != std::string::npos) {
blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model");
}
for (const auto& item : alt_names) {
size_t match = blk_name.find(item.first);
if (match != std::string::npos) {
@ -217,13 +221,17 @@ struct LoraModel : public GGMLRunner {
keys.push_back(split_blk);
}
}
keys.push_back(blk_name);
}
keys.push_back(blk_name);
std::vector<std::string> ret;
for (std::string& key : keys) {
ret.push_back(key);
replace_all_chars(key, '.', '_');
// fix for some sdxl lora, like lcm-lora-xl
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
ret.push_back("model_diffusion_model_output_blocks_2_1_conv");
}
ret.push_back(key);
}
return ret;
@ -244,390 +252,545 @@ struct LoraModel : public GGMLRunner {
std::vector<std::string> keys = to_lora_keys(k_tensor, version);
if (keys.size() == 0)
continue;
ggml_tensor* lora_up = NULL;
ggml_tensor* lora_down = NULL;
for (auto& key : keys) {
std::string alpha_name = "";
std::string scale_name = "";
std::string split_q_scale_name = "";
std::string lora_down_name = "";
std::string lora_up_name = "";
if (starts_with(key, "SPLIT|")) {
bool is_qkv_split = starts_with(key, "SPLIT|");
if (is_qkv_split) {
key = key.substr(sizeof("SPLIT|") - 1);
// TODO: Handle alphas
std::string suffix = "";
auto split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
suffix = "_proj";
split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
}
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = lora_pre[type] + key + "k" + suffix + lora_downs[type] + ".weight";
auto split_v_d_name = lora_pre[type] + key + "v" + suffix + lora_downs[type] + ".weight";
auto split_q_u_name = lora_pre[type] + key + "q" + suffix + lora_ups[type] + ".weight";
auto split_k_u_name = lora_pre[type] + key + "k" + suffix + lora_ups[type] + ".weight";
auto split_v_u_name = lora_pre[type] + key + "v" + suffix + lora_ups[type] + ".weight";
auto split_q_scale_name = lora_pre[type] + key + "q" + suffix + ".scale";
auto split_k_scale_name = lora_pre[type] + key + "k" + suffix + ".scale";
auto split_v_scale_name = lora_pre[type] + key + "v" + suffix + ".scale";
auto split_q_alpha_name = lora_pre[type] + key + "q" + suffix + ".alpha";
auto split_k_alpha_name = lora_pre[type] + key + "k" + suffix + ".alpha";
auto split_v_alpha_name = lora_pre[type] + key + "v" + suffix + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 |
// |0 ,k_up,0 |
// |0 ,0 ,v_up|
// (q_down,k_down,v_down) . (q ,k ,v)
// up_concat will be [9216, R*3, 1, 1]
// down_concat will be [R*3, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_scale(compute_ctx, z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
}
}
if (starts_with(key, "SPLIT_L|")) {
bool is_qkvm_split = starts_with(key, "SPLIT_L|");
if (is_qkvm_split) {
key = key.substr(sizeof("SPLIT_L|") - 1);
auto split_q_d_name = lora_pre[type] + key + "attn.to_q" + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = lora_pre[type] + key + "attn.to_k" + lora_downs[type] + ".weight";
auto split_v_d_name = lora_pre[type] + key + "attn.to_v" + lora_downs[type] + ".weight";
auto split_q_u_name = lora_pre[type] + key + "attn.to_q" + lora_ups[type] + ".weight";
auto split_k_u_name = lora_pre[type] + key + "attn.to_k" + lora_ups[type] + ".weight";
auto split_v_u_name = lora_pre[type] + key + "attn.to_v" + lora_ups[type] + ".weight";
auto split_m_d_name = lora_pre[type] + key + "proj_mlp" + lora_downs[type] + ".weight";
auto split_m_u_name = lora_pre[type] + key + "proj_mlp" + lora_ups[type] + ".weight";
auto split_q_scale_name = lora_pre[type] + key + "attn.to_q" + ".scale";
auto split_k_scale_name = lora_pre[type] + key + "attn.to_k" + ".scale";
auto split_v_scale_name = lora_pre[type] + key + "attn.to_v" + ".scale";
auto split_m_scale_name = lora_pre[type] + key + "proj_mlp" + ".scale";
auto split_q_alpha_name = lora_pre[type] + key + "attn.to_q" + ".alpha";
auto split_k_alpha_name = lora_pre[type] + key + "attn.to_k" + ".alpha";
auto split_v_alpha_name = lora_pre[type] + key + "attn.to_v" + ".alpha";
auto split_m_alpha_name = lora_pre[type] + key + "proj_mlp" + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
ggml_tensor* lora_m_down = NULL;
ggml_tensor* lora_m_up = NULL;
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
}
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
}
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float m_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
float lora_m_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
applied_lora_tensors.insert(split_m_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
applied_lora_tensors.insert(split_m_alpha_name);
lora_m_scale = lora_m_alpha / m_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 ,0 |
// |0 ,k_up,0 ,0 |
// |0 ,0 ,v_up,0 |
// |0 ,0 ,0 ,m_up|
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
// up_concat will be [21504, R*4, 1, 1]
// down_concat will be [R*4, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
ggml_scale(compute_ctx, z, 0);
ggml_scale(compute_ctx, mlp_z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_m_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
applied_lora_tensors.insert(split_m_d_name);
}
}
if (lora_up == NULL || lora_down == NULL) {
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
// fix for some sdxl lora, like lcm-lora-xl
key = "model_diffusion_model_output_blocks_2_1_conv";
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
}
struct ggml_tensor* updown = NULL;
float scale_value = 1.0f;
std::string fk = lora_pre[type] + key;
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
// LoHa mode
// TODO: split qkv convention for LoHas (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoHa models.");
break;
}
std::string alpha_name = "";
ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition
ggml_tensor* hada_1_up = NULL;
ggml_tensor* hada_1_down = NULL;
ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition
ggml_tensor* hada_2_up = NULL;
ggml_tensor* hada_2_down = NULL;
std::string hada_1_mid_name = "";
std::string hada_1_down_name = "";
std::string hada_1_up_name = "";
std::string hada_2_mid_name = "";
std::string hada_2_down_name = "";
std::string hada_2_up_name = "";
hada_1_down_name = fk + ".hada_w1_b";
hada_1_up_name = fk + ".hada_w1_a";
hada_1_mid_name = fk + ".hada_t1";
if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) {
hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]);
}
if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) {
hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]);
}
if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) {
hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]);
applied_lora_tensors.insert(hada_1_mid_name);
hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up));
}
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
alpha_name = lora_pre[type] + key + ".alpha";
scale_name = lora_pre[type] + key + ".scale";
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
lora_up = lora_tensors[lora_up_name];
hada_2_down_name = fk + ".hada_w2_b";
hada_2_up_name = fk + ".hada_w2_a";
hada_2_mid_name = fk + ".hada_t2";
if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) {
hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]);
}
if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) {
hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]);
}
if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) {
hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]);
applied_lora_tensors.insert(hada_2_mid_name);
hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up));
}
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
lora_down = lora_tensors[lora_down_name];
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
alpha_name = fk + ".alpha";
applied_lora_tensors.insert(hada_1_down_name);
applied_lora_tensors.insert(hada_1_up_name);
applied_lora_tensors.insert(hada_2_down_name);
applied_lora_tensors.insert(hada_2_up_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);
}
if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) {
continue;
}
if (lora_up == NULL || lora_down == NULL) {
continue;
}
// calc_scale
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
float scale_value = 1.0f;
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid);
updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2);
// calc_scale
// TODO: .dora_scale?
int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank;
}
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
// LoKr mode
// TODO: split qkv convention for LoKrs (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoKr models.");
break;
}
std::string alpha_name = fk + ".alpha";
ggml_tensor* lokr_w1 = NULL;
ggml_tensor* lokr_w2 = NULL;
std::string lokr_w1_name = "";
std::string lokr_w2_name = "";
lokr_w1_name = fk + ".lokr_w1";
lokr_w2_name = fk + ".lokr_w2";
if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) {
lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]);
applied_lora_tensors.insert(lokr_w1_name);
} else {
ggml_tensor* down = NULL;
ggml_tensor* up = NULL;
std::string down_name = lokr_w1_name + "_b";
std::string up_name = lokr_w1_name + "_a";
if (lora_tensors.find(down_name) != lora_tensors.end()) {
// w1 should not be low rank normally, sometimes w1 and w2 are swapped
down = to_f32(compute_ctx, lora_tensors[down_name]);
applied_lora_tensors.insert(down_name);
int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank;
}
}
if (lora_tensors.find(up_name) != lora_tensors.end()) {
up = to_f32(compute_ctx, lora_tensors[up_name]);
applied_lora_tensors.insert(up_name);
}
lokr_w1 = ggml_merge_lora(compute_ctx, down, up);
}
if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) {
lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]);
applied_lora_tensors.insert(lokr_w2_name);
} else {
ggml_tensor* down = NULL;
ggml_tensor* up = NULL;
std::string down_name = lokr_w2_name + "_b";
std::string up_name = lokr_w2_name + "_a";
if (lora_tensors.find(down_name) != lora_tensors.end()) {
down = to_f32(compute_ctx, lora_tensors[down_name]);
applied_lora_tensors.insert(down_name);
int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank;
}
}
if (lora_tensors.find(up_name) != lora_tensors.end()) {
up = to_f32(compute_ctx, lora_tensors[up_name]);
applied_lora_tensors.insert(up_name);
}
lokr_w2 = ggml_merge_lora(compute_ctx, down, up);
}
// Technically it might be unused, but I believe it's the expected behavior
applied_lora_tensors.insert(alpha_name);
updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2);
} else {
// LoRA mode
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
ggml_tensor* lora_up = NULL;
ggml_tensor* lora_down = NULL;
std::string alpha_name = "";
std::string scale_name = "";
std::string split_q_scale_name = "";
std::string lora_mid_name = "";
std::string lora_down_name = "";
std::string lora_up_name = "";
if (is_qkv_split) {
std::string suffix = "";
auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
suffix = "_proj";
split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
}
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight";
auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight";
auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight";
auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight";
auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight";
auto split_q_scale_name = fk + "q" + suffix + ".scale";
auto split_k_scale_name = fk + "k" + suffix + ".scale";
auto split_v_scale_name = fk + "v" + suffix + ".scale";
auto split_q_alpha_name = fk + "q" + suffix + ".alpha";
auto split_k_alpha_name = fk + "k" + suffix + ".alpha";
auto split_v_alpha_name = fk + "v" + suffix + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 |
// |0 ,k_up,0 |
// |0 ,0 ,v_up|
// (q_down,k_down,v_down) . (q ,k ,v)
// up_concat will be [9216, R*3, 1, 1]
// down_concat will be [R*3, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1);
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_scale(compute_ctx, z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1);
// print_ggml_tensor(q_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 9216, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 9216, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0);
// print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
}
} else if (is_qkvm_split) {
auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight";
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
// find qkv and mlp up parts in LoRA model
auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight";
auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight";
auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight";
auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight";
auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight";
auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight";
auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight";
auto split_q_scale_name = fk + "attn.to_q" + ".scale";
auto split_k_scale_name = fk + "attn.to_k" + ".scale";
auto split_v_scale_name = fk + "attn.to_v" + ".scale";
auto split_m_scale_name = fk + "proj_mlp" + ".scale";
auto split_q_alpha_name = fk + "attn.to_q" + ".alpha";
auto split_k_alpha_name = fk + "attn.to_k" + ".alpha";
auto split_v_alpha_name = fk + "attn.to_v" + ".alpha";
auto split_m_alpha_name = fk + "proj_mlp" + ".alpha";
ggml_tensor* lora_q_down = NULL;
ggml_tensor* lora_q_up = NULL;
ggml_tensor* lora_k_down = NULL;
ggml_tensor* lora_k_up = NULL;
ggml_tensor* lora_v_down = NULL;
ggml_tensor* lora_v_up = NULL;
ggml_tensor* lora_m_down = NULL;
ggml_tensor* lora_m_up = NULL;
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
}
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
}
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
}
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
}
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
}
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
}
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
}
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
}
float q_rank = lora_q_up->ne[0];
float k_rank = lora_k_up->ne[0];
float v_rank = lora_v_up->ne[0];
float m_rank = lora_v_up->ne[0];
float lora_q_scale = 1;
float lora_k_scale = 1;
float lora_v_scale = 1;
float lora_m_scale = 1;
if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) {
lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]);
applied_lora_tensors.insert(split_q_scale_name);
}
if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) {
lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]);
applied_lora_tensors.insert(split_k_scale_name);
}
if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) {
lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]);
applied_lora_tensors.insert(split_v_scale_name);
}
if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) {
lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]);
applied_lora_tensors.insert(split_m_scale_name);
}
if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) {
float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]);
applied_lora_tensors.insert(split_q_alpha_name);
lora_q_scale = lora_q_alpha / q_rank;
}
if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) {
float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]);
applied_lora_tensors.insert(split_k_alpha_name);
lora_k_scale = lora_k_alpha / k_rank;
}
if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) {
float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]);
applied_lora_tensors.insert(split_v_alpha_name);
lora_v_scale = lora_v_alpha / v_rank;
}
if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) {
float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]);
applied_lora_tensors.insert(split_m_alpha_name);
lora_m_scale = lora_m_alpha / m_rank;
}
ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale);
ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale);
ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale);
ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale);
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1]
// print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1]
// print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1]
// these need to be stitched together this way:
// |q_up,0 ,0 ,0 |
// |0 ,k_up,0 ,0 |
// |0 ,0 ,v_up,0 |
// |0 ,0 ,0 ,m_up|
// (q_down,k_down,v_down,m_down) . (q ,k ,v ,m)
// up_concat will be [21504, R*4, 1, 1]
// down_concat will be [R*4, 3072, 1, 1]
ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1);
// print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1]
// this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine)
// print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1]
ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up);
ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up);
ggml_scale(compute_ctx, z, 0);
ggml_scale(compute_ctx, mlp_z, 0);
ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1);
ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1);
ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1);
ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1);
ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1);
// print_ggml_tensor(q_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(k_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(v_up, true); //[R, 21504, 1, 1]
// print_ggml_tensor(m_up, true); //[R, 21504, 1, 1]
ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0);
// print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1]
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_m_u_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
applied_lora_tensors.insert(split_m_d_name);
}
} else {
lora_up_name = fk + lora_ups[type] + ".weight";
lora_down_name = fk + lora_downs[type] + ".weight";
lora_mid_name = fk + ".lora_mid.weight";
alpha_name = fk + ".alpha";
scale_name = fk + ".scale";
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
}
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
}
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
applied_lora_tensors.insert(lora_mid_name);
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);
}
if (lora_up == NULL || lora_down == NULL) {
continue;
}
// calc_scale
// TODO: .dora_scale?
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank;
}
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
}
scale_value *= multiplier;
// flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
auto lora_down_n_dims = ggml_n_dims(lora_down);
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
// ggml_mul_mat requires tensor b transposed
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
updown = ggml_reshape(compute_ctx, updown, weight);
updown = ggml_reshape(compute_ctx, updown, weight);
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
ggml_tensor* final_weight;

View File

@ -47,6 +47,8 @@ const char* sampling_methods_str[] = {
"iPNDM",
"iPNDM_v",
"LCM",
"DDIM \"trailing\"",
"TCD"
};
/*================================================== Helper Functions ================================================*/
@ -673,19 +675,20 @@ public:
for (auto& kv : lora_state) {
const std::string& lora_name = kv.first;
float multiplier = kv.second;
if (curr_lora_state.find(lora_name) != curr_lora_state.end()) {
float curr_multiplier = curr_lora_state[lora_name];
float multiplier_diff = multiplier - curr_multiplier;
if (multiplier_diff != 0.f) {
lora_state_diff[lora_name] = multiplier_diff;
}
} else {
lora_state_diff[lora_name] = multiplier;
}
lora_state_diff[lora_name] += multiplier;
}
for (auto& kv : curr_lora_state) {
const std::string& lora_name = kv.first;
float curr_multiplier = kv.second;
lora_state_diff[lora_name] -= curr_multiplier;
}
size_t rm = lora_state_diff.size() - lora_state.size();
if (rm != 0) {
LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
} else {
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
}
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
for (auto& kv : lora_state_diff) {
apply_lora(kv.first, kv.second);
@ -792,6 +795,7 @@ public:
float min_cfg,
float cfg_scale,
float guidance,
float eta,
sample_method_t method,
const std::vector<float>& sigmas,
int start_merge_step,
@ -987,7 +991,7 @@ public:
return denoised;
};
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng);
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
@ -1193,6 +1197,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -1456,6 +1461,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
cfg_scale,
cfg_scale,
guidance,
eta,
sample_method,
sigmas,
start_merge_step,
@ -1521,6 +1527,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -1599,6 +1606,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
clip_skip,
cfg_scale,
guidance,
eta,
width,
height,
sample_method,
@ -1630,6 +1638,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
sample_method_t sample_method,
@ -1777,6 +1786,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
clip_skip,
cfg_scale,
guidance,
eta,
width,
height,
sample_method,
@ -1890,6 +1900,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
min_cfg,
cfg_scale,
0.f,
0.f,
sample_method,
sigmas,
-1,

View File

@ -44,6 +44,8 @@ enum sample_method_t {
IPNDM,
IPNDM_V,
LCM,
DDIM_TRAILING,
TCD,
N_SAMPLE_METHODS
};
@ -155,6 +157,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -180,6 +183,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,

View File

@ -113,18 +113,31 @@ std::vector<std::string> get_files_from_dir(const std::string& dir) {
// Find the first file in the directory
hFind = FindFirstFile(directoryPath, &findFileData);
bool isAbsolutePath = false;
// Check if the directory was found
if (hFind == INVALID_HANDLE_VALUE) {
printf("Unable to find directory.\n");
return files;
printf("Unable to find directory. Try with original path \n");
char directoryPathAbsolute[MAX_PATH];
sprintf(directoryPathAbsolute, "%s*", dir.c_str());
hFind = FindFirstFile(directoryPathAbsolute, &findFileData);
isAbsolutePath = true;
if (hFind == INVALID_HANDLE_VALUE) {
printf("Absolute path was also wrong.\n");
return files;
}
}
// Loop through all files in the directory
do {
// Check if the found file is a regular file (not a directory)
if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) {
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
if (isAbsolutePath) {
files.push_back(dir + "\\" + std::string(findFileData.cFileName));
} else {
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
}
}
} while (FindNextFile(hFind, &findFileData) != 0);