mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
Feat: Temporal tile custom size with overlap (#1510)
* Temporal tile size + overlap * add --extra-tiling-args support --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
2e3514625a
commit
adaa599a3b
@ -107,6 +107,8 @@ Generation Options:
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
|
||||
@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() {
|
||||
"--extra-sample-args",
|
||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
||||
&extra_sample_args},
|
||||
{"",
|
||||
"--extra-tiling-args",
|
||||
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
|
||||
&extra_tiling_args},
|
||||
};
|
||||
|
||||
options.int_options = {
|
||||
@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str(
|
||||
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
|
||||
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
|
||||
}
|
||||
if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) {
|
||||
extra_tiling_args = tiling_json["extra_tiling_args"].get<std::string>();
|
||||
}
|
||||
}
|
||||
|
||||
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
|
||||
@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() {
|
||||
}
|
||||
|
||||
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
|
||||
if (high_noise_sample_params.sample_steps <= 0) {
|
||||
high_noise_sample_params.sample_steps = -1;
|
||||
}
|
||||
@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
|
||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||
|
||||
sd_pm_params_t pm_params = {
|
||||
@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
|
||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||
|
||||
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
||||
@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const {
|
||||
<< vae_tiling_params.tile_size_y << ", "
|
||||
<< vae_tiling_params.target_overlap << ", "
|
||||
<< vae_tiling_params.rel_size_x << ", "
|
||||
<< vae_tiling_params.rel_size_y << " },\n"
|
||||
<< vae_tiling_params.rel_size_y << ", "
|
||||
<< "\"" << extra_tiling_args << "\" },\n"
|
||||
<< "}";
|
||||
return oss.str();
|
||||
}
|
||||
@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
|
||||
};
|
||||
}
|
||||
|
||||
if (gen_params.vae_tiling_params.enabled) {
|
||||
if (gen_params.vae_tiling_params.enabled ||
|
||||
gen_params.vae_tiling_params.temporal_tiling ||
|
||||
!gen_params.extra_tiling_args.empty()) {
|
||||
root["vae_tiling"] = {
|
||||
{"enabled", gen_params.vae_tiling_params.enabled},
|
||||
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
|
||||
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
|
||||
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
|
||||
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
|
||||
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
|
||||
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
|
||||
{"extra_tiling_args", gen_params.extra_tiling_args},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -189,7 +189,8 @@ struct SDGenerationParams {
|
||||
int video_frames = 1;
|
||||
int fps = 16;
|
||||
float vace_strength = 1.f;
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
std::string extra_tiling_args;
|
||||
|
||||
std::string pm_id_images_dir;
|
||||
std::string pm_id_embed_path;
|
||||
|
||||
@ -209,6 +209,8 @@ Default Generation Options:
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
@ -264,6 +266,7 @@ Default Generation Options:
|
||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||
--disable-image-metadata do not embed generation metadata on image files
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--temporal-tiling enable temporal tiling for LTX video VAE decode
|
||||
--hires enable highres fix
|
||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||
|
||||
@ -504,11 +504,13 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
||||
| `sample_params.guidance.slg.scale` | `number` |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
@ -516,6 +518,8 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
||||
| `output_format` | `string` |
|
||||
| `output_compression` | `integer` |
|
||||
|
||||
`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
|
||||
|
||||
`img_gen`-specific default fields:
|
||||
|
||||
| Field | Type |
|
||||
@ -692,11 +696,13 @@ Example:
|
||||
|
||||
"vae_tiling_params": {
|
||||
"enabled": false,
|
||||
"temporal_tiling": false,
|
||||
"tile_size_x": 0,
|
||||
"tile_size_y": 0,
|
||||
"target_overlap": 0.5,
|
||||
"rel_size_x": 0.0,
|
||||
"rel_size_y": 0.0
|
||||
"rel_size_y": 0.0,
|
||||
"extra_tiling_args": ""
|
||||
},
|
||||
|
||||
"cache_mode": "disabled",
|
||||
@ -804,6 +810,14 @@ Other native fields:
|
||||
| `hires.custom_sigmas` | `array<number>` |
|
||||
| `hires.upscale_tile_size` | `integer` |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
@ -1012,11 +1026,13 @@ Example:
|
||||
|
||||
"vae_tiling_params": {
|
||||
"enabled": false,
|
||||
"temporal_tiling": false,
|
||||
"tile_size_x": 0,
|
||||
"tile_size_y": 0,
|
||||
"target_overlap": 0.5,
|
||||
"rel_size_x": 0.0,
|
||||
"rel_size_y": 0.0
|
||||
"rel_size_y": 0.0,
|
||||
"extra_tiling_args": ""
|
||||
},
|
||||
|
||||
"cache_mode": "disabled",
|
||||
@ -1134,6 +1150,14 @@ Other native fields:
|
||||
| Field | Type |
|
||||
| --- | --- |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
|
||||
@ -56,11 +56,13 @@ static const char* capability_sample_method_name(enum sample_method_t sample_met
|
||||
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
|
||||
return {
|
||||
{"enabled", params.enabled},
|
||||
{"temporal_tiling", params.temporal_tiling},
|
||||
{"tile_size_x", params.tile_size_x},
|
||||
{"tile_size_y", params.tile_size_y},
|
||||
{"target_overlap", params.target_overlap},
|
||||
{"rel_size_x", params.rel_size_x},
|
||||
{"rel_size_y", params.rel_size_y},
|
||||
{"extra_tiling_args", params.extra_tiling_args ? params.extra_tiling_args : ""},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -160,6 +160,7 @@ typedef struct {
|
||||
float target_overlap;
|
||||
float rel_size_x;
|
||||
float rel_size_y;
|
||||
const char* extra_tiling_args;
|
||||
} sd_tiling_params_t;
|
||||
|
||||
typedef struct {
|
||||
|
||||
161
src/ltx_vae.hpp
161
src/ltx_vae.hpp
@ -1,6 +1,7 @@
|
||||
#ifndef __SD_LTX_VAE_HPP__
|
||||
#define __SD_LTX_VAE_HPP__
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -143,16 +144,25 @@ namespace LTXVAE {
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
bool causal = true) {
|
||||
bool causal = true,
|
||||
int temporal_pad = 0) {
|
||||
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
||||
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
|
||||
|
||||
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
|
||||
|
||||
GGML_ASSERT(x->ne[2] >= temporal_pad);
|
||||
|
||||
int end_idx = x->ne[2] - temporal_pad;
|
||||
int start_idx = std::max(end_idx - pad, 0);
|
||||
|
||||
// Save a contiguous copy of the last `pad` frames so the large `x`
|
||||
// tensor is not kept alive across iterations by a dangling view.
|
||||
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
|
||||
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
|
||||
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
|
||||
GGML_ASSERT(start_idx >= 0);
|
||||
GGML_ASSERT(end_idx > 0);
|
||||
|
||||
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
|
||||
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
|
||||
}
|
||||
feat_idx++;
|
||||
@ -284,7 +294,8 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
|
||||
@ -311,14 +322,14 @@ namespace LTXVAE {
|
||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
|
||||
}
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
||||
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
|
||||
h = norm2->forward(ctx, h);
|
||||
if (timestep_conditioning) {
|
||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
|
||||
}
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
||||
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
|
||||
return ggml_add(ctx->ggml_ctx, h, x);
|
||||
}
|
||||
@ -367,7 +378,8 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
ggml_tensor* timestep_embed = nullptr;
|
||||
if (timestep_conditioning) {
|
||||
GGML_ASSERT(timestep != nullptr);
|
||||
@ -376,7 +388,7 @@ namespace LTXVAE {
|
||||
}
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
|
||||
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
|
||||
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -437,7 +449,8 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
|
||||
|
||||
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
|
||||
@ -453,7 +466,7 @@ namespace LTXVAE {
|
||||
x_in = res;
|
||||
}
|
||||
|
||||
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
|
||||
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
|
||||
if (residual) {
|
||||
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||
@ -986,7 +999,8 @@ namespace LTXVAE {
|
||||
ggml_tensor* timestep,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
int chunk_idx,
|
||||
int& temporal_pad) {
|
||||
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
|
||||
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
|
||||
@ -998,7 +1012,7 @@ namespace LTXVAE {
|
||||
}
|
||||
|
||||
// conv_in with feat_map for left temporal context
|
||||
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
||||
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||
|
||||
// up_blocks
|
||||
int block_idx = 0;
|
||||
@ -1006,12 +1020,13 @@ namespace LTXVAE {
|
||||
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
|
||||
if (mid_block) {
|
||||
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
|
||||
feat_map, feat_idx, chunk_idx);
|
||||
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
} else {
|
||||
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
|
||||
blocks["up_blocks." + std::to_string(block_idx)]);
|
||||
x = upsample->forward(ctx, x, causal_decoder,
|
||||
feat_map, feat_idx, chunk_idx);
|
||||
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
temporal_pad *= upsample->factor_t;
|
||||
}
|
||||
block_idx++;
|
||||
}
|
||||
@ -1028,7 +1043,7 @@ namespace LTXVAE {
|
||||
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
|
||||
}
|
||||
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
||||
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
||||
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1084,7 +1099,9 @@ namespace LTXVAE {
|
||||
// tensors can be freed by GGML before the next iteration starts.
|
||||
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* z,
|
||||
ggml_tensor* timestep) {
|
||||
ggml_tensor* timestep,
|
||||
int temporal_window_size = 1,
|
||||
int temporal_tile_overlap = 0) {
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
||||
auto latents = processor->un_normalize(ctx, z);
|
||||
@ -1099,13 +1116,43 @@ namespace LTXVAE {
|
||||
// 128 slots is generous enough for any supported decoder configuration.
|
||||
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
||||
|
||||
// Ensure window size is at least 1
|
||||
int window = std::max(1, temporal_window_size);
|
||||
int overlap = std::max(0, temporal_tile_overlap);
|
||||
|
||||
if (overlap >= window) {
|
||||
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
|
||||
overlap, window);
|
||||
overlap = window - 1;
|
||||
}
|
||||
LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles",
|
||||
window,
|
||||
overlap,
|
||||
(int)T,
|
||||
(T + window - overlap - 1) / (window - overlap));
|
||||
ggml_tensor* out = nullptr;
|
||||
for (int i = 0; i < (int)T; i++) {
|
||||
for (int i = 0; i < (int)T - overlap; i += (window - overlap)) {
|
||||
int feat_idx = 0;
|
||||
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
|
||||
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
|
||||
feat_map, feat_idx, i);
|
||||
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
|
||||
|
||||
// Calculate the end index for the current temporal chunk
|
||||
int end_i = std::min((int)T, i + window);
|
||||
if (end_i >= (int)T) {
|
||||
overlap = 0; // avoid overlap issues in the last chunk
|
||||
}
|
||||
|
||||
int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation
|
||||
|
||||
auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);
|
||||
|
||||
auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
|
||||
feat_map, feat_idx, i, chunk_overlap);
|
||||
|
||||
// discard overlap frames if it's not the final chunk
|
||||
if (overlap > 0 && end_i < (int)T) {
|
||||
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
|
||||
}
|
||||
|
||||
out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
|
||||
}
|
||||
|
||||
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
||||
@ -1140,8 +1187,13 @@ namespace LTXVAE {
|
||||
} // namespace LTXVAE
|
||||
|
||||
struct LTXVideoVAE : public VAE {
|
||||
static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4;
|
||||
static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1;
|
||||
|
||||
bool decode_only;
|
||||
bool temporal_tiling_enabled = false;
|
||||
int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||
int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||
int ltx_vae_version;
|
||||
bool timestep_conditioning;
|
||||
int patch_size;
|
||||
@ -1178,6 +1230,68 @@ struct LTXVideoVAE : public VAE {
|
||||
temporal_tiling_enabled = enabled;
|
||||
}
|
||||
|
||||
static std::string trim_tiling_arg(std::string value) {
|
||||
const char* whitespace = " \t\r\n";
|
||||
size_t begin = value.find_first_not_of(whitespace);
|
||||
if (begin == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
size_t end = value.find_last_not_of(whitespace);
|
||||
return value.substr(begin, end - begin + 1);
|
||||
}
|
||||
|
||||
static bool parse_tiling_int(const std::string& value, int& parsed) {
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stoi(value, &consumed);
|
||||
return trim_tiling_arg(value.substr(consumed)).empty();
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void set_tiling_params(const sd_tiling_params_t& params) override {
|
||||
temporal_tiling_enabled = params.temporal_tiling;
|
||||
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||
|
||||
const char* extra_tiling_args = params.extra_tiling_args;
|
||||
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string raw(extra_tiling_args);
|
||||
size_t start = 0;
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
|
||||
if (!token.empty()) {
|
||||
size_t eq = token.find('=');
|
||||
if (eq == std::string::npos) {
|
||||
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
|
||||
} else {
|
||||
std::string key = trim_tiling_arg(token.substr(0, eq));
|
||||
std::string value = trim_tiling_arg(token.substr(eq + 1));
|
||||
int parsed = 0;
|
||||
if (!parse_tiling_int(value, parsed)) {
|
||||
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
|
||||
} else if (key == "temporal_tile_frames") {
|
||||
temporal_tile_frames = std::max(1, parsed);
|
||||
} else if (key == "temporal_tile_overlap") {
|
||||
temporal_tile_overlap = std::max(0, parsed);
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
|
||||
vae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
@ -1195,7 +1309,10 @@ struct LTXVideoVAE : public VAE {
|
||||
bool use_tiled = decode_graph && temporal_tiling_enabled &&
|
||||
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
|
||||
if (use_tiled) {
|
||||
out = vae.decode_tiled(&runner_ctx, z, timestep);
|
||||
LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d",
|
||||
temporal_tile_frames,
|
||||
temporal_tile_overlap);
|
||||
out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap);
|
||||
} else {
|
||||
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
||||
}
|
||||
|
||||
@ -151,7 +151,7 @@ public:
|
||||
bool apply_lora_immediately = false;
|
||||
|
||||
std::string taesd_path;
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
|
||||
bool offload_params_to_cpu = false;
|
||||
float max_vram = 0.f;
|
||||
bool use_pmid = false;
|
||||
@ -2679,7 +2679,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->batch_count = 1;
|
||||
sd_img_gen_params->control_strength = 0.9f;
|
||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
sd_cache_params_init(&sd_img_gen_params->cache);
|
||||
sd_hires_params_init(&sd_img_gen_params->hires);
|
||||
}
|
||||
@ -2708,7 +2708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
"increase_ref_index: %s\n"
|
||||
"control_strength: %.2f\n"
|
||||
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
||||
"VAE tiling: %s (temporal=%s)\n"
|
||||
"VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n"
|
||||
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
||||
SAFE_STR(sd_img_gen_params->prompt),
|
||||
SAFE_STR(sd_img_gen_params->negative_prompt),
|
||||
@ -2728,6 +2728,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
|
||||
SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args),
|
||||
BOOL_STR(sd_img_gen_params->hires.enabled),
|
||||
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
||||
SAFE_STR(sd_img_gen_params->hires.model_path),
|
||||
@ -2765,7 +2766,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
sd_vid_gen_params->fps = 16;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
sd_vid_gen_params->vace_strength = 1.f;
|
||||
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
sd_vid_gen_params->hires.enabled = false;
|
||||
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
|
||||
sd_vid_gen_params->hires.scale = 2.f;
|
||||
|
||||
@ -484,7 +484,7 @@ public:
|
||||
if (is_wide) {
|
||||
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
} else{
|
||||
} else {
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
}
|
||||
|
||||
@ -167,6 +167,7 @@ public:
|
||||
int64_t t0 = ggml_time_ms();
|
||||
sd::Tensor<float> input = x;
|
||||
sd::Tensor<float> output;
|
||||
set_tiling_params(tiling_params);
|
||||
|
||||
if (tiling_params.enabled) {
|
||||
const int scale_factor = get_scale_factor();
|
||||
@ -216,6 +217,9 @@ public:
|
||||
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
|
||||
virtual void set_tiling_params(const sd_tiling_params_t& params) {
|
||||
set_temporal_tiling_enabled(params.temporal_tiling);
|
||||
};
|
||||
};
|
||||
|
||||
struct FakeVAE : public VAE {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user