feat: support independent sampler rng (#978)

This commit is contained in:
leejet 2025-11-16 17:11:02 +08:00 committed by GitHub
parent 6d6dc1b8ed
commit d5b05f70c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 26 deletions

View File

@ -95,6 +95,7 @@ Options:
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file type of the weight file
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui) --rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)

View File

@ -110,21 +110,22 @@ struct SDParams {
int fps = 16; int fps = 16;
float vace_strength = 1.f; float vace_strength = 1.f;
float strength = 0.75f; float strength = 0.75f;
float control_strength = 0.9f; float control_strength = 0.9f;
rng_type_t rng_type = CUDA_RNG; rng_type_t rng_type = CUDA_RNG;
int64_t seed = 42; rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool verbose = false; int64_t seed = 42;
bool offload_params_to_cpu = false; bool verbose = false;
bool control_net_cpu = false; bool offload_params_to_cpu = false;
bool clip_on_cpu = false; bool control_net_cpu = false;
bool vae_on_cpu = false; bool clip_on_cpu = false;
bool diffusion_flash_attn = false; bool vae_on_cpu = false;
bool diffusion_conv_direct = false; bool diffusion_flash_attn = false;
bool vae_conv_direct = false; bool diffusion_conv_direct = false;
bool canny_preprocess = false; bool vae_conv_direct = false;
bool color = false; bool canny_preprocess = false;
int upscale_repeats = 1; bool color = false;
int upscale_repeats = 1;
// Photo Maker // Photo Maker
std::string photo_maker_path; std::string photo_maker_path;
@ -214,6 +215,7 @@ void print_params(SDParams params) {
printf(" flow_shift: %.2f\n", params.flow_shift); printf(" flow_shift: %.2f\n", params.flow_shift);
printf(" strength(img2img): %.2f\n", params.strength); printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
printf(" sampler rng: %s\n", sd_rng_type_name(params.sampler_rng_type));
printf(" seed: %zd\n", params.seed); printf(" seed: %zd\n", params.seed);
printf(" batch_count: %d\n", params.batch_count); printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false"); printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
@ -886,6 +888,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.sampler_rng_type = str_to_rng_type(arg);
if (params.sampler_rng_type == RNG_TYPE_COUNT) {
fprintf(stderr, "error: invalid sampler rng type %s\n",
arg);
return -1;
}
return 1;
};
auto on_schedule_arg = [&](int argc, const char** argv, int index) { auto on_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -1126,6 +1142,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--rng", "--rng",
"RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)", "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)",
on_rng_arg}, on_rng_arg},
{"",
"--sampler-rng",
"sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng",
on_sampler_rng_arg},
{"-s", {"-s",
"--seed", "--seed",
"RNG seed (default: 42, use random seed for < 0)", "RNG seed (default: 42, use random seed for < 0)",
@ -1319,6 +1339,9 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + sd_basename(params.model_path) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", ";
if (params.sampler_rng_type != RNG_TYPE_COUNT) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.sample_params.scheduler != DEFAULT) { if (params.sample_params.scheduler != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
@ -1758,6 +1781,7 @@ int main(int argc, const char* argv[]) {
params.n_threads, params.n_threads,
params.wtype, params.wtype,
params.rng_type, params.rng_type,
params.sampler_rng_type,
params.prediction, params.prediction,
params.lora_apply_mode, params.lora_apply_mode,
params.offload_params_to_cpu, params.offload_params_to_cpu,

View File

@ -99,10 +99,11 @@ public:
bool vae_decode_only = false; bool vae_decode_only = false;
bool free_params_immediately = false; bool free_params_immediately = false;
std::shared_ptr<RNG> rng = std::make_shared<STDDefaultRNG>(); std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
int n_threads = -1; std::shared_ptr<RNG> sampler_rng = nullptr;
float scale_factor = 0.18215f; int n_threads = -1;
float shift_factor = 0.f; float scale_factor = 0.18215f;
float shift_factor = 0.f;
std::shared_ptr<Conditioner> cond_stage_model; std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
@ -188,6 +189,16 @@ public:
} }
} }
std::shared_ptr<RNG> get_rng(rng_type_t rng_type) {
if (rng_type == STD_DEFAULT_RNG) {
return std::make_shared<STDDefaultRNG>();
} else if (rng_type == CPU_RNG) {
return std::make_shared<MT19937RNG>();
} else { // default: CUDA_RNG
return std::make_shared<PhiloxRNG>();
}
}
bool init(const sd_ctx_params_t* sd_ctx_params) { bool init(const sd_ctx_params_t* sd_ctx_params) {
n_threads = sd_ctx_params->n_threads; n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only; vae_decode_only = sd_ctx_params->vae_decode_only;
@ -197,12 +208,11 @@ public:
use_tiny_autoencoder = taesd_path.size() > 0; use_tiny_autoencoder = taesd_path.size() > 0;
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) { rng = get_rng(sd_ctx_params->rng_type);
rng = std::make_shared<STDDefaultRNG>(); if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT) {
} else if (sd_ctx_params->rng_type == CUDA_RNG) { sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
rng = std::make_shared<PhiloxRNG>(); } else {
} else if (sd_ctx_params->rng_type == CPU_RNG) { sampler_rng = rng;
rng = std::make_shared<MT19937RNG>();
} }
ggml_log_set(ggml_log_callback_default, nullptr); ggml_log_set(ggml_log_callback_default, nullptr);
@ -1736,7 +1746,7 @@ public:
return denoised; return denoised;
}; };
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
if (inverse_noise_scaling) { if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
@ -2291,6 +2301,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->n_threads = get_num_physical_cores(); sd_ctx_params->n_threads = get_num_physical_cores();
sd_ctx_params->wtype = SD_TYPE_COUNT; sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG; sd_ctx_params->rng_type = CUDA_RNG;
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
sd_ctx_params->prediction = DEFAULT_PRED; sd_ctx_params->prediction = DEFAULT_PRED;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->offload_params_to_cpu = false;
@ -2332,6 +2343,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"n_threads: %d\n" "n_threads: %d\n"
"wtype: %s\n" "wtype: %s\n"
"rng_type: %s\n" "rng_type: %s\n"
"sampler_rng_type: %s\n"
"prediction: %s\n" "prediction: %s\n"
"offload_params_to_cpu: %s\n" "offload_params_to_cpu: %s\n"
"keep_clip_on_cpu: %s\n" "keep_clip_on_cpu: %s\n"
@ -2362,6 +2374,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->n_threads, sd_ctx_params->n_threads,
sd_type_name(sd_ctx_params->wtype), sd_type_name(sd_ctx_params->wtype),
sd_rng_type_name(sd_ctx_params->rng_type), sd_rng_type_name(sd_ctx_params->rng_type),
sd_rng_type_name(sd_ctx_params->sampler_rng_type),
sd_prediction_name(sd_ctx_params->prediction), sd_prediction_name(sd_ctx_params->prediction),
BOOL_STR(sd_ctx_params->offload_params_to_cpu), BOOL_STR(sd_ctx_params->offload_params_to_cpu),
BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
@ -2823,6 +2836,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed); LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = init_latent; struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng); ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
@ -2949,6 +2963,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
seed = rand(); seed = rand();
} }
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
sd_ctx->sd->sampler_rng->manual_seed(seed);
int sample_steps = sd_img_gen_params->sample_params.sample_steps; int sample_steps = sd_img_gen_params->sample_params.sample_steps;
@ -3240,6 +3255,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
sd_ctx->sd->sampler_rng->manual_seed(seed);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();

View File

@ -173,6 +173,7 @@ typedef struct {
int n_threads; int n_threads;
enum sd_type_t wtype; enum sd_type_t wtype;
enum rng_type_t rng_type; enum rng_type_t rng_type;
enum rng_type_t sampler_rng_type;
enum prediction_t prediction; enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode; enum lora_apply_mode_t lora_apply_mode;
bool offload_params_to_cpu; bool offload_params_to_cpu;