mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: stream LTX VAE temporal tile decoding (#1539)
This commit is contained in:
parent
adaa599a3b
commit
449165caf5
153
src/ltx_vae.hpp
153
src/ltx_vae.hpp
@ -1158,6 +1158,27 @@ namespace LTXVAE {
|
|||||||
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* z,
|
||||||
|
ggml_tensor* timestep,
|
||||||
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
|
int chunk_idx,
|
||||||
|
int temporal_tile_overlap,
|
||||||
|
int& feat_idx) {
|
||||||
|
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||||
|
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
||||||
|
auto latents = processor->un_normalize(ctx, z);
|
||||||
|
|
||||||
|
feat_idx = 0;
|
||||||
|
int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation
|
||||||
|
auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep,
|
||||||
|
feat_map, feat_idx, chunk_idx, chunk_overlap);
|
||||||
|
if (chunk_overlap > 0) {
|
||||||
|
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
|
||||||
|
}
|
||||||
|
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* encode(GGMLRunnerContext* ctx,
|
ggml_tensor* encode(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* x) {
|
ggml_tensor* x) {
|
||||||
GGML_ASSERT(!decode_only);
|
GGML_ASSERT(!decode_only);
|
||||||
@ -1296,6 +1317,41 @@ struct LTXVideoVAE : public VAE {
|
|||||||
vae.get_param_tensors(tensors, prefix);
|
vae.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct TemporalTilePlan {
|
||||||
|
int frames = 1;
|
||||||
|
int overlap = 0;
|
||||||
|
int stride = 1;
|
||||||
|
int num_tiles = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const {
|
||||||
|
TemporalTilePlan plan;
|
||||||
|
plan.frames = std::max(1, temporal_tile_frames);
|
||||||
|
plan.overlap = std::max(0, temporal_tile_overlap);
|
||||||
|
|
||||||
|
if (plan.overlap >= plan.frames) {
|
||||||
|
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
|
||||||
|
plan.overlap,
|
||||||
|
plan.frames);
|
||||||
|
plan.overlap = plan.frames - 1;
|
||||||
|
}
|
||||||
|
if (total_frames > 1 && plan.overlap >= total_frames) {
|
||||||
|
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile",
|
||||||
|
plan.overlap,
|
||||||
|
(long long)total_frames);
|
||||||
|
plan.overlap = static_cast<int>(total_frames - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
plan.stride = std::max(1, plan.frames - plan.overlap);
|
||||||
|
int64_t tiled_frames = std::max<int64_t>(1, total_frames - plan.overlap);
|
||||||
|
plan.num_tiles = total_frames > 0 ? static_cast<int>((tiled_frames + plan.stride - 1) / plan.stride) : 0;
|
||||||
|
return plan;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string temporal_feat_cache_name(size_t feat_idx) const {
|
||||||
|
return "ltx_vae_temporal_feat:" + std::to_string(feat_idx);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||||
ggml_cgraph* gf = new_graph_custom(20480);
|
ggml_cgraph* gf = new_graph_custom(20480);
|
||||||
ggml_tensor* z = make_input(z_tensor);
|
ggml_tensor* z = make_input(z_tensor);
|
||||||
@ -1306,21 +1362,97 @@ struct LTXVideoVAE : public VAE {
|
|||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
ggml_tensor* out;
|
ggml_tensor* out;
|
||||||
bool use_tiled = decode_graph && temporal_tiling_enabled &&
|
|
||||||
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
|
|
||||||
if (use_tiled) {
|
|
||||||
LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d",
|
|
||||||
temporal_tile_frames,
|
|
||||||
temporal_tile_overlap);
|
|
||||||
out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap);
|
|
||||||
} else {
|
|
||||||
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
||||||
}
|
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_cgraph* build_temporal_tile_graph(const sd::Tensor<float>& z_chunk_tensor,
|
||||||
|
int chunk_idx,
|
||||||
|
int chunk_overlap) {
|
||||||
|
ggml_cgraph* gf = new_graph_custom(20480);
|
||||||
|
ggml_tensor* z = make_input(z_chunk_tensor);
|
||||||
|
ggml_tensor* timestep = nullptr;
|
||||||
|
if (timestep_conditioning) {
|
||||||
|
timestep = make_input(decode_timestep_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
||||||
|
for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) {
|
||||||
|
feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
int feat_count = 0;
|
||||||
|
ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx,
|
||||||
|
z,
|
||||||
|
timestep,
|
||||||
|
feat_map,
|
||||||
|
chunk_idx,
|
||||||
|
chunk_overlap,
|
||||||
|
feat_count);
|
||||||
|
|
||||||
|
for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast<int>(feat_map.size()); ++feat_idx) {
|
||||||
|
ggml_tensor* feat_cache = feat_map[static_cast<size_t>(feat_idx)];
|
||||||
|
if (feat_cache != nullptr) {
|
||||||
|
cache(temporal_feat_cache_name(static_cast<size_t>(feat_idx)), feat_cache);
|
||||||
|
ggml_build_forward_expand(gf, feat_cache);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> decode_temporal_tiled_streaming(const int n_threads,
|
||||||
|
const sd::Tensor<float>& input,
|
||||||
|
size_t expected_dim) {
|
||||||
|
const int64_t total_frames = input.shape()[2];
|
||||||
|
TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames);
|
||||||
|
|
||||||
|
LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles",
|
||||||
|
plan.frames,
|
||||||
|
plan.overlap,
|
||||||
|
(long long)total_frames,
|
||||||
|
plan.num_tiles);
|
||||||
|
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
|
cache_tensor_map.clear();
|
||||||
|
|
||||||
|
sd::Tensor<float> output;
|
||||||
|
for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) {
|
||||||
|
const int64_t end = std::min<int64_t>(total_frames, start + plan.frames);
|
||||||
|
const int chunk_overlap = end < total_frames ? plan.overlap : 0;
|
||||||
|
auto z_chunk = sd::ops::slice(input, 2, start, end);
|
||||||
|
|
||||||
|
LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d",
|
||||||
|
(long long)(start / plan.stride + 1),
|
||||||
|
plan.num_tiles,
|
||||||
|
(long long)start,
|
||||||
|
(long long)end,
|
||||||
|
chunk_overlap);
|
||||||
|
|
||||||
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
|
return build_temporal_tile_graph(z_chunk,
|
||||||
|
static_cast<int>(start),
|
||||||
|
chunk_overlap);
|
||||||
|
};
|
||||||
|
auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true),
|
||||||
|
expected_dim);
|
||||||
|
if (chunk.empty()) {
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
|
cache_tensor_map.clear();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
|
cache_tensor_map.clear();
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) {
|
ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) {
|
||||||
ggml_cgraph* gf = new_graph_custom(1024);
|
ggml_cgraph* gf = new_graph_custom(1024);
|
||||||
ggml_tensor* z = make_input(z_tensor);
|
ggml_tensor* z = make_input(z_tensor);
|
||||||
@ -1356,6 +1488,9 @@ struct LTXVideoVAE : public VAE {
|
|||||||
input = sd::ops::slice(input, 2, 0, cropped_t);
|
input = sd::ops::slice(input, 2, 0, cropped_t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) {
|
||||||
|
return decode_temporal_tiled_streaming(n_threads, input, expected_dim);
|
||||||
|
}
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
return build_graph(input, decode_graph);
|
return build_graph(input, decode_graph);
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user