add wan2.2 t2v support

This commit is contained in:
leejet 2025-08-25 00:10:16 +08:00
parent afef8cef9e
commit 079b393b6e
5 changed files with 333 additions and 173 deletions

View File

@ -202,14 +202,16 @@ struct FluxModel : public DiffusionModel {
}; };
struct WanModel : public DiffusionModel { struct WanModel : public DiffusionModel {
std::string prefix;
WAN::WanRunner wan; WAN::WanRunner wan;
WanModel(ggml_backend_t backend, WanModel(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX, const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2,
bool flash_attn = false) bool flash_attn = false)
: wan(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
} }
std::string get_desc() { std::string get_desc() {
@ -229,7 +231,7 @@ struct WanModel : public DiffusionModel {
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
wan.get_param_tensors(tensors, "model.diffusion_model"); wan.get_param_tensors(tensors, prefix);
} }
size_t get_params_buffer_size() { size_t get_params_buffer_size() {

View File

@ -56,6 +56,7 @@ struct SDParams {
std::string clip_vision_path; std::string clip_vision_path;
std::string t5xxl_path; std::string t5xxl_path;
std::string diffusion_model_path; std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path; std::string vae_path;
std::string taesd_path; std::string taesd_path;
std::string esrgan_path; std::string esrgan_path;
@ -74,22 +75,21 @@ struct SDParams {
std::string prompt; std::string prompt;
std::string negative_prompt; std::string negative_prompt;
float cfg_scale = 7.0f;
float img_cfg_scale = INFINITY;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f; float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified int clip_skip = -1; // <= 0 represents unspecified
int width = 512; int width = 512;
int height = 512; int height = 512;
int batch_count = 1; int batch_count = 1;
std::vector<int> skip_layers = {7, 8, 9};
sd_sample_params_t sample_params;
std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params;
int video_frames = 1; int video_frames = 1;
int fps = 16; int fps = 16;
sample_method_t sample_method = EULER_A;
scheduler_t scheduler = DEFAULT;
int sample_steps = 20;
float strength = 0.75f; float strength = 0.75f;
float control_strength = 0.9f; float control_strength = 0.9f;
rng_type_t rng_type = CUDA_RNG; rng_type_t rng_type = CUDA_RNG;
@ -106,17 +106,19 @@ struct SDParams {
bool color = false; bool color = false;
int upscale_repeats = 1; int upscale_repeats = 1;
std::vector<int> skip_layers = {7, 8, 9};
float slg_scale = 0.f;
float skip_layer_start = 0.01f;
float skip_layer_end = 0.2f;
bool chroma_use_dit_mask = true; bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false; bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1; int chroma_t5_mask_pad = 1;
SDParams() {
sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params);
}
}; };
void print_params(SDParams params) { void print_params(SDParams params) {
char* sample_params_str = sd_sample_params_to_str(&params.sample_params);
char* high_noise_sample_params_str = sd_sample_params_to_str(&params.high_noise_sample_params);
printf("Option: \n"); printf("Option: \n");
printf(" n_threads: %d\n", params.n_threads); printf(" n_threads: %d\n", params.n_threads);
printf(" mode: %s\n", modes_str[params.mode]); printf(" mode: %s\n", modes_str[params.mode]);
@ -127,6 +129,7 @@ void print_params(SDParams params) {
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str());
printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str());
printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
@ -152,17 +155,11 @@ void print_params(SDParams params) {
printf(" strength(control): %.2f\n", params.control_strength); printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str()); printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" eta: %.2f\n", params.eta);
printf(" clip_skip: %d\n", params.clip_skip); printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width); printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height); printf(" height: %d\n", params.height);
printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method)); printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
printf(" scheduler: %s\n", sd_schedule_name(params.scheduler)); printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
printf(" sample_steps: %d\n", params.sample_steps);
printf(" strength(img2img): %.2f\n", params.strength); printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
printf(" seed: %ld\n", params.seed); printf(" seed: %ld\n", params.seed);
@ -174,6 +171,8 @@ void print_params(SDParams params) {
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
printf(" video_frames: %d\n", params.video_frames); printf(" video_frames: %d\n", params.video_frames);
printf(" fps: %d\n", params.fps); printf(" fps: %d\n", params.fps);
free(sample_params_str);
free(high_noise_sample_params_str);
} }
void print_usage(int argc, const char* argv[]) { void print_usage(int argc, const char* argv[]) {
@ -186,6 +185,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to full model\n"); printf(" -m, --model [MODEL] path to full model\n");
printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --diffusion-model path to the standalone diffusion model\n");
printf(" --high-noise-diffusion-model path to the standalone high noise diffusion model\n");
printf(" --clip_l path to the clip-l text encoder\n"); printf(" --clip_l path to the clip-l text encoder\n");
printf(" --clip_g path to the clip-g text encoder\n"); printf(" --clip_g path to the clip-g text encoder\n");
printf(" --clip_vision path to the clip-vision encoder\n"); printf(" --clip_vision path to the clip-vision encoder\n");
@ -219,6 +219,23 @@ void print_usage(int argc, const char* argv[]) {
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n"); printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n"); printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
printf(" --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --high-noise-cfg-scale SCALE (high noise) unconditional guidance scale: (default: 7.0)\n");
printf(" --high-noise-img-cfg-scale SCALE (high noise) image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
printf(" --high-noise-guidance SCALE (high noise) distilled guidance scale for models with guidance input (default: 3.5)\n");
printf(" --high-noise-slg-scale SCALE (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --high-noise-eta SCALE (high noise) eta in DDIM, only for DDIM and TCD: (default: 0)\n");
printf(" --high-noise-skip-layers LAYERS (high noise) Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --high-noise-skip-layer-start (high noise) SLG enabling point: (default: 0.01)\n");
printf(" --high-noise-skip-layer-end END (high noise) SLG disabling point: (default: 0.2)\n");
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" (high noise) sampling method (default: \"euler_a\")\n");
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n"); printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
@ -226,13 +243,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate\n"); printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@ -420,6 +433,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip_vision", "", &params.clip_vision_path}, {"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path}, {"", "--t5xxl", "", &params.t5xxl_path},
{"", "--diffusion-model", "", &params.diffusion_model_path}, {"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
{"", "--vae", "", &params.vae_path}, {"", "--vae", "", &params.vae_path},
{"", "--taesd", "", &params.taesd_path}, {"", "--taesd", "", &params.taesd_path},
{"", "--control-net", "", &params.control_net_path}, {"", "--control-net", "", &params.control_net_path},
@ -443,7 +457,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--upscale-repeats", "", &params.upscale_repeats}, {"", "--upscale-repeats", "", &params.upscale_repeats},
{"-H", "--height", "", &params.height}, {"-H", "--height", "", &params.height},
{"-W", "--width", "", &params.width}, {"-W", "--width", "", &params.width},
{"", "--steps", "", &params.sample_steps}, {"", "--steps", "", &params.sample_params.sample_steps},
{"", "--high-noise-steps", "", &params.high_noise_sample_params.sample_steps},
{"", "--clip-skip", "", &params.clip_skip}, {"", "--clip-skip", "", &params.clip_skip},
{"-b", "--batch-count", "", &params.batch_count}, {"-b", "--batch-count", "", &params.batch_count},
{"", "--chroma-t5-mask-pad", "", &params.chroma_t5_mask_pad}, {"", "--chroma-t5-mask-pad", "", &params.chroma_t5_mask_pad},
@ -452,17 +467,23 @@ void parse_args(int argc, const char** argv, SDParams& params) {
}; };
options.float_options = { options.float_options = {
{"", "--cfg-scale", "", &params.cfg_scale}, {"", "--cfg-scale", "", &params.sample_params.guidance.txt_cfg},
{"", "--img-cfg-scale", "", &params.img_cfg_scale}, {"", "--img-cfg-scale", "", &params.sample_params.guidance.img_cfg},
{"", "--guidance", "", &params.guidance}, {"", "--guidance", "", &params.sample_params.guidance.distilled_guidance},
{"", "--eta", "", &params.eta}, {"", "--slg-scale", "", &params.sample_params.guidance.slg.scale},
{"", "--skip-layer-start", "", &params.sample_params.guidance.slg.layer_start},
{"", "--skip-layer-end", "", &params.sample_params.guidance.slg.layer_end},
{"", "--eta", "", &params.sample_params.eta},
{"", "--high-noise-cfg-scale", "", &params.high_noise_sample_params.guidance.txt_cfg},
{"", "--high-noise-img-cfg-scale", "", &params.high_noise_sample_params.guidance.img_cfg},
{"", "--high-noise-guidance", "", &params.high_noise_sample_params.guidance.distilled_guidance},
{"", "--high-noise-slg-scale", "", &params.high_noise_sample_params.guidance.slg.scale},
{"", "--high-noise-skip-layer-start", "", &params.high_noise_sample_params.guidance.slg.layer_start},
{"", "--high-noise-skip-layer-end", "", &params.high_noise_sample_params.guidance.slg.layer_end},
{"", "--high-noise-eta", "", &params.high_noise_sample_params.eta},
{"", "--strength", "", &params.strength}, {"", "--strength", "", &params.strength},
{"", "--style-ratio", "", &params.style_ratio}, {"", "--style-ratio", "", &params.style_ratio},
{"", "--control-strength", "", &params.control_strength}, {"", "--control-strength", "", &params.control_strength},
{"", "--slg-scale", "", &params.slg_scale},
{"", "--skip-layer-start", "", &params.skip_layer_start},
{"", "--skip-layer-end", "", &params.skip_layer_end},
}; };
options.bool_options = { options.bool_options = {
@ -536,8 +557,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return -1; return -1;
} }
const char* arg = argv[index]; const char* arg = argv[index];
params.scheduler = str_to_schedule(arg); params.sample_params.scheduler = str_to_schedule(arg);
if (params.scheduler == SCHEDULE_COUNT) { if (params.sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n", fprintf(stderr, "error: invalid scheduler %s\n",
arg); arg);
return -1; return -1;
@ -545,13 +566,27 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid high noise scheduler %s\n",
arg);
return -1;
}
return 1;
};
auto on_sample_method_arg = [&](int argc, const char** argv, int index) { auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
} }
const char* arg = argv[index]; const char* arg = argv[index];
params.sample_method = str_to_sample_method(arg); params.sample_params.sample_method = str_to_sample_method(arg);
if (params.sample_method == SAMPLE_METHOD_COUNT) { if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
fprintf(stderr, "error: invalid sample method %s\n", fprintf(stderr, "error: invalid sample method %s\n",
arg); arg);
return -1; return -1;
@ -559,6 +594,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.sample_method = str_to_sample_method(arg);
if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
fprintf(stderr, "error: invalid high noise sample method %s\n",
arg);
return -1;
}
return 1;
};
auto on_seed_arg = [&](int argc, const char** argv, int index) { auto on_seed_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -600,6 +649,33 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string layers_str = argv[index];
if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') {
return -1;
}
layers_str = layers_str.substr(1, layers_str.size() - 2);
std::regex regex("[, ]+");
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
std::sregex_token_iterator end;
std::vector<std::string> tokens(iter, end);
std::vector<int> layers;
for (const auto& token : tokens) {
try {
layers.push_back(std::stoi(token));
} catch (const std::invalid_argument& e) {
return -1;
}
}
params.high_noise_skip_layers = layers;
return 1;
};
auto on_ref_image_arg = [&](int argc, const char** argv, int index) { auto on_ref_image_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -616,6 +692,9 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--sampling-method", "", on_sample_method_arg}, {"", "--sampling-method", "", on_sample_method_arg},
{"", "--scheduler", "", on_schedule_arg}, {"", "--scheduler", "", on_schedule_arg},
{"", "--skip-layers", "", on_skip_layers_arg}, {"", "--skip-layers", "", on_skip_layers_arg},
{"", "--high-noise-sampling-method", "", on_high_noise_sample_method_arg},
{"", "--high-noise-scheduler", "", on_high_noise_schedule_arg},
{"", "--high-noise-skip-layers", "", on_high_noise_skip_layers_arg},
{"-r", "--ref-image", "", on_ref_image_arg}, {"-r", "--ref-image", "", on_ref_image_arg},
{"-h", "--help", "", on_help_arg}, {"-h", "--help", "", on_help_arg},
}; };
@ -657,11 +736,16 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1); exit(1);
} }
if (params.sample_steps <= 0) { if (params.sample_params.sample_steps <= 0) {
fprintf(stderr, "error: the sample_steps must be greater than 0\n"); fprintf(stderr, "error: the sample_steps must be greater than 0\n");
exit(1); exit(1);
} }
if (params.high_noise_sample_params.sample_steps <= 0) {
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
exit(1);
}
if (params.strength < 0.f || params.strength > 1.f) { if (params.strength < 0.f || params.strength > 1.f) {
fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n"); fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
exit(1); exit(1);
@ -697,8 +781,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
} }
} }
if (!isfinite(params.img_cfg_scale)) { if (!isfinite(params.sample_params.guidance.img_cfg)) {
params.img_cfg_scale = params.cfg_scale; params.sample_params.guidance.img_cfg = params.sample_params.guidance.txt_cfg;
}
if (!isfinite(params.high_noise_sample_params.guidance.img_cfg)) {
params.high_noise_sample_params.guidance.img_cfg = params.high_noise_sample_params.guidance.txt_cfg;
} }
} }
@ -719,27 +807,27 @@ std::string get_image_params(SDParams params, int64_t seed) {
if (params.negative_prompt.size() != 0) { if (params.negative_prompt.size() != 0) {
parameter_string += "Negative prompt: " + params.negative_prompt + "\n"; parameter_string += "Negative prompt: " + params.negative_prompt + "\n";
} }
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "Steps: " + std::to_string(params.sample_params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; parameter_string += "CFG scale: " + std::to_string(params.sample_params.guidance.txt_cfg) + ", ";
if (params.slg_scale != 0 && params.skip_layers.size() != 0) { if (params.sample_params.guidance.slg.scale != 0 && params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", "; parameter_string += "SLG scale: " + std::to_string(params.sample_params.guidance.txt_cfg) + ", ";
parameter_string += "Skip layers: ["; parameter_string += "Skip layers: [";
for (const auto& layer : params.skip_layers) { for (const auto& layer : params.skip_layers) {
parameter_string += std::to_string(layer) + ", "; parameter_string += std::to_string(layer) + ", ";
} }
parameter_string += "], "; parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", "; parameter_string += "Skip layer start: " + std::to_string(params.sample_params.guidance.slg.layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", "; parameter_string += "Skip layer end: " + std::to_string(params.sample_params.guidance.slg.layer_end) + ", ";
} }
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; parameter_string += "Guidance: " + std::to_string(params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(params.eta) + ", "; parameter_string += "Eta: " + std::to_string(params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + sd_basename(params.model_path) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", ";
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.scheduler != DEFAULT) { if (params.sample_params.scheduler != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.scheduler)); parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
@ -806,23 +894,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
parse_args(argc, argv, params); parse_args(argc, argv, params);
sd_guidance_params_t guidance_params = {params.cfg_scale, params.sample_params.guidance.slg.layers = params.skip_layers.data();
params.img_cfg_scale, params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
params.guidance, params.high_noise_sample_params.guidance.slg.layers = params.high_noise_skip_layers.data();
{ params.high_noise_sample_params.guidance.slg.layer_count = params.high_noise_skip_layers.size();
params.skip_layers.data(),
params.skip_layers.size(),
params.skip_layer_start,
params.skip_layer_end,
params.slg_scale,
}};
sd_sample_params_t sample_params = {
guidance_params,
params.scheduler,
params.sample_method,
params.sample_steps,
params.eta,
};
sd_set_log_callback(sd_log_cb, (void*)&params); sd_set_log_callback(sd_log_cb, (void*)&params);
@ -983,6 +1058,7 @@ int main(int argc, const char* argv[]) {
params.clip_vision_path.c_str(), params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(), params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(), params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(), params.vae_path.c_str(),
params.taesd_path.c_str(), params.taesd_path.c_str(),
params.control_net_path.c_str(), params.control_net_path.c_str(),
@ -1066,7 +1142,7 @@ int main(int argc, const char* argv[]) {
mask_image, mask_image,
params.width, params.width,
params.height, params.height,
sample_params, params.sample_params,
params.strength, params.strength,
params.seed, params.seed,
params.batch_count, params.batch_count,
@ -1087,7 +1163,8 @@ int main(int argc, const char* argv[]) {
input_image, input_image,
params.width, params.width,
params.height, params.height,
sample_params, params.sample_params,
params.high_noise_sample_params,
params.strength, params.strength,
params.seed, params.seed,
params.video_frames, params.video_frames,

View File

@ -96,6 +96,7 @@ public:
std::shared_ptr<Conditioner> cond_stage_model; std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
std::shared_ptr<DiffusionModel> diffusion_model; std::shared_ptr<DiffusionModel> diffusion_model;
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model; std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage; std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net; std::shared_ptr<ControlNet> control_net;
@ -207,6 +208,13 @@ public:
} }
} }
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
LOG_INFO("loading high noise diffusion model from '%s'", sd_ctx_params->high_noise_diffusion_model_path);
if (!model_loader.init_from_file(sd_ctx_params->high_noise_diffusion_model_path, "model.high_noise_diffusion_model.")) {
LOG_WARN("loading diffusion model from '%s' failed", sd_ctx_params->high_noise_diffusion_model_path);
}
}
bool is_unet = model_loader.model_is_unet(); bool is_unet = model_loader.model_is_unet();
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
@ -380,8 +388,17 @@ public:
diffusion_model = std::make_shared<WanModel>(backend, diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
"model.diffusion_model",
version, version,
sd_ctx_params->diffusion_flash_attn); sd_ctx_params->diffusion_flash_attn);
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
"model.high_noise_diffusion_model",
version,
sd_ctx_params->diffusion_flash_attn);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") { if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -417,6 +434,11 @@ public:
diffusion_model->alloc_params_buffer(); diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors); diffusion_model->get_param_tensors(tensors);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->alloc_params_buffer();
high_noise_diffusion_model->get_param_tensors(tensors);
}
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend"); LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init(); vae_backend = ggml_backend_cpu_init();
@ -546,6 +568,9 @@ public:
{ {
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size(); size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
size_t unet_params_mem_size = diffusion_model->get_params_buffer_size(); size_t unet_params_mem_size = diffusion_model->get_params_buffer_size();
if (high_noise_diffusion_model) {
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
}
size_t vae_params_mem_size = 0; size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
vae_params_mem_size = first_stage_model->get_params_buffer_size(); vae_params_mem_size = first_stage_model->get_params_buffer_size();
@ -923,6 +948,8 @@ public:
} }
ggml_tensor* sample(ggml_context* work_ctx, ggml_tensor* sample(ggml_context* work_ctx,
std::shared_ptr<DiffusionModel> work_diffusion_model,
bool inverse_noise_scaling,
ggml_tensor* init_latent, ggml_tensor* init_latent,
ggml_tensor* noise, ggml_tensor* noise,
SDCondition cond, SDCondition cond,
@ -952,7 +979,10 @@ public:
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
copy_ggml_tensor(x, init_latent); copy_ggml_tensor(x, init_latent);
if (noise) {
x = denoiser->noise_scaling(sigmas[0], noise, x); x = denoiser->noise_scaling(sigmas[0], noise, x);
}
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x);
@ -1015,7 +1045,7 @@ public:
if (start_merge_step == -1 || step <= start_merge_step) { if (start_merge_step == -1 || step <= start_merge_step) {
// cond // cond
diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
noised_input, noised_input,
timesteps, timesteps,
cond.c_crossattn, cond.c_crossattn,
@ -1028,7 +1058,7 @@ public:
control_strength, control_strength,
&out_cond); &out_cond);
} else { } else {
diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
noised_input, noised_input,
timesteps, timesteps,
id_cond.c_crossattn, id_cond.c_crossattn,
@ -1049,7 +1079,7 @@ public:
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls; controls = control_net->controls;
} }
diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
noised_input, noised_input,
timesteps, timesteps,
uncond.c_crossattn, uncond.c_crossattn,
@ -1066,7 +1096,7 @@ public:
float* img_cond_data = NULL; float* img_cond_data = NULL;
if (has_img_cond) { if (has_img_cond) {
diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
noised_input, noised_input,
timesteps, timesteps,
img_cond.c_crossattn, img_cond.c_crossattn,
@ -1087,7 +1117,7 @@ public:
if (is_skiplayer_step) { if (is_skiplayer_step) {
LOG_DEBUG("Skipping layers at step %d\n", step); LOG_DEBUG("Skipping layers at step %d\n", step);
// skip layer (same as conditionned) // skip layer (same as conditionned)
diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
noised_input, noised_input,
timesteps, timesteps,
cond.c_crossattn, cond.c_crossattn,
@ -1152,13 +1182,15 @@ public:
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
}
if (control_net) { if (control_net) {
control_net->free_control_ctx(); control_net->free_control_ctx();
control_net->free_compute_buffer(); control_net->free_compute_buffer();
} }
diffusion_model->free_compute_buffer(); work_diffusion_model->free_compute_buffer();
return x; return x;
} }
@ -1446,6 +1478,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"clip_vision_path: %s\n" "clip_vision_path: %s\n"
"t5xxl_path: %s\n" "t5xxl_path: %s\n"
"diffusion_model_path: %s\n" "diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n"
"vae_path: %s\n" "vae_path: %s\n"
"taesd_path: %s\n" "taesd_path: %s\n"
"control_net_path: %s\n" "control_net_path: %s\n"
@ -1472,6 +1505,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->vae_path),
SAFE_STR(sd_ctx_params->taesd_path), SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path), SAFE_STR(sd_ctx_params->control_net_path),
@ -1602,6 +1636,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t)); memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
sd_sample_params_init(&sd_vid_gen_params->sample_params); sd_sample_params_init(&sd_vid_gen_params->sample_params);
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
sd_vid_gen_params->width = 512; sd_vid_gen_params->width = 512;
sd_vid_gen_params->height = 512; sd_vid_gen_params->height = 512;
sd_vid_gen_params->strength = 0.75f; sd_vid_gen_params->strength = 0.75f;
@ -1902,6 +1937,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
} }
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
sd_ctx->sd->diffusion_model,
true,
x_t, x_t,
noise, noise,
cond, cond,
@ -2243,7 +2280,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler); sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int high_noise_sample_steps = 0;
if (sd_ctx->sd->high_noise_diffusion_model) {
sd_ctx->sd->init_scheduler(sd_vid_gen_params->high_noise_sample_params.scheduler);
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
}
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
@ -2331,7 +2374,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
sample_steps = sigmas.size() - 1;
// Get learned condition // Get learned condition
bool zero_out_masked = true; bool zero_out_masked = true;
@ -2347,7 +2389,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
cond.c_concat = concat_latent; cond.c_concat = concat_latent;
cond.c_vector = clip_vision_output; cond.c_vector = clip_vision_output;
SDCondition uncond; SDCondition uncond;
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0) { if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) {
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads, sd_ctx->sd->n_threads,
negative_prompt, negative_prompt,
@ -2372,15 +2414,50 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int C = 16; int C = 16;
struct ggml_tensor* final_latent; struct ggml_tensor* final_latent;
struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
// High Noise Sample
if (high_noise_sample_steps > 0) {
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms();
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
sigmas = std::vector<float>(sigmas.begin() + high_noise_sample_steps, sigmas.end());
x_t = sd_ctx->sd->sample(work_ctx,
sd_ctx->sd->high_noise_diffusion_model,
false,
x_t,
noise,
cond,
uncond,
{},
NULL,
0,
sd_vid_gen_params->high_noise_sample_params.guidance,
sd_vid_gen_params->high_noise_sample_params.eta,
sd_vid_gen_params->high_noise_sample_params.sample_method,
high_noise_sigmas,
-1,
{});
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->high_noise_diffusion_model->free_params_buffer();
}
noise = NULL;
}
// Sample // Sample
{ {
LOG_DEBUG("sample %dx%dx%d", W, H, T); LOG_DEBUG("sample %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
final_latent = sd_ctx->sd->sample(work_ctx, final_latent = sd_ctx->sd->sample(work_ctx,
sd_ctx->sd->diffusion_model,
true,
x_t, x_t,
noise, noise,
cond, cond,
@ -2397,11 +2474,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
}
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer(); sd_ctx->sd->diffusion_model->free_params_buffer();
} }
}
int64_t t4 = ggml_time_ms(); int64_t t4 = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000); LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000);

View File

@ -118,6 +118,7 @@ typedef struct {
const char* clip_vision_path; const char* clip_vision_path;
const char* t5xxl_path; const char* t5xxl_path;
const char* diffusion_model_path; const char* diffusion_model_path;
const char* high_noise_diffusion_model_path;
const char* vae_path; const char* vae_path;
const char* taesd_path; const char* taesd_path;
const char* control_net_path; const char* control_net_path;
@ -199,6 +200,7 @@ typedef struct {
int width; int width;
int height; int height;
sd_sample_params_t sample_params; sd_sample_params_t sample_params;
sd_sample_params_t high_noise_sample_params;
float strength; float strength;
int64_t seed; int64_t seed;
int video_frames; int video_frames;
@ -229,6 +231,9 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

View File

@ -1579,7 +1579,7 @@ namespace WAN {
wan_params.num_layers = 0; wan_params.num_layers = 0;
for (auto pair : tensor_types) { for (auto pair : tensor_types) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos) if (tensor_name.find(prefix) == std::string::npos)
continue; continue;
size_t pos = tensor_name.find("blocks."); size_t pos = tensor_name.find("blocks.");
if (pos != std::string::npos) { if (pos != std::string::npos) {