add temporal tiling support

This commit is contained in:
leejet 2026-04-29 22:22:34 +08:00
parent 2ca782a65a
commit e744e1e4e2
9 changed files with 304 additions and 30 deletions

View File

@ -156,6 +156,7 @@ Generation Options:
--disable-auto-resize-ref-image disable auto resize of ref images --disable-auto-resize-ref-image disable auto resize of ref images
--disable-image-metadata do not embed generation metadata on image files --disable-image-metadata do not embed generation metadata on image files
--vae-tiling process vae in tiles to reduce memory usage --vae-tiling process vae in tiles to reduce memory usage
--temporal-tiling enable temporal tiling for LTX video VAE decode
--hires enable highres fix --hires enable highres fix
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,

View File

@ -989,6 +989,11 @@ ArgOptions SDGenerationParams::get_options() {
"process vae in tiles to reduce memory usage", "process vae in tiles to reduce memory usage",
true, true,
&vae_tiling_params.enabled}, &vae_tiling_params.enabled},
{"",
"--temporal-tiling",
"enable temporal tiling for LTX video VAE decode",
true,
&vae_tiling_params.temporal_tiling},
{"", {"",
"--hires", "--hires",
"enable highres fix", "enable highres fix",
@ -1681,6 +1686,9 @@ bool SDGenerationParams::from_json_str(
if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) { if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) {
vae_tiling_params.enabled = tiling_json["enabled"]; vae_tiling_params.enabled = tiling_json["enabled"];
} }
if (tiling_json.contains("temporal_tiling") && tiling_json["temporal_tiling"].is_boolean()) {
vae_tiling_params.temporal_tiling = tiling_json["temporal_tiling"];
}
if (tiling_json.contains("tile_size_x") && tiling_json["tile_size_x"].is_number_integer()) { if (tiling_json.contains("tile_size_x") && tiling_json["tile_size_x"].is_number_integer()) {
vae_tiling_params.tile_size_x = tiling_json["tile_size_x"]; vae_tiling_params.tile_size_x = tiling_json["tile_size_x"];
} }
@ -2275,6 +2283,7 @@ std::string SDGenerationParams::to_string() const {
<< ", upscale_tile_size: " << hires_upscale_tile_size << " },\n" << ", upscale_tile_size: " << hires_upscale_tile_size << " },\n"
<< " vae_tiling_params: { " << " vae_tiling_params: { "
<< vae_tiling_params.enabled << ", " << vae_tiling_params.enabled << ", "
<< vae_tiling_params.temporal_tiling << ", "
<< vae_tiling_params.tile_size_x << ", " << vae_tiling_params.tile_size_x << ", "
<< vae_tiling_params.tile_size_y << ", " << vae_tiling_params.tile_size_y << ", "
<< vae_tiling_params.target_overlap << ", " << vae_tiling_params.target_overlap << ", "

View File

@ -183,7 +183,7 @@ struct SDGenerationParams {
int video_frames = 1; int video_frames = 1;
int fps = 16; int fps = 16;
float vace_strength = 1.f; float vace_strength = 1.f;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
std::string pm_id_images_dir; std::string pm_id_images_dir;
std::string pm_id_embed_path; std::string pm_id_embed_path;

View File

@ -149,6 +149,7 @@ enum lora_apply_mode_t {
typedef struct { typedef struct {
bool enabled; bool enabled;
bool temporal_tiling;
int tile_size_x; int tile_size_x;
int tile_size_y; int tile_size_y;
float target_overlap; float target_overlap;

View File

@ -1683,6 +1683,8 @@ struct GGMLRunnerContext {
bool circular_y_enabled = false; bool circular_y_enabled = false;
std::shared_ptr<WeightAdapter> weight_adapter = nullptr; std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
std::unordered_map<ggml_tensor*, std::string>* debug_tensors = nullptr; std::unordered_map<ggml_tensor*, std::string>* debug_tensors = nullptr;
std::function<ggml_tensor*(const std::string&)> get_cache_tensor;
std::function<void(const std::string&, ggml_tensor*)> cache_tensor;
void capture_tensor(const std::string& name, ggml_tensor* tensor) { void capture_tensor(const std::string& name, ggml_tensor* tensor) {
if (debug_tensors == nullptr || tensor == nullptr) { if (debug_tensors == nullptr || tensor == nullptr) {
@ -1691,6 +1693,20 @@ struct GGMLRunnerContext {
ggml_set_output(tensor); ggml_set_output(tensor);
(*debug_tensors)[tensor] = name; (*debug_tensors)[tensor] = name;
} }
ggml_tensor* load_cache_tensor(const std::string& name) const {
if (!get_cache_tensor) {
return nullptr;
}
return get_cache_tensor(name);
}
void persist_cache_tensor(const std::string& name, ggml_tensor* tensor) const {
if (!cache_tensor || tensor == nullptr) {
return;
}
cache_tensor(name, tensor);
}
}; };
struct GGMLRunner { struct GGMLRunner {
@ -1850,6 +1866,11 @@ protected:
ggml_build_forward_expand(gf, entry.first); ggml_build_forward_expand(gf, entry.first);
} }
} }
for (const auto& entry : cache_tensor_map) {
if (entry.second != nullptr) {
ggml_build_forward_expand(gf, entry.second);
}
}
prepare_build_in_tensor_after(gf); prepare_build_in_tensor_after(gf);
return gf; return gf;
} }
@ -2057,6 +2078,12 @@ public:
runner_ctx.circular_y_enabled = circular_y_enabled; runner_ctx.circular_y_enabled = circular_y_enabled;
runner_ctx.weight_adapter = weight_adapter; runner_ctx.weight_adapter = weight_adapter;
runner_ctx.debug_tensors = &debug_tensors; runner_ctx.debug_tensors = &debug_tensors;
runner_ctx.get_cache_tensor = [this](const std::string& name) {
return this->get_cache_tensor_by_name(name);
};
runner_ctx.cache_tensor = [this](const std::string& name, ggml_tensor* tensor) {
this->cache(name, tensor);
};
return runner_ctx; return runner_ctx;
} }
@ -2156,6 +2183,9 @@ public:
} }
void cache(const std::string name, ggml_tensor* tensor) { void cache(const std::string name, ggml_tensor* tensor) {
if (tensor != nullptr && tensor->view_src != nullptr) {
tensor = ggml_cont(compute_ctx, tensor);
}
cache_tensor_map[name] = tensor; cache_tensor_map[name] = tensor;
} }

