refactor: centralize runner weight staging and cleanup (#1644)

This commit is contained in:
leejet 2026-06-13 13:19:13 +08:00 committed by GitHub
parent 3a54597776
commit 563137a592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 2335 additions and 1729 deletions

View File

@ -113,14 +113,13 @@ struct Conditioner {
public: public:
virtual SDCondition get_learned_condition(int n_threads, virtual SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) = 0; const ConditionerParams& conditioner_params) = 0;
virtual bool alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {} virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {}
virtual void set_stream_layers_enabled(bool enabled) {} virtual void set_stream_layers_enabled(bool enabled) {}
virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {} virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) {}
virtual void runner_done() {}
}; };
// ldm.modules.encoders.modules.FrozenCLIPEmbedder // ldm.modules.encoders.modules.FrozenCLIPEmbedder
@ -167,33 +166,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
} }
} }
bool alloc_params_buffer() override {
if (!text_model->alloc_params_buffer()) {
return false;
}
if (sd_version_is_sdxl(version)) {
if (!text_model2->alloc_params_buffer()) {
return false;
}
}
return true;
}
void free_params_buffer() override {
text_model->free_params_buffer();
if (sd_version_is_sdxl(version)) {
text_model2->free_params_buffer();
}
}
size_t get_params_buffer_size() override {
size_t buffer_size = text_model->get_params_buffer_size();
if (sd_version_is_sdxl(version)) {
buffer_size += text_model2->get_params_buffer_size();
}
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
text_model->set_max_graph_vram_bytes(max_vram_bytes); text_model->set_max_graph_vram_bytes(max_vram_bytes);
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
@ -222,6 +194,20 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
} }
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
text_model->set_weight_manager(manager);
if (sd_version_is_sdxl(version)) {
text_model2->set_weight_manager(manager);
}
}
void runner_done() override {
text_model->runner_done();
if (sd_version_is_sdxl(version)) {
text_model2->runner_done();
}
}
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) { bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(embd_path)) { if (!model_loader.init_from_file_and_convert_name(embd_path)) {
@ -263,7 +249,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
} }
return true; return true;
}; };
model_loader.load_tensors(on_load, 1); model_loader.set_n_threads(1);
model_loader.load_tensors(on_load);
int pos_start = num_custom_embeddings; int pos_start = num_custom_embeddings;
if (embd) { if (embd) {
int64_t hidden_size = text_model->model.hidden_size; int64_t hidden_size = text_model->model.hidden_size;
@ -432,7 +419,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
token_embed_custom.data(), token_embed_custom.data(),
max_token_idx, max_token_idx,
false, false,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states.empty()); GGML_ASSERT(!chunk_hidden_states.empty());
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
auto chunk_hidden_states2 = text_model2->compute(n_threads, auto chunk_hidden_states2 = text_model2->compute(n_threads,
@ -441,7 +431,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
token_embed_custom.data(), token_embed_custom.data(),
max_token_idx, max_token_idx,
false, false,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states2.empty()); GGML_ASSERT(!chunk_hidden_states2.empty());
chunk_hidden_states = sd::ops::concat(chunk_hidden_states, chunk_hidden_states2, 0); chunk_hidden_states = sd::ops::concat(chunk_hidden_states, chunk_hidden_states2, 0);
@ -452,7 +445,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
token_embed_custom.data(), token_embed_custom.data(),
max_token_idx, max_token_idx,
true, true,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!pooled.empty()); GGML_ASSERT(!pooled.empty());
} }
} }
@ -523,15 +519,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
struct FrozenCLIPVisionEmbedder : public GGMLRunner { struct FrozenCLIPVisionEmbedder : public GGMLRunner {
CLIPVisionModelProjection vision_model; CLIPVisionModelProjection vision_model;
std::string weight_prefix = "cond_stage_model.transformer";
FrozenCLIPVisionEmbedder(ggml_backend_t backend, FrozenCLIPVisionEmbedder(ggml_backend_t backend,
ggml_backend_t params_backend, ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {}) const String2TensorStorage& tensor_storage_map = {})
: GGMLRunner(backend, params_backend) { : GGMLRunner(backend, params_backend) {
std::string prefix = "cond_stage_model.transformer"; bool proj_in = false;
bool proj_in = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) { for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) { if (!starts_with(name, weight_prefix)) {
continue; continue;
} }
if (contains(name, "self_attn.in_proj")) { if (contains(name, "self_attn.in_proj")) {
@ -540,7 +536,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
} }
} }
vision_model = CLIPVisionModelProjection(OPEN_CLIP_VIT_H_14, false, proj_in); vision_model = CLIPVisionModelProjection(OPEN_CLIP_VIT_H_14, false, proj_in);
vision_model.init(params_ctx, tensor_storage_map, prefix); vision_model.init(params_ctx, tensor_storage_map, weight_prefix);
} }
std::string get_desc() override { std::string get_desc() override {
@ -548,7 +544,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) {
vision_model.get_param_tensors(tensors, "cond_stage_model.transformer"); vision_model.get_param_tensors(tensors, weight_prefix);
} }
ggml_cgraph* build_graph(const sd::Tensor<float>& pixel_values_tensor, bool return_pooled, int clip_skip) { ggml_cgraph* build_graph(const sd::Tensor<float>& pixel_values_tensor, bool return_pooled, int clip_skip) {
@ -571,7 +567,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(pixel_values, return_pooled, clip_skip); return build_graph(pixel_values, return_pooled, clip_skip);
}; };
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true)); return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true, true, true));
} }
}; };
@ -626,51 +622,6 @@ struct SD3CLIPEmbedder : public Conditioner {
} }
} }
bool alloc_params_buffer() override {
if (clip_l) {
if (!clip_l->alloc_params_buffer()) {
return false;
}
}
if (clip_g) {
if (!clip_g->alloc_params_buffer()) {
return false;
}
}
if (t5) {
if (!t5->alloc_params_buffer()) {
return false;
}
}
return true;
}
void free_params_buffer() override {
if (clip_l) {
clip_l->free_params_buffer();
}
if (clip_g) {
clip_g->free_params_buffer();
}
if (t5) {
t5->free_params_buffer();
}
}
size_t get_params_buffer_size() override {
size_t buffer_size = 0;
if (clip_l) {
buffer_size += clip_l->get_params_buffer_size();
}
if (clip_g) {
buffer_size += clip_g->get_params_buffer_size();
}
if (t5) {
buffer_size += t5->get_params_buffer_size();
}
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) { if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes); clip_l->set_max_graph_vram_bytes(max_vram_bytes);
@ -719,6 +670,30 @@ struct SD3CLIPEmbedder : public Conditioner {
} }
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
if (clip_l) {
clip_l->set_weight_manager(manager);
}
if (clip_g) {
clip_g->set_weight_manager(manager);
}
if (t5) {
t5->set_weight_manager(manager);
}
}
void runner_done() override {
if (clip_l) {
clip_l->runner_done();
}
if (clip_g) {
clip_g->runner_done();
}
if (t5) {
t5->runner_done();
}
}
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text, std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
size_t min_length = 0, size_t min_length = 0,
size_t max_length = 0, size_t max_length = 0,
@ -834,7 +809,10 @@ struct SD3CLIPEmbedder : public Conditioner {
nullptr, nullptr,
max_token_idx, max_token_idx,
false, false,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states_l.empty()); GGML_ASSERT(!chunk_hidden_states_l.empty());
chunk_hidden_states_l = ::apply_token_weights(std::move(chunk_hidden_states_l), chunk_weights); chunk_hidden_states_l = ::apply_token_weights(std::move(chunk_hidden_states_l), chunk_weights);
@ -847,7 +825,10 @@ struct SD3CLIPEmbedder : public Conditioner {
nullptr, nullptr,
max_token_idx, max_token_idx,
true, true,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!pooled_l.empty()); GGML_ASSERT(!pooled_l.empty());
} }
} else { } else {
@ -875,7 +856,10 @@ struct SD3CLIPEmbedder : public Conditioner {
nullptr, nullptr,
max_token_idx, max_token_idx,
false, false,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states_g.empty()); GGML_ASSERT(!chunk_hidden_states_g.empty());
chunk_hidden_states_g = ::apply_token_weights(std::move(chunk_hidden_states_g), chunk_weights); chunk_hidden_states_g = ::apply_token_weights(std::move(chunk_hidden_states_g), chunk_weights);
@ -888,7 +872,10 @@ struct SD3CLIPEmbedder : public Conditioner {
nullptr, nullptr,
max_token_idx, max_token_idx,
true, true,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!pooled_g.empty()); GGML_ASSERT(!pooled_g.empty());
} }
} else { } else {
@ -910,7 +897,10 @@ struct SD3CLIPEmbedder : public Conditioner {
chunk_hidden_states_t5 = t5->compute(n_threads, chunk_hidden_states_t5 = t5->compute(n_threads,
input_ids, input_ids,
sd::Tensor<float>()); sd::Tensor<float>(),
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states_t5.empty()); GGML_ASSERT(!chunk_hidden_states_t5.empty());
chunk_hidden_states_t5 = ::apply_token_weights(std::move(chunk_hidden_states_t5), chunk_weights); chunk_hidden_states_t5 = ::apply_token_weights(std::move(chunk_hidden_states_t5), chunk_weights);
} else { } else {
@ -1009,40 +999,6 @@ struct FluxCLIPEmbedder : public Conditioner {
} }
} }
bool alloc_params_buffer() override {
if (clip_l) {
if (!clip_l->alloc_params_buffer()) {
return false;
}
}
if (t5) {
if (!t5->alloc_params_buffer()) {
return false;
}
}
return true;
}
void free_params_buffer() override {
if (clip_l) {
clip_l->free_params_buffer();
}
if (t5) {
t5->free_params_buffer();
}
}
size_t get_params_buffer_size() override {
size_t buffer_size = 0;
if (clip_l) {
buffer_size += clip_l->get_params_buffer_size();
}
if (t5) {
buffer_size += t5->get_params_buffer_size();
}
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) { if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes); clip_l->set_max_graph_vram_bytes(max_vram_bytes);
@ -1070,7 +1026,7 @@ struct FluxCLIPEmbedder : public Conditioner {
} }
} }
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
} }
@ -1079,6 +1035,24 @@ struct FluxCLIPEmbedder : public Conditioner {
} }
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
if (clip_l) {
clip_l->set_weight_manager(manager);
}
if (t5) {
t5->set_weight_manager(manager);
}
}
void runner_done() override {
if (clip_l) {
clip_l->runner_done();
}
if (t5) {
t5->runner_done();
}
}
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text, std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
size_t min_length = 0, size_t min_length = 0,
size_t max_length = 0) { size_t max_length = 0) {
@ -1177,7 +1151,10 @@ struct FluxCLIPEmbedder : public Conditioner {
nullptr, nullptr,
max_token_idx, max_token_idx,
true, true,
clip_skip); clip_skip,
false,
true,
true);
GGML_ASSERT(!pooled.empty()); GGML_ASSERT(!pooled.empty());
} else { } else {
pooled = sd::Tensor<float>::zeros({768}); pooled = sd::Tensor<float>::zeros({768});
@ -1195,7 +1172,10 @@ struct FluxCLIPEmbedder : public Conditioner {
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(chunk_tokens.size())}, chunk_tokens); sd::Tensor<int32_t> input_ids({static_cast<int64_t>(chunk_tokens.size())}, chunk_tokens);
chunk_hidden_states = t5->compute(n_threads, chunk_hidden_states = t5->compute(n_threads,
input_ids, input_ids,
sd::Tensor<float>()); sd::Tensor<float>(),
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states.empty()); GGML_ASSERT(!chunk_hidden_states.empty());
chunk_hidden_states = ::apply_token_weights(std::move(chunk_hidden_states), chunk_weights); chunk_hidden_states = ::apply_token_weights(std::move(chunk_hidden_states), chunk_weights);
if (zero_out_masked) { if (zero_out_masked) {
@ -1266,29 +1246,6 @@ struct T5CLIPEmbedder : public Conditioner {
} }
} }
bool alloc_params_buffer() override {
if (t5) {
if (!t5->alloc_params_buffer()) {
return false;
}
}
return true;
}
void free_params_buffer() override {
if (t5) {
t5->free_params_buffer();
}
}
size_t get_params_buffer_size() override {
size_t buffer_size = 0;
if (t5) {
buffer_size += t5->get_params_buffer_size();
}
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (t5) { if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes); t5->set_max_graph_vram_bytes(max_vram_bytes);
@ -1313,6 +1270,18 @@ struct T5CLIPEmbedder : public Conditioner {
} }
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
if (t5) {
t5->set_weight_manager(manager);
}
}
void runner_done() override {
if (t5) {
t5->runner_done();
}
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
size_t min_length = 0, size_t min_length = 0,
size_t max_length = 0) { size_t max_length = 0) {
@ -1406,7 +1375,10 @@ struct T5CLIPEmbedder : public Conditioner {
auto chunk_hidden_states = t5->compute(n_threads, auto chunk_hidden_states = t5->compute(n_threads,
input_ids, input_ids,
t5_attn_mask_chunk); t5_attn_mask_chunk,
false,
true,
true);
GGML_ASSERT(!chunk_hidden_states.empty()); GGML_ASSERT(!chunk_hidden_states.empty());
chunk_hidden_states = apply_token_weights(std::move(chunk_hidden_states), chunk_weights); chunk_hidden_states = apply_token_weights(std::move(chunk_hidden_states), chunk_weights);
@ -1465,21 +1437,6 @@ struct AnimaConditioner : public Conditioner {
llm->get_param_tensors(tensors, "text_encoders.llm"); llm->get_param_tensors(tensors, "text_encoders.llm");
} }
bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) {
return false;
}
return true;
}
void free_params_buffer() override {
llm->free_params_buffer();
}
size_t get_params_buffer_size() override {
return llm->get_params_buffer_size();
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes); llm->set_max_graph_vram_bytes(max_vram_bytes);
} }
@ -1496,6 +1453,14 @@ struct AnimaConditioner : public Conditioner {
llm->set_weight_adapter(adapter); llm->set_weight_adapter(adapter);
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
llm->set_weight_manager(manager);
}
void runner_done() override {
llm->runner_done();
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<int>, std::vector<float>> tokenize(std::string text) { std::tuple<std::vector<int>, std::vector<float>, std::vector<int>, std::vector<float>> tokenize(std::string text) {
auto parsed_attention = parse_prompt_attention(text); auto parsed_attention = parse_prompt_attention(text);
@ -1553,7 +1518,11 @@ struct AnimaConditioner : public Conditioner {
input_ids, input_ids,
sd::Tensor<float>(), sd::Tensor<float>(),
{}, {},
{}); {},
false,
false,
true,
true);
GGML_ASSERT(!hidden_states.empty()); GGML_ASSERT(!hidden_states.empty());
hidden_states = apply_token_weights(std::move(hidden_states), qwen_weights); hidden_states = apply_token_weights(std::move(hidden_states), qwen_weights);
auto t5_ids_tensor = sd::Tensor<int32_t>::from_vector(t5_tokens); auto t5_ids_tensor = sd::Tensor<int32_t>::from_vector(t5_tokens);
@ -1617,23 +1586,6 @@ struct LLMEmbedder : public Conditioner {
llm->get_param_tensors(tensors, "text_encoders.llm"); llm->get_param_tensors(tensors, "text_encoders.llm");
} }
bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) {
return false;
}
return true;
}
void free_params_buffer() override {
llm->free_params_buffer();
}
size_t get_params_buffer_size() override {
size_t buffer_size = 0;
buffer_size += llm->get_params_buffer_size();
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes); llm->set_max_graph_vram_bytes(max_vram_bytes);
} }
@ -1652,6 +1604,18 @@ struct LLMEmbedder : public Conditioner {
} }
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
if (llm) {
llm->set_weight_manager(manager);
}
}
void runner_done() override {
if (llm) {
llm->runner_done();
}
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
const std::pair<int, int>& attn_range, const std::pair<int, int>& attn_range,
size_t min_length = 0, size_t min_length = 0,
@ -1747,7 +1711,11 @@ struct LLMEmbedder : public Conditioner {
input_ids, input_ids,
attention_mask, attention_mask,
image_embeds, image_embeds,
out_layers); out_layers,
false,
false,
true,
true);
GGML_ASSERT(!hidden_states.empty()); GGML_ASSERT(!hidden_states.empty());
hidden_states = apply_token_weights(std::move(hidden_states), weights); hidden_states = apply_token_weights(std::move(hidden_states), weights);
GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx); GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx);
@ -1825,7 +1793,7 @@ struct LLMEmbedder : public Conditioner {
auto resized_image = clip_preprocess(image, w_bar, h_bar); auto resized_image = clip_preprocess(image, w_bar, h_bar);
auto image_embed = llm->encode_image(n_threads, resized_image); auto image_embed = llm->encode_image(n_threads, resized_image, false, true, true);
GGML_ASSERT(!image_embed.empty()); GGML_ASSERT(!image_embed.empty());
image_embeds.emplace_back(image_embed_idx, image_embed); image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + static_cast<int>(image_embed.shape()[1]) + 6; image_embed_idx += 1 + static_cast<int>(image_embed.shape()[1]) + 6;
@ -1895,7 +1863,7 @@ struct LLMEmbedder : public Conditioner {
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, height, width, h_bar, w_bar); LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, height, width, h_bar, w_bar);
auto resized_image = clip_preprocess(image, w_bar, h_bar); auto resized_image = clip_preprocess(image, w_bar, h_bar);
auto image_embed = llm->encode_image(n_threads, resized_image); auto image_embed = llm->encode_image(n_threads, resized_image, false, true, true);
GGML_ASSERT(!image_embed.empty()); GGML_ASSERT(!image_embed.empty());
image_embeds.emplace_back(image_embed_idx, image_embed); image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + static_cast<int>(image_embed.shape()[1]) + 6; image_embed_idx += 1 + static_cast<int>(image_embed.shape()[1]) + 6;
@ -2163,11 +2131,15 @@ struct LTXAVTextProjectionRunner : public GGMLRunner {
return gf; return gf;
} }
sd::Tensor<float> compute(int n_threads, const sd::Tensor<float>& x) { sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
bool auto_free = true,
bool free_compute_buffer = true,
bool free_compute_params = true) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x); return build_graph(x);
}; };
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true)); return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params));
} }
}; };
@ -2205,25 +2177,6 @@ struct LTXAVEmbedder : public Conditioner {
projector->get_param_tensors(tensors, "text_embedding_projection"); projector->get_param_tensors(tensors, "text_embedding_projection");
} }
bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) {
return false;
}
if (!projector->alloc_params_buffer()) {
return false;
}
return true;
}
void free_params_buffer() override {
llm->free_params_buffer();
projector->free_params_buffer();
}
size_t get_params_buffer_size() override {
return llm->get_params_buffer_size() + projector->get_params_buffer_size();
}
void set_flash_attention_enabled(bool enabled) override { void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled); llm->set_flash_attention_enabled(enabled);
projector->set_flash_attention_enabled(enabled); projector->set_flash_attention_enabled(enabled);
@ -2239,6 +2192,16 @@ struct LTXAVEmbedder : public Conditioner {
projector->set_weight_adapter(adapter); projector->set_weight_adapter(adapter);
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
llm->set_weight_manager(manager);
projector->set_weight_manager(manager);
}
void runner_done() override {
llm->runner_done();
projector->runner_done();
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
const std::pair<int, int>& attn_range) { const std::pair<int, int>& attn_range) {
std::vector<std::pair<std::string, float>> parsed_attention; std::vector<std::pair<std::string, float>> parsed_attention;
@ -2302,6 +2265,9 @@ struct LTXAVEmbedder : public Conditioner {
attention_mask, attention_mask,
{}, {},
{}, {},
true,
false,
true,
true); true);
GGML_ASSERT(!hidden_states.empty()); GGML_ASSERT(!hidden_states.empty());
hidden_states = apply_token_weights(std::move(hidden_states), weights); hidden_states = apply_token_weights(std::move(hidden_states), weights);
@ -2361,7 +2327,7 @@ struct LTXAVEmbedder : public Conditioner {
} }
hidden_states.reshape_({kNumStates * kHiddenSize, valid_tokens}); hidden_states.reshape_({kNumStates * kHiddenSize, valid_tokens});
return projector->compute(n_threads, hidden_states); return projector->compute(n_threads, hidden_states, false, true, true);
} }
SDCondition get_learned_condition(int n_threads, SDCondition get_learned_condition(int n_threads,

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,9 @@ namespace sd::ggml_graph_cut {
if (tensor == nullptr) { if (tensor == nullptr) {
return false; return false;
} }
return params_tensor_set.find(tensor) != params_tensor_set.end(); return params_tensor_set.find(tensor) != params_tensor_set.end() ||
(tensor->view_src != nullptr &&
params_tensor_set.find(tensor->view_src) != params_tensor_set.end());
} }
static int graph_node_index_by_name(ggml_cgraph* gf, const char* name) { static int graph_node_index_by_name(ggml_cgraph* gf, const char* name) {
@ -135,6 +137,24 @@ namespace sd::ggml_graph_cut {
return max_vram_bytes_to_gib(resolve_auto_max_vram_bytes(-max_vram, backend)); return max_vram_bytes_to_gib(resolve_auto_max_vram_bytes(-max_vram, backend));
} }
static bool is_segment_output_needed_after(const Plan& plan,
size_t end_segment_index,
int output_node_index) {
if (end_segment_index + 1 >= plan.segments.size()) {
return false;
}
for (size_t seg_idx = end_segment_index + 1; seg_idx < plan.segments.size(); ++seg_idx) {
const auto& segment = plan.segments[seg_idx];
for (const auto& input_ref : segment.input_refs) {
if (input_ref.type == Segment::INPUT_PREVIOUS_CUT &&
input_ref.node_index == output_node_index) {
return true;
}
}
}
return false;
}
static Segment make_segment_seed(const Plan& plan, static Segment make_segment_seed(const Plan& plan,
size_t start_segment_index, size_t start_segment_index,
size_t end_segment_index) { size_t end_segment_index) {
@ -147,8 +167,11 @@ namespace sd::ggml_graph_cut {
const auto& target_segment = plan.segments[end_segment_index]; const auto& target_segment = plan.segments[end_segment_index];
std::unordered_set<int> seen_output_node_indices; std::unordered_set<int> seen_output_node_indices;
for (size_t seg_idx = start_segment_index; seg_idx <= end_segment_index; ++seg_idx) { for (size_t seg_idx = start_segment_index; seg_idx <= end_segment_index; ++seg_idx) {
const bool is_boundary_segment = seg_idx == end_segment_index;
for (int output_node_index : plan.segments[seg_idx].output_node_indices) { for (int output_node_index : plan.segments[seg_idx].output_node_indices) {
if (seen_output_node_indices.insert(output_node_index).second) { if ((is_boundary_segment ||
is_segment_output_needed_after(plan, end_segment_index, output_node_index)) &&
seen_output_node_indices.insert(output_node_index).second) {
seed.output_node_indices.push_back(output_node_index); seed.output_node_indices.push_back(output_node_index);
} }
} }
@ -400,23 +423,6 @@ namespace sd::ggml_graph_cut {
return tensors; return tensors;
} }
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc) {
std::vector<ggml_tensor*> tensors = param_tensors(gf, segment);
std::vector<ggml_tensor*> filtered_tensors;
filtered_tensors.reserve(tensors.size());
for (ggml_tensor* tensor : tensors) {
if (tensor_buffer(tensor) == nullptr) {
LOG_WARN("%s graph cut skipping param input without buffer: segment=%s tensor=%s",
log_desc == nullptr ? "unknown" : log_desc,
segment.group_name.c_str(),
tensor->name);
continue;
}
filtered_tensors.push_back(tensor);
}
return filtered_tensors;
}
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf, std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
const Plan& plan, const Plan& plan,
size_t current_segment_index) { size_t current_segment_index) {
@ -487,6 +493,44 @@ namespace sd::ggml_graph_cut {
return 0; return 0;
} }
struct TensorRuntimeBinding {
ggml_backend_buffer_t buffer = nullptr;
void* data = nullptr;
void* extra = nullptr;
};
std::unordered_map<ggml_tensor*, TensorRuntimeBinding> saved_bindings;
auto mark_measurement_external = [&](ggml_tensor* tensor) {
if (tensor == nullptr) {
return;
}
auto save_tensor = [&](ggml_tensor* t) {
if (t == nullptr || saved_bindings.find(t) != saved_bindings.end()) {
return;
}
saved_bindings[t] = {t->buffer, t->data, t->extra};
// During real execution params and previous-cut inputs already
// have backend/cache buffers, so gallocr must not reserve them.
t->data = reinterpret_cast<void*>(static_cast<uintptr_t>(1));
};
save_tensor(tensor);
save_tensor(tensor->view_src);
};
for (const auto& input : segment.input_refs) {
if (input.type != Segment::INPUT_PARAM &&
input.type != Segment::INPUT_PREVIOUS_CUT) {
continue;
}
mark_measurement_external(input_tensor(gf, input));
}
std::unordered_map<ggml_tensor*, int32_t> saved_output_flags;
for (int output_node_index : segment.output_node_indices) {
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
if (output != nullptr && saved_output_flags.find(output) == saved_output_flags.end()) {
saved_output_flags[output] = output->flags;
}
}
ggml_context* graph_ctx = nullptr; ggml_context* graph_ctx = nullptr;
ggml_cgraph* segment_graph = build_segment_graph(gf, segment, &graph_ctx); ggml_cgraph* segment_graph = build_segment_graph(gf, segment, &graph_ctx);
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
@ -502,6 +546,14 @@ namespace sd::ggml_graph_cut {
ggml_gallocr_free(allocr); ggml_gallocr_free(allocr);
ggml_free(graph_ctx); ggml_free(graph_ctx);
for (const auto& kv : saved_output_flags) {
kv.first->flags = kv.second;
}
for (const auto& kv : saved_bindings) {
kv.first->buffer = kv.second.buffer;
kv.first->data = kv.second.data;
kv.first->extra = kv.second.extra;
}
return buffer_size; return buffer_size;
} }
@ -669,7 +721,8 @@ namespace sd::ggml_graph_cut {
GGML_ASSERT(!candidate_plan.segments.empty()); GGML_ASSERT(!candidate_plan.segments.empty());
const auto& candidate_segment = candidate_plan.segments.back(); const auto& candidate_segment = candidate_plan.segments.back();
if (graph_cut_segment_vram_bytes(candidate_segment) > max_graph_vram_bytes) { const size_t candidate_bytes = graph_cut_segment_vram_bytes(candidate_segment);
if (candidate_bytes > max_graph_vram_bytes) {
break; break;
} }

View File

@ -80,7 +80,6 @@ namespace sd::ggml_graph_cut {
ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index); ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index);
ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref); ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref);
std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment); std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment);
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc);
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf, std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
const Plan& plan, const Plan& plan,
size_t current_segment_index); size_t current_segment_index);

