mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 08:18:51 +00:00
chore: harden safetensors and gguf loading code (#1404)
Co-authored-by: professor-moody <keys@nimbus.lan>
This commit is contained in:
parent
be9f51b25c
commit
118489eb5c
@ -59,6 +59,9 @@ private:
|
||||
if (!safe_read(fin, key_len))
|
||||
return false;
|
||||
|
||||
if (key_len > 4096)
|
||||
return false;
|
||||
|
||||
std::string key(key_len, '\0');
|
||||
if (!safe_read(fin, (char*)key.data(), key_len))
|
||||
return false;
|
||||
|
||||
@ -315,8 +315,9 @@ bool is_safetensors_file(const std::string& file_path) {
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||
if (header_.is_discarded()) {
|
||||
try {
|
||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||
} catch (const std::exception&) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -511,7 +512,14 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
return false;
|
||||
}
|
||||
|
||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||
nlohmann::json header_;
|
||||
try {
|
||||
header_ = nlohmann::json::parse(header_buf.data());
|
||||
} catch (const std::exception&) {
|
||||
LOG_ERROR("parsing safetensors header failed", file_path.c_str());
|
||||
file_paths_.pop_back();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& item : header_.items()) {
|
||||
std::string name = item.key();
|
||||
@ -575,24 +583,29 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
|
||||
size_t tensor_data_size = end - begin;
|
||||
|
||||
bool tensor_size_ok;
|
||||
if (dtype == "F8_E4M3") {
|
||||
tensor_storage.is_f8_e4m3 = true;
|
||||
// f8 -> f16
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
tensor_storage.is_f8_e5m2 = true;
|
||||
// f8 -> f16
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else if (dtype == "F64") {
|
||||
tensor_storage.is_f64 = true;
|
||||
// f64 -> f32
|
||||
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
} else if (dtype == "I64") {
|
||||
tensor_storage.is_i64 = true;
|
||||
// i64 -> i32
|
||||
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
} else {
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
|
||||
}
|
||||
if (!tensor_size_ok) {
|
||||
LOG_ERROR("size mismatch for tensor '%s' (%s)\n", name.c_str(), dtype.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
add_tensor_storage(tensor_storage);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user