mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add support for applying lora to quantized tensors
This commit is contained in:
parent
e91ce4f103
commit
29ec31644a
31
lora.hpp
31
lora.hpp
@ -12,6 +12,8 @@ struct LoraModel : public GGMLRunner {
|
||||
ModelLoader model_loader;
|
||||
bool load_failed = false;
|
||||
bool applied = false;
|
||||
std::vector<int> zero_index_vec = {0};
|
||||
ggml_tensor* zero_index = NULL;
|
||||
|
||||
LoraModel(ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
@ -68,9 +70,19 @@ struct LoraModel : public GGMLRunner {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor* to_f32(ggml_context* ctx, ggml_tensor* a) {
|
||||
auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a));
|
||||
out = ggml_get_rows(ctx, out, zero_index);
|
||||
out = ggml_reshape(ctx, out, a);
|
||||
return out;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
|
||||
|
||||
zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
|
||||
set_backend_tensor_data(zero_index, zero_index_vec.data());
|
||||
|
||||
std::set<std::string> applied_lora_tensors;
|
||||
for (auto it : model_tensors) {
|
||||
std::string k_tensor = it.first;
|
||||
@ -141,15 +153,16 @@ struct LoraModel : public GGMLRunner {
|
||||
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
||||
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
||||
ggml_tensor* final_weight;
|
||||
// if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
|
||||
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, weight->n_dims, weight->ne);
|
||||
// final_weight = ggml_cpy_inplace(compute_ctx, weight, final_weight);
|
||||
// final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
|
||||
// final_weight = ggml_cpy_inplace(compute_ctx, final_weight, weight);
|
||||
// } else {
|
||||
// final_weight = ggml_add_inplace(compute_ctx, weight, updown);
|
||||
// }
|
||||
final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
|
||||
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
|
||||
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
|
||||
// final_weight = ggml_cpy(compute_ctx, weight, final_weight);
|
||||
final_weight = to_f32(compute_ctx, weight);
|
||||
final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
|
||||
final_weight = ggml_cpy(compute_ctx, final_weight, weight);
|
||||
} else {
|
||||
final_weight = ggml_add_inplace(compute_ctx, weight, updown);
|
||||
}
|
||||
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
|
||||
ggml_build_forward_expand(gf, final_weight);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user