feat: add wtype stat (#899)

This commit is contained in:
leejet 2025-10-17 23:40:32 +08:00 committed by GitHub
parent b25785bc10
commit db6f4791b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 71 deletions

View File

@ -1892,24 +1892,25 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_COUNT;
}
ggml_type ModelLoader::get_sd_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}
if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}
ggml_type ModelLoader::get_conditioner_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
@ -1922,18 +1923,18 @@ ggml_type ModelLoader::get_conditioner_wtype() {
continue;
}
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}
if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}
ggml_type ModelLoader::get_diffusion_model_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
@ -1943,18 +1944,18 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
continue;
}
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}
if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}
ggml_type ModelLoader::get_vae_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
@ -1965,15 +1966,14 @@ ggml_type ModelLoader::get_vae_wtype() {
continue;
}
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}
if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {

View File

@ -259,10 +259,10 @@ public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
bool model_is_unet();
SDVersion get_sd_version();
ggml_type get_sd_wtype();
ggml_type get_conditioner_wtype();
ggml_type get_diffusion_model_wtype();
ggml_type get_vae_wtype();
std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
void set_wtype_override(ggml_type wtype, std::string prefix = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,

View File

@ -86,10 +86,6 @@ public:
ggml_backend_t clip_backend = NULL;
ggml_backend_t control_net_backend = NULL;
ggml_backend_t vae_backend = NULL;
ggml_type model_wtype = GGML_TYPE_COUNT;
ggml_type conditioner_wtype = GGML_TYPE_COUNT;
ggml_type diffusion_model_wtype = GGML_TYPE_COUNT;
ggml_type vae_wtype = GGML_TYPE_COUNT;
SDVersion version;
bool vae_decode_only = false;
@ -294,37 +290,33 @@ public:
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
if (wtype == GGML_TYPE_COUNT) {
model_wtype = model_loader.get_sd_wtype();
if (model_wtype == GGML_TYPE_COUNT) {
model_wtype = GGML_TYPE_F32;
LOG_WARN("can not get mode wtype frome weight, use f32");
}
conditioner_wtype = model_loader.get_conditioner_wtype();
if (conditioner_wtype == GGML_TYPE_COUNT) {
conditioner_wtype = wtype;
}
diffusion_model_wtype = model_loader.get_diffusion_model_wtype();
if (diffusion_model_wtype == GGML_TYPE_COUNT) {
diffusion_model_wtype = wtype;
}
vae_wtype = model_loader.get_vae_wtype();
if (vae_wtype == GGML_TYPE_COUNT) {
vae_wtype = wtype;
}
} else {
model_wtype = wtype;
conditioner_wtype = wtype;
diffusion_model_wtype = wtype;
vae_wtype = wtype;
if (wtype != GGML_TYPE_COUNT) {
model_loader.set_wtype_override(wtype);
}
LOG_INFO("Weight type: %s", ggml_type_name(model_wtype));
LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype));
LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype));
LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype));
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
std::map<ggml_type, uint32_t> conditioner_wtype_stat = model_loader.get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> diffusion_model_wtype_stat = model_loader.get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> vae_wtype_stat = model_loader.get_vae_wtype_stat();
auto wtype_stat_to_str = [](const std::map<ggml_type, uint32_t>& m, int key_width = 8, int value_width = 5) -> std::string {
std::ostringstream oss;
bool first = true;
for (const auto& [type, count] : m) {
if (!first)
oss << "|";
first = false;
oss << std::right << std::setw(key_width) << ggml_type_name(type)
<< ": "
<< std::left << std::setw(value_width) << count;
}
return oss.str();
};
LOG_INFO("Weight type stat: %s", wtype_stat_to_str(wtype_stat).c_str());
LOG_INFO("Conditioner weight type stat: %s", wtype_stat_to_str(conditioner_wtype_stat).c_str());
LOG_INFO("Diffusion model weight type stat: %s", wtype_stat_to_str(diffusion_model_wtype_stat).c_str());
LOG_INFO("VAE weight type stat: %s", wtype_stat_to_str(vae_wtype_stat).c_str());
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
@ -938,9 +930,6 @@ public:
}
void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
if (lora_state.size() > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
LOG_WARN("In quantized models when applying LoRA, the images have poor quality.");
}
std::unordered_map<std::string, float> lora_state_diff;
for (auto& kv : lora_state) {
const std::string& lora_name = kv.first;