introduce sd_sample_params_t

This commit is contained in:
leejet 2025-08-24 17:20:41 +08:00
parent cf48441345
commit afef8cef9e
8 changed files with 187 additions and 163 deletions

View File

@ -332,7 +332,7 @@ arguments:
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
--scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage

View File

@ -777,7 +777,7 @@ public:
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
int clip_skip = -1) {
// pixel_values: [N, num_channels, image_size, image_size]
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
@ -786,7 +786,6 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
LOG_DEBUG("clip_vison skip %d", clip_skip);
x = encoder->forward(ctx, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
@ -858,7 +857,7 @@ public:
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
int clip_skip = -1) {
// pixel_values: [N, num_channels, image_size, image_size]
// return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size]
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);

View File

@ -252,7 +252,7 @@ struct KarrasSchedule : SigmaSchedule {
};
struct Denoiser {
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0;
virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0;
@ -263,7 +263,7 @@ struct Denoiser {
virtual std::vector<float> get_sigmas(uint32_t n) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
}
};
@ -349,7 +349,7 @@ struct EDMVDenoiser : public CompVisVDenoiser {
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) {
schedule = std::make_shared<ExponentialSchedule>();
scheduler = std::make_shared<ExponentialSchedule>();
}
float t_to_sigma(float t) {

View File

@ -88,7 +88,7 @@ struct SDParams {
int fps = 16;
sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT;
scheduler_t scheduler = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
float control_strength = 0.9f;
@ -161,7 +161,7 @@ void print_params(SDParams params) {
printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height);
printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method));
printf(" schedule: %s\n", sd_schedule_name(params.schedule));
printf(" scheduler: %s\n", sd_schedule_name(params.scheduler));
printf(" sample_steps: %d\n", params.sample_steps);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@ -232,7 +232,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@ -535,10 +535,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.schedule = str_to_schedule(arg);
if (params.schedule == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid schedule %s\n",
const char* arg = argv[index];
params.scheduler = str_to_schedule(arg);
if (params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n",
arg);
return -1;
}
@ -614,7 +614,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--rng", "", on_rng_arg},
{"-s", "--seed", "", on_seed_arg},
{"", "--sampling-method", "", on_sample_method_arg},
{"", "--schedule", "", on_schedule_arg},
{"", "--scheduler", "", on_schedule_arg},
{"", "--skip-layers", "", on_skip_layers_arg},
{"-r", "--ref-image", "", on_ref_image_arg},
{"-h", "--help", "", on_help_arg},
@ -738,8 +738,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", ";
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_method));
if (params.schedule != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.schedule));
if (params.scheduler != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
@ -816,6 +816,13 @@ int main(int argc, const char* argv[]) {
params.skip_layer_end,
params.slg_scale,
}};
sd_sample_params_t sample_params = {
guidance_params,
params.scheduler,
params.sample_method,
params.sample_steps,
params.eta,
};
sd_set_log_callback(sd_log_cb, (void*)&params);
@ -988,7 +995,6 @@ int main(int argc, const char* argv[]) {
params.n_threads,
params.wtype,
params.rng_type,
params.schedule,
params.offload_params_to_cpu,
params.clip_on_cpu,
params.control_net_cpu,
@ -1054,16 +1060,13 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
guidance_params,
input_image,
ref_images.data(),
(int)ref_images.size(),
mask_image,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.eta,
sample_params,
params.strength,
params.seed,
params.batch_count,
@ -1081,13 +1084,10 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
guidance_params,
input_image,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.eta,
sample_params,
params.strength,
params.seed,
params.video_frames,

View File

@ -235,6 +235,8 @@ __STATIC_INLINE__ ggml_tensor* load_tensor_from_file(ggml_context* ctx, const st
file.read(reinterpret_cast<char*>(&length), sizeof(length));
file.read(reinterpret_cast<char*>(&ttype), sizeof(ttype));
LOG_DEBUG("load_tensor_from_file %d %d %d", n_dims, length, ttype);
if (file.eof()) {
LOG_ERROR("incomplete file '%s'", file_path.c_str());
return NULL;

View File

@ -661,39 +661,6 @@ public:
LOG_INFO("running in eps-prediction mode");
}
if (sd_ctx_params->schedule != DEFAULT) {
switch (sd_ctx_params->schedule) {
case DISCRETE:
LOG_INFO("running with discrete schedule");
denoiser->schedule = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras schedule");
denoiser->schedule = std::make_shared<KarrasSchedule>();
break;
case EXPONENTIAL:
LOG_INFO("running exponential schedule");
denoiser->schedule = std::make_shared<ExponentialSchedule>();
break;
case AYS:
LOG_INFO("Running with Align-Your-Steps schedule");
denoiser->schedule = std::make_shared<AYSSchedule>();
denoiser->schedule->version = version;
break;
case GITS:
LOG_INFO("Running with GITS schedule");
denoiser->schedule = std::make_shared<GITSSchedule>();
denoiser->schedule->version = version;
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown schedule %i", sd_ctx_params->schedule);
abort();
}
}
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
@ -707,6 +674,39 @@ public:
return true;
}
void init_scheduler(scheduler_t scheduler) {
switch (scheduler) {
case DISCRETE:
LOG_INFO("running with discrete scheduler");
denoiser->scheduler = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras scheduler");
denoiser->scheduler = std::make_shared<KarrasSchedule>();
break;
case EXPONENTIAL:
LOG_INFO("running exponential scheduler");
denoiser->scheduler = std::make_shared<ExponentialSchedule>();
break;
case AYS:
LOG_INFO("Running with Align-Your-Steps scheduler");
denoiser->scheduler = std::make_shared<AYSSchedule>();
denoiser->scheduler->version = version;
break;
case GITS:
LOG_INFO("Running with GITS scheduler");
denoiser->scheduler = std::make_shared<GITSSchedule>();
denoiser->scheduler->version = version;
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown scheduler %i", scheduler);
abort();
}
}
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5);
@ -830,7 +830,7 @@ public:
ggml_tensor* get_clip_vision_output(ggml_context* work_ctx,
sd_image_t init_image,
bool return_pooled = true,
int clip_skip = -1,
int clip_skip = -1,
bool zero_out_masked = false) {
ggml_tensor* output = NULL;
if (zero_out_masked) {
@ -954,7 +954,7 @@ public:
copy_ggml_tensor(x, init_latent);
x = denoiser->noise_scaling(sigmas[0], noise, x);
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise);
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x);
bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != NULL;
bool has_img_cond = cfg_scale != img_cfg_scale && img_cond.c_crossattn != NULL;
@ -1399,17 +1399,17 @@ const char* schedule_to_str[] = {
"gits",
};
const char* sd_schedule_name(enum schedule_t schedule) {
if (schedule < SCHEDULE_COUNT) {
return schedule_to_str[schedule];
const char* sd_schedule_name(enum scheduler_t scheduler) {
if (scheduler < SCHEDULE_COUNT) {
return schedule_to_str[scheduler];
}
return NONE_STR;
}
enum schedule_t str_to_schedule(const char* str) {
enum scheduler_t str_to_schedule(const char* str) {
for (int i = 0; i < SCHEDULE_COUNT; i++) {
if (!strcmp(str, schedule_to_str[i])) {
return (enum schedule_t)i;
return (enum scheduler_t)i;
}
}
return SCHEDULE_COUNT;
@ -1423,7 +1423,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->n_threads = get_num_physical_cores();
sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG;
sd_ctx_params->schedule = DEFAULT;
sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->keep_clip_on_cpu = false;
sd_ctx_params->keep_control_net_on_cpu = false;
@ -1459,7 +1458,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"n_threads: %d\n"
"wtype: %s\n"
"rng_type: %s\n"
"schedule: %s\n"
"offload_params_to_cpu: %s\n"
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
@ -1486,7 +1484,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->n_threads,
sd_type_name(sd_ctx_params->wtype),
sd_rng_type_name(sd_ctx_params->rng_type),
sd_schedule_name(sd_ctx_params->schedule),
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
@ -1499,28 +1496,65 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
return buf;
}
void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.txt_cfg = 7.0f;
sample_params->guidance.img_cfg = INFINITY;
sample_params->guidance.distilled_guidance = 3.5f;
sample_params->guidance.slg.layer_count = 0;
sample_params->guidance.slg.layer_start = 0.01f;
sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = DEFAULT;
sample_params->sample_method = EULER_A;
sample_params->sample_steps = 20;
}
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
char* buf = (char*)malloc(4096);
if (!buf)
return NULL;
buf[0] = '\0';
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"(txt_cfg: %.2f, "
"img_cfg: %.2f, "
"distilled_guidance: %.2f, "
"slg.layer_count: %zu, "
"slg.layer_start: %.2f, "
"slg.layer_end: %.2f, "
"slg.scale: %.2f, "
"scheduler: %s, "
"sample_method: %s, "
"sample_steps: %d, "
"eta: %.2f)",
sample_params->guidance.txt_cfg,
sample_params->guidance.img_cfg,
sample_params->guidance.distilled_guidance,
sample_params->guidance.slg.layer_count,
sample_params->guidance.slg.layer_start,
sample_params->guidance.slg.layer_end,
sample_params->guidance.slg.scale,
sd_schedule_name(sample_params->scheduler),
sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps,
sample_params->eta);
return buf;
}
void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t));
sd_img_gen_params->clip_skip = -1;
sd_img_gen_params->guidance.txt_cfg = 7.0f;
sd_img_gen_params->guidance.img_cfg = INFINITY;
sd_img_gen_params->guidance.distilled_guidance = 3.5f;
sd_img_gen_params->guidance.slg.layer_count = 0;
sd_img_gen_params->guidance.slg.layer_start = 0.01f;
sd_img_gen_params->guidance.slg.layer_end = 0.2f;
sd_img_gen_params->guidance.slg.scale = 0.f;
sd_img_gen_params->ref_images_count = 0;
sd_img_gen_params->width = 512;
sd_img_gen_params->height = 512;
sd_img_gen_params->sample_method = EULER_A;
sd_img_gen_params->sample_steps = 20;
sd_img_gen_params->eta = 0.f;
sd_img_gen_params->strength = 0.75f;
sd_img_gen_params->seed = -1;
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->style_strength = 20.f;
sd_img_gen_params->normalize_input = false;
sd_img_gen_params->clip_skip = -1;
sd_sample_params_init(&sd_img_gen_params->sample_params);
sd_img_gen_params->ref_images_count = 0;
sd_img_gen_params->width = 512;
sd_img_gen_params->height = 512;
sd_img_gen_params->strength = 0.75f;
sd_img_gen_params->seed = -1;
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->style_strength = 20.f;
sd_img_gen_params->normalize_input = false;
}
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
@ -1529,22 +1563,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
return NULL;
buf[0] = '\0';
char* sample_params_str = sd_sample_params_to_str(&sd_img_gen_params->sample_params);
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"prompt: %s\n"
"negative_prompt: %s\n"
"clip_skip: %d\n"
"txt_cfg: %.2f\n"
"img_cfg: %.2f\n"
"distilled_guidance: %.2f\n"
"slg.layer_count: %zu\n"
"slg.layer_start: %.2f\n"
"slg.layer_end: %.2f\n"
"slg.scale: %.2f\n"
"width: %d\n"
"height: %d\n"
"sample_method: %s\n"
"sample_steps: %d\n"
"eta: %.2f\n"
"sample_params: %.2f\n"
"strength: %.2f\n"
"seed: %" PRId64
"\n"
@ -1557,18 +1584,9 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
sd_img_gen_params->clip_skip,
sd_img_gen_params->guidance.txt_cfg,
sd_img_gen_params->guidance.img_cfg,
sd_img_gen_params->guidance.distilled_guidance,
sd_img_gen_params->guidance.slg.layer_count,
sd_img_gen_params->guidance.slg.layer_start,
sd_img_gen_params->guidance.slg.layer_end,
sd_img_gen_params->guidance.slg.scale,
sd_img_gen_params->width,
sd_img_gen_params->height,
sd_sample_method_name(sd_img_gen_params->sample_method),
sd_img_gen_params->sample_steps,
sd_img_gen_params->eta,
SAFE_STR(sample_params_str),
sd_img_gen_params->strength,
sd_img_gen_params->seed,
sd_img_gen_params->batch_count,
@ -1577,26 +1595,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->style_strength,
BOOL_STR(sd_img_gen_params->normalize_input),
SAFE_STR(sd_img_gen_params->input_id_images_path));
free(sample_params_str);
return buf;
}
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
sd_vid_gen_params->guidance.txt_cfg = 7.0f;
sd_vid_gen_params->guidance.img_cfg = INFINITY;
sd_vid_gen_params->guidance.distilled_guidance = 3.5f;
sd_vid_gen_params->guidance.slg.layer_count = 0;
sd_vid_gen_params->guidance.slg.layer_start = 0.01f;
sd_vid_gen_params->guidance.slg.layer_end = 0.2f;
sd_vid_gen_params->guidance.slg.scale = 0.f;
sd_vid_gen_params->width = 512;
sd_vid_gen_params->height = 512;
sd_vid_gen_params->sample_method = EULER_A;
sd_vid_gen_params->sample_steps = 20;
sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6;
sd_sample_params_init(&sd_vid_gen_params->sample_params);
sd_vid_gen_params->width = 512;
sd_vid_gen_params->height = 512;
sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6;
}
struct sd_ctx_t {
@ -2043,22 +2053,25 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
sd_ctx->sd->rng->manual_seed(seed);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
size_t t0 = ggml_time_ms();
sd_ctx->sd->init_scheduler(sd_img_gen_params->sample_params.scheduler);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
ggml_tensor* init_latent = NULL;
ggml_tensor* concat_latent = NULL;
ggml_tensor* denoise_mask = NULL;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps);
if (sd_img_gen_params->init_image.data) {
LOG_INFO("IMG2IMG");
size_t t_enc = static_cast<size_t>(sd_img_gen_params->sample_steps * sd_img_gen_params->strength);
if (t_enc == sd_img_gen_params->sample_steps)
size_t t_enc = static_cast<size_t>(sample_steps * sd_img_gen_params->strength);
if (t_enc == sample_steps)
t_enc--;
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sd_img_gen_params->sample_steps - t_enc - 1, sigmas.end());
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
sigmas = sigma_sched;
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
@ -2189,11 +2202,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
sd_img_gen_params->clip_skip,
sd_img_gen_params->guidance,
sd_img_gen_params->eta,
sd_img_gen_params->sample_params.guidance,
sd_img_gen_params->sample_params.eta,
width,
height,
sd_img_gen_params->sample_method,
sd_img_gen_params->sample_params.sample_method,
sigmas,
seed,
sd_img_gen_params->batch_count,
@ -2221,13 +2234,16 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
int width = sd_vid_gen_params->width;
int height = sd_vid_gen_params->height;
int frames = sd_vid_gen_params->video_frames;
frames = (frames - 1) / 4 * 4 + 1;
int width = sd_vid_gen_params->width;
int height = sd_vid_gen_params->height;
int frames = sd_vid_gen_params->video_frames;
frames = (frames - 1) / 4 * 4 + 1;
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
@ -2315,7 +2331,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
int sample_steps = sigmas.size() - 1;
sample_steps = sigmas.size() - 1;
// Get learned condition
bool zero_out_masked = true;
@ -2331,7 +2347,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
cond.c_concat = concat_latent;
cond.c_vector = clip_vision_output;
SDCondition uncond;
if (sd_vid_gen_params->guidance.txt_cfg != 1.0) {
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0) {
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
negative_prompt,
@ -2372,9 +2388,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
{},
NULL,
0,
sd_vid_gen_params->guidance,
sd_vid_gen_params->eta,
sd_vid_gen_params->sample_method,
sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta,
sd_vid_gen_params->sample_params.sample_method,
sigmas,
-1,
{});