View File

@ -130,6 +130,56 @@ namespace LTXVAE {
} }
return conv->forward(ctx, x); return conv->forward(ctx, x);
} }
// Chunked forward: uses feat_map to carry temporal context across frames.
// feat_map[feat_idx] holds the last `pad` frames from the previous chunk at
// this layer. nullptr means first chunk → fall back to repeat-first-frame.
// The cache entry is a contiguous copy (not a view) so that the large
// intermediate tensor `x` can be freed by GGML after this iteration ends.
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx,
bool causal = true) {
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
// Save a contiguous copy of the last `pad` frames so the large `x`
// tensor is not kept alive across iterations by a dangling view.
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
}
feat_idx++;
if (pad > 0) {
ggml_tensor* left_pad;
if (prev != nullptr) {
left_pad = prev;
} else {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
left_pad = first_frame;
for (int i = 1; i < pad; i++) {
left_pad = ggml_concat(ctx->ggml_ctx, left_pad, first_frame, 2);
}
}
x = ggml_concat(ctx->ggml_ctx, left_pad, x, 2);
}
if (!causal && pad > 0) {
auto last_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - 1, x->ne[2]);
auto right_pad = last_frame;
for (int i = 1; i < pad; i++) {
right_pad = ggml_concat(ctx->ggml_ctx, right_pad, last_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, x, right_pad, 2);
}
return conv->forward(ctx, x);
}
}; };
struct PixelNorm3D : public UnaryBlock { struct PixelNorm3D : public UnaryBlock {
@ -225,6 +275,51 @@ namespace LTXVAE {
return ggml_add(ctx->ggml_ctx, h, x); return ggml_add(ctx->ggml_ctx, h, x);
} }
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
ggml_tensor* shift1 = nullptr;
ggml_tensor* scale1 = nullptr;
ggml_tensor* shift2 = nullptr;
ggml_tensor* scale2 = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto values = ggml_add(ctx->ggml_ctx,
params["scale_shift_table"],
ggml_reshape_2d(ctx->ggml_ctx, timestep, channels, 4));
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, values, 4, 1, false);
shift1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[0]));
scale1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[1]));
shift2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[2]));
scale2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[3]));
}
auto h = norm1->forward(ctx, x);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
return ggml_add(ctx->ggml_ctx, h, x);
}
}; };
struct UNetMidBlock3D : public GGMLBlock { struct UNetMidBlock3D : public GGMLBlock {
@ -263,6 +358,26 @@ namespace LTXVAE {
} }
return x; return x;
} }
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["time_embedder"]);
timestep_embed = time_embedder->forward(ctx, timestep);
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
}
return x;
}
}; };
struct DepthToSpaceUpsample : public GGMLBlock { struct DepthToSpaceUpsample : public GGMLBlock {
@ -314,6 +429,35 @@ namespace LTXVAE {
} }
return x; return x;
} }
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
ggml_tensor* x_in = nullptr;
if (residual) {
x_in = depth_to_space_3d(ctx->ggml_ctx, x, in_channels / (factor_t * factor_s * factor_s), factor_t, factor_s, drop_first);
int repeat = (factor_t * factor_s * factor_s) / out_channels_reduction_factor;
auto res = x_in;
for (int i = 1; i < repeat; i++) {
res = ggml_concat(ctx->ggml_ctx, res, x_in, 3);
}
x_in = res;
}
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
}
return x;
}
}; };
struct SpaceToDepthDownsample : public GGMLBlock { struct SpaceToDepthDownsample : public GGMLBlock {
@ -735,6 +879,61 @@ namespace LTXVAE {
x = conv_out->forward(ctx, x, causal_decoder); x = conv_out->forward(ctx, x, causal_decoder);
return x; return x;
} }
// Process a single latent frame through the complete decoder (conv_in → up_blocks
// → final layers), using feat_map to carry per-layer causal context from the
// previous frame. Designed for tiled temporal decode: each iteration receives
// 1 latent frame so that intermediate tensors can be freed between iterations.
ggml_tensor* forward_tiled_frame(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
ggml_tensor* scaled_timestep = timestep;
if (timestep_conditioning && timestep != nullptr) {
auto multiplier = ggml_ext_backend_tensor_get_f32(params["timestep_scale_multiplier"]);
scaled_timestep = ggml_ext_scale(ctx->ggml_ctx, timestep, multiplier);
}
// conv_in with feat_map for left temporal context
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
// up_blocks
int block_idx = 0;
while (blocks.find("up_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
feat_map, feat_idx, chunk_idx);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder,
feat_map, feat_idx, chunk_idx);
}
block_idx++;
}
x = conv_norm_out->forward(ctx, x);
if (timestep_conditioning) {
auto last_time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["last_time_embedder"]);
auto timestep_embed = last_time_embedder->forward(ctx, scaled_timestep);
auto [shift, scale] = get_shift_scale(ctx->ggml_ctx,
params["last_scale_shift_table"],
timestep_embed,
hidden_channels,
2);
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
return x;
}
}; };
struct VideoVAE : public GGMLBlock { struct VideoVAE : public GGMLBlock {
@ -779,6 +978,39 @@ namespace LTXVAE {
return out; return out;
} }
// Tiled temporal decode: each latent frame is processed through the COMPLETE
// decoder individually. Per-layer causal context is passed via feat_map
// (contiguous copies, not views) so that each iteration's large intermediate
// tensors can be freed by GGML before the next iteration starts.
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
const int64_t T = z->ne[2];
if (T <= 1) {
auto out = decoder->forward(ctx, latents, timestep);
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
}
// feat_map holds ggml_tensor* nodes (contiguous copies at each conv layer).
// 128 slots is generous enough for any supported decoder configuration.
std::vector<ggml_tensor*> feat_map(128, nullptr);
ggml_tensor* out = nullptr;
for (int i = 0; i < (int)T; i++) {
int feat_idx = 0;
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
feat_map, feat_idx, i);
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
}
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
}
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* encode(GGMLRunnerContext* ctx,
ggml_tensor* x) { ggml_tensor* x) {
GGML_ASSERT(!decode_only); GGML_ASSERT(!decode_only);
@ -797,6 +1029,7 @@ namespace LTXVAE {
struct LTXVideoVAE : public VAE { struct LTXVideoVAE : public VAE {
bool decode_only; bool decode_only;
bool temporal_tiling_enabled = false;
int ltx_vae_version; int ltx_vae_version;
bool timestep_conditioning; bool timestep_conditioning;
int patch_size; int patch_size;
@ -829,30 +1062,31 @@ struct LTXVideoVAE : public VAE {
return "ltx_video_vae"; return "ltx_video_vae";
} }
void set_temporal_tiling_enabled(bool enabled) override {
temporal_tiling_enabled = enabled;
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
vae.get_param_tensors(tensors, prefix); vae.get_param_tensors(tensors, prefix);
} }
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) { ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
LOG_DEBUG("ltx_video_vae build_graph input %dx%dx%dx%d", ggml_cgraph* gf = new_graph_custom(20480);
(int)z_tensor.shape()[0],
(int)z_tensor.shape()[1],
(int)z_tensor.shape()[2],
(int)z_tensor.shape()[3]);
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
ggml_tensor* z = make_input(z_tensor); ggml_tensor* z = make_input(z_tensor);
ggml_tensor* timestep = nullptr; ggml_tensor* timestep = nullptr;
if (timestep_conditioning) { if (timestep_conditioning) {
timestep = make_input(decode_timestep_tensor); timestep = make_input(decode_timestep_tensor);
} }
auto runner_ctx = get_context(); auto runner_ctx = get_context();
ggml_tensor* out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); ggml_tensor* out;
LOG_DEBUG("ltx_video_vae build_graph output ne=[%lld,%lld,%lld,%lld]", bool use_tiled = decode_graph && temporal_tiling_enabled &&
(long long)out->ne[0], z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
(long long)out->ne[1], if (use_tiled) {
(long long)out->ne[2], out = vae.decode_tiled(&runner_ctx, z, timestep);
(long long)out->ne[3]); } else {
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
}
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
return gf; return gf;
@ -889,12 +1123,6 @@ struct LTXVideoVAE : public VAE {
if (result.empty()) { if (result.empty()) {
return {}; return {};
} }
LOG_DEBUG("ltx_video_vae host output shape=[%lld,%lld,%lld,%lld] dim=%lld",
(long long)(result.shape().size() > 0 ? result.shape()[0] : 0),
(long long)(result.shape().size() > 1 ? result.shape()[1] : 0),
(long long)(result.shape().size() > 2 ? result.shape()[2] : 0),
(long long)(result.shape().size() > 3 ? result.shape()[3] : 0),
(long long)result.dim());
return result; return result;
} }

