Compare commits

...

4 Commits

Author SHA1 Message Date
leejet
f440ad9c29
fix: avoid writable mmap for read-only weights (#1698) 2026-06-23 00:39:31 +08:00
stduhpf
41f7acbfb0
feat: support guidance_schedule (#1684) 2026-06-23 00:05:55 +08:00
leejet
b395a6972d
refactor: add Flux VAE version helper (#1696) 2026-06-22 22:39:42 +08:00
Alex Klinkhamer
854bebfe02
feat: add --prompt-file and --negative-prompt-file flags (#1693) 2026-06-22 22:16:54 +08:00
8 changed files with 194 additions and 24 deletions

View File

@ -6,6 +6,7 @@
#include <cstdlib>
#include <ctime>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <regex>
@ -260,15 +261,15 @@ bool parse_options(int argc, const char** argv, const std::vector<ArgOptions>& o
invalid_arg = true;
return;
}
if(option.concat && !option.target->empty()){
if(option.concat > 0 && option.concat <= 0xff){
if (option.concat && !option.target->empty()) {
if (option.concat > 0 && option.concat <= 0xff) {
*option.target += static_cast<char>(option.concat);
}
*option.target += argv_to_utf8(i, argv);
} else {
*option.target = argv_to_utf8(i, argv);
}
found_arg = true;
found_arg = true;
}))
break;
@ -959,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
(int)',',
&extra_sample_args},
{"",
@ -1421,6 +1422,42 @@ ArgOptions SDGenerationParams::get_options() {
return 1;
};
auto on_prompt_file_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
std::ifstream f(arg, std::ios::binary);
try {
prompt = std::string(std::istreambuf_iterator<char>{f}, {});
} catch (const std::ios_base::failure&) {
f.setstate(std::ios_base::failbit);
}
if (f.fail()) {
LOG_ERROR("error: failed to read prompt file '%s'\n", arg);
return -1;
}
return 1;
};
auto on_negative_prompt_file_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
std::ifstream f(arg, std::ios::binary);
try {
negative_prompt = std::string(std::istreambuf_iterator<char>{f}, {});
} catch (const std::ios_base::failure&) {
f.setstate(std::ios_base::failbit);
}
if (f.fail()) {
LOG_ERROR("error: failed to read negative prompt file '%s'\n", arg);
return -1;
}
return 1;
};
options.manual_options = {
{"-s",
"--seed",
@ -1484,6 +1521,14 @@ ArgOptions SDGenerationParams::get_options() {
"--vae-relative-tile-size",
"relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
on_relative_tile_size_arg},
{"",
"--prompt-file",
"path to the file containing the prompt to render",
on_prompt_file_arg},
{"",
"--negative-prompt-file",
"path to the file containing the negative prompt",
on_negative_prompt_file_arg},
};

View File

@ -186,6 +186,13 @@ static inline bool sd_version_is_ideogram4(SDVersion version) {
return false;
}
static inline bool sd_version_uses_flux_vae(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
return true;
}
return false;
}
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version) || sd_version_is_ideogram4(version)) {
return true;

View File

@ -682,7 +682,7 @@ struct AutoEncoderKL : public VAE {
} else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f;
shift_factor = 0.0609f;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
} else if (sd_version_uses_flux_vae(version)) {
scale_factor = 0.3611f;
shift_factor = 0.1159f;
} else if (sd_version_uses_flux2_vae(version)) {

View File

@ -480,7 +480,7 @@ bool ModelManager::mmap_params(const std::vector<TensorState*>& states,
return true;
}
auto mmap_store = model_loader_.mmap_tensors(mmap_candidates, {}, true);
auto mmap_store = model_loader_.mmap_tensors(mmap_candidates, {}, writable_mmap_);
if (mmap_store.empty()) {
return true;
}

View File

@ -69,6 +69,7 @@ private:
uint64_t current_lora_epoch_ = 0;
int n_threads_ = 0;
bool enable_mmap_ = false;
bool writable_mmap_ = false;
void finish_compute_backend_usage(const std::vector<TensorState*>& states);
void release_all();
@ -110,6 +111,7 @@ public:
model_loader_.set_n_threads(n_threads);
}
void set_enable_mmap(bool enable_mmap) { enable_mmap_ = enable_mmap; }
void set_writable_mmap(bool writable_mmap) { writable_mmap_ = writable_mmap; }
void set_common_ignore_tensors(std::set<std::string> ignore_tensors);
void set_loras(std::vector<LoraSpec> loras, SDVersion version);

