mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add GGUFReader
This commit is contained in:
parent
2570565dfa
commit
29c61c8c29
231
gguf_reader.hpp
Normal file
231
gguf_reader.hpp
Normal file
@ -0,0 +1,231 @@
|
||||
#ifndef __GGUF_READER_HPP__
|
||||
#define __GGUF_READER_HPP__
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "util.h"
|
||||
|
||||
struct GGUFTensorInfo {
|
||||
std::string name;
|
||||
ggml_type type;
|
||||
std::vector<int64_t> shape;
|
||||
size_t offset;
|
||||
};
|
||||
|
||||
enum class GGUFMetadataType : uint32_t {
|
||||
UINT8 = 0,
|
||||
INT8 = 1,
|
||||
UINT16 = 2,
|
||||
INT16 = 3,
|
||||
UINT32 = 4,
|
||||
INT32 = 5,
|
||||
FLOAT32 = 6,
|
||||
BOOL = 7,
|
||||
STRING = 8,
|
||||
ARRAY = 9,
|
||||
UINT64 = 10,
|
||||
INT64 = 11,
|
||||
FLOAT64 = 12,
|
||||
};
|
||||
|
||||
class GGUFReader {
|
||||
private:
|
||||
std::vector<GGUFTensorInfo> tensors_;
|
||||
size_t data_offset_;
|
||||
size_t alignment_ = 32; // default alignment is 32
|
||||
|
||||
template <typename T>
|
||||
bool safe_read(std::ifstream& fin, T& value) {
|
||||
fin.read(reinterpret_cast<char*>(&value), sizeof(T));
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool safe_read(std::ifstream& fin, char* buffer, size_t size) {
|
||||
fin.read(buffer, size);
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool safe_seek(std::ifstream& fin, std::streamoff offset, std::ios::seekdir dir) {
|
||||
fin.seekg(offset, dir);
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool read_metadata(std::ifstream& fin) {
|
||||
uint64_t key_len = 0;
|
||||
if (!safe_read(fin, key_len))
|
||||
return false;
|
||||
|
||||
std::string key(key_len, '\0');
|
||||
if (!safe_read(fin, (char*)key.data(), key_len))
|
||||
return false;
|
||||
|
||||
uint32_t type = 0;
|
||||
if (!safe_read(fin, type))
|
||||
return false;
|
||||
|
||||
if (key == "general.alignment") {
|
||||
uint32_t align_val = 0;
|
||||
if (!safe_read(fin, align_val))
|
||||
return false;
|
||||
|
||||
if (align_val != 0 && (align_val & (align_val - 1)) == 0) {
|
||||
alignment_ = align_val;
|
||||
LOG_DEBUG("Found alignment: %zu", alignment_);
|
||||
} else {
|
||||
LOG_ERROR("Invalid alignment value %u, fallback to default %zu", align_val, alignment_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (static_cast<GGUFMetadataType>(type)) {
|
||||
case GGUFMetadataType::UINT8:
|
||||
case GGUFMetadataType::INT8:
|
||||
case GGUFMetadataType::BOOL:
|
||||
return safe_seek(fin, 1, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT16:
|
||||
case GGUFMetadataType::INT16:
|
||||
return safe_seek(fin, 2, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT32:
|
||||
case GGUFMetadataType::INT32:
|
||||
case GGUFMetadataType::FLOAT32:
|
||||
return safe_seek(fin, 4, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT64:
|
||||
case GGUFMetadataType::INT64:
|
||||
case GGUFMetadataType::FLOAT64:
|
||||
return safe_seek(fin, 8, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::STRING: {
|
||||
uint64_t len = 0;
|
||||
if (!safe_read(fin, len))
|
||||
return false;
|
||||
return safe_seek(fin, len, std::ios::cur);
|
||||
}
|
||||
|
||||
case GGUFMetadataType::ARRAY: {
|
||||
uint32_t elem_type = 0;
|
||||
uint64_t len = 0;
|
||||
if (!safe_read(fin, elem_type))
|
||||
return false;
|
||||
if (!safe_read(fin, len))
|
||||
return false;
|
||||
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
if (!read_metadata(fin))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
default:
|
||||
LOG_ERROR("Unknown metadata type=%u", type);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
GGUFTensorInfo read_tensor_info(std::ifstream& fin) {
|
||||
GGUFTensorInfo info;
|
||||
|
||||
uint64_t name_len;
|
||||
if (!safe_read(fin, name_len))
|
||||
throw std::runtime_error("read tensor name length failed");
|
||||
|
||||
info.name.resize(name_len);
|
||||
if (!safe_read(fin, (char*)info.name.data(), name_len))
|
||||
throw std::runtime_error("read tensor name failed");
|
||||
|
||||
uint32_t n_dims;
|
||||
if (!safe_read(fin, n_dims))
|
||||
throw std::runtime_error("read tensor dims failed");
|
||||
|
||||
info.shape.resize(n_dims);
|
||||
for (uint32_t i = 0; i < n_dims; i++) {
|
||||
if (!safe_read(fin, info.shape[i]))
|
||||
throw std::runtime_error("read tensor shape failed");
|
||||
}
|
||||
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
for (int i = GGML_MAX_DIMS; i < n_dims; i++) {
|
||||
info.shape[GGML_MAX_DIMS - 1] *= info.shape[i]; // stack to last dim;
|
||||
}
|
||||
info.shape.resize(GGML_MAX_DIMS);
|
||||
n_dims = GGML_MAX_DIMS;
|
||||
}
|
||||
|
||||
uint32_t type;
|
||||
if (!safe_read(fin, type))
|
||||
throw std::runtime_error("read tensor type failed");
|
||||
info.type = static_cast<ggml_type>(type);
|
||||
|
||||
if (!safe_read(fin, info.offset))
|
||||
throw std::runtime_error("read tensor offset failed");
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
public:
|
||||
bool load(const std::string& file_path) {
|
||||
std::ifstream fin(file_path, std::ios::binary);
|
||||
if (!fin) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Header ---
|
||||
char magic[4];
|
||||
if (!safe_read(fin, magic, 4) || strncmp(magic, "GGUF", 4) != 0) {
|
||||
LOG_ERROR("not a valid GGUF file");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t version;
|
||||
if (!safe_read(fin, version))
|
||||
return false;
|
||||
|
||||
uint64_t tensor_count, metadata_kv_count;
|
||||
if (!safe_read(fin, tensor_count))
|
||||
return false;
|
||||
if (!safe_read(fin, metadata_kv_count))
|
||||
return false;
|
||||
|
||||
LOG_DEBUG("GGUF v%u, tensor_count=%llu, metadata_kv_count=%llu",
|
||||
version, (unsigned long long)tensor_count, (unsigned long long)metadata_kv_count);
|
||||
|
||||
// --- Read Metadata ---
|
||||
for (uint64_t i = 0; i < metadata_kv_count; i++) {
|
||||
if (!read_metadata(fin)) {
|
||||
LOG_ERROR("read meta data failed");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tensor Infos ---
|
||||
tensors_.clear();
|
||||
try {
|
||||
for (uint64_t i = 0; i < tensor_count; i++) {
|
||||
tensors_.push_back(read_tensor_info(fin));
|
||||
}
|
||||
} catch (const std::runtime_error& e) {
|
||||
LOG_ERROR("%s", e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
data_offset_ = static_cast<size_t>(fin.tellg());
|
||||
if ((data_offset_ % alignment_) != 0) {
|
||||
data_offset_ = ((data_offset_ + alignment_ - 1) / alignment_) * alignment_;
|
||||
}
|
||||
fin.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<GGUFTensorInfo>& tensors() const { return tensors_; }
|
||||
size_t data_offset() const { return data_offset_; }
|
||||
};
|
||||
|
||||
#endif // __GGUF_READER_HPP__
|
||||
32
model.cpp
32
model.cpp
@ -6,6 +6,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "gguf_reader.hpp"
|
||||
#include "model.h"
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
@ -1057,8 +1058,35 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
|
||||
if (!ctx_gguf_) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
return false;
|
||||
LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str());
|
||||
GGUFReader gguf_reader;
|
||||
if (!gguf_reader.load(file_path)) {
|
||||
LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_offset = gguf_reader.data_offset();
|
||||
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
|
||||
std::string name = gguf_tensor_info.name;
|
||||
if (!starts_with(name, prefix)) {
|
||||
name = prefix + name;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(
|
||||
name,
|
||||
gguf_tensor_info.type,
|
||||
gguf_tensor_info.shape.data(),
|
||||
gguf_tensor_info.shape.size(),
|
||||
file_index,
|
||||
data_offset + gguf_tensor_info.offset);
|
||||
|
||||
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int n_tensors = gguf_get_n_tensors(ctx_gguf_);
|
||||
|
||||
2
model.h
2
model.h
@ -123,7 +123,7 @@ struct TensorStorage {
|
||||
|
||||
TensorStorage() = default;
|
||||
|
||||
TensorStorage(const std::string& name, ggml_type type, int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
TensorStorage(const std::string& name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
: name(name), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
this->ne[i] = ne[i];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user