mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add high noise lora support
This commit is contained in:
parent
6de680a94c
commit
eb3fed8b52
@ -1113,14 +1113,18 @@ __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend,
|
|||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) {
|
__STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) {
|
||||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16);
|
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_I32);
|
||||||
float value;
|
float value;
|
||||||
if (tensor->type == GGML_TYPE_F32) {
|
if (tensor->type == GGML_TYPE_F32) {
|
||||||
ggml_backend_tensor_get(tensor, &value, 0, sizeof(value));
|
ggml_backend_tensor_get(tensor, &value, 0, sizeof(value));
|
||||||
} else { // GGML_TYPE_F16
|
} else if (tensor->type == GGML_TYPE_F16) {
|
||||||
ggml_fp16_t f16_value;
|
ggml_fp16_t f16_value;
|
||||||
ggml_backend_tensor_get(tensor, &f16_value, 0, sizeof(f16_value));
|
ggml_backend_tensor_get(tensor, &f16_value, 0, sizeof(f16_value));
|
||||||
value = ggml_fp16_to_fp32(f16_value);
|
value = ggml_fp16_to_fp32(f16_value);
|
||||||
|
} else { // GGML_TYPE_I32
|
||||||
|
int int32_value;
|
||||||
|
ggml_backend_tensor_get(tensor, &int32_value, 0, sizeof(int32_value));
|
||||||
|
value = (float)int32_value;
|
||||||
}
|
}
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|||||||
12
lora.hpp
12
lora.hpp
@ -130,7 +130,7 @@ struct LoraModel : public GGMLRunner {
|
|||||||
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
// LOG_INFO("%s", name.c_str());
|
// LOG_INFO("lora_tensor %s", name.c_str());
|
||||||
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
|
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
|
||||||
if (name.find(type_fingerprints[i]) != std::string::npos) {
|
if (name.find(type_fingerprints[i]) != std::string::npos) {
|
||||||
type = (lora_t)i;
|
type = (lora_t)i;
|
||||||
@ -781,21 +781,18 @@ struct LoraModel : public GGMLRunner {
|
|||||||
|
|
||||||
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
||||||
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
|
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
|
||||||
|
applied_lora_tensors.insert(lora_up_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
||||||
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
|
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
|
||||||
|
applied_lora_tensors.insert(lora_down_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
|
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
|
||||||
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
|
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
|
||||||
applied_lora_tensors.insert(lora_mid_name);
|
applied_lora_tensors.insert(lora_mid_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
applied_lora_tensors.insert(lora_up_name);
|
|
||||||
applied_lora_tensors.insert(lora_down_name);
|
|
||||||
applied_lora_tensors.insert(alpha_name);
|
|
||||||
applied_lora_tensors.insert(scale_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lora_up == NULL || lora_down == NULL) {
|
if (lora_up == NULL || lora_down == NULL) {
|
||||||
@ -806,9 +803,12 @@ struct LoraModel : public GGMLRunner {
|
|||||||
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||||
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
||||||
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
||||||
|
applied_lora_tensors.insert(scale_name);
|
||||||
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||||
scale_value = alpha / rank;
|
scale_value = alpha / rank;
|
||||||
|
// LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value);
|
||||||
|
applied_lora_tensors.insert(alpha_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
|
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
|
||||||
|
|||||||
@ -607,7 +607,7 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
new_name = "lora." + name;
|
new_name = "lora." + name;
|
||||||
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
|
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
|
||||||
contains(name, "lora.up") || contains(name, "lora.down") ||
|
contains(name, "lora.up") || contains(name, "lora.down") ||
|
||||||
contains(name, "lora_linear")) {
|
contains(name, "lora_linear") || ends_with(name, ".alpha")) {
|
||||||
size_t pos = new_name.find(".processor");
|
size_t pos = new_name.find(".processor");
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
new_name.replace(pos, strlen(".processor"), "");
|
new_name.replace(pos, strlen(".processor"), "");
|
||||||
@ -615,7 +615,11 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
|
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
|
||||||
// new_name = "model.diffusion_model." + new_name;
|
// new_name = "model.diffusion_model." + new_name;
|
||||||
// }
|
// }
|
||||||
pos = new_name.rfind("lora");
|
if (ends_with(name, ".alpha")) {
|
||||||
|
pos = new_name.rfind("alpha");
|
||||||
|
} else {
|
||||||
|
pos = new_name.rfind("lora");
|
||||||
|
}
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
||||||
std::string network_part = new_name.substr(pos);
|
std::string network_part = new_name.substr(pos);
|
||||||
|
|||||||
@ -771,8 +771,15 @@ public:
|
|||||||
return result < -1;
|
return result < -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply_lora(const std::string& lora_name, float multiplier) {
|
void apply_lora(std::string lora_name, float multiplier) {
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
|
std::string high_noise_tag = "|high_noise|";
|
||||||
|
bool is_high_noise = false;
|
||||||
|
if (starts_with(lora_name, high_noise_tag)) {
|
||||||
|
lora_name = lora_name.substr(high_noise_tag.size());
|
||||||
|
is_high_noise = true;
|
||||||
|
LOG_DEBUG("high noise lora: %s", lora_name.c_str());
|
||||||
|
}
|
||||||
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
|
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
|
||||||
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
|
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
|
||||||
std::string file_path;
|
std::string file_path;
|
||||||
@ -784,7 +791,7 @@ public:
|
|||||||
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
|
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
LoraModel lora(backend, file_path);
|
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
|
||||||
if (!lora.load_from_file()) {
|
if (!lora.load_from_file()) {
|
||||||
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
|
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
|
||||||
return;
|
return;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user