mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-17 11:46:38 +00:00
fix: correct conversion from sd_type_t to ggml_type (#1519)
This commit is contained in:
parent
5a34bc7f6e
commit
710bc91c8f
@ -99,7 +99,7 @@ bool convert(const char* input_path,
|
|||||||
model_loader.convert_tensors_name();
|
model_loader.convert_tensors_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_type type = (ggml_type)output_type;
|
ggml_type type = sd_type_to_ggml_type(output_type);
|
||||||
bool output_is_safetensors = ends_with(output_path, ".safetensors");
|
bool output_is_safetensors = ends_with(output_path, ".safetensors");
|
||||||
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
|
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
|
||||||
|
|
||||||
|
|||||||
@ -406,6 +406,15 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_type sd_type_to_ggml_type(sd_type_t sdtype) {
|
||||||
|
const int type_value = static_cast<int>(sdtype);
|
||||||
|
if (type_value < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)) {
|
||||||
|
return static_cast<ggml_type>(type_value);
|
||||||
|
} else {
|
||||||
|
return GGML_TYPE_COUNT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
|
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
|
||||||
KeyValueArgs pairs;
|
KeyValueArgs pairs;
|
||||||
|
|
||||||
|
|||||||
@ -80,6 +80,8 @@ void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float
|
|||||||
|
|
||||||
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
||||||
|
|
||||||
|
ggml_type sd_type_to_ggml_type(sd_type_t sdtype);
|
||||||
|
|
||||||
std::string trim(const std::string& s);
|
std::string trim(const std::string& s);
|
||||||
|
|
||||||
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
|
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
|
||||||
|
|||||||
@ -522,9 +522,7 @@ public:
|
|||||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||||
|
|
||||||
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
||||||
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
ggml_type wtype = sd_type_to_ggml_type(sd_ctx_params->wtype);
|
||||||
? (ggml_type)sd_ctx_params->wtype
|
|
||||||
: GGML_TYPE_COUNT;
|
|
||||||
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
|
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
|
||||||
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
|
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
|
||||||
model_loader.set_wtype_override(wtype, tensor_type_rules);
|
model_loader.set_wtype_override(wtype, tensor_type_rules);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user