View File

@ -50,7 +50,7 @@ enum sample_method_t {
SAMPLE_METHOD_COUNT
};
enum schedule_t {
enum scheduler_t {
DEFAULT,
DISCRETE,
KARRAS,
@ -130,7 +130,6 @@ typedef struct {
int n_threads;
enum sd_type_t wtype;
enum rng_type_t rng_type;
enum schedule_t schedule;
bool offload_params_to_cpu;
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
@ -163,20 +162,25 @@ typedef struct {
sd_slg_params_t slg;
} sd_guidance_params_t;
typedef struct {
sd_guidance_params_t guidance;
enum scheduler_t scheduler;
enum sample_method_t sample_method;
int sample_steps;
float eta;
} sd_sample_params_t;
typedef struct {
const char* prompt;
const char* negative_prompt;
int clip_skip;
sd_guidance_params_t guidance;
sd_image_t init_image;
sd_image_t* ref_images;
int ref_images_count;
sd_image_t mask_image;
int width;
int height;
enum sample_method_t sample_method;
int sample_steps;
float eta;
sd_sample_params_t sample_params;
float strength;
int64_t seed;
int batch_count;
@ -191,13 +195,10 @@ typedef struct {
const char* prompt;
const char* negative_prompt;
int clip_skip;
sd_guidance_params_t guidance;
sd_image_t init_image;
int width;
int height;
enum sample_method_t sample_method;
int sample_steps;
float eta;
sd_sample_params_t sample_params;
float strength;
int64_t seed;
int video_frames;
@ -219,8 +220,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
SD_API enum rng_type_t str_to_rng_type(const char* str);
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
SD_API enum sample_method_t str_to_sample_method(const char* str);
SD_API const char* sd_schedule_name(enum schedule_t schedule);
SD_API enum schedule_t str_to_schedule(const char* str);
SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str);
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

24
wan.hpp
View File

@ -941,8 +941,8 @@ namespace WAN {
};
static void load_from_file_and_test(const std::string& file_path) {
ggml_backend_t backend = ggml_backend_cuda_init(0);
// ggml_backend_t backend = ggml_backend_cpu_init();
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend, false));
{
@ -1099,6 +1099,8 @@ namespace WAN {
if (qk_norm) {
blocks["norm_k_img"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
} else {
blocks["norm_k_img"] = std::shared_ptr<GGMLBlock>(new Identity());
}
}
@ -1705,7 +1707,7 @@ namespace WAN {
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
params.mem_size = static_cast<size_t>(200 * 1024 * 1024); // 200 MB
params.mem_buffer = NULL;
params.no_alloc = false;
@ -1719,20 +1721,22 @@ namespace WAN {
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16);
// ggml_set_f32(x, 0.01f);
auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin");
// print_ggml_tensor(x);
print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 999.f);
std::vector<float> timesteps_vec(1, 1000.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1);
// ggml_set_f32(context, 0.01f);
auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin");
// print_ggml_tensor(context);
print_ggml_tensor(context);
auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin");
print_ggml_tensor(clip_fea);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx);
compute(8, x, timesteps, context, clip_fea, NULL, NULL, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
@ -1754,7 +1758,7 @@ namespace WAN {
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
LOG_DEBUG("%s %u", item.first.c_str(), item.second);
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
}
@ -1763,7 +1767,9 @@ namespace WAN {
std::shared_ptr<WanRunner> wan = std::shared_ptr<WanRunner>(new WanRunner(backend,
false,
tensor_types,
"model.diffusion_model"));
"model.diffusion_model",
VERSION_WAN2,
true));
wan->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;