mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
refactor: split model loader from model definitions (#1619)
This commit is contained in:
parent
2a07540c2a
commit
b3d56d0ba1
@ -9,6 +9,7 @@
|
||||
#include "model/te/clip.hpp"
|
||||
#include "model/te/llm.hpp"
|
||||
#include "model/te/t5.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
struct SDCondition {
|
||||
sd::Tensor<float> c_crossattn;
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
#include "model_io/gguf_io.h"
|
||||
#include "model_io/safetensors_io.h"
|
||||
#include "model_loader.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "ggml_extend_backend.h"
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
|
||||
#include "conditioning/conditioner.hpp"
|
||||
#include "core/ggml_extend_backend.h"
|
||||
#include "model.h"
|
||||
#include "model_loader.h"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
struct GenerationExtensionInitContext {
|
||||
|
||||
74
src/model.h
74
src/model.h
@ -1,11 +1,8 @@
|
||||
#ifndef __MODEL_H__
|
||||
#define __MODEL_H__
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "core/ordered_map.hpp"
|
||||
@ -238,73 +235,4 @@ enum PMVersion {
|
||||
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
||||
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
|
||||
|
||||
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
|
||||
|
||||
class MmapWrapper;
|
||||
|
||||
struct ModelFileData {
|
||||
std::string path;
|
||||
std::vector<TensorStorage> tensors;
|
||||
std::shared_ptr<MmapWrapper> mmapped;
|
||||
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
|
||||
bool is_zip;
|
||||
};
|
||||
|
||||
struct MmapTensorStore {
|
||||
std::shared_ptr<MmapWrapper> mmapped;
|
||||
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
|
||||
};
|
||||
|
||||
class ModelLoader {
|
||||
protected:
|
||||
SDVersion version_ = VERSION_COUNT;
|
||||
std::vector<std::string> file_paths_;
|
||||
std::vector<ModelFileData> file_data;
|
||||
bool model_files_processed = false;
|
||||
String2TensorStorage tensor_storage_map;
|
||||
|
||||
void add_tensor_storage(const TensorStorage& tensor_storage);
|
||||
|
||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
void convert_tensors_name();
|
||||
bool init_from_file_and_convert_name(const std::string& file_path,
|
||||
const std::string& prefix = "",
|
||||
SDVersion version = VERSION_COUNT);
|
||||
SDVersion get_sd_version();
|
||||
std::map<ggml_type, uint32_t> get_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
|
||||
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
|
||||
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
|
||||
void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
|
||||
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
bool writable = true);
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
|
||||
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
int n_threads = 0,
|
||||
bool use_mmap = false);
|
||||
|
||||
std::vector<std::string> get_tensor_names() const {
|
||||
std::vector<std::string> names;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
names.push_back(name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
~ModelLoader() = default;
|
||||
};
|
||||
|
||||
#endif // __MODEL_H__
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <mutex>
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
#define LORA_GRAPH_BASE_SIZE 10240
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include "model/adapter/lora.hpp"
|
||||
#include "model/common/block.hpp"
|
||||
#include "model/te/clip.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
struct FuseBlock : public GGMLBlock {
|
||||
// network hparams
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
#ifndef __SD_MODEL_DIFFUSION_CONTROL_HPP__
|
||||
#define __SD_MODEL_DIFFUSION_CONTROL_HPP__
|
||||
|
||||
#include "model.h"
|
||||
#include "model/common/block.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
#define CONTROL_NET_GRAPH_SIZE 1536
|
||||
|
||||
|
||||
@ -4,10 +4,10 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
#include "model/common/rope.hpp"
|
||||
#include "model/diffusion/dit.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
#define FLUX_GRAPH_SIZE 10240
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "model/common/rope.hpp"
|
||||
#include "model/diffusion/flux.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace LTXV {
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "model/common/block.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
#define MMDIT_GRAPH_SIZE 10240
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include "model/common/block.hpp"
|
||||
#include "model/diffusion/flux.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace Qwen {
|
||||
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include "model/common/rope.hpp"
|
||||
#include "model/diffusion/flux.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace WAN {
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "model/diffusion/flux.hpp"
|
||||
#include "model/diffusion/mmdit.hpp"
|
||||
#include "model/diffusion/model.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
// Ref: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/model/model.py
|
||||
// Ref: https://github.com/huggingface/diffusers/pull/12703
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "json.hpp"
|
||||
#include "model/common/rope.hpp"
|
||||
#include "model_loader.h"
|
||||
#include "tokenizers/bpe_tokenizer.h"
|
||||
#include "tokenizers/gemma_tokenizer.h"
|
||||
#include "tokenizers/gpt_oss_tokenizer.h"
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "model_loader.h"
|
||||
#include "tokenizers/t5_unigram_tokenizer.h"
|
||||
|
||||
struct T5Config {
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
#define __SD_MODEL_UPSCALER_ESRGAN_HPP__
|
||||
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "model_loader.h"
|
||||
|
||||
/*
|
||||
=================================== ESRGAN ===================================
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "core/ggml_graph_cut.h"
|
||||
#include "core/util.h"
|
||||
#include "model.h"
|
||||
#include "model/diffusion/dit.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace LTXVUpsampler {
|
||||
constexpr int LTX_UPSAMPLER_GRAPH_SIZE = 10240;
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace LTXV {
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include "model/diffusion/ltxv.hpp"
|
||||
#include "model/vae/vae.hpp"
|
||||
#include "model/vae/wan_vae.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace LTXVAE {
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
|
||||
#include "model/common/block.hpp"
|
||||
#include "model/vae/vae.hpp"
|
||||
#include "model_loader.h"
|
||||
|
||||
namespace WAN {
|
||||
|
||||
|
||||
@ -14,11 +14,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "core/util.h"
|
||||
#include "model.h"
|
||||
#include "model_io/gguf_io.h"
|
||||
#include "model_io/safetensors_io.h"
|
||||
#include "model_io/torch_legacy_io.h"
|
||||
#include "model_io/torch_zip_io.h"
|
||||
#include "model_loader.h"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#include "core/ggml_extend_backend.h"
|
||||
82
src/model_loader.h
Normal file
82
src/model_loader.h
Normal file
@ -0,0 +1,82 @@
|
||||
#ifndef __MODEL_LOADER_H__
|
||||
#define __MODEL_LOADER_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
|
||||
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
|
||||
|
||||
class MmapWrapper;
|
||||
|
||||
struct ModelFileData {
|
||||
std::string path;
|
||||
std::vector<TensorStorage> tensors;
|
||||
std::shared_ptr<MmapWrapper> mmapped;
|
||||
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
|
||||
bool is_zip;
|
||||
};
|
||||
|
||||
struct MmapTensorStore {
|
||||
std::shared_ptr<MmapWrapper> mmapped;
|
||||
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
|
||||
};
|
||||
|
||||
class ModelLoader {
|
||||
protected:
|
||||
SDVersion version_ = VERSION_COUNT;
|
||||
std::vector<std::string> file_paths_;
|
||||
std::vector<ModelFileData> file_data;
|
||||
bool model_files_processed = false;
|
||||
String2TensorStorage tensor_storage_map;
|
||||
|
||||
void add_tensor_storage(const TensorStorage& tensor_storage);
|
||||
|
||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
void convert_tensors_name();
|
||||
bool init_from_file_and_convert_name(const std::string& file_path,
|
||||
const std::string& prefix = "",
|
||||
SDVersion version = VERSION_COUNT);
|
||||
SDVersion get_sd_version();
|
||||
std::map<ggml_type, uint32_t> get_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
|
||||
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
|
||||
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
|
||||
void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
|
||||
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
bool writable = true);
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
|
||||
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
int n_threads = 0,
|
||||
bool use_mmap = false);
|
||||
|
||||
std::vector<std::string> get_tensor_names() const {
|
||||
std::vector<std::string> names;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
names.push_back(name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
~ModelLoader() = default;
|
||||
};
|
||||
|
||||
#endif // __MODEL_LOADER_H__
|
||||
@ -1,3 +1,4 @@
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include "core/rng_mt19937.hpp"
|
||||
#include "core/rng_philox.hpp"
|
||||
#include "core/util.h"
|
||||
#include "model.h"
|
||||
#include "model_loader.h"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#include "conditioning/conditioner.hpp"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#include "upscaler.h"
|
||||
#include "core/ggml_extend.hpp"
|
||||
#include "core/util.h"
|
||||
#include "model.h"
|
||||
#include "model_loader.h"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user