mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 16:28:53 +00:00
Compare commits
3 Commits
7d33d4b2dd
...
66143340b6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66143340b6 | ||
|
|
7023fc4cfb | ||
|
|
e77e4c46bf |
@ -156,10 +156,12 @@ endif()
|
|||||||
|
|
||||||
set(SD_LIB stable-diffusion)
|
set(SD_LIB stable-diffusion)
|
||||||
|
|
||||||
file(GLOB SD_LIB_SOURCES
|
file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS
|
||||||
"src/*.h"
|
"src/*.h"
|
||||||
"src/*.cpp"
|
"src/*.cpp"
|
||||||
"src/*.hpp"
|
"src/*.hpp"
|
||||||
|
"src/model_io/*.h"
|
||||||
|
"src/model_io/*.cpp"
|
||||||
"src/tokenizers/*.h"
|
"src/tokenizers/*.h"
|
||||||
"src/tokenizers/*.cpp"
|
"src/tokenizers/*.cpp"
|
||||||
"src/tokenizers/vocab/*.h"
|
"src/tokenizers/vocab/*.h"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \
|
for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \
|
||||||
examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
|
src/model_io/*.h src/model_io/*.cpp examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
|
||||||
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
|
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
|
||||||
[[ "$f" == vocab* ]] && continue
|
[[ "$f" == vocab* ]] && continue
|
||||||
echo "formatting '$f'"
|
echo "formatting '$f'"
|
||||||
|
|||||||
@ -977,7 +977,7 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral_flow(denoise_cb_t model,
|
|||||||
float eta = 1.0f) {
|
float eta = 1.0f) {
|
||||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
float sigma = sigmas[i];
|
float sigma = sigmas[i];
|
||||||
float sigma_to = sigmas[i + 1];
|
float sigma_to = sigmas[i + 1];
|
||||||
|
|
||||||
bool opt_first_step = (1.0 - sigma < 1e-6);
|
bool opt_first_step = (1.0 - sigma < 1e-6);
|
||||||
@ -1040,10 +1040,10 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral_flow(denoise_cb_t model,
|
|||||||
// and sigma_s = sigma_fn(s) = 1.0f / (exp(s) + 1.0f)
|
// and sigma_s = sigma_fn(s) = 1.0f / (exp(s) + 1.0f)
|
||||||
|
|
||||||
float exp_s = std::sqrt(((1 - sigma) / sigma) * ((1 - sigma_down) / sigma_down));
|
float exp_s = std::sqrt(((1 - sigma) / sigma) * ((1 - sigma_down) / sigma_down));
|
||||||
sigma_s = 1.0f / (exp_s + 1.0f);
|
sigma_s = 1.0f / (exp_s + 1.0f);
|
||||||
|
|
||||||
float sigma_s_i_ratio = sigma_s / sigma;
|
float sigma_s_i_ratio = sigma_s / sigma;
|
||||||
sd::Tensor<float> u = (x * sigma_s_i_ratio) + (denoised * (1.0f - sigma_s_i_ratio));
|
sd::Tensor<float> u = (x * sigma_s_i_ratio) + (denoised * (1.0f - sigma_s_i_ratio));
|
||||||
|
|
||||||
auto denoised2_opt = model(u, sigma_s, i + 1);
|
auto denoised2_opt = model(u, sigma_s, i + 1);
|
||||||
if (denoised2_opt.empty()) {
|
if (denoised2_opt.empty()) {
|
||||||
@ -1053,7 +1053,7 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral_flow(denoise_cb_t model,
|
|||||||
}
|
}
|
||||||
|
|
||||||
float sigma_down_i_ratio = sigma_down / sigma;
|
float sigma_down_i_ratio = sigma_down / sigma;
|
||||||
x = (x * sigma_down_i_ratio) + (D_i * (1.0f - sigma_down_i_ratio));
|
x = (x * sigma_down_i_ratio) + (D_i * (1.0f - sigma_down_i_ratio));
|
||||||
|
|
||||||
if (sigma_to > 0.0f && eta > 0.0f) {
|
if (sigma_to > 0.0f && eta > 0.0f) {
|
||||||
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
|
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
|
||||||
@ -1064,8 +1064,6 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral_flow(denoise_cb_t model,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static sd::Tensor<float> sample_dpmpp_2m(denoise_cb_t model,
|
static sd::Tensor<float> sample_dpmpp_2m(denoise_cb_t model,
|
||||||
sd::Tensor<float> x,
|
sd::Tensor<float> x,
|
||||||
const std::vector<float>& sigmas) {
|
const std::vector<float>& sigmas) {
|
||||||
@ -1137,7 +1135,8 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
|
|||||||
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||||
sd::Tensor<float> x,
|
sd::Tensor<float> x,
|
||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
std::shared_ptr<RNG> rng) {
|
std::shared_ptr<RNG> rng,
|
||||||
|
bool is_flow_denoiser) {
|
||||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
auto denoised_opt = model(x, sigmas[i], i + 1);
|
auto denoised_opt = model(x, sigmas[i], i + 1);
|
||||||
@ -1146,6 +1145,9 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
|||||||
}
|
}
|
||||||
x = std::move(denoised_opt);
|
x = std::move(denoised_opt);
|
||||||
if (sigmas[i + 1] > 0) {
|
if (sigmas[i + 1] > 0) {
|
||||||
|
if (is_flow_denoiser) {
|
||||||
|
x *= (1 - sigmas[i + 1]);
|
||||||
|
}
|
||||||
x += sd::Tensor<float>::randn_like(x, rng) * sigmas[i + 1];
|
x += sd::Tensor<float>::randn_like(x, rng) * sigmas[i + 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1521,32 +1523,12 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
|
|||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
std::shared_ptr<RNG> rng,
|
std::shared_ptr<RNG> rng,
|
||||||
float eta) {
|
float eta) {
|
||||||
float beta_start = 0.00085f;
|
|
||||||
float beta_end = 0.0120f;
|
|
||||||
std::vector<double> alphas_cumprod(TIMESTEPS);
|
|
||||||
std::vector<double> compvis_sigmas(TIMESTEPS);
|
|
||||||
for (int i = 0; i < TIMESTEPS; i++) {
|
|
||||||
alphas_cumprod[i] =
|
|
||||||
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
|
|
||||||
(1.0f -
|
|
||||||
std::pow(sqrtf(beta_start) +
|
|
||||||
(sqrtf(beta_end) - sqrtf(beta_start)) *
|
|
||||||
((float)i / (TIMESTEPS - 1)),
|
|
||||||
2));
|
|
||||||
compvis_sigmas[i] =
|
|
||||||
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1;
|
|
||||||
int prev_timestep = timestep - TIMESTEPS / steps;
|
float sigma = sigmas[i];
|
||||||
float sigma = static_cast<float>(compvis_sigmas[timestep]);
|
float sigma_to = sigmas[i + 1];
|
||||||
if (i == 0) {
|
|
||||||
x *= std::sqrt(sigma * sigma + 1) / sigma;
|
|
||||||
} else {
|
|
||||||
x *= std::sqrt(sigma * sigma + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto model_output_opt = model(x, sigma, i + 1);
|
auto model_output_opt = model(x, sigma, i + 1);
|
||||||
if (model_output_opt.empty()) {
|
if (model_output_opt.empty()) {
|
||||||
@ -1555,8 +1537,8 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
|
|||||||
sd::Tensor<float> model_output = std::move(model_output_opt);
|
sd::Tensor<float> model_output = std::move(model_output_opt);
|
||||||
model_output = (x - model_output) * (1.0f / sigma);
|
model_output = (x - model_output) * (1.0f / sigma);
|
||||||
|
|
||||||
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
|
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
|
||||||
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
|
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
|
||||||
float beta_prod_t = 1.0f - alpha_prod_t;
|
float beta_prod_t = 1.0f - alpha_prod_t;
|
||||||
|
|
||||||
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
|
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
|
||||||
@ -1568,12 +1550,13 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
|
|||||||
(1.0f - alpha_prod_t / alpha_prod_t_prev);
|
(1.0f - alpha_prod_t / alpha_prod_t_prev);
|
||||||
float std_dev_t = eta * std::sqrt(variance);
|
float std_dev_t = eta * std::sqrt(variance);
|
||||||
|
|
||||||
x = std::sqrt(alpha_prod_t_prev) * pred_original_sample +
|
x = pred_original_sample +
|
||||||
std::sqrt(1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * model_output;
|
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
|
||||||
|
|
||||||
if (eta > 0) {
|
if (eta > 0) {
|
||||||
x += std_dev_t * sd::Tensor<float>::randn_like(x, rng);
|
x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1599,19 +1582,25 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
|||||||
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
|
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int original_steps = 50;
|
auto get_timestep_from_sigma = [&](float s) -> int {
|
||||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
|
||||||
for (int i = 0; i < steps; i++) {
|
if (it == compvis_sigmas.begin()) return 0;
|
||||||
int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps));
|
if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
|
||||||
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
|
int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
|
||||||
int timestep_s = (int)floor((1 - eta) * prev_timestep);
|
int idx_low = idx_high - 1;
|
||||||
float sigma = static_cast<float>(compvis_sigmas[timestep]);
|
if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
|
||||||
|
return idx_high;
|
||||||
if (i == 0) {
|
|
||||||
x *= std::sqrt(sigma * sigma + 1) / sigma;
|
|
||||||
} else {
|
|
||||||
x *= std::sqrt(sigma * sigma + 1);
|
|
||||||
}
|
}
|
||||||
|
return idx_low;
|
||||||
|
};
|
||||||
|
|
||||||
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
|
||||||
|
float sigma_to = sigmas[i + 1];
|
||||||
|
int prev_timestep = get_timestep_from_sigma(sigma_to);
|
||||||
|
int timestep_s = (int)floor((1 - eta) * prev_timestep);
|
||||||
|
float sigma = sigmas[i];
|
||||||
|
|
||||||
auto model_output_opt = model(x, sigma, i + 1);
|
auto model_output_opt = model(x, sigma, i + 1);
|
||||||
if (model_output_opt.empty()) {
|
if (model_output_opt.empty()) {
|
||||||
@ -1620,9 +1609,9 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
|||||||
sd::Tensor<float> model_output = std::move(model_output_opt);
|
sd::Tensor<float> model_output = std::move(model_output_opt);
|
||||||
model_output = (x - model_output) * (1.0f / sigma);
|
model_output = (x - model_output) * (1.0f / sigma);
|
||||||
|
|
||||||
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
|
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
|
||||||
float beta_prod_t = 1.0f - alpha_prod_t;
|
float beta_prod_t = 1.0f - alpha_prod_t;
|
||||||
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
|
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
|
||||||
float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
|
float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
|
||||||
float beta_prod_s = 1.0f - alpha_prod_s;
|
float beta_prod_s = 1.0f - alpha_prod_s;
|
||||||
|
|
||||||
@ -1630,13 +1619,14 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
|
|||||||
std::sqrt(beta_prod_t) * model_output) *
|
std::sqrt(beta_prod_t) * model_output) *
|
||||||
(1.0f / std::sqrt(alpha_prod_t));
|
(1.0f / std::sqrt(alpha_prod_t));
|
||||||
|
|
||||||
x = std::sqrt(alpha_prod_s) * pred_original_sample +
|
x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample +
|
||||||
std::sqrt(beta_prod_s) * model_output;
|
std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output;
|
||||||
|
|
||||||
if (eta > 0 && i != steps - 1) {
|
if (eta > 0 && sigma_to > 0.0f) {
|
||||||
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
|
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
|
||||||
std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
|
std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1671,7 +1661,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
|||||||
case DPMPP2Mv2_SAMPLE_METHOD:
|
case DPMPP2Mv2_SAMPLE_METHOD:
|
||||||
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
|
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
|
||||||
case LCM_SAMPLE_METHOD:
|
case LCM_SAMPLE_METHOD:
|
||||||
return sample_lcm(model, std::move(x), sigmas, rng);
|
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser);
|
||||||
case IPNDM_SAMPLE_METHOD:
|
case IPNDM_SAMPLE_METHOD:
|
||||||
return sample_ipndm(model, std::move(x), sigmas);
|
return sample_ipndm(model, std::move(x), sigmas);
|
||||||
case IPNDM_V_SAMPLE_METHOD:
|
case IPNDM_V_SAMPLE_METHOD:
|
||||||
|
|||||||
762
src/model.cpp
762
src/model.cpp
@ -12,8 +12,10 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gguf_reader.hpp"
|
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
#include "model_io/ckpt_io.h"
|
||||||
|
#include "model_io/gguf_io.h"
|
||||||
|
#include "model_io/safetensors_io.h"
|
||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
@ -21,6 +23,7 @@
|
|||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
#include "ggml-cpu.h"
|
#include "ggml-cpu.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "zip.h"
|
||||||
|
|
||||||
#include "name_conversion.h"
|
#include "name_conversion.h"
|
||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
@ -37,40 +40,6 @@
|
|||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define ST_HEADER_SIZE_LEN 8
|
|
||||||
|
|
||||||
uint64_t read_u64(uint8_t* buffer) {
|
|
||||||
// little endian
|
|
||||||
uint64_t value = 0;
|
|
||||||
value |= static_cast<int64_t>(buffer[7]) << 56;
|
|
||||||
value |= static_cast<int64_t>(buffer[6]) << 48;
|
|
||||||
value |= static_cast<int64_t>(buffer[5]) << 40;
|
|
||||||
value |= static_cast<int64_t>(buffer[4]) << 32;
|
|
||||||
value |= static_cast<int64_t>(buffer[3]) << 24;
|
|
||||||
value |= static_cast<int64_t>(buffer[2]) << 16;
|
|
||||||
value |= static_cast<int64_t>(buffer[1]) << 8;
|
|
||||||
value |= static_cast<int64_t>(buffer[0]);
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t read_int(uint8_t* buffer) {
|
|
||||||
// little endian
|
|
||||||
int value = 0;
|
|
||||||
value |= buffer[3] << 24;
|
|
||||||
value |= buffer[2] << 16;
|
|
||||||
value |= buffer[1] << 8;
|
|
||||||
value |= buffer[0];
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint16_t read_short(uint8_t* buffer) {
|
|
||||||
// little endian
|
|
||||||
uint16_t value = 0;
|
|
||||||
value |= buffer[1] << 8;
|
|
||||||
value |= buffer[0];
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*================================================= Preprocess ==================================================*/
|
/*================================================= Preprocess ==================================================*/
|
||||||
|
|
||||||
const char* unused_tensors[] = {
|
const char* unused_tensors[] = {
|
||||||
@ -250,79 +219,6 @@ void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) {
|
|||||||
tensor_storage_map[tensor_storage.name] = tensor_storage;
|
tensor_storage_map[tensor_storage.name] = tensor_storage;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_zip_file(const std::string& file_path) {
|
|
||||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
|
||||||
if (zip == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
zip_close(zip);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_gguf_file(const std::string& file_path) {
|
|
||||||
std::ifstream file(file_path, std::ios::binary);
|
|
||||||
if (!file.is_open()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
char magic[4];
|
|
||||||
|
|
||||||
file.read(magic, sizeof(magic));
|
|
||||||
if (!file) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
|
||||||
if (magic[i] != GGUF_MAGIC[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_safetensors_file(const std::string& file_path) {
|
|
||||||
std::ifstream file(file_path, std::ios::binary);
|
|
||||||
if (!file.is_open()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get file size
|
|
||||||
file.seekg(0, file.end);
|
|
||||||
size_t file_size_ = file.tellg();
|
|
||||||
file.seekg(0, file.beg);
|
|
||||||
|
|
||||||
// read header size
|
|
||||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
|
||||||
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
|
||||||
if (!file) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t header_size_ = read_u64(header_size_buf);
|
|
||||||
if (header_size_ >= file_size_ || header_size_ <= 2) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// read header
|
|
||||||
std::vector<char> header_buf;
|
|
||||||
header_buf.resize(header_size_ + 1);
|
|
||||||
header_buf[header_size_] = '\0';
|
|
||||||
file.read(header_buf.data(), header_size_);
|
|
||||||
if (!file) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
|
||||||
if (is_directory(file_path)) {
|
if (is_directory(file_path)) {
|
||||||
LOG_INFO("load %s using diffusers format", file_path.c_str());
|
LOG_INFO("load %s using diffusers format", file_path.c_str());
|
||||||
@ -333,7 +229,7 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
|
|||||||
} else if (is_safetensors_file(file_path)) {
|
} else if (is_safetensors_file(file_path)) {
|
||||||
LOG_INFO("load %s using safetensors format", file_path.c_str());
|
LOG_INFO("load %s using safetensors format", file_path.c_str());
|
||||||
return init_from_safetensors_file(file_path, prefix);
|
return init_from_safetensors_file(file_path, prefix);
|
||||||
} else if (is_zip_file(file_path)) {
|
} else if (is_ckpt_file(file_path)) {
|
||||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||||
return init_from_ckpt_file(file_path, prefix);
|
return init_from_ckpt_file(file_path, prefix);
|
||||||
} else {
|
} else {
|
||||||
@ -375,242 +271,59 @@ bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path,
|
|||||||
|
|
||||||
bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) {
|
bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) {
|
||||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||||
|
|
||||||
|
std::vector<TensorStorage> tensor_storages;
|
||||||
|
std::string error;
|
||||||
|
if (!read_gguf_file(file_path, tensor_storages, &error)) {
|
||||||
|
LOG_ERROR("%s", error.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
file_paths_.push_back(file_path);
|
file_paths_.push_back(file_path);
|
||||||
size_t file_index = file_paths_.size() - 1;
|
size_t file_index = file_paths_.size() - 1;
|
||||||
|
|
||||||
gguf_context* ctx_gguf_ = nullptr;
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
ggml_context* ctx_meta_ = nullptr;
|
// LOG_DEBUG("%s", tensor_storage.name.c_str());
|
||||||
|
|
||||||
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
|
if (!starts_with(tensor_storage.name, prefix)) {
|
||||||
if (!ctx_gguf_) {
|
tensor_storage.name = prefix + tensor_storage.name;
|
||||||
LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str());
|
|
||||||
GGUFReader gguf_reader;
|
|
||||||
if (!gguf_reader.load(file_path)) {
|
|
||||||
LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
tensor_storage.file_index = file_index;
|
||||||
size_t data_offset = gguf_reader.data_offset();
|
|
||||||
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
|
|
||||||
std::string name = gguf_tensor_info.name;
|
|
||||||
if (!starts_with(name, prefix)) {
|
|
||||||
name = prefix + name;
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorStorage tensor_storage(
|
|
||||||
name,
|
|
||||||
gguf_tensor_info.type,
|
|
||||||
gguf_tensor_info.shape.data(),
|
|
||||||
static_cast<int>(gguf_tensor_info.shape.size()),
|
|
||||||
file_index,
|
|
||||||
data_offset + gguf_tensor_info.offset);
|
|
||||||
|
|
||||||
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
|
|
||||||
|
|
||||||
add_tensor_storage(tensor_storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
|
|
||||||
|
|
||||||
size_t total_size = 0;
|
|
||||||
size_t data_offset = gguf_get_data_offset(ctx_gguf_);
|
|
||||||
for (int i = 0; i < n_tensors; i++) {
|
|
||||||
std::string name = gguf_get_tensor_name(ctx_gguf_, i);
|
|
||||||
ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str());
|
|
||||||
size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i);
|
|
||||||
|
|
||||||
// LOG_DEBUG("%s", name.c_str());
|
|
||||||
|
|
||||||
if (!starts_with(name, prefix)) {
|
|
||||||
name = prefix + name;
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
|
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
|
||||||
|
|
||||||
add_tensor_storage(tensor_storage);
|
add_tensor_storage(tensor_storage);
|
||||||
}
|
}
|
||||||
|
|
||||||
gguf_free(ctx_gguf_);
|
|
||||||
ggml_free(ctx_meta_);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*================================================= SafeTensorsModelLoader ==================================================*/
|
/*================================================= SafeTensorsModelLoader ==================================================*/
|
||||||
|
|
||||||
ggml_type str_to_ggml_type(const std::string& dtype) {
|
|
||||||
ggml_type ttype = GGML_TYPE_COUNT;
|
|
||||||
if (dtype == "F16") {
|
|
||||||
ttype = GGML_TYPE_F16;
|
|
||||||
} else if (dtype == "BF16") {
|
|
||||||
ttype = GGML_TYPE_BF16;
|
|
||||||
} else if (dtype == "F32") {
|
|
||||||
ttype = GGML_TYPE_F32;
|
|
||||||
} else if (dtype == "F64") {
|
|
||||||
ttype = GGML_TYPE_F32;
|
|
||||||
} else if (dtype == "F8_E4M3") {
|
|
||||||
ttype = GGML_TYPE_F16;
|
|
||||||
} else if (dtype == "F8_E5M2") {
|
|
||||||
ttype = GGML_TYPE_F16;
|
|
||||||
} else if (dtype == "I64") {
|
|
||||||
ttype = GGML_TYPE_I32;
|
|
||||||
}
|
|
||||||
return ttype;
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://huggingface.co/docs/safetensors/index
|
|
||||||
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
||||||
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
|
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
|
||||||
|
|
||||||
|
std::vector<TensorStorage> tensor_storages;
|
||||||
|
std::string error;
|
||||||
|
if (!read_safetensors_file(file_path, tensor_storages, &error)) {
|
||||||
|
LOG_ERROR("%s", error.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
file_paths_.push_back(file_path);
|
file_paths_.push_back(file_path);
|
||||||
size_t file_index = file_paths_.size() - 1;
|
size_t file_index = file_paths_.size() - 1;
|
||||||
std::ifstream file(file_path, std::ios::binary);
|
|
||||||
if (!file.is_open()) {
|
|
||||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get file size
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
file.seekg(0, file.end);
|
if (is_unused_tensor(tensor_storage.name)) {
|
||||||
size_t file_size_ = file.tellg();
|
|
||||||
file.seekg(0, file.beg);
|
|
||||||
|
|
||||||
// read header size
|
|
||||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
|
||||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
|
||||||
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
|
||||||
if (!file) {
|
|
||||||
LOG_ERROR("read safetensors header size failed: '%s'", file_path.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t header_size_ = read_u64(header_size_buf);
|
|
||||||
if (header_size_ >= file_size_) {
|
|
||||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// read header
|
|
||||||
std::vector<char> header_buf;
|
|
||||||
header_buf.resize(header_size_ + 1);
|
|
||||||
header_buf[header_size_] = '\0';
|
|
||||||
file.read(header_buf.data(), header_size_);
|
|
||||||
if (!file) {
|
|
||||||
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
|
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json header_;
|
|
||||||
try {
|
|
||||||
header_ = nlohmann::json::parse(header_buf.data());
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
LOG_ERROR("parsing safetensors header failed", file_path.c_str());
|
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& item : header_.items()) {
|
|
||||||
std::string name = item.key();
|
|
||||||
nlohmann::json tensor_info = item.value();
|
|
||||||
// LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
|
|
||||||
|
|
||||||
if (name == "__metadata__") {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_unused_tensor(name)) {
|
if (!starts_with(tensor_storage.name, prefix)) {
|
||||||
continue;
|
tensor_storage.name = prefix + tensor_storage.name;
|
||||||
}
|
|
||||||
|
|
||||||
std::string dtype = tensor_info["dtype"];
|
|
||||||
nlohmann::json shape = tensor_info["shape"];
|
|
||||||
|
|
||||||
if (dtype == "U8") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
|
||||||
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
|
||||||
|
|
||||||
ggml_type type = str_to_ggml_type(dtype);
|
|
||||||
if (type == GGML_TYPE_COUNT) {
|
|
||||||
LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shape.size() > SD_MAX_DIMS) {
|
|
||||||
LOG_ERROR("invalid tensor '%s'", name.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_dims = (int)shape.size();
|
|
||||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
|
||||||
for (int i = 0; i < n_dims; i++) {
|
|
||||||
ne[i] = shape[i].get<int64_t>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_dims == 5) {
|
|
||||||
n_dims = 4;
|
|
||||||
ne[0] = ne[0] * ne[1];
|
|
||||||
ne[1] = ne[2];
|
|
||||||
ne[2] = ne[3];
|
|
||||||
ne[3] = ne[4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_n_dims returns 1 for scalars
|
|
||||||
if (n_dims == 0) {
|
|
||||||
n_dims = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!starts_with(name, prefix)) {
|
|
||||||
name = prefix + name;
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
|
||||||
tensor_storage.reverse_ne();
|
|
||||||
|
|
||||||
size_t tensor_data_size = end - begin;
|
|
||||||
|
|
||||||
bool tensor_size_ok;
|
|
||||||
if (dtype == "F8_E4M3") {
|
|
||||||
tensor_storage.is_f8_e4m3 = true;
|
|
||||||
// f8 -> f16
|
|
||||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
|
||||||
} else if (dtype == "F8_E5M2") {
|
|
||||||
tensor_storage.is_f8_e5m2 = true;
|
|
||||||
// f8 -> f16
|
|
||||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
|
||||||
} else if (dtype == "F64") {
|
|
||||||
tensor_storage.is_f64 = true;
|
|
||||||
// f64 -> f32
|
|
||||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
|
||||||
} else if (dtype == "I64") {
|
|
||||||
tensor_storage.is_i64 = true;
|
|
||||||
// i64 -> i32
|
|
||||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
|
||||||
} else {
|
|
||||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
|
|
||||||
}
|
|
||||||
if (!tensor_size_ok) {
|
|
||||||
LOG_ERROR("size mismatch for tensor '%s' (%s)\n", name.c_str(), dtype.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
tensor_storage.file_index = file_index;
|
||||||
|
|
||||||
add_tensor_storage(tensor_storage);
|
add_tensor_storage(tensor_storage);
|
||||||
|
|
||||||
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
// LOG_DEBUG("%s", tensor_storage.to_string().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -644,362 +357,30 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
|
|||||||
|
|
||||||
/*================================================= CkptModelLoader ==================================================*/
|
/*================================================= CkptModelLoader ==================================================*/
|
||||||
|
|
||||||
// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
|
bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
|
||||||
// 0: \x80 PROTO 2
|
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||||
// 2: } EMPTY_DICT
|
|
||||||
// 3: q BINPUT 0
|
|
||||||
// 5: ( MARK
|
|
||||||
// 6: X BINUNICODE 'epoch'
|
|
||||||
// 16: q BINPUT 1
|
|
||||||
// 18: K BININT1 6
|
|
||||||
// 20: X BINUNICODE 'global_step'
|
|
||||||
// 36: q BINPUT 2
|
|
||||||
// 38: J BININT 470000
|
|
||||||
// 43: X BINUNICODE 'pytorch-lightning_version'
|
|
||||||
// 73: q BINPUT 3
|
|
||||||
// 75: X BINUNICODE '1.4.2'
|
|
||||||
// 85: q BINPUT 4
|
|
||||||
// 87: X BINUNICODE 'state_dict'
|
|
||||||
// 102: q BINPUT 5
|
|
||||||
// 104: } EMPTY_DICT
|
|
||||||
// 105: q BINPUT 6
|
|
||||||
// 107: ( MARK
|
|
||||||
// 108: X BINUNICODE 'betas'
|
|
||||||
// 118: q BINPUT 7
|
|
||||||
// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2'
|
|
||||||
// 153: q BINPUT 8
|
|
||||||
// 155: ( MARK
|
|
||||||
// 156: ( MARK
|
|
||||||
// 157: X BINUNICODE 'storage'
|
|
||||||
// 169: q BINPUT 9
|
|
||||||
// 171: c GLOBAL 'torch FloatStorage'
|
|
||||||
// 191: q BINPUT 10
|
|
||||||
// 193: X BINUNICODE '0'
|
|
||||||
// 199: q BINPUT 11
|
|
||||||
// 201: X BINUNICODE 'cpu'
|
|
||||||
// 209: q BINPUT 12
|
|
||||||
// 211: M BININT2 1000
|
|
||||||
// 214: t TUPLE (MARK at 156)
|
|
||||||
// 215: q BINPUT 13
|
|
||||||
// 217: Q BINPERSID
|
|
||||||
// 218: K BININT1 0
|
|
||||||
// 220: M BININT2 1000
|
|
||||||
// ...............................
|
|
||||||
// 3201: q BINPUT 250
|
|
||||||
// 3203: R REDUCE
|
|
||||||
// 3204: q BINPUT 251
|
|
||||||
// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
|
|
||||||
// 3264: q BINPUT 252
|
|
||||||
// 3266: h BINGET 8
|
|
||||||
// 3268: ( MARK
|
|
||||||
// 3269: ( MARK
|
|
||||||
// 3270: h BINGET 9
|
|
||||||
// 3272: h BINGET 10
|
|
||||||
// 3274: X BINUNICODE '30'
|
|
||||||
// 3281: q BINPUT 253
|
|
||||||
// 3283: h BINGET 12
|
|
||||||
// 3285: J BININT 102400
|
|
||||||
// 3290: t TUPLE (MARK at 3269)
|
|
||||||
// 3291: q BINPUT 254
|
|
||||||
// 3293: Q BINPERSID
|
|
||||||
// 3294: K BININT1 0
|
|
||||||
// 3296: ( MARK
|
|
||||||
// 3297: M BININT2 320
|
|
||||||
// 3300: M BININT2 320
|
|
||||||
// 3303: K BININT1 1
|
|
||||||
// 3305: K BININT1 1
|
|
||||||
// 3307: t TUPLE (MARK at 3296)
|
|
||||||
// 3308: q BINPUT 255
|
|
||||||
// 3310: ( MARK
|
|
||||||
// 3311: M BININT2 320
|
|
||||||
// 3314: K BININT1 1
|
|
||||||
// 3316: K BININT1 1
|
|
||||||
// 3318: K BININT1 1
|
|
||||||
// 3320: t TUPLE (MARK at 3310)
|
|
||||||
// 3321: r LONG_BINPUT 256
|
|
||||||
// 3326: \x89 NEWFALSE
|
|
||||||
// 3327: h BINGET 16
|
|
||||||
// 3329: ) EMPTY_TUPLE
|
|
||||||
// 3330: R REDUCE
|
|
||||||
// 3331: r LONG_BINPUT 257
|
|
||||||
// 3336: t TUPLE (MARK at 3268)
|
|
||||||
// 3337: r LONG_BINPUT 258
|
|
||||||
// 3342: R REDUCE
|
|
||||||
// 3343: r LONG_BINPUT 259
|
|
||||||
// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
|
|
||||||
// 3404: r LONG_BINPUT 260
|
|
||||||
// 3409: h BINGET 8
|
|
||||||
// 3411: ( MARK
|
|
||||||
// 3412: ( MARK
|
|
||||||
// 3413: h BINGET 9
|
|
||||||
// 3415: h BINGET 10
|
|
||||||
// 3417: X BINUNICODE '31'
|
|
||||||
|
|
||||||
struct PickleTensorReader {
|
std::vector<TensorStorage> tensor_storages;
|
||||||
enum ReadPhase {
|
std::string error;
|
||||||
READ_NAME,
|
if (!read_ckpt_file(file_path, tensor_storages, &error)) {
|
||||||
READ_DATA,
|
LOG_ERROR("%s", error.c_str());
|
||||||
CHECK_SIZE,
|
|
||||||
READ_DIMENS
|
|
||||||
};
|
|
||||||
ReadPhase phase = READ_NAME;
|
|
||||||
size_t entry_size = 0;
|
|
||||||
int32_t nelements = 0;
|
|
||||||
|
|
||||||
TensorStorage tensor_storage;
|
|
||||||
|
|
||||||
static ggml_type global_type; // all pickle_tensors data type
|
|
||||||
static bool read_global_type;
|
|
||||||
|
|
||||||
bool read_int_value(uint32_t value) {
|
|
||||||
if (phase == CHECK_SIZE) {
|
|
||||||
if (entry_size == value * ggml_type_size(tensor_storage.type)) {
|
|
||||||
nelements = value;
|
|
||||||
phase = READ_DIMENS;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
phase = READ_NAME;
|
|
||||||
}
|
|
||||||
} else if (phase == READ_DIMENS) {
|
|
||||||
if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens
|
|
||||||
phase = READ_NAME;
|
|
||||||
tensor_storage.n_dims = 0;
|
|
||||||
}
|
|
||||||
if (nelements % value == 0) {
|
|
||||||
tensor_storage.ne[tensor_storage.n_dims] = value;
|
|
||||||
tensor_storage.n_dims++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_global(const std::string& str) {
|
|
||||||
if (str == "FloatStorage") {
|
|
||||||
if (read_global_type) {
|
|
||||||
global_type = GGML_TYPE_F32;
|
|
||||||
read_global_type = false;
|
|
||||||
}
|
|
||||||
tensor_storage.type = GGML_TYPE_F32;
|
|
||||||
} else if (str == "HalfStorage") {
|
|
||||||
if (read_global_type) {
|
|
||||||
global_type = GGML_TYPE_F16;
|
|
||||||
read_global_type = false;
|
|
||||||
}
|
|
||||||
tensor_storage.type = GGML_TYPE_F16;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void read_string(const std::string& str, zip_t* zip, std::string dir) {
|
|
||||||
if (str == "storage") {
|
|
||||||
read_global_type = true;
|
|
||||||
} else if (str != "state_dict") {
|
|
||||||
if (phase == READ_DATA) {
|
|
||||||
std::string entry_name = dir + "data/" + std::string(str);
|
|
||||||
|
|
||||||
size_t i, n = zip_entries_total(zip);
|
|
||||||
for (i = 0; i < n; ++i) {
|
|
||||||
zip_entry_openbyindex(zip, i);
|
|
||||||
{
|
|
||||||
std::string name = zip_entry_name(zip);
|
|
||||||
if (name == entry_name) {
|
|
||||||
tensor_storage.index_in_zip = (int)i;
|
|
||||||
entry_size = zip_entry_size(zip);
|
|
||||||
zip_entry_close(zip);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
zip_entry_close(zip);
|
|
||||||
}
|
|
||||||
|
|
||||||
phase = entry_size > 0 ? CHECK_SIZE : READ_NAME;
|
|
||||||
}
|
|
||||||
if (!read_global_type && phase == READ_NAME) {
|
|
||||||
tensor_storage.name = str;
|
|
||||||
phase = READ_DATA;
|
|
||||||
tensor_storage.type = global_type;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type
|
|
||||||
bool PickleTensorReader::read_global_type = false;
|
|
||||||
|
|
||||||
int find_char(uint8_t* buffer, int len, char c) {
|
|
||||||
for (int pos = 0; pos < len; pos++) {
|
|
||||||
if (buffer[pos] == c) {
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MAX_STRING_BUFFER 512
|
|
||||||
|
|
||||||
bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
|
||||||
size_t buffer_size,
|
|
||||||
zip_t* zip,
|
|
||||||
std::string dir,
|
|
||||||
size_t file_index,
|
|
||||||
const std::string prefix) {
|
|
||||||
uint8_t* buffer_end = buffer + buffer_size;
|
|
||||||
if (buffer[0] == 0x80) { // proto
|
|
||||||
if (buffer[1] != 2) {
|
|
||||||
LOG_ERROR("Unsupported protocol\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
buffer += 2; // 0x80 and version
|
|
||||||
char string_buffer[MAX_STRING_BUFFER];
|
|
||||||
bool finish = false;
|
|
||||||
PickleTensorReader reader;
|
|
||||||
// read pickle binary file
|
|
||||||
while (!finish && buffer < buffer_end) {
|
|
||||||
uint8_t opcode = *buffer;
|
|
||||||
buffer++;
|
|
||||||
// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
|
|
||||||
// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
|
|
||||||
switch (opcode) {
|
|
||||||
case '}': // EMPTY_DICT = b'}' # push empty dict
|
|
||||||
break;
|
|
||||||
case ']': // EMPTY_LIST = b']' # push empty list
|
|
||||||
break;
|
|
||||||
// skip unused sections
|
|
||||||
case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg
|
|
||||||
case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg
|
|
||||||
case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack
|
|
||||||
buffer++;
|
|
||||||
break;
|
|
||||||
case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg
|
|
||||||
buffer += 4;
|
|
||||||
break;
|
|
||||||
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
|
|
||||||
buffer += 8;
|
|
||||||
break;
|
|
||||||
case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo
|
|
||||||
break;
|
|
||||||
case '(': // MARK = b'(' # push special markobject on stack
|
|
||||||
break;
|
|
||||||
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
|
||||||
{
|
|
||||||
uint8_t value = *buffer;
|
|
||||||
if (reader.read_int_value(value)) {
|
|
||||||
buffer++;
|
|
||||||
}
|
|
||||||
buffer++;
|
|
||||||
} break;
|
|
||||||
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
|
||||||
{
|
|
||||||
uint16_t value = read_short(buffer);
|
|
||||||
if (reader.read_int_value(value)) {
|
|
||||||
buffer++;
|
|
||||||
}
|
|
||||||
buffer += 2;
|
|
||||||
} break;
|
|
||||||
case 'J': // BININT = b'J' # push four-byte signed int
|
|
||||||
{
|
|
||||||
const int32_t value = read_int(buffer);
|
|
||||||
if (reader.read_int_value(value)) {
|
|
||||||
buffer++; // skip tuple after read num_elements
|
|
||||||
}
|
|
||||||
buffer += 4;
|
|
||||||
} break;
|
|
||||||
case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument
|
|
||||||
{
|
|
||||||
const int32_t len = read_int(buffer);
|
|
||||||
buffer += 4;
|
|
||||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
|
||||||
if (len > MAX_STRING_BUFFER) {
|
|
||||||
LOG_WARN("tensor name very large");
|
|
||||||
}
|
|
||||||
memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1));
|
|
||||||
buffer += len;
|
|
||||||
reader.read_string(string_buffer, zip, dir);
|
|
||||||
} break;
|
|
||||||
case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
|
|
||||||
{
|
|
||||||
const int8_t len = *buffer;
|
|
||||||
buffer++;
|
|
||||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
|
||||||
memcpy(string_buffer, buffer, len);
|
|
||||||
buffer += len;
|
|
||||||
// printf("String: '%s'\n", string_buffer);
|
|
||||||
} break;
|
|
||||||
case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args
|
|
||||||
{
|
|
||||||
int len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
|
||||||
|
|
||||||
buffer += len + 1;
|
|
||||||
len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
|
||||||
|
|
||||||
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
|
||||||
memcpy(string_buffer, buffer, len);
|
|
||||||
buffer += len + 1;
|
|
||||||
reader.read_global(string_buffer);
|
|
||||||
} break;
|
|
||||||
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items
|
|
||||||
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top
|
|
||||||
case 't': // TUPLE = b't' # build tuple from topmost stack items
|
|
||||||
if (reader.phase == PickleTensorReader::READ_DIMENS) {
|
|
||||||
reader.tensor_storage.reverse_ne();
|
|
||||||
reader.tensor_storage.file_index = file_index;
|
|
||||||
// if(strcmp(prefix.c_str(), "scarlett") == 0)
|
|
||||||
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
|
||||||
std::string name = reader.tensor_storage.name;
|
|
||||||
if (!starts_with(name, prefix)) {
|
|
||||||
name = prefix + name;
|
|
||||||
}
|
|
||||||
reader.tensor_storage.name = name;
|
|
||||||
add_tensor_storage(reader.tensor_storage);
|
|
||||||
|
|
||||||
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
|
||||||
// reset
|
|
||||||
reader = PickleTensorReader();
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case '.': // STOP = b'.' # every pickle ends with STOP
|
|
||||||
finish = true;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
|
|
||||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
|
||||||
file_paths_.push_back(file_path);
|
file_paths_.push_back(file_path);
|
||||||
size_t file_index = file_paths_.size() - 1;
|
size_t file_index = file_paths_.size() - 1;
|
||||||
|
|
||||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
if (zip == nullptr) {
|
if (!starts_with(tensor_storage.name, prefix)) {
|
||||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
tensor_storage.name = prefix + tensor_storage.name;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
int n = (int)zip_entries_total(zip);
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
zip_entry_openbyindex(zip, i);
|
|
||||||
{
|
|
||||||
std::string name = zip_entry_name(zip);
|
|
||||||
size_t pos = name.find("data.pkl");
|
|
||||||
if (pos != std::string::npos) {
|
|
||||||
std::string dir = name.substr(0, pos);
|
|
||||||
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
|
|
||||||
void* pkl_data = nullptr;
|
|
||||||
size_t pkl_size;
|
|
||||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
|
||||||
|
|
||||||
// LOG_DEBUG("%lld", pkl_size);
|
|
||||||
|
|
||||||
parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix);
|
|
||||||
|
|
||||||
free(pkl_data);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
zip_entry_close(zip);
|
tensor_storage.file_index = file_index;
|
||||||
|
|
||||||
|
add_tensor_storage(tensor_storage);
|
||||||
|
|
||||||
|
// LOG_DEBUG("%s", tensor_storage.to_string().c_str());
|
||||||
}
|
}
|
||||||
zip_close(zip);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1703,19 +1084,8 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
||||||
auto backend = ggml_backend_cpu_init();
|
|
||||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
|
||||||
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
|
|
||||||
mem_size += get_params_mem_size(backend, type);
|
|
||||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
|
||||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
|
||||||
|
|
||||||
gguf_context* gguf_ctx = gguf_init_empty();
|
|
||||||
|
|
||||||
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
|
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
|
||||||
|
auto get_tensor_type = [&](const TensorStorage& tensor_storage) -> ggml_type {
|
||||||
std::mutex tensor_mutex;
|
|
||||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
|
||||||
const std::string& name = tensor_storage.name;
|
const std::string& name = tensor_storage.name;
|
||||||
ggml_type tensor_type = tensor_storage.type;
|
ggml_type tensor_type = tensor_storage.type;
|
||||||
ggml_type dst_type = type;
|
ggml_type dst_type = type;
|
||||||
@ -1732,6 +1102,28 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
|||||||
tensor_type = dst_type;
|
tensor_type = dst_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return tensor_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto backend = ggml_backend_cpu_init();
|
||||||
|
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||||
|
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
|
||||||
|
mem_size += get_params_mem_size(backend, type);
|
||||||
|
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||||
|
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||||
|
|
||||||
|
if (ggml_ctx == nullptr) {
|
||||||
|
LOG_ERROR("ggml_init failed for GGUF writer");
|
||||||
|
ggml_backend_free(backend);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> tensors;
|
||||||
|
std::mutex tensor_mutex;
|
||||||
|
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||||
|
const std::string& name = tensor_storage.name;
|
||||||
|
ggml_type tensor_type = get_tensor_type(tensor_storage);
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(tensor_mutex);
|
std::lock_guard<std::mutex> lock(tensor_mutex);
|
||||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
@ -1754,8 +1146,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
|||||||
}
|
}
|
||||||
|
|
||||||
*dst_tensor = tensor;
|
*dst_tensor = tensor;
|
||||||
|
tensors.push_back(tensor);
|
||||||
gguf_add_tensor(gguf_ctx, tensor);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
@ -1763,12 +1154,17 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
|||||||
bool success = load_tensors(on_new_tensor_cb);
|
bool success = load_tensors(on_new_tensor_cb);
|
||||||
ggml_backend_free(backend);
|
ggml_backend_free(backend);
|
||||||
LOG_INFO("load tensors done");
|
LOG_INFO("load tensors done");
|
||||||
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
|
||||||
|
std::string error;
|
||||||
if (success) {
|
if (success) {
|
||||||
gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
|
success = write_gguf_file(file_path, tensors, &error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!success && !error.empty()) {
|
||||||
|
LOG_ERROR("%s", error.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
ggml_free(ggml_ctx);
|
ggml_free(ggml_ctx);
|
||||||
gguf_free(gguf_ctx);
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
125
src/model.h
125
src/model.h
@ -5,20 +5,13 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "gguf.h"
|
#include "model_io/tensor_storage.h"
|
||||||
#include "json.hpp"
|
|
||||||
#include "ordered_map.hpp"
|
#include "ordered_map.hpp"
|
||||||
#include "zip.h"
|
|
||||||
|
|
||||||
#define SD_MAX_DIMS 5
|
|
||||||
|
|
||||||
enum SDVersion {
|
enum SDVersion {
|
||||||
VERSION_SD1,
|
VERSION_SD1,
|
||||||
@ -195,115 +188,6 @@ enum PMVersion {
|
|||||||
PM_VERSION_2,
|
PM_VERSION_2,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TensorStorage {
|
|
||||||
std::string name;
|
|
||||||
ggml_type type = GGML_TYPE_F32;
|
|
||||||
ggml_type expected_type = GGML_TYPE_COUNT;
|
|
||||||
bool is_f8_e4m3 = false;
|
|
||||||
bool is_f8_e5m2 = false;
|
|
||||||
bool is_f64 = false;
|
|
||||||
bool is_i64 = false;
|
|
||||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
|
||||||
int n_dims = 0;
|
|
||||||
|
|
||||||
size_t file_index = 0;
|
|
||||||
int index_in_zip = -1; // >= means stored in a zip file
|
|
||||||
uint64_t offset = 0; // offset in file
|
|
||||||
|
|
||||||
TensorStorage() = default;
|
|
||||||
|
|
||||||
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
|
||||||
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
|
||||||
for (int i = 0; i < n_dims; i++) {
|
|
||||||
this->ne[i] = ne[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t nelements() const {
|
|
||||||
int64_t n = 1;
|
|
||||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
|
||||||
n *= ne[i];
|
|
||||||
}
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t nbytes() const {
|
|
||||||
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t nbytes_to_read() const {
|
|
||||||
if (is_f8_e4m3 || is_f8_e5m2) {
|
|
||||||
return nbytes() / 2;
|
|
||||||
} else if (is_f64 || is_i64) {
|
|
||||||
return nbytes() * 2;
|
|
||||||
} else {
|
|
||||||
return nbytes();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void unsqueeze() {
|
|
||||||
if (n_dims == 2) {
|
|
||||||
n_dims = 4;
|
|
||||||
ne[3] = ne[1];
|
|
||||||
ne[2] = ne[0];
|
|
||||||
ne[1] = 1;
|
|
||||||
ne[0] = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<TensorStorage> chunk(size_t n) {
|
|
||||||
std::vector<TensorStorage> chunks;
|
|
||||||
uint64_t chunk_size = nbytes_to_read() / n;
|
|
||||||
// printf("%d/%d\n", chunk_size, nbytes_to_read());
|
|
||||||
reverse_ne();
|
|
||||||
for (size_t i = 0; i < n; i++) {
|
|
||||||
TensorStorage chunk_i = *this;
|
|
||||||
chunk_i.ne[0] = ne[0] / n;
|
|
||||||
chunk_i.offset = offset + i * chunk_size;
|
|
||||||
chunk_i.reverse_ne();
|
|
||||||
chunks.push_back(chunk_i);
|
|
||||||
}
|
|
||||||
reverse_ne();
|
|
||||||
return chunks;
|
|
||||||
}
|
|
||||||
|
|
||||||
void reverse_ne() {
|
|
||||||
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
|
||||||
for (int i = 0; i < n_dims; i++) {
|
|
||||||
new_ne[i] = ne[n_dims - 1 - i];
|
|
||||||
}
|
|
||||||
for (int i = 0; i < n_dims; i++) {
|
|
||||||
ne[i] = new_ne[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string to_string() const {
|
|
||||||
std::stringstream ss;
|
|
||||||
const char* type_name = ggml_type_name(type);
|
|
||||||
if (is_f8_e4m3) {
|
|
||||||
type_name = "f8_e4m3";
|
|
||||||
} else if (is_f8_e5m2) {
|
|
||||||
type_name = "f8_e5m2";
|
|
||||||
} else if (is_f64) {
|
|
||||||
type_name = "f64";
|
|
||||||
} else if (is_i64) {
|
|
||||||
type_name = "i64";
|
|
||||||
}
|
|
||||||
ss << name << " | " << type_name << " | ";
|
|
||||||
ss << n_dims << " [";
|
|
||||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
|
||||||
ss << ne[i];
|
|
||||||
if (i != SD_MAX_DIMS - 1) {
|
|
||||||
ss << ", ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ss << "]";
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
|
||||||
|
|
||||||
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
||||||
|
|
||||||
class ModelLoader {
|
class ModelLoader {
|
||||||
@ -314,13 +198,6 @@ protected:
|
|||||||
|
|
||||||
void add_tensor_storage(const TensorStorage& tensor_storage);
|
void add_tensor_storage(const TensorStorage& tensor_storage);
|
||||||
|
|
||||||
bool parse_data_pkl(uint8_t* buffer,
|
|
||||||
size_t buffer_size,
|
|
||||||
zip_t* zip,
|
|
||||||
std::string dir,
|
|
||||||
size_t file_index,
|
|
||||||
const std::string prefix);
|
|
||||||
|
|
||||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||||
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
|
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
|
||||||
|
|||||||
403
src/model_io/ckpt_io.cpp
Normal file
403
src/model_io/ckpt_io.cpp
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
#include "ckpt_io.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "zip.h"
|
||||||
|
|
||||||
|
static constexpr int MAX_STRING_BUFFER = 512;
|
||||||
|
|
||||||
|
static void set_error(std::string* error, const std::string& message) {
|
||||||
|
if (error != nullptr) {
|
||||||
|
*error = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int32_t read_int(const uint8_t* buffer) {
|
||||||
|
// little endian
|
||||||
|
uint32_t value = 0;
|
||||||
|
value |= static_cast<uint32_t>(buffer[3]) << 24;
|
||||||
|
value |= static_cast<uint32_t>(buffer[2]) << 16;
|
||||||
|
value |= static_cast<uint32_t>(buffer[1]) << 8;
|
||||||
|
value |= static_cast<uint32_t>(buffer[0]);
|
||||||
|
return static_cast<int32_t>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint16_t read_short(const uint8_t* buffer) {
|
||||||
|
// little endian
|
||||||
|
uint16_t value = 0;
|
||||||
|
value |= static_cast<uint16_t>(buffer[1]) << 8;
|
||||||
|
value |= static_cast<uint16_t>(buffer[0]);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_ckpt_file(const std::string& file_path) {
|
||||||
|
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||||
|
if (zip == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
zip_close(zip);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*================================================= CkptModelLoader ==================================================*/
|
||||||
|
|
||||||
|
// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100
|
||||||
|
// 0: \x80 PROTO 2
|
||||||
|
// 2: } EMPTY_DICT
|
||||||
|
// 3: q BINPUT 0
|
||||||
|
// 5: ( MARK
|
||||||
|
// 6: X BINUNICODE 'epoch'
|
||||||
|
// 16: q BINPUT 1
|
||||||
|
// 18: K BININT1 6
|
||||||
|
// 20: X BINUNICODE 'global_step'
|
||||||
|
// 36: q BINPUT 2
|
||||||
|
// 38: J BININT 470000
|
||||||
|
// 43: X BINUNICODE 'pytorch-lightning_version'
|
||||||
|
// 73: q BINPUT 3
|
||||||
|
// 75: X BINUNICODE '1.4.2'
|
||||||
|
// 85: q BINPUT 4
|
||||||
|
// 87: X BINUNICODE 'state_dict'
|
||||||
|
// 102: q BINPUT 5
|
||||||
|
// 104: } EMPTY_DICT
|
||||||
|
// 105: q BINPUT 6
|
||||||
|
// 107: ( MARK
|
||||||
|
// 108: X BINUNICODE 'betas'
|
||||||
|
// 118: q BINPUT 7
|
||||||
|
// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2'
|
||||||
|
// 153: q BINPUT 8
|
||||||
|
// 155: ( MARK
|
||||||
|
// 156: ( MARK
|
||||||
|
// 157: X BINUNICODE 'storage'
|
||||||
|
// 169: q BINPUT 9
|
||||||
|
// 171: c GLOBAL 'torch FloatStorage'
|
||||||
|
// 191: q BINPUT 10
|
||||||
|
// 193: X BINUNICODE '0'
|
||||||
|
// 199: q BINPUT 11
|
||||||
|
// 201: X BINUNICODE 'cpu'
|
||||||
|
// 209: q BINPUT 12
|
||||||
|
// 211: M BININT2 1000
|
||||||
|
// 214: t TUPLE (MARK at 156)
|
||||||
|
// 215: q BINPUT 13
|
||||||
|
// 217: Q BINPERSID
|
||||||
|
// 218: K BININT1 0
|
||||||
|
// 220: M BININT2 1000
|
||||||
|
// ...............................
|
||||||
|
// 3201: q BINPUT 250
|
||||||
|
// 3203: R REDUCE
|
||||||
|
// 3204: q BINPUT 251
|
||||||
|
// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight'
|
||||||
|
// 3264: q BINPUT 252
|
||||||
|
// 3266: h BINGET 8
|
||||||
|
// 3268: ( MARK
|
||||||
|
// 3269: ( MARK
|
||||||
|
// 3270: h BINGET 9
|
||||||
|
// 3272: h BINGET 10
|
||||||
|
// 3274: X BINUNICODE '30'
|
||||||
|
// 3281: q BINPUT 253
|
||||||
|
// 3283: h BINGET 12
|
||||||
|
// 3285: J BININT 102400
|
||||||
|
// 3290: t TUPLE (MARK at 3269)
|
||||||
|
// 3291: q BINPUT 254
|
||||||
|
// 3293: Q BINPERSID
|
||||||
|
// 3294: K BININT1 0
|
||||||
|
// 3296: ( MARK
|
||||||
|
// 3297: M BININT2 320
|
||||||
|
// 3300: M BININT2 320
|
||||||
|
// 3303: K BININT1 1
|
||||||
|
// 3305: K BININT1 1
|
||||||
|
// 3307: t TUPLE (MARK at 3296)
|
||||||
|
// 3308: q BINPUT 255
|
||||||
|
// 3310: ( MARK
|
||||||
|
// 3311: M BININT2 320
|
||||||
|
// 3314: K BININT1 1
|
||||||
|
// 3316: K BININT1 1
|
||||||
|
// 3318: K BININT1 1
|
||||||
|
// 3320: t TUPLE (MARK at 3310)
|
||||||
|
// 3321: r LONG_BINPUT 256
|
||||||
|
// 3326: \x89 NEWFALSE
|
||||||
|
// 3327: h BINGET 16
|
||||||
|
// 3329: ) EMPTY_TUPLE
|
||||||
|
// 3330: R REDUCE
|
||||||
|
// 3331: r LONG_BINPUT 257
|
||||||
|
// 3336: t TUPLE (MARK at 3268)
|
||||||
|
// 3337: r LONG_BINPUT 258
|
||||||
|
// 3342: R REDUCE
|
||||||
|
// 3343: r LONG_BINPUT 259
|
||||||
|
// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias'
|
||||||
|
// 3404: r LONG_BINPUT 260
|
||||||
|
// 3409: h BINGET 8
|
||||||
|
// 3411: ( MARK
|
||||||
|
// 3412: ( MARK
|
||||||
|
// 3413: h BINGET 9
|
||||||
|
// 3415: h BINGET 10
|
||||||
|
// 3417: X BINUNICODE '31'
|
||||||
|
|
||||||
|
struct PickleTensorReader {
|
||||||
|
enum ReadPhase {
|
||||||
|
READ_NAME,
|
||||||
|
READ_DATA,
|
||||||
|
CHECK_SIZE,
|
||||||
|
READ_DIMENS
|
||||||
|
};
|
||||||
|
ReadPhase phase = READ_NAME;
|
||||||
|
size_t entry_size = 0;
|
||||||
|
int32_t nelements = 0;
|
||||||
|
|
||||||
|
TensorStorage tensor_storage;
|
||||||
|
|
||||||
|
static ggml_type global_type; // all pickle_tensors data type
|
||||||
|
static bool read_global_type;
|
||||||
|
|
||||||
|
bool read_int_value(uint32_t value) {
|
||||||
|
if (phase == CHECK_SIZE) {
|
||||||
|
if (entry_size == value * ggml_type_size(tensor_storage.type)) {
|
||||||
|
nelements = value;
|
||||||
|
phase = READ_DIMENS;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
phase = READ_NAME;
|
||||||
|
}
|
||||||
|
} else if (phase == READ_DIMENS) {
|
||||||
|
if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens
|
||||||
|
phase = READ_NAME;
|
||||||
|
tensor_storage.n_dims = 0;
|
||||||
|
}
|
||||||
|
if (nelements % value == 0) {
|
||||||
|
tensor_storage.ne[tensor_storage.n_dims] = value;
|
||||||
|
tensor_storage.n_dims++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_global(const std::string& str) {
|
||||||
|
if (str == "FloatStorage") {
|
||||||
|
if (read_global_type) {
|
||||||
|
global_type = GGML_TYPE_F32;
|
||||||
|
read_global_type = false;
|
||||||
|
}
|
||||||
|
tensor_storage.type = GGML_TYPE_F32;
|
||||||
|
} else if (str == "HalfStorage") {
|
||||||
|
if (read_global_type) {
|
||||||
|
global_type = GGML_TYPE_F16;
|
||||||
|
read_global_type = false;
|
||||||
|
}
|
||||||
|
tensor_storage.type = GGML_TYPE_F16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_string(const std::string& str, zip_t* zip, std::string dir) {
|
||||||
|
if (str == "storage") {
|
||||||
|
read_global_type = true;
|
||||||
|
} else if (str != "state_dict") {
|
||||||
|
if (phase == READ_DATA) {
|
||||||
|
std::string entry_name = dir + "data/" + std::string(str);
|
||||||
|
|
||||||
|
size_t i, n = zip_entries_total(zip);
|
||||||
|
for (i = 0; i < n; ++i) {
|
||||||
|
zip_entry_openbyindex(zip, i);
|
||||||
|
{
|
||||||
|
std::string name = zip_entry_name(zip);
|
||||||
|
if (name == entry_name) {
|
||||||
|
tensor_storage.index_in_zip = (int)i;
|
||||||
|
entry_size = zip_entry_size(zip);
|
||||||
|
zip_entry_close(zip);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
zip_entry_close(zip);
|
||||||
|
}
|
||||||
|
|
||||||
|
phase = entry_size > 0 ? CHECK_SIZE : READ_NAME;
|
||||||
|
}
|
||||||
|
if (!read_global_type && phase == READ_NAME) {
|
||||||
|
tensor_storage.name = str;
|
||||||
|
phase = READ_DATA;
|
||||||
|
tensor_storage.type = global_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type
|
||||||
|
bool PickleTensorReader::read_global_type = false;
|
||||||
|
|
||||||
|
static int find_char(uint8_t* buffer, int len, char c) {
|
||||||
|
for (int pos = 0; pos < len; pos++) {
|
||||||
|
if (buffer[pos] == c) {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool parse_data_pkl(uint8_t* buffer,
|
||||||
|
size_t buffer_size,
|
||||||
|
zip_t* zip,
|
||||||
|
std::string dir,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error) {
|
||||||
|
uint8_t* buffer_end = buffer + buffer_size;
|
||||||
|
if (buffer[0] == 0x80) { // proto
|
||||||
|
if (buffer[1] != 2) {
|
||||||
|
set_error(error, "unsupported pickle protocol");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
buffer += 2; // 0x80 and version
|
||||||
|
char string_buffer[MAX_STRING_BUFFER];
|
||||||
|
bool finish = false;
|
||||||
|
PickleTensorReader reader;
|
||||||
|
// read pickle binary file
|
||||||
|
while (!finish && buffer < buffer_end) {
|
||||||
|
uint8_t opcode = *buffer;
|
||||||
|
buffer++;
|
||||||
|
// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048
|
||||||
|
// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105
|
||||||
|
switch (opcode) {
|
||||||
|
case '}': // EMPTY_DICT = b'}' # push empty dict
|
||||||
|
break;
|
||||||
|
case ']': // EMPTY_LIST = b']' # push empty list
|
||||||
|
break;
|
||||||
|
// skip unused sections
|
||||||
|
case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg
|
||||||
|
case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg
|
||||||
|
case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack
|
||||||
|
buffer++;
|
||||||
|
break;
|
||||||
|
case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg
|
||||||
|
buffer += 4;
|
||||||
|
break;
|
||||||
|
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
|
||||||
|
buffer += 8;
|
||||||
|
break;
|
||||||
|
case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo
|
||||||
|
break;
|
||||||
|
case '(': // MARK = b'(' # push special markobject on stack
|
||||||
|
break;
|
||||||
|
case 'K': // BININT1 = b'K' # push 1-byte unsigned int
|
||||||
|
{
|
||||||
|
uint8_t value = *buffer;
|
||||||
|
if (reader.read_int_value(value)) {
|
||||||
|
buffer++;
|
||||||
|
}
|
||||||
|
buffer++;
|
||||||
|
} break;
|
||||||
|
case 'M': // BININT2 = b'M' # push 2-byte unsigned int
|
||||||
|
{
|
||||||
|
uint16_t value = read_short(buffer);
|
||||||
|
if (reader.read_int_value(value)) {
|
||||||
|
buffer++;
|
||||||
|
}
|
||||||
|
buffer += 2;
|
||||||
|
} break;
|
||||||
|
case 'J': // BININT = b'J' # push four-byte signed int
|
||||||
|
{
|
||||||
|
const int32_t value = read_int(buffer);
|
||||||
|
if (reader.read_int_value(value)) {
|
||||||
|
buffer++; // skip tuple after read num_elements
|
||||||
|
}
|
||||||
|
buffer += 4;
|
||||||
|
} break;
|
||||||
|
case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument
|
||||||
|
{
|
||||||
|
const int32_t len = read_int(buffer);
|
||||||
|
buffer += 4;
|
||||||
|
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||||
|
if (len > MAX_STRING_BUFFER) {
|
||||||
|
// keep truncated names null-terminated, matching the old parser behavior
|
||||||
|
}
|
||||||
|
memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1));
|
||||||
|
buffer += len;
|
||||||
|
reader.read_string(string_buffer, zip, dir);
|
||||||
|
} break;
|
||||||
|
case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
|
||||||
|
{
|
||||||
|
const int8_t len = *buffer;
|
||||||
|
buffer++;
|
||||||
|
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||||
|
memcpy(string_buffer, buffer, len);
|
||||||
|
buffer += len;
|
||||||
|
// printf("String: '%s'\n", string_buffer);
|
||||||
|
} break;
|
||||||
|
case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args
|
||||||
|
{
|
||||||
|
int len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
||||||
|
|
||||||
|
buffer += len + 1;
|
||||||
|
len = find_char(buffer, MAX_STRING_BUFFER, '\n');
|
||||||
|
|
||||||
|
memset(string_buffer, 0, MAX_STRING_BUFFER);
|
||||||
|
memcpy(string_buffer, buffer, len);
|
||||||
|
buffer += len + 1;
|
||||||
|
reader.read_global(string_buffer);
|
||||||
|
} break;
|
||||||
|
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items
|
||||||
|
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top
|
||||||
|
case 't': // TUPLE = b't' # build tuple from topmost stack items
|
||||||
|
if (reader.phase == PickleTensorReader::READ_DIMENS) {
|
||||||
|
reader.tensor_storage.reverse_ne();
|
||||||
|
tensor_storages.push_back(reader.tensor_storage);
|
||||||
|
|
||||||
|
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
||||||
|
// reset
|
||||||
|
reader = PickleTensorReader();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case '.': // STOP = b'.' # every pickle ends with STOP
|
||||||
|
finish = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool read_ckpt_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error) {
|
||||||
|
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||||
|
if (zip == nullptr) {
|
||||||
|
set_error(error, "failed to open '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_storages.clear();
|
||||||
|
bool success = true;
|
||||||
|
int n = (int)zip_entries_total(zip);
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
zip_entry_openbyindex(zip, i);
|
||||||
|
{
|
||||||
|
std::string name = zip_entry_name(zip);
|
||||||
|
size_t pos = name.find("data.pkl");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
std::string dir = name.substr(0, pos);
|
||||||
|
printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str());
|
||||||
|
void* pkl_data = nullptr;
|
||||||
|
size_t pkl_size;
|
||||||
|
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||||
|
|
||||||
|
// LOG_DEBUG("%lld", pkl_size);
|
||||||
|
|
||||||
|
if (!parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
free(pkl_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
zip_entry_close(zip);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
zip_close(zip);
|
||||||
|
return success;
|
||||||
|
}
|
||||||
14
src/model_io/ckpt_io.h
Normal file
14
src/model_io/ckpt_io.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#ifndef __SD_MODEL_IO_CKPT_IO_H__
|
||||||
|
#define __SD_MODEL_IO_CKPT_IO_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensor_storage.h"
|
||||||
|
|
||||||
|
bool is_ckpt_file(const std::string& file_path);
|
||||||
|
bool read_ckpt_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error = nullptr);
|
||||||
|
|
||||||
|
#endif // __SD_MODEL_IO_CKPT_IO_H__
|
||||||
122
src/model_io/gguf_io.cpp
Normal file
122
src/model_io/gguf_io.cpp
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
#include "gguf_io.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gguf.h"
|
||||||
|
#include "gguf_reader_ext.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
static void set_error(std::string* error, const std::string& message) {
|
||||||
|
if (error != nullptr) {
|
||||||
|
*error = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_gguf_file(const std::string& file_path) {
|
||||||
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
char magic[4];
|
||||||
|
|
||||||
|
file.read(magic, sizeof(magic));
|
||||||
|
if (!file) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
||||||
|
if (magic[i] != GGUF_MAGIC[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool read_gguf_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error) {
|
||||||
|
tensor_storages.clear();
|
||||||
|
|
||||||
|
gguf_context* ctx_gguf_ = nullptr;
|
||||||
|
ggml_context* ctx_meta_ = nullptr;
|
||||||
|
|
||||||
|
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
|
||||||
|
if (!ctx_gguf_) {
|
||||||
|
GGUFReader gguf_reader;
|
||||||
|
if (!gguf_reader.load(file_path)) {
|
||||||
|
set_error(error, "failed to open '" + file_path + "' with GGUFReader");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t data_offset = gguf_reader.data_offset();
|
||||||
|
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
|
||||||
|
TensorStorage tensor_storage(
|
||||||
|
gguf_tensor_info.name,
|
||||||
|
gguf_tensor_info.type,
|
||||||
|
gguf_tensor_info.shape.data(),
|
||||||
|
static_cast<int>(gguf_tensor_info.shape.size()),
|
||||||
|
0,
|
||||||
|
data_offset + gguf_tensor_info.offset);
|
||||||
|
|
||||||
|
tensor_storages.push_back(tensor_storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
|
||||||
|
|
||||||
|
size_t data_offset = gguf_get_data_offset(ctx_gguf_);
|
||||||
|
for (int i = 0; i < n_tensors; i++) {
|
||||||
|
std::string name = gguf_get_tensor_name(ctx_gguf_, i);
|
||||||
|
ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str());
|
||||||
|
size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i);
|
||||||
|
|
||||||
|
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset);
|
||||||
|
|
||||||
|
if (ggml_nbytes(dummy) != tensor_storage.nbytes()) {
|
||||||
|
gguf_free(ctx_gguf_);
|
||||||
|
ggml_free(ctx_meta_);
|
||||||
|
set_error(error, "size mismatch for tensor '" + name + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_storages.push_back(tensor_storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
gguf_free(ctx_gguf_);
|
||||||
|
ggml_free(ctx_meta_);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool write_gguf_file(const std::string& file_path,
|
||||||
|
const std::vector<ggml_tensor*>& tensors,
|
||||||
|
std::string* error) {
|
||||||
|
gguf_context* gguf_ctx = gguf_init_empty();
|
||||||
|
if (gguf_ctx == nullptr) {
|
||||||
|
set_error(error, "gguf_init_empty failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (ggml_tensor* tensor : tensors) {
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
set_error(error, "null tensor cannot be written to GGUF");
|
||||||
|
gguf_free(gguf_ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
gguf_add_tensor(gguf_ctx, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||||
|
bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
|
||||||
|
if (!success) {
|
||||||
|
set_error(error, "failed to write GGUF file '" + file_path + "'");
|
||||||
|
}
|
||||||
|
gguf_free(gguf_ctx);
|
||||||
|
return success;
|
||||||
|
}
|
||||||
17
src/model_io/gguf_io.h
Normal file
17
src/model_io/gguf_io.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#ifndef __SD_MODEL_IO_GGUF_IO_H__
|
||||||
|
#define __SD_MODEL_IO_GGUF_IO_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensor_storage.h"
|
||||||
|
|
||||||
|
bool is_gguf_file(const std::string& file_path);
|
||||||
|
bool read_gguf_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error = nullptr);
|
||||||
|
bool write_gguf_file(const std::string& file_path,
|
||||||
|
const std::vector<ggml_tensor*>& tensors,
|
||||||
|
std::string* error = nullptr);
|
||||||
|
|
||||||
|
#endif // __SD_MODEL_IO_GGUF_IO_H__
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#ifndef __GGUF_READER_HPP__
|
#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||||
#define __GGUF_READER_HPP__
|
#define __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -231,4 +231,4 @@ public:
|
|||||||
size_t data_offset() const { return data_offset_; }
|
size_t data_offset() const { return data_offset_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // __GGUF_READER_HPP__
|
#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||||
236
src/model_io/safetensors_io.cpp
Normal file
236
src/model_io/safetensors_io.cpp
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
#include "safetensors_io.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <exception>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
|
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
|
||||||
|
|
||||||
|
static void set_error(std::string* error, const std::string& message) {
|
||||||
|
if (error != nullptr) {
|
||||||
|
*error = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t read_u64(const uint8_t* buffer) {
|
||||||
|
// little endian
|
||||||
|
uint64_t value = 0;
|
||||||
|
value |= static_cast<uint64_t>(buffer[7]) << 56;
|
||||||
|
value |= static_cast<uint64_t>(buffer[6]) << 48;
|
||||||
|
value |= static_cast<uint64_t>(buffer[5]) << 40;
|
||||||
|
value |= static_cast<uint64_t>(buffer[4]) << 32;
|
||||||
|
value |= static_cast<uint64_t>(buffer[3]) << 24;
|
||||||
|
value |= static_cast<uint64_t>(buffer[2]) << 16;
|
||||||
|
value |= static_cast<uint64_t>(buffer[1]) << 8;
|
||||||
|
value |= static_cast<uint64_t>(buffer[0]);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_safetensors_file(const std::string& file_path) {
|
||||||
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get file size
|
||||||
|
file.seekg(0, file.end);
|
||||||
|
size_t file_size_ = file.tellg();
|
||||||
|
file.seekg(0, file.beg);
|
||||||
|
|
||||||
|
// read header size
|
||||||
|
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
||||||
|
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
||||||
|
if (!file) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t header_size_ = read_u64(header_size_buf);
|
||||||
|
if (header_size_ >= file_size_ || header_size_ <= 2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// read header
|
||||||
|
std::vector<char> header_buf;
|
||||||
|
header_buf.resize(header_size_ + 1);
|
||||||
|
header_buf[header_size_] = '\0';
|
||||||
|
file.read(header_buf.data(), header_size_);
|
||||||
|
if (!file) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||||
|
ggml_type ttype = GGML_TYPE_COUNT;
|
||||||
|
if (dtype == "F16") {
|
||||||
|
ttype = GGML_TYPE_F16;
|
||||||
|
} else if (dtype == "BF16") {
|
||||||
|
ttype = GGML_TYPE_BF16;
|
||||||
|
} else if (dtype == "F32") {
|
||||||
|
ttype = GGML_TYPE_F32;
|
||||||
|
} else if (dtype == "F64") {
|
||||||
|
ttype = GGML_TYPE_F32;
|
||||||
|
} else if (dtype == "F8_E4M3") {
|
||||||
|
ttype = GGML_TYPE_F16;
|
||||||
|
} else if (dtype == "F8_E5M2") {
|
||||||
|
ttype = GGML_TYPE_F16;
|
||||||
|
} else if (dtype == "I32") {
|
||||||
|
ttype = GGML_TYPE_I32;
|
||||||
|
} else if (dtype == "I64") {
|
||||||
|
ttype = GGML_TYPE_I32;
|
||||||
|
}
|
||||||
|
return ttype;
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://huggingface.co/docs/safetensors/index
|
||||||
|
bool read_safetensors_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error) {
|
||||||
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
set_error(error, "failed to open '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get file size
|
||||||
|
file.seekg(0, file.end);
|
||||||
|
size_t file_size_ = file.tellg();
|
||||||
|
file.seekg(0, file.beg);
|
||||||
|
|
||||||
|
// read header size
|
||||||
|
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||||
|
set_error(error, "invalid safetensor file '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
||||||
|
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
||||||
|
if (!file) {
|
||||||
|
set_error(error, "read safetensors header size failed: '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t header_size_ = read_u64(header_size_buf);
|
||||||
|
if (header_size_ >= file_size_) {
|
||||||
|
set_error(error, "invalid safetensor file '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// read header
|
||||||
|
std::vector<char> header_buf;
|
||||||
|
header_buf.resize(header_size_ + 1);
|
||||||
|
header_buf[header_size_] = '\0';
|
||||||
|
file.read(header_buf.data(), header_size_);
|
||||||
|
if (!file) {
|
||||||
|
set_error(error, "read safetensors header failed: '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json header_;
|
||||||
|
try {
|
||||||
|
header_ = nlohmann::json::parse(header_buf.data());
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
set_error(error, "parsing safetensors header failed: '" + file_path + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_storages.clear();
|
||||||
|
for (auto& item : header_.items()) {
|
||||||
|
std::string name = item.key();
|
||||||
|
nlohmann::json tensor_info = item.value();
|
||||||
|
// LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
|
||||||
|
|
||||||
|
if (name == "__metadata__") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string dtype = tensor_info["dtype"];
|
||||||
|
nlohmann::json shape = tensor_info["shape"];
|
||||||
|
|
||||||
|
if (dtype == "U8") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
||||||
|
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
||||||
|
|
||||||
|
ggml_type type = str_to_ggml_type(dtype);
|
||||||
|
if (type == GGML_TYPE_COUNT) {
|
||||||
|
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shape.size() > SD_MAX_DIMS) {
|
||||||
|
set_error(error, "invalid tensor '" + name + "'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_dims = (int)shape.size();
|
||||||
|
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||||
|
for (int i = 0; i < n_dims; i++) {
|
||||||
|
ne[i] = shape[i].get<int64_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_dims == 5) {
|
||||||
|
n_dims = 4;
|
||||||
|
ne[0] = ne[0] * ne[1];
|
||||||
|
ne[1] = ne[2];
|
||||||
|
ne[2] = ne[3];
|
||||||
|
ne[3] = ne[4];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_n_dims returns 1 for scalars
|
||||||
|
if (n_dims == 0) {
|
||||||
|
n_dims = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||||
|
tensor_storage.reverse_ne();
|
||||||
|
|
||||||
|
size_t tensor_data_size = end - begin;
|
||||||
|
|
||||||
|
bool tensor_size_ok;
|
||||||
|
if (dtype == "F8_E4M3") {
|
||||||
|
tensor_storage.is_f8_e4m3 = true;
|
||||||
|
// f8 -> f16
|
||||||
|
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||||
|
} else if (dtype == "F8_E5M2") {
|
||||||
|
tensor_storage.is_f8_e5m2 = true;
|
||||||
|
// f8 -> f16
|
||||||
|
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||||
|
} else if (dtype == "F64") {
|
||||||
|
tensor_storage.is_f64 = true;
|
||||||
|
// f64 -> f32
|
||||||
|
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||||
|
} else if (dtype == "I64") {
|
||||||
|
tensor_storage.is_i64 = true;
|
||||||
|
// i64 -> i32
|
||||||
|
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||||
|
} else {
|
||||||
|
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
|
||||||
|
}
|
||||||
|
if (!tensor_size_ok) {
|
||||||
|
set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_storages.push_back(tensor_storage);
|
||||||
|
|
||||||
|
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
14
src/model_io/safetensors_io.h
Normal file
14
src/model_io/safetensors_io.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||||
|
#define __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensor_storage.h"
|
||||||
|
|
||||||
|
bool is_safetensors_file(const std::string& file_path);
|
||||||
|
bool read_safetensors_file(const std::string& file_path,
|
||||||
|
std::vector<TensorStorage>& tensor_storages,
|
||||||
|
std::string* error = nullptr);
|
||||||
|
|
||||||
|
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||||
125
src/model_io/tensor_storage.h
Normal file
125
src/model_io/tensor_storage.h
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#ifndef __SD_TENSOR_STORAGE_H__
|
||||||
|
#define __SD_TENSOR_STORAGE_H__
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#define SD_MAX_DIMS 5
|
||||||
|
|
||||||
|
struct TensorStorage {
|
||||||
|
std::string name;
|
||||||
|
ggml_type type = GGML_TYPE_F32;
|
||||||
|
ggml_type expected_type = GGML_TYPE_COUNT;
|
||||||
|
bool is_f8_e4m3 = false;
|
||||||
|
bool is_f8_e5m2 = false;
|
||||||
|
bool is_f64 = false;
|
||||||
|
bool is_i64 = false;
|
||||||
|
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||||
|
int n_dims = 0;
|
||||||
|
|
||||||
|
size_t file_index = 0;
|
||||||
|
int index_in_zip = -1; // >= means stored in a zip file
|
||||||
|
uint64_t offset = 0; // offset in file
|
||||||
|
|
||||||
|
TensorStorage() = default;
|
||||||
|
|
||||||
|
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||||
|
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
||||||
|
for (int i = 0; i < n_dims; i++) {
|
||||||
|
this->ne[i] = ne[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t nelements() const {
|
||||||
|
int64_t n = 1;
|
||||||
|
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||||
|
n *= ne[i];
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t nbytes() const {
|
||||||
|
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t nbytes_to_read() const {
|
||||||
|
if (is_f8_e4m3 || is_f8_e5m2) {
|
||||||
|
return nbytes() / 2;
|
||||||
|
} else if (is_f64 || is_i64) {
|
||||||
|
return nbytes() * 2;
|
||||||
|
} else {
|
||||||
|
return nbytes();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void unsqueeze() {
|
||||||
|
if (n_dims == 2) {
|
||||||
|
n_dims = 4;
|
||||||
|
ne[3] = ne[1];
|
||||||
|
ne[2] = ne[0];
|
||||||
|
ne[1] = 1;
|
||||||
|
ne[0] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TensorStorage> chunk(size_t n) {
|
||||||
|
std::vector<TensorStorage> chunks;
|
||||||
|
uint64_t chunk_size = nbytes_to_read() / n;
|
||||||
|
// printf("%d/%d\n", chunk_size, nbytes_to_read());
|
||||||
|
reverse_ne();
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
TensorStorage chunk_i = *this;
|
||||||
|
chunk_i.ne[0] = ne[0] / n;
|
||||||
|
chunk_i.offset = offset + i * chunk_size;
|
||||||
|
chunk_i.reverse_ne();
|
||||||
|
chunks.push_back(chunk_i);
|
||||||
|
}
|
||||||
|
reverse_ne();
|
||||||
|
return chunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reverse_ne() {
|
||||||
|
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||||
|
for (int i = 0; i < n_dims; i++) {
|
||||||
|
new_ne[i] = ne[n_dims - 1 - i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n_dims; i++) {
|
||||||
|
ne[i] = new_ne[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string to_string() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
const char* type_name = ggml_type_name(type);
|
||||||
|
if (is_f8_e4m3) {
|
||||||
|
type_name = "f8_e4m3";
|
||||||
|
} else if (is_f8_e5m2) {
|
||||||
|
type_name = "f8_e5m2";
|
||||||
|
} else if (is_f64) {
|
||||||
|
type_name = "f64";
|
||||||
|
} else if (is_i64) {
|
||||||
|
type_name = "i64";
|
||||||
|
}
|
||||||
|
ss << name << " | " << type_name << " | ";
|
||||||
|
ss << n_dims << " [";
|
||||||
|
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||||
|
ss << ne[i];
|
||||||
|
if (i != SD_MAX_DIMS - 1) {
|
||||||
|
ss << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||||
|
|
||||||
|
#endif // __SD_TENSOR_STORAGE_H__
|
||||||
@ -2457,8 +2457,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
|
|||||||
return EXPONENTIAL_SCHEDULER;
|
return EXPONENTIAL_SCHEDULER;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sample_method == LCM_SAMPLE_METHOD) {
|
if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) {
|
||||||
return LCM_SCHEDULER;
|
return LCM_SCHEDULER;
|
||||||
|
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
|
||||||
|
return SIMPLE_SCHEDULER;
|
||||||
}
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user