feat: support mmap for model loading (#1059)

This commit is contained in:
Wagner Bruna 2025-12-28 11:38:29 -03:00 committed by GitHub
parent a2d83dd0c8
commit d0d836ae74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 175 additions and 7 deletions

View File

@ -48,6 +48,7 @@ Context Options:
--vae-tiling process vae in tiles to reduce memory usage --vae-tiling process vae in tiles to reduce memory usage
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
--mmap whether to memory-map model
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)

View File

@ -453,6 +453,7 @@ struct SDContextParams {
rng_type_t rng_type = CUDA_RNG; rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT; rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false; bool offload_params_to_cpu = false;
bool enable_mmap = false;
bool control_net_cpu = false; bool control_net_cpu = false;
bool clip_on_cpu = false; bool clip_on_cpu = false;
bool vae_on_cpu = false; bool vae_on_cpu = false;
@ -598,6 +599,10 @@ struct SDContextParams {
"--offload-to-cpu", "--offload-to-cpu",
"place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed",
true, &offload_params_to_cpu}, true, &offload_params_to_cpu},
{"",
"--mmap",
"whether to memory-map model",
true, &enable_mmap},
{"", {"",
"--control-net-cpu", "--control-net-cpu",
"keep controlnet in cpu (for low vram)", "keep controlnet in cpu (for low vram)",
@ -895,6 +900,7 @@ struct SDContextParams {
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n" << " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
@ -958,6 +964,7 @@ struct SDContextParams {
prediction, prediction,
lora_apply_mode, lora_apply_mode,
offload_params_to_cpu, offload_params_to_cpu,
enable_mmap,
clip_on_cpu, clip_on_cpu,
control_net_cpu, control_net_cpu,
vae_on_cpu, vae_on_cpu,

View File

@ -43,6 +43,7 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)
--mmap whether to memory-map model
--diffusion-fa use flash attention in the diffusion model --diffusion-fa use flash attention in the diffusion model
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model --vae-conv-direct use ggml_conv2d_direct in the vae model

View File

@ -1340,7 +1340,7 @@ std::string ModelLoader::load_umt5_tokenizer_json() {
return json_str; return json_str;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
int64_t process_time_ms = 0; int64_t process_time_ms = 0;
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);
std::atomic<int64_t> memcpy_time_ms(0); std::atomic<int64_t> memcpy_time_ms(0);
@ -1390,6 +1390,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
} }
} }
std::unique_ptr<MmapWrapper> mmapped;
if (enable_mmap && !is_zip) {
LOG_DEBUG("using mmap for I/O");
mmapped = MmapWrapper::create(file_path);
if (!mmapped) {
LOG_WARN("failed to memory-map '%s'", file_path.c_str());
}
}
int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size()); int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size());
if (n_threads < 1) { if (n_threads < 1) {
n_threads = 1; n_threads = 1;
@ -1411,7 +1420,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
failed = true; failed = true;
return; return;
} }
} else { } else if (!mmapped) {
file.open(file_path, std::ios::binary); file.open(file_path, std::ios::binary);
if (!file.is_open()) { if (!file.is_open()) {
LOG_ERROR("failed to open '%s'", file_path.c_str()); LOG_ERROR("failed to open '%s'", file_path.c_str());
@ -1464,6 +1473,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
zip_entry_noallocread(zip, (void*)buf, n); zip_entry_noallocread(zip, (void*)buf, n);
} }
zip_entry_close(zip); zip_entry_close(zip);
} else if (mmapped) {
if (!mmapped->copy_data(buf, n, tensor_storage.offset)) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
failed = true;
}
} else { } else {
file.seekg(tensor_storage.offset); file.seekg(tensor_storage.offset);
file.read(buf, n); file.read(buf, n);
@ -1588,7 +1602,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors, bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors, std::set<std::string> ignore_tensors,
int n_threads) { int n_threads,
bool enable_mmap) {
std::set<std::string> tensor_names_in_file; std::set<std::string> tensor_names_in_file;
std::mutex tensor_names_mutex; std::mutex tensor_names_mutex;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
@ -1631,7 +1646,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true; return true;
}; };
bool success = load_tensors(on_new_tensor_cb, n_threads); bool success = load_tensors(on_new_tensor_cb, n_threads, enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from file failed"); LOG_ERROR("load tensors from file failed");
return false; return false;

View File

@ -310,10 +310,11 @@ public:
std::map<ggml_type, uint32_t> get_vae_wtype_stat(); std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors, bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {}, std::set<std::string> ignore_tensors = {},
int n_threads = 0); int n_threads = 0,
bool use_mmap = false);
std::vector<std::string> get_tensor_names() const { std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names; std::vector<std::string> names;

View File

@ -766,7 +766,7 @@ public:
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3"); ignore_tensors.insert("conditioner.embedders.3");
} }
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads); bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from model loader failed"); LOG_ERROR("load tensors from model loader failed");
ggml_free(ctx); ggml_free(ctx);
@ -2875,6 +2875,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->enable_mmap = false;
sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_clip_on_cpu = false;
sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false;
sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false;