View File

@ -1,132 +0,0 @@
#include "core/layer_registry.h"
#include <utility>
#include "core/util.h"
namespace sd::layer_registry {
void LayerRegistry::register_layer(const std::string& name, ggml_tensor* tensor) {
auto& info = layers_[name];
info.tensors.push_back(tensor);
info.bytes += ggml_nbytes(tensor);
}
bool LayerRegistry::move_layer_to_gpu(const std::string& name) {
auto it = layers_.find(name);
if (it == layers_.end())
return false;
LayerInfo& info = it->second;
if (info.on_gpu)
return true;
if (gpu_backend_ == nullptr || cpu_backend_ == nullptr) {
LOG_ERROR("layer_registry: backends not set; cannot move '%s' to GPU",
name.c_str());
return false;
}
if (info.tensors.empty()) {
info.on_gpu = true;
return true;
}
// 1. Build a no_alloc context big enough to hold one twin tensor per CPU
// tensor, plus a little overhead.
const size_t ctx_size = info.tensors.size() * ggml_tensor_overhead() + 1024;
ggml_init_params ctx_params{ctx_size, /*mem_buffer=*/nullptr, /*no_alloc=*/true};
ggml_context* twin_ctx = ggml_init(ctx_params);
if (twin_ctx == nullptr) {
LOG_ERROR("layer_registry: failed to allocate twin context for '%s'",
name.c_str());
return false;
}
// 2. Create one GPU twin per CPU tensor. The twin shares the original
// name so any name-based lookup keeps working.
std::vector<ggml_tensor*> gpu_twins;
gpu_twins.reserve(info.tensors.size());
for (ggml_tensor* cpu_t : info.tensors) {
ggml_tensor* twin = ggml_dup_tensor(twin_ctx, cpu_t);
if (cpu_t->name[0] != '\0') {
ggml_set_name(twin, cpu_t->name);
}
gpu_twins.push_back(twin);
}
// 3. Back the twins with a GPU buffer in one alloc call.
ggml_backend_buffer_t gpu_buffer = ggml_backend_alloc_ctx_tensors(twin_ctx, gpu_backend_);
if (gpu_buffer == nullptr) {
LOG_ERROR("layer_registry: failed to allocate GPU buffer for '%s'",
name.c_str());
ggml_free(twin_ctx);
return false;
}
// 4. H2D copy + sync.
for (size_t i = 0; i < info.tensors.size(); ++i) {
ggml_backend_tensor_copy(info.tensors[i], gpu_twins[i]);
}
ggml_backend_synchronize(gpu_backend_);
// 5. Swap buffer/data/extra so the originals now point at GPU memory.
for (size_t i = 0; i < info.tensors.size(); ++i) {
std::swap(info.tensors[i]->buffer, gpu_twins[i]->buffer);
std::swap(info.tensors[i]->data, gpu_twins[i]->data);
std::swap(info.tensors[i]->extra, gpu_twins[i]->extra);
}
info.gpu_twins = std::move(gpu_twins);
info.twin_ctx = twin_ctx;
info.gpu_buffer = gpu_buffer;
info.on_gpu = true;
return true;
}
bool LayerRegistry::move_layer_to_cpu(const std::string& name) {
auto it = layers_.find(name);
if (it == layers_.end())
return false;
LayerInfo& info = it->second;
if (!info.on_gpu)
return true;
if (info.tensors.size() != info.gpu_twins.size()) {
LOG_ERROR("layer_registry: twin/tensor count mismatch for '%s'",
name.c_str());
return false;
}
// 1. Swap back: originals point at CPU memory again.
for (size_t i = 0; i < info.tensors.size(); ++i) {
if (info.gpu_twins[i] == nullptr)
continue;
std::swap(info.tensors[i]->buffer, info.gpu_twins[i]->buffer);
std::swap(info.tensors[i]->data, info.gpu_twins[i]->data);
std::swap(info.tensors[i]->extra, info.gpu_twins[i]->extra);
}
// 2. Free the GPU buffer + twin context.
if (info.gpu_buffer != nullptr) {
ggml_backend_buffer_free(info.gpu_buffer);
info.gpu_buffer = nullptr;
}
if (info.twin_ctx != nullptr) {
ggml_free(info.twin_ctx);
info.twin_ctx = nullptr;
}
info.gpu_twins.clear();
info.on_gpu = false;
return true;
}
bool LayerRegistry::is_layer_on_gpu(const std::string& name) const {
auto it = layers_.find(name);
return it != layers_.end() && it->second.on_gpu;
}
size_t LayerRegistry::get_layer_size(const std::string& name) const {
auto it = layers_.find(name);
return it != layers_.end() ? it->second.bytes : 0;
}
} // namespace sd::layer_registry

