#ifndef __MODEL_H__ #define __MODEL_H__ #include #include #include #include #include #include #include #include #include #include "ggml-backend.h" #include "ggml.h" #include "gguf.h" #include "json.hpp" #include "ordered_map.hpp" #include "zip.h" #define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, VERSION_SD1_INPAINT, VERSION_SD1_PIX2PIX, VERSION_SD1_TINY_UNET, VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SD2_TINY_UNET, VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, VERSION_FLUX_CONTROLS, VERSION_FLEX_2, VERSION_CHROMA_RADIANCE, VERSION_WAN2, VERSION_WAN2_2_I2V, VERSION_WAN2_2_TI2V, VERSION_QWEN_IMAGE, VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_COUNT, }; static inline bool sd_version_is_sd1(SDVersion version) { if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) { return true; } return false; } static inline bool sd_version_is_sd2(SDVersion version) { if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) { return true; } return false; } static inline bool sd_version_is_sdxl(SDVersion version) { if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { return true; } return false; } static inline bool sd_version_is_unet(SDVersion version) { if (sd_version_is_sd1(version) || sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { return true; } return false; } static inline bool sd_version_is_sd3(SDVersion version) { if (version == VERSION_SD3) { return true; } return false; } static inline bool sd_version_is_flux(SDVersion version) { if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 || version == VERSION_OVIS_IMAGE || version == VERSION_CHROMA_RADIANCE) { return true; } return false; } static inline bool sd_version_is_flux2(SDVersion version) { if (version == VERSION_FLUX2) { return true; } return false; } static inline bool sd_version_is_wan(SDVersion version) { if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) { return true; } return false; } static inline bool sd_version_is_qwen_image(SDVersion version) { if (version == VERSION_QWEN_IMAGE) { return true; } return false; } static inline bool sd_version_is_z_image(SDVersion version) { if (version == VERSION_Z_IMAGE) { return true; } return false; } static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { return true; } return false; } static inline bool sd_version_is_dit(SDVersion version) { if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_z_image(version)) { return true; } return false; } static inline bool sd_version_is_unet_edit(SDVersion version) { return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; } static inline bool sd_version_is_control(SDVersion version) { return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2; } static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) { return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version) || sd_version_is_control(version); } enum PMVersion { PM_VERSION_1, PM_VERSION_2, }; struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; ggml_type expected_type = GGML_TYPE_COUNT; bool is_f8_e4m3 = false; bool is_f8_e5m2 = false; bool is_f64 = false; bool is_i64 = false; int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; int n_dims = 0; size_t file_index = 0; int index_in_zip = -1; // >= means stored in a zip file uint64_t offset = 0; // offset in file TensorStorage() = default; TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { for (int i = 0; i < n_dims; i++) { this->ne[i] = ne[i]; } } int64_t nelements() const { int64_t n = 1; for (int i = 0; i < SD_MAX_DIMS; i++) { n *= ne[i]; } return n; } int64_t nbytes() const { return nelements() * ggml_type_size(type) / ggml_blck_size(type); } int64_t nbytes_to_read() const { if (is_f8_e4m3 || is_f8_e5m2) { return nbytes() / 2; } else if (is_f64 || is_i64) { return nbytes() * 2; } else { return nbytes(); } } void unsqueeze() { if (n_dims == 2) { n_dims = 4; ne[3] = ne[1]; ne[2] = ne[0]; ne[1] = 1; ne[0] = 1; } } std::vector chunk(size_t n) { std::vector chunks; uint64_t chunk_size = nbytes_to_read() / n; // printf("%d/%d\n", chunk_size, nbytes_to_read()); reverse_ne(); for (size_t i = 0; i < n; i++) { TensorStorage chunk_i = *this; chunk_i.ne[0] = ne[0] / n; chunk_i.offset = offset + i * chunk_size; chunk_i.reverse_ne(); chunks.push_back(chunk_i); } reverse_ne(); return chunks; } void reverse_ne() { int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; for (int i = 0; i < n_dims; i++) { new_ne[i] = ne[n_dims - 1 - i]; } for (int i = 0; i < n_dims; i++) { ne[i] = new_ne[i]; } } std::string to_string() const { std::stringstream ss; const char* type_name = ggml_type_name(type); if (is_f8_e4m3) { type_name = "f8_e4m3"; } else if (is_f8_e5m2) { type_name = "f8_e5m2"; } else if (is_f64) { type_name = "f64"; } else if (is_i64) { type_name = "i64"; } ss << name << " | " << type_name << " | "; ss << n_dims << " ["; for (int i = 0; i < SD_MAX_DIMS; i++) { ss << ne[i]; if (i != SD_MAX_DIMS - 1) { ss << ", "; } } ss << "]"; return ss.str(); } }; typedef std::function on_new_tensor_cb_t; typedef OrderedMap String2TensorStorage; class ModelLoader { protected: SDVersion version_ = VERSION_COUNT; std::vector file_paths_; String2TensorStorage tensor_storage_map; void add_tensor_storage(const TensorStorage& tensor_storage); bool parse_data_pkl(uint8_t* buffer, size_t buffer_size, zip_t* zip, std::string dir, size_t file_index, const std::string prefix); bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: bool init_from_file(const std::string& file_path, const std::string& prefix = ""); void convert_tensors_name(); bool init_from_file_and_convert_name(const std::string& file_path, const std::string& prefix = "", SDVersion version = VERSION_COUNT); SDVersion get_sd_version(); std::map get_wtype_stat(); std::map get_conditioner_wtype_stat(); std::map get_diffusion_model_wtype_stat(); std::map get_vae_wtype_stat(); String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(std::map& tensors, std::set ignore_tensors = {}, int n_threads = 0); std::vector get_tensor_names() const { std::vector names; for (const auto& [name, tensor_storage] : tensor_storage_map) { names.push_back(name); } return names; } bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; static std::string load_merges(); static std::string load_qwen2_merges(); static std::string load_mistral_merges(); static std::string load_mistral_vocab_json(); static std::string load_t5_tokenizer_json(); static std::string load_umt5_tokenizer_json(); }; #endif // __MODEL_H__