#ifndef __MODEL_LOADER_H__ #define __MODEL_LOADER_H__ #include #include #include #include #include #include #include "model.h" TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); class MmapWrapper; struct ModelFileData { std::string path; std::vector tensors; std::shared_ptr mmapped; std::shared_ptr mmbuffer; bool is_zip; }; struct MmapTensorStore { std::shared_ptr mmapped; std::shared_ptr mmbuffer; }; class ModelLoader { protected: SDVersion version_ = VERSION_COUNT; std::vector file_paths_; std::vector file_data; bool model_files_processed = false; String2TensorStorage tensor_storage_map; void add_tensor_storage(const TensorStorage& tensor_storage); bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: bool init_from_file(const std::string& file_path, const std::string& prefix = ""); void convert_tensors_name(); bool init_from_file_and_convert_name(const std::string& file_path, const std::string& prefix = "", SDVersion version = VERSION_COUNT); SDVersion get_sd_version(); std::map get_wtype_stat(); std::map get_conditioner_wtype_stat(); std::map get_diffusion_model_wtype_stat(); std::map get_vae_wtype_stat(); String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); void process_model_files(bool enable_mmap = false, bool writable_mmap = true); std::vector mmap_tensors(std::map& tensors, std::set ignore_tensors = {}, bool writable = true); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false); bool load_tensors(std::map& tensors, std::set ignore_tensors = {}, int n_threads = 0, bool use_mmap = false); std::vector get_tensor_names() const { std::vector names; for (const auto& [name, tensor_storage] : tensor_storage_map) { names.push_back(name); } return names; } bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; }; #endif // __MODEL_LOADER_H__