mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
Compare commits
2 Commits
65891d74cc
...
f0f641a142
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0f641a142 | ||
|
|
9f56833e14 |
@ -263,6 +263,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
log_print(level, log, svr_params->verbose, svr_params->color);
|
log_print(level, log, svr_params->verbose, svr_params->color);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LoraEntry {
|
||||||
|
std::string name;
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, const char** argv) {
|
int main(int argc, const char** argv) {
|
||||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||||
std::cout << version_string() << "\n";
|
std::cout << version_string() << "\n";
|
||||||
@ -293,6 +298,54 @@ int main(int argc, const char** argv) {
|
|||||||
|
|
||||||
std::mutex sd_ctx_mutex;
|
std::mutex sd_ctx_mutex;
|
||||||
|
|
||||||
|
std::vector<LoraEntry> lora_cache;
|
||||||
|
std::mutex lora_mutex;
|
||||||
|
|
||||||
|
auto refresh_lora_cache = [&]() {
|
||||||
|
std::vector<LoraEntry> new_cache;
|
||||||
|
|
||||||
|
fs::path lora_dir = ctx_params.lora_model_dir;
|
||||||
|
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
||||||
|
auto is_lora_ext = [](const fs::path& p) {
|
||||||
|
auto ext = p.extension().string();
|
||||||
|
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||||
|
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
|
||||||
|
if (!entry.is_regular_file())
|
||||||
|
continue;
|
||||||
|
const fs::path& p = entry.path();
|
||||||
|
if (!is_lora_ext(p))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
LoraEntry e;
|
||||||
|
e.name = p.stem().u8string();
|
||||||
|
std::string rel = fs::relative(p, lora_dir).u8string();
|
||||||
|
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||||
|
e.path = rel;
|
||||||
|
|
||||||
|
new_cache.push_back(std::move(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(new_cache.begin(), new_cache.end(),
|
||||||
|
[](const LoraEntry& a, const LoraEntry& b) {
|
||||||
|
return a.path < b.path;
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
lora_cache = std::move(new_cache);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto is_valid_lora_path = [&](const std::string& path) -> bool {
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
return std::any_of(lora_cache.begin(), lora_cache.end(),
|
||||||
|
[&](const LoraEntry& e) { return e.path == path; });
|
||||||
|
};
|
||||||
|
|
||||||
httplib::Server svr;
|
httplib::Server svr;
|
||||||
|
|
||||||
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
||||||
@ -312,7 +365,7 @@ int main(int argc, const char** argv) {
|
|||||||
return httplib::Server::HandlerResponse::Unhandled;
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
});
|
});
|
||||||
|
|
||||||
// health
|
// root
|
||||||
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
if (!svr_params.serve_html_path.empty()) {
|
if (!svr_params.serve_html_path.empty()) {
|
||||||
std::ifstream file(svr_params.serve_html_path);
|
std::ifstream file(svr_params.serve_html_path);
|
||||||
@ -767,6 +820,37 @@ int main(int argc, const char** argv) {
|
|||||||
return bad("prompt required");
|
return bad("prompt required");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<sd_lora_t> sd_loras;
|
||||||
|
std::vector<std::string> lora_path_storage;
|
||||||
|
|
||||||
|
if (j.contains("lora") && j["lora"].is_array()) {
|
||||||
|
for (const auto& item : j["lora"]) {
|
||||||
|
if (!item.is_object()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string path = item.value("path", "");
|
||||||
|
float multiplier = item.value("multiplier", 1.0f);
|
||||||
|
bool is_high_noise = item.value("is_high_noise", false);
|
||||||
|
|
||||||
|
if (path.empty()) {
|
||||||
|
return bad("lora.path required");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_valid_lora_path(path)) {
|
||||||
|
return bad("invalid lora path: " + path);
|
||||||
|
}
|
||||||
|
|
||||||
|
lora_path_storage.push_back(path);
|
||||||
|
sd_lora_t l;
|
||||||
|
l.is_high_noise = is_high_noise;
|
||||||
|
l.multiplier = multiplier;
|
||||||
|
l.path = lora_path_storage.back().c_str();
|
||||||
|
|
||||||
|
sd_loras.push_back(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||||
if (result != SAMPLE_METHOD_COUNT) return result;
|
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||||
@ -894,8 +978,8 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sd_img_gen_params_t img_gen_params = {
|
sd_img_gen_params_t img_gen_params = {
|
||||||
gen_params.lora_vec.data(),
|
sd_loras.data(),
|
||||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
static_cast<uint32_t>(sd_loras.size()),
|
||||||
gen_params.prompt.c_str(),
|
gen_params.prompt.c_str(),
|
||||||
gen_params.negative_prompt.c_str(),
|
gen_params.negative_prompt.c_str(),
|
||||||
gen_params.clip_skip,
|
gen_params.clip_skip,
|
||||||
@ -987,6 +1071,23 @@ int main(int argc, const char** argv) {
|
|||||||
sdapi_any2img(req, res, true);
|
sdapi_any2img(req, res, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
refresh_lora_cache();
|
||||||
|
|
||||||
|
json result = json::array();
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
for (const auto& e : lora_cache) {
|
||||||
|
json item;
|
||||||
|
item["name"] = e.name;
|
||||||
|
item["path"] = e.path;
|
||||||
|
result.push_back(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res.set_content(result.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
std::vector<std::string> sampler_names;
|
std::vector<std::string> sampler_names;
|
||||||
sampler_names.push_back("default");
|
sampler_names.push_back("default");
|
||||||
|
|||||||
171
ggml_extend.hpp
171
ggml_extend.hpp
@ -1577,7 +1577,7 @@ struct WeightAdapter {
|
|||||||
bool force_prec_f32 = false;
|
bool force_prec_f32 = false;
|
||||||
float scale = 1.f;
|
float scale = 1.f;
|
||||||
} linear;
|
} linear;
|
||||||
struct {
|
struct conv2d_params_t {
|
||||||
int s0 = 1;
|
int s0 = 1;
|
||||||
int s1 = 1;
|
int s1 = 1;
|
||||||
int p0 = 0;
|
int p0 = 0;
|
||||||
@ -2630,4 +2630,173 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
|
||||||
|
struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
|
||||||
|
struct ggml_tensor* w1, // Outer C (Full rank)
|
||||||
|
struct ggml_tensor* w1a, // Outer A (Low rank part 1)
|
||||||
|
struct ggml_tensor* w1b, // Outer B (Low rank part 2)
|
||||||
|
struct ggml_tensor* w2, // Inner BA (Full rank)
|
||||||
|
struct ggml_tensor* w2a, // Inner A (Low rank part 1)
|
||||||
|
struct ggml_tensor* w2b, // Inner B (Low rank part 2)
|
||||||
|
bool is_conv,
|
||||||
|
WeightAdapter::ForwardParams::conv2d_params_t conv_params,
|
||||||
|
float scale) {
|
||||||
|
GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL)));
|
||||||
|
GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL)));
|
||||||
|
|
||||||
|
int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0];
|
||||||
|
int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1];
|
||||||
|
|
||||||
|
int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0];
|
||||||
|
int vq = q_actual / uq;
|
||||||
|
|
||||||
|
int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1])
|
||||||
|
: (int)w2a->ne[1];
|
||||||
|
GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split");
|
||||||
|
|
||||||
|
struct ggml_tensor* hb;
|
||||||
|
|
||||||
|
if (!is_conv) {
|
||||||
|
int batch = (int)h->ne[1];
|
||||||
|
int merge_batch_uq = batch;
|
||||||
|
int merge_batch_vp = batch;
|
||||||
|
|
||||||
|
#if SD_USE_VULKAN
|
||||||
|
if (batch > 1) {
|
||||||
|
// no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
|
||||||
|
int max_batch = 65535;
|
||||||
|
int max_batch_uq = max_batch / uq;
|
||||||
|
merge_batch_uq = 1;
|
||||||
|
for (int i = max_batch_uq; i > 0; i--) {
|
||||||
|
if (batch % i == 0) {
|
||||||
|
merge_batch_uq = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int max_batch_vp = max_batch / vp;
|
||||||
|
merge_batch_vp = 1;
|
||||||
|
for (int i = max_batch_vp; i > 0; i--) {
|
||||||
|
if (batch % i == 0) {
|
||||||
|
merge_batch_vp = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq);
|
||||||
|
if (w2 != NULL) {
|
||||||
|
hb = ggml_mul_mat(ctx, w2, h_split);
|
||||||
|
} else {
|
||||||
|
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch > 1) {
|
||||||
|
hb = ggml_reshape_3d(ctx, hb, vp, uq, batch);
|
||||||
|
}
|
||||||
|
struct ggml_tensor* hb_t = ggml_cont(ctx, ggml_transpose(ctx, hb));
|
||||||
|
hb_t = ggml_reshape_3d(ctx, hb_t, uq, vp * merge_batch_vp, batch / merge_batch_vp);
|
||||||
|
|
||||||
|
struct ggml_tensor* hc_t;
|
||||||
|
if (w1 != NULL) {
|
||||||
|
hc_t = ggml_mul_mat(ctx, w1, hb_t);
|
||||||
|
} else {
|
||||||
|
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch > 1) {
|
||||||
|
hc_t = ggml_reshape_3d(ctx, hc_t, up, vp, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
|
||||||
|
struct ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc), up * vp, batch);
|
||||||
|
return ggml_scale(ctx, out, scale);
|
||||||
|
} else {
|
||||||
|
int batch = (int)h->ne[3];
|
||||||
|
// 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
|
||||||
|
struct ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);
|
||||||
|
|
||||||
|
if (w2 != NULL) {
|
||||||
|
hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr,
|
||||||
|
conv_params.s0,
|
||||||
|
conv_params.s1,
|
||||||
|
conv_params.p0,
|
||||||
|
conv_params.p1,
|
||||||
|
conv_params.d0,
|
||||||
|
conv_params.d1,
|
||||||
|
conv_params.direct,
|
||||||
|
conv_params.circular_x,
|
||||||
|
conv_params.circular_y,
|
||||||
|
conv_params.scale);
|
||||||
|
} else {
|
||||||
|
// swap a and b order for conv lora
|
||||||
|
struct ggml_tensor* a = w2b;
|
||||||
|
struct ggml_tensor* b = w2a;
|
||||||
|
|
||||||
|
// unpack conv2d weights if needed
|
||||||
|
if (ggml_n_dims(a) < 4) {
|
||||||
|
int k = (int)sqrt(a->ne[0] / h_split->ne[2]);
|
||||||
|
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[0]);
|
||||||
|
a = ggml_reshape_4d(ctx, a, k, k, a->ne[0] / (k * k), a->ne[1]);
|
||||||
|
} else if (a->ne[2] != h_split->ne[2]) {
|
||||||
|
int k = (int)sqrt(a->ne[2] / h_split->ne[2]);
|
||||||
|
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[2]);
|
||||||
|
a = ggml_reshape_4d(ctx, a, a->ne[0] * k, a->ne[1] * k, a->ne[2] / (k * k), a->ne[3]);
|
||||||
|
}
|
||||||
|
struct ggml_tensor* ha = ggml_ext_conv_2d(ctx, h_split, a, nullptr,
|
||||||
|
conv_params.s0,
|
||||||
|
conv_params.s1,
|
||||||
|
conv_params.p0,
|
||||||
|
conv_params.p1,
|
||||||
|
conv_params.d0,
|
||||||
|
conv_params.d1,
|
||||||
|
conv_params.direct,
|
||||||
|
conv_params.circular_x,
|
||||||
|
conv_params.circular_y,
|
||||||
|
conv_params.scale);
|
||||||
|
|
||||||
|
// not supporting lora_mid here
|
||||||
|
hb = ggml_ext_conv_2d(ctx,
|
||||||
|
ha,
|
||||||
|
b,
|
||||||
|
nullptr,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
conv_params.direct,
|
||||||
|
conv_params.circular_x,
|
||||||
|
conv_params.circular_y,
|
||||||
|
conv_params.scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current hb shape: [W_out, H_out, vp, uq * batch]
|
||||||
|
int w_out = (int)hb->ne[0];
|
||||||
|
int h_out = (int)hb->ne[1];
|
||||||
|
|
||||||
|
// struct ggml_tensor* hb_cat = ggml_reshape_4d(ctx, hb, w_out , h_out , vp * uq, batch);
|
||||||
|
// [W_out, H_out, vp * uq, batch]
|
||||||
|
// Now left to compute (W1 kr Id) * hb_cat == (W1 kr W2) cv h
|
||||||
|
|
||||||
|
// merge the uq groups of size vp*w_out*h_out
|
||||||
|
struct ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch);
|
||||||
|
struct ggml_tensor* hc_t;
|
||||||
|
struct ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged));
|
||||||
|
if (w1 != NULL) {
|
||||||
|
// Would be great to be able to transpose w1 instead to avoid transposing both hb and hc
|
||||||
|
hc_t = ggml_mul_mat(ctx, w1, hb_merged_t);
|
||||||
|
} else {
|
||||||
|
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_merged_t));
|
||||||
|
}
|
||||||
|
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
|
||||||
|
// ungroup
|
||||||
|
struct ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc), w_out, h_out, up * vp, batch);
|
||||||
|
return ggml_scale(ctx, out, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#endif // __GGML_EXTEND__HPP__
|
#endif // __GGML_EXTEND__HPP__
|
||||||
|
|||||||
116
lora.hpp
116
lora.hpp
@ -468,10 +468,10 @@ struct LoraModel : public GGMLRunner {
|
|||||||
return updown;
|
return updown;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) {
|
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
|
||||||
// lora
|
// lora
|
||||||
ggml_tensor* diff = nullptr;
|
ggml_tensor* diff = nullptr;
|
||||||
if (with_lora) {
|
if (with_lora_and_lokr) {
|
||||||
diff = get_lora_weight_diff(model_tensor_name, ctx);
|
diff = get_lora_weight_diff(model_tensor_name, ctx);
|
||||||
}
|
}
|
||||||
// diff
|
// diff
|
||||||
@ -483,7 +483,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
diff = get_loha_weight_diff(model_tensor_name, ctx);
|
diff = get_loha_weight_diff(model_tensor_name, ctx);
|
||||||
}
|
}
|
||||||
// lokr
|
// lokr
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr && with_lora_and_lokr) {
|
||||||
diff = get_lokr_weight_diff(model_tensor_name, ctx);
|
diff = get_lokr_weight_diff(model_tensor_name, ctx);
|
||||||
}
|
}
|
||||||
if (diff != nullptr) {
|
if (diff != nullptr) {
|
||||||
@ -514,6 +514,108 @@ struct LoraModel : public GGMLRunner {
|
|||||||
} else {
|
} else {
|
||||||
key = model_tensor_name + "." + std::to_string(index);
|
key = model_tensor_name + "." + std::to_string(index);
|
||||||
}
|
}
|
||||||
|
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||||
|
|
||||||
|
std::string lokr_w1_name = "lora." + key + ".lokr_w1";
|
||||||
|
std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a";
|
||||||
|
// if either of these is found, then we have a lokr lora
|
||||||
|
auto iter = lora_tensors.find(lokr_w1_name);
|
||||||
|
auto iter_a = lora_tensors.find(lokr_w1_a_name);
|
||||||
|
if (iter != lora_tensors.end() || iter_a != lora_tensors.end()) {
|
||||||
|
std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b";
|
||||||
|
std::string lokr_w2_name = "lora." + key + ".lokr_w2";
|
||||||
|
std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a";
|
||||||
|
std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b";
|
||||||
|
std::string alpha_name = "lora." + key + ".alpha";
|
||||||
|
|
||||||
|
ggml_tensor* lokr_w1 = nullptr;
|
||||||
|
ggml_tensor* lokr_w1_a = nullptr;
|
||||||
|
ggml_tensor* lokr_w1_b = nullptr;
|
||||||
|
ggml_tensor* lokr_w2 = nullptr;
|
||||||
|
ggml_tensor* lokr_w2_a = nullptr;
|
||||||
|
ggml_tensor* lokr_w2_b = nullptr;
|
||||||
|
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w1 = iter->second;
|
||||||
|
}
|
||||||
|
iter = iter_a;
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w1_a = iter->second;
|
||||||
|
}
|
||||||
|
iter = lora_tensors.find(lokr_w1_b_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w1_b = iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
iter = lora_tensors.find(lokr_w2_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w2 = iter->second;
|
||||||
|
if (is_conv2d && lokr_w2->type != GGML_TYPE_F16) {
|
||||||
|
lokr_w2 = ggml_cast(ctx, lokr_w2, GGML_TYPE_F16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iter = lora_tensors.find(lokr_w2_a_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w2_a = iter->second;
|
||||||
|
if (is_conv2d && lokr_w2_a->type != GGML_TYPE_F16) {
|
||||||
|
lokr_w2_a = ggml_cast(ctx, lokr_w2_a, GGML_TYPE_F16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iter = lora_tensors.find(lokr_w2_b_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
lokr_w2_b = iter->second;
|
||||||
|
if (is_conv2d && lokr_w2_b->type != GGML_TYPE_F16) {
|
||||||
|
lokr_w2_b = ggml_cast(ctx, lokr_w2_b, GGML_TYPE_F16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank = 1;
|
||||||
|
if (lokr_w1_b) {
|
||||||
|
rank = (int)lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1];
|
||||||
|
}
|
||||||
|
if (lokr_w2_b) {
|
||||||
|
rank = (int)lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale_value = 1.0f;
|
||||||
|
iter = lora_tensors.find(alpha_name);
|
||||||
|
if (iter != lora_tensors.end()) {
|
||||||
|
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
|
||||||
|
scale_value = alpha / rank;
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rank == 1) {
|
||||||
|
scale_value = 1.0f;
|
||||||
|
}
|
||||||
|
scale_value *= multiplier;
|
||||||
|
|
||||||
|
auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
|
||||||
|
if (out_diff == nullptr) {
|
||||||
|
out_diff = curr_out_diff;
|
||||||
|
} else {
|
||||||
|
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lokr_w1)
|
||||||
|
applied_lora_tensors.insert(lokr_w1_name);
|
||||||
|
if (lokr_w1_a)
|
||||||
|
applied_lora_tensors.insert(lokr_w1_a_name);
|
||||||
|
if (lokr_w1_b)
|
||||||
|
applied_lora_tensors.insert(lokr_w1_b_name);
|
||||||
|
if (lokr_w2)
|
||||||
|
applied_lora_tensors.insert(lokr_w2_name);
|
||||||
|
if (lokr_w2_a)
|
||||||
|
applied_lora_tensors.insert(lokr_w2_name);
|
||||||
|
if (lokr_w2_b)
|
||||||
|
applied_lora_tensors.insert(lokr_w2_b_name);
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
|
|
||||||
|
index++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// not a lokr, normal lora path
|
||||||
|
|
||||||
std::string lora_down_name = "lora." + key + ".lora_down";
|
std::string lora_down_name = "lora." + key + ".lora_down";
|
||||||
std::string lora_up_name = "lora." + key + ".lora_up";
|
std::string lora_up_name = "lora." + key + ".lora_up";
|
||||||
@ -525,9 +627,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
ggml_tensor* lora_mid = nullptr;
|
ggml_tensor* lora_mid = nullptr;
|
||||||
ggml_tensor* lora_down = nullptr;
|
ggml_tensor* lora_down = nullptr;
|
||||||
|
|
||||||
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
iter = lora_tensors.find(lora_up_name);
|
||||||
|
|
||||||
auto iter = lora_tensors.find(lora_up_name);
|
|
||||||
if (iter != lora_tensors.end()) {
|
if (iter != lora_tensors.end()) {
|
||||||
lora_up = iter->second;
|
lora_up = iter->second;
|
||||||
if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
|
if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
|
||||||
@ -741,9 +841,9 @@ public:
|
|||||||
: lora_models(lora_models) {
|
: lora_models(lora_models) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) {
|
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
|
||||||
for (auto& lora_model : lora_models) {
|
for (auto& lora_model : lora_models) {
|
||||||
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora);
|
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr);
|
||||||
if (diff == nullptr) {
|
if (diff == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user