124 lines
3.5 KiB
C++

#include "gguf_io.h"
#include <cstdint>
#include <fstream>
#include <string>
#include <vector>
#include "gguf.h"
#include "gguf_reader_ext.h"
#include "util.h"
static void set_error(std::string* error, const std::string& message) {
if (error != nullptr) {
*error = message;
}
}
bool is_gguf_file(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
return false;
}
char magic[4];
file.read(magic, sizeof(magic));
if (!file) {
return false;
}
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
return false;
}
}
return true;
}
bool read_gguf_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
tensor_storages.clear();
gguf_context* ctx_gguf_ = nullptr;
ggml_context* ctx_meta_ = nullptr;
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
if (!ctx_gguf_) {
GGUFReader gguf_reader;
if (!gguf_reader.load(file_path)) {
set_error(error, "failed to open '" + file_path + "' with GGUFReader");
return false;
}
size_t data_offset = gguf_reader.data_offset();
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
TensorStorage tensor_storage(
gguf_tensor_info.name,
gguf_tensor_info.type,
gguf_tensor_info.shape.data(),
static_cast<int>(gguf_tensor_info.shape.size()),
0,
data_offset + gguf_tensor_info.offset);
tensor_storages.push_back(tensor_storage);
}
return true;
}
int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
size_t data_offset = gguf_get_data_offset(ctx_gguf_);
for (int i = 0; i < n_tensors; i++) {
std::string name = gguf_get_tensor_name(ctx_gguf_, i);
ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str());
size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i);
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset);
if (ggml_nbytes(dummy) != tensor_storage.nbytes()) {
gguf_free(ctx_gguf_);
ggml_free(ctx_meta_);
set_error(error, "size mismatch for tensor '" + name + "'");
return false;
}
tensor_storages.push_back(tensor_storage);
}
gguf_free(ctx_gguf_);
ggml_free(ctx_meta_);
return true;
}
bool write_gguf_file(const std::string& file_path,
const std::vector<TensorWriteInfo>& tensors,
std::string* error) {
gguf_context* gguf_ctx = gguf_init_empty();
if (gguf_ctx == nullptr) {
set_error(error, "gguf_init_empty failed");
return false;
}
for (const TensorWriteInfo& write_tensor : tensors) {
ggml_tensor* tensor = write_tensor.tensor;
if (tensor == nullptr) {
set_error(error, "null tensor cannot be written to GGUF");
gguf_free(gguf_ctx);
return false;
}
gguf_add_tensor(gguf_ctx, tensor);
}
LOG_INFO("trying to save tensors to %s", file_path.c_str());
bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
if (!success) {
set_error(error, "failed to write GGUF file '" + file_path + "'");
}
gguf_free(gguf_ctx);
return success;
}