mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 08:18:51 +00:00
feat: support safetensors export in convert mode (#1444)
This commit is contained in:
parent
0a7ae07f94
commit
44cca3d626
@ -77,9 +77,10 @@ API and command-line option may change frequently.***
|
|||||||
- OpenCL
|
- OpenCL
|
||||||
- SYCL
|
- SYCL
|
||||||
- Supported weight formats
|
- Supported weight formats
|
||||||
- Pytorch checkpoint (`.ckpt` or `.pth`)
|
- Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`)
|
||||||
- Safetensors (`.safetensors`)
|
- Safetensors (`.safetensors`)
|
||||||
- GGUF (`.gguf`)
|
- GGUF (`.gguf`)
|
||||||
|
- Convert mode supports converting model weights to `.gguf` or `.safetensors`
|
||||||
- Supported platforms
|
- Supported platforms
|
||||||
- Linux
|
- Linux
|
||||||
- Mac OS
|
- Mac OS
|
||||||
|
|||||||
@ -14,6 +14,9 @@ CLI Options:
|
|||||||
--metadata-format <string> metadata output format, one of [text, json] (default: text)
|
--metadata-format <string> metadata output format, one of [text, json] (default: text)
|
||||||
--canny apply canny preprocessor (edge detection)
|
--canny apply canny preprocessor (edge detection)
|
||||||
--convert-name convert tensor name (for convert mode)
|
--convert-name convert tensor name (for convert mode)
|
||||||
|
convert mode writes `.gguf` or `.safetensors` based on the output extension.
|
||||||
|
`.safetensors` export currently supports f16, bf16, f32, and i32 tensor types only.
|
||||||
|
i32 is passthrough only; no f32 <-> i32 conversion is performed
|
||||||
-v, --verbose print extra info
|
-v, --verbose print extra info
|
||||||
--color colors the logging tags according to level
|
--color colors the logging tags according to level
|
||||||
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
|
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
|
||||||
|
|||||||
138
src/convert.cpp
Normal file
138
src/convert.cpp
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
#include <cstring>
|
||||||
|
#include <mutex>
|
||||||
|
#include <regex>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "model.h"
|
||||||
|
#include "model_io/gguf_io.h"
|
||||||
|
#include "model_io/safetensors_io.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
#include "ggml-cpu.h"
|
||||||
|
|
||||||
|
static ggml_type get_export_tensor_type(ModelLoader& model_loader,
|
||||||
|
const TensorStorage& tensor_storage,
|
||||||
|
ggml_type type,
|
||||||
|
const TensorTypeRules& tensor_type_rules) {
|
||||||
|
const std::string& name = tensor_storage.name;
|
||||||
|
ggml_type tensor_type = tensor_storage.type;
|
||||||
|
ggml_type dst_type = type;
|
||||||
|
|
||||||
|
for (const auto& tensor_type_rule : tensor_type_rules) {
|
||||||
|
std::regex pattern(tensor_type_rule.first);
|
||||||
|
if (std::regex_search(name, pattern)) {
|
||||||
|
dst_type = tensor_type_rule.second;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||||
|
tensor_type = dst_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool load_tensors_for_export(ModelLoader& model_loader,
|
||||||
|
ggml_context* ggml_ctx,
|
||||||
|
ggml_type type,
|
||||||
|
const TensorTypeRules& tensor_type_rules,
|
||||||
|
std::vector<TensorWriteInfo>& tensors) {
|
||||||
|
std::mutex tensor_mutex;
|
||||||
|
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||||
|
const std::string& name = tensor_storage.name;
|
||||||
|
ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(tensor_mutex);
|
||||||
|
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
LOG_ERROR("ggml_new_tensor failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_set_name(tensor, name.c_str());
|
||||||
|
|
||||||
|
if (!tensor->data) {
|
||||||
|
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
||||||
|
// Avoid crashing writers by setting a dummy pointer for zero-sized tensors.
|
||||||
|
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
||||||
|
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorWriteInfo write_info;
|
||||||
|
write_info.tensor = tensor;
|
||||||
|
write_info.n_dims = tensor_storage.n_dims;
|
||||||
|
for (int i = 0; i < tensor_storage.n_dims; ++i) {
|
||||||
|
write_info.ne[i] = tensor_storage.ne[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst_tensor = tensor;
|
||||||
|
tensors.push_back(std::move(write_info));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(on_new_tensor_cb);
|
||||||
|
LOG_INFO("load tensors done");
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool convert(const char* input_path,
|
||||||
|
const char* vae_path,
|
||||||
|
const char* output_path,
|
||||||
|
sd_type_t output_type,
|
||||||
|
const char* tensor_type_rules,
|
||||||
|
bool convert_name) {
|
||||||
|
ModelLoader model_loader;
|
||||||
|
|
||||||
|
if (!model_loader.init_from_file(input_path)) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", input_path);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vae_path != nullptr && strlen(vae_path) > 0) {
|
||||||
|
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (convert_name) {
|
||||||
|
model_loader.convert_tensors_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_type type = (ggml_type)output_type;
|
||||||
|
bool output_is_safetensors = ends_with(output_path, ".safetensors");
|
||||||
|
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
|
||||||
|
|
||||||
|
auto backend = ggml_backend_cpu_init();
|
||||||
|
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||||
|
mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead();
|
||||||
|
mem_size += model_loader.get_params_mem_size(backend, type);
|
||||||
|
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||||
|
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||||
|
|
||||||
|
if (ggml_ctx == nullptr) {
|
||||||
|
LOG_ERROR("ggml_init failed for converter");
|
||||||
|
ggml_backend_free(backend);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TensorWriteInfo> tensors;
|
||||||
|
bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors);
|
||||||
|
ggml_backend_free(backend);
|
||||||
|
|
||||||
|
std::string error;
|
||||||
|
if (success) {
|
||||||
|
if (output_is_safetensors) {
|
||||||
|
success = write_safetensors_file(output_path, tensors, &error);
|
||||||
|
} else {
|
||||||
|
success = write_gguf_file(output_path, tensors, &error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!success && !error.empty()) {
|
||||||
|
LOG_ERROR("%s", error.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_free(ggml_ctx);
|
||||||
|
return success;
|
||||||
|
}
|
||||||
116
src/model.cpp
116
src/model.cpp
@ -81,7 +81,7 @@ const char* unused_tensors[] = {
|
|||||||
"first_stage_model.bn.",
|
"first_stage_model.bn.",
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_unused_tensor(std::string name) {
|
bool is_unused_tensor(const std::string& name) {
|
||||||
for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) {
|
for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) {
|
||||||
if (starts_with(name, unused_tensors[i])) {
|
if (starts_with(name, unused_tensors[i])) {
|
||||||
return true;
|
return true;
|
||||||
@ -687,8 +687,8 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
|
|||||||
return wtype_stat;
|
return wtype_stat;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||||
std::vector<std::pair<std::string, ggml_type>> result;
|
TensorTypeRules result;
|
||||||
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||||
if (item.size() == 0)
|
if (item.size() == 0)
|
||||||
continue;
|
continue;
|
||||||
@ -1121,91 +1121,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
|
||||||
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
|
|
||||||
auto get_tensor_type = [&](const TensorStorage& tensor_storage) -> ggml_type {
|
|
||||||
const std::string& name = tensor_storage.name;
|
|
||||||
ggml_type tensor_type = tensor_storage.type;
|
|
||||||
ggml_type dst_type = type;
|
|
||||||
|
|
||||||
for (const auto& tensor_type_rule : tensor_type_rules) {
|
|
||||||
std::regex pattern(tensor_type_rule.first);
|
|
||||||
if (std::regex_search(name, pattern)) {
|
|
||||||
dst_type = tensor_type_rule.second;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tensor_should_be_converted(tensor_storage, dst_type)) {
|
|
||||||
tensor_type = dst_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
return tensor_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto backend = ggml_backend_cpu_init();
|
|
||||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
|
||||||
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
|
|
||||||
mem_size += get_params_mem_size(backend, type);
|
|
||||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
|
||||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
|
||||||
|
|
||||||
if (ggml_ctx == nullptr) {
|
|
||||||
LOG_ERROR("ggml_init failed for GGUF writer");
|
|
||||||
ggml_backend_free(backend);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ggml_tensor*> tensors;
|
|
||||||
std::mutex tensor_mutex;
|
|
||||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
|
||||||
const std::string& name = tensor_storage.name;
|
|
||||||
ggml_type tensor_type = get_tensor_type(tensor_storage);
|
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(tensor_mutex);
|
|
||||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
|
||||||
if (tensor == nullptr) {
|
|
||||||
LOG_ERROR("ggml_new_tensor failed");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ggml_set_name(tensor, name.c_str());
|
|
||||||
|
|
||||||
// LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(),
|
|
||||||
// ggml_nbytes(tensor), ggml_type_name(tensor_type),
|
|
||||||
// tensor_storage.n_dims,
|
|
||||||
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
|
|
||||||
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
||||||
|
|
||||||
if (!tensor->data) {
|
|
||||||
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
|
||||||
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
|
|
||||||
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
|
||||||
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
*dst_tensor = tensor;
|
|
||||||
tensors.push_back(tensor);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool success = load_tensors(on_new_tensor_cb);
|
|
||||||
ggml_backend_free(backend);
|
|
||||||
LOG_INFO("load tensors done");
|
|
||||||
|
|
||||||
std::string error;
|
|
||||||
if (success) {
|
|
||||||
success = write_gguf_file(file_path, tensors, &error);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!success && !error.empty()) {
|
|
||||||
LOG_ERROR("%s", error.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_free(ggml_ctx);
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
|
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
|
||||||
size_t alignment = 128;
|
size_t alignment = 128;
|
||||||
if (backend != nullptr) {
|
if (backend != nullptr) {
|
||||||
@ -1225,28 +1140,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
|
|||||||
|
|
||||||
return mem_size;
|
return mem_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool convert(const char* input_path,
|
|
||||||
const char* vae_path,
|
|
||||||
const char* output_path,
|
|
||||||
sd_type_t output_type,
|
|
||||||
const char* tensor_type_rules,
|
|
||||||
bool convert_name) {
|
|
||||||
ModelLoader model_loader;
|
|
||||||
|
|
||||||
if (!model_loader.init_from_file(input_path)) {
|
|
||||||
LOG_ERROR("init model loader from file failed: '%s'", input_path);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (vae_path != nullptr && strlen(vae_path) > 0) {
|
|
||||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
|
||||||
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (convert_name) {
|
|
||||||
model_loader.convert_tensors_name();
|
|
||||||
}
|
|
||||||
return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
|
|
||||||
}
|
|
||||||
|
|||||||
@ -189,6 +189,9 @@ enum PMVersion {
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
||||||
|
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
|
||||||
|
|
||||||
|
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
|
||||||
|
|
||||||
class ModelLoader {
|
class ModelLoader {
|
||||||
protected:
|
protected:
|
||||||
@ -231,7 +234,6 @@ public:
|
|||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
|
|
||||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||||
~ModelLoader() = default;
|
~ModelLoader() = default;
|
||||||
|
|||||||
@ -95,7 +95,7 @@ bool read_gguf_file(const std::string& file_path,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool write_gguf_file(const std::string& file_path,
|
bool write_gguf_file(const std::string& file_path,
|
||||||
const std::vector<ggml_tensor*>& tensors,
|
const std::vector<TensorWriteInfo>& tensors,
|
||||||
std::string* error) {
|
std::string* error) {
|
||||||
gguf_context* gguf_ctx = gguf_init_empty();
|
gguf_context* gguf_ctx = gguf_init_empty();
|
||||||
if (gguf_ctx == nullptr) {
|
if (gguf_ctx == nullptr) {
|
||||||
@ -103,7 +103,8 @@ bool write_gguf_file(const std::string& file_path,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ggml_tensor* tensor : tensors) {
|
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||||
|
ggml_tensor* tensor = write_tensor.tensor;
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
set_error(error, "null tensor cannot be written to GGUF");
|
set_error(error, "null tensor cannot be written to GGUF");
|
||||||
gguf_free(gguf_ctx);
|
gguf_free(gguf_ctx);
|
||||||
|
|||||||
@ -11,7 +11,7 @@ bool read_gguf_file(const std::string& file_path,
|
|||||||
std::vector<TensorStorage>& tensor_storages,
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
std::string* error = nullptr);
|
std::string* error = nullptr);
|
||||||
bool write_gguf_file(const std::string& file_path,
|
bool write_gguf_file(const std::string& file_path,
|
||||||
const std::vector<ggml_tensor*>& tensors,
|
const std::vector<TensorWriteInfo>& tensors,
|
||||||
std::string* error = nullptr);
|
std::string* error = nullptr);
|
||||||
|
|
||||||
#endif // __SD_MODEL_IO_GGUF_IO_H__
|
#endif // __SD_MODEL_IO_GGUF_IO_H__
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "binary_io.h"
|
#include "binary_io.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
|
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ bool is_safetensors_file(const std::string& file_path) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_type str_to_ggml_type(const std::string& dtype) {
|
static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) {
|
||||||
ggml_type ttype = GGML_TYPE_COUNT;
|
ggml_type ttype = GGML_TYPE_COUNT;
|
||||||
if (dtype == "F16") {
|
if (dtype == "F16") {
|
||||||
ttype = GGML_TYPE_F16;
|
ttype = GGML_TYPE_F16;
|
||||||
@ -154,7 +155,7 @@ bool read_safetensors_file(const std::string& file_path,
|
|||||||
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
||||||
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
||||||
|
|
||||||
ggml_type type = str_to_ggml_type(dtype);
|
ggml_type type = safetensors_dtype_to_ggml_type(dtype);
|
||||||
if (type == GGML_TYPE_COUNT) {
|
if (type == GGML_TYPE_COUNT) {
|
||||||
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
|
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
|
||||||
return false;
|
return false;
|
||||||
@ -221,3 +222,95 @@ bool read_safetensors_file(const std::string& file_path,
|
|||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
*dtype = "F16";
|
||||||
|
return true;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
*dtype = "BF16";
|
||||||
|
return true;
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
*dtype = "F32";
|
||||||
|
return true;
|
||||||
|
case GGML_TYPE_I32:
|
||||||
|
*dtype = "I32";
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool write_safetensors_file(const std::string& file_path,
|
||||||
|
const std::vector<TensorWriteInfo>& tensors,
|
||||||
|
std::string* error) {
|
||||||
|
nlohmann::ordered_json header = nlohmann::ordered_json::object();
|
||||||
|
|
||||||
|
uint64_t data_offset = 0;
|
||||||
|
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||||
|
ggml_tensor* tensor = write_tensor.tensor;
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
set_error(error, "null tensor cannot be written to safetensors");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string name = ggml_get_name(tensor);
|
||||||
|
std::string dtype;
|
||||||
|
if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) {
|
||||||
|
set_error(error,
|
||||||
|
"unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) +
|
||||||
|
"' for tensor '" + name + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t tensor_nbytes = ggml_nbytes(tensor);
|
||||||
|
|
||||||
|
nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object();
|
||||||
|
json_tensor_info["dtype"] = dtype;
|
||||||
|
|
||||||
|
nlohmann::ordered_json shape = nlohmann::ordered_json::array();
|
||||||
|
for (int i = 0; i < write_tensor.n_dims; ++i) {
|
||||||
|
shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]);
|
||||||
|
}
|
||||||
|
json_tensor_info["shape"] = shape;
|
||||||
|
|
||||||
|
nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array();
|
||||||
|
data_offsets.push_back(data_offset);
|
||||||
|
data_offsets.push_back(data_offset + tensor_nbytes);
|
||||||
|
json_tensor_info["data_offsets"] = data_offsets;
|
||||||
|
|
||||||
|
header[name] = json_tensor_info;
|
||||||
|
data_offset += tensor_nbytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string header_str = header.dump();
|
||||||
|
|
||||||
|
std::ofstream file(file_path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
set_error(error, "failed to open '" + file_path + "' for writing");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||||
|
model_io::write_u64(file, header_str.size());
|
||||||
|
file.write(header_str.data(), header_str.size());
|
||||||
|
if (!file) {
|
||||||
|
set_error(error, "failed to write safetensors header to '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||||
|
ggml_tensor* tensor = write_tensor.tensor;
|
||||||
|
const std::string name = ggml_get_name(tensor);
|
||||||
|
const size_t tensor_nbytes = ggml_nbytes(tensor);
|
||||||
|
file.write((const char*)tensor->data, tensor_nbytes);
|
||||||
|
if (!file) {
|
||||||
|
set_error(error,
|
||||||
|
"failed to write tensor '" + name + "' to '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|||||||
@ -10,5 +10,8 @@ bool is_safetensors_file(const std::string& file_path);
|
|||||||
bool read_safetensors_file(const std::string& file_path,
|
bool read_safetensors_file(const std::string& file_path,
|
||||||
std::vector<TensorStorage>& tensor_storages,
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
std::string* error = nullptr);
|
std::string* error = nullptr);
|
||||||
|
bool write_safetensors_file(const std::string& file_path,
|
||||||
|
const std::vector<TensorWriteInfo>& tensors,
|
||||||
|
std::string* error = nullptr);
|
||||||
|
|
||||||
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__
|
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||||
|
|||||||
@ -121,6 +121,12 @@ struct TensorStorage {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TensorWriteInfo {
|
||||||
|
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||||
|
int n_dims = 0;
|
||||||
|
ggml_tensor* tensor = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||||
|
|
||||||
#endif // __SD_TENSOR_STORAGE_H__
|
#endif // __SD_TENSOR_STORAGE_H__
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user