mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
Feat: Temporal tile custom size with overlap (#1510)
* Temporal tile size + overlap * add --extra-tiling-args support --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
2e3514625a
commit
adaa599a3b
@ -107,6 +107,8 @@ Generation Options:
|
|||||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
stretch, terminal; euler_ge supports gamma
|
stretch, terminal; euler_ge supports gamma
|
||||||
|
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||||
|
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
|
|||||||
@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
"--extra-sample-args",
|
"--extra-sample-args",
|
||||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
||||||
&extra_sample_args},
|
&extra_sample_args},
|
||||||
|
{"",
|
||||||
|
"--extra-tiling-args",
|
||||||
|
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
|
||||||
|
&extra_tiling_args},
|
||||||
};
|
};
|
||||||
|
|
||||||
options.int_options = {
|
options.int_options = {
|
||||||
@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str(
|
|||||||
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
|
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
|
||||||
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
|
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
|
||||||
}
|
}
|
||||||
|
if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) {
|
||||||
|
extra_tiling_args = tiling_json["extra_tiling_args"].get<std::string>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
|
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
|
||||||
@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
|
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
|
||||||
|
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||||
|
|
||||||
if (high_noise_sample_params.sample_steps <= 0) {
|
if (high_noise_sample_params.sample_steps <= 0) {
|
||||||
high_noise_sample_params.sample_steps = -1;
|
high_noise_sample_params.sample_steps = -1;
|
||||||
}
|
}
|
||||||
@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
|
|||||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||||
|
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||||
|
|
||||||
sd_pm_params_t pm_params = {
|
sd_pm_params_t pm_params = {
|
||||||
@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
|
|||||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||||
|
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||||
|
|
||||||
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
||||||
@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const {
|
|||||||
<< vae_tiling_params.tile_size_y << ", "
|
<< vae_tiling_params.tile_size_y << ", "
|
||||||
<< vae_tiling_params.target_overlap << ", "
|
<< vae_tiling_params.target_overlap << ", "
|
||||||
<< vae_tiling_params.rel_size_x << ", "
|
<< vae_tiling_params.rel_size_x << ", "
|
||||||
<< vae_tiling_params.rel_size_y << " },\n"
|
<< vae_tiling_params.rel_size_y << ", "
|
||||||
|
<< "\"" << extra_tiling_args << "\" },\n"
|
||||||
<< "}";
|
<< "}";
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.vae_tiling_params.enabled) {
|
if (gen_params.vae_tiling_params.enabled ||
|
||||||
|
gen_params.vae_tiling_params.temporal_tiling ||
|
||||||
|
!gen_params.extra_tiling_args.empty()) {
|
||||||
root["vae_tiling"] = {
|
root["vae_tiling"] = {
|
||||||
{"enabled", gen_params.vae_tiling_params.enabled},
|
{"enabled", gen_params.vae_tiling_params.enabled},
|
||||||
|
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
|
||||||
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
|
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
|
||||||
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
|
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
|
||||||
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
|
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
|
||||||
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
|
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
|
||||||
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
|
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
|
||||||
|
{"extra_tiling_args", gen_params.extra_tiling_args},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -189,7 +189,8 @@ struct SDGenerationParams {
|
|||||||
int video_frames = 1;
|
int video_frames = 1;
|
||||||
int fps = 16;
|
int fps = 16;
|
||||||
float vace_strength = 1.f;
|
float vace_strength = 1.f;
|
||||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||||
|
std::string extra_tiling_args;
|
||||||
|
|
||||||
std::string pm_id_images_dir;
|
std::string pm_id_images_dir;
|
||||||
std::string pm_id_embed_path;
|
std::string pm_id_embed_path;
|
||||||
|
|||||||
@ -209,6 +209,8 @@ Default Generation Options:
|
|||||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
stretch, terminal; euler_ge supports gamma
|
stretch, terminal; euler_ge supports gamma
|
||||||
|
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||||
|
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
@ -264,6 +266,7 @@ Default Generation Options:
|
|||||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||||
--disable-image-metadata do not embed generation metadata on image files
|
--disable-image-metadata do not embed generation metadata on image files
|
||||||
--vae-tiling process vae in tiles to reduce memory usage
|
--vae-tiling process vae in tiles to reduce memory usage
|
||||||
|
--temporal-tiling enable temporal tiling for LTX video VAE decode
|
||||||
--hires enable highres fix
|
--hires enable highres fix
|
||||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||||
|
|||||||
@ -504,11 +504,13 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
|||||||
| `sample_params.guidance.slg.scale` | `number` |
|
| `sample_params.guidance.slg.scale` | `number` |
|
||||||
| `vae_tiling_params` | `object` |
|
| `vae_tiling_params` | `object` |
|
||||||
| `vae_tiling_params.enabled` | `boolean` |
|
| `vae_tiling_params.enabled` | `boolean` |
|
||||||
|
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||||
| `vae_tiling_params.target_overlap` | `number` |
|
| `vae_tiling_params.target_overlap` | `number` |
|
||||||
| `vae_tiling_params.rel_size_x` | `number` |
|
| `vae_tiling_params.rel_size_x` | `number` |
|
||||||
| `vae_tiling_params.rel_size_y` | `number` |
|
| `vae_tiling_params.rel_size_y` | `number` |
|
||||||
|
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||||
| `cache_mode` | `string` |
|
| `cache_mode` | `string` |
|
||||||
| `cache_option` | `string` |
|
| `cache_option` | `string` |
|
||||||
| `scm_mask` | `string` |
|
| `scm_mask` | `string` |
|
||||||
@ -516,6 +518,8 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
|||||||
| `output_format` | `string` |
|
| `output_format` | `string` |
|
||||||
| `output_compression` | `integer` |
|
| `output_compression` | `integer` |
|
||||||
|
|
||||||
|
`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
|
||||||
|
|
||||||
`img_gen`-specific default fields:
|
`img_gen`-specific default fields:
|
||||||
|
|
||||||
| Field | Type |
|
| Field | Type |
|
||||||
@ -692,11 +696,13 @@ Example:
|
|||||||
|
|
||||||
"vae_tiling_params": {
|
"vae_tiling_params": {
|
||||||
"enabled": false,
|
"enabled": false,
|
||||||
|
"temporal_tiling": false,
|
||||||
"tile_size_x": 0,
|
"tile_size_x": 0,
|
||||||
"tile_size_y": 0,
|
"tile_size_y": 0,
|
||||||
"target_overlap": 0.5,
|
"target_overlap": 0.5,
|
||||||
"rel_size_x": 0.0,
|
"rel_size_x": 0.0,
|
||||||
"rel_size_y": 0.0
|
"rel_size_y": 0.0,
|
||||||
|
"extra_tiling_args": ""
|
||||||
},
|
},
|
||||||
|
|
||||||
"cache_mode": "disabled",
|
"cache_mode": "disabled",
|
||||||
@ -804,6 +810,14 @@ Other native fields:
|
|||||||
| `hires.custom_sigmas` | `array<number>` |
|
| `hires.custom_sigmas` | `array<number>` |
|
||||||
| `hires.upscale_tile_size` | `integer` |
|
| `hires.upscale_tile_size` | `integer` |
|
||||||
| `vae_tiling_params` | `object` |
|
| `vae_tiling_params` | `object` |
|
||||||
|
| `vae_tiling_params.enabled` | `boolean` |
|
||||||
|
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||||
|
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||||
|
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||||
|
| `vae_tiling_params.target_overlap` | `number` |
|
||||||
|
| `vae_tiling_params.rel_size_x` | `number` |
|
||||||
|
| `vae_tiling_params.rel_size_y` | `number` |
|
||||||
|
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||||
| `cache_mode` | `string` |
|
| `cache_mode` | `string` |
|
||||||
| `cache_option` | `string` |
|
| `cache_option` | `string` |
|
||||||
| `scm_mask` | `string` |
|
| `scm_mask` | `string` |
|
||||||
@ -1012,11 +1026,13 @@ Example:
|
|||||||
|
|
||||||
"vae_tiling_params": {
|
"vae_tiling_params": {
|
||||||
"enabled": false,
|
"enabled": false,
|
||||||
|
"temporal_tiling": false,
|
||||||
"tile_size_x": 0,
|
"tile_size_x": 0,
|
||||||
"tile_size_y": 0,
|
"tile_size_y": 0,
|
||||||
"target_overlap": 0.5,
|
"target_overlap": 0.5,
|
||||||
"rel_size_x": 0.0,
|
"rel_size_x": 0.0,
|
||||||
"rel_size_y": 0.0
|
"rel_size_y": 0.0,
|
||||||
|
"extra_tiling_args": ""
|
||||||
},
|
},
|
||||||
|
|
||||||
"cache_mode": "disabled",
|
"cache_mode": "disabled",
|
||||||
@ -1134,6 +1150,14 @@ Other native fields:
|
|||||||
| Field | Type |
|
| Field | Type |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| `vae_tiling_params` | `object` |
|
| `vae_tiling_params` | `object` |
|
||||||
|
| `vae_tiling_params.enabled` | `boolean` |
|
||||||
|
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||||
|
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||||
|
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||||
|
| `vae_tiling_params.target_overlap` | `number` |
|
||||||
|
| `vae_tiling_params.rel_size_x` | `number` |
|
||||||
|
| `vae_tiling_params.rel_size_y` | `number` |
|
||||||
|
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||||
| `cache_mode` | `string` |
|
| `cache_mode` | `string` |
|
||||||
| `cache_option` | `string` |
|
| `cache_option` | `string` |
|
||||||
| `scm_mask` | `string` |
|
| `scm_mask` | `string` |
|
||||||
|
|||||||
@ -56,11 +56,13 @@ static const char* capability_sample_method_name(enum sample_method_t sample_met
|
|||||||
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
|
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
|
||||||
return {
|
return {
|
||||||
{"enabled", params.enabled},
|
{"enabled", params.enabled},
|
||||||
|
{"temporal_tiling", params.temporal_tiling},
|
||||||
{"tile_size_x", params.tile_size_x},
|
{"tile_size_x", params.tile_size_x},
|
||||||
{"tile_size_y", params.tile_size_y},
|
{"tile_size_y", params.tile_size_y},
|
||||||
{"target_overlap", params.target_overlap},
|
{"target_overlap", params.target_overlap},
|
||||||
{"rel_size_x", params.rel_size_x},
|
{"rel_size_x", params.rel_size_x},
|
||||||
{"rel_size_y", params.rel_size_y},
|
{"rel_size_y", params.rel_size_y},
|
||||||
|
{"extra_tiling_args", params.extra_tiling_args ? params.extra_tiling_args : ""},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -160,6 +160,7 @@ typedef struct {
|
|||||||
float target_overlap;
|
float target_overlap;
|
||||||
float rel_size_x;
|
float rel_size_x;
|
||||||
float rel_size_y;
|
float rel_size_y;
|
||||||
|
const char* extra_tiling_args;
|
||||||
} sd_tiling_params_t;
|
} sd_tiling_params_t;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|||||||
@ -3172,7 +3172,7 @@ protected:
|
|||||||
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||||
this->prefix = prefix;
|
this->prefix = prefix;
|
||||||
enum ggml_type wtype = GGML_TYPE_F16;
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
|
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
|
||||||
|
|||||||
161
src/ltx_vae.hpp
161
src/ltx_vae.hpp
@ -1,6 +1,7 @@
|
|||||||
#ifndef __SD_LTX_VAE_HPP__
|
#ifndef __SD_LTX_VAE_HPP__
|
||||||
#define __SD_LTX_VAE_HPP__
|
#define __SD_LTX_VAE_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -143,16 +144,25 @@ namespace LTXVAE {
|
|||||||
std::vector<ggml_tensor*>& feat_map,
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
int chunk_idx,
|
int chunk_idx,
|
||||||
bool causal = true) {
|
bool causal = true,
|
||||||
|
int temporal_pad = 0) {
|
||||||
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
||||||
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
|
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
|
||||||
|
|
||||||
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
|
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
|
||||||
|
|
||||||
|
GGML_ASSERT(x->ne[2] >= temporal_pad);
|
||||||
|
|
||||||
|
int end_idx = x->ne[2] - temporal_pad;
|
||||||
|
int start_idx = std::max(end_idx - pad, 0);
|
||||||
|
|
||||||
// Save a contiguous copy of the last `pad` frames so the large `x`
|
// Save a contiguous copy of the last `pad` frames so the large `x`
|
||||||
// tensor is not kept alive across iterations by a dangling view.
|
// tensor is not kept alive across iterations by a dangling view.
|
||||||
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
|
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
|
||||||
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
|
GGML_ASSERT(start_idx >= 0);
|
||||||
|
GGML_ASSERT(end_idx > 0);
|
||||||
|
|
||||||
|
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
|
||||||
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
|
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
|
||||||
}
|
}
|
||||||
feat_idx++;
|
feat_idx++;
|
||||||
@ -284,7 +294,8 @@ namespace LTXVAE {
|
|||||||
bool causal,
|
bool causal,
|
||||||
std::vector<ggml_tensor*>& feat_map,
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
int chunk_idx) {
|
int chunk_idx,
|
||||||
|
int temporal_pad = 0) {
|
||||||
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
|
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
|
||||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
|
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
|
||||||
@ -311,14 +322,14 @@ namespace LTXVAE {
|
|||||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
|
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
|
||||||
}
|
}
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||||
|
|
||||||
h = norm2->forward(ctx, h);
|
h = norm2->forward(ctx, h);
|
||||||
if (timestep_conditioning) {
|
if (timestep_conditioning) {
|
||||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
|
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
|
||||||
}
|
}
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||||
|
|
||||||
return ggml_add(ctx->ggml_ctx, h, x);
|
return ggml_add(ctx->ggml_ctx, h, x);
|
||||||
}
|
}
|
||||||
@ -367,7 +378,8 @@ namespace LTXVAE {
|
|||||||
bool causal,
|
bool causal,
|
||||||
std::vector<ggml_tensor*>& feat_map,
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
int chunk_idx) {
|
int chunk_idx,
|
||||||
|
int temporal_pad = 0) {
|
||||||
ggml_tensor* timestep_embed = nullptr;
|
ggml_tensor* timestep_embed = nullptr;
|
||||||
if (timestep_conditioning) {
|
if (timestep_conditioning) {
|
||||||
GGML_ASSERT(timestep != nullptr);
|
GGML_ASSERT(timestep != nullptr);
|
||||||
@ -376,7 +388,7 @@ namespace LTXVAE {
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
|
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
|
||||||
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
|
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -437,7 +449,8 @@ namespace LTXVAE {
|
|||||||
bool causal,
|
bool causal,
|
||||||
std::vector<ggml_tensor*>& feat_map,
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
int chunk_idx) {
|
int chunk_idx,
|
||||||
|
int temporal_pad = 0) {
|
||||||
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
|
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
|
||||||
|
|
||||||
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
|
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
|
||||||
@ -453,7 +466,7 @@ namespace LTXVAE {
|
|||||||
x_in = res;
|
x_in = res;
|
||||||
}
|
}
|
||||||
|
|
||||||
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
|
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||||
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
|
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
|
||||||
if (residual) {
|
if (residual) {
|
||||||
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||||
@ -986,7 +999,8 @@ namespace LTXVAE {
|
|||||||
ggml_tensor* timestep,
|
ggml_tensor* timestep,
|
||||||
std::vector<ggml_tensor*>& feat_map,
|
std::vector<ggml_tensor*>& feat_map,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
int chunk_idx) {
|
int chunk_idx,
|
||||||
|
int& temporal_pad) {
|
||||||
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
|
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
|
||||||
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
|
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
|
||||||
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
|
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
|
||||||
@ -998,7 +1012,7 @@ namespace LTXVAE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// conv_in with feat_map for left temporal context
|
// conv_in with feat_map for left temporal context
|
||||||
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||||
|
|
||||||
// up_blocks
|
// up_blocks
|
||||||
int block_idx = 0;
|
int block_idx = 0;
|
||||||
@ -1006,12 +1020,13 @@ namespace LTXVAE {
|
|||||||
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
|
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
|
||||||
if (mid_block) {
|
if (mid_block) {
|
||||||
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
|
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
|
||||||
feat_map, feat_idx, chunk_idx);
|
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||||
} else {
|
} else {
|
||||||
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
|
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
|
||||||
blocks["up_blocks." + std::to_string(block_idx)]);
|
blocks["up_blocks." + std::to_string(block_idx)]);
|
||||||
x = upsample->forward(ctx, x, causal_decoder,
|
x = upsample->forward(ctx, x, causal_decoder,
|
||||||
feat_map, feat_idx, chunk_idx);
|
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||||
|
temporal_pad *= upsample->factor_t;
|
||||||
}
|
}
|
||||||
block_idx++;
|
block_idx++;
|
||||||
}
|
}
|
||||||
@ -1028,7 +1043,7 @@ namespace LTXVAE {
|
|||||||
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
|
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
|
||||||
}
|
}
|
||||||
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
||||||
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1084,7 +1099,9 @@ namespace LTXVAE {
|
|||||||
// tensors can be freed by GGML before the next iteration starts.
|
// tensors can be freed by GGML before the next iteration starts.
|
||||||
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
|
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* z,
|
ggml_tensor* z,
|
||||||
ggml_tensor* timestep) {
|
ggml_tensor* timestep,
|
||||||
|
int temporal_window_size = 1,
|
||||||
|
int temporal_tile_overlap = 0) {
|
||||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||||
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
||||||
auto latents = processor->un_normalize(ctx, z);
|
auto latents = processor->un_normalize(ctx, z);
|
||||||
@ -1099,13 +1116,43 @@ namespace LTXVAE {
|
|||||||
// 128 slots is generous enough for any supported decoder configuration.
|
// 128 slots is generous enough for any supported decoder configuration.
|
||||||
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
||||||
|
|
||||||
|
// Ensure window size is at least 1
|
||||||
|
int window = std::max(1, temporal_window_size);
|
||||||
|
int overlap = std::max(0, temporal_tile_overlap);
|
||||||
|
|
||||||
|
if (overlap >= window) {
|
||||||
|
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
|
||||||
|
overlap, window);
|
||||||
|
overlap = window - 1;
|
||||||
|
}
|
||||||
|
LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles",
|
||||||
|
window,
|
||||||
|
overlap,
|
||||||
|
(int)T,
|
||||||
|
(T + window - overlap - 1) / (window - overlap));
|
||||||
ggml_tensor* out = nullptr;
|
ggml_tensor* out = nullptr;
|
||||||
for (int i = 0; i < (int)T; i++) {
|
for (int i = 0; i < (int)T - overlap; i += (window - overlap)) {
|
||||||
int feat_idx = 0;
|
int feat_idx = 0;
|
||||||
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
|
|
||||||
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
|
// Calculate the end index for the current temporal chunk
|
||||||
feat_map, feat_idx, i);
|
int end_i = std::min((int)T, i + window);
|
||||||
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
|
if (end_i >= (int)T) {
|
||||||
|
overlap = 0; // avoid overlap issues in the last chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation
|
||||||
|
|
||||||
|
auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);
|
||||||
|
|
||||||
|
auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
|
||||||
|
feat_map, feat_idx, i, chunk_overlap);
|
||||||
|
|
||||||
|
// discard overlap frames if it's not the final chunk
|
||||||
|
if (overlap > 0 && end_i < (int)T) {
|
||||||
|
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
|
||||||
|
}
|
||||||
|
|
||||||
|
out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
||||||
@ -1140,8 +1187,13 @@ namespace LTXVAE {
|
|||||||
} // namespace LTXVAE
|
} // namespace LTXVAE
|
||||||
|
|
||||||
struct LTXVideoVAE : public VAE {
|
struct LTXVideoVAE : public VAE {
|
||||||
|
static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4;
|
||||||
|
static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1;
|
||||||
|
|
||||||
bool decode_only;
|
bool decode_only;
|
||||||
bool temporal_tiling_enabled = false;
|
bool temporal_tiling_enabled = false;
|
||||||
|
int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||||
|
int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||||
int ltx_vae_version;
|
int ltx_vae_version;
|
||||||
bool timestep_conditioning;
|
bool timestep_conditioning;
|
||||||
int patch_size;
|
int patch_size;
|
||||||
@ -1178,6 +1230,68 @@ struct LTXVideoVAE : public VAE {
|
|||||||
temporal_tiling_enabled = enabled;
|
temporal_tiling_enabled = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string trim_tiling_arg(std::string value) {
|
||||||
|
const char* whitespace = " \t\r\n";
|
||||||
|
size_t begin = value.find_first_not_of(whitespace);
|
||||||
|
if (begin == std::string::npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
size_t end = value.find_last_not_of(whitespace);
|
||||||
|
return value.substr(begin, end - begin + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool parse_tiling_int(const std::string& value, int& parsed) {
|
||||||
|
try {
|
||||||
|
size_t consumed = 0;
|
||||||
|
parsed = std::stoi(value, &consumed);
|
||||||
|
return trim_tiling_arg(value.substr(consumed)).empty();
|
||||||
|
} catch (...) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_tiling_params(const sd_tiling_params_t& params) override {
|
||||||
|
temporal_tiling_enabled = params.temporal_tiling;
|
||||||
|
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||||
|
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||||
|
|
||||||
|
const char* extra_tiling_args = params.extra_tiling_args;
|
||||||
|
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string raw(extra_tiling_args);
|
||||||
|
size_t start = 0;
|
||||||
|
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||||
|
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
|
||||||
|
if (!token.empty()) {
|
||||||
|
size_t eq = token.find('=');
|
||||||
|
if (eq == std::string::npos) {
|
||||||
|
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
|
||||||
|
} else {
|
||||||
|
std::string key = trim_tiling_arg(token.substr(0, eq));
|
||||||
|
std::string value = trim_tiling_arg(token.substr(eq + 1));
|
||||||
|
int parsed = 0;
|
||||||
|
if (!parse_tiling_int(value, parsed)) {
|
||||||
|
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
} else if (key == "temporal_tile_frames") {
|
||||||
|
temporal_tile_frames = std::max(1, parsed);
|
||||||
|
} else if (key == "temporal_tile_overlap") {
|
||||||
|
temporal_tile_overlap = std::max(0, parsed);
|
||||||
|
} else {
|
||||||
|
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
|
||||||
vae.get_param_tensors(tensors, prefix);
|
vae.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
@ -1195,7 +1309,10 @@ struct LTXVideoVAE : public VAE {
|
|||||||
bool use_tiled = decode_graph && temporal_tiling_enabled &&
|
bool use_tiled = decode_graph && temporal_tiling_enabled &&
|
||||||
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
|
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
|
||||||
if (use_tiled) {
|
if (use_tiled) {
|
||||||
out = vae.decode_tiled(&runner_ctx, z, timestep);
|
LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d",
|
||||||
|
temporal_tile_frames,
|
||||||
|
temporal_tile_overlap);
|
||||||
|
out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap);
|
||||||
} else {
|
} else {
|
||||||
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -151,7 +151,7 @@ public:
|
|||||||
bool apply_lora_immediately = false;
|
bool apply_lora_immediately = false;
|
||||||
|
|
||||||
std::string taesd_path;
|
std::string taesd_path;
|
||||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
|
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
|
||||||
bool offload_params_to_cpu = false;
|
bool offload_params_to_cpu = false;
|
||||||
float max_vram = 0.f;
|
float max_vram = 0.f;
|
||||||
bool use_pmid = false;
|
bool use_pmid = false;
|
||||||
@ -2679,7 +2679,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
sd_img_gen_params->batch_count = 1;
|
sd_img_gen_params->batch_count = 1;
|
||||||
sd_img_gen_params->control_strength = 0.9f;
|
sd_img_gen_params->control_strength = 0.9f;
|
||||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||||
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||||
sd_cache_params_init(&sd_img_gen_params->cache);
|
sd_cache_params_init(&sd_img_gen_params->cache);
|
||||||
sd_hires_params_init(&sd_img_gen_params->hires);
|
sd_hires_params_init(&sd_img_gen_params->hires);
|
||||||
}
|
}
|
||||||
@ -2708,7 +2708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
"increase_ref_index: %s\n"
|
"increase_ref_index: %s\n"
|
||||||
"control_strength: %.2f\n"
|
"control_strength: %.2f\n"
|
||||||
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
||||||
"VAE tiling: %s (temporal=%s)\n"
|
"VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n"
|
||||||
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
||||||
SAFE_STR(sd_img_gen_params->prompt),
|
SAFE_STR(sd_img_gen_params->prompt),
|
||||||
SAFE_STR(sd_img_gen_params->negative_prompt),
|
SAFE_STR(sd_img_gen_params->negative_prompt),
|
||||||
@ -2728,6 +2728,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
||||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
|
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
|
||||||
|
SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args),
|
||||||
BOOL_STR(sd_img_gen_params->hires.enabled),
|
BOOL_STR(sd_img_gen_params->hires.enabled),
|
||||||
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
||||||
SAFE_STR(sd_img_gen_params->hires.model_path),
|
SAFE_STR(sd_img_gen_params->hires.model_path),
|
||||||
@ -2765,7 +2766,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
|||||||
sd_vid_gen_params->fps = 16;
|
sd_vid_gen_params->fps = 16;
|
||||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||||
sd_vid_gen_params->vace_strength = 1.f;
|
sd_vid_gen_params->vace_strength = 1.f;
|
||||||
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||||
sd_vid_gen_params->hires.enabled = false;
|
sd_vid_gen_params->hires.enabled = false;
|
||||||
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
|
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
|
||||||
sd_vid_gen_params->hires.scale = 2.f;
|
sd_vid_gen_params->hires.scale = 2.f;
|
||||||
|
|||||||
12
src/tae.hpp
12
src/tae.hpp
@ -265,7 +265,7 @@ class WideMemBlock : public GGMLBlock {
|
|||||||
public:
|
public:
|
||||||
WideMemBlock(int channels, int out_channels)
|
WideMemBlock(int channels, int out_channels)
|
||||||
: has_skip_conv(channels != out_channels) {
|
: has_skip_conv(channels != out_channels) {
|
||||||
int groups = std::max(1, out_channels / 64);
|
int groups = std::max(1, out_channels / 64);
|
||||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
|
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
|
||||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
|
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
|
||||||
@ -479,12 +479,12 @@ public:
|
|||||||
int index = 3;
|
int index = 3;
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
for (int j = 0; j < num_blocks; j++) {
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||||
if (is_wide) {
|
if (is_wide) {
|
||||||
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
||||||
h = block->forward(ctx, h, mem);
|
h = block->forward(ctx, h, mem);
|
||||||
} else{
|
} else {
|
||||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||||
h = block->forward(ctx, h, mem);
|
h = block->forward(ctx, h, mem);
|
||||||
}
|
}
|
||||||
@ -683,7 +683,7 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
struct TinyVideoAutoEncoder : public VAE {
|
struct TinyVideoAutoEncoder : public VAE {
|
||||||
TAEHV taehv;
|
TAEHV taehv;
|
||||||
bool decode_only = false;
|
bool decode_only = false;
|
||||||
bool is_wide = false;
|
bool is_wide = false;
|
||||||
|
|
||||||
TinyVideoAutoEncoder(ggml_backend_t backend,
|
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend,
|
ggml_backend_t params_backend,
|
||||||
@ -699,7 +699,7 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
taehv = TAEHV(decoder_only, version, is_wide);
|
taehv = TAEHV(decoder_only, version, is_wide);
|
||||||
scale_input = false;
|
scale_input = false;
|
||||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -167,6 +167,7 @@ public:
|
|||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
sd::Tensor<float> input = x;
|
sd::Tensor<float> input = x;
|
||||||
sd::Tensor<float> output;
|
sd::Tensor<float> output;
|
||||||
|
set_tiling_params(tiling_params);
|
||||||
|
|
||||||
if (tiling_params.enabled) {
|
if (tiling_params.enabled) {
|
||||||
const int scale_factor = get_scale_factor();
|
const int scale_factor = get_scale_factor();
|
||||||
@ -216,6 +217,9 @@ public:
|
|||||||
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||||
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
|
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
|
||||||
|
virtual void set_tiling_params(const sd_tiling_params_t& params) {
|
||||||
|
set_temporal_tiling_enabled(params.temporal_tiling);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FakeVAE : public VAE {
|
struct FakeVAE : public VAE {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user