feat: stream LTX VAE temporal tile decoding (#1539)

This commit is contained in:
leejet 2026-05-22 00:25:04 +08:00 committed by GitHub
parent adaa599a3b
commit 449165caf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1158,6 +1158,27 @@ namespace LTXVAE {
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
} }
ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int chunk_idx,
int temporal_tile_overlap,
int& feat_idx) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
feat_idx = 0;
int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation
auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep,
feat_map, feat_idx, chunk_idx, chunk_overlap);
if (chunk_overlap > 0) {
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
}
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1);
}
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* encode(GGMLRunnerContext* ctx,
ggml_tensor* x) { ggml_tensor* x) {
GGML_ASSERT(!decode_only); GGML_ASSERT(!decode_only);
@ -1296,6 +1317,41 @@ struct LTXVideoVAE : public VAE {
vae.get_param_tensors(tensors, prefix); vae.get_param_tensors(tensors, prefix);
} }
struct TemporalTilePlan {
int frames = 1;
int overlap = 0;
int stride = 1;
int num_tiles = 1;
};
TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const {
TemporalTilePlan plan;
plan.frames = std::max(1, temporal_tile_frames);
plan.overlap = std::max(0, temporal_tile_overlap);
if (plan.overlap >= plan.frames) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
plan.overlap,
plan.frames);
plan.overlap = plan.frames - 1;
}
if (total_frames > 1 && plan.overlap >= total_frames) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile",
plan.overlap,
(long long)total_frames);
plan.overlap = static_cast<int>(total_frames - 1);
}
plan.stride = std::max(1, plan.frames - plan.overlap);
int64_t tiled_frames = std::max<int64_t>(1, total_frames - plan.overlap);
plan.num_tiles = total_frames > 0 ? static_cast<int>((tiled_frames + plan.stride - 1) / plan.stride) : 0;
return plan;
}
std::string temporal_feat_cache_name(size_t feat_idx) const {
return "ltx_vae_temporal_feat:" + std::to_string(feat_idx);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) { ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
ggml_cgraph* gf = new_graph_custom(20480); ggml_cgraph* gf = new_graph_custom(20480);
ggml_tensor* z = make_input(z_tensor); ggml_tensor* z = make_input(z_tensor);
@ -1306,21 +1362,97 @@ struct LTXVideoVAE : public VAE {
auto runner_ctx = get_context(); auto runner_ctx = get_context();
ggml_tensor* out; ggml_tensor* out;
bool use_tiled = decode_graph && temporal_tiling_enabled &&
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
if (use_tiled) {
LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d",
temporal_tile_frames,
temporal_tile_overlap);
out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap);
} else {
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
}
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
return gf; return gf;
} }
ggml_cgraph* build_temporal_tile_graph(const sd::Tensor<float>& z_chunk_tensor,
int chunk_idx,
int chunk_overlap) {
ggml_cgraph* gf = new_graph_custom(20480);
ggml_tensor* z = make_input(z_chunk_tensor);
ggml_tensor* timestep = nullptr;
if (timestep_conditioning) {
timestep = make_input(decode_timestep_tensor);
}
std::vector<ggml_tensor*> feat_map(128, nullptr);
for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) {
feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx));
}
auto runner_ctx = get_context();
int feat_count = 0;
ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx,
z,
timestep,
feat_map,
chunk_idx,
chunk_overlap,
feat_count);
for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast<int>(feat_map.size()); ++feat_idx) {
ggml_tensor* feat_cache = feat_map[static_cast<size_t>(feat_idx)];
if (feat_cache != nullptr) {
cache(temporal_feat_cache_name(static_cast<size_t>(feat_idx)), feat_cache);
ggml_build_forward_expand(gf, feat_cache);
}
}
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> decode_temporal_tiled_streaming(const int n_threads,
const sd::Tensor<float>& input,
size_t expected_dim) {
const int64_t total_frames = input.shape()[2];
TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames);
LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles",
plan.frames,
plan.overlap,
(long long)total_frames,
plan.num_tiles);
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
sd::Tensor<float> output;
for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) {
const int64_t end = std::min<int64_t>(total_frames, start + plan.frames);
const int chunk_overlap = end < total_frames ? plan.overlap : 0;
auto z_chunk = sd::ops::slice(input, 2, start, end);
LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d",
(long long)(start / plan.stride + 1),
plan.num_tiles,
(long long)start,
(long long)end,
chunk_overlap);
auto get_graph = [&]() -> ggml_cgraph* {
return build_temporal_tile_graph(z_chunk,
static_cast<int>(start),
chunk_overlap);
};
auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true),
expected_dim);
if (chunk.empty()) {
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
return {};
}
output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2);
}
free_cache_ctx_and_buffer();
cache_tensor_map.clear();
return output;
}
ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) { ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) {
ggml_cgraph* gf = new_graph_custom(1024); ggml_cgraph* gf = new_graph_custom(1024);
ggml_tensor* z = make_input(z_tensor); ggml_tensor* z = make_input(z_tensor);
@ -1356,6 +1488,9 @@ struct LTXVideoVAE : public VAE {
input = sd::ops::slice(input, 2, 0, cropped_t); input = sd::ops::slice(input, 2, 0, cropped_t);
} }
} }
if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) {
return decode_temporal_tiled_streaming(n_threads, input, expected_dim);
}
auto get_graph = [&]() -> ggml_cgraph* { auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input, decode_graph); return build_graph(input, decode_graph);
}; };