View File

@ -1,50 +0,0 @@
#ifndef __SD_CORE_LAYER_REGISTRY_H__
#define __SD_CORE_LAYER_REGISTRY_H__
#include <map>
#include <set>
#include <string>
#include <vector>
#include "ggml-backend.h"
#include "ggml.h"
namespace sd::layer_registry {
struct LayerInfo {
std::vector<ggml_tensor*> tensors;
std::vector<ggml_tensor*> gpu_twins;
ggml_context* twin_ctx = nullptr;
ggml_backend_buffer_t gpu_buffer = nullptr;
bool on_gpu = false;
size_t bytes = 0;
};
class LayerRegistry {
public:
LayerRegistry() = default;
LayerRegistry(ggml_backend_t gpu_backend, ggml_backend_t cpu_backend)
: gpu_backend_(gpu_backend), cpu_backend_(cpu_backend) {}
void set_backends(ggml_backend_t gpu_backend, ggml_backend_t cpu_backend) {
gpu_backend_ = gpu_backend;
cpu_backend_ = cpu_backend;
}
void register_layer(const std::string& name, ggml_tensor* tensor);
bool move_layer_to_gpu(const std::string& name);
bool move_layer_to_cpu(const std::string& name);
bool is_layer_on_gpu(const std::string& name) const;
size_t get_layer_size(const std::string& name) const;
size_t get_layer_count() const { return layers_.size(); }
const std::map<std::string, LayerInfo>& layers() const { return layers_; }
private:
ggml_backend_t gpu_backend_ = nullptr;
ggml_backend_t cpu_backend_ = nullptr;
std::map<std::string, LayerInfo> layers_;
};
} // namespace sd::layer_registry
#endif // __SD_CORE_LAYER_REGISTRY_H__

View File

@ -488,7 +488,7 @@ bool parse_strict_bool(const std::string& text, bool& value) {
return false; return false;
} }
static std::string build_progress_bar(int step, int steps) { static std::string build_progress_bar(int step, int steps, char progress_char = '=', bool show_head = true) {
std::string progress = " |"; std::string progress = " |";
int max_progress = 50; int max_progress = 50;
int32_t current = 0; int32_t current = 0;
@ -498,21 +498,21 @@ static std::string build_progress_bar(int step, int steps) {
for (int i = 0; i < 50; i++) { for (int i = 0; i < 50; i++) {
if (i > current) { if (i > current) {
progress += " "; progress += " ";
} else if (i == current && i != max_progress - 1) { } else if (show_head && i == current && i != max_progress - 1) {
progress += ">"; progress += ">";
} else { } else {
progress += "="; progress += progress_char;
} }
} }
progress += "|"; progress += "|";
return progress; return progress;
} }
static void print_progress_line(int step, int steps, const std::string& speed_text) { static void print_progress_line(int step, int steps, const std::string& speed_text, char progress_char = '=', bool show_head = true) {
if (step == 0) { if (step == 0) {
return; return;
} }
std::string progress = build_progress_bar(step, steps); std::string progress = build_progress_bar(step, steps, progress_char, show_head);
const char* lf = (step == steps ? "\n" : ""); const char* lf = (step == steps ? "\n" : "");
printf("\r%s %i/%i - %s\033[K%s", progress.c_str(), step, steps, speed_text.c_str(), lf); printf("\r%s %i/%i - %s\033[K%s", progress.c_str(), step, steps, speed_text.c_str(), lf);
fflush(stdout); // for linux fflush(stdout); // for linux
@ -552,9 +552,9 @@ void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float
double speed_mb = bytes_per_second / (1024.0 * 1024.0); double speed_mb = bytes_per_second / (1024.0 * 1024.0);
if (speed_mb >= 1024.0) { if (speed_mb >= 1024.0) {
print_progress_line(step, steps, sd_format("%.2fGB/s", speed_mb / 1024.0)); print_progress_line(step, steps, sd_format("%.2fGB/s", speed_mb / 1024.0), '#', false);
} else { } else {
print_progress_line(step, steps, sd_format("%.2fMB/s", speed_mb)); print_progress_line(step, steps, sd_format("%.2fMB/s", speed_mb), '#', false);
} }
} }

View File

@ -6,10 +6,12 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <vector>
#include "conditioning/conditioner.hpp" #include "conditioning/conditioner.hpp"
#include "core/ggml_extend_backend.h" #include "core/ggml_extend_backend.h"
#include "model_loader.h" #include "model_loader.h"
#include "model_manager.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
struct GenerationExtensionInitContext { struct GenerationExtensionInitContext {
@ -23,21 +25,12 @@ struct GenerationExtensionInitContext {
std::function<ggml_backend_t(SDBackendModule)> params_backend_for; std::function<ggml_backend_t(SDBackendModule)> params_backend_for;
}; };
struct GenerationExtensionTensorContext {
std::map<std::string, ggml_tensor*>& tensors;
std::map<std::string, ggml_tensor*>& mmap_able_tensors;
std::function<bool(SDBackendModule)> module_can_mmap;
};
struct GenerationExtensionConditionContext { struct GenerationExtensionConditionContext {
Conditioner* conditioner; Conditioner* conditioner;
ConditionerParams& condition_params; ConditionerParams& condition_params;
const sd_pm_params_t& pm_params; const sd_pm_params_t& pm_params;
std::map<std::string, ggml_tensor*>& tensors;
SDVersion version;
int n_threads; int n_threads;
int total_steps; int total_steps;
bool free_params_immediately;
}; };
struct GenerationExtension { struct GenerationExtension {
@ -50,14 +43,11 @@ struct GenerationExtension {
virtual bool init(const GenerationExtensionInitContext&) { virtual bool init(const GenerationExtensionInitContext&) {
return true; return true;
} }
virtual void collect_param_tensors(GenerationExtensionTensorContext&) {} virtual void get_param_tensors(std::map<std::string, ggml_tensor*>&) {}
virtual void collect_loras(std::vector<ModelManager::LoraSpec>&) {}
virtual void add_ignore_tensors(std::set<std::string>&) const {} virtual void add_ignore_tensors(std::set<std::string>&) const {}
virtual bool alloc_params_buffer() { virtual void set_weight_manager(const std::shared_ptr<RunnerWeightManager>&) {}
return true; virtual void runner_done() {}
}
virtual size_t get_params_buffer_size() const {
return 0;
}
virtual void reset_runtime_condition() {} virtual void reset_runtime_condition() {}
virtual bool prepare_condition(GenerationExtensionConditionContext&) { virtual bool prepare_condition(GenerationExtensionConditionContext&) {
return false; return false;

View File

@ -7,7 +7,6 @@
#include "core/tensor_ggml.hpp" #include "core/tensor_ggml.hpp"
#include "core/util.h" #include "core/util.h"
#include "model/adapter/lora.hpp"
#include "model/adapter/pmid.hpp" #include "model/adapter/pmid.hpp"
static std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>> static std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
@ -103,7 +102,6 @@ static std::string remove_photomaker_trigger_from_prompt(FrozenCLIPEmbedderWithC
struct PhotoMakerExtension : public GenerationExtension { struct PhotoMakerExtension : public GenerationExtension {
std::shared_ptr<PhotoMakerIDEncoder> pmid_model; std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora;
bool enabled = false; bool enabled = false;
std::string model_path; std::string model_path;
std::string trigger_word = "img"; std::string trigger_word = "img";
@ -129,7 +127,13 @@ struct PhotoMakerExtension : public GenerationExtension {
} }
PMVersion pm_version = std::strstr(model_path.c_str(), "v2") != nullptr ? PM_VERSION_2 : PM_VERSION_1; PMVersion pm_version = std::strstr(model_path.c_str(), "v2") != nullptr ? PM_VERSION_2 : PM_VERSION_1;
pmid_model = std::make_shared<PhotoMakerIDEncoder>(ctx.backend_for(SDBackendModule::PHOTOMAKER), LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", model_path.c_str());
if (!ctx.model_loader.init_from_file_and_convert_name(model_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", model_path.c_str());
return true;
}
pmid_model = std::make_shared<PhotoMakerIDEncoder>(ctx.backend_for(SDBackendModule::PHOTOMAKER),
ctx.params_backend_for(SDBackendModule::PHOTOMAKER), ctx.params_backend_for(SDBackendModule::PHOTOMAKER),
ctx.tensor_storage_map, ctx.tensor_storage_map,
"pmid", "pmid",
@ -139,44 +143,28 @@ struct PhotoMakerExtension : public GenerationExtension {
LOG_INFO("using PhotoMaker Version 2"); LOG_INFO("using PhotoMaker Version 2");
} }
pmid_lora = std::make_shared<LoraModel>("pmid",
ctx.backend_for(SDBackendModule::PHOTOMAKER),
ctx.params_backend_for(SDBackendModule::PHOTOMAKER),
model_path,
"",
ctx.version);
auto lora_tensor_filter = [&](const std::string& tensor_name) {
return starts_with(tensor_name, "lora.model");
};
if (!pmid_lora->load_from_file(ctx.n_threads, lora_tensor_filter)) {
LOG_WARN("load photomaker lora tensors from %s failed", model_path.c_str());
return false;
}
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", model_path.c_str());
if (!ctx.model_loader.init_from_file_and_convert_name(model_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", model_path.c_str());
return true;
}
enabled = true; enabled = true;
return true; return true;
} }
void collect_param_tensors(GenerationExtensionTensorContext& ctx) override { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
if (!enabled || pmid_model == nullptr) { if (!enabled || pmid_model == nullptr) {
return; return;
} }
std::map<std::string, ggml_tensor*> temp; pmid_model->get_param_tensors(tensors, "pmid");
pmid_model->get_param_tensors(temp, "pmid"); }
bool do_mmap = ctx.module_can_mmap(SDBackendModule::PHOTOMAKER);
for (const auto& [key, tensor] : temp) { void collect_loras(std::vector<ModelManager::LoraSpec>& loras) override {
ctx.tensors[key] = tensor; if (!enabled || model_path.empty()) {
if (do_mmap) { return;
ctx.mmap_able_tensors[key] = tensor;
}
} }
ModelManager::LoraSpec lora;
lora.path = model_path;
lora.multiplier = 1.0f;
lora.tensor_name_prefix_filter = "lora.model";
lora.required = true;
loras.push_back(std::move(lora));
} }
void add_ignore_tensors(std::set<std::string>& ignore_tensors) const override { void add_ignore_tensors(std::set<std::string>& ignore_tensors) const override {
@ -186,18 +174,16 @@ struct PhotoMakerExtension : public GenerationExtension {
ignore_tensors.insert("pmid.unet."); ignore_tensors.insert("pmid.unet.");
} }
bool alloc_params_buffer() override { void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
if (!enabled || pmid_model == nullptr) { if (pmid_model != nullptr) {
return true; pmid_model->set_weight_manager(manager);
} }
return pmid_model->alloc_params_buffer();
} }
size_t get_params_buffer_size() const override { void runner_done() override {
if (!enabled || pmid_model == nullptr) { if (pmid_model != nullptr) {
return 0; pmid_model->runner_done();
} }
return pmid_model->get_params_buffer_size();
} }
void reset_runtime_condition() override { void reset_runtime_condition() override {
@ -207,21 +193,10 @@ struct PhotoMakerExtension : public GenerationExtension {
bool prepare_condition(GenerationExtensionConditionContext& ctx) override { bool prepare_condition(GenerationExtensionConditionContext& ctx) override {
reset_runtime_condition(); reset_runtime_condition();
if (!enabled || pmid_model == nullptr || pmid_lora == nullptr) { if (!enabled || pmid_model == nullptr) {
return false; return false;
} }
if (!pmid_lora->applied) {
int64_t t0 = ggml_time_ms();
pmid_lora->apply(ctx.tensors, ctx.version, ctx.n_threads);
int64_t t1 = ggml_time_ms();
pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (ctx.free_params_immediately) {
pmid_lora->free_params_buffer();
}
}
bool pmv2 = pmid_model->get_version() == PM_VERSION_2; bool pmv2 = pmid_model->get_version() == PM_VERSION_2;
if (ctx.pm_params.id_images_count <= 0 || ctx.pm_params.id_images == nullptr) { if (ctx.pm_params.id_images_count <= 0 || ctx.pm_params.id_images == nullptr) {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
@ -305,9 +280,6 @@ struct PhotoMakerExtension : public GenerationExtension {
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
if (ctx.free_params_immediately) {
pmid_model->free_params_buffer();
}
return true; return true;
} }

View File

@ -71,7 +71,8 @@ struct LoraModel : public GGMLRunner {
return true; return true;
}; };
model_loader.load_tensors(on_new_tensor_cb, n_threads); model_loader.set_n_threads(n_threads);
model_loader.load_tensors(on_new_tensor_cb);
if (tensors_to_create.empty()) { if (tensors_to_create.empty()) {
return true; return true;
@ -93,19 +94,39 @@ struct LoraModel : public GGMLRunner {
} }
dry_run = false; dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, n_threads); model_loader.load_tensors(on_new_tensor_cb);
LOG_DEBUG("finished loaded lora"); LOG_DEBUG("finished loaded lora");
return true; return true;
} }
void preprocess_lora_tensors(const std::map<std::string, ggml_tensor*>& model_tensors) { void release_loaded_tensors() {
free_compute_buffer();
free_params_buffer();
free_params_ctx();
alloc_params_ctx();
lora_tensors.clear();
original_tensor_to_final_tensor.clear();
applied_lora_tensors.clear();
applied = false;
tensor_preprocessed = false;
}
static std::set<std::string> tensor_names(const std::map<std::string, ggml_tensor*>& model_tensors) {
std::set<std::string> names;
for (const auto& item : model_tensors) {
names.insert(item.first);
}
return names;
}
void preprocess_lora_tensors(const std::set<std::string>& model_tensor_names) {
if (tensor_preprocessed) { if (tensor_preprocessed) {
return; return;
} }
tensor_preprocessed = true; tensor_preprocessed = true;
// I really hate these hardcoded processes. // I really hate these hardcoded processes.
if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) { if (model_tensor_names.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensor_names.end()) {
std::unordered_map<std::string, ggml_tensor*> new_lora_tensors; std::unordered_map<std::string, ggml_tensor*> new_lora_tensors;
for (auto& [old_name, tensor] : lora_tensors) { for (auto& [old_name, tensor] : lora_tensors) {
std::string new_name = old_name; std::string new_name = old_name;
@ -753,11 +774,13 @@ struct LoraModel : public GGMLRunner {
return out_diff; return out_diff;
} }
ggml_cgraph* build_lora_graph(const std::map<std::string, ggml_tensor*>& model_tensors, SDVersion version) { ggml_cgraph* build_lora_graph(const std::map<std::string, ggml_tensor*>& model_tensors,
const std::set<std::string>& model_tensor_names,
SDVersion version) {
size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10;
ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false);
preprocess_lora_tensors(model_tensors); preprocess_lora_tensors(model_tensor_names);
original_tensor_to_final_tensor.clear(); original_tensor_to_final_tensor.clear();
applied_lora_tensors.clear(); applied_lora_tensors.clear();
@ -794,12 +817,16 @@ struct LoraModel : public GGMLRunner {
return gf; return gf;
} }
void apply(std::map<std::string, ggml_tensor*> model_tensors, SDVersion version, int n_threads) { void apply(std::map<std::string, ggml_tensor*> model_tensors,
const std::set<std::string>& model_tensor_names,
SDVersion version,
int n_threads,
bool warn_unused = true) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_lora_graph(model_tensors, version); return build_lora_graph(model_tensors, model_tensor_names, version);
}; };
GGMLRunner::compute<float>(get_graph, n_threads, false, true); GGMLRunner::compute<float>(get_graph, n_threads, false, false, false, true);
stat(); stat(!warn_unused);
for (auto item : original_tensor_to_final_tensor) { for (auto item : original_tensor_to_final_tensor) {
ggml_tensor* original_tensor = item.first; ggml_tensor* original_tensor = item.first;
ggml_tensor* final_tensor = item.second; ggml_tensor* final_tensor = item.second;
@ -810,6 +837,10 @@ struct LoraModel : public GGMLRunner {
GGMLRunner::free_compute_buffer(); GGMLRunner::free_compute_buffer();
} }
void apply(std::map<std::string, ggml_tensor*> model_tensors, SDVersion version, int n_threads, bool warn_unused = true) {
apply(model_tensors, tensor_names(model_tensors), version, n_threads, warn_unused);
}
void stat(bool at_runntime = false) { void stat(bool at_runntime = false) {
size_t total_lora_tensors_count = 0; size_t total_lora_tensors_count = 0;
size_t applied_lora_tensors_count = 0; size_t applied_lora_tensors_count = 0;

View File

@ -558,7 +558,7 @@ public:
return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds); return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds);
}; };
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true)); return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true, true, true));
} }
}; };
@ -616,14 +616,15 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
return true; return true;
}; };
model_loader->load_tensors(on_new_tensor_cb, n_threads); model_loader->set_n_threads(n_threads);
model_loader->load_tensors(on_new_tensor_cb);
if (!alloc_params_buffer()) { if (!alloc_params_buffer()) {
LOG_ERROR("PhotoMaker ID embeds buffer allocation failed"); LOG_ERROR("PhotoMaker ID embeds buffer allocation failed");
return false; return false;
} }
dry_run = false; dry_run = false;
model_loader->load_tensors(on_new_tensor_cb, n_threads); model_loader->load_tensors(on_new_tensor_cb);
LOG_DEBUG("finished loading PhotoMaker ID Embeds "); LOG_DEBUG("finished loading PhotoMaker ID Embeds ");
return true; return true;

View File

@ -697,7 +697,7 @@ namespace Anima {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, t5_ids, t5_weights); return build_graph(x, timesteps, context, t5_ids, t5_weights);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -309,6 +309,7 @@ public:
struct ControlNet : public GGMLRunner { struct ControlNet : public GGMLRunner {
SDVersion version = VERSION_SD1; SDVersion version = VERSION_SD1;
ControlNetBlock control_net; ControlNetBlock control_net;
std::string weight_prefix;
ggml_backend_buffer_t control_buffer = nullptr; ggml_backend_buffer_t control_buffer = nullptr;
ggml_context* control_ctx = nullptr; ggml_context* control_ctx = nullptr;
@ -321,9 +322,10 @@ struct ControlNet : public GGMLRunner {
ControlNet(ggml_backend_t backend, ControlNet(ggml_backend_t backend,
ggml_backend_t params_backend, ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1,
: GGMLRunner(backend, params_backend), control_net(version) { const std::string& prefix = "")
control_net.init(params_ctx, tensor_storage_map, ""); : GGMLRunner(backend, params_backend), version(version), control_net(version), weight_prefix(prefix) {
control_net.init(params_ctx, tensor_storage_map, prefix);
} }
~ControlNet() override { ~ControlNet() override {
@ -374,8 +376,8 @@ struct ControlNet : public GGMLRunner {
return "control_net"; return "control_net";
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) {
control_net.get_param_tensors(tensors, prefix); control_net.get_param_tensors(tensors, weight_prefix);
} }
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor, ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
@ -435,7 +437,7 @@ struct ControlNet : public GGMLRunner {
return build_graph(x, hint, timesteps, context, y); return build_graph(x, hint, timesteps, context, y);
}; };
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false); auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false);
if (!compute_result.has_value()) { if (!compute_result.has_value()) {
return std::nullopt; return std::nullopt;
} }
@ -472,7 +474,8 @@ struct ControlNet : public GGMLRunner {
return false; return false;
} }
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads); model_loader.set_n_threads(n_threads);
bool success = model_loader.load_tensors(tensors, ignore_tensors);
if (!success) { if (!success) {
LOG_ERROR("load control net tensors from model loader failed"); LOG_ERROR("load control net tensors from model loader failed");

View File

@ -440,7 +440,7 @@ namespace ErnieImage {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context); return build_graph(x, timesteps, context);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -1500,7 +1500,7 @@ namespace Flux {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers); return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
}; };
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
return result; return result;
} }

View File

@ -323,11 +323,15 @@ namespace HiDreamO1 {
return gf; return gf;
} }
sd::Tensor<float> compute(int n_threads, const sd::Tensor<float>& image) { sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& image,
bool auto_free = true,
bool free_compute_buffer = true,
bool free_compute_params = true) {
auto get_graph = [&]() { auto get_graph = [&]() {
return build_graph(image); return build_graph(image);
}; };
auto output = GGMLRunner::compute<float>(get_graph, n_threads, false); auto output = GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params);
return output.has_value() ? std::move(output.value()) : sd::Tensor<float>(); return output.has_value() ? std::move(output.value()) : sd::Tensor<float>();
} }
}; };
@ -455,7 +459,7 @@ namespace HiDreamO1 {
auto get_graph = [&]() { auto get_graph = [&]() {
return build_graph(x, timestep, input_ids, input_pos, token_types, vinput_mask, image_embeds, ref_images); return build_graph(x, timestep, input_ids, input_pos, token_types, vinput_mask, image_embeds, ref_images);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,
@ -494,21 +498,6 @@ namespace HiDreamO1 {
vision_runner->get_param_tensors(tensors); vision_runner->get_param_tensors(tensors);
} }
bool alloc_params_buffer() override {
if (!vision_runner->alloc_params_buffer()) {
return false;
}
return true;
}
void free_params_buffer() override {
vision_runner->free_params_buffer();
}
size_t get_params_buffer_size() override {
return vision_runner->get_params_buffer_size();
}
void set_max_graph_vram_bytes(size_t max_graph_vram_bytes) override { void set_max_graph_vram_bytes(size_t max_graph_vram_bytes) override {
vision_runner->set_max_graph_vram_bytes(max_graph_vram_bytes); vision_runner->set_max_graph_vram_bytes(max_graph_vram_bytes);
} }
@ -521,6 +510,14 @@ namespace HiDreamO1 {
vision_runner->set_weight_adapter(adapter); vision_runner->set_weight_adapter(adapter);
} }
void set_weight_manager(const std::shared_ptr<RunnerWeightManager>& manager) override {
vision_runner->set_weight_manager(manager);
}
void runner_done() override {
vision_runner->runner_done();
}
SDCondition get_learned_condition(int n_threads, SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override { const ConditionerParams& conditioner_params) override {
SDCondition result; SDCondition result;
@ -666,7 +663,7 @@ namespace HiDreamO1 {
result.c_vinput_mask = sd::Tensor<int32_t>(vinput_mask_shape, std::move(vinput_mask)); result.c_vinput_mask = sd::Tensor<int32_t>(vinput_mask_shape, std::move(vinput_mask));
result.c_image_embeds.reserve(vlm_images.size()); result.c_image_embeds.reserve(vlm_images.size());
for (const auto& vlm_image : vlm_images) { for (const auto& vlm_image : vlm_images) {
auto image_embed = vision_runner->compute(n_threads, vlm_image.second); auto image_embed = vision_runner->compute(n_threads, vlm_image.second, false, true, true);
if (image_embed.empty()) { if (image_embed.empty()) {
LOG_ERROR("hidream_o1 conditioner: encode VLM image failed"); LOG_ERROR("hidream_o1 conditioner: encode VLM image failed");
return SDCondition(); return SDCondition();

View File

@ -537,7 +537,7 @@ namespace Ideogram4 {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, use_uncond_model); return build_graph(x, timesteps, context, use_uncond_model);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -408,7 +408,7 @@ namespace Lens {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context); return build_graph(x, timesteps, context);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -1939,7 +1939,7 @@ namespace LTXV {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, audio_x, audio_timesteps, audio_length, frame_rate, video_positions); return build_graph(x, timesteps, context, audio_x, audio_timesteps, audio_length, frame_rate, video_positions);
}; };
auto out = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); auto out = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
return out; return out;
} }

View File

@ -935,7 +935,7 @@ struct MMDiTRunner : public DiffusionModelRunner {
return build_graph(x, timesteps, context, y, skip_layers); return build_graph(x, timesteps, context, y, skip_layers);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -823,7 +823,7 @@ namespace Pid {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, lq_latent, degrade_sigma); return build_graph(x, timesteps, context, lq_latent, degrade_sigma);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -627,7 +627,7 @@ namespace Qwen {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index); return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -772,7 +772,7 @@ struct UNetModelRunner : public DiffusionModelRunner {
return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength); return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -950,7 +950,7 @@ namespace WAN {
return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength); return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -634,7 +634,7 @@ namespace ZImage {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index); return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
} }
sd::Tensor<float> compute(int n_threads, sd::Tensor<float> compute(int n_threads,

View File

@ -567,11 +567,14 @@ struct CLIPTextModelRunner : public GGMLRunner {
void* custom_embeddings_data, void* custom_embeddings_data,
size_t max_token_idx, size_t max_token_idx,
bool return_pooled, bool return_pooled,
int clip_skip) { int clip_skip,
bool auto_free = true,
bool free_compute_buffer = true,
bool free_compute_params = true) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip); return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
}; };
auto result = GGMLRunner::compute<float>(get_graph, n_threads, true); auto result = GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params);
if (return_pooled) { if (return_pooled) {
return take_or_empty(std::move(result)); return take_or_empty(std::move(result));
} }

