mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 21:38:58 +00:00
fix: make weight override more robust against ggml changes (#760)
This commit is contained in:
parent
48956ffb87
commit
5869987fe4
@ -2310,7 +2310,7 @@ std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std
|
|||||||
if (type_name == "f32") {
|
if (type_name == "f32") {
|
||||||
tensor_type = GGML_TYPE_F32;
|
tensor_type = GGML_TYPE_F32;
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
|
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||||
auto trait = ggml_get_type_traits((ggml_type)i);
|
auto trait = ggml_get_type_traits((ggml_type)i);
|
||||||
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
|
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
|
||||||
tensor_type = (ggml_type)i;
|
tensor_type = (ggml_type)i;
|
||||||
|
|||||||
@ -265,7 +265,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
||||||
ggml_type wtype = (ggml_type)sd_ctx_params->wtype;
|
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
||||||
|
? (ggml_type)sd_ctx_params->wtype
|
||||||
|
: GGML_TYPE_COUNT;
|
||||||
if (wtype == GGML_TYPE_COUNT) {
|
if (wtype == GGML_TYPE_COUNT) {
|
||||||
model_wtype = model_loader.get_sd_wtype();
|
model_wtype = model_loader.get_sd_wtype();
|
||||||
if (model_wtype == GGML_TYPE_COUNT) {
|
if (model_wtype == GGML_TYPE_COUNT) {
|
||||||
@ -1465,11 +1467,14 @@ public:
|
|||||||
#define NONE_STR "NONE"
|
#define NONE_STR "NONE"
|
||||||
|
|
||||||
const char* sd_type_name(enum sd_type_t type) {
|
const char* sd_type_name(enum sd_type_t type) {
|
||||||
return ggml_type_name((ggml_type)type);
|
if ((int)type < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)) {
|
||||||
|
return ggml_type_name((ggml_type)type);
|
||||||
|
}
|
||||||
|
return NONE_STR;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum sd_type_t str_to_sd_type(const char* str) {
|
enum sd_type_t str_to_sd_type(const char* str) {
|
||||||
for (int i = 0; i < SD_TYPE_COUNT; i++) {
|
for (int i = 0; i < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT); i++) {
|
||||||
auto trait = ggml_get_type_traits((ggml_type)i);
|
auto trait = ggml_get_type_traits((ggml_type)i);
|
||||||
if (!strcmp(str, trait->type_name)) {
|
if (!strcmp(str, trait->type_name)) {
|
||||||
return (enum sd_type_t)i;
|
return (enum sd_type_t)i;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user