mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 08:18:51 +00:00
Compare commits
2 Commits
66143340b6
...
44cca3d626
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44cca3d626 | ||
|
|
0a7ae07f94 |
@ -77,9 +77,10 @@ API and command-line option may change frequently.***
|
||||
- OpenCL
|
||||
- SYCL
|
||||
- Supported weight formats
|
||||
- Pytorch checkpoint (`.ckpt` or `.pth`)
|
||||
- Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`)
|
||||
- Safetensors (`.safetensors`)
|
||||
- GGUF (`.gguf`)
|
||||
- Convert mode supports converting model weights to `.gguf` or `.safetensors`
|
||||
- Supported platforms
|
||||
- Linux
|
||||
- Mac OS
|
||||
|
||||
@ -14,6 +14,9 @@ CLI Options:
|
||||
--metadata-format <string> metadata output format, one of [text, json] (default: text)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--convert-name convert tensor name (for convert mode)
|
||||
convert mode writes `.gguf` or `.safetensors` based on the output extension.
|
||||
`.safetensors` export currently supports f16, bf16, f32, and i32 tensor types only.
|
||||
i32 is passthrough only; no f32 <-> i32 conversion is performed
|
||||
-v, --verbose print extra info
|
||||
--color colors the logging tags according to level
|
||||
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
|
||||
|
||||
138
src/convert.cpp
Normal file
138
src/convert.cpp
Normal file
@ -0,0 +1,138 @@
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
#include "model_io/gguf_io.h"
|
||||
#include "model_io/safetensors_io.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "ggml-cpu.h"
|
||||
|
||||
static ggml_type get_export_tensor_type(ModelLoader& model_loader,
|
||||
const TensorStorage& tensor_storage,
|
||||
ggml_type type,
|
||||
const TensorTypeRules& tensor_type_rules) {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = tensor_storage.type;
|
||||
ggml_type dst_type = type;
|
||||
|
||||
for (const auto& tensor_type_rule : tensor_type_rules) {
|
||||
std::regex pattern(tensor_type_rule.first);
|
||||
if (std::regex_search(name, pattern)) {
|
||||
dst_type = tensor_type_rule.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||
tensor_type = dst_type;
|
||||
}
|
||||
|
||||
return tensor_type;
|
||||
}
|
||||
|
||||
static bool load_tensors_for_export(ModelLoader& model_loader,
|
||||
ggml_context* ggml_ctx,
|
||||
ggml_type type,
|
||||
const TensorTypeRules& tensor_type_rules,
|
||||
std::vector<TensorWriteInfo>& tensors) {
|
||||
std::mutex tensor_mutex;
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules);
|
||||
|
||||
std::lock_guard<std::mutex> lock(tensor_mutex);
|
||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||
if (tensor == nullptr) {
|
||||
LOG_ERROR("ggml_new_tensor failed");
|
||||
return false;
|
||||
}
|
||||
ggml_set_name(tensor, name.c_str());
|
||||
|
||||
if (!tensor->data) {
|
||||
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
||||
// Avoid crashing writers by setting a dummy pointer for zero-sized tensors.
|
||||
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
||||
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
||||
}
|
||||
|
||||
TensorWriteInfo write_info;
|
||||
write_info.tensor = tensor;
|
||||
write_info.n_dims = tensor_storage.n_dims;
|
||||
for (int i = 0; i < tensor_storage.n_dims; ++i) {
|
||||
write_info.ne[i] = tensor_storage.ne[i];
|
||||
}
|
||||
|
||||
*dst_tensor = tensor;
|
||||
tensors.push_back(std::move(write_info));
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
LOG_INFO("load tensors done");
|
||||
return success;
|
||||
}
|
||||
|
||||
bool convert(const char* input_path,
|
||||
const char* vae_path,
|
||||
const char* output_path,
|
||||
sd_type_t output_type,
|
||||
const char* tensor_type_rules,
|
||||
bool convert_name) {
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader.init_from_file(input_path)) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", input_path);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (vae_path != nullptr && strlen(vae_path) > 0) {
|
||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (convert_name) {
|
||||
model_loader.convert_tensors_name();
|
||||
}
|
||||
|
||||
ggml_type type = (ggml_type)output_type;
|
||||
bool output_is_safetensors = ends_with(output_path, ".safetensors");
|
||||
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
|
||||
|
||||
auto backend = ggml_backend_cpu_init();
|
||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||
mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead();
|
||||
mem_size += model_loader.get_params_mem_size(backend, type);
|
||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||
|
||||
if (ggml_ctx == nullptr) {
|
||||
LOG_ERROR("ggml_init failed for converter");
|
||||
ggml_backend_free(backend);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<TensorWriteInfo> tensors;
|
||||
bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors);
|
||||
ggml_backend_free(backend);
|
||||
|
||||
std::string error;
|
||||
if (success) {
|
||||
if (output_is_safetensors) {
|
||||
success = write_safetensors_file(output_path, tensors, &error);
|
||||
} else {
|
||||
success = write_gguf_file(output_path, tensors, &error);
|
||||
}
|
||||
}
|
||||
|
||||
if (!success && !error.empty()) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
}
|
||||
|
||||
ggml_free(ggml_ctx);
|
||||
return success;
|
||||
}
|
||||
@ -1523,12 +1523,10 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
|
||||
const std::vector<float>& sigmas,
|
||||
std::shared_ptr<RNG> rng,
|
||||
float eta) {
|
||||
|
||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
|
||||
float sigma = sigmas[i];
|
||||
float sigma_to = sigmas[i + 1];
|
||||
float sigma = sigmas[i];
|
||||
float sigma_to = sigmas[i + 1];
|
||||
|
||||
auto model_output_opt = model(x, sigma, i + 1);
|
||||
if (model_output_opt.empty()) {
|
||||
@ -1551,12 +1549,11 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
|
||||
float std_dev_t = eta * std::sqrt(variance);
|
||||
|
||||
x = pred_original_sample +
|
||||
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
|
||||
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;
|
||||
|
||||
if (eta > 0) {
|
||||
x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
|
||||
x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
|
||||
}
|
||||
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -1584,8 +1581,10 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
||||
|
||||
auto get_timestep_from_sigma = [&](float s) -> int {
|
||||
auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
|
||||
if (it == compvis_sigmas.begin()) return 0;
|
||||
if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
|
||||
if (it == compvis_sigmas.begin())
|
||||
return 0;
|
||||
if (it == compvis_sigmas.end())
|
||||
return TIMESTEPS - 1;
|
||||
int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
|
||||
int idx_low = idx_high - 1;
|
||||
if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
|
||||
@ -1596,7 +1595,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
||||
|
||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
|
||||
float sigma_to = sigmas[i + 1];
|
||||
int prev_timestep = get_timestep_from_sigma(sigma_to);
|
||||
int timestep_s = (int)floor((1 - eta) * prev_timestep);
|
||||
@ -1626,7 +1624,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
||||
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
|
||||
std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
|
||||
}
|
||||
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
221
src/model.cpp
221
src/model.cpp
@ -2,6 +2,7 @@
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
@ -13,9 +14,10 @@
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
#include "model_io/ckpt_io.h"
|
||||
#include "model_io/gguf_io.h"
|
||||
#include "model_io/safetensors_io.h"
|
||||
#include "model_io/torch_legacy_io.h"
|
||||
#include "model_io/torch_zip_io.h"
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
|
||||
@ -79,7 +81,7 @@ const char* unused_tensors[] = {
|
||||
"first_stage_model.bn.",
|
||||
};
|
||||
|
||||
bool is_unused_tensor(std::string name) {
|
||||
bool is_unused_tensor(const std::string& name) {
|
||||
for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) {
|
||||
if (starts_with(name, unused_tensors[i])) {
|
||||
return true;
|
||||
@ -229,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
|
||||
} else if (is_safetensors_file(file_path)) {
|
||||
LOG_INFO("load %s using safetensors format", file_path.c_str());
|
||||
return init_from_safetensors_file(file_path, prefix);
|
||||
} else if (is_ckpt_file(file_path)) {
|
||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||
return init_from_ckpt_file(file_path, prefix);
|
||||
} else if (is_torch_zip_file(file_path)) {
|
||||
LOG_INFO("load %s using torch zip format", file_path.c_str());
|
||||
return init_from_torch_zip_file(file_path, prefix);
|
||||
} else if (init_from_torch_legacy_file(file_path, prefix)) {
|
||||
LOG_INFO("load %s using torch legacy format", file_path.c_str());
|
||||
return true;
|
||||
} else {
|
||||
if (file_exists(file_path)) {
|
||||
LOG_WARN("unknown format %s", file_path.c_str());
|
||||
@ -329,6 +334,68 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= TorchLegacyModelLoader ==================================================*/
|
||||
|
||||
bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from torch legacy '%s'", file_path.c_str());
|
||||
|
||||
std::vector<TensorStorage> tensor_storages;
|
||||
std::string error;
|
||||
if (!read_torch_legacy_file(file_path, tensor_storages, &error)) {
|
||||
if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) {
|
||||
LOG_WARN("%s", error.c_str());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!starts_with(tensor_storage.name, prefix)) {
|
||||
tensor_storage.name = prefix + tensor_storage.name;
|
||||
}
|
||||
tensor_storage.file_index = file_index;
|
||||
|
||||
add_tensor_storage(tensor_storage);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= TorchZipModelLoader ==================================================*/
|
||||
|
||||
bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
|
||||
std::vector<TensorStorage> tensor_storages;
|
||||
std::string error;
|
||||
if (!read_torch_zip_file(file_path, tensor_storages, &error)) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (!starts_with(tensor_storage.name, prefix)) {
|
||||
tensor_storage.name = prefix + tensor_storage.name;
|
||||
}
|
||||
tensor_storage.file_index = file_index;
|
||||
|
||||
add_tensor_storage(tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s", tensor_storage.to_string().c_str());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= DiffusersModelLoader ==================================================*/
|
||||
|
||||
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
|
||||
@ -355,35 +422,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= CkptModelLoader ==================================================*/
|
||||
|
||||
bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
|
||||
std::vector<TensorStorage> tensor_storages;
|
||||
std::string error;
|
||||
if (!read_ckpt_file(file_path, tensor_storages, &error)) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (!starts_with(tensor_storage.name, prefix)) {
|
||||
tensor_storage.name = prefix + tensor_storage.name;
|
||||
}
|
||||
tensor_storage.file_index = file_index;
|
||||
|
||||
add_tensor_storage(tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s", tensor_storage.to_string().c_str());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight, input_block_weight;
|
||||
|
||||
@ -649,8 +687,8 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
|
||||
return wtype_stat;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
std::vector<std::pair<std::string, ggml_type>> result;
|
||||
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
TensorTypeRules result;
|
||||
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||
if (item.size() == 0)
|
||||
continue;
|
||||
@ -1083,91 +1121,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
||||
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
|
||||
auto get_tensor_type = [&](const TensorStorage& tensor_storage) -> ggml_type {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = tensor_storage.type;
|
||||
ggml_type dst_type = type;
|
||||
|
||||
for (const auto& tensor_type_rule : tensor_type_rules) {
|
||||
std::regex pattern(tensor_type_rule.first);
|
||||
if (std::regex_search(name, pattern)) {
|
||||
dst_type = tensor_type_rule.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||
tensor_type = dst_type;
|
||||
}
|
||||
|
||||
return tensor_type;
|
||||
};
|
||||
|
||||
auto backend = ggml_backend_cpu_init();
|
||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
|
||||
mem_size += get_params_mem_size(backend, type);
|
||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||
|
||||
if (ggml_ctx == nullptr) {
|
||||
LOG_ERROR("ggml_init failed for GGUF writer");
|
||||
ggml_backend_free(backend);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor*> tensors;
|
||||
std::mutex tensor_mutex;
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = get_tensor_type(tensor_storage);
|
||||
|
||||
std::lock_guard<std::mutex> lock(tensor_mutex);
|
||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||
if (tensor == nullptr) {
|
||||
LOG_ERROR("ggml_new_tensor failed");
|
||||
return false;
|
||||
}
|
||||
ggml_set_name(tensor, name.c_str());
|
||||
|
||||
// LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(),
|
||||
// ggml_nbytes(tensor), ggml_type_name(tensor_type),
|
||||
// tensor_storage.n_dims,
|
||||
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
|
||||
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
|
||||
if (!tensor->data) {
|
||||
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
||||
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
|
||||
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
||||
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
||||
}
|
||||
|
||||
*dst_tensor = tensor;
|
||||
tensors.push_back(tensor);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool success = load_tensors(on_new_tensor_cb);
|
||||
ggml_backend_free(backend);
|
||||
LOG_INFO("load tensors done");
|
||||
|
||||
std::string error;
|
||||
if (success) {
|
||||
success = write_gguf_file(file_path, tensors, &error);
|
||||
}
|
||||
|
||||
if (!success && !error.empty()) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
}
|
||||
|
||||
ggml_free(ggml_ctx);
|
||||
return success;
|
||||
}
|
||||
|
||||
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
|
||||
size_t alignment = 128;
|
||||
if (backend != nullptr) {
|
||||
@ -1187,29 +1140,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
|
||||
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
bool convert(const char* input_path,
|
||||
const char* vae_path,
|
||||
const char* output_path,
|
||||
sd_type_t output_type,
|
||||
const char* tensor_type_rules,
|
||||
bool convert_name) {
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader.init_from_file(input_path)) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", input_path);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (vae_path != nullptr && strlen(vae_path) > 0) {
|
||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (convert_name) {
|
||||
model_loader.convert_tensors_name();
|
||||
}
|
||||
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
|
||||
return success;
|
||||
}
|
||||
|
||||
@ -189,6 +189,9 @@ enum PMVersion {
|
||||
};
|
||||
|
||||
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
||||
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
|
||||
|
||||
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
|
||||
|
||||
class ModelLoader {
|
||||
protected:
|
||||
@ -200,7 +203,8 @@ protected:
|
||||
|
||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
@ -230,7 +234,6 @@ public:
|
||||
return names;
|
||||
}
|
||||
|
||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
~ModelLoader() = default;
|
||||
|
||||
57
src/model_io/binary_io.h
Normal file
57
src/model_io/binary_io.h
Normal file
@ -0,0 +1,57 @@
|
||||
#ifndef __SD_MODEL_IO_BINARY_IO_H__
|
||||
#define __SD_MODEL_IO_BINARY_IO_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
|
||||
namespace model_io {
|
||||
|
||||
inline int32_t read_int(const uint8_t* buffer) {
|
||||
uint32_t value = 0;
|
||||
value |= static_cast<uint32_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint32_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint32_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint32_t>(buffer[0]);
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
inline uint16_t read_short(const uint8_t* buffer) {
|
||||
uint16_t value = 0;
|
||||
value |= static_cast<uint16_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint16_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
inline uint64_t read_u64(const uint8_t* buffer) {
|
||||
uint64_t value = 0;
|
||||
value |= static_cast<uint64_t>(buffer[7]) << 56;
|
||||
value |= static_cast<uint64_t>(buffer[6]) << 48;
|
||||
value |= static_cast<uint64_t>(buffer[5]) << 40;
|
||||
value |= static_cast<uint64_t>(buffer[4]) << 32;
|
||||
value |= static_cast<uint64_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint64_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint64_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint64_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
inline void write_u64(std::ostream& stream, uint64_t value) {
|
||||
uint8_t buffer[8];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
|
||||
}
|
||||
stream.write((const char*)buffer, sizeof(buffer));
|
||||
}
|
||||
|
||||
inline int find_char(const uint8_t* buffer, int len, char c) {
|
||||
for (int pos = 0; pos < len; pos++) {
|
||||
if (buffer[pos] == (uint8_t)c) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
} // namespace model_io
|
||||
|
||||
#endif // __SD_MODEL_IO_BINARY_IO_H__
|
||||
@ -1,403 +0,0 @@
|
||||
#include "ckpt_io.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "zip.h"
|
||||
|
||||
static constexpr int MAX_STRING_BUFFER = 512;
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t read_int(const uint8_t* buffer) {
|
||||
// little endian
|
||||
uint32_t value = 0;
|
||||
value |= static_cast<uint32_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint32_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint32_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint32_t>(buffer[0]);
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
static uint16_t read_short(const uint8_t* buffer) {
|
||||
// little endian
|
||||
uint16_t value = 0;
|
||||
value |= static_cast<uint16_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint16_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
bool is_ckpt_file(const std::string& file_path) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
return false;
|
||||
}
|
||||
zip_close(zip);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*================================================= CkptModelLoader ==================================================*/
|
||||
|
||||
// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
|
||||
// 0: \x80 PROTO 2
|
||||
// 2: } EMPTY_DICT
|
||||
// 3: q BINPUT 0
|
||||
// 5: ( MARK
|
||||
// 6: X BINUNICODE 'epoch'
|
||||
// 16: q BINPUT 1
|
||||
// 18: K BININT1 6
|
||||
// 20: X BINUNICODE 'global_step'
|
||||
// 36: q BINPUT 2
|
||||
// 38: J BININT 470000
|
||||
// 43: X BINUNICODE 'pytorch-lightning_version'
|
||||
// 73: q BINPUT 3
|
||||
// 75: X BINUNICODE '1.4.2'
|
||||
// 85: q BINPUT 4
|
||||
// 87: X BINUNICODE 'state_dict'
|
||||
// 102: q BINPUT 5
|
||||
// 104: } EMPTY_DICT
|
||||
// 105: q BINPUT 6
|
||||
// 107: ( MARK
|
||||
// 108: X BINUNICODE 'betas'
|
||||
// 118: q BINPUT 7
|
||||
// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2'
|
||||
// 153: q BINPUT 8
|
||||
// 155: ( MARK
|
||||
// 156: ( MARK
|
||||
// 157: X BINUNICODE 'storage'
|
||||
// 169: q BINPUT 9
|
||||
// 171: c GLOBAL 'torch FloatStorage'
|
||||
// 191: q BINPUT 10
|
||||
// 193: X BINUNICODE '0'
|
||||
// 199: q BINPUT 11
|
||||
// 201: X BINUNICODE 'cpu'
|
||||
// 209: q BINPUT 12
|
||||
// 211: M BININT2 1000
|
||||
// 214: t TUPLE (MARK at 156)
|
||||
// 215: q BINPUT 13
|
||||
// 217: Q BINPERSID
|
||||
// 218: K BININT1 0
|
||||
// 220: M BININT2 1000
|
||||
// ...............................
|
||||
// 3201: q BINPUT 250
|
||||
// 3203: R REDUCE
|
||||
// 3204: q BINPUT 251
|
||||
// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
|
||||
// 3264: q BINPUT 252
|
||||
// 3266: h BINGET 8
|
||||
// 3268: ( MARK
|
||||
// 3269: ( MARK
|
||||
// 3270: h BINGET 9
|
||||
// 3272: h BINGET 10
|
||||
// 3274: X BINUNICODE '30'
|
||||
// 3281: q BINPUT 253
|
||||
// 3283: h BINGET 12
|
||||
// 3285: J BININT 102400
|
||||
// 3290: t TUPLE (MARK at 3269)
|
||||
// 3291: q BINPUT 254
|
||||
// 3293: Q BINPERSID
|
||||
// 3294: K BININT1 0
|
||||
// 3296: ( MARK
|
||||
// 3297: M BININT2 320
|
||||
// 3300: M BININT2 320
|
||||
// 3303: K BININT1 1
|
||||
// 3305: K BININT1 1
|
||||
// 3307: t TUPLE (MARK at 3296)
|
||||
// 3308: q BINPUT 255
|
||||
// 3310: ( MARK
|
||||
// 3311: M BININT2 320
|
||||
// 3314: K BININT1 1
|
||||
// 3316: K BININT1 1
|
||||
// 3318: K BININT1 1
|
||||
// 3320: t TUPLE (MARK at 3310)
|
||||
// 3321: r LONG_BINPUT 256
|
||||
// 3326: \x89 NEWFALSE
|
||||
// 3327: h BINGET 16
|
||||
// 3329: ) EMPTY_TUPLE
|
||||
// 3330: R REDUCE
|
||||
// 3331: r LONG_BINPUT 257
|
||||
// 3336: t TUPLE (MARK at 3268)
|
||||
// 3337: r LONG_BINPUT 258
|
||||
// 3342: R REDUCE
|
||||
// 3343: r LONG_BINPUT 259
|
||||
// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
|
||||
// 3404: r LONG_BINPUT 260
|
||||
// 3409: h BINGET 8
|
||||
// 3411: ( MARK
|
||||
// 3412: ( MARK
|
||||
// 3413: h BINGET 9
|
||||
// 3415: h BINGET 10
|
||||
// 3417: X BINUNICODE '31'
|
||||
|
||||
struct PickleTensorReader {
|
||||
enum ReadPhase {
|
||||
READ_NAME,
|
||||
READ_DATA,
|
||||
CHECK_SIZE,
|
||||
READ_DIMENS
|
||||
};
|
||||
ReadPhase phase = READ_NAME;
|
||||
size_t entry_size = 0;
|
||||
int32_t nelements = 0;
|
||||
|
||||
TensorStorage tensor_storage;
|
||||
|
||||
static ggml_type global_type; // all pickle_tensors data type
|
||||
static bool read_global_type;
|
||||
|
||||
bool read_int_value(uint32_t value) {
|
||||
if (phase == CHECK_SIZE) {
|
||||
if (entry_size == value * ggml_type_size(tensor_storage.type)) {
|
||||
nelements = value;
|
||||
phase = READ_DIMENS;
|
||||
return true;
|
||||
} else {
|
||||
phase = READ_NAME;
|
||||
}
|
||||
} else if (phase == READ_DIMENS) {
|
||||
if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens
|
||||
phase = READ_NAME;
|
||||
tensor_storage.n_dims = 0;
|
||||
}
|
||||
if (nelements % value == 0) {
|
||||
tensor_storage.ne[tensor_storage.n_dims] = value;
|
||||
tensor_storage.n_dims++;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void read_global(const std::string& str) {
|
||||
if (str == "FloatStorage") {
|
||||
if (read_global_type) {
|
||||
global_type = GGML_TYPE_F32;
|
||||
read_global_type = false;
|
||||
}
|
||||
tensor_storage.type = GGML_TYPE_F32;
|
||||
} else if (str == "HalfStorage") {
|
||||
if (read_global_type) {
|
||||
global_type = GGML_TYPE_F16;
|
||||
read_global_type = false;
|
||||
}
|
||||
tensor_storage.type = GGML_TYPE_F16;
|
||||
}
|
||||
}
|
||||
|
||||
void read_string(const std::string& str, zip_t* zip, std::string dir) {
|
||||
if (str == "storage") {
|
||||
read_global_type = true;
|
||||
} else if (str != "state_dict") {
|
||||
if (phase == READ_DATA) {
|
||||
std::string entry_name = dir + "data/" + std::string(str);
|
||||
|
||||
size_t i, n = zip_entries_total(zip);
|
||||
for (i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
{
|
||||
std::string name = zip_entry_name(zip);
|
||||
if (name == entry_name) {
|
||||
tensor_storage.index_in_zip = (int)i;
|
||||
entry_size = zip_entry_size(zip);
|
||||
zip_entry_close(zip);
|
||||
break;
|
||||
}
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
}
|
||||
|
||||
phase = entry_size > 0 ? CHECK_SIZE : READ_NAME;
|
||||
}
|
||||
if (!read_global_type && phase == READ_NAME) {
|
||||
tensor_storage.name = str;
|
||||
phase = READ_DATA;
|
||||
tensor_storage.type = global_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type
|
||||
bool PickleTensorReader::read_global_type = false;
|
||||
|
||||
static int find_char(uint8_t* buffer, int len, char c) {
|
||||
for (int pos = 0; pos < len; pos++) {
|
||||
if (buffer[pos] == c) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
static bool parse_data_pkl(uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
std::string dir,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
uint8_t* buffer_end = buffer + buffer_size;
|
||||
if (buffer[0] == 0x80) { // proto
|
||||
if (buffer[1] != 2) {
|
||||
set_error(error, "unsupported pickle protocol");
|
||||
return false;
|
||||
}
|
||||
buffer += 2; // 0x80 and version
|
||||
char string_buffer[MAX_STRING_BUFFER];
|
||||
bool finish = false;
|
||||
PickleTensorReader reader;
|
||||
// read pickle binary file
|
||||
while (!finish && buffer < buffer_end) {
|
||||
uint8_t opcode = *buffer;
|
||||
buffer++;
|
||||
// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
|
||||
// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
|
||||
switch (opcode) {
|
||||
case '}': // EMPTY_DICT = b'}' # push empty dict
|
||||
break;
|
||||
case ']': // EMPTY_LIST = b']' # push empty list
|
||||
break;
|
||||
// skip unused sections
|
||||
case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg
|
||||
case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg
|
||||
case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack
|
||||
buffer++;
|
||||
break;
|
||||
case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg
|
||||
buffer += 4;
|
||||
break;
|
||||
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
|
||||
buffer += 8;
|
||||
break;
|
||||
case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo
|
||||
break;
|
||||
case '(': // MARK = b'(' # push special markobject on stack
|
||||
break;
|
||||
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
||||
{
|
||||
uint8_t value = *buffer;
|
||||
if (reader.read_int_value(value)) {
|
||||
buffer++;
|
||||
}
|
||||
buffer++;
|
||||
} break;
|
||||
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
||||
{
|
||||
uint16_t value = read_short(buffer);
|
||||
if (reader.read_int_value(value)) {
|
||||
buffer++;
|
||||
}
|
||||
buffer += 2;
|
||||
} break;
|
||||
case 'J': // BININT = b'J' # push four-byte signed int
|
||||
{
|
||||
const int32_t value = read_int(buffer);
|
||||
if (reader.read_int_value(value)) {
|
||||
buffer++; // skip tuple after read num_elements
|
||||
}
|
||||
buffer += 4;
|
||||
} break;
|
||||
case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument
|
||||
{
|
||||
const int32_t len = read_int(buffer);
|
||||
buffer += 4;
|
||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||
if (len > MAX_STRING_BUFFER) {
|
||||
// keep truncated names null-terminated, matching the old parser behavior
|
||||
}
|
||||
memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1));
|
||||
buffer += len;
|
||||
reader.read_string(string_buffer, zip, dir);
|
||||
} break;
|
||||
case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
|
||||
{
|
||||
const int8_t len = *buffer;
|
||||
buffer++;
|
||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||
memcpy(string_buffer, buffer, len);
|
||||
buffer += len;
|
||||
// printf("String: '%s'\n", string_buffer);
|
||||
} break;
|
||||
case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args
|
||||
{
|
||||
int len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
||||
|
||||
buffer += len + 1;
|
||||
len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
||||
|
||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||
memcpy(string_buffer, buffer, len);
|
||||
buffer += len + 1;
|
||||
reader.read_global(string_buffer);
|
||||
} break;
|
||||
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items
|
||||
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top
|
||||
case 't': // TUPLE = b't' # build tuple from topmost stack items
|
||||
if (reader.phase == PickleTensorReader::READ_DIMENS) {
|
||||
reader.tensor_storage.reverse_ne();
|
||||
tensor_storages.push_back(reader.tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
||||
// reset
|
||||
reader = PickleTensorReader();
|
||||
}
|
||||
break;
|
||||
case '.': // STOP = b'.' # every pickle ends with STOP
|
||||
finish = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_ckpt_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.clear();
|
||||
bool success = true;
|
||||
int n = (int)zip_entries_total(zip);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
{
|
||||
std::string name = zip_entry_name(zip);
|
||||
size_t pos = name.find("data.pkl");
|
||||
if (pos != std::string::npos) {
|
||||
std::string dir = name.substr(0, pos);
|
||||
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
|
||||
void* pkl_data = nullptr;
|
||||
size_t pkl_size;
|
||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||
|
||||
// LOG_DEBUG("%lld", pkl_size);
|
||||
|
||||
if (!parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
|
||||
success = false;
|
||||
}
|
||||
|
||||
free(pkl_data);
|
||||
}
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
|
||||
if (!success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
zip_close(zip);
|
||||
return success;
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
#ifndef __SD_MODEL_IO_CKPT_IO_H__
|
||||
#define __SD_MODEL_IO_CKPT_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool is_ckpt_file(const std::string& file_path);
|
||||
bool read_ckpt_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_CKPT_IO_H__
|
||||
@ -95,7 +95,7 @@ bool read_gguf_file(const std::string& file_path,
|
||||
}
|
||||
|
||||
bool write_gguf_file(const std::string& file_path,
|
||||
const std::vector<ggml_tensor*>& tensors,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error) {
|
||||
gguf_context* gguf_ctx = gguf_init_empty();
|
||||
if (gguf_ctx == nullptr) {
|
||||
@ -103,7 +103,8 @@ bool write_gguf_file(const std::string& file_path,
|
||||
return false;
|
||||
}
|
||||
|
||||
for (ggml_tensor* tensor : tensors) {
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
if (tensor == nullptr) {
|
||||
set_error(error, "null tensor cannot be written to GGUF");
|
||||
gguf_free(gguf_ctx);
|
||||
|
||||
@ -11,7 +11,7 @@ bool read_gguf_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
bool write_gguf_file(const std::string& file_path,
|
||||
const std::vector<ggml_tensor*>& tensors,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_GGUF_IO_H__
|
||||
|
||||
1064
src/model_io/pickle_io.cpp
Normal file
1064
src/model_io/pickle_io.cpp
Normal file
File diff suppressed because it is too large
Load Diff
21
src/model_io/pickle_io.h
Normal file
21
src/model_io/pickle_io.h
Normal file
@ -0,0 +1,21 @@
|
||||
#ifndef __SD_MODEL_IO_PICKLE_IO_H__
|
||||
#define __SD_MODEL_IO_PICKLE_IO_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size);
|
||||
bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size);
|
||||
bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value);
|
||||
bool parse_torch_state_dict_pickle(const uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::unordered_map<std::string, uint64_t>& storage_nbytes,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_PICKLE_IO_H__
|
||||
@ -6,7 +6,9 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "binary_io.h"
|
||||
#include "json.hpp"
|
||||
#include "util.h"
|
||||
|
||||
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
|
||||
|
||||
@ -16,20 +18,6 @@ static void set_error(std::string* error, const std::string& message) {
|
||||
}
|
||||
}
|
||||
|
||||
static uint64_t read_u64(const uint8_t* buffer) {
|
||||
// little endian
|
||||
uint64_t value = 0;
|
||||
value |= static_cast<uint64_t>(buffer[7]) << 56;
|
||||
value |= static_cast<uint64_t>(buffer[6]) << 48;
|
||||
value |= static_cast<uint64_t>(buffer[5]) << 40;
|
||||
value |= static_cast<uint64_t>(buffer[4]) << 32;
|
||||
value |= static_cast<uint64_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint64_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint64_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint64_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
bool is_safetensors_file(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
@ -52,7 +40,7 @@ bool is_safetensors_file(const std::string& file_path) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t header_size_ = read_u64(header_size_buf);
|
||||
size_t header_size_ = model_io::read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_ || header_size_ <= 2) {
|
||||
return false;
|
||||
}
|
||||
@ -73,7 +61,7 @@ bool is_safetensors_file(const std::string& file_path) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) {
|
||||
ggml_type ttype = GGML_TYPE_COUNT;
|
||||
if (dtype == "F16") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
@ -123,7 +111,7 @@ bool read_safetensors_file(const std::string& file_path,
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t header_size_ = read_u64(header_size_buf);
|
||||
size_t header_size_ = model_io::read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_) {
|
||||
set_error(error, "invalid safetensor file '" + file_path + "'");
|
||||
return false;
|
||||
@ -167,7 +155,7 @@ bool read_safetensors_file(const std::string& file_path,
|
||||
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
||||
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
||||
|
||||
ggml_type type = str_to_ggml_type(dtype);
|
||||
ggml_type type = safetensors_dtype_to_ggml_type(dtype);
|
||||
if (type == GGML_TYPE_COUNT) {
|
||||
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
|
||||
return false;
|
||||
@ -234,3 +222,95 @@ bool read_safetensors_file(const std::string& file_path,
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F16:
|
||||
*dtype = "F16";
|
||||
return true;
|
||||
case GGML_TYPE_BF16:
|
||||
*dtype = "BF16";
|
||||
return true;
|
||||
case GGML_TYPE_F32:
|
||||
*dtype = "F32";
|
||||
return true;
|
||||
case GGML_TYPE_I32:
|
||||
*dtype = "I32";
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool write_safetensors_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error) {
|
||||
nlohmann::ordered_json header = nlohmann::ordered_json::object();
|
||||
|
||||
uint64_t data_offset = 0;
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
if (tensor == nullptr) {
|
||||
set_error(error, "null tensor cannot be written to safetensors");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
std::string dtype;
|
||||
if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) {
|
||||
set_error(error,
|
||||
"unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) +
|
||||
"' for tensor '" + name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint64_t tensor_nbytes = ggml_nbytes(tensor);
|
||||
|
||||
nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object();
|
||||
json_tensor_info["dtype"] = dtype;
|
||||
|
||||
nlohmann::ordered_json shape = nlohmann::ordered_json::array();
|
||||
for (int i = 0; i < write_tensor.n_dims; ++i) {
|
||||
shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]);
|
||||
}
|
||||
json_tensor_info["shape"] = shape;
|
||||
|
||||
nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array();
|
||||
data_offsets.push_back(data_offset);
|
||||
data_offsets.push_back(data_offset + tensor_nbytes);
|
||||
json_tensor_info["data_offsets"] = data_offsets;
|
||||
|
||||
header[name] = json_tensor_info;
|
||||
data_offset += tensor_nbytes;
|
||||
}
|
||||
|
||||
const std::string header_str = header.dump();
|
||||
|
||||
std::ofstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
set_error(error, "failed to open '" + file_path + "' for writing");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||
model_io::write_u64(file, header_str.size());
|
||||
file.write(header_str.data(), header_str.size());
|
||||
if (!file) {
|
||||
set_error(error, "failed to write safetensors header to '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
const size_t tensor_nbytes = ggml_nbytes(tensor);
|
||||
file.write((const char*)tensor->data, tensor_nbytes);
|
||||
if (!file) {
|
||||
set_error(error,
|
||||
"failed to write tensor '" + name + "' to '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -10,5 +10,8 @@ bool is_safetensors_file(const std::string& file_path);
|
||||
bool read_safetensors_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
bool write_safetensors_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||
|
||||
@ -24,6 +24,7 @@ struct TensorStorage {
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
|
||||
std::string storage_key;
|
||||
size_t file_index = 0;
|
||||
int index_in_zip = -1; // >= means stored in a zip file
|
||||
uint64_t offset = 0; // offset in file
|
||||
@ -120,6 +121,12 @@ struct TensorStorage {
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorWriteInfo {
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
ggml_tensor* tensor = nullptr;
|
||||
};
|
||||
|
||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||
|
||||
#endif // __SD_TENSOR_STORAGE_H__
|
||||
|
||||
252
src/model_io/torch_legacy_io.cpp
Normal file
252
src/model_io/torch_legacy_io.cpp
Normal file
@ -0,0 +1,252 @@
|
||||
#include "torch_legacy_io.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "pickle_io.h"
|
||||
#include "util.h"
|
||||
|
||||
// torch.save format background:
|
||||
//
|
||||
// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by
|
||||
// default.
|
||||
// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive
|
||||
// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder.
|
||||
// - The old format can still be produced explicitly with:
|
||||
// torch.save(obj, path, _use_new_zipfile_serialization=False)
|
||||
//
|
||||
// Whether obj is a state_dict or a whole nn.Module does not change the outer
|
||||
// container format selected by torch.save. It changes the pickled object inside:
|
||||
//
|
||||
// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a
|
||||
// restricted subset of this layout because tensor metadata and raw storages
|
||||
// can be recovered without executing pickle callables.
|
||||
// - whole module/checkpoint object: arbitrary Python object graph. This may
|
||||
// require importing user classes and executing pickle GLOBAL/REDUCE rebuild
|
||||
// logic, so it is intentionally not supported here.
|
||||
//
|
||||
// Legacy non-zip PyTorch files are not a single pickle object:
|
||||
//
|
||||
// 1. pickle object: PyTorch legacy magic number
|
||||
// 2. pickle object: legacy protocol version, expected to be 1001
|
||||
// 3. pickle object: sys_info metadata, ignored by this reader
|
||||
// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp
|
||||
// 5. pickle object: serialized storage key list, skipped here
|
||||
// 6. raw storage data payloads
|
||||
// - PyTorch writes storages after the pickles, ordered by storage key
|
||||
// - each storage has an 8-byte legacy storage header followed by raw bytes
|
||||
static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8;
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
static std::string bytes_to_hex(const std::vector<uint8_t>& bytes) {
|
||||
static const char* hex = "0123456789ABCDEF";
|
||||
std::string result;
|
||||
result.reserve(bytes.size() * 3);
|
||||
for (size_t i = 0; i < bytes.size(); ++i) {
|
||||
if (i > 0) {
|
||||
result.push_back('-');
|
||||
}
|
||||
result.push_back(hex[(bytes[i] >> 4) & 0x0F]);
|
||||
result.push_back(hex[bytes[i] & 0x0F]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool is_probably_tar_file(const std::vector<uint8_t>& header) {
|
||||
return header.size() >= 262 &&
|
||||
header[257] == 'u' &&
|
||||
header[258] == 's' &&
|
||||
header[259] == 't' &&
|
||||
header[260] == 'a' &&
|
||||
header[261] == 'r';
|
||||
}
|
||||
|
||||
static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector<uint8_t>& buffer) {
|
||||
if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) {
|
||||
return "";
|
||||
}
|
||||
if (buffer.empty()) {
|
||||
return "unsupported PyTorch file '" + file_path + "': empty file";
|
||||
}
|
||||
|
||||
size_t short_len = std::min<size_t>(buffer.size(), 32);
|
||||
std::vector<uint8_t> short_header(buffer.begin(), buffer.begin() + short_len);
|
||||
const bool raw_pickle = buffer[0] == 0x80;
|
||||
const bool tar_file = is_probably_tar_file(buffer);
|
||||
|
||||
std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " +
|
||||
bytes_to_hex(short_header) +
|
||||
", raw_pickle=" + (raw_pickle ? "true" : "false") +
|
||||
", tar=" + (tar_file ? "true" : "false");
|
||||
if (raw_pickle) {
|
||||
message += "; raw pickle did not match the restricted state_dict layouts currently supported";
|
||||
} else if (tar_file) {
|
||||
message += "; legacy tar PyTorch checkpoints are not supported yet";
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
bool read_torch_legacy_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
file.seekg(0, file.end);
|
||||
size_t file_size = (size_t)file.tellg();
|
||||
file.seekg(0, file.beg);
|
||||
if (file_size == 0) {
|
||||
set_error(error, "empty file '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> buffer(file_size);
|
||||
file.read((char*)buffer.data(), file_size);
|
||||
if (!file) {
|
||||
set_error(error, "failed to read '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto finalize_tensor_offsets = [&](size_t storage_data_offset,
|
||||
const std::unordered_map<std::string, uint64_t>& legacy_storage_map) -> bool {
|
||||
if (storage_data_offset > file_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> storage_keys;
|
||||
storage_keys.reserve(legacy_storage_map.size());
|
||||
for (const auto& [storage_key, _] : legacy_storage_map) {
|
||||
storage_keys.push_back(storage_key);
|
||||
}
|
||||
std::sort(storage_keys.begin(), storage_keys.end());
|
||||
|
||||
std::unordered_map<std::string, uint64_t> storage_offsets;
|
||||
uint64_t current_offset = storage_data_offset;
|
||||
for (const auto& storage_key : storage_keys) {
|
||||
auto it = legacy_storage_map.find(storage_key);
|
||||
if (it == legacy_storage_map.end()) {
|
||||
return false;
|
||||
}
|
||||
if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) {
|
||||
return false;
|
||||
}
|
||||
storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE;
|
||||
current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second;
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.storage_key.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it_offset = storage_offsets.find(tensor_storage.storage_key);
|
||||
auto it_size = legacy_storage_map.find(tensor_storage.storage_key);
|
||||
if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t base_offset = it_offset->second;
|
||||
uint64_t storage_nbytes = it_size->second;
|
||||
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
|
||||
if (tensor_storage.offset + tensor_nbytes > storage_nbytes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storage.offset = base_offset + tensor_storage.offset;
|
||||
tensor_storage.storage_key.clear();
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool {
|
||||
tensor_storages.clear();
|
||||
std::unordered_map<std::string, uint64_t> legacy_storage_map;
|
||||
if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset,
|
||||
state_dict_size,
|
||||
tensor_storages,
|
||||
legacy_storage_map,
|
||||
error)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t offset_after_state_dict = state_dict_offset + state_dict_size;
|
||||
size_t storage_keys_size = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset_after_state_dict,
|
||||
buffer.size() - offset_after_state_dict,
|
||||
&storage_keys_size)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*storage_data_offset = offset_after_state_dict + storage_keys_size;
|
||||
return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map);
|
||||
};
|
||||
|
||||
size_t object_size_1 = 0;
|
||||
size_t offset = 0;
|
||||
|
||||
if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) &&
|
||||
pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) {
|
||||
offset += object_size_1;
|
||||
|
||||
size_t object_size_2 = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
uint32_t protocol_version = 0;
|
||||
if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
offset += object_size_2;
|
||||
|
||||
size_t object_size_3 = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
offset += object_size_3;
|
||||
|
||||
size_t state_dict_size = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t storage_data_offset = 0;
|
||||
if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (error != nullptr && error->empty()) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t state_dict_size = 0;
|
||||
if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) {
|
||||
size_t storage_data_offset = 0;
|
||||
if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (error != nullptr && error->empty()) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
13
src/model_io/torch_legacy_io.h
Normal file
13
src/model_io/torch_legacy_io.h
Normal file
@ -0,0 +1,13 @@
|
||||
#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool read_torch_legacy_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
140
src/model_io/torch_zip_io.cpp
Normal file
140
src/model_io/torch_zip_io.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
#include "torch_zip_io.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "pickle_io.h"
|
||||
|
||||
#include "zip.h"
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_torch_zip_file(const std::string& file_path) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
return false;
|
||||
}
|
||||
zip_close(zip);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) {
|
||||
size_t n = zip_entries_total(zip);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
std::string name = zip_entry_name(zip);
|
||||
if (name == entry_name) {
|
||||
*index = (int)i;
|
||||
*size = zip_entry_size(zip);
|
||||
zip_entry_close(zip);
|
||||
return true;
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool parse_zip_data_pkl(const uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
const std::string& dir,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
std::vector<TensorStorage> parsed_tensors;
|
||||
std::unordered_map<std::string, uint64_t> storage_nbytes;
|
||||
if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) {
|
||||
if (error != nullptr && error->empty()) {
|
||||
*error = "failed to parse torch zip pickle metadata";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : parsed_tensors) {
|
||||
if (tensor_storage.storage_key.empty()) {
|
||||
set_error(error, "tensor '" + tensor_storage.name + "' has no storage key");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string entry_name = dir + "data/" + tensor_storage.storage_key;
|
||||
int zip_index = -1;
|
||||
uint64_t entry_size = 0;
|
||||
if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) {
|
||||
set_error(error, "storage entry '" + entry_name + "' was not found");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key);
|
||||
if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) {
|
||||
set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
|
||||
if (tensor_storage.offset + tensor_nbytes > entry_size) {
|
||||
set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storage.index_in_zip = zip_index;
|
||||
tensor_storage.storage_key.clear();
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_torch_zip_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.clear();
|
||||
bool success = true;
|
||||
bool found_data_pkl = false;
|
||||
int n = (int)zip_entries_total(zip);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
std::string name = zip_entry_name(zip);
|
||||
size_t pos = name.find("data.pkl");
|
||||
if (pos != std::string::npos) {
|
||||
found_data_pkl = true;
|
||||
std::string dir = name.substr(0, pos);
|
||||
void* pkl_data = nullptr;
|
||||
size_t pkl_size = 0;
|
||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||
|
||||
if (pkl_data == nullptr || pkl_size == 0) {
|
||||
set_error(error, "failed to read '" + name + "' from '" + file_path + "'");
|
||||
success = false;
|
||||
} else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
|
||||
success = false;
|
||||
}
|
||||
|
||||
free(pkl_data);
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
|
||||
if (!success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (success && !found_data_pkl) {
|
||||
set_error(error, "data.pkl was not found in '" + file_path + "'");
|
||||
success = false;
|
||||
}
|
||||
|
||||
zip_close(zip);
|
||||
return success;
|
||||
}
|
||||
14
src/model_io/torch_zip_io.h
Normal file
14
src/model_io/torch_zip_io.h
Normal file
@ -0,0 +1,14 @@
|
||||
#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
#define __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool is_torch_zip_file(const std::string& file_path);
|
||||
bool read_torch_zip_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
Loading…
x
Reference in New Issue
Block a user