View File

@ -1733,7 +1733,10 @@ namespace LLM {
const sd::Tensor<float>& attention_mask, const sd::Tensor<float>& attention_mask,
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds, const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
std::set<int> out_layers, std::set<int> out_layers,
bool return_all_hidden_states = false) { bool return_all_hidden_states = false,
bool auto_free = true,
bool free_compute_buffer = true,
bool free_compute_params = true) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input_ids, return build_graph(input_ids,
attention_mask, attention_mask,
@ -1741,7 +1744,7 @@ namespace LLM {
out_layers, out_layers,
return_all_hidden_states); return_all_hidden_states);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true), return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params),
input_ids.dim() + 1); input_ids.dim() + 1);
} }
@ -1802,11 +1805,14 @@ namespace LLM {
} }
sd::Tensor<float> encode_image(const int n_threads, sd::Tensor<float> encode_image(const int n_threads,
const sd::Tensor<float>& image) { const sd::Tensor<float>& image,
bool auto_free = false,
bool free_compute_buffer = false,
bool free_compute_params = false) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_encode_image_graph(image); return build_encode_image_graph(image);
}; };
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, false)); return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params));
} }
}; };

View File

@ -394,11 +394,14 @@ struct T5Runner : public GGMLRunner {
sd::Tensor<float> compute(const int n_threads, sd::Tensor<float> compute(const int n_threads,
const sd::Tensor<int32_t>& input_ids, const sd::Tensor<int32_t>& input_ids,
const sd::Tensor<float>& attention_mask) { const sd::Tensor<float>& attention_mask,
bool auto_free = true,
bool free_compute_buffer = true,
bool free_compute_params = true) {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input_ids, attention_mask); return build_graph(input_ids, attention_mask);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true), 3); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, auto_free, free_compute_buffer, free_compute_params), 3);
} }
static std::vector<int> _relative_position_bucket(const std::vector<int>& relative_position, static std::vector<int> _relative_position_bucket(const std::vector<int>& relative_position,

View File

@ -336,9 +336,11 @@ struct ESRGAN : public GGMLRunner {
} }
} }
success = model_loader.load_tensors(model_tensors, {}, n_threads); model_loader.set_n_threads(n_threads);
success = model_loader.load_tensors(model_tensors);
} else { } else {
success = model_loader.load_tensors(esrgan_tensors, {}, n_threads); model_loader.set_n_threads(n_threads);
success = model_loader.load_tensors(esrgan_tensors);
} }
if (!success) { if (!success) {
@ -367,7 +369,7 @@ struct ESRGAN : public GGMLRunner {
sd::Tensor<float> compute(const int n_threads, sd::Tensor<float> compute(const int n_threads,
const sd::Tensor<float>& x) { const sd::Tensor<float>& x) {
auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x); }; auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x); };
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim()); auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
return result; return result;
} }
}; };

View File

@ -240,20 +240,25 @@ namespace LTXVUpsampler {
protected: protected:
int64_t channels; int64_t channels;
int stride; int stride;
ggml_tensor* kernel = nullptr;
std::vector<float> kernel_data; std::vector<float> kernel_data;
std::string kernel_name;
void init_params(ggml_context* ctx, void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override { const std::string prefix = "") override {
SD_UNUSED(ctx);
SD_UNUSED(tensor_storage_map); SD_UNUSED(tensor_storage_map);
if (stride == 1) { if (stride == 1) {
return; return;
} }
kernel = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 5, 5, 1, channels); kernel_name = prefix + "kernel";
std::string name = prefix + "kernel"; }
ggml_set_name(kernel, name.c_str());
public:
BlurDownsample(int64_t channels, int stride)
: channels(channels),
stride(stride) {
GGML_ASSERT(stride >= 1);
static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f}; static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f};
kernel_data.resize(static_cast<size_t>(5 * 5 * channels)); kernel_data.resize(static_cast<size_t>(5 * 5 * channels));
for (int64_t c = 0; c < channels; ++c) { for (int64_t c = 0; c < channels; ++c) {
@ -266,26 +271,16 @@ namespace LTXVUpsampler {
} }
} }
public:
BlurDownsample(int64_t channels, int stride)
: channels(channels),
stride(stride) {
GGML_ASSERT(stride >= 1);
}
void load_fixed_tensors() {
if (kernel == nullptr || kernel_data.empty()) {
return;
}
ggml_backend_tensor_set(kernel, kernel_data.data(), 0, kernel_data.size() * sizeof(float));
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
if (stride == 1) { if (stride == 1) {
return x; return x;
} }
GGML_ASSERT(kernel != nullptr); GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(!kernel_data.empty());
GGML_ASSERT(x->ne[2] == channels); GGML_ASSERT(x->ne[2] == channels);
ggml_tensor* kernel = ggml_new_tensor_4d(ctx->ggml_ctx, GGML_TYPE_F32, 5, 5, 1, channels);
ggml_set_name(kernel, kernel_name.empty() ? "blur_down.kernel" : kernel_name.c_str());
ctx->bind_backend_tensor_data(kernel, kernel_data.data());
if (ctx->conv2d_direct_enabled) { if (ctx->conv2d_direct_enabled) {
return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1); return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1);
} }
@ -311,11 +306,6 @@ namespace LTXVUpsampler {
blocks["blur_down"] = std::shared_ptr<GGMLBlock>(new BlurDownsample(mid_channels, den)); blocks["blur_down"] = std::shared_ptr<GGMLBlock>(new BlurDownsample(mid_channels, den));
} }
void load_fixed_tensors() {
auto blur_down = std::dynamic_pointer_cast<BlurDownsample>(blocks["blur_down"]);
blur_down->load_fixed_tensors();
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]); auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["pixel_shuffle"]); auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["pixel_shuffle"]);
@ -426,14 +416,6 @@ namespace LTXVUpsampler {
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x"); sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x");
return x; return x;
} }
void load_fixed_tensors() {
if (!config.rational_resampler) {
return;
}
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
upsampler->load_fixed_tensors();
}
}; };
struct LatentUpsamplerRunner : public GGMLRunner { struct LatentUpsamplerRunner : public GGMLRunner {
@ -490,12 +472,11 @@ namespace LTXVUpsampler {
if (config.rational_resampler) { if (config.rational_resampler) {
ignore_tensors.insert("upsampler.blur_down.kernel"); ignore_tensors.insert("upsampler.blur_down.kernel");
} }
if (!model_loader.load_tensors(tensors, ignore_tensors, n_threads)) { model_loader.set_n_threads(n_threads);
if (!model_loader.load_tensors(tensors, ignore_tensors)) {
LOG_ERROR("load LTX latent upsampler tensors failed"); LOG_ERROR("load LTX latent upsampler tensors failed");
return false; return false;
} }
model->load_fixed_tensors();
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d", LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d",
config.in_channels, config.in_channels,
config.mid_channels, config.mid_channels,
@ -542,7 +523,7 @@ namespace LTXVUpsampler {
} }
size_t expected_dim = static_cast<size_t>(x.dim()); size_t expected_dim = static_cast<size_t>(x.dim());
auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x); }; auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x); };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), expected_dim); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), expected_dim);
} }
}; };

