Compare commits

..

3 Commits

2 changed files with 20 additions and 2 deletions

View File

@ -2470,10 +2470,26 @@ protected:
*effective_budget_out = effective_budget; *effective_budget_out = effective_budget;
} }
// When streaming and the model dwarfs the budget, cap the planner at
// a quarter so it builds smaller merged segments and chunk-K can fit
// alongside. Without streaming the cap only adds dispatch overhead.
size_t planner_budget = effective_budget;
if (stream_layers_enabled) {
size_t total_params_bytes = 0;
for (const ggml_tensor* t : params_tensor_set_) {
if (t != nullptr) {
total_params_bytes += ggml_nbytes(t);
}
}
if (total_params_bytes * 4 > effective_budget * 3) {
planner_budget = effective_budget / 4;
}
}
*plan_out = sd::ggml_graph_cut::resolve_plan(runtime_backend, *plan_out = sd::ggml_graph_cut::resolve_plan(runtime_backend,
gf, gf,
&graph_cut_plan_cache_, &graph_cut_plan_cache_,
effective_budget, planner_budget,
params_tensor_set_, params_tensor_set_,
get_desc().c_str()); get_desc().c_str());
if (stream_layers_enabled) { if (stream_layers_enabled) {
@ -3311,6 +3327,7 @@ public:
for (auto& pair : params) { for (auto& pair : params) {
ggml_tensor* param = pair.second; ggml_tensor* param = pair.second;
tensors[prefix + pair.first] = pair.second; tensors[prefix + pair.first] = pair.second;
ggml_set_name(param, (prefix + pair.first).c_str());
} }
} }

View File

@ -173,8 +173,9 @@ namespace sd::guidance {
} }
float diff_norm = 0.0f; float diff_norm = 0.0f;
const int standard_res = 2 * 1024 / 8; // Use SDXL as the standard resolution (1024x1024, 8x8 patches, 4=2x2 channels)
if (params_.norm_threshold > 0.0f) { if (params_.norm_threshold > 0.0f) {
diff_norm = std::sqrt((deltas * deltas).sum()); diff_norm = std::sqrt((deltas * deltas).sum()) * standard_res / std::sqrt(static_cast<float>(deltas.numel()));
} }
float apg_scale_factor = 1.0f; float apg_scale_factor = 1.0f;