mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
support forward with lora
This commit is contained in:
parent
9a35003e7f
commit
7195efa9a2
@ -959,12 +959,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
|
|||||||
int64_t ne3 = x->ne[3];
|
int64_t ne3 = x->ne[3];
|
||||||
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
|
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
|
||||||
x = ggml_mul_mat(ctx, w, x);
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
|
if (force_prec_f32) {
|
||||||
|
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
|
||||||
} else {
|
} else {
|
||||||
x = ggml_mul_mat(ctx, w, x);
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
}
|
if (force_prec_f32) {
|
||||||
if (force_prec_f32) {
|
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
}
|
||||||
}
|
}
|
||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, 1.f / scale);
|
x = ggml_scale(ctx, x, 1.f / scale);
|
||||||
@ -1473,8 +1476,34 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
|||||||
#define MAX_GRAPH_SIZE 327680
|
#define MAX_GRAPH_SIZE 327680
|
||||||
|
|
||||||
struct WeightAdapter {
|
struct WeightAdapter {
|
||||||
virtual ggml_tensor* patch_weight(ggml_context* ggml_ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
|
struct ForwardParams {
|
||||||
virtual size_t get_extra_graph_size() = 0;
|
enum class op_type_t {
|
||||||
|
OP_LINEAR,
|
||||||
|
OP_CONV2D,
|
||||||
|
} op_type;
|
||||||
|
struct {
|
||||||
|
bool force_prec_f32 = false;
|
||||||
|
float scale = 1.f;
|
||||||
|
} linear;
|
||||||
|
struct {
|
||||||
|
int s0 = 1;
|
||||||
|
int s1 = 1;
|
||||||
|
int p0 = 0;
|
||||||
|
int p1 = 0;
|
||||||
|
int d0 = 1;
|
||||||
|
int d1 = 1;
|
||||||
|
bool direct = false;
|
||||||
|
float scale = 1.f;
|
||||||
|
} conv2d;
|
||||||
|
};
|
||||||
|
virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
|
||||||
|
virtual ggml_tensor* forward_with_lora(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* w,
|
||||||
|
ggml_tensor* b,
|
||||||
|
const std::string& prefix,
|
||||||
|
ForwardParams forward_params) = 0;
|
||||||
|
virtual size_t get_extra_graph_size() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GGMLRunnerContext {
|
struct GGMLRunnerContext {
|
||||||
@ -2070,14 +2099,15 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (ctx->weight_adapter) {
|
|
||||||
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
|
|
||||||
}
|
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
if (ctx->weight_adapter) {
|
}
|
||||||
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
|
if (ctx->weight_adapter) {
|
||||||
}
|
WeightAdapter::ForwardParams forward_params;
|
||||||
|
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
|
||||||
|
forward_params.linear.force_prec_f32 = force_prec_f32;
|
||||||
|
forward_params.linear.scale = scale;
|
||||||
|
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
||||||
}
|
}
|
||||||
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
|
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
|
||||||
}
|
}
|
||||||
@ -2177,17 +2207,21 @@ public:
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = nullptr;
|
struct ggml_tensor* b = nullptr;
|
||||||
if (ctx->weight_adapter) {
|
|
||||||
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
|
|
||||||
if (w->type != GGML_TYPE_F16) {
|
|
||||||
w = ggml_cast(ctx->ggml_ctx, w, GGML_TYPE_F16);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
if (ctx->weight_adapter) {
|
}
|
||||||
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
|
if (ctx->weight_adapter) {
|
||||||
}
|
WeightAdapter::ForwardParams forward_params;
|
||||||
|
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||||
|
forward_params.conv2d.s0 = stride.second;
|
||||||
|
forward_params.conv2d.s1 = stride.first;
|
||||||
|
forward_params.conv2d.p0 = padding.second;
|
||||||
|
forward_params.conv2d.p1 = padding.first;
|
||||||
|
forward_params.conv2d.d0 = dilation.second;
|
||||||
|
forward_params.conv2d.d1 = dilation.first;
|
||||||
|
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||||
|
forward_params.conv2d.scale = scale;
|
||||||
|
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
|
||||||
}
|
}
|
||||||
return ggml_ext_conv_2d(ctx->ggml_ctx,
|
return ggml_ext_conv_2d(ctx->ggml_ctx,
|
||||||
x,
|
x,
|
||||||
|
|||||||
257
lora.hpp
257
lora.hpp
@ -9,7 +9,7 @@
|
|||||||
struct LoraModel : public GGMLRunner {
|
struct LoraModel : public GGMLRunner {
|
||||||
std::string lora_id;
|
std::string lora_id;
|
||||||
float multiplier = 1.0f;
|
float multiplier = 1.0f;
|
||||||
std::map<std::string, struct ggml_tensor*> lora_tensors;
|
std::unordered_map<std::string, struct ggml_tensor*> lora_tensors;
|
||||||
std::map<ggml_tensor*, ggml_tensor*> original_tensor_to_final_tensor;
|
std::map<ggml_tensor*, ggml_tensor*> original_tensor_to_final_tensor;
|
||||||
std::set<std::string> applied_lora_tensors;
|
std::set<std::string> applied_lora_tensors;
|
||||||
std::string file_path;
|
std::string file_path;
|
||||||
@ -71,6 +71,10 @@ struct LoraModel : public GGMLRunner {
|
|||||||
|
|
||||||
model_loader.load_tensors(on_new_tensor_cb, n_threads);
|
model_loader.load_tensors(on_new_tensor_cb, n_threads);
|
||||||
|
|
||||||
|
if (tensors_to_create.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& pair : tensors_to_create) {
|
for (const auto& pair : tensors_to_create) {
|
||||||
const auto& name = pair.first;
|
const auto& name = pair.first;
|
||||||
const auto& ts = pair.second;
|
const auto& ts = pair.second;
|
||||||
@ -97,7 +101,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
tensor_preprocessed = true;
|
tensor_preprocessed = true;
|
||||||
// I really hate these hardcoded processes.
|
// I really hate these hardcoded processes.
|
||||||
if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) {
|
if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) {
|
||||||
std::map<std::string, ggml_tensor*> new_lora_tensors;
|
std::unordered_map<std::string, ggml_tensor*> new_lora_tensors;
|
||||||
for (auto& [old_name, tensor] : lora_tensors) {
|
for (auto& [old_name, tensor] : lora_tensors) {
|
||||||
std::string new_name = old_name;
|
std::string new_name = old_name;
|
||||||
|
|
||||||
@ -125,7 +129,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* get_lora_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||||
ggml_tensor* updown = nullptr;
|
ggml_tensor* updown = nullptr;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -201,21 +205,50 @@ struct LoraModel : public GGMLRunner {
|
|||||||
|
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// diff
|
|
||||||
if (updown == nullptr) {
|
|
||||||
std::string lora_diff_name = "lora." + model_tensor_name + ".diff";
|
|
||||||
|
|
||||||
if (lora_tensors.find(lora_diff_name) != lora_tensors.end()) {
|
|
||||||
updown = ggml_ext_cast_f32(ctx, lora_tensors[lora_diff_name]);
|
|
||||||
applied_lora_tensors.insert(lora_diff_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return updown;
|
return updown;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* get_loha_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||||
|
ggml_tensor* updown = nullptr;
|
||||||
|
int index = 0;
|
||||||
|
while (true) {
|
||||||
|
std::string key;
|
||||||
|
if (index == 0) {
|
||||||
|
key = model_tensor_name;
|
||||||
|
} else {
|
||||||
|
key = model_tensor_name + "." + std::to_string(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string diff_name = "lora." + key + ".diff";
|
||||||
|
|
||||||
|
ggml_tensor* curr_updown = nullptr;
|
||||||
|
|
||||||
|
auto iter = lora_tensors.find(diff_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
curr_updown = ggml_ext_cast_f32(ctx, iter->second);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(diff_name);
|
||||||
|
|
||||||
|
float scale_value = 1.0f;
|
||||||
|
scale_value *= multiplier;
|
||||||
|
|
||||||
|
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value);
|
||||||
|
|
||||||
|
if (updown == nullptr) {
|
||||||
|
updown = curr_updown;
|
||||||
|
} else {
|
||||||
|
updown = ggml_concat(ctx, updown, curr_updown, ggml_n_dims(updown) - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return updown;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||||
ggml_tensor* updown = nullptr;
|
ggml_tensor* updown = nullptr;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -318,7 +351,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
return updown;
|
return updown;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* get_lokr_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||||
ggml_tensor* updown = nullptr;
|
ggml_tensor* updown = nullptr;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -435,16 +468,23 @@ struct LoraModel : public GGMLRunner {
|
|||||||
return updown;
|
return updown;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* get_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor) {
|
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) {
|
||||||
// lora
|
// lora
|
||||||
ggml_tensor* diff = get_lora_diff(model_tensor_name, ctx);
|
ggml_tensor* diff = nullptr;
|
||||||
|
if (with_lora) {
|
||||||
|
diff = get_lora_weight_diff(model_tensor_name, ctx);
|
||||||
|
}
|
||||||
|
// diff
|
||||||
|
if (diff == nullptr) {
|
||||||
|
diff = get_raw_weight_diff(model_tensor_name, ctx);
|
||||||
|
}
|
||||||
// loha
|
// loha
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr) {
|
||||||
diff = get_loha_diff(model_tensor_name, ctx);
|
diff = get_loha_weight_diff(model_tensor_name, ctx);
|
||||||
}
|
}
|
||||||
// lokr
|
// lokr
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr) {
|
||||||
diff = get_lokr_diff(model_tensor_name, ctx);
|
diff = get_lokr_weight_diff(model_tensor_name, ctx);
|
||||||
}
|
}
|
||||||
if (diff != nullptr) {
|
if (diff != nullptr) {
|
||||||
if (ggml_nelements(diff) < ggml_nelements(model_tensor)) {
|
if (ggml_nelements(diff) < ggml_nelements(model_tensor)) {
|
||||||
@ -461,6 +501,135 @@ struct LoraModel : public GGMLRunner {
|
|||||||
return diff;
|
return diff;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor* get_out_diff(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
WeightAdapter::ForwardParams forward_params,
|
||||||
|
const std::string& model_tensor_name) {
|
||||||
|
ggml_tensor* out_diff = nullptr;
|
||||||
|
int index = 0;
|
||||||
|
while (true) {
|
||||||
|
std::string key;
|
||||||
|
if (index == 0) {
|
||||||
|
key = model_tensor_name;
|
||||||
|
} else {
|
||||||
|
key = model_tensor_name + "." + std::to_string(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string lora_down_name = "lora." + key + ".lora_down";
|
||||||
|
std::string lora_up_name = "lora." + key + ".lora_up";
|
||||||
|
std::string lora_mid_name = "lora." + key + ".lora_mid";
|
||||||
|
std::string scale_name = "lora." + key + ".scale";
|
||||||
|
std::string alpha_name = "lora." + key + ".alpha";
|
||||||
|
|
||||||
|
ggml_tensor* lora_up = nullptr;
|
||||||
|
ggml_tensor* lora_mid = nullptr;
|
||||||
|
ggml_tensor* lora_down = nullptr;
|
||||||
|
|
||||||
|
auto iter = lora_tensors.find(lora_up_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lora_up = iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
iter = lora_tensors.find(lora_mid_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lora_mid = iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
iter = lora_tensors.find(lora_down_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lora_down = iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lora_up == nullptr || lora_down == nullptr) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
applied_lora_tensors.insert(lora_up_name);
|
||||||
|
applied_lora_tensors.insert(lora_down_name);
|
||||||
|
|
||||||
|
if (lora_mid) {
|
||||||
|
applied_lora_tensors.insert(lora_mid_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale_value = 1.0f;
|
||||||
|
|
||||||
|
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||||
|
iter = lora_tensors.find(scale_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
scale_value = ggml_ext_backend_tensor_get_f32(iter->second);
|
||||||
|
applied_lora_tensors.insert(scale_name);
|
||||||
|
} else {
|
||||||
|
iter = lora_tensors.find(alpha_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
|
||||||
|
scale_value = alpha / rank;
|
||||||
|
// LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value);
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scale_value *= multiplier;
|
||||||
|
|
||||||
|
ggml_tensor* lx;
|
||||||
|
if (forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_LINEAR) {
|
||||||
|
lx = ggml_ext_linear(ctx, x, lora_down, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||||
|
if (lora_mid) {
|
||||||
|
lx = ggml_ext_linear(ctx, lx, lora_mid, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||||
|
}
|
||||||
|
lx = ggml_ext_linear(ctx, lx, lora_up, nullptr, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||||
|
} else { // OP_CONV2D
|
||||||
|
lx = ggml_ext_conv_2d(ctx,
|
||||||
|
x,
|
||||||
|
lora_down,
|
||||||
|
nullptr,
|
||||||
|
forward_params.conv2d.s0,
|
||||||
|
forward_params.conv2d.s1,
|
||||||
|
forward_params.conv2d.p0,
|
||||||
|
forward_params.conv2d.p1,
|
||||||
|
forward_params.conv2d.d0,
|
||||||
|
forward_params.conv2d.d1,
|
||||||
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.scale);
|
||||||
|
if (lora_mid) {
|
||||||
|
lx = ggml_ext_conv_2d(ctx,
|
||||||
|
lx,
|
||||||
|
lora_mid,
|
||||||
|
nullptr,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.scale);
|
||||||
|
}
|
||||||
|
lx = lx = ggml_ext_conv_2d(ctx,
|
||||||
|
lx,
|
||||||
|
lora_up,
|
||||||
|
nullptr,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value);
|
||||||
|
|
||||||
|
if (out_diff == nullptr) {
|
||||||
|
out_diff = curr_out_diff;
|
||||||
|
} else {
|
||||||
|
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, ggml_n_dims(out_diff) - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return out_diff;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_lora_graph(const std::map<std::string, ggml_tensor*>& model_tensors, SDVersion version) {
|
struct ggml_cgraph* build_lora_graph(const std::map<std::string, ggml_tensor*>& model_tensors, SDVersion version) {
|
||||||
size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10;
|
size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10;
|
||||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false);
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false);
|
||||||
@ -475,7 +644,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
ggml_tensor* model_tensor = it.second;
|
ggml_tensor* model_tensor = it.second;
|
||||||
|
|
||||||
// lora
|
// lora
|
||||||
ggml_tensor* diff = get_diff(model_tensor_name, compute_ctx, model_tensor);
|
ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor);
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -555,10 +724,9 @@ public:
|
|||||||
: lora_models(lora_models) {
|
: lora_models(lora_models) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: cache result for multi run?
|
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) {
|
||||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
|
|
||||||
for (auto& lora_model : lora_models) {
|
for (auto& lora_model : lora_models) {
|
||||||
ggml_tensor* diff = lora_model->get_diff(weight_name, ctx, weight);
|
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora);
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -571,6 +739,47 @@ public:
|
|||||||
return weight;
|
return weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
|
||||||
|
return patch_weight(ctx, weight, weight_name, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward_with_lora(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* w,
|
||||||
|
ggml_tensor* b,
|
||||||
|
const std::string& prefix,
|
||||||
|
WeightAdapter::ForwardParams forward_params) override {
|
||||||
|
w = patch_weight(ctx, w, prefix + "weight", false);
|
||||||
|
if (b) {
|
||||||
|
b = patch_weight(ctx, b, prefix + "bias", false);
|
||||||
|
}
|
||||||
|
ggml_tensor* out;
|
||||||
|
if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) {
|
||||||
|
out = ggml_ext_linear(ctx, x, w, b, forward_params.linear.force_prec_f32, forward_params.linear.scale);
|
||||||
|
} else { // OP_CONV2D
|
||||||
|
out = ggml_ext_conv_2d(ctx,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
b,
|
||||||
|
forward_params.conv2d.s0,
|
||||||
|
forward_params.conv2d.s1,
|
||||||
|
forward_params.conv2d.p0,
|
||||||
|
forward_params.conv2d.p1,
|
||||||
|
forward_params.conv2d.d0,
|
||||||
|
forward_params.conv2d.d1,
|
||||||
|
forward_params.conv2d.direct,
|
||||||
|
forward_params.conv2d.scale);
|
||||||
|
}
|
||||||
|
for (auto& lora_model : lora_models) {
|
||||||
|
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight");
|
||||||
|
if (out_diff == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out = ggml_add_inplace(ctx, out, out_diff);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
size_t get_extra_graph_size() override {
|
size_t get_extra_graph_size() override {
|
||||||
size_t lora_tensor_num = 0;
|
size_t lora_tensor_num = 0;
|
||||||
for (auto& lora_model : lora_models) {
|
for (auto& lora_model : lora_models) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user