View File

@ -670,7 +670,7 @@ struct AutoEncoderKL : public VAE {
bool decode_only = false, bool decode_only = false,
bool use_video_decoder = false, bool use_video_decoder = false,
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1)
: decode_only(decode_only), VAE(version, backend, params_backend) { : VAE(version, backend, params_backend, prefix), decode_only(decode_only) {
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
scale_factor = 0.18215f; scale_factor = 0.18215f;
shift_factor = 0.f; shift_factor = 0.f;
@ -718,8 +718,8 @@ struct AutoEncoderKL : public VAE {
return "vae"; return "vae";
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
ae.get_param_tensors(tensors, prefix); ae.get_param_tensors(tensors, weight_prefix);
} }
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) { ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
@ -742,7 +742,7 @@ struct AutoEncoderKL : public VAE {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(z, decode_graph); return build_graph(z, decode_graph);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), z.dim());
} }
sd::Tensor<float> gaussian_latent_sample(const sd::Tensor<float>& moments, std::shared_ptr<RNG> rng) { sd::Tensor<float> gaussian_latent_sample(const sd::Tensor<float>& moments, std::shared_ptr<RNG> rng) {

View File

@ -997,6 +997,7 @@ namespace LTXV {
struct LTXAudioVAERunner : public GGMLRunner { struct LTXAudioVAERunner : public GGMLRunner {
LTXAudioVAEConfig config; LTXAudioVAEConfig config;
LTXAudioVAE model; LTXAudioVAE model;
std::string weight_prefix;
sd::Tensor<float> bwe_skip_filter_tensor; sd::Tensor<float> bwe_skip_filter_tensor;
LTXAudioVAERunner(ggml_backend_t backend, LTXAudioVAERunner(ggml_backend_t backend,
@ -1004,6 +1005,7 @@ namespace LTXV {
const String2TensorStorage& tensor_storage_map, const String2TensorStorage& tensor_storage_map,
const std::string& prefix = "") const std::string& prefix = "")
: GGMLRunner(backend, params_backend), : GGMLRunner(backend, params_backend),
weight_prefix(prefix),
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)), config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
model(config) { model(config) {
model.init(params_ctx, tensor_storage_map, prefix); model.init(params_ctx, tensor_storage_map, prefix);
@ -1013,8 +1015,8 @@ namespace LTXV {
} }
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) {
model.get_param_tensors(tensors, prefix); model.get_param_tensors(tensors, weight_prefix);
} }
size_t get_params_buffer_size() { size_t get_params_buffer_size() {
@ -1037,7 +1039,7 @@ namespace LTXV {
ggml_build_forward_expand(gf, waveform); ggml_build_forward_expand(gf, waveform);
return gf; return gf;
}; };
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4); auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), 4);
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
return result; return result;
@ -1082,7 +1084,7 @@ namespace LTXV {
} }
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;
ltx_audio_vae->get_param_tensors(tensors, ""); ltx_audio_vae->get_param_tensors(tensors);
if (!model_loader.load_tensors(tensors)) { if (!model_loader.load_tensors(tensors)) {
LOG_ERROR("load tensors from model loader failed"); LOG_ERROR("load tensors from model loader failed");

View File

@ -1239,7 +1239,7 @@ struct LTXVideoVAE : public VAE {
patch_size, patch_size,
tensor_storage_map, tensor_storage_map,
prefix), prefix),
VAE(version, backend, params_backend) { VAE(version, backend, params_backend, prefix) {
vae.init(params_ctx, tensor_storage_map, prefix); vae.init(params_ctx, tensor_storage_map, prefix);
decode_timestep_tensor.values()[0] = vae.decode_timestep; decode_timestep_tensor.values()[0] = vae.decode_timestep;
} }
@ -1271,8 +1271,8 @@ struct LTXVideoVAE : public VAE {
} }
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
vae.get_param_tensors(tensors, prefix); vae.get_param_tensors(tensors, weight_prefix);
} }
struct TemporalTilePlan { struct TemporalTilePlan {
@ -1396,7 +1396,7 @@ struct LTXVideoVAE : public VAE {
static_cast<int>(start), static_cast<int>(start),
chunk_overlap); chunk_overlap);
}; };
auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true), auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true, true, true),
expected_dim); expected_dim);
if (chunk.empty()) { if (chunk.empty()) {
free_cache_ctx_and_buffer(); free_cache_ctx_and_buffer();
@ -1452,7 +1452,7 @@ struct LTXVideoVAE : public VAE {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input, decode_graph); return build_graph(input, decode_graph);
}; };
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), expected_dim); auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), expected_dim);
if (result.empty()) { if (result.empty()) {
return {}; return {};
} }
@ -1465,7 +1465,7 @@ struct LTXVideoVAE : public VAE {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_latent_statistics_graph(z, normalize); return build_latent_statistics_graph(z, normalize);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false),
static_cast<size_t>(z.dim())); static_cast<size_t>(z.dim()));
} }
@ -1541,7 +1541,7 @@ struct LTXVideoVAE : public VAE {
} }
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;
vae->get_param_tensors(tensors, "first_stage_model"); vae->get_param_tensors(tensors);
if (!model_loader.load_tensors(tensors)) { if (!model_loader.load_tensors(tensors)) {
LOG_ERROR("load tensors from model loader failed"); LOG_ERROR("load tensors from model loader failed");

View File

@ -628,9 +628,9 @@ struct TinyImageAutoEncoder : public VAE {
const std::string prefix, const std::string prefix,
bool decoder_only = true, bool decoder_only = true,
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1)
: decode_only(decoder_only), : VAE(version, backend, params_backend, "tae"),
taesd(decoder_only, version), decode_only(decoder_only),
VAE(version, backend, params_backend) { taesd(decoder_only, version) {
scale_input = false; scale_input = false;
taesd.init(params_ctx, tensor_storage_map, prefix); taesd.init(params_ctx, tensor_storage_map, prefix);
} }
@ -639,8 +639,8 @@ struct TinyImageAutoEncoder : public VAE {
return "taesd"; return "taesd";
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
taesd.get_param_tensors(tensors, prefix); taesd.get_param_tensors(tensors, weight_prefix);
} }
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override { sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
@ -676,7 +676,7 @@ struct TinyImageAutoEncoder : public VAE {
return build_graph(z_tensor, decode_graph); return build_graph(z_tensor, decode_graph);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z_tensor.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), z_tensor.dim());
} }
}; };
@ -691,8 +691,8 @@ struct TinyVideoAutoEncoder : public VAE {
const std::string prefix, const std::string prefix,
bool decoder_only = true, bool decoder_only = true,
SDVersion version = VERSION_WAN2) SDVersion version = VERSION_WAN2)
: decode_only(decoder_only), : VAE(version, backend, params_backend, "tae"),
VAE(version, backend, params_backend) { decode_only(decoder_only) {
for (auto tensor_storage : tensor_storage_map) { for (auto tensor_storage : tensor_storage_map) {
if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) { if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) {
is_wide = true; is_wide = true;
@ -708,8 +708,8 @@ struct TinyVideoAutoEncoder : public VAE {
return "taehv"; return "taehv";
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
taehv.get_param_tensors(tensors, prefix); taehv.get_param_tensors(tensors, weight_prefix);
} }
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override { sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
@ -746,7 +746,7 @@ struct TinyVideoAutoEncoder : public VAE {
return build_graph(z_tensor, decode_graph); return build_graph(z_tensor, decode_graph);
}; };
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z_tensor.dim()); return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), z_tensor.dim());
} }
}; };

View File

