diff --git a/src/convert.cpp b/src/convert.cpp index 5ad066c1..27d377ec 100644 --- a/src/convert.cpp +++ b/src/convert.cpp @@ -99,7 +99,7 @@ bool convert(const char* input_path, model_loader.convert_tensors_name(); } - ggml_type type = (ggml_type)output_type; + ggml_type type = sd_type_to_ggml_type(output_type); bool output_is_safetensors = ends_with(output_path, ".safetensors"); TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); diff --git a/src/core/util.cpp b/src/core/util.cpp index 7325607e..b844d29e 100644 --- a/src/core/util.cpp +++ b/src/core/util.cpp @@ -406,6 +406,15 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } +ggml_type sd_type_to_ggml_type(sd_type_t sdtype) { + const int type_value = static_cast(sdtype); + if (type_value < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT)) { + return static_cast(type_value); + } else { + return GGML_TYPE_COUNT; + } +} + KeyValueArgs parse_key_value_args(const char* args, const char* context) { KeyValueArgs pairs; diff --git a/src/core/util.h b/src/core/util.h index 44ea4174..8213dfab 100644 --- a/src/core/util.h +++ b/src/core/util.h @@ -80,6 +80,8 @@ void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); +ggml_type sd_type_to_ggml_type(sd_type_t sdtype); + std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c2a1974b..4df047dd 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -522,9 +522,7 @@ public: auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); - ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) - ? (ggml_type)sd_ctx_params->wtype - : GGML_TYPE_COUNT; + ggml_type wtype = sd_type_to_ggml_type(sd_ctx_params->wtype); std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules); if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) { model_loader.set_wtype_override(wtype, tensor_type_rules);