refactor: split model loader from model definitions (#1619)

This commit is contained in:
leejet 2026-06-07 23:20:12 +08:00 committed by GitHub
parent 2a07540c2a
commit b3d56d0ba1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 106 additions and 84 deletions

View File

@ -9,6 +9,7 @@
#include "model/te/clip.hpp"
#include "model/te/llm.hpp"
#include "model/te/t5.hpp"
#include "model_loader.h"
struct SDCondition {
sd::Tensor<float> c_crossattn;

View File

@ -3,9 +3,9 @@
#include <regex>
#include <vector>
#include "model.h"
#include "model_io/gguf_io.h"
#include "model_io/safetensors_io.h"
#include "model_loader.h"
#include "util.h"
#include "ggml_extend_backend.h"

View File

@ -9,7 +9,7 @@
#include "conditioning/conditioner.hpp"
#include "core/ggml_extend_backend.h"
#include "model.h"
#include "model_loader.h"
#include "stable-diffusion.h"
struct GenerationExtensionInitContext {

View File

@ -1,11 +1,8 @@
#ifndef __MODEL_H__
#define __MODEL_H__
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "core/ordered_map.hpp"
@ -238,73 +235,4 @@ enum PMVersion {
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
class MmapWrapper;
struct ModelFileData {
std::string path;
std::vector<TensorStorage> tensors;
std::shared_ptr<MmapWrapper> mmapped;
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
bool is_zip;
};
struct MmapTensorStore {
std::shared_ptr<MmapWrapper> mmapped;
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
};
class ModelLoader {
protected:
SDVersion version_ = VERSION_COUNT;
std::vector<std::string> file_paths_;
std::vector<ModelFileData> file_data;
bool model_files_processed = false;
String2TensorStorage tensor_storage_map;
void add_tensor_storage(const TensorStorage& tensor_storage);
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
void convert_tensors_name();
bool init_from_file_and_convert_name(const std::string& file_path,
const std::string& prefix = "",
SDVersion version = VERSION_COUNT);
SDVersion get_sd_version();
std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},
bool writable = true);
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},
int n_threads = 0,
bool use_mmap = false);
std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
names.push_back(name);
}
return names;
}
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;
};
#endif // __MODEL_H__

View File

@ -3,6 +3,7 @@
#include <mutex>
#include "core/ggml_extend.hpp"
#include "model_loader.h"
#define LORA_GRAPH_BASE_SIZE 10240

View File