View File

@ -182,6 +182,7 @@ typedef struct {
enum prediction_t prediction; enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode; enum lora_apply_mode_t lora_apply_mode;
bool offload_params_to_cpu; bool offload_params_to_cpu;
bool enable_mmap;
bool keep_clip_on_cpu; bool keep_clip_on_cpu;
bool keep_control_net_on_cpu; bool keep_control_net_on_cpu;
bool keep_vae_on_cpu; bool keep_vae_on_cpu;

118
util.cpp
View File

@ -95,9 +95,71 @@ bool is_directory(const std::string& path) {
return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY)); return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY));
} }
class MmapWrapperImpl : public MmapWrapper {
public:
MmapWrapperImpl(void* data, size_t size, HANDLE hfile, HANDLE hmapping)
: MmapWrapper(data, size), hfile_(hfile), hmapping_(hmapping) {}
~MmapWrapperImpl() override {
UnmapViewOfFile(data_);
CloseHandle(hmapping_);
CloseHandle(hfile_);
}
private:
HANDLE hfile_;
HANDLE hmapping_;
};
std::unique_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
void* mapped_data = nullptr;
size_t file_size = 0;
HANDLE file_handle = CreateFileA(
filename.c_str(),
GENERIC_READ,
FILE_SHARE_READ,
NULL,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
NULL);
if (file_handle == INVALID_HANDLE_VALUE) {
return nullptr;
}
LARGE_INTEGER size;
if (!GetFileSizeEx(file_handle, &size)) {
CloseHandle(file_handle);
return nullptr;
}
file_size = static_cast<size_t>(size.QuadPart);
HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL);
if (mapping_handle == NULL) {
CloseHandle(file_handle);
return nullptr;
}
mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size);
if (mapped_data == NULL) {
CloseHandle(mapping_handle);
CloseHandle(file_handle);
return nullptr;
}
return std::make_unique<MmapWrapperImpl>(mapped_data, file_size, file_handle, mapping_handle);
}
#else // Unix #else // Unix
#include <dirent.h> #include <dirent.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h>
bool file_exists(const std::string& filename) { bool file_exists(const std::string& filename) {
struct stat buffer; struct stat buffer;
@ -109,8 +171,64 @@ bool is_directory(const std::string& path) {
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
} }
class MmapWrapperImpl : public MmapWrapper {
public:
MmapWrapperImpl(void* data, size_t size)
: MmapWrapper(data, size) {}
~MmapWrapperImpl() override {
munmap(data_, size_);
}
};
std::unique_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
return nullptr;
}
int mmap_flags = MAP_PRIVATE;
#ifdef __linux__
// performance flags used by llama.cpp
// posix_fadvise(file_descriptor, 0, 0, POSIX_FADV_SEQUENTIAL);
// mmap_flags |= MAP_POPULATE;
#endif #endif
struct stat sb;
if (fstat(file_descriptor, &sb) == -1) {
close(file_descriptor);
return nullptr;
}
size_t file_size = sb.st_size;
void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0);
close(file_descriptor);
if (mapped_data == MAP_FAILED) {
return nullptr;
}
#ifdef __linux__
// performance flags used by llama.cpp
// posix_madvise(mapped_data, file_size, POSIX_MADV_WILLNEED);
#endif
return std::make_unique<MmapWrapperImpl>(mapped_data, file_size);
}
#endif
bool MmapWrapper::copy_data(void* buf, size_t n, size_t offset) const {
if (offset >= size_ || n > (size_ - offset)) {
return false;
}
std::memcpy(buf, data() + offset, n);
return true;
}
// get_num_physical_cores is copy from // get_num_physical_cores is copy from
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp // https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE // LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE

23
util.h
View File

@ -2,6 +2,7 @@
#define __UTIL_H__ #define __UTIL_H__
#include <cstdint> #include <cstdint>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -43,6 +44,28 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height); sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
class MmapWrapper {
public:
static std::unique_ptr<MmapWrapper> create(const std::string& filename);
virtual ~MmapWrapper() = default;
MmapWrapper(const MmapWrapper&) = delete;
MmapWrapper& operator=(const MmapWrapper&) = delete;
MmapWrapper(MmapWrapper&&) = delete;
MmapWrapper& operator=(MmapWrapper&&) = delete;
const uint8_t* data() const { return static_cast<uint8_t*>(data_); }
size_t size() const { return size_; }
bool copy_data(void* buf, size_t n, size_t offset) const;
protected:
MmapWrapper(void* data, size_t size)
: data_(data), size_(size) {}
void* data_ = nullptr;
size_t size_ = 0;
};
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter); std::vector<std::string> split_string(const std::string& str, char delimiter);
void pretty_progress(int step, int steps, float time); void pretty_progress(int step, int steps, float time);