Feat: Temporal tile custom size with overlap (#1510)

* Temporal tile size + overlap

* add --extra-tiling-args support

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2026-05-21 17:44:12 +02:00 committed by GitHub
parent 2e3514625a
commit adaa599a3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 210 additions and 39 deletions

View File

@ -107,6 +107,8 @@ Generation Options:
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal; euler_ge supports gamma
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)

View File

@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() {
"--extra-sample-args",
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
&extra_sample_args},
{"",
"--extra-tiling-args",
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
&extra_tiling_args},
};
options.int_options = {
@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str(
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
}
if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) {
extra_tiling_args = tiling_json["extra_tiling_args"].get<std::string>();
}
}
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() {
}
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
if (high_noise_sample_params.sample_steps <= 0) {
high_noise_sample_params.sample_steps = -1;
}
@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
sd_pm_params_t pm_params = {
@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const {
<< vae_tiling_params.tile_size_y << ", "
<< vae_tiling_params.target_overlap << ", "
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< vae_tiling_params.rel_size_y << ", "
<< "\"" << extra_tiling_args << "\" },\n"
<< "}";
return oss.str();
}
@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
};
}
if (gen_params.vae_tiling_params.enabled) {
if (gen_params.vae_tiling_params.enabled ||
gen_params.vae_tiling_params.temporal_tiling ||
!gen_params.extra_tiling_args.empty()) {
root["vae_tiling"] = {
{"enabled", gen_params.vae_tiling_params.enabled},
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
{"extra_tiling_args", gen_params.extra_tiling_args},
};
}

View File

@ -189,7 +189,8 @@ struct SDGenerationParams {
int video_frames = 1;
int fps = 16;
float vace_strength = 1.f;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
std::string extra_tiling_args;
std::string pm_id_images_dir;
std::string pm_id_embed_path;

View File

@ -209,6 +209,8 @@ Default Generation Options:
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal; euler_ge supports gamma
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)
@ -264,6 +266,7 @@ Default Generation Options:
--disable-auto-resize-ref-image disable auto resize of ref images
--disable-image-metadata do not embed generation metadata on image files
--vae-tiling process vae in tiles to reduce memory usage
--temporal-tiling enable temporal tiling for LTX video VAE decode
--hires enable highres fix
-s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,

View File

@ -504,11 +504,13 @@ Shared default fields used by both `img_gen` and `vid_gen`:
| `sample_params.guidance.slg.scale` | `number` |
| `vae_tiling_params` | `object` |
| `vae_tiling_params.enabled` | `boolean` |
| `vae_tiling_params.temporal_tiling` | `boolean` |
| `vae_tiling_params.tile_size_x` | `integer` |
| `vae_tiling_params.tile_size_y` | `integer` |
| `vae_tiling_params.target_overlap` | `number` |
| `vae_tiling_params.rel_size_x` | `number` |
| `vae_tiling_params.rel_size_y` | `number` |
| `vae_tiling_params.extra_tiling_args` | `string` |
| `cache_mode` | `string` |
| `cache_option` | `string` |
| `scm_mask` | `string` |
@ -516,6 +518,8 @@ Shared default fields used by both `img_gen` and `vid_gen`:
| `output_format` | `string` |
| `output_compression` | `integer` |
`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
`img_gen`-specific default fields:
| Field | Type |
@ -692,11 +696,13 @@ Example:
"vae_tiling_params": {
"enabled": false,
"temporal_tiling": false,
"tile_size_x": 0,
"tile_size_y": 0,
"target_overlap": 0.5,
"rel_size_x": 0.0,
"rel_size_y": 0.0
"rel_size_y": 0.0,
"extra_tiling_args": ""
},
"cache_mode": "disabled",
@ -804,6 +810,14 @@ Other native fields:
| `hires.custom_sigmas` | `array<number>` |
| `hires.upscale_tile_size` | `integer` |
| `vae_tiling_params` | `object` |
| `vae_tiling_params.enabled` | `boolean` |
| `vae_tiling_params.temporal_tiling` | `boolean` |
| `vae_tiling_params.tile_size_x` | `integer` |
| `vae_tiling_params.tile_size_y` | `integer` |
| `vae_tiling_params.target_overlap` | `number` |
| `vae_tiling_params.rel_size_x` | `number` |
| `vae_tiling_params.rel_size_y` | `number` |
| `vae_tiling_params.extra_tiling_args` | `string` |
| `cache_mode` | `string` |
| `cache_option` | `string` |
| `scm_mask` | `string` |
@ -1012,11 +1026,13 @@ Example:
"vae_tiling_params": {
"enabled": false,
"temporal_tiling": false,
"tile_size_x": 0,
"tile_size_y": 0,
"target_overlap": 0.5,
"rel_size_x": 0.0,
"rel_size_y": 0.0
"rel_size_y": 0.0,
"extra_tiling_args": ""
},
"cache_mode": "disabled",
@ -1134,6 +1150,14 @@ Other native fields:
| Field | Type |
| --- | --- |
| `vae_tiling_params` | `object` |
| `vae_tiling_params.enabled` | `boolean` |
| `vae_tiling_params.temporal_tiling` | `boolean` |
| `vae_tiling_params.tile_size_x` | `integer` |
| `vae_tiling_params.tile_size_y` | `integer` |
| `vae_tiling_params.target_overlap` | `number` |
| `vae_tiling_params.rel_size_x` | `number` |
| `vae_tiling_params.rel_size_y` | `number` |
| `vae_tiling_params.extra_tiling_args` | `string` |
| `cache_mode` | `string` |
| `cache_option` | `string` |
| `scm_mask` | `string` |

View File

@ -56,11 +56,13 @@ static const char* capability_sample_method_name(enum sample_method_t sample_met
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
return {
{"enabled", params.enabled},
{"temporal_tiling", params.temporal_tiling},
{"tile_size_x", params.tile_size_x},
{"tile_size_y", params.tile_size_y},
{"target_overlap", params.target_overlap},
{"rel_size_x", params.rel_size_x},
{"rel_size_y", params.rel_size_y},
{"extra_tiling_args", params.extra_tiling_args ? params.extra_tiling_args : ""},
};
}

View File

@ -160,6 +160,7 @@ typedef struct {
float target_overlap;
float rel_size_x;
float rel_size_y;
const char* extra_tiling_args;
} sd_tiling_params_t;
typedef struct {

View File

@ -1,6 +1,7 @@
#ifndef __SD_LTX_VAE_HPP__
#define __SD_LTX_VAE_HPP__
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
@ -143,16 +144,25 @@ namespace LTXVAE {
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx,
bool causal = true) {
bool causal = true,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
GGML_ASSERT(x->ne[2] >= temporal_pad);
int end_idx = x->ne[2] - temporal_pad;
int start_idx = std::max(end_idx - pad, 0);
// Save a contiguous copy of the last `pad` frames so the large `x`
// tensor is not kept alive across iterations by a dangling view.
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
GGML_ASSERT(start_idx >= 0);
GGML_ASSERT(end_idx > 0);
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
}
feat_idx++;
@ -284,7 +294,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
@ -311,14 +322,14 @@ namespace LTXVAE {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
return ggml_add(ctx->ggml_ctx, h, x);
}
@ -367,7 +378,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
@ -376,7 +388,7 @@ namespace LTXVAE {
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
}
return x;
}
@ -437,7 +449,8 @@ namespace LTXVAE {
bool causal,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int temporal_pad = 0) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
@ -453,7 +466,7 @@ namespace LTXVAE {
x_in = res;
}
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
@ -986,7 +999,8 @@ namespace LTXVAE {
ggml_tensor* timestep,
std::vector<ggml_tensor*>& feat_map,
int& feat_idx,
int chunk_idx) {
int chunk_idx,
int& temporal_pad) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
@ -998,7 +1012,7 @@ namespace LTXVAE {
}
// conv_in with feat_map for left temporal context
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
// up_blocks
int block_idx = 0;
@ -1006,12 +1020,13 @@ namespace LTXVAE {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder,
feat_map, feat_idx, chunk_idx);
feat_map, feat_idx, chunk_idx, temporal_pad);
temporal_pad *= upsample->factor_t;
}
block_idx++;
}
@ -1028,7 +1043,7 @@ namespace LTXVAE {
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
return x;
}
};
@ -1084,7 +1099,9 @@ namespace LTXVAE {
// tensors can be freed by GGML before the next iteration starts.
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
ggml_tensor* timestep,
int temporal_window_size = 1,
int temporal_tile_overlap = 0) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
@ -1099,13 +1116,43 @@ namespace LTXVAE {
// 128 slots is generous enough for any supported decoder configuration.
std::vector<ggml_tensor*> feat_map(128, nullptr);
// Ensure window size is at least 1
int window = std::max(1, temporal_window_size);
int overlap = std::max(0, temporal_tile_overlap);
if (overlap >= window) {
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
overlap, window);
overlap = window - 1;
}
LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles",
window,
overlap,
(int)T,
(T + window - overlap - 1) / (window - overlap));
ggml_tensor* out = nullptr;
for (int i = 0; i < (int)T; i++) {
for (int i = 0; i < (int)T - overlap; i += (window - overlap)) {
int feat_idx = 0;
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
feat_map, feat_idx, i);
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
// Calculate the end index for the current temporal chunk
int end_i = std::min((int)T, i + window);
if (end_i >= (int)T) {
overlap = 0; // avoid overlap issues in the last chunk
}
int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation
auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);
auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
feat_map, feat_idx, i, chunk_overlap);
// discard overlap frames if it's not the final chunk
if (overlap > 0 && end_i < (int)T) {
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
}
out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
}
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
@ -1140,8 +1187,13 @@ namespace LTXVAE {
} // namespace LTXVAE
struct LTXVideoVAE : public VAE {
static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4;
static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1;
bool decode_only;
bool temporal_tiling_enabled = false;
int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
int ltx_vae_version;
bool timestep_conditioning;
int patch_size;
@ -1178,6 +1230,68 @@ struct LTXVideoVAE : public VAE {
temporal_tiling_enabled = enabled;
}
static std::string trim_tiling_arg(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}
static bool parse_tiling_int(const std::string& value, int& parsed) {
try {
size_t consumed = 0;
parsed = std::stoi(value, &consumed);
return trim_tiling_arg(value.substr(consumed)).empty();
} catch (...) {
return false;
}
}
void set_tiling_params(const sd_tiling_params_t& params) override {
temporal_tiling_enabled = params.temporal_tiling;
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
const char* extra_tiling_args = params.extra_tiling_args;
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
return;
}
std::string raw(extra_tiling_args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
continue;
}
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
} else {
std::string key = trim_tiling_arg(token.substr(0, eq));
std::string value = trim_tiling_arg(token.substr(eq + 1));
int parsed = 0;
if (!parse_tiling_int(value, parsed)) {
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
temporal_tile_frames = std::max(1, parsed);
} else if (key == "temporal_tile_overlap") {
temporal_tile_overlap = std::max(0, parsed);
} else {
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
}
}
}
start = pos + 1;
}
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
vae.get_param_tensors(tensors, prefix);
}
@ -1195,7 +1309,10 @@ struct LTXVideoVAE : public VAE {
bool use_tiled = decode_graph && temporal_tiling_enabled &&
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
if (use_tiled) {
out = vae.decode_tiled(&runner_ctx, z, timestep);
LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d",
temporal_tile_frames,
temporal_tile_overlap);
out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap);
} else {
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
}

View File

@ -151,7 +151,7 @@ public:
bool apply_lora_immediately = false;
std::string taesd_path;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool use_pmid = false;
@ -2679,7 +2679,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires);
}
@ -2708,7 +2708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"increase_ref_index: %s\n"
"control_strength: %.2f\n"
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
"VAE tiling: %s (temporal=%s)\n"
"VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n"
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
@ -2728,6 +2728,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args),
BOOL_STR(sd_img_gen_params->hires.enabled),
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
SAFE_STR(sd_img_gen_params->hires.model_path),
@ -2765,7 +2766,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->fps = 16;
sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_vid_gen_params->hires.enabled = false;
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
sd_vid_gen_params->hires.scale = 2.f;

View File

@ -484,7 +484,7 @@ public:
if (is_wide) {
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
} else{
} else {
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
}

View File

@ -167,6 +167,7 @@ public:
int64_t t0 = ggml_time_ms();
sd::Tensor<float> input = x;
sd::Tensor<float> output;
set_tiling_params(tiling_params);
if (tiling_params.enabled) {
const int scale_factor = get_scale_factor();
@ -216,6 +217,9 @@ public:
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
virtual void set_tiling_params(const sd_tiling_params_t& params) {
set_temporal_tiling_enabled(params.temporal_tiling);
};
};
struct FakeVAE : public VAE {