mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add support for timestep boundary based automatic expert routing in Wan MoE (#779)
* Wan MoE: Automatic expert routing based on timestep boundary * unify code style and fix some issues --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
cb1d975e96
commit
21ce9fe2cf
@ -89,6 +89,8 @@ struct SDParams {
|
||||
std::vector<int> high_noise_skip_layers = {7, 8, 9};
|
||||
sd_sample_params_t high_noise_sample_params;
|
||||
|
||||
float moe_boundary = 0.875f;
|
||||
|
||||
int video_frames = 1;
|
||||
int fps = 16;
|
||||
|
||||
@ -117,6 +119,7 @@ struct SDParams {
|
||||
SDParams() {
|
||||
sd_sample_params_init(&sample_params);
|
||||
sd_sample_params_init(&high_noise_sample_params);
|
||||
high_noise_sample_params.sample_steps = -1;
|
||||
}
|
||||
};
|
||||
|
||||
@ -167,6 +170,7 @@ void print_params(SDParams params) {
|
||||
printf(" height: %d\n", params.height);
|
||||
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
|
||||
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
|
||||
printf(" moe_boundary: %.3f\n", params.moe_boundary);
|
||||
printf(" strength(img2img): %.2f\n", params.strength);
|
||||
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
|
||||
printf(" seed: %ld\n", params.seed);
|
||||
@ -243,7 +247,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
|
||||
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
|
||||
printf(" (high noise) sampling method (default: \"euler_a\")\n");
|
||||
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
|
||||
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
|
||||
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
|
||||
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
|
||||
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
|
||||
@ -274,6 +278,8 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
|
||||
printf(" --video-frames video frames (default: 1)\n");
|
||||
printf(" --fps fps (default: 24)\n");
|
||||
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
|
||||
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
|
||||
printf(" -v, --verbose print extra info\n");
|
||||
}
|
||||
|
||||
@ -362,7 +368,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
bool found_arg = false;
|
||||
arg = argv[i];
|
||||
arg = argv[i];
|
||||
|
||||
for (auto& option : options.string_options) {
|
||||
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
|
||||
@ -423,7 +429,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
|
||||
for (auto& option : options.manual_options) {
|
||||
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
|
||||
found_arg = true;
|
||||
int ret = option.cb(argc, argv, i);
|
||||
int ret = option.cb(argc, argv, i);
|
||||
if (ret < 0) {
|
||||
invalid_arg = true;
|
||||
break;
|
||||
@ -435,7 +441,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
|
||||
break;
|
||||
}
|
||||
if (!found_arg) {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -507,6 +513,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
{"", "--strength", "", ¶ms.strength},
|
||||
{"", "--style-ratio", "", ¶ms.style_ratio},
|
||||
{"", "--control-strength", "", ¶ms.control_strength},
|
||||
{"", "--moe-boundary", "", ¶ms.moe_boundary},
|
||||
};
|
||||
|
||||
options.bool_options = {
|
||||
@ -767,8 +774,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
}
|
||||
|
||||
if (params.high_noise_sample_params.sample_steps <= 0) {
|
||||
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
|
||||
exit(1);
|
||||
params.high_noise_sample_params.sample_steps = -1;
|
||||
}
|
||||
|
||||
if (params.strength < 0.f || params.strength > 1.f) {
|
||||
@ -1222,6 +1228,7 @@ int main(int argc, const char* argv[]) {
|
||||
params.height,
|
||||
params.sample_params,
|
||||
params.high_noise_sample_params,
|
||||
params.moe_boundary,
|
||||
params.strength,
|
||||
params.seed,
|
||||
params.video_frames,
|
||||
|
||||
@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
|
||||
sd_sample_params_init(&sd_vid_gen_params->sample_params);
|
||||
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
|
||||
sd_vid_gen_params->width = 512;
|
||||
sd_vid_gen_params->height = 512;
|
||||
sd_vid_gen_params->strength = 0.75f;
|
||||
sd_vid_gen_params->seed = -1;
|
||||
sd_vid_gen_params->video_frames = 6;
|
||||
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
|
||||
sd_vid_gen_params->width = 512;
|
||||
sd_vid_gen_params->height = 512;
|
||||
sd_vid_gen_params->strength = 0.75f;
|
||||
sd_vid_gen_params->seed = -1;
|
||||
sd_vid_gen_params->video_frames = 6;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
}
|
||||
|
||||
struct sd_ctx_t {
|
||||
@ -2381,7 +2383,24 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
|
||||
}
|
||||
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
|
||||
int total_steps = sample_steps;
|
||||
|
||||
if (high_noise_sample_steps > 0) {
|
||||
total_steps += high_noise_sample_steps;
|
||||
}
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);
|
||||
|
||||
if (high_noise_sample_steps < 0) {
|
||||
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
|
||||
for (size_t i = 0; i < sigmas.size(); ++i) {
|
||||
if (sigmas[i] < sd_vid_gen_params->moe_boundary) {
|
||||
high_noise_sample_steps = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
LOG_DEBUG("switching from high noise model at step %d", high_noise_sample_steps);
|
||||
sample_steps = total_steps - high_noise_sample_steps;
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB
|
||||
|
||||
@ -205,6 +205,7 @@ typedef struct {
|
||||
int height;
|
||||
sd_sample_params_t sample_params;
|
||||
sd_sample_params_t high_noise_sample_params;
|
||||
float moe_boundary;
|
||||
float strength;
|
||||
int64_t seed;
|
||||
int video_frames;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user