From b3d56d0ba1bd437886079e339118e8e75bb79ee7 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 7 Jun 2026 23:20:12 +0800 Subject: [PATCH] refactor: split model loader from model definitions (#1619) --- src/conditioning/conditioner.hpp | 1 + src/convert.cpp | 2 +- src/extensions/generation_extension.h | 2 +- src/model.h | 74 +------------------ src/model/adapter/lora.hpp | 1 + src/model/adapter/pmid.hpp | 1 + src/model/diffusion/control.hpp | 2 +- src/model/diffusion/flux.hpp | 2 +- src/model/diffusion/ltxv.hpp | 1 + src/model/diffusion/mmdit.hpp | 2 +- src/model/diffusion/qwen_image.hpp | 1 + src/model/diffusion/wan.hpp | 1 + src/model/diffusion/z_image.hpp | 1 + src/model/te/llm.hpp | 1 + src/model/te/t5.hpp | 2 +- src/model/upscaler/esrgan.hpp | 2 +- src/model/upscaler/ltx_latent_upscaler.hpp | 2 +- src/model/vae/ltx_audio_vae.hpp | 1 + src/model/vae/ltx_vae.hpp | 1 + src/model/vae/wan_vae.hpp | 1 + src/{model.cpp => model_loader.cpp} | 2 +- src/model_loader.h | 82 ++++++++++++++++++++++ src/name_conversion.cpp | 1 + src/stable-diffusion.cpp | 2 +- src/upscaler.cpp | 2 +- 25 files changed, 106 insertions(+), 84 deletions(-) rename src/{model.cpp => model_loader.cpp} (99%) create mode 100644 src/model_loader.h diff --git a/src/conditioning/conditioner.hpp b/src/conditioning/conditioner.hpp index 217658bf..0cb3172b 100644 --- a/src/conditioning/conditioner.hpp +++ b/src/conditioning/conditioner.hpp @@ -9,6 +9,7 @@ #include "model/te/clip.hpp" #include "model/te/llm.hpp" #include "model/te/t5.hpp" +#include "model_loader.h" struct SDCondition { sd::Tensor c_crossattn; diff --git a/src/convert.cpp b/src/convert.cpp index cc1cdd7e..5ad066c1 100644 --- a/src/convert.cpp +++ b/src/convert.cpp @@ -3,9 +3,9 @@ #include #include -#include "model.h" #include "model_io/gguf_io.h" #include "model_io/safetensors_io.h" +#include "model_loader.h" #include "util.h" #include "ggml_extend_backend.h" diff --git a/src/extensions/generation_extension.h b/src/extensions/generation_extension.h index 0c895b87..1e6d1341 100644 --- a/src/extensions/generation_extension.h +++ b/src/extensions/generation_extension.h @@ -9,7 +9,7 @@ #include "conditioning/conditioner.hpp" #include "core/ggml_extend_backend.h" -#include "model.h" +#include "model_loader.h" #include "stable-diffusion.h" struct GenerationExtensionInitContext { diff --git a/src/model.h b/src/model.h index 26451e07..d037705e 100644 --- a/src/model.h +++ b/src/model.h @@ -1,11 +1,8 @@ #ifndef __MODEL_H__ #define __MODEL_H__ -#include -#include -#include -#include #include +#include #include #include "core/ordered_map.hpp" @@ -238,73 +235,4 @@ enum PMVersion { typedef OrderedMap String2TensorStorage; using TensorTypeRules = std::vector>; -TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); - -class MmapWrapper; - -struct ModelFileData { - std::string path; - std::vector tensors; - std::shared_ptr mmapped; - std::shared_ptr mmbuffer; - bool is_zip; -}; - -struct MmapTensorStore { - std::shared_ptr mmapped; - std::shared_ptr mmbuffer; -}; - -class ModelLoader { -protected: - SDVersion version_ = VERSION_COUNT; - std::vector file_paths_; - std::vector file_data; - bool model_files_processed = false; - String2TensorStorage tensor_storage_map; - - void add_tensor_storage(const TensorStorage& tensor_storage); - - bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); - -public: - bool init_from_file(const std::string& file_path, const std::string& prefix = ""); - void convert_tensors_name(); - bool init_from_file_and_convert_name(const std::string& file_path, - const std::string& prefix = "", - SDVersion version = VERSION_COUNT); - SDVersion get_sd_version(); - std::map get_wtype_stat(); - std::map get_conditioner_wtype_stat(); - std::map get_diffusion_model_wtype_stat(); - std::map get_vae_wtype_stat(); - String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } - void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); - void process_model_files(bool enable_mmap = false, bool writable_mmap = true); - std::vector mmap_tensors(std::map& tensors, - std::set ignore_tensors = {}, - bool writable = true); - bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false); - bool load_tensors(std::map& tensors, - std::set ignore_tensors = {}, - int n_threads = 0, - bool use_mmap = false); - - std::vector get_tensor_names() const { - std::vector names; - for (const auto& [name, tensor_storage] : tensor_storage_map) { - names.push_back(name); - } - return names; - } - - bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); - int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); - ~ModelLoader() = default; -}; - #endif // __MODEL_H__ diff --git a/src/model/adapter/lora.hpp b/src/model/adapter/lora.hpp index 8132110f..7ab8a8de 100644 --- a/src/model/adapter/lora.hpp +++ b/src/model/adapter/lora.hpp @@ -3,6 +3,7 @@ #include #include "core/ggml_extend.hpp" +#include "model_loader.h" #define LORA_GRAPH_BASE_SIZE 10240 diff --git a/src/model/adapter/pmid.hpp b/src/model/adapter/pmid.hpp index 410ca726..3cf59a47 100644 --- a/src/model/adapter/pmid.hpp +++ b/src/model/adapter/pmid.hpp @@ -6,6 +6,7 @@ #include "model/adapter/lora.hpp" #include "model/common/block.hpp" #include "model/te/clip.hpp" +#include "model_loader.h" struct FuseBlock : public GGMLBlock { // network hparams diff --git a/src/model/diffusion/control.hpp b/src/model/diffusion/control.hpp index c442fb92..2f5eb574 100644 --- a/src/model/diffusion/control.hpp +++ b/src/model/diffusion/control.hpp @@ -1,8 +1,8 @@ #ifndef __SD_MODEL_DIFFUSION_CONTROL_HPP__ #define __SD_MODEL_DIFFUSION_CONTROL_HPP__ -#include "model.h" #include "model/common/block.hpp" +#include "model_loader.h" #define CONTROL_NET_GRAPH_SIZE 1536 diff --git a/src/model/diffusion/flux.hpp b/src/model/diffusion/flux.hpp index a337f99c..1d01041b 100644 --- a/src/model/diffusion/flux.hpp +++ b/src/model/diffusion/flux.hpp @@ -4,10 +4,10 @@ #include #include -#include "model.h" #include "model/common/rope.hpp" #include "model/diffusion/dit.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" #define FLUX_GRAPH_SIZE 10240 diff --git a/src/model/diffusion/ltxv.hpp b/src/model/diffusion/ltxv.hpp index 947a2de3..a86b4cf5 100644 --- a/src/model/diffusion/ltxv.hpp +++ b/src/model/diffusion/ltxv.hpp @@ -13,6 +13,7 @@ #include "model/common/rope.hpp" #include "model/diffusion/flux.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" namespace LTXV { diff --git a/src/model/diffusion/mmdit.hpp b/src/model/diffusion/mmdit.hpp index 3c234054..84433945 100644 --- a/src/model/diffusion/mmdit.hpp +++ b/src/model/diffusion/mmdit.hpp @@ -7,9 +7,9 @@ #include #include "core/ggml_extend.hpp" -#include "model.h" #include "model/common/block.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" #define MMDIT_GRAPH_SIZE 10240 diff --git a/src/model/diffusion/qwen_image.hpp b/src/model/diffusion/qwen_image.hpp index e481ca59..678c3467 100644 --- a/src/model/diffusion/qwen_image.hpp +++ b/src/model/diffusion/qwen_image.hpp @@ -6,6 +6,7 @@ #include "model/common/block.hpp" #include "model/diffusion/flux.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" namespace Qwen { constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480; diff --git a/src/model/diffusion/wan.hpp b/src/model/diffusion/wan.hpp index f3410956..92f49dc2 100644 --- a/src/model/diffusion/wan.hpp +++ b/src/model/diffusion/wan.hpp @@ -9,6 +9,7 @@ #include "model/common/rope.hpp" #include "model/diffusion/flux.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" namespace WAN { diff --git a/src/model/diffusion/z_image.hpp b/src/model/diffusion/z_image.hpp index 9eeec07f..c35c4495 100644 --- a/src/model/diffusion/z_image.hpp +++ b/src/model/diffusion/z_image.hpp @@ -7,6 +7,7 @@ #include "model/diffusion/flux.hpp" #include "model/diffusion/mmdit.hpp" #include "model/diffusion/model.hpp" +#include "model_loader.h" // Ref: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/model/model.py // Ref: https://github.com/huggingface/diffusers/pull/12703 diff --git a/src/model/te/llm.hpp b/src/model/te/llm.hpp index 9e9bfe19..3a22e881 100644 --- a/src/model/te/llm.hpp +++ b/src/model/te/llm.hpp @@ -21,6 +21,7 @@ #include "core/ggml_extend.hpp" #include "json.hpp" #include "model/common/rope.hpp" +#include "model_loader.h" #include "tokenizers/bpe_tokenizer.h" #include "tokenizers/gemma_tokenizer.h" #include "tokenizers/gpt_oss_tokenizer.h" diff --git a/src/model/te/t5.hpp b/src/model/te/t5.hpp index 335969bb..41d9978e 100644 --- a/src/model/te/t5.hpp +++ b/src/model/te/t5.hpp @@ -11,7 +11,7 @@ #include #include "core/ggml_extend.hpp" -#include "model.h" +#include "model_loader.h" #include "tokenizers/t5_unigram_tokenizer.h" struct T5Config { diff --git a/src/model/upscaler/esrgan.hpp b/src/model/upscaler/esrgan.hpp index 9d9d0416..a56ebfe5 100644 --- a/src/model/upscaler/esrgan.hpp +++ b/src/model/upscaler/esrgan.hpp @@ -2,7 +2,7 @@ #define __SD_MODEL_UPSCALER_ESRGAN_HPP__ #include "core/ggml_extend.hpp" -#include "model.h" +#include "model_loader.h" /* =================================== ESRGAN =================================== diff --git a/src/model/upscaler/ltx_latent_upscaler.hpp b/src/model/upscaler/ltx_latent_upscaler.hpp index b608d6cd..b411e8aa 100644 --- a/src/model/upscaler/ltx_latent_upscaler.hpp +++ b/src/model/upscaler/ltx_latent_upscaler.hpp @@ -14,8 +14,8 @@ #include "core/ggml_extend.hpp" #include "core/ggml_graph_cut.h" #include "core/util.h" -#include "model.h" #include "model/diffusion/dit.hpp" +#include "model_loader.h" namespace LTXVUpsampler { constexpr int LTX_UPSAMPLER_GRAPH_SIZE = 10240; diff --git a/src/model/vae/ltx_audio_vae.hpp b/src/model/vae/ltx_audio_vae.hpp index 662c4122..d41a79a4 100644 --- a/src/model/vae/ltx_audio_vae.hpp +++ b/src/model/vae/ltx_audio_vae.hpp @@ -8,6 +8,7 @@ #include #include "core/ggml_extend.hpp" +#include "model_loader.h" namespace LTXV { diff --git a/src/model/vae/ltx_vae.hpp b/src/model/vae/ltx_vae.hpp index 17d1f9b2..86fcdcb0 100644 --- a/src/model/vae/ltx_vae.hpp +++ b/src/model/vae/ltx_vae.hpp @@ -12,6 +12,7 @@ #include "model/diffusion/ltxv.hpp" #include "model/vae/vae.hpp" #include "model/vae/wan_vae.hpp" +#include "model_loader.h" namespace LTXVAE { diff --git a/src/model/vae/wan_vae.hpp b/src/model/vae/wan_vae.hpp index 4bc21a84..c20764cd 100644 --- a/src/model/vae/wan_vae.hpp +++ b/src/model/vae/wan_vae.hpp @@ -7,6 +7,7 @@ #include "model/common/block.hpp" #include "model/vae/vae.hpp" +#include "model_loader.h" namespace WAN { diff --git a/src/model.cpp b/src/model_loader.cpp similarity index 99% rename from src/model.cpp rename to src/model_loader.cpp index 0953a827..9c2d5cef 100644 --- a/src/model.cpp +++ b/src/model_loader.cpp @@ -14,11 +14,11 @@ #include #include "core/util.h" -#include "model.h" #include "model_io/gguf_io.h" #include "model_io/safetensors_io.h" #include "model_io/torch_legacy_io.h" #include "model_io/torch_zip_io.h" +#include "model_loader.h" #include "stable-diffusion.h" #include "core/ggml_extend_backend.h" diff --git a/src/model_loader.h b/src/model_loader.h new file mode 100644 index 00000000..8e0f4198 --- /dev/null +++ b/src/model_loader.h @@ -0,0 +1,82 @@ +#ifndef __MODEL_LOADER_H__ +#define __MODEL_LOADER_H__ + +#include +#include +#include +#include +#include +#include + +#include "model.h" + +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); + +class MmapWrapper; + +struct ModelFileData { + std::string path; + std::vector tensors; + std::shared_ptr mmapped; + std::shared_ptr mmbuffer; + bool is_zip; +}; + +struct MmapTensorStore { + std::shared_ptr mmapped; + std::shared_ptr mmbuffer; +}; + +class ModelLoader { +protected: + SDVersion version_ = VERSION_COUNT; + std::vector file_paths_; + std::vector file_data; + bool model_files_processed = false; + String2TensorStorage tensor_storage_map; + + void add_tensor_storage(const TensorStorage& tensor_storage); + + bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); + +public: + bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + void convert_tensors_name(); + bool init_from_file_and_convert_name(const std::string& file_path, + const std::string& prefix = "", + SDVersion version = VERSION_COUNT); + SDVersion get_sd_version(); + std::map get_wtype_stat(); + std::map get_conditioner_wtype_stat(); + std::map get_diffusion_model_wtype_stat(); + std::map get_vae_wtype_stat(); + String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } + void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); + void process_model_files(bool enable_mmap = false, bool writable_mmap = true); + std::vector mmap_tensors(std::map& tensors, + std::set ignore_tensors = {}, + bool writable = true); + bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false); + bool load_tensors(std::map& tensors, + std::set ignore_tensors = {}, + int n_threads = 0, + bool use_mmap = false); + + std::vector get_tensor_names() const { + std::vector names; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + names.push_back(name); + } + return names; + } + + bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); + int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); + ~ModelLoader() = default; +}; + +#endif // __MODEL_LOADER_H__ diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index 55628b0f..e316f8c4 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -1,3 +1,4 @@ +#include #include #include diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 266618c4..8ba4a463 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -10,7 +10,7 @@ #include "core/rng_mt19937.hpp" #include "core/rng_philox.hpp" #include "core/util.h" -#include "model.h" +#include "model_loader.h" #include "stable-diffusion.h" #include "conditioning/conditioner.hpp" diff --git a/src/upscaler.cpp b/src/upscaler.cpp index 54c56850..8635f677 100644 --- a/src/upscaler.cpp +++ b/src/upscaler.cpp @@ -1,7 +1,7 @@ #include "upscaler.h" #include "core/ggml_extend.hpp" #include "core/util.h" -#include "model.h" +#include "model_loader.h" #include "stable-diffusion.h" #include