feat: support for --tensor-type-rules on generation modes (#932)

This commit is contained in:
Wagner Bruna 2025-11-16 06:07:32 -03:00 committed by GitHub
parent 742a7333c3
commit 199e675cc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 49 deletions

View File

@ -1241,10 +1241,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
}
if (params.mode == VID_GEN && params.video_frames <= 0) {
fprintf(stderr, "warning: --video-frames must be at least 1\n");
exit(1);
@ -1756,6 +1752,7 @@ int main(int argc, const char* argv[]) {
params.lora_model_dir.c_str(),
params.embedding_dir.c_str(),
params.photo_maker_path.c_str(),
params.tensor_type_rules.c_str(),
vae_decode_only,
true,
params.n_threads,

View File

@ -1254,15 +1254,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
return wtype_stat;
}
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);
ggml_type tensor_type = GGML_TYPE_COUNT;
if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}
if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}
void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
ggml_type dst_type = wtype;
for (const auto& tensor_type_rule : map_rules) {
std::regex pattern(tensor_type_rule.first);
if (std::regex_search(name, pattern)) {
dst_type = tensor_type_rule.second;
break;
}
}
if (dst_type == GGML_TYPE_COUNT) {
continue;
}
if (!tensor_should_be_converted(tensor_storage, wtype)) {
if (!tensor_should_be_converted(tensor_storage, dst_type)) {
continue;
}
tensor_storage.expected_type = wtype;
tensor_storage.expected_type = dst_type;
}
}
@ -1603,41 +1647,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
}
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);
ggml_type tensor_type = GGML_TYPE_COUNT;
if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}
if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
const std::string& name = tensor_storage.name;
if (type != GGML_TYPE_COUNT) {

View File

@ -292,7 +292,7 @@ public:
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string prefix = "");
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},

View File

@ -307,8 +307,9 @@ public:
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
if (wtype != GGML_TYPE_COUNT) {
model_loader.set_wtype_override(wtype);
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
model_loader.set_wtype_override(wtype, tensor_type_rules);
}
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
@ -2325,6 +2326,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"lora_model_dir: %s\n"
"embedding_dir: %s\n"
"photo_maker_path: %s\n"
"tensor_type_rules: %s\n"
"vae_decode_only: %s\n"
"free_params_immediately: %s\n"
"n_threads: %d\n"
@ -2354,6 +2356,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->lora_model_dir),
SAFE_STR(sd_ctx_params->embedding_dir),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
BOOL_STR(sd_ctx_params->vae_decode_only),
BOOL_STR(sd_ctx_params->free_params_immediately),
sd_ctx_params->n_threads,

View File

@ -167,6 +167,7 @@ typedef struct {
const char* lora_model_dir;
const char* embedding_dir;
const char* photo_maker_path;
const char* tensor_type_rules;
bool vae_decode_only;
bool free_params_immediately;
int n_threads;