Compare commits

...

3 Commits

2 changed files with 20 additions and 2 deletions

View File

@ -2470,10 +2470,26 @@ protected:
*effective_budget_out = effective_budget;
}
// When streaming and the model dwarfs the budget, cap the planner at
// a quarter so it builds smaller merged segments and chunk-K can fit
// alongside. Without streaming the cap only adds dispatch overhead.
size_t planner_budget = effective_budget;
if (stream_layers_enabled) {
size_t total_params_bytes = 0;
for (const ggml_tensor* t : params_tensor_set_) {
if (t != nullptr) {
total_params_bytes += ggml_nbytes(t);
}
}
if (total_params_bytes * 4 > effective_budget * 3) {
planner_budget = effective_budget / 4;
}
}
*plan_out = sd::ggml_graph_cut::resolve_plan(runtime_backend,
gf,
&graph_cut_plan_cache_,
effective_budget,
planner_budget,
params_tensor_set_,
get_desc().c_str());
if (stream_layers_enabled) {
@ -3311,6 +3327,7 @@ public:
for (auto& pair : params) {
ggml_tensor* param = pair.second;
tensors[prefix + pair.first] = pair.second;
ggml_set_name(param, (prefix + pair.first).c_str());
}
}

View File

@ -173,8 +173,9 @@ namespace sd::guidance {
}
float diff_norm = 0.0f;
const int standard_res = 2 * 1024 / 8; // Use SDXL as the standard resolution (1024x1024, 8x8 patches, 4=2x2 channels)
if (params_.norm_threshold > 0.0f) {
diff_norm = std::sqrt((deltas * deltas).sum());
diff_norm = std::sqrt((deltas * deltas).sum()) * standard_res / std::sqrt(static_cast<float>(deltas.numel()));
}
float apg_scale_factor = 1.0f;