feat: overriding quant types for specific tensors on model conversion (#724)

This commit is contained in:
Wagner Bruna 2025-07-07 13:11:38 -03:00 committed by GitHub
parent dafc32d0dd
commit 6d84a30c66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 71 additions and 14 deletions

View File

@ -87,6 +87,7 @@ struct SDParams {
std::string stacked_id_embeddings_path;
std::string input_id_images_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
@ -223,6 +224,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
printf(" If not specified, the default is the type of the weight file\n");
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@ -404,6 +406,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
valid_types.c_str());
exit(1);
}
} else if (arg == "--tensor-type-rules") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.tensor_type_rules = argv[i];
} else if (arg == "--lora-model-dir") {
if (++i >= argc) {
invalid_arg = true;
@ -733,6 +741,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
}
if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
@ -845,7 +857,7 @@ int main(int argc, const char* argv[]) {
}
if (params.mode == CONVERT) {
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str());
if (!success) {
fprintf(stderr,
"convert '%s'/'%s' to '%s' failed\n",

View File

@ -100,7 +100,7 @@ const char* unused_tensors[] = {
"model_ema.diffusion_model",
"embedding_manager",
"denoiser.sigmas",
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
};
bool is_unused_tensor(std::string name) {
@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
n_dims = 1;
}
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
};
int tensor_count = 0;
int64_t t1 = ggml_time_ms();
bool partial = false;
bool partial = false;
for (auto& tensor_storage : processed_tensor_storages) {
if (tensor_storage.file_index != file_index) {
++tensor_count;
@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}
}
size_t tensor_max = processed_tensor_storages.size();
int64_t t2 = ggml_time_ms();
int64_t t2 = ggml_time_ms();
pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f);
t1 = t2;
t1 = t2;
partial = tensor_count != tensor_max;
}
@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
}
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : splitString(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);
ggml_type tensor_type = GGML_TYPE_COUNT;
if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}
if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
const std::string& name = tensor_storage.name;
if (type != GGML_TYPE_COUNT) {
@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
return false;
}
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += tensor_storages.size() * ggml_tensor_overhead();
@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
gguf_context* gguf_ctx = gguf_init_empty();
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
ggml_type tensor_type = tensor_storage.type;
ggml_type dst_type = type;
ggml_type tensor_type = tensor_storage.type;
if (tensor_should_be_converted(tensor_storage, type)) {
tensor_type = type;
for (const auto& tensor_type_rule : tensor_type_rules) {
std::regex pattern(tensor_type_rule.first);
if (std::regex_search(name, pattern)) {
dst_type = tensor_type_rule.second;
break;
}
}
if (tensor_should_be_converted(tensor_storage, dst_type)) {
tensor_type = dst_type;
}
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
return mem_size;
}
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) {
ModelLoader model_loader;
if (!model_loader.init_from_file(input_path)) {
@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
return false;
}
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
return success;
}

View File

@ -222,7 +222,7 @@ public:
ggml_backend_t backend,
std::set<std::string> ignore_tensors = {});
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;

View File

@ -257,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* tensor_type_rules);
SD_API uint8_t* preprocess_canny(uint8_t* img,
int width,