mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 02:43:36 +00:00
feat: enable vae tiling for vid gen (#1152)
* enable vae tiling for vid gen * format code * eliminate compilation warning --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
27b5f17401
commit
0e52afc651
@ -757,6 +757,7 @@ int main(int argc, const char* argv[]) {
|
||||
gen_params.seed,
|
||||
gen_params.video_frames,
|
||||
gen_params.vace_strength,
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
|
||||
@ -2489,10 +2489,15 @@ public:
|
||||
ne2 = 1;
|
||||
ne3 = C * x->ne[3];
|
||||
} else {
|
||||
if (!use_tiny_autoencoder) {
|
||||
C *= 2;
|
||||
int64_t out_channels = C;
|
||||
bool encode_outputs_mu = use_tiny_autoencoder ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_flux2(version) ||
|
||||
version == VERSION_CHROMA_RADIANCE;
|
||||
if (!encode_outputs_mu) {
|
||||
out_channels *= 2;
|
||||
}
|
||||
ne2 = C;
|
||||
ne2 = out_channels;
|
||||
ne3 = x->ne[3];
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
||||
@ -2633,7 +2638,7 @@ public:
|
||||
}
|
||||
process_latent_out(x);
|
||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||
if (vae_tiling_params.enabled && !decode_video) {
|
||||
if (vae_tiling_params.enabled) {
|
||||
float tile_overlap;
|
||||
int tile_size_x, tile_size_y;
|
||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
|
||||
@ -2651,7 +2656,7 @@ public:
|
||||
first_stage_model->free_compute_buffer();
|
||||
process_vae_output_tensor(result);
|
||||
} else {
|
||||
if (vae_tiling_params.enabled && !decode_video) {
|
||||
if (vae_tiling_params.enabled) {
|
||||
// split latent in 64x64 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
tae_first_stage->compute(n_threads, in, true, &out);
|
||||
@ -3046,7 +3051,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
"height: %d\n"
|
||||
"sample_params: %s\n"
|
||||
"strength: %.2f\n"
|
||||
"seed: %" PRId64 "\n"
|
||||
"seed: %" PRId64
|
||||
"\n"
|
||||
"batch_count: %d\n"
|
||||
"ref_images_count: %d\n"
|
||||
"auto_resize_ref_image: %s\n"
|
||||
@ -3099,6 +3105,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
sd_vid_gen_params->video_frames = 6;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
sd_vid_gen_params->vace_strength = 1.f;
|
||||
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_cache_params_init(&sd_vid_gen_params->cache);
|
||||
}
|
||||
|
||||
@ -3728,6 +3735,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params;
|
||||
|
||||
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
||||
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
||||
|
||||
@ -319,6 +319,7 @@ typedef struct {
|
||||
int64_t seed;
|
||||
int video_frames;
|
||||
float vace_strength;
|
||||
sd_tiling_params_t vae_tiling_params;
|
||||
sd_cache_params_t cache;
|
||||
} sd_vid_gen_params_t;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user