@ -6,6 +6,7 @@
#include "model/adapter/lora.hpp"
#include "model/common/block.hpp"
#include "model/te/clip.hpp"
#include "model_loader.h"
struct FuseBlock : public GGMLBlock {
// network hparams

View File

@ -1,8 +1,8 @@
#ifndef __SD_MODEL_DIFFUSION_CONTROL_HPP__
#define __SD_MODEL_DIFFUSION_CONTROL_HPP__
#include "model.h"
#include "model/common/block.hpp"
#include "model_loader.h"
#define CONTROL_NET_GRAPH_SIZE 1536

View File

@ -4,10 +4,10 @@
#include <memory>
#include <vector>
#include "model.h"
#include "model/common/rope.hpp"
#include "model/diffusion/dit.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
#define FLUX_GRAPH_SIZE 10240

View File

@ -13,6 +13,7 @@
#include "model/common/rope.hpp"
#include "model/diffusion/flux.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
namespace LTXV {

View File

@ -7,9 +7,9 @@
#include <vector>
#include "core/ggml_extend.hpp"
#include "model.h"
#include "model/common/block.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
#define MMDIT_GRAPH_SIZE 10240

View File

@ -6,6 +6,7 @@
#include "model/common/block.hpp"
#include "model/diffusion/flux.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
namespace Qwen {
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;

View File

@ -9,6 +9,7 @@
#include "model/common/rope.hpp"
#include "model/diffusion/flux.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
namespace WAN {

View File

@ -7,6 +7,7 @@
#include "model/diffusion/flux.hpp"
#include "model/diffusion/mmdit.hpp"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
// Ref: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/model/model.py
// Ref: https://github.com/huggingface/diffusers/pull/12703

View File

@ -21,6 +21,7 @@
#include "core/ggml_extend.hpp"
#include "json.hpp"
#include "model/common/rope.hpp"
#include "model_loader.h"
#include "tokenizers/bpe_tokenizer.h"
#include "tokenizers/gemma_tokenizer.h"
#include "tokenizers/gpt_oss_tokenizer.h"

View File

@ -11,7 +11,7 @@
#include <unordered_map>
#include "core/ggml_extend.hpp"
#include "model.h"
#include "model_loader.h"
#include "tokenizers/t5_unigram_tokenizer.h"
struct T5Config {

View File

@ -2,7 +2,7 @@
#define __SD_MODEL_UPSCALER_ESRGAN_HPP__
#include "core/ggml_extend.hpp"
#include "model.h"
#include "model_loader.h"
/*
=================================== ESRGAN ===================================

View File

@ -14,8 +14,8 @@
#include "core/ggml_extend.hpp"
#include "core/ggml_graph_cut.h"
#include "core/util.h"
#include "model.h"
#include "model/diffusion/dit.hpp"
#include "model_loader.h"
namespace LTXVUpsampler {
constexpr int LTX_UPSAMPLER_GRAPH_SIZE = 10240;

View File

@ -8,6 +8,7 @@
#include <vector>
#include "core/ggml_extend.hpp"
#include "model_loader.h"
namespace LTXV {

View File

@ -12,6 +12,7 @@
#include "model/diffusion/ltxv.hpp"
#include "model/vae/vae.hpp"
#include "model/vae/wan_vae.hpp"
#include "model_loader.h"
namespace LTXVAE {

View File

@ -7,6 +7,7 @@
#include "model/common/block.hpp"
#include "model/vae/vae.hpp"
#include "model_loader.h"
namespace WAN {

View File

@ -14,11 +14,11 @@
#include <vector>
#include "core/util.h"
#include "model.h"
#include "model_io/gguf_io.h"
#include "model_io/safetensors_io.h"
#include "model_io/torch_legacy_io.h"
#include "model_io/torch_zip_io.h"
#include "model_loader.h"
#include "stable-diffusion.h"
#include "core/ggml_extend_backend.h"

82
src/model_loader.h Normal file
View File

@ -0,0 +1,82 @@
#ifndef __MODEL_LOADER_H__
#define __MODEL_LOADER_H__
#include <cstdint>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "model.h"
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
class MmapWrapper;
struct ModelFileData {
std::string path;
std::vector<TensorStorage> tensors;
std::shared_ptr<MmapWrapper> mmapped;
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
bool is_zip;
};
struct MmapTensorStore {
std::shared_ptr<MmapWrapper> mmapped;
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
};
class ModelLoader {
protected:
SDVersion version_ = VERSION_COUNT;
std::vector<std::string> file_paths_;
std::vector<ModelFileData> file_data;
bool model_files_processed = false;
String2TensorStorage tensor_storage_map;
void add_tensor_storage(const TensorStorage& tensor_storage);
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public:
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
void convert_tensors_name();
bool init_from_file_and_convert_name(const std::string& file_path,
const std::string& prefix = "",
SDVersion version = VERSION_COUNT);
SDVersion get_sd_version();
std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},
bool writable = true);
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},
int n_threads = 0,
bool use_mmap = false);
std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
names.push_back(name);
}
return names;
}
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;
};
#endif // __MODEL_LOADER_H__

View File

@ -1,3 +1,4 @@
#include <map>
#include <unordered_map>
#include <unordered_set>

View File

@ -10,7 +10,7 @@
#include "core/rng_mt19937.hpp"
#include "core/rng_philox.hpp"
#include "core/util.h"
#include "model.h"
#include "model_loader.h"
#include "stable-diffusion.h"
#include "conditioning/conditioner.hpp"

View File

@ -1,7 +1,7 @@
#include "upscaler.h"
#include "core/ggml_extend.hpp"
#include "core/util.h"
#include "model.h"
#include "model_loader.h"
#include "stable-diffusion.h"
#include <utility>