mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
support forward with lora
This commit is contained in:
parent
9a35003e7f
commit
7195efa9a2
@ -959,12 +959,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
||||
int64_t ne3 = x->ne[3];
|
||||
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
|
||||
x = ggml_mul_mat(ctx, w, x);
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
|
||||
if (force_prec_f32) {
|
||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||
}
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
|
||||
} else {
|
||||
x = ggml_mul_mat(ctx, w, x);
|
||||
}
|
||||
if (force_prec_f32) {
|
||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||
if (force_prec_f32) {
|
||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
if (scale != 1.f) {
|
||||
x = ggml_scale(ctx, x, 1.f / scale);
|
||||
@ -1473,8 +1476,34 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||
#define MAX_GRAPH_SIZE 327680
|
||||
|
||||
struct WeightAdapter {
|
||||
virtual ggml_tensor* patch_weight(ggml_context* ggml_ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
|
||||
virtual size_t get_extra_graph_size() = 0;
|
||||
struct ForwardParams {
|
||||
enum class op_type_t {
|
||||
OP_LINEAR,
|
||||
OP_CONV2D,
|
||||
} op_type;
|
||||
struct {
|
||||
bool force_prec_f32 = false;
|
||||
float scale = 1.f;
|
||||
} linear;
|
||||
struct {
|
||||
int s0 = 1;
|
||||
int s1 = 1;
|
||||
int p0 = 0;
|
||||
int p1 = 0;
|
||||
int d0 = 1;
|
||||
int d1 = 1;
|
||||
bool direct = false;
|
||||
float scale = 1.f;
|
||||
} conv2d;
|
||||
};
|
||||
virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
|
||||
virtual ggml_tensor* forward_with_lora(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* w,
|
||||
ggml_tensor* b,
|
||||
const std::string& prefix,
|
||||
ForwardParams forward_params) = 0;
|
||||
virtual size_t get_extra_graph_size() = 0;
|
||||
};
|
||||
|
||||
struct GGMLRunnerContext {
|
||||
@ -2070,14 +2099,15 @@ public:
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
struct ggml_tensor* b = nullptr;
|
||||
if (ctx->weight_adapter) {
|
||||
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
|
||||
}
|
||||
if (bias) {
|
||||
b = params["bias"];
|
||||
if (ctx->weight_adapter) {
|
||||
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
|
||||
}
|
||||
}
|
||||
if (ctx->weight_adapter) {
|
||||
WeightAdapter::ForwardParams forward_params;
|
||||
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
|
||||
forward_params.linear.force_prec_f32 = force_prec_f32;
|
||||
forward_params.linear.scale = scale;
|
||||
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
||||
}
|
||||
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
|
||||
}
|
||||
@ -2177,17 +2207,21 @@ public:
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
struct ggml_tensor* b = nullptr;
|
||||
if (ctx->weight_adapter) {
|
||||
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
|
||||
if (w->type != GGML_TYPE_F16) {
|
||||
w = ggml_cast(ctx->ggml_ctx, w, GGML_TYPE_F16);
|
||||
}
|
||||
}
|
||||
if (bias) {
|
||||
b = params["bias"];
|
||||
if (ctx->weight_adapter) {
|
||||
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
|
||||
}
|
||||
}
|
||||
if (ctx->weight_adapter) {
|
||||
WeightAdapter::ForwardParams forward_params;
|
||||
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||
forward_params.conv2d.s0 = stride.second;
|
||||
forward_params.conv2d.s1 = stride.first;
|
||||
forward_params.conv2d.p0 = padding.second;
|
||||
forward_params.conv2d.p1 = padding.first;
|
||||
forward_params.conv2d.d0 = dilation.second;
|
||||
forward_params.conv2d.d1 = dilation.first;
|
||||
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||
forward_params.conv2d.scale = scale;
|
||||
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
||||
}
|
||||
return ggml_ext_conv_2d(ctx->ggml_ctx,
|
||||
x,
|
||||
|
||||
257
lora.hpp
257
lora.hpp
@ -9,7 +9,7 @@
|
||||
struct LoraModel : public GGMLRunner {
|
||||
std::string lora_id;
|
||||
float multiplier = 1.0f;
|
||||
std::map<std::string, struct ggml_tensor*> lora_tensors;
|
||||
std::unordered_map<std::string, struct ggml_tensor*> lora_tensors;
|
||||
std::map<ggml_tensor*, ggml_tensor*> original_tensor_to_final_tensor;
|
||||
std::set<std::string> applied_lora_tensors;
|
||||
std::string file_path;
|
||||
@ -71,6 +71,10 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
model_loader.load_tensors(on_new_tensor_cb, n_threads);
|
||||
|
||||
if (tensors_to_create.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto& pair : tensors_to_create) {
|
||||
const auto& name = pair.first;
|
||||
const auto& ts = pair.second;
|
||||
@ -97,7 +101,7 @@ struct LoraModel : public GGMLRunner {
|
||||
tensor_preprocessed = true;
|
||||
// I really hate these hardcoded processes.
|
||||
if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) {
|
||||
std::map<std::string, ggml_tensor*> new_lora_tensors;
|
||||
std::unordered_map<std::string, ggml_tensor*> new_lora_tensors;
|
||||
for (auto& [old_name, tensor] : lora_tensors) {
|
||||
std::string new_name = old_name;
|
||||
|
||||
@ -125,7 +129,7 @@ struct LoraModel : public GGMLRunner {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* get_lora_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -201,21 +205,50 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
// diff
|
||||
if (updown == nullptr) {
|
||||
std::string lora_diff_name = "lora." + model_tensor_name + ".diff";
|
||||
|
||||
if (lora_tensors.find(lora_diff_name) != lora_tensors.end()) {
|
||||
updown = ggml_ext_cast_f32(ctx, lora_tensors[lora_diff_name]);
|
||||
applied_lora_tensors.insert(lora_diff_name);
|
||||
}
|
||||
}
|
||||
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_loha_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
std::string key;
|
||||
if (index == 0) {
|
||||
key = model_tensor_name;
|
||||
} else {
|
||||
key = model_tensor_name + "." + std::to_string(index);
|
||||
}
|
||||
|
||||
std::string diff_name = "lora." + key + ".diff";
|
||||
|
||||
ggml_tensor* curr_updown = nullptr;
|
||||
|
||||
auto iter = lora_tensors.find(diff_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
curr_updown = ggml_ext_cast_f32(ctx, iter->second);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
applied_lora_tensors.insert(diff_name);
|
||||
|
||||
float scale_value = 1.0f;
|
||||
scale_value *= multiplier;
|
||||
|
||||
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||
|
||||
if (updown == nullptr) {
|
||||
updown = curr_updown;
|
||||
} else {
|
||||
updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1);
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -318,7 +351,7 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_lokr_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -435,16 +468,23 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor) {
|
||||
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) {
|
||||
// lora
|
||||
ggml_tensor* diff = get_lora_diff(model_tensor_name, ctx);
|
||||
ggml_tensor* diff = nullptr;
|
||||
if (with_lora) {
|
||||
diff = get_lora_weight_diff(model_tensor_name, ctx);
|
||||
}
|
||||
// diff
|
||||
if (diff == nullptr) {
|
||||
diff = get_raw_weight_diff(model_tensor_name, ctx);
|
||||
}
|
||||
// loha
|
||||
if (diff == nullptr) {
|
||||
diff = get_loha_diff(model_tensor_name, ctx);
|
||||
diff = get_loha_weight_diff(model_tensor_name, ctx);
|
||||
}
|
||||
// lokr
|
||||
if (diff == nullptr) {
|
||||
diff = get_lokr_diff(model_tensor_name, ctx);
|
||||
diff = get_lokr_weight_diff(model_tensor_name, ctx);
|
||||
}
|
||||
if (diff != nullptr) {
|
||||
if (ggml_nelements(diff) < ggml_nelements(model_tensor)) {
|
||||
@ -461,6 +501,135 @@ struct LoraModel : public GGMLRunner {
|
||||
return diff;
|
||||
}
|
||||
|
||||
ggml_tensor* get_out_diff(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
WeightAdapter::ForwardParams forward_params,
|
||||
const std::string& model_tensor_name) {
|
||||
ggml_tensor* out_diff = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
std::string key;
|
||||
if (index == 0) {
|
||||
key = model_tensor_name;
|
||||
} else {
|
||||
key = model_tensor_name + "." + std::to_string(index);
|
||||
}
|
||||
|
||||
std::string lora_down_name = "lora." + key + ".lora_down";
|
||||
std::string lora_up_name = "lora." + key + ".lora_up";
|
||||
std::string lora_mid_name = "lora." + key + ".lora_mid";
|
||||
std::string scale_name = "lora." + key + ".scale";
|
||||
std::string alpha_name = "lora." + key + ".alpha";
|
||||
|
||||
ggml_tensor* lora_up = nullptr;
|
||||
ggml_tensor* lora_mid = nullptr;
|
||||
ggml_tensor* lora_down = nullptr;
|
||||
|
||||
auto iter = lora_tensors.find(lora_up_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_up = iter->second;
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lora_mid_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_mid = iter->second;
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lora_down_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_down = iter->second;
|
||||
}
|
||||
|
||||
if (lora_up == nullptr || lora_down == nullptr) {
|
||||
break;
|
||||
}
|
||||
|
||||
applied_lora_tensors.insert(lora_up_name);
|
||||
applied_lora_tensors.insert(lora_down_name);
|
||||
|
||||
if (lora_mid) {
|
||||
applied_lora_tensors.insert(lora_mid_name);
|
||||
}
|
||||
|
||||
float scale_value = 1.0f;
|
||||
|
||||
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||
iter = lora_tensors.find(scale_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
scale_value = ggml_ext_backend_tensor_get_f32(iter->second);
|
||||
applied_lora_tensors.insert(scale_name);
|
||||
} else {
|
||||
iter = lora_tensors.find(alpha_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
|
||||
scale_value = alpha / rank;
|
||||
// LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value);
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
}
|
||||
}
|
||||
scale_value *= multiplier;
|
||||
|
||||
ggml_tensor* lx;
|
||||
if (forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_LINEAR) {
|
||||
lx = ggml_ext_linear(ctx, x, lora_down, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||
if (lora_mid) {
|
||||
lx = ggml_ext_linear(ctx, lx, lora_mid, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||
}
|
||||
lx = ggml_ext_linear(ctx, lx, lora_up, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||
} else { // OP_CONV2D
|
||||
lx = ggml_ext_conv_2d(ctx,
|
||||
x,
|
||||
lora_down,
|
||||
nullptr,
|
||||
forward_params.conv2d.s0,
|
||||
forward_params.conv2d.s1,
|
||||
forward_params.conv2d.p0,
|
||||
forward_params.conv2d.p1,
|
||||
forward_params.conv2d.d0,
|
||||
forward_params.conv2d.d1,
|
||||
forward_params.conv2d.direct,
|
||||
forward_params.conv2d.scale);
|
||||
if (lora_mid) {
|
||||
lx = ggml_ext_conv_2d(ctx,
|
||||
lx,
|
||||
lora_mid,
|
||||
nullptr,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
forward_params.conv2d.direct,
|
||||
forward_params.conv2d.scale);
|
||||
}
|
||||
lx = lx = ggml_ext_conv_2d(ctx,
|
||||
lx,
|
||||
lora_up,
|
||||
nullptr,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
forward_params.conv2d.direct,
|
||||
forward_params.conv2d.scale);
|
||||
}
|
||||
|
||||
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value);
|
||||
|
||||
if (out_diff == nullptr) {
|
||||
out_diff = curr_out_diff;
|
||||
} else {
|
||||
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, ggml_n_dims(out_diff) - 1);
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
return out_diff;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_lora_graph(const std::map<std::string, ggml_tensor*>& model_tensors, SDVersion version) {
|
||||
size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10;
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false);
|
||||
@ -475,7 +644,7 @@ struct LoraModel : public GGMLRunner {
|
||||
ggml_tensor* model_tensor = it.second;
|
||||
|
||||
// lora
|
||||
ggml_tensor* diff = get_diff(model_tensor_name, compute_ctx, model_tensor);
|
||||
ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor);
|
||||
if (diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@ -555,10 +724,9 @@ public:
|
||||
: lora_models(lora_models) {
|
||||
}
|
||||
|
||||
// TODO: cache result for multi run?
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) {
|
||||
for (auto& lora_model : lora_models) {
|
||||
ggml_tensor* diff = lora_model->get_diff(weight_name, ctx, weight);
|
||||
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora);
|
||||
if (diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@ -571,6 +739,47 @@ public:
|
||||
return weight;
|
||||
}
|
||||
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
|
||||
return patch_weight(ctx, weight, weight_name, true);
|
||||
}
|
||||
|
||||
ggml_tensor* forward_with_lora(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* w,
|
||||
ggml_tensor* b,
|
||||
const std::string& prefix,
|
||||
WeightAdapter::ForwardParams forward_params) override {
|
||||
w = patch_weight(ctx, w, prefix + "weight", false);
|
||||
if (b) {
|
||||
b = patch_weight(ctx, b, prefix + "bias", false);
|
||||
}
|
||||
ggml_tensor* out;
|
||||
if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) {
|
||||
out = ggml_ext_linear(ctx, x, w, b, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||
} else { // OP_CONV2D
|
||||
out = ggml_ext_conv_2d(ctx,
|
||||
x,
|
||||
w,
|
||||
b,
|
||||
forward_params.conv2d.s0,
|
||||
forward_params.conv2d.s1,
|
||||
forward_params.conv2d.p0,
|
||||
forward_params.conv2d.p1,
|
||||
forward_params.conv2d.d0,
|
||||
forward_params.conv2d.d1,
|
||||
forward_params.conv2d.direct,
|
||||
forward_params.conv2d.scale);
|
||||
}
|
||||
for (auto& lora_model : lora_models) {
|
||||
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight");
|
||||
if (out_diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
out = ggml_add_inplace(ctx, out, out_diff);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
size_t get_extra_graph_size() override {
|
||||
size_t lora_tensor_num = 0;
|
||||
for (auto& lora_model : lora_models) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user