mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
Compare commits
No commits in common. "3a8788cb7d74f185d6b18688e9563015524ecaf5" and "2e3514625a31e2abd00b9f8c8be0dc3517d628a5" have entirely different histories.
3a8788cb7d
...
2e3514625a
@ -107,8 +107,6 @@ Generation Options:
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
|
||||
@ -835,10 +835,6 @@ ArgOptions SDGenerationParams::get_options() {
|
||||
"--extra-sample-args",
|
||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
||||
&extra_sample_args},
|
||||
{"",
|
||||
"--extra-tiling-args",
|
||||
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
|
||||
&extra_tiling_args},
|
||||
};
|
||||
|
||||
options.int_options = {
|
||||
@ -1784,9 +1780,6 @@ bool SDGenerationParams::from_json_str(
|
||||
if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) {
|
||||
vae_tiling_params.rel_size_y = tiling_json["rel_size_y"];
|
||||
}
|
||||
if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) {
|
||||
extra_tiling_args = tiling_json["extra_tiling_args"].get<std::string>();
|
||||
}
|
||||
}
|
||||
|
||||
if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) {
|
||||
@ -2009,8 +2002,6 @@ bool SDGenerationParams::initialize_cache_params() {
|
||||
}
|
||||
|
||||
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
|
||||
if (high_noise_sample_params.sample_steps <= 0) {
|
||||
high_noise_sample_params.sample_steps = -1;
|
||||
}
|
||||
@ -2197,7 +2188,6 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
|
||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||
|
||||
sd_pm_params_t pm_params = {
|
||||
@ -2271,7 +2261,6 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
|
||||
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
|
||||
sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str();
|
||||
high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str();
|
||||
vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str();
|
||||
cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str();
|
||||
|
||||
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
|
||||
@ -2397,8 +2386,7 @@ std::string SDGenerationParams::to_string() const {
|
||||
<< vae_tiling_params.tile_size_y << ", "
|
||||
<< vae_tiling_params.target_overlap << ", "
|
||||
<< vae_tiling_params.rel_size_x << ", "
|
||||
<< vae_tiling_params.rel_size_y << ", "
|
||||
<< "\"" << extra_tiling_args << "\" },\n"
|
||||
<< vae_tiling_params.rel_size_y << " },\n"
|
||||
<< "}";
|
||||
return oss.str();
|
||||
}
|
||||
@ -2577,18 +2565,14 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
|
||||
};
|
||||
}
|
||||
|
||||
if (gen_params.vae_tiling_params.enabled ||
|
||||
gen_params.vae_tiling_params.temporal_tiling ||
|
||||
!gen_params.extra_tiling_args.empty()) {
|
||||
if (gen_params.vae_tiling_params.enabled) {
|
||||
root["vae_tiling"] = {
|
||||
{"enabled", gen_params.vae_tiling_params.enabled},
|
||||
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
|
||||
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
|
||||
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
|
||||
{"target_overlap", gen_params.vae_tiling_params.target_overlap},
|
||||
{"rel_size_x", gen_params.vae_tiling_params.rel_size_x},
|
||||
{"rel_size_y", gen_params.vae_tiling_params.rel_size_y},
|
||||
{"extra_tiling_args", gen_params.extra_tiling_args},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -189,8 +189,7 @@ struct SDGenerationParams {
|
||||
int video_frames = 1;
|
||||
int fps = 16;
|
||||
float vace_strength = 1.f;
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
std::string extra_tiling_args;
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
|
||||
std::string pm_id_images_dir;
|
||||
std::string pm_id_embed_path;
|
||||
|
||||
@ -209,8 +209,6 @@ Default Generation Options:
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
|
||||
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
@ -266,7 +264,6 @@ Default Generation Options:
|
||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||
--disable-image-metadata do not embed generation metadata on image files
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--temporal-tiling enable temporal tiling for LTX video VAE decode
|
||||
--hires enable highres fix
|
||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||
|
||||
@ -504,13 +504,11 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
||||
| `sample_params.guidance.slg.scale` | `number` |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
@ -518,8 +516,6 @@ Shared default fields used by both `img_gen` and `vid_gen`:
|
||||
| `output_format` | `string` |
|
||||
| `output_compression` | `integer` |
|
||||
|
||||
`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
|
||||
|
||||
`img_gen`-specific default fields:
|
||||
|
||||
| Field | Type |
|
||||
@ -696,13 +692,11 @@ Example:
|
||||
|
||||
"vae_tiling_params": {
|
||||
"enabled": false,
|
||||
"temporal_tiling": false,
|
||||
"tile_size_x": 0,
|
||||
"tile_size_y": 0,
|
||||
"target_overlap": 0.5,
|
||||
"rel_size_x": 0.0,
|
||||
"rel_size_y": 0.0,
|
||||
"extra_tiling_args": ""
|
||||
"rel_size_y": 0.0
|
||||
},
|
||||
|
||||
"cache_mode": "disabled",
|
||||
@ -810,14 +804,6 @@ Other native fields:
|
||||
| `hires.custom_sigmas` | `array<number>` |
|
||||
| `hires.upscale_tile_size` | `integer` |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
@ -1026,13 +1012,11 @@ Example:
|
||||
|
||||
"vae_tiling_params": {
|
||||
"enabled": false,
|
||||
"temporal_tiling": false,
|
||||
"tile_size_x": 0,
|
||||
"tile_size_y": 0,
|
||||
"target_overlap": 0.5,
|
||||
"rel_size_x": 0.0,
|
||||
"rel_size_y": 0.0,
|
||||
"extra_tiling_args": ""
|
||||
"rel_size_y": 0.0
|
||||
},
|
||||
|
||||
"cache_mode": "disabled",
|
||||
@ -1150,14 +1134,6 @@ Other native fields:
|
||||
| Field | Type |
|
||||
| --- | --- |
|
||||
| `vae_tiling_params` | `object` |
|
||||
| `vae_tiling_params.enabled` | `boolean` |
|
||||
| `vae_tiling_params.temporal_tiling` | `boolean` |
|
||||
| `vae_tiling_params.tile_size_x` | `integer` |
|
||||
| `vae_tiling_params.tile_size_y` | `integer` |
|
||||
| `vae_tiling_params.target_overlap` | `number` |
|
||||
| `vae_tiling_params.rel_size_x` | `number` |
|
||||
| `vae_tiling_params.rel_size_y` | `number` |
|
||||
| `vae_tiling_params.extra_tiling_args` | `string` |
|
||||
| `cache_mode` | `string` |
|
||||
| `cache_option` | `string` |
|
||||
| `scm_mask` | `string` |
|
||||
|
||||
@ -56,13 +56,11 @@ static const char* capability_sample_method_name(enum sample_method_t sample_met
|
||||
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
|
||||
return {
|
||||
{"enabled", params.enabled},
|
||||
{"temporal_tiling", params.temporal_tiling},
|
||||
{"tile_size_x", params.tile_size_x},
|
||||
{"tile_size_y", params.tile_size_y},
|
||||
{"target_overlap", params.target_overlap},
|
||||
{"rel_size_x", params.rel_size_x},
|
||||
{"rel_size_y", params.rel_size_y},
|
||||
{"extra_tiling_args", params.extra_tiling_args ? params.extra_tiling_args : ""},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -160,7 +160,6 @@ typedef struct {
|
||||
float target_overlap;
|
||||
float rel_size_x;
|
||||
float rel_size_y;
|
||||
const char* extra_tiling_args;
|
||||
} sd_tiling_params_t;
|
||||
|
||||
typedef struct {
|
||||
|
||||
158
src/denoiser.hpp
158
src/denoiser.hpp
@ -496,26 +496,84 @@ struct LTX2Scheduler : SigmaScheduler {
|
||||
parse_extra_sample_args(extra_sample_args);
|
||||
}
|
||||
|
||||
static std::string trim(std::string value) {
|
||||
const char* whitespace = " \t\r\n";
|
||||
size_t begin = value.find_first_not_of(whitespace);
|
||||
if (begin == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
size_t end = value.find_last_not_of(whitespace);
|
||||
return value.substr(begin, end - begin + 1);
|
||||
}
|
||||
|
||||
void parse_extra_sample_args(const char* extra_sample_args) {
|
||||
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
|
||||
if (key == "max_shift") {
|
||||
if (!parse_strict_float(value, max_shift)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string raw(extra_sample_args);
|
||||
size_t start = 0;
|
||||
auto parse_arg = [&](const std::string& item) {
|
||||
std::string token = trim(item);
|
||||
if (token.empty()) {
|
||||
return;
|
||||
}
|
||||
size_t eq = token.find('=');
|
||||
if (eq == std::string::npos) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
std::string key = trim(token.substr(0, eq));
|
||||
std::string value = trim(token.substr(eq + 1));
|
||||
auto parse_float = [&](float* out) -> bool {
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
float parsed = std::stof(value, &consumed);
|
||||
if (!trim(value.substr(consumed)).empty()) {
|
||||
return false;
|
||||
}
|
||||
*out = parsed;
|
||||
return true;
|
||||
} catch (const std::exception&) {
|
||||
return false;
|
||||
}
|
||||
} else if (key == "base_shift") {
|
||||
if (!parse_strict_float(value, base_shift)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
};
|
||||
try {
|
||||
if (key == "max_shift") {
|
||||
if (!parse_float(&max_shift)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
}
|
||||
} else if (key == "base_shift") {
|
||||
if (!parse_float(&base_shift)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
}
|
||||
} else if (key == "terminal") {
|
||||
if (!parse_float(&terminal)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
}
|
||||
} else if (key == "stretch") {
|
||||
std::string v = value;
|
||||
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
if (v == "1" || v == "true" || v == "yes" || v == "on") {
|
||||
stretch = true;
|
||||
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
|
||||
stretch = false;
|
||||
} else {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
}
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
|
||||
}
|
||||
} else if (key == "terminal") {
|
||||
if (!parse_strict_float(value, terminal)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
} else if (key == "stretch") {
|
||||
if (!parse_strict_bool(value, stretch)) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||
parse_arg(raw.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1218,7 +1276,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
|
||||
return x;
|
||||
}
|
||||
|
||||
using SamplerExtraArgs = KeyValueArgs;
|
||||
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||
sd::Tensor<float> x,
|
||||
@ -1238,8 +1296,15 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||
|
||||
for (const auto& [key, value] : extra_sample_args) {
|
||||
float parsed = 0.0f;
|
||||
if (!parse_strict_float(value, parsed)) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stof(value, &consumed);
|
||||
if (trim(value.substr(consumed)).size() != 0) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
if (key == "noise_clip_std") {
|
||||
@ -1796,8 +1861,15 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
||||
|
||||
for (const auto& [key, value] : extra_sample_args) {
|
||||
float parsed = 0.0f;
|
||||
if (!parse_strict_float(value, parsed)) {
|
||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stof(value, &consumed);
|
||||
if (trim(value.substr(consumed)).size() != 0) {
|
||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
if (key == "gamma") {
|
||||
@ -1844,6 +1916,46 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
||||
return x;
|
||||
}
|
||||
|
||||
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
|
||||
SamplerExtraArgs pairs;
|
||||
|
||||
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
||||
return pairs;
|
||||
}
|
||||
|
||||
auto trim = [](std::string value) -> std::string {
|
||||
const char* whitespace = " \t\r\n";
|
||||
size_t begin = value.find_first_not_of(whitespace);
|
||||
if (begin == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
size_t end = value.find_last_not_of(whitespace);
|
||||
return value.substr(begin, end - begin + 1);
|
||||
};
|
||||
|
||||
std::string raw(extra_sample_args);
|
||||
size_t start = 0;
|
||||
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||
std::string item = raw.substr(start, pos - start);
|
||||
std::string token = trim(item);
|
||||
|
||||
if (!token.empty()) {
|
||||
size_t eq = token.find('=');
|
||||
if (eq != std::string::npos) {
|
||||
std::string key = trim(token.substr(0, eq));
|
||||
std::string value = trim(token.substr(eq + 1));
|
||||
pairs.emplace_back(std::move(key), std::move(value));
|
||||
}
|
||||
}
|
||||
start = pos + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return pairs;
|
||||
}
|
||||
|
||||
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
||||
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
denoise_cb_t model,
|
||||
@ -1853,7 +1965,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
float eta,
|
||||
bool is_flow_denoiser,
|
||||
const char* extra_sample_args) {
|
||||
SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
|
||||
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
|
||||
switch (method) {
|
||||
case EULER_A_SAMPLE_METHOD:
|
||||
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
||||
|
||||
@ -3172,7 +3172,7 @@ protected:
|
||||
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
this->prefix = prefix;
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
||||
if (bias) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
|
||||
|
||||
265
src/ltx_vae.hpp
265
src/ltx_vae.hpp
@ -1,7 +1,6 @@
|
||||
#ifndef __SD_LTX_VAE_HPP__
|
||||
#define __SD_LTX_VAE_HPP__
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -144,25 +143,16 @@ namespace LTXVAE {
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
bool causal = true,
|
||||
int temporal_pad = 0) {
|
||||
bool causal = true) {
|
||||
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
||||
const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2;
|
||||
|
||||
ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr;
|
||||
|
||||
GGML_ASSERT(x->ne[2] >= temporal_pad);
|
||||
|
||||
int end_idx = x->ne[2] - temporal_pad;
|
||||
int start_idx = std::max(end_idx - pad, 0);
|
||||
|
||||
// Save a contiguous copy of the last `pad` frames so the large `x`
|
||||
// tensor is not kept alive across iterations by a dangling view.
|
||||
if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) {
|
||||
GGML_ASSERT(start_idx >= 0);
|
||||
GGML_ASSERT(end_idx > 0);
|
||||
|
||||
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx);
|
||||
if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) {
|
||||
auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]);
|
||||
feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice);
|
||||
}
|
||||
feat_idx++;
|
||||
@ -294,8 +284,7 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
int chunk_idx) {
|
||||
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
|
||||
@ -322,14 +311,14 @@ namespace LTXVAE {
|
||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
|
||||
}
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
||||
|
||||
h = norm2->forward(ctx, h);
|
||||
if (timestep_conditioning) {
|
||||
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
|
||||
}
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal);
|
||||
|
||||
return ggml_add(ctx->ggml_ctx, h, x);
|
||||
}
|
||||
@ -378,8 +367,7 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
int chunk_idx) {
|
||||
ggml_tensor* timestep_embed = nullptr;
|
||||
if (timestep_conditioning) {
|
||||
GGML_ASSERT(timestep != nullptr);
|
||||
@ -388,7 +376,7 @@ namespace LTXVAE {
|
||||
}
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
|
||||
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -449,8 +437,7 @@ namespace LTXVAE {
|
||||
bool causal,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
int temporal_pad = 0) {
|
||||
int chunk_idx) {
|
||||
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
|
||||
|
||||
bool drop_first = (chunk_idx == 0) && (factor_t > 1);
|
||||
@ -466,7 +453,7 @@ namespace LTXVAE {
|
||||
x_in = res;
|
||||
}
|
||||
|
||||
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad);
|
||||
x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal);
|
||||
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first);
|
||||
if (residual) {
|
||||
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||
@ -999,8 +986,7 @@ namespace LTXVAE {
|
||||
ggml_tensor* timestep,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int& feat_idx,
|
||||
int chunk_idx,
|
||||
int& temporal_pad) {
|
||||
int chunk_idx) {
|
||||
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
|
||||
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
|
||||
@ -1012,7 +998,7 @@ namespace LTXVAE {
|
||||
}
|
||||
|
||||
// conv_in with feat_map for left temporal context
|
||||
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||
x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
||||
|
||||
// up_blocks
|
||||
int block_idx = 0;
|
||||
@ -1020,13 +1006,12 @@ namespace LTXVAE {
|
||||
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
|
||||
if (mid_block) {
|
||||
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder,
|
||||
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
feat_map, feat_idx, chunk_idx);
|
||||
} else {
|
||||
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(
|
||||
blocks["up_blocks." + std::to_string(block_idx)]);
|
||||
x = upsample->forward(ctx, x, causal_decoder,
|
||||
feat_map, feat_idx, chunk_idx, temporal_pad);
|
||||
temporal_pad *= upsample->factor_t;
|
||||
feat_map, feat_idx, chunk_idx);
|
||||
}
|
||||
block_idx++;
|
||||
}
|
||||
@ -1043,7 +1028,7 @@ namespace LTXVAE {
|
||||
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
|
||||
}
|
||||
x = ggml_silu_inplace(ctx->ggml_ctx, x);
|
||||
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad);
|
||||
x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1099,9 +1084,7 @@ namespace LTXVAE {
|
||||
// tensors can be freed by GGML before the next iteration starts.
|
||||
ggml_tensor* decode_tiled(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* z,
|
||||
ggml_tensor* timestep,
|
||||
int temporal_window_size = 1,
|
||||
int temporal_tile_overlap = 0) {
|
||||
ggml_tensor* timestep) {
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
||||
auto latents = processor->un_normalize(ctx, z);
|
||||
@ -1116,69 +1099,18 @@ namespace LTXVAE {
|
||||
// 128 slots is generous enough for any supported decoder configuration.
|
||||
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
||||
|
||||
// Ensure window size is at least 1
|
||||
int window = std::max(1, temporal_window_size);
|
||||
int overlap = std::max(0, temporal_tile_overlap);
|
||||
|
||||
if (overlap >= window) {
|
||||
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
|
||||
overlap, window);
|
||||
overlap = window - 1;
|
||||
}
|
||||
LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles",
|
||||
window,
|
||||
overlap,
|
||||
(int)T,
|
||||
(T + window - overlap - 1) / (window - overlap));
|
||||
ggml_tensor* out = nullptr;
|
||||
for (int i = 0; i < (int)T - overlap; i += (window - overlap)) {
|
||||
for (int i = 0; i < (int)T; i++) {
|
||||
int feat_idx = 0;
|
||||
|
||||
// Calculate the end index for the current temporal chunk
|
||||
int end_i = std::min((int)T, i + window);
|
||||
if (end_i >= (int)T) {
|
||||
overlap = 0; // avoid overlap issues in the last chunk
|
||||
}
|
||||
|
||||
int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation
|
||||
|
||||
auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i);
|
||||
|
||||
auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep,
|
||||
feat_map, feat_idx, i, chunk_overlap);
|
||||
|
||||
// discard overlap frames if it's not the final chunk
|
||||
if (overlap > 0 && end_i < (int)T) {
|
||||
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
|
||||
}
|
||||
|
||||
out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2);
|
||||
auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1);
|
||||
auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep,
|
||||
feat_map, feat_idx, i);
|
||||
out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2);
|
||||
}
|
||||
|
||||
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
|
||||
}
|
||||
|
||||
ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* z,
|
||||
ggml_tensor* timestep,
|
||||
std::vector<ggml_tensor*>& feat_map,
|
||||
int chunk_idx,
|
||||
int temporal_tile_overlap,
|
||||
int& feat_idx) {
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
|
||||
auto latents = processor->un_normalize(ctx, z);
|
||||
|
||||
feat_idx = 0;
|
||||
int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation
|
||||
auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep,
|
||||
feat_map, feat_idx, chunk_idx, chunk_overlap);
|
||||
if (chunk_overlap > 0) {
|
||||
out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap);
|
||||
}
|
||||
return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1);
|
||||
}
|
||||
|
||||
ggml_tensor* encode(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x) {
|
||||
GGML_ASSERT(!decode_only);
|
||||
@ -1208,13 +1140,8 @@ namespace LTXVAE {
|
||||
} // namespace LTXVAE
|
||||
|
||||
struct LTXVideoVAE : public VAE {
|
||||
static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4;
|
||||
static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1;
|
||||
|
||||
bool decode_only;
|
||||
bool temporal_tiling_enabled = false;
|
||||
int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||
int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||
int ltx_vae_version;
|
||||
bool timestep_conditioning;
|
||||
int patch_size;
|
||||
@ -1251,64 +1178,10 @@ struct LTXVideoVAE : public VAE {
|
||||
temporal_tiling_enabled = enabled;
|
||||
}
|
||||
|
||||
void set_tiling_params(const sd_tiling_params_t& params) override {
|
||||
temporal_tiling_enabled = params.temporal_tiling;
|
||||
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||
|
||||
for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
|
||||
int parsed = 0;
|
||||
if (!parse_strict_int(value, parsed)) {
|
||||
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
|
||||
} else if (key == "temporal_tile_frames") {
|
||||
temporal_tile_frames = std::max(1, parsed);
|
||||
} else if (key == "temporal_tile_overlap") {
|
||||
temporal_tile_overlap = std::max(0, parsed);
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
|
||||
vae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
struct TemporalTilePlan {
|
||||
int frames = 1;
|
||||
int overlap = 0;
|
||||
int stride = 1;
|
||||
int num_tiles = 1;
|
||||
};
|
||||
|
||||
TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const {
|
||||
TemporalTilePlan plan;
|
||||
plan.frames = std::max(1, temporal_tile_frames);
|
||||
plan.overlap = std::max(0, temporal_tile_overlap);
|
||||
|
||||
if (plan.overlap >= plan.frames) {
|
||||
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows",
|
||||
plan.overlap,
|
||||
plan.frames);
|
||||
plan.overlap = plan.frames - 1;
|
||||
}
|
||||
if (total_frames > 1 && plan.overlap >= total_frames) {
|
||||
LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile",
|
||||
plan.overlap,
|
||||
(long long)total_frames);
|
||||
plan.overlap = static_cast<int>(total_frames - 1);
|
||||
}
|
||||
|
||||
plan.stride = std::max(1, plan.frames - plan.overlap);
|
||||
int64_t tiled_frames = std::max<int64_t>(1, total_frames - plan.overlap);
|
||||
plan.num_tiles = total_frames > 0 ? static_cast<int>((tiled_frames + plan.stride - 1) / plan.stride) : 0;
|
||||
return plan;
|
||||
}
|
||||
|
||||
std::string temporal_feat_cache_name(size_t feat_idx) const {
|
||||
return "ltx_vae_temporal_feat:" + std::to_string(feat_idx);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||
ggml_cgraph* gf = new_graph_custom(20480);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
@ -1319,97 +1192,18 @@ struct LTXVideoVAE : public VAE {
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out;
|
||||
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
||||
bool use_tiled = decode_graph && temporal_tiling_enabled &&
|
||||
z_tensor.dim() == 5 && z_tensor.shape()[2] > 1;
|
||||
if (use_tiled) {
|
||||
out = vae.decode_tiled(&runner_ctx, z, timestep);
|
||||
} else {
|
||||
out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
|
||||
}
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_temporal_tile_graph(const sd::Tensor<float>& z_chunk_tensor,
|
||||
int chunk_idx,
|
||||
int chunk_overlap) {
|
||||
ggml_cgraph* gf = new_graph_custom(20480);
|
||||
ggml_tensor* z = make_input(z_chunk_tensor);
|
||||
ggml_tensor* timestep = nullptr;
|
||||
if (timestep_conditioning) {
|
||||
timestep = make_input(decode_timestep_tensor);
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor*> feat_map(128, nullptr);
|
||||
for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) {
|
||||
feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx));
|
||||
}
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
int feat_count = 0;
|
||||
ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx,
|
||||
z,
|
||||
timestep,
|
||||
feat_map,
|
||||
chunk_idx,
|
||||
chunk_overlap,
|
||||
feat_count);
|
||||
|
||||
for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast<int>(feat_map.size()); ++feat_idx) {
|
||||
ggml_tensor* feat_cache = feat_map[static_cast<size_t>(feat_idx)];
|
||||
if (feat_cache != nullptr) {
|
||||
cache(temporal_feat_cache_name(static_cast<size_t>(feat_idx)), feat_cache);
|
||||
ggml_build_forward_expand(gf, feat_cache);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
sd::Tensor<float> decode_temporal_tiled_streaming(const int n_threads,
|
||||
const sd::Tensor<float>& input,
|
||||
size_t expected_dim) {
|
||||
const int64_t total_frames = input.shape()[2];
|
||||
TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames);
|
||||
|
||||
LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles",
|
||||
plan.frames,
|
||||
plan.overlap,
|
||||
(long long)total_frames,
|
||||
plan.num_tiles);
|
||||
|
||||
free_cache_ctx_and_buffer();
|
||||
cache_tensor_map.clear();
|
||||
|
||||
sd::Tensor<float> output;
|
||||
for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) {
|
||||
const int64_t end = std::min<int64_t>(total_frames, start + plan.frames);
|
||||
const int chunk_overlap = end < total_frames ? plan.overlap : 0;
|
||||
auto z_chunk = sd::ops::slice(input, 2, start, end);
|
||||
|
||||
LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d",
|
||||
(long long)(start / plan.stride + 1),
|
||||
plan.num_tiles,
|
||||
(long long)start,
|
||||
(long long)end,
|
||||
chunk_overlap);
|
||||
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_temporal_tile_graph(z_chunk,
|
||||
static_cast<int>(start),
|
||||
chunk_overlap);
|
||||
};
|
||||
auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, true),
|
||||
expected_dim);
|
||||
if (chunk.empty()) {
|
||||
free_cache_ctx_and_buffer();
|
||||
cache_tensor_map.clear();
|
||||
return {};
|
||||
}
|
||||
output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2);
|
||||
}
|
||||
|
||||
free_cache_ctx_and_buffer();
|
||||
cache_tensor_map.clear();
|
||||
return output;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_latent_statistics_graph(const sd::Tensor<float>& z_tensor, bool normalize) {
|
||||
ggml_cgraph* gf = new_graph_custom(1024);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
@ -1445,9 +1239,6 @@ struct LTXVideoVAE : public VAE {
|
||||
input = sd::ops::slice(input, 2, 0, cropped_t);
|
||||
}
|
||||
}
|
||||
if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) {
|
||||
return decode_temporal_tiled_streaming(n_threads, input, expected_dim);
|
||||
}
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(input, decode_graph);
|
||||
};
|
||||
|
||||
@ -151,7 +151,7 @@ public:
|
||||
bool apply_lora_immediately = false;
|
||||
|
||||
std::string taesd_path;
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
|
||||
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
|
||||
bool offload_params_to_cpu = false;
|
||||
float max_vram = 0.f;
|
||||
bool use_pmid = false;
|
||||
@ -2679,7 +2679,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->batch_count = 1;
|
||||
sd_img_gen_params->control_strength = 0.9f;
|
||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_cache_params_init(&sd_img_gen_params->cache);
|
||||
sd_hires_params_init(&sd_img_gen_params->hires);
|
||||
}
|
||||
@ -2708,7 +2708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
"increase_ref_index: %s\n"
|
||||
"control_strength: %.2f\n"
|
||||
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
||||
"VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n"
|
||||
"VAE tiling: %s (temporal=%s)\n"
|
||||
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
||||
SAFE_STR(sd_img_gen_params->prompt),
|
||||
SAFE_STR(sd_img_gen_params->negative_prompt),
|
||||
@ -2728,7 +2728,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
|
||||
SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args),
|
||||
BOOL_STR(sd_img_gen_params->hires.enabled),
|
||||
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
||||
SAFE_STR(sd_img_gen_params->hires.model_path),
|
||||
@ -2766,7 +2765,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
sd_vid_gen_params->fps = 16;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
sd_vid_gen_params->vace_strength = 1.f;
|
||||
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
|
||||
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_vid_gen_params->hires.enabled = false;
|
||||
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
|
||||
sd_vid_gen_params->hires.scale = 2.f;
|
||||
|
||||
14
src/tae.hpp
14
src/tae.hpp
@ -265,7 +265,7 @@ class WideMemBlock : public GGMLBlock {
|
||||
public:
|
||||
WideMemBlock(int channels, int out_channels)
|
||||
: has_skip_conv(channels != out_channels) {
|
||||
int groups = std::max(1, out_channels / 64);
|
||||
int groups = std::max(1, out_channels / 64);
|
||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
|
||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
|
||||
@ -479,12 +479,12 @@ public:
|
||||
int index = 3;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||
if (is_wide) {
|
||||
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
} else {
|
||||
} else{
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
}
|
||||
@ -683,8 +683,8 @@ struct TinyImageAutoEncoder : public VAE {
|
||||
struct TinyVideoAutoEncoder : public VAE {
|
||||
TAEHV taehv;
|
||||
bool decode_only = false;
|
||||
bool is_wide = false;
|
||||
|
||||
bool is_wide = false;
|
||||
|
||||
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||
ggml_backend_t params_backend,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
@ -699,7 +699,7 @@ struct TinyVideoAutoEncoder : public VAE {
|
||||
break;
|
||||
}
|
||||
}
|
||||
taehv = TAEHV(decoder_only, version, is_wide);
|
||||
taehv = TAEHV(decoder_only, version, is_wide);
|
||||
scale_input = false;
|
||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
84
src/util.cpp
84
src/util.cpp
@ -1,10 +1,8 @@
|
||||
#include "util.h"
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <cstdarg>
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <locale>
|
||||
#include <regex>
|
||||
@ -408,88 +406,6 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
|
||||
return result;
|
||||
}
|
||||
|
||||
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
|
||||
KeyValueArgs pairs;
|
||||
|
||||
if (args == nullptr || args[0] == '\0') {
|
||||
return pairs;
|
||||
}
|
||||
|
||||
std::string raw(args);
|
||||
size_t start = 0;
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string token = trim(raw.substr(start, pos - start));
|
||||
if (!token.empty()) {
|
||||
size_t eq = token.find('=');
|
||||
if (eq == std::string::npos) {
|
||||
const char* log_context = context ? context : "key=value arg";
|
||||
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
|
||||
} else {
|
||||
std::string key = trim(token.substr(0, eq));
|
||||
std::string value = trim(token.substr(eq + 1));
|
||||
pairs.emplace_back(std::move(key), std::move(value));
|
||||
}
|
||||
}
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
|
||||
return pairs;
|
||||
}
|
||||
|
||||
KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
|
||||
return parse_key_value_args(args.c_str(), context);
|
||||
}
|
||||
|
||||
bool parse_strict_float(const std::string& text, float& value) {
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
float parsed = std::stof(text, &consumed);
|
||||
if (!trim(text.substr(consumed)).empty()) {
|
||||
return false;
|
||||
}
|
||||
value = parsed;
|
||||
return true;
|
||||
} catch (const std::exception&) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_strict_int(const std::string& text, int& value) {
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
int parsed = std::stoi(text, &consumed);
|
||||
if (!trim(text.substr(consumed)).empty()) {
|
||||
return false;
|
||||
}
|
||||
value = parsed;
|
||||
return true;
|
||||
} catch (const std::exception&) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_strict_bool(const std::string& text, bool& value) {
|
||||
std::string lowered = trim(text);
|
||||
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
|
||||
if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
|
||||
value = true;
|
||||
return true;
|
||||
}
|
||||
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
|
||||
value = false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string build_progress_bar(int step, int steps) {
|
||||
std::string progress = " |";
|
||||
int max_progress = 50;
|
||||
|
||||
10
src/util.h
10
src/util.h
@ -4,7 +4,6 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml-backend.h"
|
||||
@ -66,15 +65,6 @@ protected:
|
||||
|
||||
std::string path_join(const std::string& p1, const std::string& p2);
|
||||
std::vector<std::string> split_string(const std::string& str, char delimiter);
|
||||
|
||||
using KeyValueArgs = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg");
|
||||
KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg");
|
||||
bool parse_strict_float(const std::string& text, float& value);
|
||||
bool parse_strict_int(const std::string& text, int& value);
|
||||
bool parse_strict_bool(const std::string& text, bool& value);
|
||||
|
||||
void pretty_progress(int step, int steps, float time);
|
||||
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);
|
||||
|
||||
|
||||
@ -167,7 +167,6 @@ public:
|
||||
int64_t t0 = ggml_time_ms();
|
||||
sd::Tensor<float> input = x;
|
||||
sd::Tensor<float> output;
|
||||
set_tiling_params(tiling_params);
|
||||
|
||||
if (tiling_params.enabled) {
|
||||
const int scale_factor = get_scale_factor();
|
||||
@ -217,9 +216,6 @@ public:
|
||||
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
|
||||
virtual void set_tiling_params(const sd_tiling_params_t& params) {
|
||||
set_temporal_tiling_enabled(params.temporal_tiling);
|
||||
};
|
||||
};
|
||||
|
||||
struct FakeVAE : public VAE {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user