View File

@ -3,6 +3,7 @@
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <optional>
#include <string>
#include <utility>
@ -63,6 +64,82 @@ namespace sd::guidance {
return uncond;
}
std::vector<float> parse_guidance_schedule_from_spec(std::string spec) {
std::vector<float> schedule;
while (!spec.empty()) {
auto sep = spec.find('+');
auto segment = spec.substr(0, sep);
auto x = segment.find('x');
if (x == std::string::npos) {
LOG_ERROR("Invalid guidance schedule segment: '%s' (expected <guidance>x<count>)", segment.c_str());
return {};
}
float guidance;
int count;
auto guidance_str = segment.substr(0, x);
auto count_str = segment.substr(x + 1);
try {
size_t idx = 0;
guidance = std::stof(guidance_str, &idx);
if (idx != guidance_str.size()) {
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
return {};
}
} catch (const std::exception&) {
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
return {};
}
try {
size_t idx = 0;
count = std::stoi(count_str, &idx);
if (idx != count_str.size()) {
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
return {};
}
} catch (const std::exception&) {
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
return {};
}
if (count <= 0) {
LOG_ERROR("Guidance schedule count must be positive");
return {};
}
schedule.insert(schedule.end(), count, guidance);
if (sep == std::string::npos) {
break;
}
spec = spec.substr(sep + 1);
}
return schedule;
}
std::vector<float> parse_guidance_schedule(const char* extra_sample_args) {
std::vector<float> guidance_schedule;
std::string guidance_schedule_str = "";
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "extra sample arg")) {
float parsed = 0.0f;
if (key == "guidance_schedule") {
guidance_schedule_str = value;
}
}
if (!guidance_schedule_str.empty()) {
guidance_schedule = parse_guidance_schedule_from_spec(guidance_schedule_str);
}
return guidance_schedule;
}
ClassifierFreeGuidance::ClassifierFreeGuidance(float guidance_scale,
float image_guidance_scale)
: guidance_scale_(guidance_scale),
@ -70,8 +147,10 @@ namespace sd::guidance {
}
GuiderOutput ClassifierFreeGuidance::forward(const GuidanceInput& input,
GuiderOutput previous) const {
GuiderOutput previous,
std::optional<float> scale_override) const {
(void)previous;
float guidance_scale = scale_override.value_or(guidance_scale_);
GuiderOutput output;
if (!has_tensor(input.pred_cond)) {
@ -86,14 +165,14 @@ namespace sd::guidance {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_uncond);
guidance_scale * (pred_cond - pred_uncond);
} else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
}
} else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
}
return output;
@ -128,8 +207,10 @@ namespace sd::guidance {
}
GuiderOutput AdaptiveProjectedGuidance::forward(const GuidanceInput& input,
GuiderOutput previous) const {
GuiderOutput previous,
std::optional<float> scale_override) const {
(void)previous;
float guidance_scale = scale_override.value_or(guidance_scale_);
GuiderOutput output;
if (!has_tensor(input.pred_cond)) {
@ -144,13 +225,13 @@ namespace sd::guidance {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_uncond);
guidance_scale * (pred_cond - pred_uncond);
} else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
}
} else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
}
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_uncond)) {
return output;
@ -162,7 +243,7 @@ namespace sd::guidance {
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
pred_uncond,
pred_img_uncond,
guidance_scale_,
guidance_scale,
image_guidance_scale_);
if (params_.momentum != 0.0f) {
if (momentum_buffer_.shape() != deltas.shape()) {
@ -239,7 +320,8 @@ namespace sd::guidance {
}
GuiderOutput SkipLayerGuidance::forward(const GuidanceInput& input,
GuiderOutput output) const {
GuiderOutput output,
std::optional<float> /*scale_override*/) const {
if (scale_ == 0.0f || !is_enabled_for_step(input) || !input.predict_skip_layer) {
return output;
}

View File

@ -3,6 +3,7 @@
#include <cstddef>
#include <functional>
#include <optional>
#include <vector>
#include "core/tensor.hpp"
@ -27,6 +28,7 @@ namespace sd::guidance {
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args);
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params);
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args);
std::vector<float> parse_guidance_schedule(const char* extra_sample_args);
struct GuidanceInput {
int step = 0;
@ -40,9 +42,10 @@ namespace sd::guidance {
class BaseGuidance {
public:
virtual ~BaseGuidance() = default;
virtual ~BaseGuidance() = default;
virtual GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const = 0;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const = 0;
};
class ClassifierFreeGuidance : public BaseGuidance {
@ -54,7 +57,8 @@ namespace sd::guidance {
float image_guidance_scale);
GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};
class AdaptiveProjectedGuidance : public BaseGuidance {
@ -69,7 +73,8 @@ namespace sd::guidance {
AdaptiveProjectedGuidanceParams params);
GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};
class SkipLayerGuidance : public BaseGuidance {
@ -88,7 +93,8 @@ namespace sd::guidance {
const std::vector<int>& layers() const;
GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
GuiderOutput previous,
std::optional<float> scale_override = std::nullopt) const override;
};
} // namespace sd::guidance

View File

@ -532,7 +532,6 @@ public:
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
model_loader.set_wtype_override(wtype, tensor_type_rules);
}
model_loader.process_model_files(enable_mmap, true);
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
std::map<ggml_type, uint32_t> conditioner_wtype_stat = model_loader.get_conditioner_wtype_stat();
@ -586,9 +585,12 @@ public:
apply_lora_immediately = false;
}
bool needs_writable_mmap = enable_mmap && apply_lora_immediately;
model_manager->set_writable_mmap(needs_writable_mmap);
if (enable_mmap && apply_lora_immediately) {
LOG_WARN("in mode 'immediately', LoRAs will cause extra memory usage with mmap");
}
model_loader.process_model_files(enable_mmap, needs_writable_mmap);
load_alphas_cumprod(model_loader);
size_t text_encoder_params_mem_size = 0;
@ -1719,7 +1721,7 @@ public:
if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
} else if (sd_version_uses_flux_vae(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
@ -1942,6 +1944,32 @@ public:
float slg_scale = guidance.slg.scale;
bool slg_uncond = sd::guidance::parse_skip_layer_guidance_uncond_arg(extra_sample_args);
std::vector<float> guidance_schedule = sd::guidance::parse_guidance_schedule(extra_sample_args);
if (!guidance_schedule.empty() && guidance_schedule.size() != sigmas.size() - 1) {
if (guidance_schedule.size() > sigmas.size()) {
LOG_WARN("guidance_schedule length (%zu) is greater than number of steps (%zu)", guidance_schedule.size(), sigmas.size() - 1);
LOG_WARN("truncating guidance_schedule to match step count");
guidance_schedule.resize(sigmas.size() - 1);
} else {
LOG_INFO("padding guidance_schedule with cfg_scale");
while (guidance_schedule.size() < sigmas.size() - 1) {
guidance_schedule.push_back(cfg_scale);
}
}
}
if (!guidance_schedule.empty()) {
std::string schedule_str = "[";
for (size_t i = 0; i < guidance_schedule.size(); ++i) {
schedule_str += std::to_string(guidance_schedule[i]);
if (i < guidance_schedule.size() - 1) {
schedule_str += ", ";
}
}
schedule_str += "]";
LOG_DEBUG("using guidance schedule: %s", schedule_str.c_str());
}
sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version,
cache_params,
denoiser.get(),
@ -2182,7 +2210,7 @@ public:
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
guidance_input.pred_img_uncond = img_uncond_out.empty() ? nullptr : &img_uncond_out;
sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {});
sd::guidance::GuiderOutput guided = guidance_schedule.empty() ? primary_guidance.forward(guidance_input, {}) : primary_guidance.forward(guidance_input, {}, guidance_schedule[guidance_schedule.size() - 1 - step]);
if (guided.pred.empty()) {
return {};
}