View File

@ -924,12 +924,12 @@ namespace LTXV {
if (count < 0) { if (count < 0) {
count = coeff - start; count = coeff - start;
} }
auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, timestep->ne[1]); auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, timestep->ne[1]);
auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1); auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1);
auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, timestep->ne[1]); auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, timestep->ne[1]);
t = ggml_repeat(ctx->ggml_ctx, t, e); t = ggml_repeat(ctx->ggml_ctx, t, e);
s = ggml_repeat(ctx->ggml_ctx, s, e); s = ggml_repeat(ctx->ggml_ctx, s, e);
auto out = ggml_add(ctx->ggml_ctx, s, t); auto out = ggml_add(ctx->ggml_ctx, s, t);
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, out, static_cast<int>(coeff), 1); auto chunks = ggml_ext_chunk(ctx->ggml_ctx, out, static_cast<int>(coeff), 1);
return std::vector<ggml_tensor*>(chunks.begin() + start, chunks.begin() + start + count); return std::vector<ggml_tensor*>(chunks.begin() + start, chunks.begin() + start + count);
} }

View File

@ -144,7 +144,7 @@ public:
bool apply_lora_immediately = false; bool apply_lora_immediately = false;
std::string taesd_path; std::string taesd_path;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
bool offload_params_to_cpu = false; bool offload_params_to_cpu = false;
bool use_pmid = false; bool use_pmid = false;
@ -1520,9 +1520,11 @@ public:
sd::Tensor<float> decoded; sd::Tensor<float> decoded;
bool is_video = preview_latent_tensor_is_video(latents); bool is_video = preview_latent_tensor_is_video(latents);
if (preview_vae) { if (preview_vae) {
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = preview_vae->diffusion_to_vae_latents(latents); vae_latents = preview_vae->diffusion_to_vae_latents(latents);
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} else { } else {
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = first_stage_model->diffusion_to_vae_latents(latents); vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true); decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} }
@ -1963,6 +1965,7 @@ public:
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) { sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
auto latents = first_stage_model->diffusion_to_vae_latents(x); auto latents = first_stage_model->diffusion_to_vae_latents(x);
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y); return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
} }
@ -2398,7 +2401,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->batch_count = 1; sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_cache_params_init(&sd_img_gen_params->cache); sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires); sd_hires_params_init(&sd_img_gen_params->hires);
} }
@ -2427,7 +2430,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"increase_ref_index: %s\n" "increase_ref_index: %s\n"
"control_strength: %.2f\n" "control_strength: %.2f\n"
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
"VAE tiling: %s\n" "VAE tiling: %s (temporal=%s)\n"
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n", "hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt), SAFE_STR(sd_img_gen_params->negative_prompt),
@ -2446,6 +2449,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->pm_params.id_images_count, sd_img_gen_params->pm_params.id_images_count,
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
BOOL_STR(sd_img_gen_params->hires.enabled), BOOL_STR(sd_img_gen_params->hires.enabled),
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler), sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
SAFE_STR(sd_img_gen_params->hires.model_path), SAFE_STR(sd_img_gen_params->hires.model_path),
@ -2483,7 +2487,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->fps = 16; sd_vid_gen_params->fps = 16;
sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_cache_params_init(&sd_vid_gen_params->cache); sd_cache_params_init(&sd_vid_gen_params->cache);
} }

View File

@ -214,6 +214,7 @@ public:
virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0; virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0;
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0; virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
}; };
struct FakeVAE : public VAE { struct FakeVAE : public VAE {