@ -7,6 +7,7 @@
struct VAE : public GGMLRunner { struct VAE : public GGMLRunner {
protected: protected:
SDVersion version; SDVersion version;
std::string weight_prefix;
bool scale_input = true; bool scale_input = true;
virtual sd::Tensor<float> _compute(const int n_threads, virtual sd::Tensor<float> _compute(const int n_threads,
const sd::Tensor<float>& z, const sd::Tensor<float>& z,
@ -62,8 +63,8 @@ protected:
} }
public: public:
VAE(SDVersion version, ggml_backend_t backend, ggml_backend_t params_backend) VAE(SDVersion version, ggml_backend_t backend, ggml_backend_t params_backend, const std::string& weight_prefix = "")
: version(version), GGMLRunner(backend, params_backend) {} : version(version), weight_prefix(weight_prefix), GGMLRunner(backend, params_backend) {}
int get_scale_factor() { int get_scale_factor() {
int scale_factor = 8; int scale_factor = 8;
@ -214,7 +215,7 @@ public:
virtual sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) = 0; virtual sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) = 0;
virtual sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) = 0; virtual sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) = 0;
virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0; virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0;
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0; virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0;
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); }; virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
virtual void set_tiling_params(const sd_tiling_params_t& params) { virtual void set_tiling_params(const sd_tiling_params_t& params) {
@ -251,7 +252,7 @@ struct FakeVAE : public VAE {
return latents; return latents;
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {} void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {}
std::string get_desc() override { std::string get_desc() override {
return "fake_vae"; return "fake_vae";

View File

@ -1129,7 +1129,7 @@ namespace WAN {
const std::string prefix = "", const std::string prefix = "",
bool decode_only = false, bool decode_only = false,
SDVersion version = VERSION_WAN2) SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(version, backend, params_backend) { : VAE(version, backend, params_backend, prefix), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) {
ae.init(params_ctx, tensor_storage_map, prefix); ae.init(params_ctx, tensor_storage_map, prefix);
} }
@ -1137,8 +1137,8 @@ namespace WAN {
return "wan_vae"; return "wan_vae";
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
ae.get_param_tensors(tensors, prefix); ae.get_param_tensors(tensors, weight_prefix);
} }
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override { sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
@ -1255,7 +1255,7 @@ namespace WAN {
return build_graph(input, decode_graph); return build_graph(input, decode_graph);
} }
}; };
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true), auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true, true, true),
input.empty() ? z.dim() : input.dim()); input.empty() ? z.dim() : input.dim());
if (!result.empty() && z.dim() == 4) { if (!result.empty() && z.dim() == 4) {
result.squeeze_(2); result.squeeze_(2);
@ -1268,7 +1268,7 @@ namespace WAN {
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph_partial(z, decode_graph, i); return build_graph_partial(z, decode_graph, i);
}; };
auto out_opt = GGMLRunner::compute<float>(get_graph, n_threads, true); auto out_opt = GGMLRunner::compute<float>(get_graph, n_threads, true, true, true);
if (!out_opt.has_value()) { if (!out_opt.has_value()) {
return {}; return {};
} }
@ -1281,7 +1281,7 @@ namespace WAN {
sd::Tensor<float> output = std::move(out); sd::Tensor<float> output = std::move(out);
for (i = 1; i < t; i++) { for (i = 1; i < t; i++) {
auto chunk_opt = GGMLRunner::compute<float>(get_graph, n_threads, true); auto chunk_opt = GGMLRunner::compute<float>(get_graph, n_threads, true, true, true);
if (!chunk_opt.has_value()) { if (!chunk_opt.has_value()) {
return {}; return {};
} }
@ -1327,7 +1327,7 @@ namespace WAN {
// ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = sd_backend_cpu_init(); ggml_backend_t backend = sd_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16; ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, backend, String2TensorStorage{}, "", false, VERSION_WAN2_2_TI2V); std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, backend, String2TensorStorage{}, "first_stage_model", false, VERSION_WAN2_2_TI2V);
{ {
LOG_INFO("loading from '%s'", file_path.c_str()); LOG_INFO("loading from '%s'", file_path.c_str());
@ -1336,7 +1336,7 @@ namespace WAN {
return; return;
} }
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;
vae->get_param_tensors(tensors, "first_stage_model"); vae->get_param_tensors(tensors);
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(file_path, "vae.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "vae.")) {

View File

@ -1,6 +1,7 @@
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <cinttypes>
#include <cstdarg> #include <cstdarg>
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
@ -204,10 +205,28 @@ void convert_tensor(void* src,
/*================================================= ModelLoader ==================================================*/ /*================================================= ModelLoader ==================================================*/
ModelLoader::ModelLoader()
: n_threads_(sd_get_num_physical_cores()) {
}
size_t ModelLoader::add_file_path(const std::string& file_path) {
if (model_files_processed) {
file_data.clear();
model_files_processed = false;
}
file_paths_.push_back(file_path);
return file_paths_.size() - 1;
}
void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) {
tensor_storage_map[tensor_storage.name] = tensor_storage; tensor_storage_map[tensor_storage.name] = tensor_storage;
} }
void ModelLoader::set_n_threads(int n_threads) {
n_threads_ = n_threads > 0 ? n_threads : sd_get_num_physical_cores();
LOG_DEBUG("using %d threads for model loading", n_threads_);
}
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
if (is_directory(file_path)) { if (is_directory(file_path)) {
LOG_INFO("load %s using diffusers format", file_path.c_str()); LOG_INFO("load %s using diffusers format", file_path.c_str());
@ -271,8 +290,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
return false; return false;
} }
file_paths_.push_back(file_path); size_t file_index = add_file_path(file_path);
size_t file_index = file_paths_.size() - 1;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
// LOG_DEBUG("%s", tensor_storage.name.c_str()); // LOG_DEBUG("%s", tensor_storage.name.c_str());
@ -300,8 +318,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
return false; return false;
} }
file_paths_.push_back(file_path); size_t file_index = add_file_path(file_path);
size_t file_index = file_paths_.size() - 1;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) { if (is_unused_tensor(tensor_storage.name)) {
@ -335,8 +352,7 @@ bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, cons
return false; return false;
} }
file_paths_.push_back(file_path); size_t file_index = add_file_path(file_path);
size_t file_index = file_paths_.size() - 1;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) { if (is_unused_tensor(tensor_storage.name)) {
@ -366,8 +382,7 @@ bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const s
return false; return false;
} }
file_paths_.push_back(file_path); size_t file_index = add_file_path(file_path);
size_t file_index = file_paths_.size() - 1;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (!starts_with(tensor_storage.name, prefix)) { if (!starts_with(tensor_storage.name, prefix)) {
@ -760,8 +775,6 @@ void ModelLoader::process_model_files(bool enable_mmap, bool writable_mmap) {
return; return;
} }
int64_t start_time = ggml_time_ms();
std::vector<TensorStorage> processed_tensor_storages; std::vector<TensorStorage> processed_tensor_storages;
for (const auto& [name, tensor_storage] : tensor_storage_map) { for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) { if (is_unused_tensor(tensor_storage.name)) {
@ -812,20 +825,12 @@ void ModelLoader::process_model_files(bool enable_mmap, bool writable_mmap) {
} else { } else {
LOG_WARN("failed to memory-map '%s' (falling back to read())", file_path.c_str()); LOG_WARN("failed to memory-map '%s' (falling back to read())", file_path.c_str());
} }
} else if (!is_zip) {
LOG_INFO("NOT using mmap for '%s' (mmap disabled by caller)",
file_path.c_str());
} }
file_data.push_back(std::move(fdata)); file_data.push_back(std::move(fdata));
} }
model_files_processed = true; model_files_processed = true;
int64_t end_time = ggml_time_ms();
int64_t process_time_ms = end_time - start_time;
LOG_INFO("model files processing completed in %.2fs", process_time_ms / 1000.f);
} }
std::vector<MmapTensorStore> ModelLoader::mmap_tensors(std::map<std::string, ggml_tensor*>& tensors, std::vector<MmapTensorStore> ModelLoader::mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
@ -919,7 +924,9 @@ std::vector<MmapTensorStore> ModelLoader::mmap_tensors(std::map<std::string, ggm
return result; return result;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb,
bool enable_mmap,
const std::set<std::string>* target_tensor_names) {
process_model_files(enable_mmap, false); process_model_files(enable_mmap, false);
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);
@ -928,14 +935,26 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
std::atomic<int64_t> convert_time_ms(0); std::atomic<int64_t> convert_time_ms(0);
std::atomic<uint64_t> bytes_processed(0); std::atomic<uint64_t> bytes_processed(0);
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : sd_get_num_physical_cores(); int num_threads_to_use = n_threads_;
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
int64_t start_time = ggml_time_ms(); int64_t start_time = ggml_time_ms();
size_t total_tensors_to_process = 0; size_t total_tensors_to_process = 0;
std::vector<size_t> file_tensors_to_process_counts;
file_tensors_to_process_counts.reserve(file_data.size());
for (const auto& fdata : file_data) { for (const auto& fdata : file_data) {
total_tensors_to_process += fdata.tensors.size(); size_t file_tensors_to_process = 0;
if (target_tensor_names == nullptr) {
file_tensors_to_process = fdata.tensors.size();
} else {
for (const TensorStorage& tensor_storage : fdata.tensors) {
if (target_tensor_names->find(tensor_storage.name) != target_tensor_names->end()) {
file_tensors_to_process++;
}
}
}
file_tensors_to_process_counts.push_back(file_tensors_to_process);
total_tensors_to_process += file_tensors_to_process;
} }
bool success = true; bool success = true;
@ -943,17 +962,38 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
const int64_t t_start = start_time; const int64_t t_start = start_time;
int last_n_threads = 1; int last_n_threads = 1;
for (auto& fdata : file_data) { for (size_t file_index = 0; file_index < file_data.size(); ++file_index) {
auto& fdata = file_data[file_index];
const std::string& file_path = fdata.path; const std::string& file_path = fdata.path;
LOG_DEBUG("loading tensors from %s", file_path.c_str());
const std::vector<TensorStorage>& file_tensors = fdata.tensors; const std::vector<TensorStorage>& file_tensors = fdata.tensors;
std::vector<const TensorStorage*> tensors_to_process;
size_t file_tensors_to_process = file_tensors_to_process_counts[file_index];
tensors_to_process.reserve(file_tensors_to_process);
if (target_tensor_names == nullptr) {
for (const TensorStorage& tensor_storage : file_tensors) {
tensors_to_process.push_back(&tensor_storage);
}
} else {
for (const TensorStorage& tensor_storage : file_tensors) {
if (target_tensor_names->find(tensor_storage.name) != target_tensor_names->end()) {
tensors_to_process.push_back(&tensor_storage);
}
}
}
if (tensors_to_process.empty()) {
continue;
}
LOG_DEBUG("loading %zu/%zu tensors from %s",
tensors_to_process.size(),
file_tensors.size(),
file_path.c_str());
bool is_zip = fdata.is_zip; bool is_zip = fdata.is_zip;
std::shared_ptr<MmapWrapper> mmapped = fdata.mmapped; std::shared_ptr<MmapWrapper> mmapped = fdata.mmapped;
int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size()); int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)tensors_to_process.size());
if (n_threads < 1) { if (n_threads < 1) {
n_threads = 1; n_threads = 1;
} }
@ -989,11 +1029,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
while (true) { while (true) {
int64_t t0, t1; int64_t t0, t1;
size_t idx = tensor_idx.fetch_add(1); size_t idx = tensor_idx.fetch_add(1);
if (idx >= file_tensors.size() || failed) { if (idx >= tensors_to_process.size() || failed) {
break; break;
} }
const TensorStorage& tensor_storage = file_tensors[idx]; const TensorStorage& tensor_storage = *tensors_to_process[idx];
ggml_tensor* dst_tensor = nullptr; ggml_tensor* dst_tensor = nullptr;
t0 = ggml_time_ms(); t0 = ggml_time_ms();
@ -1133,16 +1173,18 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
while (true) { while (true) {
size_t current_idx = tensor_idx.load(); size_t current_idx = tensor_idx.load();
if (current_idx >= file_tensors.size() || failed) { if (current_idx >= tensors_to_process.size() || failed) {
break; break;
} }
size_t curr_num = total_tensors_processed + current_idx; size_t curr_num = total_tensors_processed + current_idx;
float elapsed_seconds = (ggml_time_ms() - t_start) / 1000.0f; float elapsed_seconds = (ggml_time_ms() - t_start) / 1000.0f;
pretty_bytes_progress(static_cast<int>(curr_num), if (total_tensors_to_process > 0) {
static_cast<int>(total_tensors_to_process), pretty_bytes_progress(static_cast<int>(curr_num),
bytes_processed.load(), static_cast<int>(total_tensors_to_process),
elapsed_seconds); bytes_processed.load(),
std::this_thread::sleep_for(std::chrono::milliseconds(200)); elapsed_seconds);
}
std::this_thread::sleep_for(std::chrono::milliseconds(total_tensors_to_process <= 4 ? 10 : 200));
} }
for (auto& w : workers) { for (auto& w : workers) {
@ -1153,12 +1195,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
success = false; success = false;
break; break;
} }
total_tensors_processed += file_tensors.size(); total_tensors_processed += tensors_to_process.size();
pretty_bytes_progress(static_cast<int>(total_tensors_processed), if (total_tensors_to_process > 0) {
static_cast<int>(total_tensors_to_process), pretty_bytes_progress(static_cast<int>(total_tensors_processed),
bytes_processed.load(), static_cast<int>(total_tensors_to_process),
(ggml_time_ms() - t_start) / 1000.0f); bytes_processed.load(),
if (total_tensors_processed < total_tensors_to_process) { (ggml_time_ms() - t_start) / 1000.0f);
}
if (total_tensors_processed < total_tensors_to_process && total_tensors_to_process > 0) {
printf("\n"); printf("\n");
} }
} }
@ -1173,9 +1217,77 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
return success; return success;
} }
bool ModelLoader::load_float_tensor(const std::string& name,
std::vector<float>& data,
int n_threads,
bool use_mmap) {
data.clear();
auto tensor_storage_it = tensor_storage_map.find(name);
if (tensor_storage_it == tensor_storage_map.end()) {
return false;
}
const TensorStorage& tensor_storage = tensor_storage_it->second;
int64_t n_elements = tensor_storage.nelements();
if (n_elements <= 0) {
LOG_ERROR("tensor '%s' has invalid element count: %" PRId64, name.c_str(), n_elements);
return false;
}
if (tensor_storage.n_dims <= 0 || tensor_storage.n_dims > GGML_MAX_DIMS) {
LOG_ERROR("tensor '%s' has unsupported dims: %d", name.c_str(), tensor_storage.n_dims);
return false;
}
std::vector<float> loaded_data(static_cast<size_t>(n_elements));
ggml_init_params params;
params.mem_size = ggml_tensor_overhead();
params.mem_buffer = nullptr;
params.no_alloc = true;
ggml_context* ctx = ggml_init(params);
if (ctx == nullptr) {
LOG_ERROR("failed to create context for tensor '%s'", name.c_str());
return false;
}
ggml_tensor* tensor = ggml_new_tensor(ctx, GGML_TYPE_F32, tensor_storage.n_dims, tensor_storage.ne);
ggml_set_name(tensor, name.c_str());
tensor->data = loaded_data.data();
bool loaded = false;
auto on_new_tensor_cb = [&](const TensorStorage& current_tensor_storage, ggml_tensor** dst_tensor) -> bool {
*dst_tensor = nullptr;
if (current_tensor_storage.name != name) {
return true;
}
if (current_tensor_storage.nelements() != n_elements) {
LOG_ERROR("tensor '%s' element count changed during load", name.c_str());
return false;
}
*dst_tensor = tensor;
loaded = true;
return true;
};
std::set<std::string> target_tensor_names{name};
if (n_threads > 0) {
set_n_threads(n_threads);
}
bool success = load_tensors(on_new_tensor_cb, use_mmap, &target_tensor_names);
ggml_free(ctx);
if (!success || !loaded) {
data.clear();
return false;
}
data = std::move(loaded_data);
return true;
}
bool ModelLoader::load_tensors(std::map<std::string, ggml_tensor*>& tensors, bool ModelLoader::load_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors, std::set<std::string> ignore_tensors,
int n_threads,
bool enable_mmap) { bool enable_mmap) {
std::set<std::string> tensor_names_in_file; std::set<std::string> tensor_names_in_file;
std::mutex tensor_names_mutex; std::mutex tensor_names_mutex;
@ -1219,7 +1331,7 @@ bool ModelLoader::load_tensors(std::map<std::string, ggml_tensor*>& tensors,
return true; return true;
}; };
bool success = load_tensors(on_new_tensor_cb, n_threads, enable_mmap); bool success = load_tensors(on_new_tensor_cb, enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from file failed"); LOG_ERROR("load tensors from file failed");
return false; return false;

View File

@ -34,7 +34,9 @@ protected:
std::vector<ModelFileData> file_data; std::vector<ModelFileData> file_data;
bool model_files_processed = false; bool model_files_processed = false;
String2TensorStorage tensor_storage_map; String2TensorStorage tensor_storage_map;
int n_threads_;
size_t add_file_path(const std::string& file_path);
void add_tensor_storage(const TensorStorage& tensor_storage); void add_tensor_storage(const TensorStorage& tensor_storage);
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
@ -44,6 +46,8 @@ protected:
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public: public:
ModelLoader();
bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_file(const std::string& file_path, const std::string& prefix = "");
void convert_tensors_name(); void convert_tensors_name();
bool init_from_file_and_convert_name(const std::string& file_path, bool init_from_file_and_convert_name(const std::string& file_path,
@ -55,16 +59,23 @@ public:
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat(); std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat(); std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
const String2TensorStorage& get_tensor_storage_map() const { return tensor_storage_map; }
void set_n_threads(int n_threads);
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
void process_model_files(bool enable_mmap = false, bool writable_mmap = true); void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors, std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {}, std::set<std::string> ignore_tensors = {},
bool writable = true); bool writable = true);
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb,
bool use_mmap = false,
const std::set<std::string>* target_tensor_names = nullptr);
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors, bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {}, std::set<std::string> ignore_tensors = {},
int n_threads = 0,
bool use_mmap = false); bool use_mmap = false);
bool load_float_tensor(const std::string& name,
std::vector<float>& data,
int n_threads = 0,
bool use_mmap = false);
std::vector<std::string> get_tensor_names() const { std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names; std::vector<std::string> names;

944
src/model_manager.cpp Normal file
View File

@ -0,0 +1,944 @@
#include "model_manager.h"
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <mutex>
#include <unordered_set>
#include "core/ggml_extend_backend.h"
#include "core/util.h"
#include "model/adapter/lora.hpp"
static size_t aligned_offset(const void* buffer, size_t offset, size_t alignment) {
GGML_ASSERT(alignment != 0 && (alignment & (alignment - 1)) == 0);
size_t align = (alignment - ((reinterpret_cast<uintptr_t>(buffer) + offset) % alignment)) % alignment;
return offset + align;
}
static bool lora_specs_equal(const std::vector<ModelManager::LoraSpec>& lhs,
const std::vector<ModelManager::LoraSpec>& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
if (lhs[i].path != rhs[i].path ||
lhs[i].multiplier != rhs[i].multiplier ||
lhs[i].is_high_noise != rhs[i].is_high_noise ||
lhs[i].tensor_name_prefix_filter != rhs[i].tensor_name_prefix_filter ||
lhs[i].required != rhs[i].required) {
return false;
}
}
return true;
}
static std::string lora_id(const ModelManager::LoraSpec& lora) {
return lora.is_high_noise ? "|high_noise|" + lora.path : lora.path;
}
static bool backend_supports_host_buffer(ggml_backend_t backend) {
if (backend == nullptr) {
return false;
}
if (sd_backend_is_cpu(backend)) {
return true;
}
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
if (dev == nullptr) {
return false;
}
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
return props.caps.buffer_from_host_ptr;
}
ModelManager::~ModelManager() {
release_all();
}
void ModelManager::set_common_ignore_tensors(std::set<std::string> ignore_tensors) {
common_ignore_tensors_ = std::move(ignore_tensors);
}
void ModelManager::set_loras(std::vector<LoraSpec> loras, SDVersion version) {
if (loras.empty() && loras_.empty()) {
lora_version_ = version;
return;
}
if (lora_version_ == version && lora_specs_equal(loras_, loras)) {
return;
}
loras_ = std::move(loras);
lora_version_ = version;
current_lora_epoch_++;
reset_lora_applied_params();
}
std::set<std::string> ModelManager::tensor_names() const {
std::set<std::string> names;
for (const auto& state : tensor_states_) {
if (state != nullptr) {
names.insert(state->name);
}
}
return names;
}
size_t estimate_tensors_size(const std::map<std::string, ggml_tensor*>& tensors) {
size_t size = 0;
std::unordered_set<ggml_tensor*> seen;
for (const auto& pair : tensors) {
ggml_tensor* tensor = pair.second;
if (tensor == nullptr || seen.find(tensor) != seen.end()) {
continue;
}
seen.insert(tensor);
size += ggml_nbytes(tensor);
}
return size;
}
bool ModelManager::register_param_tensors(const std::string& desc,
std::map<std::string, ggml_tensor*> tensors,
ResidencyMode residency_mode,
ggml_backend_t compute_backend,
ggml_backend_t params_backend,
size_t* registered_tensor_size) {
if (desc.empty()) {
LOG_ERROR("model manager tensor desc is empty");
return false;
}
if (registered_tensor_size != nullptr) {
*registered_tensor_size += estimate_tensors_size(tensors);
}
std::vector<std::unique_ptr<TensorState>> new_states;
new_states.reserve(tensors.size());
for (const auto& pair : tensors) {
const std::string& name = pair.first;
ggml_tensor* tensor = pair.second;
if (tensor == nullptr) {
continue;
}
if (tensor_states_by_name_.find(name) != tensor_states_by_name_.end()) {
LOG_ERROR("model manager tensor name '%s' is already registered", name.c_str());
return false;
}
ggml_set_name(tensor, name.c_str());
auto state = std::make_unique<TensorState>();
state->name = name;
state->tensor = tensor;
state->desc = desc;
state->residency_mode = residency_mode;
state->compute_backend = compute_backend;
state->params_backend = params_backend;
new_states.push_back(std::move(state));
}
for (auto& state : new_states) {
TensorState* registered_state = state.get();
tensor_states_by_name_[registered_state->name] = registered_state;
tensor_states_.push_back(std::move(state));
}
return true;
}
bool ModelManager::validate_registered_tensors() {
bool ok = true;
for (const auto& state : tensor_states_) {
if (state == nullptr) {
ok = false;
continue;
}
bool state_ok = validate_tensor(*state);
if (state_ok) {
state->metadata_validated = true;
}
ok = state_ok && ok;
}
return ok;
}
bool ModelManager::load_tensors_to_params_backend(const std::vector<TensorState*>& states) {
std::vector<TensorState*> need_load;
need_load.reserve(states.size());
for (TensorState* state : states) {
if (state == nullptr || should_ignore(*state) || is_optional_missing_tensor(state->name)) {
continue;
}
if (!state->metadata_validated) {
if (!validate_tensor(*state)) {
return false;
}
state->metadata_validated = true;
}
if (!state->loaded_to_params_backend) {
need_load.push_back(state);
}
}
if (need_load.empty()) {
return true;
}
std::vector<ParamsStorageBlock*> created_storage_blocks;
if (!mmap_params(need_load, created_storage_blocks)) {
for (ParamsStorageBlock* block : created_storage_blocks) {
if (block != nullptr) {
free_params_storage_block(*block);
erase_params_storage_block(block);
}
}
return false;
}
std::vector<TensorState*> need_alloc;
need_alloc.reserve(need_load.size());
for (TensorState* state : need_load) {
if (state->tensor != nullptr && state->tensor->data == nullptr && state->tensor->view_src == nullptr) {
need_alloc.push_back(state);
}
}
if (!alloc_params_buffers(need_alloc, created_storage_blocks) ||
!load_tensors(need_load)) {
for (ParamsStorageBlock* block : created_storage_blocks) {
if (block != nullptr) {
free_params_storage_block(*block);
erase_params_storage_block(block);
}
}
return false;
}
for (ParamsStorageBlock* block : created_storage_blocks) {
if (block != nullptr && block->buffer != nullptr) {
LOG_DEBUG("model manager prepared params backend buffer (%6.2f MB, %zu tensors, %s)",
ggml_backend_buffer_get_size(block->buffer) / (1024.f * 1024.f),
block->states.size(),
ggml_backend_buffer_is_host(block->buffer) ? "RAM" : "VRAM");
}
}
return true;
}
bool ModelManager::stage_tensors_to_compute_backend(const std::vector<TensorState*>& states) {
std::map<ggml_backend_t, std::vector<TensorState*>> states_by_compute_backend;
for (TensorState* state : states) {
if (state == nullptr || should_ignore(*state) || is_optional_missing_tensor(state->name)) {
continue;
}
if (state->compute_backend == nullptr) {
LOG_ERROR("model manager compute backend is null for tensor '%s'", state->name.c_str());
return false;
}
if (state->params_backend == nullptr) {
LOG_ERROR("model manager params backend is null for tensor '%s'", state->name.c_str());
return false;
}
if (state->compute_backend == state->params_backend || state->staged_to_compute_backend) {
continue;
}
if (!state->loaded_to_params_backend || state->tensor == nullptr || state->tensor->data == nullptr) {
LOG_ERROR("model manager tensor '%s' is not loaded to params backend", state->name.c_str());
return false;
}
states_by_compute_backend[state->compute_backend].push_back(state);
}
for (const auto& pair : states_by_compute_backend) {
ggml_backend_t compute_backend = pair.first;
const std::vector<TensorState*>& states = pair.second;
if (states.empty()) {
continue;
}
int64_t t0 = ggml_time_ms();
ggml_init_params init_params;
init_params.mem_size = std::max<size_t>(1, states.size()) * ggml_tensor_overhead();
init_params.mem_buffer = nullptr;
init_params.no_alloc = true;
ggml_context* staging_ctx = ggml_init(init_params);
GGML_ASSERT(staging_ctx != nullptr);
std::vector<std::pair<TensorState*, ggml_tensor*>> staged_tensors;
staged_tensors.reserve(states.size());
for (TensorState* state : states) {
ggml_tensor* staging_tensor = ggml_dup_tensor(staging_ctx, state->tensor);
ggml_set_name(staging_tensor, state->tensor->name);
staged_tensors.push_back({state, staging_tensor});
}
ggml_backend_buffer_t compute_buffer = ggml_backend_alloc_ctx_tensors(staging_ctx, compute_backend);
if (compute_buffer == nullptr) {
LOG_ERROR("model manager alloc compute params backend buffer failed, num_tensors = %zu",
staged_tensors.size());
ggml_free(staging_ctx);
return false;
}
ggml_backend_buffer_set_usage(compute_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
for (auto& staged_tensor : staged_tensors) {
TensorState* state = staged_tensor.first;
ggml_tensor* managed_tensor = state->tensor;
ggml_tensor* staging_tensor = staged_tensor.second;
ggml_backend_tensor_copy(managed_tensor, staging_tensor);
std::swap(managed_tensor->buffer, staging_tensor->buffer);
std::swap(managed_tensor->data, staging_tensor->data);
std::swap(managed_tensor->extra, staging_tensor->extra);
}
ggml_backend_synchronize(compute_backend);
auto block = std::make_unique<ComputeStagingBlock>();
block->compute_backend = compute_backend;
block->buffer = compute_buffer;
block->staging_ctx = staging_ctx;
block->staged_tensors = std::move(staged_tensors);
for (auto& staged_tensor : block->staged_tensors) {
TensorState* state = staged_tensor.first;
state->staged_to_compute_backend = true;
}
compute_staging_blocks_.push_back(std::move(block));
int64_t t1 = ggml_time_ms();
LOG_DEBUG("model manager staged compute params (%6.2f MB, %zu tensors) to %s, taking %.2fs",
ggml_backend_buffer_get_size(compute_buffer) / (1024.f * 1024.f),
states.size(),
ggml_backend_name(compute_backend),
(t1 - t0) * 1.0f / 1000);
}
return true;
}
bool ModelManager::apply_loras_to_params(const std::vector<TensorState*>& states) {
if (loras_.empty()) {
return true;
}
struct LoraApplyGroup {
std::map<std::string, ggml_tensor*> model_tensors;
std::vector<TensorState*> states;
};
std::map<ggml_backend_t, LoraApplyGroup> groups;
for (TensorState* state : states) {
if (state == nullptr || state->tensor == nullptr ||
should_ignore(*state) || is_optional_missing_tensor(state->name)) {
continue;
}
if (state->applied_lora_epoch == current_lora_epoch_) {
continue;
}
if (state->compute_backend == nullptr) {
LOG_ERROR("model manager compute backend is null for lora target tensor '%s'", state->name.c_str());
return false;
}
if (state->tensor->data == nullptr) {
LOG_ERROR("model manager lora target tensor '%s' is not prepared", state->name.c_str());
return false;
}
LoraApplyGroup& group = groups[state->compute_backend];
group.model_tensors[state->name] = state->tensor;
group.states.push_back(state);
}
if (groups.empty()) {
return true;
}
std::set<std::string> all_tensor_names = tensor_names();
for (auto& group_pair : groups) {
ggml_backend_t compute_backend = group_pair.first;
LoraApplyGroup& group = group_pair.second;
for (const LoraSpec& lora_spec : loras_) {
if (group.model_tensors.empty()) {
continue;
}
std::string id = lora_id(lora_spec);
auto lora = std::make_shared<LoraModel>(id,
compute_backend,
compute_backend,
lora_spec.path,
lora_spec.is_high_noise ? "model.high_noise_" : "",
lora_version_);
LoraModel::filter_t lora_tensor_filter = nullptr;
if (!lora_spec.tensor_name_prefix_filter.empty()) {
lora_tensor_filter = [&](const std::string& tensor_name) {
return starts_with(tensor_name, lora_spec.tensor_name_prefix_filter);
};
}
if (!lora->load_from_file(n_threads_, lora_tensor_filter)) {
LOG_WARN("load lora tensors from %s failed", lora_spec.path.c_str());
if (lora_spec.required) {
return false;
}
continue;
}
if (lora->lora_tensors.empty()) {
if (lora_spec.required) {
LOG_ERROR("required lora has no tensors: %s", lora_spec.path.c_str());
return false;
}
continue;
}
lora->multiplier = lora_spec.multiplier;
lora->apply(group.model_tensors, all_tensor_names, lora_version_, n_threads_, false);
lora->release_loaded_tensors();
}
for (TensorState* state : group.states) {
if (state != nullptr) {
state->applied_lora_epoch = current_lora_epoch_;
}
}
}
return true;
}
void ModelManager::reset_lora_applied_params() {
release_compute_staging_blocks(true);
release_params_storage_blocks(true);
for (auto& state : tensor_states_) {
state->applied_lora_epoch = UINT64_MAX;
}
}
bool ModelManager::should_ignore(const TensorState& state) const {
for (const auto& ignore_prefix : common_ignore_tensors_) {
if (starts_with(state.name, ignore_prefix)) {
return true;
}
}
return false;
}
bool ModelManager::is_optional_missing_tensor(const std::string& name) const {
return name.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos ||
name.find("alphas_cumprod") != std::string::npos;
}
bool ModelManager::validate_tensor(const TensorState& state) const {
if (state.tensor == nullptr || should_ignore(state) || is_optional_missing_tensor(state.name)) {
return true;
}
const auto& tensor_storage_map = model_loader_.get_tensor_storage_map();
auto ts_it = tensor_storage_map.find(state.name);
if (ts_it == tensor_storage_map.end()) {
LOG_ERROR("%s tensor '%s' not in model metadata", state.desc.c_str(), state.name.c_str());
return false;
}
const TensorStorage& tensor_storage = ts_it->second;
if (state.tensor->ne[0] != tensor_storage.ne[0] ||
state.tensor->ne[1] != tensor_storage.ne[1] ||
state.tensor->ne[2] != tensor_storage.ne[2] ||
state.tensor->ne[3] != tensor_storage.ne[3]) {
LOG_ERROR(
"%s tensor '%s' has wrong shape in model metadata: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]",
state.desc.c_str(),
state.name.c_str(),
(int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3],
(int)state.tensor->ne[0], (int)state.tensor->ne[1], (int)state.tensor->ne[2], (int)state.tensor->ne[3]);
return false;
}
return true;
}
bool ModelManager::mmap_params(const std::vector<TensorState*>& states,
std::vector<ParamsStorageBlock*>& created_storage_blocks) {
std::map<std::string, ggml_tensor*> mmap_candidates;
std::map<std::string, TensorState*> mmap_states;
for (TensorState* state : states) {
if (state == nullptr || !can_mmap_storage(*state) || state->tensor == nullptr ||
state->tensor->data != nullptr || state->tensor->view_src != nullptr) {
continue;
}
mmap_candidates[state->name] = state->tensor;
mmap_states[state->name] = state;
}
if (mmap_candidates.empty()) {
return true;
}
auto mmap_store = model_loader_.mmap_tensors(mmap_candidates, {}, true);
if (mmap_store.empty()) {
return true;
}
auto block = std::make_unique<ParamsStorageBlock>();
block->mmap_tensor_stores = std::move(mmap_store);
ParamsStorageBlock* raw = block.get();
for (const auto& pair : mmap_states) {
TensorState* state = pair.second;
if (state != nullptr && state->tensor != nullptr && state->tensor->data != nullptr) {
block->states.push_back(state);
}
}
if (!block->states.empty()) {
params_storage_blocks_.push_back(std::move(block));
created_storage_blocks.push_back(raw);
}
return true;
}
bool ModelManager::can_mmap_storage(const TensorState& state) const {
if (!enable_mmap_ || state.residency_mode != ResidencyMode::Resident) {
return false;
}
if (state.compute_backend == nullptr || state.params_backend == nullptr) {
return false;
}
return sd_backend_is_cpu(state.compute_backend) ||
sd_backend_is_cpu(state.params_backend) ||
backend_supports_host_buffer(state.compute_backend);
}
bool ModelManager::alloc_params_buffers(const std::vector<TensorState*>& states,
std::vector<ParamsStorageBlock*>& created_storage_blocks) {
std::map<std::pair<ggml_backend_buffer_type_t, int>, std::vector<TensorState*>> states_by_buffer_type;
for (TensorState* state : states) {
if (state == nullptr || state->tensor == nullptr) {
continue;
}
ggml_backend_buffer_type_t params_buft = params_buffer_type_for(*state);
if (params_buft == nullptr) {
return false;
}
states_by_buffer_type[{params_buft, static_cast<int>(state->residency_mode)}].push_back(state);
}
for (const auto& pair : states_by_buffer_type) {
ggml_backend_buffer_type_t params_buft = pair.first.first;
const std::vector<TensorState*>& states = pair.second;
size_t alignment = ggml_backend_buft_get_alignment(params_buft);
size_t max_size = ggml_backend_buft_get_max_size(params_buft);
auto alloc_chunk = [&](const std::vector<TensorState*>& chunk, size_t chunk_size) -> bool {
if (chunk.empty() || chunk_size == 0) {
return true;
}
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(params_buft, chunk_size);
if (buffer == nullptr) {
LOG_ERROR("model manager alloc params backend buffer failed, size = %.2fMB",
chunk_size / (1024.0 * 1024.0));
return false;
}
ggml_backend_buffer_set_usage(buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
std::vector<ggml_tensor*> initialized_tensors;
void* base = ggml_backend_buffer_get_base(buffer);
size_t offset = aligned_offset(base, 0, ggml_backend_buffer_get_alignment(buffer));
for (TensorState* state : chunk) {
ggml_tensor* tensor = state->tensor;
size_t tensor_size = GGML_PAD(ggml_backend_buffer_get_alloc_size(buffer, tensor),
ggml_backend_buffer_get_alignment(buffer));
enum ggml_status status = ggml_backend_tensor_alloc(buffer, tensor, static_cast<char*>(base) + offset);
if (status != GGML_STATUS_SUCCESS) {
LOG_ERROR("model manager failed to initialize params tensor '%s'", ggml_get_name(tensor));
for (ggml_tensor* initialized : initialized_tensors) {
initialized->buffer = nullptr;
initialized->data = nullptr;
initialized->extra = nullptr;
}
LOG_DEBUG("model manager releasing params backend buffer (%6.2f MB, %zu tensors, %s)",
ggml_backend_buffer_get_size(buffer) / (1024.f * 1024.f),
initialized_tensors.size(),
ggml_backend_buffer_is_host(buffer) ? "RAM" : "VRAM");
ggml_backend_buffer_free(buffer);
return false;
}
initialized_tensors.push_back(tensor);
offset += tensor_size;
}
auto block = std::make_unique<ParamsStorageBlock>();
block->buffer = buffer;
block->states = chunk;
ParamsStorageBlock* raw = block.get();
params_storage_blocks_.push_back(std::move(block));
created_storage_blocks.push_back(raw);
return true;
};
std::vector<TensorState*> chunk;
size_t chunk_size = 0;
for (TensorState* state : states) {
ggml_tensor* tensor = state->tensor;
size_t tensor_size = GGML_PAD(ggml_backend_buft_get_alloc_size(params_buft, tensor), alignment);
if (max_size > 0 && tensor_size > max_size) {
LOG_ERROR("model manager tensor '%s' is too large for params buffer: %zu > %zu",
ggml_get_name(tensor),
tensor_size,
max_size);
return false;
}
if (!chunk.empty() && max_size > 0 && chunk_size + tensor_size > max_size) {
if (!alloc_chunk(chunk, chunk_size)) {
return false;
}
chunk.clear();
chunk_size = 0;
}
chunk.push_back(state);
chunk_size += tensor_size;
}
if (!alloc_chunk(chunk, chunk_size)) {
return false;
}
}
return true;
}
bool ModelManager::load_tensors(const std::vector<TensorState*>& states) {
std::map<std::string, TensorState*> states_by_name;
std::set<std::string> target_tensor_names;
for (TensorState* state : states) {
if (state == nullptr) {
continue;
}
states_by_name[state->name] = state;
target_tensor_names.insert(state->name);
}
if (states_by_name.empty()) {
return true;
}
std::set<std::string> loaded_names;
std::mutex loaded_names_mutex;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
*dst_tensor = nullptr;
auto state_it = states_by_name.find(name);
if (state_it == states_by_name.end()) {
return true;
}
TensorState* state = state_it->second;
if (state == nullptr || state->tensor == nullptr) {
LOG_ERROR("model manager tensor '%s' is null", name.c_str());
return false;
}
if (state->tensor->ne[0] != tensor_storage.ne[0] ||
state->tensor->ne[1] != tensor_storage.ne[1] ||
state->tensor->ne[2] != tensor_storage.ne[2] ||
state->tensor->ne[3] != tensor_storage.ne[3]) {
LOG_ERROR(
"model manager tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]",
name.c_str(),
(int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3],
(int)state->tensor->ne[0], (int)state->tensor->ne[1], (int)state->tensor->ne[2], (int)state->tensor->ne[3]);
return false;
}
{
std::lock_guard<std::mutex> lock(loaded_names_mutex);
loaded_names.insert(name);
}
*dst_tensor = state->tensor;
return true;
};
if (!model_loader_.load_tensors(on_new_tensor_cb, enable_mmap_, &target_tensor_names)) {
LOG_ERROR("model manager load tensors failed");
return false;
}
bool missing = false;
for (const auto& pair : states_by_name) {
const std::string& name = pair.first;
if (loaded_names.find(name) == loaded_names.end()) {
LOG_ERROR("model manager tensor '%s' was not loaded", name.c_str());
missing = true;
}
}
if (missing) {
return false;
}
for (const auto& pair : states_by_name) {
pair.second->loaded_to_params_backend = true;
}
return true;
}
ggml_backend_buffer_type_t ModelManager::params_buffer_type_for(const TensorState& state) const {
if (state.params_backend == nullptr) {
LOG_ERROR("model manager params backend is null for tensor '%s'", state.name.c_str());
return nullptr;
}
ggml_backend_buffer_type_t params_buft = nullptr;
if (state.compute_backend != nullptr && state.params_backend != state.compute_backend) {
ggml_backend_dev_t compute_dev = ggml_backend_get_device(state.compute_backend);
if (compute_dev != nullptr) {
params_buft = ggml_backend_dev_host_buffer_type(compute_dev);
}
}
if (params_buft == nullptr) {
params_buft = ggml_backend_get_default_buffer_type(state.params_backend);
}
return params_buft;
}
void ModelManager::free_compute_staging_block(ComputeStagingBlock& block) {
for (auto& staged_tensor : block.staged_tensors) {
TensorState* state = staged_tensor.first;
ggml_tensor* staging_tensor = staged_tensor.second;
if (state == nullptr || state->tensor == nullptr || staging_tensor == nullptr) {
continue;
}
ggml_tensor* managed_tensor = state->tensor;
managed_tensor->buffer = staging_tensor->buffer;
managed_tensor->data = staging_tensor->data;
managed_tensor->extra = staging_tensor->extra;
staging_tensor->buffer = nullptr;
staging_tensor->data = nullptr;
staging_tensor->extra = nullptr;
state->staged_to_compute_backend = false;
state->applied_lora_epoch = UINT64_MAX;
}
if (block.buffer != nullptr) {
LOG_DEBUG("model manager releasing compute params (%6.2f MB, %zu tensors) from %s",
ggml_backend_buffer_get_size(block.buffer) / (1024.f * 1024.f),
block.staged_tensors.size(),
block.compute_backend != nullptr ? ggml_backend_name(block.compute_backend) : "unknown");
ggml_backend_buffer_free(block.buffer);
block.buffer = nullptr;
}
if (block.staging_ctx != nullptr) {
ggml_free(block.staging_ctx);
block.staging_ctx = nullptr;
}
block.staged_tensors.clear();
}
void ModelManager::release_compute_staging_blocks(bool force,
const std::unordered_set<TensorState*>* target_states) {
for (auto it = compute_staging_blocks_.begin(); it != compute_staging_blocks_.end();) {
ComputeStagingBlock* block = it->get();
bool can_release = force;
if (!can_release) {
can_release = std::all_of(block->staged_tensors.begin(),
block->staged_tensors.end(),
[target_states](const std::pair<TensorState*, ggml_tensor*>& pair) {
TensorState* state = pair.first;
if (state == nullptr) {
return true;
}
if (target_states != nullptr &&
target_states->find(state) == target_states->end()) {
return false;
}
return state->active_prepare_count == 0;
});
}
if (can_release) {
free_compute_staging_block(*block);
it = compute_staging_blocks_.erase(it);
} else {
++it;
}
}
}
void ModelManager::free_params_storage_block(ParamsStorageBlock& block) {
if (block.buffer != nullptr) {
LOG_DEBUG("model manager releasing params backend buffer (%6.2f MB, %zu tensors, %s)",
ggml_backend_buffer_get_size(block.buffer) / (1024.f * 1024.f),
block.states.size(),
ggml_backend_buffer_is_host(block.buffer) ? "RAM" : "VRAM");
ggml_backend_buffer_free(block.buffer);
block.buffer = nullptr;
}
block.mmap_tensor_stores.clear();
for (TensorState* state : block.states) {
if (state == nullptr || state->tensor == nullptr) {
continue;
}
state->tensor->buffer = nullptr;
state->tensor->data = nullptr;
state->tensor->extra = nullptr;
state->loaded_to_params_backend = false;
state->applied_lora_epoch = UINT64_MAX;
}
block.states.clear();
}
void ModelManager::release_params_storage_blocks(bool force,
const std::unordered_set<TensorState*>* target_states) {
for (auto it = params_storage_blocks_.begin(); it != params_storage_blocks_.end();) {
ParamsStorageBlock* block = it->get();
bool can_release = force;
if (!can_release) {
can_release = std::all_of(block->states.begin(),
block->states.end(),
[target_states](TensorState* state) {
if (state == nullptr) {
return true;
}
if (target_states != nullptr &&
target_states->find(state) == target_states->end()) {
return false;
}
return state->active_prepare_count == 0 &&
!state->staged_to_compute_backend &&
state->residency_mode == ResidencyMode::Disk;
});
}
if (can_release) {
free_params_storage_block(*block);
it = params_storage_blocks_.erase(it);
} else {
++it;
}
}
}
void ModelManager::erase_params_storage_block(ParamsStorageBlock* block) {
auto it = std::find_if(params_storage_blocks_.begin(),
params_storage_blocks_.end(),
[block](const std::unique_ptr<ParamsStorageBlock>& item) {
return item.get() == block;
});
if (it != params_storage_blocks_.end()) {
params_storage_blocks_.erase(it);
}
}
void ModelManager::release_all() {
for (auto& state : tensor_states_) {
state->active_prepare_count = 0;
state->applied_lora_epoch = UINT64_MAX;
}
release_compute_staging_blocks(true);
release_params_storage_blocks(true);
}
bool ModelManager::resolve_required_tensor_states(const std::vector<ggml_tensor*>& tensors,
std::vector<TensorState*>& required_states) const {
required_states.clear();
std::unordered_set<TensorState*> seen;
for (ggml_tensor* tensor : tensors) {
if (tensor == nullptr) {
continue;
}
const char* raw_name = ggml_get_name(tensor);
if (raw_name == nullptr || raw_name[0] == '\0') {
LOG_ERROR("model manager unnamed tensor is not registered");
return false;
}
auto state_it = tensor_states_by_name_.find(raw_name);
if (state_it == tensor_states_by_name_.end()) {
LOG_ERROR("model manager tensor '%s' is not registered", raw_name);
return false;
}
TensorState* state = state_it->second;
if (state == nullptr) {
LOG_ERROR("model manager tensor '%s' has no tensor state", raw_name);
return false;
}
if (seen.insert(state).second) {
required_states.push_back(state);
}
}
return true;
}
bool ModelManager::prepare_params(const std::vector<ggml_tensor*>& tensors) {
if (tensors.empty()) {
return true;
}
std::vector<TensorState*> required_states;
if (!resolve_required_tensor_states(tensors, required_states)) {
return false;
}
if (!load_tensors_to_params_backend(required_states)) {
return false;
}
if (!stage_tensors_to_compute_backend(required_states)) {
release_compute_staging_blocks(false);
release_params_storage_blocks(false);
return false;
}
if (!apply_loras_to_params(required_states)) {
release_compute_staging_blocks(false);
release_params_storage_blocks(false);
return false;
}
for (TensorState* state : required_states) {
if (state == nullptr) {
continue;
}
state->active_prepare_count++;
}
return true;
}
void ModelManager::finish_compute_backend_usage(const std::vector<TensorState*>& states) {
if (states.empty()) {
return;
}
std::unordered_set<TensorState*> target_states;
for (TensorState* state : states) {
if (state == nullptr || !target_states.insert(state).second) {
continue;
}
if (state->active_prepare_count > 0) {
state->active_prepare_count--;
}
}
release_compute_staging_blocks(false, &target_states);
}
void ModelManager::release_compute_backend_params(const std::vector<ggml_tensor*>& tensors) {
if (tensors.empty()) {
return;
}
std::vector<TensorState*> required_states;
if (!resolve_required_tensor_states(tensors, required_states)) {
return;
}
finish_compute_backend_usage(required_states);
}
void ModelManager::release_params_backend_params(const std::vector<ggml_tensor*>& tensors) {
if (tensors.empty()) {
return;
}
std::vector<TensorState*> required_states;
if (!resolve_required_tensor_states(tensors, required_states)) {
return;
}
if (required_states.empty()) {
return;
}
std::unordered_set<TensorState*> target_states(required_states.begin(), required_states.end());
release_params_storage_blocks(false, &target_states);
}

131
src/model_manager.h Normal file
View File

@ -0,0 +1,131 @@
#ifndef __MODEL_MANAGER_H__
#define __MODEL_MANAGER_H__
#include <cstdint>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <vector>
#include "model_loader.h"
#include "weight_manager.h"
class ModelManager : public RunnerWeightManager {
public:
enum class ResidencyMode {
Disk,
Resident,
};
struct LoraSpec {
std::string path;
float multiplier = 1.0f;
bool is_high_noise = false;
std::string tensor_name_prefix_filter;
bool required = false;
};
private:
struct TensorState {
std::string name;
ggml_tensor* tensor = nullptr;
std::string desc;
ResidencyMode residency_mode = ResidencyMode::Resident;
ggml_backend_t compute_backend = nullptr;
ggml_backend_t params_backend = nullptr;
bool metadata_validated = false;
int active_prepare_count = 0;
bool loaded_to_params_backend = false;
bool staged_to_compute_backend = false;
uint64_t applied_lora_epoch = UINT64_MAX;
};
struct ParamsStorageBlock {
ggml_backend_buffer_t buffer = nullptr;
std::vector<MmapTensorStore> mmap_tensor_stores;
std::vector<TensorState*> states;
};
struct ComputeStagingBlock {
ggml_backend_t compute_backend = nullptr;
ggml_backend_buffer_t buffer = nullptr;
ggml_context* staging_ctx = nullptr;
std::vector<std::pair<TensorState*, ggml_tensor*>> staged_tensors;
};
ModelLoader model_loader_;
std::vector<std::unique_ptr<TensorState>> tensor_states_;
std::map<std::string, TensorState*> tensor_states_by_name_;
std::vector<std::unique_ptr<ParamsStorageBlock>> params_storage_blocks_;
std::vector<std::unique_ptr<ComputeStagingBlock>> compute_staging_blocks_;
std::set<std::string> common_ignore_tensors_;
std::vector<LoraSpec> loras_;
SDVersion lora_version_ = VERSION_COUNT;
uint64_t current_lora_epoch_ = 0;
int n_threads_ = 0;
bool enable_mmap_ = false;
void finish_compute_backend_usage(const std::vector<TensorState*>& states);
void release_all();
bool resolve_required_tensor_states(const std::vector<ggml_tensor*>& tensors,
std::vector<TensorState*>& required_states) const;
bool should_ignore(const TensorState& state) const;
bool is_optional_missing_tensor(const std::string& name) const;
bool validate_tensor(const TensorState& state) const;
bool load_tensors_to_params_backend(const std::vector<TensorState*>& states);
bool apply_loras_to_params(const std::vector<TensorState*>& states);
bool mmap_params(const std::vector<TensorState*>& states,
std::vector<ParamsStorageBlock*>& created_storage_blocks);
bool can_mmap_storage(const TensorState& state) const;
bool alloc_params_buffers(const std::vector<TensorState*>& states,
std::vector<ParamsStorageBlock*>& created_storage_blocks);
bool load_tensors(const std::vector<TensorState*>& states);
bool stage_tensors_to_compute_backend(const std::vector<TensorState*>& states);
ggml_backend_buffer_type_t params_buffer_type_for(const TensorState& state) const;
void release_compute_staging_blocks(bool force = false,
const std::unordered_set<TensorState*>* target_states = nullptr);
void release_params_storage_blocks(bool force = false,
const std::unordered_set<TensorState*>* target_states = nullptr);
void free_compute_staging_block(ComputeStagingBlock& block);
void free_params_storage_block(ParamsStorageBlock& block);
void erase_params_storage_block(ParamsStorageBlock* block);
void reset_lora_applied_params();
public:
~ModelManager() override;
ModelLoader& loader() { return model_loader_; }
const ModelLoader& loader() const { return model_loader_; }
void set_n_threads(int n_threads) {
n_threads_ = n_threads;
model_loader_.set_n_threads(n_threads);
}
void set_enable_mmap(bool enable_mmap) { enable_mmap_ = enable_mmap; }
void set_common_ignore_tensors(std::set<std::string> ignore_tensors);
void set_loras(std::vector<LoraSpec> loras, SDVersion version);
std::set<std::string> tensor_names() const;
bool register_param_tensors(const std::string& desc,
std::map<std::string, ggml_tensor*> tensors,
ResidencyMode residency_mode,
ggml_backend_t compute_backend,
ggml_backend_t params_backend,
size_t* registered_tensor_size = nullptr);
bool validate_registered_tensors();
bool prepare_params(const std::vector<ggml_tensor*>& tensors) override;
void release_compute_backend_params(const std::vector<ggml_tensor*>& tensors) override;
void release_params_backend_params(const std::vector<ggml_tensor*>& tensors) override;
};
#endif // __MODEL_MANAGER_H__

View File

@ -172,8 +172,8 @@ namespace sd::guidance {
momentum_buffer_ = deltas; momentum_buffer_ = deltas;
} }
float diff_norm = 0.0f; float diff_norm = 0.0f;
const int standard_res = 2 * 1024 / 8; // Use SDXL as the standard resolution (1024x1024, 8x8 patches, 4=2x2 channels) const int standard_res = 2 * 1024 / 8; // Use SDXL as the standard resolution (1024x1024, 8x8 patches, 4=2x2 channels)
if (params_.norm_threshold > 0.0f) { if (params_.norm_threshold > 0.0f) {
diff_norm = std::sqrt((deltas * deltas).sum()) * standard_res / std::sqrt(static_cast<float>(deltas.numel())); diff_norm = std::sqrt((deltas * deltas).sum()) * standard_res / std::sqrt(static_cast<float>(deltas.numel()));
} }

File diff suppressed because it is too large Load Diff

15
src/weight_manager.h Normal file
View File

@ -0,0 +1,15 @@
#ifndef __WEIGHT_MANAGER_H__
#define __WEIGHT_MANAGER_H__
#include <vector>
struct ggml_tensor;
struct RunnerWeightManager {
virtual ~RunnerWeightManager() = default;
virtual bool prepare_params(const std::vector<ggml_tensor*>& tensors) = 0;
virtual void release_compute_backend_params(const std::vector<ggml_tensor*>& tensors) = 0;
virtual void release_params_backend_params(const std::vector<ggml_tensor*>& tensors) = 0;
};
#endif // __WEIGHT_MANAGER_H__