diff --git a/README.md b/README.md index 89eb095..295d21b 100644 --- a/README.md +++ b/README.md @@ -332,7 +332,7 @@ arguments: --rng {std_default, cuda} RNG (default: cuda) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate - --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) + --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete) --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage diff --git a/clip.hpp b/clip.hpp index 1f64eee..ec2e173 100644 --- a/clip.hpp +++ b/clip.hpp @@ -777,7 +777,7 @@ public: struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true, - int clip_skip = -1) { + int clip_skip = -1) { // pixel_values: [N, num_channels, image_size, image_size] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto pre_layernorm = std::dynamic_pointer_cast(blocks["pre_layernorm"]); @@ -786,7 +786,6 @@ public: auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - LOG_DEBUG("clip_vison skip %d", clip_skip); x = encoder->forward(ctx, x, clip_skip, false); // print_ggml_tensor(x, true, "ClipVisionModel x: "); auto last_hidden_state = x; @@ -858,7 +857,7 @@ public: struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true, - int clip_skip = -1) { + int clip_skip = -1) { // pixel_values: [N, num_channels, image_size, image_size] // return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size] auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); diff --git a/denoiser.hpp b/denoiser.hpp index d4bcec5..385bcfb 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -252,7 +252,7 @@ struct KarrasSchedule : SigmaSchedule { }; struct Denoiser { - std::shared_ptr schedule = std::make_shared(); + std::shared_ptr scheduler = std::make_shared(); virtual float sigma_min() = 0; virtual float sigma_max() = 0; virtual float sigma_to_t(float sigma) = 0; @@ -263,7 +263,7 @@ struct Denoiser { virtual std::vector get_sigmas(uint32_t n) { auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); - return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); + return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); } }; @@ -349,7 +349,7 @@ struct EDMVDenoiser : public CompVisVDenoiser { EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) { - schedule = std::make_shared(); + scheduler = std::make_shared(); } float t_to_sigma(float t) { diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 32013ab..bbbc8b1 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -88,7 +88,7 @@ struct SDParams { int fps = 16; sample_method_t sample_method = EULER_A; - schedule_t schedule = DEFAULT; + scheduler_t scheduler = DEFAULT; int sample_steps = 20; float strength = 0.75f; float control_strength = 0.9f; @@ -161,7 +161,7 @@ void print_params(SDParams params) { printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method)); - printf(" schedule: %s\n", sd_schedule_name(params.schedule)); + printf(" scheduler: %s\n", sd_schedule_name(params.scheduler)); printf(" sample_steps: %d\n", params.sample_steps); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); @@ -232,7 +232,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate\n"); - printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); + printf(" --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n"); printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); @@ -535,10 +535,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { if (++index >= argc) { return -1; } - const char* arg = argv[index]; - params.schedule = str_to_schedule(arg); - if (params.schedule == SCHEDULE_COUNT) { - fprintf(stderr, "error: invalid schedule %s\n", + const char* arg = argv[index]; + params.scheduler = str_to_schedule(arg); + if (params.scheduler == SCHEDULE_COUNT) { + fprintf(stderr, "error: invalid scheduler %s\n", arg); return -1; } @@ -614,7 +614,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--rng", "", on_rng_arg}, {"-s", "--seed", "", on_seed_arg}, {"", "--sampling-method", "", on_sample_method_arg}, - {"", "--schedule", "", on_schedule_arg}, + {"", "--scheduler", "", on_schedule_arg}, {"", "--skip-layers", "", on_skip_layers_arg}, {"-r", "--ref-image", "", on_ref_image_arg}, {"-h", "--help", "", on_help_arg}, @@ -738,8 +738,8 @@ std::string get_image_params(SDParams params, int64_t seed) { parameter_string += "Model: " + sd_basename(params.model_path) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", "; parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_method)); - if (params.schedule != DEFAULT) { - parameter_string += " " + std::string(sd_schedule_name(params.schedule)); + if (params.scheduler != DEFAULT) { + parameter_string += " " + std::string(sd_schedule_name(params.scheduler)); } parameter_string += ", "; for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) { @@ -816,6 +816,13 @@ int main(int argc, const char* argv[]) { params.skip_layer_end, params.slg_scale, }}; + sd_sample_params_t sample_params = { + guidance_params, + params.scheduler, + params.sample_method, + params.sample_steps, + params.eta, + }; sd_set_log_callback(sd_log_cb, (void*)¶ms); @@ -988,7 +995,6 @@ int main(int argc, const char* argv[]) { params.n_threads, params.wtype, params.rng_type, - params.schedule, params.offload_params_to_cpu, params.clip_on_cpu, params.control_net_cpu, @@ -1054,16 +1060,13 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - guidance_params, input_image, ref_images.data(), (int)ref_images.size(), mask_image, params.width, params.height, - params.sample_method, - params.sample_steps, - params.eta, + sample_params, params.strength, params.seed, params.batch_count, @@ -1081,13 +1084,10 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - guidance_params, input_image, params.width, params.height, - params.sample_method, - params.sample_steps, - params.eta, + sample_params, params.strength, params.seed, params.video_frames, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 28fe308..56b7658 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -235,6 +235,8 @@ __STATIC_INLINE__ ggml_tensor* load_tensor_from_file(ggml_context* ctx, const st file.read(reinterpret_cast(&length), sizeof(length)); file.read(reinterpret_cast(&ttype), sizeof(ttype)); + LOG_DEBUG("load_tensor_from_file %d %d %d", n_dims, length, ttype); + if (file.eof()) { LOG_ERROR("incomplete file '%s'", file_path.c_str()); return NULL; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index a06bfdc..64b57f7 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -661,39 +661,6 @@ public: LOG_INFO("running in eps-prediction mode"); } - if (sd_ctx_params->schedule != DEFAULT) { - switch (sd_ctx_params->schedule) { - case DISCRETE: - LOG_INFO("running with discrete schedule"); - denoiser->schedule = std::make_shared(); - break; - case KARRAS: - LOG_INFO("running with Karras schedule"); - denoiser->schedule = std::make_shared(); - break; - case EXPONENTIAL: - LOG_INFO("running exponential schedule"); - denoiser->schedule = std::make_shared(); - break; - case AYS: - LOG_INFO("Running with Align-Your-Steps schedule"); - denoiser->schedule = std::make_shared(); - denoiser->schedule->version = version; - break; - case GITS: - LOG_INFO("Running with GITS schedule"); - denoiser->schedule = std::make_shared(); - denoiser->schedule->version = version; - break; - case DEFAULT: - // Don't touch anything. - break; - default: - LOG_ERROR("Unknown schedule %i", sd_ctx_params->schedule); - abort(); - } - } - auto comp_vis_denoiser = std::dynamic_pointer_cast(denoiser); if (comp_vis_denoiser) { for (int i = 0; i < TIMESTEPS; i++) { @@ -707,6 +674,39 @@ public: return true; } + void init_scheduler(scheduler_t scheduler) { + switch (scheduler) { + case DISCRETE: + LOG_INFO("running with discrete scheduler"); + denoiser->scheduler = std::make_shared(); + break; + case KARRAS: + LOG_INFO("running with Karras scheduler"); + denoiser->scheduler = std::make_shared(); + break; + case EXPONENTIAL: + LOG_INFO("running exponential scheduler"); + denoiser->scheduler = std::make_shared(); + break; + case AYS: + LOG_INFO("Running with Align-Your-Steps scheduler"); + denoiser->scheduler = std::make_shared(); + denoiser->scheduler->version = version; + break; + case GITS: + LOG_INFO("Running with GITS scheduler"); + denoiser->scheduler = std::make_shared(); + denoiser->scheduler->version = version; + break; + case DEFAULT: + // Don't touch anything. + break; + default: + LOG_ERROR("Unknown scheduler %i", scheduler); + abort(); + } + } + bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); ggml_set_f32(x_t, 0.5); @@ -830,7 +830,7 @@ public: ggml_tensor* get_clip_vision_output(ggml_context* work_ctx, sd_image_t init_image, bool return_pooled = true, - int clip_skip = -1, + int clip_skip = -1, bool zero_out_masked = false) { ggml_tensor* output = NULL; if (zero_out_masked) { @@ -954,7 +954,7 @@ public: copy_ggml_tensor(x, init_latent); x = denoiser->noise_scaling(sigmas[0], noise, x); - struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise); + struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x); bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != NULL; bool has_img_cond = cfg_scale != img_cfg_scale && img_cond.c_crossattn != NULL; @@ -1399,17 +1399,17 @@ const char* schedule_to_str[] = { "gits", }; -const char* sd_schedule_name(enum schedule_t schedule) { - if (schedule < SCHEDULE_COUNT) { - return schedule_to_str[schedule]; +const char* sd_schedule_name(enum scheduler_t scheduler) { + if (scheduler < SCHEDULE_COUNT) { + return schedule_to_str[scheduler]; } return NONE_STR; } -enum schedule_t str_to_schedule(const char* str) { +enum scheduler_t str_to_schedule(const char* str) { for (int i = 0; i < SCHEDULE_COUNT; i++) { if (!strcmp(str, schedule_to_str[i])) { - return (enum schedule_t)i; + return (enum scheduler_t)i; } } return SCHEDULE_COUNT; @@ -1423,7 +1423,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->n_threads = get_num_physical_cores(); sd_ctx_params->wtype = SD_TYPE_COUNT; sd_ctx_params->rng_type = CUDA_RNG; - sd_ctx_params->schedule = DEFAULT; sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false; @@ -1459,7 +1458,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "n_threads: %d\n" "wtype: %s\n" "rng_type: %s\n" - "schedule: %s\n" "offload_params_to_cpu: %s\n" "keep_clip_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n" @@ -1486,7 +1484,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->n_threads, sd_type_name(sd_ctx_params->wtype), sd_rng_type_name(sd_ctx_params->rng_type), - sd_schedule_name(sd_ctx_params->schedule), BOOL_STR(sd_ctx_params->offload_params_to_cpu), BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), @@ -1499,28 +1496,65 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { return buf; } +void sd_sample_params_init(sd_sample_params_t* sample_params) { + sample_params->guidance.txt_cfg = 7.0f; + sample_params->guidance.img_cfg = INFINITY; + sample_params->guidance.distilled_guidance = 3.5f; + sample_params->guidance.slg.layer_count = 0; + sample_params->guidance.slg.layer_start = 0.01f; + sample_params->guidance.slg.layer_end = 0.2f; + sample_params->guidance.slg.scale = 0.f; + sample_params->scheduler = DEFAULT; + sample_params->sample_method = EULER_A; + sample_params->sample_steps = 20; +} + +char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { + char* buf = (char*)malloc(4096); + if (!buf) + return NULL; + buf[0] = '\0'; + + snprintf(buf + strlen(buf), 4096 - strlen(buf), + "(txt_cfg: %.2f, " + "img_cfg: %.2f, " + "distilled_guidance: %.2f, " + "slg.layer_count: %zu, " + "slg.layer_start: %.2f, " + "slg.layer_end: %.2f, " + "slg.scale: %.2f, " + "scheduler: %s, " + "sample_method: %s, " + "sample_steps: %d, " + "eta: %.2f)", + sample_params->guidance.txt_cfg, + sample_params->guidance.img_cfg, + sample_params->guidance.distilled_guidance, + sample_params->guidance.slg.layer_count, + sample_params->guidance.slg.layer_start, + sample_params->guidance.slg.layer_end, + sample_params->guidance.slg.scale, + sd_schedule_name(sample_params->scheduler), + sd_sample_method_name(sample_params->sample_method), + sample_params->sample_steps, + sample_params->eta); + + return buf; +} + void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t)); - sd_img_gen_params->clip_skip = -1; - sd_img_gen_params->guidance.txt_cfg = 7.0f; - sd_img_gen_params->guidance.img_cfg = INFINITY; - sd_img_gen_params->guidance.distilled_guidance = 3.5f; - sd_img_gen_params->guidance.slg.layer_count = 0; - sd_img_gen_params->guidance.slg.layer_start = 0.01f; - sd_img_gen_params->guidance.slg.layer_end = 0.2f; - sd_img_gen_params->guidance.slg.scale = 0.f; - sd_img_gen_params->ref_images_count = 0; - sd_img_gen_params->width = 512; - sd_img_gen_params->height = 512; - sd_img_gen_params->sample_method = EULER_A; - sd_img_gen_params->sample_steps = 20; - sd_img_gen_params->eta = 0.f; - sd_img_gen_params->strength = 0.75f; - sd_img_gen_params->seed = -1; - sd_img_gen_params->batch_count = 1; - sd_img_gen_params->control_strength = 0.9f; - sd_img_gen_params->style_strength = 20.f; - sd_img_gen_params->normalize_input = false; + sd_img_gen_params->clip_skip = -1; + sd_sample_params_init(&sd_img_gen_params->sample_params); + sd_img_gen_params->ref_images_count = 0; + sd_img_gen_params->width = 512; + sd_img_gen_params->height = 512; + sd_img_gen_params->strength = 0.75f; + sd_img_gen_params->seed = -1; + sd_img_gen_params->batch_count = 1; + sd_img_gen_params->control_strength = 0.9f; + sd_img_gen_params->style_strength = 20.f; + sd_img_gen_params->normalize_input = false; } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -1529,22 +1563,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { return NULL; buf[0] = '\0'; + char* sample_params_str = sd_sample_params_to_str(&sd_img_gen_params->sample_params); + snprintf(buf + strlen(buf), 4096 - strlen(buf), "prompt: %s\n" "negative_prompt: %s\n" "clip_skip: %d\n" - "txt_cfg: %.2f\n" - "img_cfg: %.2f\n" - "distilled_guidance: %.2f\n" - "slg.layer_count: %zu\n" - "slg.layer_start: %.2f\n" - "slg.layer_end: %.2f\n" - "slg.scale: %.2f\n" "width: %d\n" "height: %d\n" - "sample_method: %s\n" - "sample_steps: %d\n" - "eta: %.2f\n" + "sample_params: %.2f\n" "strength: %.2f\n" "seed: %" PRId64 "\n" @@ -1557,18 +1584,9 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, - sd_img_gen_params->guidance.txt_cfg, - sd_img_gen_params->guidance.img_cfg, - sd_img_gen_params->guidance.distilled_guidance, - sd_img_gen_params->guidance.slg.layer_count, - sd_img_gen_params->guidance.slg.layer_start, - sd_img_gen_params->guidance.slg.layer_end, - sd_img_gen_params->guidance.slg.scale, sd_img_gen_params->width, sd_img_gen_params->height, - sd_sample_method_name(sd_img_gen_params->sample_method), - sd_img_gen_params->sample_steps, - sd_img_gen_params->eta, + SAFE_STR(sample_params_str), sd_img_gen_params->strength, sd_img_gen_params->seed, sd_img_gen_params->batch_count, @@ -1577,26 +1595,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->style_strength, BOOL_STR(sd_img_gen_params->normalize_input), SAFE_STR(sd_img_gen_params->input_id_images_path)); - + free(sample_params_str); return buf; } void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t)); - sd_vid_gen_params->guidance.txt_cfg = 7.0f; - sd_vid_gen_params->guidance.img_cfg = INFINITY; - sd_vid_gen_params->guidance.distilled_guidance = 3.5f; - sd_vid_gen_params->guidance.slg.layer_count = 0; - sd_vid_gen_params->guidance.slg.layer_start = 0.01f; - sd_vid_gen_params->guidance.slg.layer_end = 0.2f; - sd_vid_gen_params->guidance.slg.scale = 0.f; - sd_vid_gen_params->width = 512; - sd_vid_gen_params->height = 512; - sd_vid_gen_params->sample_method = EULER_A; - sd_vid_gen_params->sample_steps = 20; - sd_vid_gen_params->strength = 0.75f; - sd_vid_gen_params->seed = -1; - sd_vid_gen_params->video_frames = 6; + sd_sample_params_init(&sd_vid_gen_params->sample_params); + sd_vid_gen_params->width = 512; + sd_vid_gen_params->height = 512; + sd_vid_gen_params->strength = 0.75f; + sd_vid_gen_params->seed = -1; + sd_vid_gen_params->video_frames = 6; } struct sd_ctx_t { @@ -2043,22 +2053,25 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } sd_ctx->sd->rng->manual_seed(seed); + int sample_steps = sd_img_gen_params->sample_params.sample_steps; + size_t t0 = ggml_time_ms(); + sd_ctx->sd->init_scheduler(sd_img_gen_params->sample_params.scheduler); + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); + ggml_tensor* init_latent = NULL; ggml_tensor* concat_latent = NULL; ggml_tensor* denoise_mask = NULL; - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps); - if (sd_img_gen_params->init_image.data) { LOG_INFO("IMG2IMG"); - size_t t_enc = static_cast(sd_img_gen_params->sample_steps * sd_img_gen_params->strength); - if (t_enc == sd_img_gen_params->sample_steps) + size_t t_enc = static_cast(sample_steps * sd_img_gen_params->strength); + if (t_enc == sample_steps) t_enc--; LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched; - sigma_sched.assign(sigmas.begin() + sd_img_gen_params->sample_steps - t_enc - 1, sigmas.end()); + sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); sigmas = sigma_sched; ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); @@ -2189,11 +2202,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, - sd_img_gen_params->guidance, - sd_img_gen_params->eta, + sd_img_gen_params->sample_params.guidance, + sd_img_gen_params->sample_params.eta, width, height, - sd_img_gen_params->sample_method, + sd_img_gen_params->sample_params.sample_method, sigmas, seed, sd_img_gen_params->batch_count, @@ -2221,13 +2234,16 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s std::string prompt = SAFE_STR(sd_vid_gen_params->prompt); std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); - int width = sd_vid_gen_params->width; - int height = sd_vid_gen_params->height; - int frames = sd_vid_gen_params->video_frames; - frames = (frames - 1) / 4 * 4 + 1; + int width = sd_vid_gen_params->width; + int height = sd_vid_gen_params->height; + int frames = sd_vid_gen_params->video_frames; + frames = (frames - 1) / 4 * 4 + 1; + int sample_steps = sd_vid_gen_params->sample_params.sample_steps; LOG_INFO("generate_video %dx%dx%d", width, height, frames); - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); + sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler); + + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); struct ggml_init_params params; params.mem_size = static_cast(100 * 1024) * 1024; // 100 MB @@ -2315,7 +2331,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); - int sample_steps = sigmas.size() - 1; + sample_steps = sigmas.size() - 1; // Get learned condition bool zero_out_masked = true; @@ -2331,7 +2347,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s cond.c_concat = concat_latent; cond.c_vector = clip_vision_output; SDCondition uncond; - if (sd_vid_gen_params->guidance.txt_cfg != 1.0) { + if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0) { uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, sd_ctx->sd->n_threads, negative_prompt, @@ -2372,9 +2388,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s {}, NULL, 0, - sd_vid_gen_params->guidance, - sd_vid_gen_params->eta, - sd_vid_gen_params->sample_method, + sd_vid_gen_params->sample_params.guidance, + sd_vid_gen_params->sample_params.eta, + sd_vid_gen_params->sample_params.sample_method, sigmas, -1, {}); diff --git a/stable-diffusion.h b/stable-diffusion.h index 732bdd5..63c5265 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -50,7 +50,7 @@ enum sample_method_t { SAMPLE_METHOD_COUNT }; -enum schedule_t { +enum scheduler_t { DEFAULT, DISCRETE, KARRAS, @@ -130,7 +130,6 @@ typedef struct { int n_threads; enum sd_type_t wtype; enum rng_type_t rng_type; - enum schedule_t schedule; bool offload_params_to_cpu; bool keep_clip_on_cpu; bool keep_control_net_on_cpu; @@ -163,20 +162,25 @@ typedef struct { sd_slg_params_t slg; } sd_guidance_params_t; +typedef struct { + sd_guidance_params_t guidance; + enum scheduler_t scheduler; + enum sample_method_t sample_method; + int sample_steps; + float eta; +} sd_sample_params_t; + typedef struct { const char* prompt; const char* negative_prompt; int clip_skip; - sd_guidance_params_t guidance; sd_image_t init_image; sd_image_t* ref_images; int ref_images_count; sd_image_t mask_image; int width; int height; - enum sample_method_t sample_method; - int sample_steps; - float eta; + sd_sample_params_t sample_params; float strength; int64_t seed; int batch_count; @@ -191,13 +195,10 @@ typedef struct { const char* prompt; const char* negative_prompt; int clip_skip; - sd_guidance_params_t guidance; sd_image_t init_image; int width; int height; - enum sample_method_t sample_method; - int sample_steps; - float eta; + sd_sample_params_t sample_params; float strength; int64_t seed; int video_frames; @@ -219,8 +220,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type); SD_API enum rng_type_t str_to_rng_type(const char* str); SD_API const char* sd_sample_method_name(enum sample_method_t sample_method); SD_API enum sample_method_t str_to_sample_method(const char* str); -SD_API const char* sd_schedule_name(enum schedule_t schedule); -SD_API enum schedule_t str_to_schedule(const char* str); +SD_API const char* sd_schedule_name(enum scheduler_t scheduler); +SD_API enum scheduler_t str_to_schedule(const char* str); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/wan.hpp b/wan.hpp index 9f3cc51..40d7cbc 100644 --- a/wan.hpp +++ b/wan.hpp @@ -941,8 +941,8 @@ namespace WAN { }; static void load_from_file_and_test(const std::string& file_path) { - ggml_backend_t backend = ggml_backend_cuda_init(0); - // ggml_backend_t backend = ggml_backend_cpu_init(); + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend, false)); { @@ -1099,6 +1099,8 @@ namespace WAN { if (qk_norm) { blocks["norm_k_img"] = std::shared_ptr(new RMSNorm(dim, eps)); + } else { + blocks["norm_k_img"] = std::shared_ptr(new Identity()); } } @@ -1705,7 +1707,7 @@ namespace WAN { void test() { struct ggml_init_params params; - params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_size = static_cast(200 * 1024 * 1024); // 200 MB params.mem_buffer = NULL; params.no_alloc = false; @@ -1719,20 +1721,22 @@ namespace WAN { // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16); // ggml_set_f32(x, 0.01f); auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin"); - // print_ggml_tensor(x); + print_ggml_tensor(x); - std::vector timesteps_vec(1, 999.f); + std::vector timesteps_vec(1, 1000.f); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1); // ggml_set_f32(context, 0.01f); auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin"); - // print_ggml_tensor(context); + print_ggml_tensor(context); + auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin"); + print_ggml_tensor(clip_fea); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx); + compute(8, x, timesteps, context, clip_fea, NULL, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1754,7 +1758,7 @@ namespace WAN { auto tensor_types = model_loader.tensor_storages_types; for (auto& item : tensor_types) { - LOG_DEBUG("%s %u", item.first.c_str(), item.second); + // LOG_DEBUG("%s %u", item.first.c_str(), item.second); if (ends_with(item.first, "weight")) { item.second = model_data_type; } @@ -1763,7 +1767,9 @@ namespace WAN { std::shared_ptr wan = std::shared_ptr(new WanRunner(backend, false, tensor_types, - "model.diffusion_model")); + "model.diffusion_model", + VERSION_WAN2, + true)); wan->alloc_params_buffer(); std::map tensors;