mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: support for --tensor-type-rules on generation modes (#932)
This commit is contained in:
parent
742a7333c3
commit
199e675cc7
@ -1241,10 +1241,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
|
||||
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
|
||||
}
|
||||
|
||||
if (params.mode == VID_GEN && params.video_frames <= 0) {
|
||||
fprintf(stderr, "warning: --video-frames must be at least 1\n");
|
||||
exit(1);
|
||||
@ -1756,6 +1752,7 @@ int main(int argc, const char* argv[]) {
|
||||
params.lora_model_dir.c_str(),
|
||||
params.embedding_dir.c_str(),
|
||||
params.photo_maker_path.c_str(),
|
||||
params.tensor_type_rules.c_str(),
|
||||
vae_decode_only,
|
||||
true,
|
||||
params.n_threads,
|
||||
|
||||
87
model.cpp
87
model.cpp
@ -1254,15 +1254,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
|
||||
return wtype_stat;
|
||||
}
|
||||
|
||||
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
|
||||
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
std::vector<std::pair<std::string, ggml_type>> result;
|
||||
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||
if (item.size() == 0)
|
||||
continue;
|
||||
std::string::size_type pos = item.find('=');
|
||||
if (pos == std::string::npos) {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string tensor_pattern = item.substr(0, pos);
|
||||
std::string type_name = item.substr(pos + 1);
|
||||
|
||||
ggml_type tensor_type = GGML_TYPE_COUNT;
|
||||
|
||||
if (type_name == "f32") {
|
||||
tensor_type = GGML_TYPE_F32;
|
||||
} else {
|
||||
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
auto trait = ggml_get_type_traits((ggml_type)i);
|
||||
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
|
||||
tensor_type = (ggml_type)i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_type != GGML_TYPE_COUNT) {
|
||||
result.emplace_back(tensor_pattern, tensor_type);
|
||||
} else {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
|
||||
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
ggml_type dst_type = wtype;
|
||||
for (const auto& tensor_type_rule : map_rules) {
|
||||
std::regex pattern(tensor_type_rule.first);
|
||||
if (std::regex_search(name, pattern)) {
|
||||
dst_type = tensor_type_rule.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (dst_type == GGML_TYPE_COUNT) {
|
||||
continue;
|
||||
}
|
||||
if (!tensor_should_be_converted(tensor_storage, wtype)) {
|
||||
if (!tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||
continue;
|
||||
}
|
||||
tensor_storage.expected_type = wtype;
|
||||
tensor_storage.expected_type = dst_type;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1603,41 +1647,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
std::vector<std::pair<std::string, ggml_type>> result;
|
||||
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||
if (item.size() == 0)
|
||||
continue;
|
||||
std::string::size_type pos = item.find('=');
|
||||
if (pos == std::string::npos) {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string tensor_pattern = item.substr(0, pos);
|
||||
std::string type_name = item.substr(pos + 1);
|
||||
|
||||
ggml_type tensor_type = GGML_TYPE_COUNT;
|
||||
|
||||
if (type_name == "f32") {
|
||||
tensor_type = GGML_TYPE_F32;
|
||||
} else {
|
||||
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
auto trait = ggml_get_type_traits((ggml_type)i);
|
||||
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
|
||||
tensor_type = (ggml_type)i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_type != GGML_TYPE_COUNT) {
|
||||
result.emplace_back(tensor_pattern, tensor_type);
|
||||
} else {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
|
||||
const std::string& name = tensor_storage.name;
|
||||
if (type != GGML_TYPE_COUNT) {
|
||||
|
||||
2
model.h
2
model.h
@ -292,7 +292,7 @@ public:
|
||||
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
|
||||
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
|
||||
void set_wtype_override(ggml_type wtype, std::string prefix = "");
|
||||
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
|
||||
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
|
||||
@ -307,8 +307,9 @@ public:
|
||||
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
||||
? (ggml_type)sd_ctx_params->wtype
|
||||
: GGML_TYPE_COUNT;
|
||||
if (wtype != GGML_TYPE_COUNT) {
|
||||
model_loader.set_wtype_override(wtype);
|
||||
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
|
||||
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
|
||||
model_loader.set_wtype_override(wtype, tensor_type_rules);
|
||||
}
|
||||
|
||||
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
|
||||
@ -2325,6 +2326,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"lora_model_dir: %s\n"
|
||||
"embedding_dir: %s\n"
|
||||
"photo_maker_path: %s\n"
|
||||
"tensor_type_rules: %s\n"
|
||||
"vae_decode_only: %s\n"
|
||||
"free_params_immediately: %s\n"
|
||||
"n_threads: %d\n"
|
||||
@ -2354,6 +2356,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->lora_model_dir),
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||
SAFE_STR(sd_ctx_params->tensor_type_rules),
|
||||
BOOL_STR(sd_ctx_params->vae_decode_only),
|
||||
BOOL_STR(sd_ctx_params->free_params_immediately),
|
||||
sd_ctx_params->n_threads,
|
||||
|
||||
@ -167,6 +167,7 @@ typedef struct {
|
||||
const char* lora_model_dir;
|
||||
const char* embedding_dir;
|
||||
const char* photo_maker_path;
|
||||
const char* tensor_type_rules;
|
||||
bool vae_decode_only;
|
||||
bool free_params_immediately;
|
||||
int n_threads;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user