mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add ltx2.3 support (#1463)
* add GemmaTokenizer * add basic ltx2.3 support * change vocab file encoding * fix ci * fix ubuntu build * add temporal tiling support * add ltx audio support * update ggml submodule url * fix generate_video * add i2v support * minify bundled Gemma tokenizer vocab sources * pass video fps into temporal rope embeddings * fix av_ca_timestep_scale_multiplier * add LTX2Scheduler support * update docs * fix ci
This commit is contained in:
parent
3b4d26f3d9
commit
67dda3f897
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -135,7 +135,7 @@ jobs:
|
|||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install build-essential libvulkan-dev glslc
|
sudo apt-get install build-essential libvulkan-dev glslc spirv-headers
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|||||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@ -1,6 +1,6 @@
|
|||||||
[submodule "ggml"]
|
[submodule "ggml"]
|
||||||
path = ggml
|
path = ggml
|
||||||
url = https://github.com/ggml-org/ggml.git
|
url = https://github.com/leejet/ggml.git
|
||||||
[submodule "examples/server/frontend"]
|
[submodule "examples/server/frontend"]
|
||||||
path = examples/server/frontend
|
path = examples/server/frontend
|
||||||
url = https://github.com/leejet/sdcpp-webui.git
|
url = https://github.com/leejet/sdcpp-webui.git
|
||||||
|
|||||||
@ -13,7 +13,9 @@ if (MSVC)
|
|||||||
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)
|
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)
|
||||||
add_compile_options(
|
add_compile_options(
|
||||||
$<$<COMPILE_LANGUAGE:C>:/MP>
|
$<$<COMPILE_LANGUAGE:C>:/MP>
|
||||||
|
$<$<COMPILE_LANGUAGE:C>:/utf-8>
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:/MP>
|
$<$<COMPILE_LANGUAGE:CXX>:/MP>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:/utf-8>
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ ARG UBUNTU_VERSION=24.04
|
|||||||
|
|
||||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git cmake libvulkan-dev glslc
|
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git cmake libvulkan-dev glslc spirv-headers
|
||||||
|
|
||||||
WORKDIR /sd.cpp
|
WORKDIR /sd.cpp
|
||||||
|
|
||||||
|
|||||||
@ -64,6 +64,7 @@ API and command-line option may change frequently.***
|
|||||||
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
|
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Wan2.1/Wan2.2](./docs/wan.md)
|
- [Wan2.1/Wan2.2](./docs/wan.md)
|
||||||
|
- [LTX-2.3](./docs/ltx2.md)
|
||||||
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
|
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
|
||||||
- Control Net support with SD 1.5
|
- Control Net support with SD 1.5
|
||||||
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
|
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
|
||||||
@ -147,6 +148,7 @@ For runtime and parameter backend placement, see the [backend selection guide](.
|
|||||||
- [🔥Qwen Image](./docs/qwen_image.md)
|
- [🔥Qwen Image](./docs/qwen_image.md)
|
||||||
- [🔥Qwen Image Edit series](./docs/qwen_image_edit.md)
|
- [🔥Qwen Image Edit series](./docs/qwen_image_edit.md)
|
||||||
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
|
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
|
||||||
|
- [🔥LTX-2.3](./docs/ltx2.md)
|
||||||
- [🔥Z-Image](./docs/z_image.md)
|
- [🔥Z-Image](./docs/z_image.md)
|
||||||
- [Ovis-Image](./docs/ovis_image.md)
|
- [Ovis-Image](./docs/ovis_image.md)
|
||||||
- [Anima](./docs/anima.md)
|
- [Anima](./docs/anima.md)
|
||||||
|
|||||||
BIN
assets/ltx2/i2v.webm
Normal file
BIN
assets/ltx2/i2v.webm
Normal file
Binary file not shown.
BIN
assets/ltx2/t2v.webm
Normal file
BIN
assets/ltx2/t2v.webm
Normal file
Binary file not shown.
@ -102,6 +102,11 @@ cmake --build . --config Release
|
|||||||
## Build with Vulkan
|
## Build with Vulkan
|
||||||
|
|
||||||
Install Vulkan SDK from https://www.lunarg.com/vulkan-sdk/.
|
Install Vulkan SDK from https://www.lunarg.com/vulkan-sdk/.
|
||||||
|
On Ubuntu, install the Vulkan development packages and SPIR-V headers:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo apt-get install build-essential libvulkan-dev glslc spirv-headers
|
||||||
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
mkdir build && cd build
|
mkdir build && cd build
|
||||||
|
|||||||
41
docs/ltx2.md
Normal file
41
docs/ltx2.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# How to Use
|
||||||
|
|
||||||
|
## Download weights
|
||||||
|
|
||||||
|
- Download LTX-2.3
|
||||||
|
- safetensors: https://huggingface.co/Kijai/LTX2.3_comfy/tree/main/diffusion_models
|
||||||
|
- gguf: https://huggingface.co/unsloth/LTX-2.3-GGUF/tree/main
|
||||||
|
- Download gemma-3-12b-it
|
||||||
|
- gguf: https://huggingface.co/unsloth/gemma-3-12b-it-GGUF/tree/main
|
||||||
|
- Download embeddings connectors
|
||||||
|
- safetensors: https://huggingface.co/unsloth/LTX-2.3-GGUF/tree/main/text_encoders
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/unsloth/LTX-2.3-GGUF/tree/main/vae
|
||||||
|
- Download audio vae
|
||||||
|
- safetensors: https://huggingface.co/unsloth/LTX-2.3-GGUF/tree/main/vae
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### LTX-2.3 dev T2V
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\ltx-2.3-22b-dev-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\ltx-2.3-22b-dev_video_vae.safetensors --audio-vae ..\..\ComfyUI\models\vae\ltx-2.3-22b-dev_audio_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\gemma-3-12b-it-qat-UD-Q4_K_XL.gguf --embeddings-connectors ..\..\ComfyUI\models\text_encoders\ltx-2.3-22b-dev_embeddings_connectors.safetensors -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "worst quality, low quality, blurry, distorted, artifacts" -W 1280 -H 720 --diffusion-fa --offload-to-cpu --video-frames 33 --fps 24 -o t2v.webm
|
||||||
|
```
|
||||||
|
|
||||||
|
<video
|
||||||
|
src="../assets/ltx2/t2v.webm"
|
||||||
|
controls
|
||||||
|
muted
|
||||||
|
style="max-width: 100%; height: auto;"></video>
|
||||||
|
|
||||||
|
### LTX-2.3 dev I2V
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\ltx-2.3-22b-dev-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\ltx-2.3-22b-dev_video_vae.safetensors --audio-vae ..\..\ComfyUI\models\vae\ltx-2.3-22b-dev_audio_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\gemma-3-12b-it-qat-UD-Q4_K_XL.gguf --embeddings-connectors ..\..\ComfyUI\models\text_encoders\ltx-2.3-22b-dev_embeddings_connectors.safetensors -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -W 1280 -H 720 --diffusion-fa --offload-to-cpu --video-frames 33 -i ..\assets\ernie_image\turbo_example.png -o i2v.webm
|
||||||
|
```
|
||||||
|
|
||||||
|
<video
|
||||||
|
src="../assets/ltx2/i2v.webm"
|
||||||
|
controls
|
||||||
|
muted
|
||||||
|
style="max-width: 100%; height: auto;"></video>
|
||||||
@ -7,6 +7,10 @@ add_executable(${TARGET}
|
|||||||
image_metadata.cpp
|
image_metadata.cpp
|
||||||
main.cpp
|
main.cpp
|
||||||
)
|
)
|
||||||
|
target_include_directories(${TARGET} PRIVATE
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/.."
|
||||||
|
"${PROJECT_SOURCE_DIR}/src"
|
||||||
|
)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
|
||||||
if(SD_WEBP)
|
if(SD_WEBP)
|
||||||
|
|||||||
@ -103,8 +103,9 @@ Generation Options:
|
|||||||
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
||||||
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
||||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||||
--extra-sample-args <string> extra sampler args, key=value list. Currently lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
|
stretch, terminal
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
@ -160,6 +161,7 @@ Generation Options:
|
|||||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||||
--disable-image-metadata do not embed generation metadata on image files
|
--disable-image-metadata do not embed generation metadata on image files
|
||||||
--vae-tiling process vae in tiles to reduce memory usage
|
--vae-tiling process vae in tiles to reduce memory usage
|
||||||
|
--temporal-tiling enable temporal tiling for LTX video VAE decode
|
||||||
--hires enable highres fix
|
--hires enable highres fix
|
||||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||||
@ -169,8 +171,8 @@ Generation Options:
|
|||||||
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
||||||
res_2s, er_sde, euler_cfg_pp, euler_a_cfg_pp] default: euler for Flux/SD3/Wan, euler_a otherwise
|
res_2s, er_sde, euler_cfg_pp, euler_a_cfg_pp] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
||||||
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
|
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default:
|
||||||
discrete
|
model-specific
|
||||||
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
||||||
"14.61,7.8,3.5,0.0").
|
"14.61,7.8,3.5,0.0").
|
||||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||||
|
|||||||
@ -385,11 +385,32 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static fs::path get_video_audio_sidecar_path(const SDCliParams& cli_params) {
|
||||||
|
fs::path out_path = cli_params.output_path;
|
||||||
|
fs::path base_path = out_path;
|
||||||
|
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
|
||||||
|
std::string ext_lower = ext.string();
|
||||||
|
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
|
||||||
|
const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string());
|
||||||
|
if (!ext.empty()) {
|
||||||
|
if (output_format == EncodedImageFormat::JPEG ||
|
||||||
|
output_format == EncodedImageFormat::PNG ||
|
||||||
|
output_format == EncodedImageFormat::WEBP ||
|
||||||
|
ext_lower == ".avi" ||
|
||||||
|
ext_lower == ".webm") {
|
||||||
|
base_path.replace_extension();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
base_path += ".wav";
|
||||||
|
return base_path;
|
||||||
|
}
|
||||||
|
|
||||||
bool save_results(const SDCliParams& cli_params,
|
bool save_results(const SDCliParams& cli_params,
|
||||||
const SDContextParams& ctx_params,
|
const SDContextParams& ctx_params,
|
||||||
const SDGenerationParams& gen_params,
|
const SDGenerationParams& gen_params,
|
||||||
sd_image_t* results,
|
sd_image_t* results,
|
||||||
int num_results) {
|
int num_results,
|
||||||
|
const sd_audio_t* generated_audio = nullptr) {
|
||||||
if (results == nullptr || num_results <= 0) {
|
if (results == nullptr || num_results <= 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -442,6 +463,21 @@ bool save_results(const SDCliParams& cli_params,
|
|||||||
return ok;
|
return ok;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto write_audio_sidecar = [&](const fs::path& wav_path) {
|
||||||
|
if (generated_audio == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (write_wav_to_file(wav_path.string(),
|
||||||
|
generated_audio->data,
|
||||||
|
generated_audio->sample_count,
|
||||||
|
generated_audio->channels,
|
||||||
|
generated_audio->sample_rate)) {
|
||||||
|
LOG_INFO("save result audio to '%s'", wav_path.string().c_str());
|
||||||
|
} else {
|
||||||
|
LOG_WARN("failed to save result audio to '%s'", wav_path.string().c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
int sucessful_reults = 0;
|
int sucessful_reults = 0;
|
||||||
|
|
||||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||||
@ -465,8 +501,16 @@ bool save_results(const SDCliParams& cli_params,
|
|||||||
ext = ".avi";
|
ext = ".avi";
|
||||||
fs::path video_path = base_path;
|
fs::path video_path = base_path;
|
||||||
video_path += ext;
|
video_path += ext;
|
||||||
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
|
std::string final_ext_lower = ext.string();
|
||||||
|
std::transform(final_ext_lower.begin(), final_ext_lower.end(), final_ext_lower.begin(), ::tolower);
|
||||||
|
const bool mux_audio = generated_audio != nullptr && (final_ext_lower == ".avi" || final_ext_lower == ".webm");
|
||||||
|
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps, 90, mux_audio ? generated_audio : nullptr) == 0) {
|
||||||
LOG_INFO("save result video to '%s'", video_path.string().c_str());
|
LOG_INFO("save result video to '%s'", video_path.string().c_str());
|
||||||
|
if (generated_audio != nullptr && !mux_audio) {
|
||||||
|
fs::path wav_path = video_path;
|
||||||
|
wav_path.replace_extension(".wav");
|
||||||
|
write_audio_sidecar(wav_path);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
|
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
|
||||||
@ -488,6 +532,9 @@ bool save_results(const SDCliParams& cli_params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
|
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
|
||||||
|
if (generated_audio != nullptr) {
|
||||||
|
write_audio_sidecar(get_video_audio_sidecar_path(cli_params));
|
||||||
|
}
|
||||||
return sucessful_reults != 0;
|
return sucessful_reults != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -701,7 +748,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
|
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
|
||||||
|
|
||||||
SDImageVec results;
|
SDImageVec results;
|
||||||
int num_results = 0;
|
int num_results = 0;
|
||||||
|
sd_audio_t* generated_audio = nullptr;
|
||||||
|
|
||||||
if (cli_params.mode == UPSCALE) {
|
if (cli_params.mode == UPSCALE) {
|
||||||
num_results = 1;
|
num_results = 1;
|
||||||
@ -733,7 +781,10 @@ int main(int argc, const char* argv[]) {
|
|||||||
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
|
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
|
||||||
} else if (cli_params.mode == VID_GEN) {
|
} else if (cli_params.mode == VID_GEN) {
|
||||||
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
|
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
|
||||||
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
|
sd_image_t* generated_video = nullptr;
|
||||||
|
if (!generate_video(sd_ctx.get(), &vid_gen_params, &generated_video, &num_results, &generated_audio)) {
|
||||||
|
generated_video = nullptr;
|
||||||
|
}
|
||||||
results.adopt(generated_video, num_results);
|
results.adopt(generated_video, num_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,9 +826,12 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
|
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results, generated_audio)) {
|
||||||
|
free_sd_audio(generated_audio);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
free_sd_audio(generated_audio);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -340,10 +340,18 @@ ArgOptions SDContextParams::get_options() {
|
|||||||
"--high-noise-diffusion-model",
|
"--high-noise-diffusion-model",
|
||||||
"path to the standalone high noise diffusion model",
|
"path to the standalone high noise diffusion model",
|
||||||
&high_noise_diffusion_model_path},
|
&high_noise_diffusion_model_path},
|
||||||
|
{"",
|
||||||
|
"--embeddings-connectors",
|
||||||
|
"path to LTXAV embeddings connectors",
|
||||||
|
&embeddings_connectors_path},
|
||||||
{"",
|
{"",
|
||||||
"--vae",
|
"--vae",
|
||||||
"path to standalone vae model",
|
"path to standalone vae model",
|
||||||
&vae_path},
|
&vae_path},
|
||||||
|
{"",
|
||||||
|
"--audio-vae",
|
||||||
|
"path to standalone LTX audio vae model",
|
||||||
|
&audio_vae_path},
|
||||||
{"",
|
{"",
|
||||||
"--taesd",
|
"--taesd",
|
||||||
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
||||||
@ -669,7 +677,9 @@ std::string SDContextParams::to_string() const {
|
|||||||
<< " llm_vision_path: \"" << llm_vision_path << "\",\n"
|
<< " llm_vision_path: \"" << llm_vision_path << "\",\n"
|
||||||
<< " diffusion_model_path: \"" << diffusion_model_path << "\",\n"
|
<< " diffusion_model_path: \"" << diffusion_model_path << "\",\n"
|
||||||
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
|
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
|
||||||
|
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
|
||||||
<< " vae_path: \"" << vae_path << "\",\n"
|
<< " vae_path: \"" << vae_path << "\",\n"
|
||||||
|
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
|
||||||
<< " taesd_path: \"" << taesd_path << "\",\n"
|
<< " taesd_path: \"" << taesd_path << "\",\n"
|
||||||
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
||||||
<< " control_net_path: \"" << control_net_path << "\",\n"
|
<< " control_net_path: \"" << control_net_path << "\",\n"
|
||||||
@ -728,7 +738,9 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
|
|||||||
llm_vision_path.c_str(),
|
llm_vision_path.c_str(),
|
||||||
diffusion_model_path.c_str(),
|
diffusion_model_path.c_str(),
|
||||||
high_noise_diffusion_model_path.c_str(),
|
high_noise_diffusion_model_path.c_str(),
|
||||||
|
embeddings_connectors_path.c_str(),
|
||||||
vae_path.c_str(),
|
vae_path.c_str(),
|
||||||
|
audio_vae_path.c_str(),
|
||||||
taesd_path.c_str(),
|
taesd_path.c_str(),
|
||||||
control_net_path.c_str(),
|
control_net_path.c_str(),
|
||||||
embedding_vec.data(),
|
embedding_vec.data(),
|
||||||
@ -821,7 +833,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
&hires_upscaler},
|
&hires_upscaler},
|
||||||
{"",
|
{"",
|
||||||
"--extra-sample-args",
|
"--extra-sample-args",
|
||||||
"extra sampler args, key=value list. Currently lcm supports noise_clip_std, noise_scale_start, noise_scale_end",
|
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal",
|
||||||
&extra_sample_args},
|
&extra_sample_args},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1006,6 +1018,11 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
"process vae in tiles to reduce memory usage",
|
"process vae in tiles to reduce memory usage",
|
||||||
true,
|
true,
|
||||||
&vae_tiling_params.enabled},
|
&vae_tiling_params.enabled},
|
||||||
|
{"",
|
||||||
|
"--temporal-tiling",
|
||||||
|
"enable temporal tiling for LTX video VAE decode",
|
||||||
|
true,
|
||||||
|
&vae_tiling_params.temporal_tiling},
|
||||||
{"",
|
{"",
|
||||||
"--hires",
|
"--hires",
|
||||||
"enable highres fix",
|
"enable highres fix",
|
||||||
@ -1270,7 +1287,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
on_high_noise_sample_method_arg},
|
on_high_noise_sample_method_arg},
|
||||||
{"",
|
{"",
|
||||||
"--scheduler",
|
"--scheduler",
|
||||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
|
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific",
|
||||||
on_scheduler_arg},
|
on_scheduler_arg},
|
||||||
{"",
|
{"",
|
||||||
"--sigmas",
|
"--sigmas",
|
||||||
@ -1703,6 +1720,9 @@ bool SDGenerationParams::from_json_str(
|
|||||||
if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) {
|
if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) {
|
||||||
vae_tiling_params.enabled = tiling_json["enabled"];
|
vae_tiling_params.enabled = tiling_json["enabled"];
|
||||||
}
|
}
|
||||||
|
if (tiling_json.contains("temporal_tiling") && tiling_json["temporal_tiling"].is_boolean()) {
|
||||||
|
vae_tiling_params.temporal_tiling = tiling_json["temporal_tiling"];
|
||||||
|
}
|
||||||
if (tiling_json.contains("tile_size_x") && tiling_json["tile_size_x"].is_number_integer()) {
|
if (tiling_json.contains("tile_size_x") && tiling_json["tile_size_x"].is_number_integer()) {
|
||||||
vae_tiling_params.tile_size_x = tiling_json["tile_size_x"];
|
vae_tiling_params.tile_size_x = tiling_json["tile_size_x"];
|
||||||
}
|
}
|
||||||
@ -2212,6 +2232,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
|
|||||||
params.strength = strength;
|
params.strength = strength;
|
||||||
params.seed = seed;
|
params.seed = seed;
|
||||||
params.video_frames = video_frames;
|
params.video_frames = video_frames;
|
||||||
|
params.fps = fps;
|
||||||
params.vace_strength = vace_strength;
|
params.vace_strength = vace_strength;
|
||||||
params.vae_tiling_params = vae_tiling_params;
|
params.vae_tiling_params = vae_tiling_params;
|
||||||
params.cache = cache_params;
|
params.cache = cache_params;
|
||||||
@ -2300,6 +2321,7 @@ std::string SDGenerationParams::to_string() const {
|
|||||||
<< ", upscale_tile_size: " << hires_upscale_tile_size << " },\n"
|
<< ", upscale_tile_size: " << hires_upscale_tile_size << " },\n"
|
||||||
<< " vae_tiling_params: { "
|
<< " vae_tiling_params: { "
|
||||||
<< vae_tiling_params.enabled << ", "
|
<< vae_tiling_params.enabled << ", "
|
||||||
|
<< vae_tiling_params.temporal_tiling << ", "
|
||||||
<< vae_tiling_params.tile_size_x << ", "
|
<< vae_tiling_params.tile_size_x << ", "
|
||||||
<< vae_tiling_params.tile_size_y << ", "
|
<< vae_tiling_params.tile_size_y << ", "
|
||||||
<< vae_tiling_params.target_overlap << ", "
|
<< vae_tiling_params.target_overlap << ", "
|
||||||
|
|||||||
@ -92,7 +92,9 @@ struct SDContextParams {
|
|||||||
std::string llm_vision_path;
|
std::string llm_vision_path;
|
||||||
std::string diffusion_model_path;
|
std::string diffusion_model_path;
|
||||||
std::string high_noise_diffusion_model_path;
|
std::string high_noise_diffusion_model_path;
|
||||||
|
std::string embeddings_connectors_path;
|
||||||
std::string vae_path;
|
std::string vae_path;
|
||||||
|
std::string audio_vae_path;
|
||||||
std::string taesd_path;
|
std::string taesd_path;
|
||||||
std::string esrgan_path;
|
std::string esrgan_path;
|
||||||
std::string control_net_path;
|
std::string control_net_path;
|
||||||
@ -187,7 +189,7 @@ struct SDGenerationParams {
|
|||||||
int video_frames = 1;
|
int video_frames = 1;
|
||||||
int fps = 16;
|
int fps = 16;
|
||||||
float vace_strength = 1.f;
|
float vace_strength = 1.f;
|
||||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||||
|
|
||||||
std::string pm_id_images_dir;
|
std::string pm_id_images_dir;
|
||||||
std::string pm_id_embed_path;
|
std::string pm_id_embed_path;
|
||||||
|
|||||||
@ -613,6 +613,13 @@ typedef struct {
|
|||||||
uint32_t size;
|
uint32_t size;
|
||||||
} avi_index_entry;
|
} avi_index_entry;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char fourcc[4];
|
||||||
|
uint32_t flags;
|
||||||
|
uint32_t offset;
|
||||||
|
uint32_t size;
|
||||||
|
} avi_chunk_index_entry;
|
||||||
|
|
||||||
void write_u32_le(FILE* f, uint32_t val) {
|
void write_u32_le(FILE* f, uint32_t val) {
|
||||||
fwrite(&val, 4, 1, f);
|
fwrite(&val, 4, 1, f);
|
||||||
}
|
}
|
||||||
@ -647,6 +654,33 @@ void write_fourcc(std::vector<uint8_t>& data, const char* fourcc) {
|
|||||||
data.insert(data.end(), fourcc, fourcc + 4);
|
data.insert(data.end(), fourcc, fourcc + 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<uint8_t> audio_to_pcm16_bytes(const sd_audio_t* audio) {
|
||||||
|
if (audio == nullptr || audio->data == nullptr || audio->sample_count == 0 || audio->channels == 0 || audio->sample_rate == 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t pcm_samples = static_cast<size_t>(audio->sample_count) * static_cast<size_t>(audio->channels);
|
||||||
|
std::vector<uint8_t> bytes(pcm_samples * sizeof(int16_t));
|
||||||
|
auto* pcm = reinterpret_cast<int16_t*>(bytes.data());
|
||||||
|
for (size_t i = 0; i < pcm_samples; ++i) {
|
||||||
|
const float sample = std::clamp(audio->data[i], -1.0f, 1.0f);
|
||||||
|
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
|
||||||
|
}
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<uint64_t, uint64_t> audio_sample_range_for_video_frame(const sd_audio_t* audio, int frame_idx, int num_frames, int fps) {
|
||||||
|
if (audio == nullptr || fps <= 0 || num_frames <= 0) {
|
||||||
|
return {0, 0};
|
||||||
|
}
|
||||||
|
const uint64_t total = audio->sample_count;
|
||||||
|
const uint64_t start = static_cast<uint64_t>((static_cast<long double>(frame_idx) * total) / num_frames);
|
||||||
|
const uint64_t end = frame_idx + 1 == num_frames
|
||||||
|
? total
|
||||||
|
: static_cast<uint64_t>((static_cast<long double>(frame_idx + 1) * total) / num_frames);
|
||||||
|
return {start, std::max(start, end)};
|
||||||
|
}
|
||||||
|
|
||||||
EncodedImageFormat encoded_image_format_from_path(const std::string& path) {
|
EncodedImageFormat encoded_image_format_from_path(const std::string& path) {
|
||||||
std::string ext = fs::path(path).extension().string();
|
std::string ext = fs::path(path).extension().string();
|
||||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||||
@ -776,7 +810,7 @@ uint8_t* load_image_from_memory(const char* image_bytes,
|
|||||||
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
|
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) {
|
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||||
if (num_images == 0) {
|
if (num_images == 0) {
|
||||||
fprintf(stderr, "Error: Image array is empty.\n");
|
fprintf(stderr, "Error: Image array is empty.\n");
|
||||||
return {};
|
return {};
|
||||||
@ -793,7 +827,13 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
// stb_image_write changes JPEG sampling behavior above quality 90.
|
// stb_image_write changes JPEG sampling behavior above quality 90.
|
||||||
// MJPG AVI playback is more compatible when we keep the encoder on the
|
// MJPG AVI playback is more compatible when we keep the encoder on the
|
||||||
// <= 90 path.
|
// <= 90 path.
|
||||||
const int mjpg_quality = std::clamp(quality, 1, 90);
|
const int mjpg_quality = std::clamp(quality, 1, 90);
|
||||||
|
const bool has_audio = audio != nullptr && audio->data != nullptr && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0;
|
||||||
|
const std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
|
||||||
|
const uint16_t audio_bits_per_sample = 16;
|
||||||
|
const uint16_t audio_block_align = has_audio ? static_cast<uint16_t>(audio->channels * (audio_bits_per_sample / 8)) : 0;
|
||||||
|
const uint32_t audio_byte_rate = has_audio ? static_cast<uint32_t>(audio->sample_rate * audio_block_align) : 0;
|
||||||
|
const uint32_t audio_data_size = has_audio ? static_cast<uint32_t>(audio_pcm.size()) : 0;
|
||||||
|
|
||||||
std::vector<uint8_t> avi_data;
|
std::vector<uint8_t> avi_data;
|
||||||
avi_data.reserve(static_cast<size_t>(num_images) * 1024);
|
avi_data.reserve(static_cast<size_t>(num_images) * 1024);
|
||||||
@ -804,7 +844,11 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
write_fourcc(avi_data, "AVI ");
|
write_fourcc(avi_data, "AVI ");
|
||||||
|
|
||||||
write_fourcc(avi_data, "LIST");
|
write_fourcc(avi_data, "LIST");
|
||||||
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
|
uint32_t hdrl_size = 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40;
|
||||||
|
if (has_audio) {
|
||||||
|
hdrl_size += 8 + (4 + 8 + 56 + 8 + 16);
|
||||||
|
}
|
||||||
|
write_u32_le(avi_data, hdrl_size);
|
||||||
write_fourcc(avi_data, "hdrl");
|
write_fourcc(avi_data, "hdrl");
|
||||||
|
|
||||||
write_fourcc(avi_data, "avih");
|
write_fourcc(avi_data, "avih");
|
||||||
@ -815,7 +859,7 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
write_u32_le(avi_data, 0x110);
|
write_u32_le(avi_data, 0x110);
|
||||||
write_u32_le(avi_data, num_images);
|
write_u32_le(avi_data, num_images);
|
||||||
write_u32_le(avi_data, 0);
|
write_u32_le(avi_data, 0);
|
||||||
write_u32_le(avi_data, 1);
|
write_u32_le(avi_data, has_audio ? 2 : 1);
|
||||||
write_u32_le(avi_data, width * height * 3);
|
write_u32_le(avi_data, width * height * 3);
|
||||||
write_u32_le(avi_data, width);
|
write_u32_le(avi_data, width);
|
||||||
write_u32_le(avi_data, height);
|
write_u32_le(avi_data, height);
|
||||||
@ -862,12 +906,48 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
write_u32_le(avi_data, 0);
|
write_u32_le(avi_data, 0);
|
||||||
write_u32_le(avi_data, 0);
|
write_u32_le(avi_data, 0);
|
||||||
|
|
||||||
|
if (has_audio) {
|
||||||
|
write_fourcc(avi_data, "LIST");
|
||||||
|
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 16);
|
||||||
|
write_fourcc(avi_data, "strl");
|
||||||
|
|
||||||
|
write_fourcc(avi_data, "strh");
|
||||||
|
write_u32_le(avi_data, 56);
|
||||||
|
write_fourcc(avi_data, "auds");
|
||||||
|
write_u32_le(avi_data, 0);
|
||||||
|
write_u32_le(avi_data, 0);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
write_u32_le(avi_data, 0);
|
||||||
|
write_u32_le(avi_data, audio_block_align);
|
||||||
|
write_u32_le(avi_data, audio_byte_rate);
|
||||||
|
write_u32_le(avi_data, 0);
|
||||||
|
write_u32_le(avi_data, static_cast<uint32_t>(audio->sample_count));
|
||||||
|
write_u32_le(avi_data, audio_data_size);
|
||||||
|
write_u32_le(avi_data, static_cast<uint32_t>(-1));
|
||||||
|
write_u32_le(avi_data, audio_block_align);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
write_u16_le(avi_data, 0);
|
||||||
|
|
||||||
|
write_fourcc(avi_data, "strf");
|
||||||
|
write_u32_le(avi_data, 16);
|
||||||
|
write_u16_le(avi_data, 1);
|
||||||
|
write_u16_le(avi_data, static_cast<uint16_t>(audio->channels));
|
||||||
|
write_u32_le(avi_data, audio->sample_rate);
|
||||||
|
write_u32_le(avi_data, audio_byte_rate);
|
||||||
|
write_u16_le(avi_data, audio_block_align);
|
||||||
|
write_u16_le(avi_data, audio_bits_per_sample);
|
||||||
|
}
|
||||||
|
|
||||||
write_fourcc(avi_data, "LIST");
|
write_fourcc(avi_data, "LIST");
|
||||||
const size_t movi_size_pos = avi_data.size();
|
const size_t movi_size_pos = avi_data.size();
|
||||||
write_u32_le(avi_data, 0);
|
write_u32_le(avi_data, 0);
|
||||||
write_fourcc(avi_data, "movi");
|
write_fourcc(avi_data, "movi");
|
||||||
|
|
||||||
std::vector<avi_index_entry> index(static_cast<size_t>(num_images));
|
std::vector<avi_chunk_index_entry> index;
|
||||||
|
index.reserve(static_cast<size_t>(num_images) + (has_audio ? 1 : 0));
|
||||||
std::vector<uint8_t> jpeg_data;
|
std::vector<uint8_t> jpeg_data;
|
||||||
|
|
||||||
for (int i = 0; i < num_images; i++) {
|
for (int i = 0; i < num_images; i++) {
|
||||||
@ -884,27 +964,46 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
index[i].offset = static_cast<uint32_t>(avi_data.size());
|
avi_chunk_index_entry video_entry = {};
|
||||||
|
memcpy(video_entry.fourcc, "00dc", 4);
|
||||||
|
video_entry.flags = 0x10;
|
||||||
|
video_entry.offset = static_cast<uint32_t>(avi_data.size());
|
||||||
write_fourcc(avi_data, "00dc");
|
write_fourcc(avi_data, "00dc");
|
||||||
write_u32_le(avi_data, static_cast<uint32_t>(jpeg_data.size()));
|
write_u32_le(avi_data, static_cast<uint32_t>(jpeg_data.size()));
|
||||||
index[i].size = (uint32_t)jpeg_data.size();
|
video_entry.size = static_cast<uint32_t>(jpeg_data.size());
|
||||||
avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end());
|
avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end());
|
||||||
|
index.push_back(video_entry);
|
||||||
|
|
||||||
if (jpeg_data.size() % 2) {
|
if (jpeg_data.size() % 2) {
|
||||||
avi_data.push_back(0);
|
avi_data.push_back(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_audio && !audio_pcm.empty()) {
|
||||||
|
avi_chunk_index_entry audio_entry = {};
|
||||||
|
memcpy(audio_entry.fourcc, "01wb", 4);
|
||||||
|
audio_entry.flags = 0;
|
||||||
|
audio_entry.offset = static_cast<uint32_t>(avi_data.size());
|
||||||
|
audio_entry.size = static_cast<uint32_t>(audio_pcm.size());
|
||||||
|
write_fourcc(avi_data, "01wb");
|
||||||
|
write_u32_le(avi_data, static_cast<uint32_t>(audio_pcm.size()));
|
||||||
|
avi_data.insert(avi_data.end(), audio_pcm.begin(), audio_pcm.end());
|
||||||
|
index.push_back(audio_entry);
|
||||||
|
if (audio_pcm.size() % 2 != 0) {
|
||||||
|
avi_data.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const size_t movi_size = avi_data.size() - movi_size_pos - 4;
|
const size_t movi_size = avi_data.size() - movi_size_pos - 4;
|
||||||
patch_u32_le(avi_data, movi_size_pos, static_cast<uint32_t>(movi_size));
|
patch_u32_le(avi_data, movi_size_pos, static_cast<uint32_t>(movi_size));
|
||||||
|
|
||||||
write_fourcc(avi_data, "idx1");
|
write_fourcc(avi_data, "idx1");
|
||||||
write_u32_le(avi_data, num_images * 16);
|
write_u32_le(avi_data, static_cast<uint32_t>(index.size() * 16));
|
||||||
for (int i = 0; i < num_images; i++) {
|
for (const auto& entry : index) {
|
||||||
write_fourcc(avi_data, "00dc");
|
write_fourcc(avi_data, entry.fourcc);
|
||||||
write_u32_le(avi_data, 0x10);
|
write_u32_le(avi_data, entry.flags);
|
||||||
write_u32_le(avi_data, index[i].offset);
|
write_u32_le(avi_data, entry.offset);
|
||||||
write_u32_le(avi_data, index[i].size);
|
write_u32_le(avi_data, entry.size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t file_size = avi_data.size() - riff_size_pos - 4;
|
const size_t file_size = avi_data.size() - riff_size_pos - 4;
|
||||||
@ -913,8 +1012,8 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
|||||||
return avi_data;
|
return avi_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||||
std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality);
|
std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||||
if (avi_data.empty()) {
|
if (avi_data.empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
@ -1044,7 +1143,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef SD_USE_WEBM
|
#ifdef SD_USE_WEBM
|
||||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) {
|
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||||
if (num_images == 0) {
|
if (num_images == 0) {
|
||||||
fprintf(stderr, "Error: Image array is empty.\n");
|
fprintf(stderr, "Error: Image array is empty.\n");
|
||||||
return {};
|
return {};
|
||||||
@ -1089,6 +1188,25 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
|||||||
video_track->set_display_height(static_cast<uint64_t>(height));
|
video_track->set_display_height(static_cast<uint64_t>(height));
|
||||||
video_track->set_frame_rate(static_cast<double>(fps));
|
video_track->set_frame_rate(static_cast<double>(fps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64_t audio_track_number = 0;
|
||||||
|
std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
|
||||||
|
if (audio != nullptr && !audio_pcm.empty()) {
|
||||||
|
audio_track_number = segment.AddAudioTrack(static_cast<int32_t>(audio->sample_rate), static_cast<int32_t>(audio->channels), 0);
|
||||||
|
if (audio_track_number == 0) {
|
||||||
|
fprintf(stderr, "Error: Failed to add audio track.\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
auto* audio_track = static_cast<mkvmuxer::AudioTrack*>(segment.GetTrackByNumber(audio_track_number));
|
||||||
|
if (audio_track == nullptr) {
|
||||||
|
fprintf(stderr, "Error: Failed to get audio track.\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
audio_track->set_codec_id("A_PCM/INT/LIT");
|
||||||
|
audio_track->set_bit_depth(16);
|
||||||
|
audio_track->set_sample_rate(static_cast<double>(audio->sample_rate));
|
||||||
|
audio_track->set_channels(audio->channels);
|
||||||
|
}
|
||||||
segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp");
|
segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp");
|
||||||
segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp");
|
segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp");
|
||||||
|
|
||||||
@ -1118,6 +1236,23 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (audio_track_number != 0) {
|
||||||
|
auto [audio_begin, audio_end] = audio_sample_range_for_video_frame(audio, i, num_images, fps);
|
||||||
|
const uint64_t frame_samples = audio_end - audio_begin;
|
||||||
|
if (frame_samples > 0) {
|
||||||
|
const uint64_t frame_bytes = frame_samples * audio->channels * sizeof(int16_t);
|
||||||
|
const uint8_t* frame_ptr = audio_pcm.data() + audio_begin * audio->channels * sizeof(int16_t);
|
||||||
|
if (!segment.AddFrame(frame_ptr,
|
||||||
|
frame_bytes,
|
||||||
|
audio_track_number,
|
||||||
|
timestamp_ns,
|
||||||
|
true)) {
|
||||||
|
fprintf(stderr, "Error: Failed to mux audio chunk %d into WebM.\n", i);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
timestamp_ns += frame_duration_ns;
|
timestamp_ns += frame_duration_ns;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1133,8 +1268,8 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
|||||||
return writer.data();
|
return writer.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||||
std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality);
|
std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||||
if (webm_data.empty()) {
|
if (webm_data.empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
@ -1150,7 +1285,8 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
|||||||
sd_image_t* images,
|
sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality) {
|
int quality,
|
||||||
|
const sd_audio_t* audio) {
|
||||||
std::string format = output_format;
|
std::string format = output_format;
|
||||||
std::transform(format.begin(), format.end(), format.begin(),
|
std::transform(format.begin(), format.end(), format.begin(),
|
||||||
[](unsigned char c) { return static_cast<char>(tolower(c)); });
|
[](unsigned char c) { return static_cast<char>(tolower(c)); });
|
||||||
@ -1160,7 +1296,7 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
|||||||
|
|
||||||
#ifdef SD_USE_WEBM
|
#ifdef SD_USE_WEBM
|
||||||
if (format == "webm") {
|
if (format == "webm") {
|
||||||
return create_webm_from_sd_images_to_vector(images, num_images, fps, quality);
|
return create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -1170,14 +1306,14 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality);
|
return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||||
}
|
}
|
||||||
|
|
||||||
int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||||
std::string path = filename ? filename : "";
|
std::string path = filename ? filename : "";
|
||||||
auto pos = path.find_last_of('.');
|
auto pos = path.find_last_of('.');
|
||||||
std::string ext = pos == std::string::npos ? "" : path.substr(pos);
|
std::string ext = pos == std::string::npos ? "" : path.substr(pos);
|
||||||
std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality);
|
std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality, audio);
|
||||||
if (video_data.empty()) {
|
if (video_data.empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
@ -1187,3 +1323,54 @@ int create_video_from_sd_images(const char* filename, sd_image_t* images, int nu
|
|||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool write_wav_to_file(const std::string& path,
|
||||||
|
const float* interleaved_samples,
|
||||||
|
uint64_t sample_count,
|
||||||
|
uint32_t channels,
|
||||||
|
uint32_t sample_rate) {
|
||||||
|
if (interleaved_samples == nullptr || sample_count == 0 || channels == 0 || sample_rate == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ofstream file(path, std::ios::binary);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t bits_per_sample = 16;
|
||||||
|
uint32_t bytes_per_sample = bits_per_sample / 8;
|
||||||
|
uint32_t block_align = channels * bytes_per_sample;
|
||||||
|
uint32_t byte_rate = sample_rate * block_align;
|
||||||
|
uint32_t data_size = static_cast<uint32_t>(sample_count * channels * bytes_per_sample);
|
||||||
|
uint32_t riff_size = 36 + data_size;
|
||||||
|
|
||||||
|
file.write("RIFF", 4);
|
||||||
|
file.write(reinterpret_cast<const char*>(&riff_size), sizeof(riff_size));
|
||||||
|
file.write("WAVE", 4);
|
||||||
|
file.write("fmt ", 4);
|
||||||
|
|
||||||
|
uint32_t fmt_size = 16;
|
||||||
|
uint16_t audio_format = 1;
|
||||||
|
uint16_t wav_channels = static_cast<uint16_t>(channels);
|
||||||
|
uint16_t wav_block_align = static_cast<uint16_t>(block_align);
|
||||||
|
uint16_t wav_bits_per_sample = static_cast<uint16_t>(bits_per_sample);
|
||||||
|
file.write(reinterpret_cast<const char*>(&fmt_size), sizeof(fmt_size));
|
||||||
|
file.write(reinterpret_cast<const char*>(&audio_format), sizeof(audio_format));
|
||||||
|
file.write(reinterpret_cast<const char*>(&wav_channels), sizeof(wav_channels));
|
||||||
|
file.write(reinterpret_cast<const char*>(&sample_rate), sizeof(sample_rate));
|
||||||
|
file.write(reinterpret_cast<const char*>(&byte_rate), sizeof(byte_rate));
|
||||||
|
file.write(reinterpret_cast<const char*>(&wav_block_align), sizeof(wav_block_align));
|
||||||
|
file.write(reinterpret_cast<const char*>(&wav_bits_per_sample), sizeof(wav_bits_per_sample));
|
||||||
|
|
||||||
|
file.write("data", 4);
|
||||||
|
file.write(reinterpret_cast<const char*>(&data_size), sizeof(data_size));
|
||||||
|
|
||||||
|
std::vector<int16_t> pcm(sample_count * channels);
|
||||||
|
for (size_t i = 0; i < pcm.size(); ++i) {
|
||||||
|
float sample = std::max(-1.0f, std::min(1.0f, interleaved_samples[i]));
|
||||||
|
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
|
||||||
|
}
|
||||||
|
file.write(reinterpret_cast<const char*>(pcm.data()), static_cast<std::streamsize>(pcm.size() * sizeof(int16_t)));
|
||||||
|
return file.good();
|
||||||
|
}
|
||||||
|
|||||||
@ -57,11 +57,13 @@ int create_mjpg_avi_from_sd_images(const char* filename,
|
|||||||
sd_image_t* images,
|
sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
|
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
|
|
||||||
#ifdef SD_USE_WEBP
|
#ifdef SD_USE_WEBP
|
||||||
int create_animated_webp_from_sd_images(const char* filename,
|
int create_animated_webp_from_sd_images(const char* filename,
|
||||||
@ -80,22 +82,32 @@ int create_webm_from_sd_images(const char* filename,
|
|||||||
sd_image_t* images,
|
sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
|
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int create_video_from_sd_images(const char* filename,
|
int create_video_from_sd_images(const char* filename,
|
||||||
sd_image_t* images,
|
sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
|
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
|
||||||
sd_image_t* images,
|
sd_image_t* images,
|
||||||
int num_images,
|
int num_images,
|
||||||
int fps,
|
int fps,
|
||||||
int quality = 90);
|
int quality = 90,
|
||||||
|
const sd_audio_t* audio = nullptr);
|
||||||
|
|
||||||
|
bool write_wav_to_file(const std::string& path,
|
||||||
|
const float* interleaved_samples,
|
||||||
|
uint64_t sample_count,
|
||||||
|
uint32_t channels,
|
||||||
|
uint32_t sample_rate);
|
||||||
|
|
||||||
#endif // __MEDIA_IO_H__
|
#endif // __MEDIA_IO_H__
|
||||||
|
|||||||
@ -205,8 +205,9 @@ Default Generation Options:
|
|||||||
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
||||||
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
||||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||||
--extra-sample-args <string> extra sampler args, key=value list. Currently lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
|
stretch, terminal
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
@ -271,8 +272,8 @@ Default Generation Options:
|
|||||||
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
||||||
res_2s, er_sde, euler_cfg_pp, euler_a_cfg_pp] default: euler for Flux/SD3/Wan, euler_a otherwise
|
res_2s, er_sde, euler_cfg_pp, euler_a_cfg_pp] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
||||||
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
|
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default:
|
||||||
discrete
|
model-specific
|
||||||
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
||||||
"14.61,7.8,3.5,0.0").
|
"14.61,7.8,3.5,0.0").
|
||||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||||
|
|||||||
@ -231,16 +231,21 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
|
|||||||
sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t();
|
sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t();
|
||||||
|
|
||||||
SDImageVec results;
|
SDImageVec results;
|
||||||
int num_results = 0;
|
int num_results = 0;
|
||||||
|
sd_audio_t* generated_audio = nullptr;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
||||||
sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results);
|
sd_image_t* raw_results = nullptr;
|
||||||
|
if (!generate_video(runtime.sd_ctx, ¶ms, &raw_results, &num_results, &generated_audio)) {
|
||||||
|
raw_results = nullptr;
|
||||||
|
}
|
||||||
results.adopt(raw_results, num_results);
|
results.adopt(raw_results, num_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
num_results = results.count();
|
num_results = results.count();
|
||||||
if (num_results <= 0) {
|
if (num_results <= 0) {
|
||||||
|
free_sd_audio(generated_audio);
|
||||||
error_message = "generate_video returned no results";
|
error_message = "generate_video returned no results";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -249,7 +254,9 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
|
|||||||
results.data(),
|
results.data(),
|
||||||
num_results,
|
num_results,
|
||||||
job.vid_gen.gen_params.fps,
|
job.vid_gen.gen_params.fps,
|
||||||
job.vid_gen.output_compression);
|
job.vid_gen.output_compression,
|
||||||
|
generated_audio);
|
||||||
|
free_sd_audio(generated_audio);
|
||||||
if (video_bytes.empty()) {
|
if (video_bytes.empty()) {
|
||||||
error_message = "failed to encode generated video container";
|
error_message = "failed to encode generated video container";
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
2
ggml
2
ggml
@ -1 +1 @@
|
|||||||
Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50
|
Subproject commit 7f4ab364b2843921e795d6890d0f42dd5e5d6b63
|
||||||
@ -68,6 +68,7 @@ enum scheduler_t {
|
|||||||
KL_OPTIMAL_SCHEDULER,
|
KL_OPTIMAL_SCHEDULER,
|
||||||
LCM_SCHEDULER,
|
LCM_SCHEDULER,
|
||||||
BONG_TANGENT_SCHEDULER,
|
BONG_TANGENT_SCHEDULER,
|
||||||
|
LTX2_SCHEDULER,
|
||||||
SCHEDULER_COUNT
|
SCHEDULER_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -151,6 +152,7 @@ enum lora_apply_mode_t {
|
|||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
bool enabled;
|
bool enabled;
|
||||||
|
bool temporal_tiling;
|
||||||
int tile_size_x;
|
int tile_size_x;
|
||||||
int tile_size_y;
|
int tile_size_y;
|
||||||
float target_overlap;
|
float target_overlap;
|
||||||
@ -173,7 +175,9 @@ typedef struct {
|
|||||||
const char* llm_vision_path;
|
const char* llm_vision_path;
|
||||||
const char* diffusion_model_path;
|
const char* diffusion_model_path;
|
||||||
const char* high_noise_diffusion_model_path;
|
const char* high_noise_diffusion_model_path;
|
||||||
|
const char* embeddings_connectors_path;
|
||||||
const char* vae_path;
|
const char* vae_path;
|
||||||
|
const char* audio_vae_path;
|
||||||
const char* taesd_path;
|
const char* taesd_path;
|
||||||
const char* control_net_path;
|
const char* control_net_path;
|
||||||
const sd_embedding_t* embeddings;
|
const sd_embedding_t* embeddings;
|
||||||
@ -210,6 +214,13 @@ typedef struct {
|
|||||||
const char* params_backend;
|
const char* params_backend;
|
||||||
} sd_ctx_params_t;
|
} sd_ctx_params_t;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
uint32_t sample_rate;
|
||||||
|
uint32_t channels;
|
||||||
|
uint64_t sample_count;
|
||||||
|
float* data;
|
||||||
|
} sd_audio_t;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint32_t width;
|
uint32_t width;
|
||||||
uint32_t height;
|
uint32_t height;
|
||||||
@ -365,6 +376,7 @@ typedef struct {
|
|||||||
float strength;
|
float strength;
|
||||||
int64_t seed;
|
int64_t seed;
|
||||||
int video_frames;
|
int video_frames;
|
||||||
|
int fps;
|
||||||
float vace_strength;
|
float vace_strength;
|
||||||
sd_tiling_params_t vae_tiling_params;
|
sd_tiling_params_t vae_tiling_params;
|
||||||
sd_cache_params_t cache;
|
sd_cache_params_t cache;
|
||||||
@ -409,6 +421,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
|||||||
|
|
||||||
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
|
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
|
||||||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
||||||
|
SD_API void free_sd_audio(sd_audio_t* audio);
|
||||||
|
|
||||||
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
||||||
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
||||||
@ -421,7 +434,11 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para
|
|||||||
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
|
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
|
||||||
|
|
||||||
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
|
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
|
||||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
|
SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
||||||
|
const sd_vid_gen_params_t* sd_vid_gen_params,
|
||||||
|
sd_image_t** frames_out,
|
||||||
|
int* num_frames_out,
|
||||||
|
sd_audio_t** audio_out);
|
||||||
|
|
||||||
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
||||||
|
|
||||||
|
|||||||
@ -103,6 +103,64 @@ namespace DiT {
|
|||||||
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
|
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline ggml_tensor* patchify(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int pt,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
int64_t N = 1) {
|
||||||
|
// x: [N*C, T, H, W]
|
||||||
|
// return: [N, h*w, C*pt*ph*pw]
|
||||||
|
int64_t C = x->ne[3] / N;
|
||||||
|
int64_t T = x->ne[2];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t t_len = T / pt;
|
||||||
|
int64_t h_len = H / ph;
|
||||||
|
int64_t w_len = W / pw;
|
||||||
|
|
||||||
|
GGML_ASSERT(C * N == x->ne[3]);
|
||||||
|
GGML_ASSERT(t_len * pt == T && h_len * ph == H && w_len * pw == W);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt, t_len * C * N); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph, h_len * t_len * C * N); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw, w_len, ph * pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * ph * pt, w_len * h_len * t_len, C, N); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N, t_len*h_len*w_len, C, pt*ph*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * ph * pt * C, w_len * h_len * t_len, N, 1); // [N, t_len*h_len*w_len, C*pt*ph*pw]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ggml_tensor* unpatchify(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
int64_t t_len,
|
||||||
|
int64_t h_len,
|
||||||
|
int64_t w_len,
|
||||||
|
int pt,
|
||||||
|
int ph,
|
||||||
|
int pw) {
|
||||||
|
// x: [N, t_len*h_len*w_len, pt*ph*pw*C]
|
||||||
|
// return: [N*C, t_len*pt, h_len*ph, w_len*pw]
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
int64_t C = x->ne[0] / pt / ph / pw;
|
||||||
|
|
||||||
|
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
} // namespace DiT
|
} // namespace DiT
|
||||||
|
|
||||||
#endif // __COMMON_DIT_HPP__
|
#endif // __COMMON_DIT_HPP__
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
#ifndef __CONDITIONER_HPP__
|
#ifndef __CONDITIONER_HPP__
|
||||||
#define __CONDITIONER_HPP__
|
#define __CONDITIONER_HPP__
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "clip.hpp"
|
#include "clip.hpp"
|
||||||
@ -66,6 +68,17 @@ static inline sd::Tensor<float> apply_token_weights(sd::Tensor<float> hidden_sta
|
|||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool all_one = true;
|
||||||
|
for (float weight : weights) {
|
||||||
|
if (weight != 1.0f) {
|
||||||
|
all_one = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (all_one) {
|
||||||
|
return hidden_states;
|
||||||
|
}
|
||||||
|
|
||||||
if (hidden_states.dim() == 1) {
|
if (hidden_states.dim() == 1) {
|
||||||
hidden_states.unsqueeze_(1);
|
hidden_states.unsqueeze_(1);
|
||||||
}
|
}
|
||||||
@ -77,7 +90,7 @@ static inline sd::Tensor<float> apply_token_weights(sd::Tensor<float> hidden_sta
|
|||||||
chunk_weights.reshape_({1, static_cast<int64_t>(weights.size())});
|
chunk_weights.reshape_({1, static_cast<int64_t>(weights.size())});
|
||||||
hidden_states *= chunk_weights;
|
hidden_states *= chunk_weights;
|
||||||
float new_mean = hidden_states.mean();
|
float new_mean = hidden_states.mean();
|
||||||
if (new_mean != 0.0f) {
|
if (std::isfinite(original_mean) && std::isfinite(new_mean) && new_mean != 0.0f) {
|
||||||
hidden_states *= (original_mean / new_mean);
|
hidden_states *= (original_mean / new_mean);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2022,4 +2035,277 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LTXAVTextProjection : public GGMLBlock {
|
||||||
|
static constexpr int64_t kHiddenSize = 3840;
|
||||||
|
static constexpr int64_t kNumStates = 49;
|
||||||
|
bool dual_projection = false;
|
||||||
|
|
||||||
|
LTXAVTextProjection(bool dual_projection = false)
|
||||||
|
: dual_projection(dual_projection) {
|
||||||
|
if (dual_projection) {
|
||||||
|
blocks["video_aggregate_embed"] = std::make_shared<Linear>(kHiddenSize * kNumStates, 4096, true);
|
||||||
|
blocks["audio_aggregate_embed"] = std::make_shared<Linear>(kHiddenSize * kNumStates, 2048, true);
|
||||||
|
} else {
|
||||||
|
blocks["projection"] = std::make_shared<Linear>(kHiddenSize * kNumStates, kHiddenSize, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
if (!dual_projection) {
|
||||||
|
auto projection = std::dynamic_pointer_cast<Linear>(blocks["projection"]);
|
||||||
|
return projection->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto video_projection = std::dynamic_pointer_cast<Linear>(blocks["video_aggregate_embed"]);
|
||||||
|
auto audio_projection = std::dynamic_pointer_cast<Linear>(blocks["audio_aggregate_embed"]);
|
||||||
|
auto video_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(4096.f / static_cast<float>(kHiddenSize)));
|
||||||
|
auto audio_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(2048.f / static_cast<float>(kHiddenSize)));
|
||||||
|
auto video = video_projection->forward(ctx, video_in);
|
||||||
|
auto audio = audio_projection->forward(ctx, audio_in);
|
||||||
|
return ggml_concat(ctx->ggml_ctx, video, audio, 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LTXAVTextProjectionRunner : public GGMLRunner {
|
||||||
|
LTXAVTextProjection model;
|
||||||
|
|
||||||
|
LTXAVTextProjectionRunner(ggml_backend_t backend,
|
||||||
|
ggml_backend_t params_backend,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
const std::string& prefix = "")
|
||||||
|
: GGMLRunner(backend, params_backend),
|
||||||
|
model(tensor_storage_map.find(prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end()) {
|
||||||
|
model.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "ltxav_text_projection";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) {
|
||||||
|
model.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
||||||
|
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
auto x = make_input(x_tensor);
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
auto out = model.forward(&runner_ctx, x);
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads, const sd::Tensor<float>& x) {
|
||||||
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
|
return build_graph(x);
|
||||||
|
};
|
||||||
|
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LTXAVEmbedder : public Conditioner {
|
||||||
|
static constexpr int64_t kHiddenSize = 3840;
|
||||||
|
static constexpr int64_t kNumStates = 49;
|
||||||
|
static constexpr int64_t kMinLength = 1024;
|
||||||
|
|
||||||
|
std::shared_ptr<GemmaTokenizer> tokenizer;
|
||||||
|
std::shared_ptr<LLM::LLMRunner> llm;
|
||||||
|
std::shared_ptr<LTXAVTextProjectionRunner> projector;
|
||||||
|
bool dual_projection = false;
|
||||||
|
|
||||||
|
LTXAVEmbedder(ggml_backend_t backend,
|
||||||
|
ggml_backend_t params_backend,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
const std::string& llm_prefix = "text_encoders.llm",
|
||||||
|
const std::string& projector_prefix = "text_embedding_projection") {
|
||||||
|
tokenizer = std::make_shared<GemmaTokenizer>();
|
||||||
|
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::GEMMA3_12B,
|
||||||
|
backend,
|
||||||
|
params_backend,
|
||||||
|
tensor_storage_map,
|
||||||
|
llm_prefix,
|
||||||
|
false);
|
||||||
|
dual_projection = tensor_storage_map.find(projector_prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end();
|
||||||
|
projector = std::make_shared<LTXAVTextProjectionRunner>(backend,
|
||||||
|
params_backend,
|
||||||
|
tensor_storage_map,
|
||||||
|
projector_prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
|
||||||
|
llm->get_param_tensors(tensors, "text_encoders.llm");
|
||||||
|
projector->get_param_tensors(tensors, "text_embedding_projection");
|
||||||
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() override {
|
||||||
|
llm->alloc_params_buffer();
|
||||||
|
projector->alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_params_buffer() override {
|
||||||
|
llm->free_params_buffer();
|
||||||
|
projector->free_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_params_buffer_size() override {
|
||||||
|
return llm->get_params_buffer_size() + projector->get_params_buffer_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
llm->set_flash_attention_enabled(enabled);
|
||||||
|
projector->set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
|
llm->set_weight_adapter(adapter);
|
||||||
|
projector->set_weight_adapter(adapter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
|
||||||
|
const std::pair<int, int>& attn_range) {
|
||||||
|
std::vector<std::pair<std::string, float>> parsed_attention;
|
||||||
|
if (attn_range.first >= 0 && attn_range.second > 0) {
|
||||||
|
if (attn_range.first > 0) {
|
||||||
|
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
|
||||||
|
}
|
||||||
|
if (attn_range.second - attn_range.first > 0) {
|
||||||
|
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
|
||||||
|
parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end());
|
||||||
|
}
|
||||||
|
if (static_cast<size_t>(attn_range.second) < text.size()) {
|
||||||
|
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parsed_attention.emplace_back(text, 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tokens;
|
||||||
|
std::vector<float> weights;
|
||||||
|
for (const auto& item : parsed_attention) {
|
||||||
|
auto curr_tokens = tokenizer->encode(item.first, nullptr);
|
||||||
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
|
weights.insert(weights.end(), curr_tokens.size(), item.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> mask;
|
||||||
|
tokenizer->pad_tokens(tokens, &weights, &mask, kMinLength);
|
||||||
|
return {tokens, weights, mask};
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> encode_prompt(int n_threads,
|
||||||
|
const std::string& prompt,
|
||||||
|
const std::pair<int, int>& prompt_attn_range) {
|
||||||
|
auto tokens_weights_mask = tokenize(prompt, prompt_attn_range);
|
||||||
|
auto& tokens = std::get<0>(tokens_weights_mask);
|
||||||
|
auto& weights = std::get<1>(tokens_weights_mask);
|
||||||
|
auto& mask = std::get<2>(tokens_weights_mask);
|
||||||
|
|
||||||
|
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, std::vector<int32_t>(tokens.begin(), tokens.end()));
|
||||||
|
sd::Tensor<float> attention_mask;
|
||||||
|
if (!mask.empty()) {
|
||||||
|
const float mask_min = std::numeric_limits<float>::lowest() / 4.0f;
|
||||||
|
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
|
||||||
|
for (size_t i1 = 0; i1 < mask.size(); ++i1) {
|
||||||
|
for (size_t i0 = 0; i0 < mask.size(); ++i0) {
|
||||||
|
float value = 0.0f;
|
||||||
|
if (mask[i0] == 0.0f) {
|
||||||
|
value += mask_min;
|
||||||
|
}
|
||||||
|
if (i0 > i1) {
|
||||||
|
value += mask_min;
|
||||||
|
}
|
||||||
|
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto hidden_states = llm->compute(n_threads,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
true);
|
||||||
|
GGML_ASSERT(!hidden_states.empty());
|
||||||
|
hidden_states = apply_token_weights(std::move(hidden_states), weights);
|
||||||
|
|
||||||
|
int64_t valid_tokens = 0;
|
||||||
|
for (float value : mask) {
|
||||||
|
valid_tokens += static_cast<int64_t>(value > 0.0f);
|
||||||
|
}
|
||||||
|
GGML_ASSERT(valid_tokens > 0);
|
||||||
|
|
||||||
|
hidden_states = sd::ops::slice(hidden_states,
|
||||||
|
1,
|
||||||
|
hidden_states.shape()[1] - valid_tokens,
|
||||||
|
hidden_states.shape()[1]);
|
||||||
|
hidden_states.reshape_({kHiddenSize, kNumStates, valid_tokens});
|
||||||
|
hidden_states = hidden_states.permute({1, 0, 2});
|
||||||
|
|
||||||
|
if (dual_projection) {
|
||||||
|
for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) {
|
||||||
|
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
|
||||||
|
double sq_sum = 0.0;
|
||||||
|
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
|
||||||
|
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
|
||||||
|
sq_sum += static_cast<double>(value) * static_cast<double>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_rms = 1.0f / std::sqrt(static_cast<float>(sq_sum / static_cast<double>(kHiddenSize)) + 1e-6f);
|
||||||
|
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
|
||||||
|
hidden_states.index(state_idx, hidden_idx, token_idx) *= inv_rms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) {
|
||||||
|
double sum = 0.0;
|
||||||
|
float min_value = std::numeric_limits<float>::infinity();
|
||||||
|
float max_value = -std::numeric_limits<float>::infinity();
|
||||||
|
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
|
||||||
|
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
|
||||||
|
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
|
||||||
|
sum += value;
|
||||||
|
min_value = std::min(min_value, value);
|
||||||
|
max_value = std::max(max_value, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float mean_value = static_cast<float>(sum / static_cast<double>(kHiddenSize * valid_tokens));
|
||||||
|
float denom = max_value - min_value + 1e-6f;
|
||||||
|
float scale_value = 8.0f / denom;
|
||||||
|
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
|
||||||
|
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
|
||||||
|
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
|
||||||
|
hidden_states.index(state_idx, hidden_idx, token_idx) = (value - mean_value) * scale_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states.reshape_({kNumStates * kHiddenSize, valid_tokens});
|
||||||
|
return projector->compute(n_threads, hidden_states);
|
||||||
|
}
|
||||||
|
|
||||||
|
SDCondition get_learned_condition(int n_threads,
|
||||||
|
const ConditionerParams& conditioner_params) override {
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
|
std::string prompt;
|
||||||
|
std::pair<int, int> prompt_attn_range;
|
||||||
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
|
prompt += conditioner_params.text;
|
||||||
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
|
auto hidden_states = encode_prompt(n_threads, prompt, prompt_attn_range);
|
||||||
|
GGML_ASSERT(!hidden_states.empty());
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing LTXAV condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
|
|
||||||
|
SDCondition result;
|
||||||
|
result.c_crossattn = std::move(hidden_states);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
147
src/denoiser.hpp
147
src/denoiser.hpp
@ -1,6 +1,8 @@
|
|||||||
#ifndef __DENOISER_HPP__
|
#ifndef __DENOISER_HPP__
|
||||||
#define __DENOISER_HPP__
|
#define __DENOISER_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -480,6 +482,141 @@ struct KLOptimalScheduler : SigmaScheduler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LTX2Scheduler : SigmaScheduler {
|
||||||
|
int token_count = 4096;
|
||||||
|
float max_shift = 2.05f;
|
||||||
|
float base_shift = 0.95f;
|
||||||
|
bool stretch = true;
|
||||||
|
float terminal = 0.1f;
|
||||||
|
|
||||||
|
explicit LTX2Scheduler(int token_count, const char* extra_sample_args = nullptr)
|
||||||
|
: token_count(token_count > 0 ? token_count : 4096) {
|
||||||
|
parse_extra_sample_args(extra_sample_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string trim(std::string value) {
|
||||||
|
const char* whitespace = " \t\r\n";
|
||||||
|
size_t begin = value.find_first_not_of(whitespace);
|
||||||
|
if (begin == std::string::npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
size_t end = value.find_last_not_of(whitespace);
|
||||||
|
return value.substr(begin, end - begin + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse_extra_sample_args(const char* extra_sample_args) {
|
||||||
|
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string raw(extra_sample_args);
|
||||||
|
size_t start = 0;
|
||||||
|
auto parse_arg = [&](const std::string& item) {
|
||||||
|
std::string token = trim(item);
|
||||||
|
if (token.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_t eq = token.find('=');
|
||||||
|
if (eq == std::string::npos) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string key = trim(token.substr(0, eq));
|
||||||
|
std::string value = trim(token.substr(eq + 1));
|
||||||
|
auto parse_float = [&](float* out) -> bool {
|
||||||
|
try {
|
||||||
|
size_t consumed = 0;
|
||||||
|
float parsed = std::stof(value, &consumed);
|
||||||
|
if (!trim(value.substr(consumed)).empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*out = parsed;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
try {
|
||||||
|
if (key == "max_shift") {
|
||||||
|
if (!parse_float(&max_shift)) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "base_shift") {
|
||||||
|
if (!parse_float(&base_shift)) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "terminal") {
|
||||||
|
if (!parse_float(&terminal)) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "stretch") {
|
||||||
|
std::string v = value;
|
||||||
|
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||||
|
if (v == "1" || v == "true" || v == "yes" || v == "on") {
|
||||||
|
stretch = true;
|
||||||
|
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
|
||||||
|
stretch = false;
|
||||||
|
} else {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
|
||||||
|
}
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||||
|
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||||
|
parse_arg(raw.substr(start, pos - start));
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
|
||||||
|
std::vector<float> sigmas;
|
||||||
|
if (n == 0) {
|
||||||
|
sigmas.push_back(0.0f);
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr float base_shift_anchor = 1024.0f;
|
||||||
|
constexpr float max_shift_anchor = 4096.0f;
|
||||||
|
float m = (max_shift - base_shift) / (max_shift_anchor - base_shift_anchor);
|
||||||
|
float b = base_shift - m * base_shift_anchor;
|
||||||
|
float sigma_shift = static_cast<float>(token_count) * m + b;
|
||||||
|
float exp_shift = std::exp(sigma_shift);
|
||||||
|
float target_terminal = std::clamp(terminal, 0.0f, 0.99f);
|
||||||
|
|
||||||
|
LOG_DEBUG("LTX2 scheduler: tokens=%d, shift=%.4f, stretch=%d, terminal=%.4f", token_count, sigma_shift, stretch ? 1 : 0, target_terminal);
|
||||||
|
|
||||||
|
sigmas.reserve(n + 1);
|
||||||
|
for (uint32_t i = 0; i <= n; ++i) {
|
||||||
|
float sigma = 1.0f - static_cast<float>(i) / static_cast<float>(n);
|
||||||
|
if (sigma != 0.0f) {
|
||||||
|
sigma = exp_shift / (exp_shift + (1.0f / sigma - 1.0f));
|
||||||
|
}
|
||||||
|
sigmas.push_back(sigma);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stretch && sigmas.size() > 2) {
|
||||||
|
float one_minus_last = 1.0f - sigmas[n - 1];
|
||||||
|
float scale_factor = one_minus_last / (1.0f - target_terminal);
|
||||||
|
if (scale_factor > 1e-8f) {
|
||||||
|
for (uint32_t i = 0; i < n; ++i) {
|
||||||
|
sigmas[i] = 1.0f - (1.0f - sigmas[i]) / scale_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sigmas[n] = 0.0f;
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Denoiser {
|
struct Denoiser {
|
||||||
virtual float sigma_min() = 0;
|
virtual float sigma_min() = 0;
|
||||||
virtual float sigma_max() = 0;
|
virtual float sigma_max() = 0;
|
||||||
@ -492,7 +629,7 @@ struct Denoiser {
|
|||||||
virtual sd::Tensor<float> inverse_noise_scaling(float sigma,
|
virtual sd::Tensor<float> inverse_noise_scaling(float sigma,
|
||||||
const sd::Tensor<float>& latent) = 0;
|
const sd::Tensor<float>& latent) = 0;
|
||||||
|
|
||||||
virtual std::vector<float> get_sigmas(uint32_t n, int /*image_seq_len*/, scheduler_t scheduler_type, SDVersion version) {
|
virtual std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version, const char* extra_sample_args = nullptr) {
|
||||||
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
|
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
|
||||||
std::shared_ptr<SigmaScheduler> scheduler;
|
std::shared_ptr<SigmaScheduler> scheduler;
|
||||||
switch (scheduler_type) {
|
switch (scheduler_type) {
|
||||||
@ -540,6 +677,10 @@ struct Denoiser {
|
|||||||
LOG_INFO("get_sigmas with LCM scheduler");
|
LOG_INFO("get_sigmas with LCM scheduler");
|
||||||
scheduler = std::make_shared<LCMScheduler>();
|
scheduler = std::make_shared<LCMScheduler>();
|
||||||
break;
|
break;
|
||||||
|
case LTX2_SCHEDULER:
|
||||||
|
LOG_INFO("get_sigmas with LTX2 scheduler");
|
||||||
|
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG_INFO("get_sigmas with discrete scheduler (default)");
|
LOG_INFO("get_sigmas with discrete scheduler (default)");
|
||||||
scheduler = std::make_shared<DiscreteScheduler>();
|
scheduler = std::make_shared<DiscreteScheduler>();
|
||||||
@ -745,11 +886,11 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
|
|||||||
return mu;
|
return mu;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version) override {
|
std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version, const char* extra_sample_args = nullptr) override {
|
||||||
float mu = compute_empirical_mu(n, image_seq_len);
|
float mu = compute_empirical_mu(n, image_seq_len);
|
||||||
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
|
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
|
||||||
set_shift(mu);
|
set_shift(mu);
|
||||||
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version);
|
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version, extra_sample_args);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include "ernie_image.hpp"
|
#include "ernie_image.hpp"
|
||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "hidream_o1.hpp"
|
#include "hidream_o1.hpp"
|
||||||
|
#include "ltxv.hpp"
|
||||||
#include "mmdit.hpp"
|
#include "mmdit.hpp"
|
||||||
#include "qwen_image.hpp"
|
#include "qwen_image.hpp"
|
||||||
#include "tensor_ggml.hpp"
|
#include "tensor_ggml.hpp"
|
||||||
@ -16,6 +17,8 @@
|
|||||||
struct DiffusionParams {
|
struct DiffusionParams {
|
||||||
const sd::Tensor<float>* x = nullptr;
|
const sd::Tensor<float>* x = nullptr;
|
||||||
const sd::Tensor<float>* timesteps = nullptr;
|
const sd::Tensor<float>* timesteps = nullptr;
|
||||||
|
const sd::Tensor<float>* audio_x = nullptr;
|
||||||
|
const sd::Tensor<float>* audio_timesteps = nullptr;
|
||||||
const sd::Tensor<float>* context = nullptr;
|
const sd::Tensor<float>* context = nullptr;
|
||||||
const sd::Tensor<float>* c_concat = nullptr;
|
const sd::Tensor<float>* c_concat = nullptr;
|
||||||
const sd::Tensor<float>* y = nullptr;
|
const sd::Tensor<float>* y = nullptr;
|
||||||
@ -35,6 +38,8 @@ struct DiffusionParams {
|
|||||||
float control_strength = 0.f;
|
float control_strength = 0.f;
|
||||||
const sd::Tensor<float>* vace_context = nullptr;
|
const sd::Tensor<float>* vace_context = nullptr;
|
||||||
float vace_strength = 1.f;
|
float vace_strength = 1.f;
|
||||||
|
int audio_length = 0;
|
||||||
|
float frame_rate = 24.f;
|
||||||
const std::vector<int>* skip_layers = nullptr;
|
const std::vector<int>* skip_layers = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -695,4 +700,74 @@ struct ErnieImageModel : public DiffusionModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LTXAVModel : public DiffusionModel {
|
||||||
|
std::string prefix;
|
||||||
|
LTXV::LTXAVRunner ltxav;
|
||||||
|
|
||||||
|
LTXAVModel(ggml_backend_t backend,
|
||||||
|
ggml_backend_t params_backend,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
const std::string prefix = "model.diffusion_model")
|
||||||
|
: prefix(prefix), ltxav(backend, params_backend, tensor_storage_map, prefix) {
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return ltxav.get_desc();
|
||||||
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() override {
|
||||||
|
ltxav.alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_params_buffer() override {
|
||||||
|
ltxav.free_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_compute_buffer() override {
|
||||||
|
ltxav.free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
|
||||||
|
ltxav.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_params_buffer_size() override {
|
||||||
|
return ltxav.get_params_buffer_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||||
|
ltxav.set_weight_adapter(adapter);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_adm_in_channels() override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_flash_attention_enabled(bool enabled) override {
|
||||||
|
ltxav.set_flash_attention_enabled(enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||||
|
ltxav.set_max_graph_vram_bytes(max_vram_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||||
|
ltxav.set_circular_axes(circular_x, circular_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads,
|
||||||
|
const DiffusionParams& diffusion_params) override {
|
||||||
|
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||||
|
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||||
|
return ltxav.compute(n_threads,
|
||||||
|
*diffusion_params.x,
|
||||||
|
*diffusion_params.timesteps,
|
||||||
|
tensor_or_empty(diffusion_params.context),
|
||||||
|
tensor_or_empty(diffusion_params.audio_x),
|
||||||
|
tensor_or_empty(diffusion_params.audio_timesteps),
|
||||||
|
diffusion_params.audio_length,
|
||||||
|
diffusion_params.frame_rate);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -1127,18 +1127,33 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_conv_3d(ggml_context* ctx,
|
|||||||
ggml_tensor* w,
|
ggml_tensor* w,
|
||||||
ggml_tensor* b,
|
ggml_tensor* b,
|
||||||
int64_t IC,
|
int64_t IC,
|
||||||
int s0 = 1,
|
int s0 = 1,
|
||||||
int s1 = 1,
|
int s1 = 1,
|
||||||
int s2 = 1,
|
int s2 = 1,
|
||||||
int p0 = 0,
|
int p0 = 0,
|
||||||
int p1 = 0,
|
int p1 = 0,
|
||||||
int p2 = 0,
|
int p2 = 0,
|
||||||
int d0 = 1,
|
int d0 = 1,
|
||||||
int d1 = 1,
|
int d1 = 1,
|
||||||
int d2 = 1) {
|
int d2 = 1,
|
||||||
int64_t OC = w->ne[3] / IC;
|
bool force_prec_f32 = false) {
|
||||||
int64_t N = x->ne[3] / IC;
|
if (force_prec_f32) {
|
||||||
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
ggml_tensor* im2col = ggml_im2col_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, w->type);
|
||||||
|
|
||||||
|
int64_t OC = w->ne[3] / IC;
|
||||||
|
int64_t N = x->ne[3] / IC;
|
||||||
|
x = ggml_mul_mat(ctx,
|
||||||
|
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]),
|
||||||
|
ggml_reshape_2d(ctx, w, w->ne[0] * w->ne[1] * w->ne[2] * IC, OC));
|
||||||
|
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||||
|
|
||||||
|
int64_t OD = im2col->ne[3] / N;
|
||||||
|
x = ggml_reshape_4d(ctx, x, im2col->ne[1] * im2col->ne[2], OD, N, OC);
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2));
|
||||||
|
x = ggml_reshape_4d(ctx, x, im2col->ne[1], im2col->ne[2], OD, OC * N);
|
||||||
|
} else {
|
||||||
|
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
if (b != nullptr) {
|
if (b != nullptr) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||||
@ -3133,6 +3148,7 @@ protected:
|
|||||||
std::tuple<int, int, int> padding;
|
std::tuple<int, int, int> padding;
|
||||||
std::tuple<int, int, int> dilation;
|
std::tuple<int, int, int> dilation;
|
||||||
bool bias;
|
bool bias;
|
||||||
|
bool force_prec_f32;
|
||||||
std::string prefix;
|
std::string prefix;
|
||||||
|
|
||||||
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||||
@ -3156,14 +3172,16 @@ public:
|
|||||||
std::tuple<int, int, int> stride = {1, 1, 1},
|
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||||
std::tuple<int, int, int> padding = {0, 0, 0},
|
std::tuple<int, int, int> padding = {0, 0, 0},
|
||||||
std::tuple<int, int, int> dilation = {1, 1, 1},
|
std::tuple<int, int, int> dilation = {1, 1, 1},
|
||||||
bool bias = true)
|
bool bias = true,
|
||||||
|
bool force_prec_f32 = false)
|
||||||
: in_channels(in_channels),
|
: in_channels(in_channels),
|
||||||
out_channels(out_channels),
|
out_channels(out_channels),
|
||||||
kernel_size(kernel_size),
|
kernel_size(kernel_size),
|
||||||
stride(stride),
|
stride(stride),
|
||||||
padding(padding),
|
padding(padding),
|
||||||
dilation(dilation),
|
dilation(dilation),
|
||||||
bias(bias) {}
|
bias(bias),
|
||||||
|
force_prec_f32(force_prec_f32) {}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
ggml_tensor* w = params["weight"];
|
ggml_tensor* w = params["weight"];
|
||||||
@ -3183,7 +3201,8 @@ public:
|
|||||||
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
|
||||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
||||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation),
|
||||||
|
force_prec_f32);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
364
src/llm.hpp
364
src/llm.hpp
@ -7,6 +7,7 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@ -21,6 +22,7 @@
|
|||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
#include "tokenizers/bpe_tokenizer.h"
|
#include "tokenizers/bpe_tokenizer.h"
|
||||||
|
#include "tokenizers/gemma_tokenizer.h"
|
||||||
#include "tokenizers/mistral_tokenizer.h"
|
#include "tokenizers/mistral_tokenizer.h"
|
||||||
#include "tokenizers/qwen2_tokenizer.h"
|
#include "tokenizers/qwen2_tokenizer.h"
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ namespace LLM {
|
|||||||
QWEN3_VL,
|
QWEN3_VL,
|
||||||
MISTRAL_SMALL_3_2,
|
MISTRAL_SMALL_3_2,
|
||||||
MINISTRAL_3_3B,
|
MINISTRAL_3_3B,
|
||||||
|
GEMMA3_12B,
|
||||||
ARCH_COUNT,
|
ARCH_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -42,6 +45,12 @@ namespace LLM {
|
|||||||
"qwen3vl",
|
"qwen3vl",
|
||||||
"mistral_small3.2",
|
"mistral_small3.2",
|
||||||
"ministral3.3b",
|
"ministral3.3b",
|
||||||
|
"gemma3_12b",
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class MLPActivation {
|
||||||
|
SILU,
|
||||||
|
GELU_TANH,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class LLMVisionArch {
|
enum class LLMVisionArch {
|
||||||
@ -66,23 +75,71 @@ namespace LLM {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct LLMParams {
|
struct LLMParams {
|
||||||
LLMArch arch = LLMArch::QWEN2_5_VL;
|
LLMArch arch = LLMArch::QWEN2_5_VL;
|
||||||
int64_t num_layers = 28;
|
int64_t num_layers = 28;
|
||||||
int64_t hidden_size = 3584;
|
int64_t hidden_size = 3584;
|
||||||
int64_t intermediate_size = 18944;
|
int64_t intermediate_size = 18944;
|
||||||
int num_heads = 28;
|
int num_heads = 28;
|
||||||
int num_kv_heads = 4;
|
int num_kv_heads = 4;
|
||||||
int head_dim = 128;
|
int head_dim = 128;
|
||||||
bool qkv_bias = true;
|
bool qkv_bias = true;
|
||||||
bool qk_norm = false;
|
bool qk_norm = false;
|
||||||
int64_t vocab_size = 152064;
|
bool rms_norm_add = false;
|
||||||
float rms_norm_eps = 1e-06f;
|
bool normalize_input = false;
|
||||||
|
int64_t vocab_size = 152064;
|
||||||
|
int64_t max_position_embeddings = 128000;
|
||||||
|
float rms_norm_eps = 1e-06f;
|
||||||
|
MLPActivation mlp_activation = MLPActivation::SILU;
|
||||||
|
std::vector<float> rope_thetas = {1000000.f};
|
||||||
|
std::vector<float> rope_scales = {1.f};
|
||||||
|
std::vector<int> sliding_attention;
|
||||||
LLMVisionParams vision;
|
LLMVisionParams vision;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MLP : public GGMLBlock {
|
struct LLMRMSNorm : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t hidden_size;
|
||||||
|
float eps;
|
||||||
|
bool add_unit_offset;
|
||||||
|
std::string prefix;
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
std::string prefix = "") override {
|
||||||
|
this->prefix = prefix;
|
||||||
|
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
|
LLMRMSNorm(int64_t hidden_size,
|
||||||
|
float eps = 1e-06f,
|
||||||
|
bool add_unit_offset = false)
|
||||||
|
: hidden_size(hidden_size), eps(eps), add_unit_offset(add_unit_offset) {}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
ggml_tensor* w = params["weight"];
|
||||||
|
if (ctx->weight_adapter) {
|
||||||
|
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight");
|
||||||
|
}
|
||||||
|
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
|
||||||
|
auto scaled = ggml_mul(ctx->ggml_ctx, x, w);
|
||||||
|
if (add_unit_offset) {
|
||||||
|
scaled = ggml_add_inplace(ctx->ggml_ctx, scaled, x);
|
||||||
|
}
|
||||||
|
return scaled;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MLP : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
MLPActivation activation;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MLP(int64_t hidden_size,
|
||||||
|
int64_t intermediate_size,
|
||||||
|
bool bias = false,
|
||||||
|
MLPActivation activation_ = MLPActivation::SILU)
|
||||||
|
: activation(activation_) {
|
||||||
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
|
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
|
||||||
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
|
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
|
||||||
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
|
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
|
||||||
@ -95,9 +152,13 @@ namespace LLM {
|
|||||||
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
||||||
|
|
||||||
auto h = gate_proj->forward(ctx, x);
|
auto h = gate_proj->forward(ctx, x);
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
if (activation == MLPActivation::GELU_TANH) {
|
||||||
h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x));
|
h = ggml_ext_gelu(ctx->ggml_ctx, h, true);
|
||||||
h = down_proj->forward(ctx, h);
|
} else {
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, h);
|
||||||
|
}
|
||||||
|
h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x));
|
||||||
|
h = down_proj->forward(ctx, h);
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -537,24 +598,35 @@ namespace LLM {
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
int64_t num_kv_heads;
|
int64_t num_kv_heads;
|
||||||
bool qk_norm;
|
bool qk_norm;
|
||||||
|
int64_t max_position_embeddings;
|
||||||
|
std::vector<float> rope_thetas;
|
||||||
|
std::vector<float> rope_scales;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Attention(const LLMParams& params)
|
Attention(const LLMParams& params)
|
||||||
: arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) {
|
: arch(params.arch),
|
||||||
|
num_heads(params.num_heads),
|
||||||
|
num_kv_heads(params.num_kv_heads),
|
||||||
|
head_dim(params.head_dim),
|
||||||
|
qk_norm(params.qk_norm),
|
||||||
|
max_position_embeddings(params.max_position_embeddings),
|
||||||
|
rope_thetas(params.rope_thetas),
|
||||||
|
rope_scales(params.rope_scales) {
|
||||||
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
|
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
|
||||||
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
|
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
|
||||||
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
|
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
|
||||||
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, false);
|
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, false);
|
||||||
if (params.qk_norm) {
|
if (params.qk_norm) {
|
||||||
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
|
blocks["q_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
|
||||||
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
|
blocks["k_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask = nullptr) {
|
ggml_tensor* attention_mask = nullptr,
|
||||||
|
int rope_index = 0) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
int64_t n_token = x->ne[1];
|
int64_t n_token = x->ne[1];
|
||||||
int64_t N = x->ne[2];
|
int64_t N = x->ne[2];
|
||||||
@ -572,8 +644,8 @@ namespace LLM {
|
|||||||
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
|
||||||
if (qk_norm) {
|
if (qk_norm) {
|
||||||
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
auto q_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["q_norm"]);
|
||||||
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
|
auto k_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["k_norm"]);
|
||||||
|
|
||||||
q = q_norm->forward(ctx, q);
|
q = q_norm->forward(ctx, q);
|
||||||
k = k_norm->forward(ctx, k);
|
k = k_norm->forward(ctx, k);
|
||||||
@ -588,6 +660,36 @@ namespace LLM {
|
|||||||
} else if (arch == LLMArch::QWEN3) {
|
} else if (arch == LLMArch::QWEN3) {
|
||||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
|
} else if (arch == LLMArch::GEMMA3_12B) {
|
||||||
|
float rope_theta = (rope_index == 1 ? 10000.0f : 1000000.0f);
|
||||||
|
float rope_scale = (rope_index == 1 ? 1.f : 8.f);
|
||||||
|
float freq_scale = 1.f / rope_scale;
|
||||||
|
q = ggml_rope_ext(ctx->ggml_ctx,
|
||||||
|
q,
|
||||||
|
input_pos,
|
||||||
|
nullptr,
|
||||||
|
head_dim,
|
||||||
|
GGML_ROPE_TYPE_NORMAL,
|
||||||
|
0,
|
||||||
|
rope_theta,
|
||||||
|
freq_scale,
|
||||||
|
0.f,
|
||||||
|
1.f,
|
||||||
|
32.f,
|
||||||
|
1.f);
|
||||||
|
k = ggml_rope_ext(ctx->ggml_ctx,
|
||||||
|
k,
|
||||||
|
input_pos,
|
||||||
|
nullptr,
|
||||||
|
head_dim,
|
||||||
|
GGML_ROPE_TYPE_NORMAL,
|
||||||
|
0,
|
||||||
|
rope_theta,
|
||||||
|
freq_scale,
|
||||||
|
0.f,
|
||||||
|
1.f,
|
||||||
|
32.f,
|
||||||
|
1.f);
|
||||||
} else if (arch == LLMArch::QWEN3_VL) {
|
} else if (arch == LLMArch::QWEN3_VL) {
|
||||||
int sections[4] = {24, 20, 20, 0};
|
int sections[4] = {24, 20, 20, 0};
|
||||||
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
@ -612,33 +714,76 @@ namespace LLM {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TransformerBlock : public GGMLBlock {
|
struct TransformerBlock : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
LLMArch arch;
|
||||||
|
int sliding_attention;
|
||||||
|
bool has_post_attention_norm;
|
||||||
|
bool has_post_ffw_norm;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TransformerBlock(const LLMParams& params) {
|
TransformerBlock(const LLMParams& params, int layer_index)
|
||||||
|
: arch(params.arch),
|
||||||
|
sliding_attention(0),
|
||||||
|
has_post_attention_norm(params.arch == LLMArch::GEMMA3_12B),
|
||||||
|
has_post_ffw_norm(params.arch == LLMArch::GEMMA3_12B) {
|
||||||
blocks["self_attn"] = std::make_shared<Attention>(params);
|
blocks["self_attn"] = std::make_shared<Attention>(params);
|
||||||
blocks["mlp"] = std::make_shared<MLP>(params.hidden_size, params.intermediate_size);
|
blocks["mlp"] = std::make_shared<MLP>(params.hidden_size,
|
||||||
blocks["input_layernorm"] = std::make_shared<RMSNorm>(params.hidden_size, params.rms_norm_eps);
|
params.intermediate_size,
|
||||||
blocks["post_attention_layernorm"] = std::make_shared<RMSNorm>(params.hidden_size, params.rms_norm_eps);
|
false,
|
||||||
|
params.mlp_activation);
|
||||||
|
blocks["input_layernorm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
|
||||||
|
blocks["post_attention_layernorm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
|
||||||
|
if (has_post_attention_norm) {
|
||||||
|
blocks["post_attention_norm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
|
||||||
|
}
|
||||||
|
if (has_post_ffw_norm) {
|
||||||
|
blocks["post_ffw_norm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
|
||||||
|
}
|
||||||
|
if (!params.sliding_attention.empty()) {
|
||||||
|
sliding_attention = params.sliding_attention[layer_index % params.sliding_attention.size()];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask = nullptr) {
|
ggml_tensor* attention_mask = nullptr,
|
||||||
|
ggml_tensor* sliding_attention_mask = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
||||||
auto input_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["input_layernorm"]);
|
auto input_layernorm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["input_layernorm"]);
|
||||||
auto post_attention_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["post_attention_layernorm"]);
|
auto post_attention_layernorm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_attention_layernorm"]);
|
||||||
|
std::shared_ptr<LLMRMSNorm> post_attention_norm = nullptr;
|
||||||
|
std::shared_ptr<LLMRMSNorm> post_ffw_norm = nullptr;
|
||||||
|
if (has_post_attention_norm) {
|
||||||
|
post_attention_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_attention_norm"]);
|
||||||
|
}
|
||||||
|
if (has_post_ffw_norm) {
|
||||||
|
post_ffw_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_ffw_norm"]);
|
||||||
|
}
|
||||||
|
ggml_tensor* block_attention_mask = attention_mask;
|
||||||
|
int rope_index = 0;
|
||||||
|
if (arch == LLMArch::GEMMA3_12B && sliding_attention > 0) {
|
||||||
|
block_attention_mask = sliding_attention_mask;
|
||||||
|
rope_index = 1;
|
||||||
|
}
|
||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = input_layernorm->forward(ctx, x);
|
x = input_layernorm->forward(ctx, x);
|
||||||
x = self_attn->forward(ctx, x, input_pos, attention_mask);
|
x = self_attn->forward(ctx, x, input_pos, block_attention_mask, rope_index);
|
||||||
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
if (post_attention_norm != nullptr) {
|
||||||
|
x = post_attention_norm->forward(ctx, x);
|
||||||
|
}
|
||||||
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
x = post_attention_layernorm->forward(ctx, x);
|
x = post_attention_layernorm->forward(ctx, x);
|
||||||
x = mlp->forward(ctx, x);
|
x = mlp->forward(ctx, x);
|
||||||
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
if (post_ffw_norm != nullptr) {
|
||||||
|
x = post_ffw_norm->forward(ctx, x);
|
||||||
|
}
|
||||||
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -654,9 +799,9 @@ namespace LLM {
|
|||||||
: num_layers(params.num_layers), params(params) {
|
: num_layers(params.num_layers), params(params) {
|
||||||
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size, params.hidden_size));
|
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size, params.hidden_size));
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(params));
|
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(params, i));
|
||||||
}
|
}
|
||||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.hidden_size, params.rms_norm_eps));
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LLMRMSNorm(params.hidden_size, params.rms_norm_eps, params.rms_norm_add));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* embed(GGMLRunnerContext* ctx,
|
ggml_tensor* embed(GGMLRunnerContext* ctx,
|
||||||
@ -670,46 +815,78 @@ namespace LLM {
|
|||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask,
|
ggml_tensor* attention_mask,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
ggml_tensor* sliding_attention_mask = nullptr,
|
||||||
|
bool return_all_hidden_states = false) {
|
||||||
|
auto norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["norm"]);
|
||||||
std::vector<ggml_tensor*> intermediate_outputs;
|
std::vector<ggml_tensor*> intermediate_outputs;
|
||||||
|
|
||||||
|
if (params.normalize_input) {
|
||||||
|
x = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(static_cast<float>(params.hidden_size)), true);
|
||||||
|
}
|
||||||
|
if (return_all_hidden_states) {
|
||||||
|
intermediate_outputs.push_back(x);
|
||||||
|
}
|
||||||
|
|
||||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.prelude", "x");
|
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.prelude", "x");
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
x = block->forward(ctx, x, input_pos, attention_mask);
|
x = block->forward(ctx, x, input_pos, attention_mask, sliding_attention_mask);
|
||||||
if (out_layers.size() > 1) {
|
if (return_all_hidden_states || out_layers.size() > 1) {
|
||||||
x = ggml_cont(ctx->ggml_ctx, x);
|
x = ggml_cont(ctx->ggml_ctx, x);
|
||||||
}
|
}
|
||||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.layers." + std::to_string(i), "x");
|
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.layers." + std::to_string(i), "x");
|
||||||
if (out_layers.find(i + 1) != out_layers.end()) {
|
if (return_all_hidden_states) {
|
||||||
|
if (i + 1 < num_layers) {
|
||||||
|
intermediate_outputs.push_back(x);
|
||||||
|
}
|
||||||
|
} else if (out_layers.find(i + 1) != out_layers.end()) {
|
||||||
intermediate_outputs.push_back(x);
|
intermediate_outputs.push_back(x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!intermediate_outputs.empty()) {
|
auto normed_x = norm->forward(ctx, x);
|
||||||
|
if (return_all_hidden_states) {
|
||||||
|
intermediate_outputs.push_back(normed_x);
|
||||||
x = intermediate_outputs[0];
|
x = intermediate_outputs[0];
|
||||||
for (int i = 1; i < intermediate_outputs.size(); i++) {
|
for (int i = 1; i < intermediate_outputs.size(); i++) {
|
||||||
x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0);
|
x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0);
|
||||||
}
|
}
|
||||||
return x;
|
} else if (!intermediate_outputs.empty()) {
|
||||||
|
if (out_layers.find(static_cast<int>(num_layers + 1)) != out_layers.end()) {
|
||||||
|
intermediate_outputs.push_back(normed_x);
|
||||||
|
}
|
||||||
|
x = intermediate_outputs[0];
|
||||||
|
for (int i = 1; i < intermediate_outputs.size(); i++) {
|
||||||
|
x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
x = normed_x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return norm->forward(ctx, x);
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* input_ids,
|
ggml_tensor* input_ids,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask,
|
ggml_tensor* attention_mask,
|
||||||
|
ggml_tensor* sliding_attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
|
bool return_all_hidden_states = false) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
// return: [N, n_token, hidden_size]
|
// return: [N, n_token, hidden_size]
|
||||||
auto x = embed(ctx, input_ids);
|
auto x = embed(ctx, input_ids);
|
||||||
x = splice_image_embeds(ctx, x, image_embeds);
|
x = splice_image_embeds(ctx, x, image_embeds);
|
||||||
return forward_embeds(ctx, x, input_pos, attention_mask, std::move(out_layers));
|
return forward_embeds(ctx,
|
||||||
|
x,
|
||||||
|
input_pos,
|
||||||
|
attention_mask,
|
||||||
|
std::move(out_layers),
|
||||||
|
sliding_attention_mask,
|
||||||
|
return_all_hidden_states);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -731,12 +908,21 @@ namespace LLM {
|
|||||||
ggml_tensor* input_ids,
|
ggml_tensor* input_ids,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask,
|
ggml_tensor* attention_mask,
|
||||||
|
ggml_tensor* sliding_attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
|
bool return_all_hidden_states = false) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
||||||
|
|
||||||
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
auto x = model->forward(ctx,
|
||||||
|
input_ids,
|
||||||
|
input_pos,
|
||||||
|
attention_mask,
|
||||||
|
sliding_attention_mask,
|
||||||
|
image_embeds,
|
||||||
|
out_layers,
|
||||||
|
return_all_hidden_states);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -764,6 +950,7 @@ namespace LLM {
|
|||||||
|
|
||||||
std::vector<int> input_pos_vec;
|
std::vector<int> input_pos_vec;
|
||||||
std::vector<float> attention_mask_vec;
|
std::vector<float> attention_mask_vec;
|
||||||
|
std::vector<float> sliding_attention_mask_vec;
|
||||||
std::vector<float> window_mask_vec;
|
std::vector<float> window_mask_vec;
|
||||||
std::vector<int> window_index_vec;
|
std::vector<int> window_index_vec;
|
||||||
std::vector<int> window_inverse_index_vec;
|
std::vector<int> window_inverse_index_vec;
|
||||||
@ -998,6 +1185,23 @@ namespace LLM {
|
|||||||
params.qkv_bias = false;
|
params.qkv_bias = false;
|
||||||
params.qk_norm = true;
|
params.qk_norm = true;
|
||||||
params.rms_norm_eps = 1e-6f;
|
params.rms_norm_eps = 1e-6f;
|
||||||
|
} else if (arch == LLMArch::GEMMA3_12B) {
|
||||||
|
params.head_dim = 256;
|
||||||
|
params.num_heads = 16;
|
||||||
|
params.num_kv_heads = 8;
|
||||||
|
params.qkv_bias = false;
|
||||||
|
params.qk_norm = true;
|
||||||
|
params.rms_norm_eps = 1e-6f;
|
||||||
|
// llama.cpp adds +1 to Gemma3 norm.weight when exporting GGUF, so GGUF loading
|
||||||
|
// must keep rms_norm_add disabled here or the offset gets applied twice.
|
||||||
|
// Convenient for the converter, less convenient for whoever gets to debug it later.
|
||||||
|
params.rms_norm_add = false;
|
||||||
|
params.normalize_input = true;
|
||||||
|
params.max_position_embeddings = 131072;
|
||||||
|
params.mlp_activation = MLPActivation::GELU_TANH;
|
||||||
|
params.rope_thetas = {1000000.f, 10000.f};
|
||||||
|
params.rope_scales = {8.f, 1.f};
|
||||||
|
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
|
||||||
}
|
}
|
||||||
bool have_vision_weight = false;
|
bool have_vision_weight = false;
|
||||||
bool llama_cpp_style = false;
|
bool llama_cpp_style = false;
|
||||||
@ -1067,9 +1271,18 @@ namespace LLM {
|
|||||||
ggml_tensor* input_ids,
|
ggml_tensor* input_ids,
|
||||||
ggml_tensor* input_pos,
|
ggml_tensor* input_pos,
|
||||||
ggml_tensor* attention_mask,
|
ggml_tensor* attention_mask,
|
||||||
|
ggml_tensor* sliding_attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
|
bool return_all_hidden_states = false) {
|
||||||
|
auto hidden_states = model.forward(ctx,
|
||||||
|
input_ids,
|
||||||
|
input_pos,
|
||||||
|
attention_mask,
|
||||||
|
sliding_attention_mask,
|
||||||
|
image_embeds,
|
||||||
|
out_layers,
|
||||||
|
return_all_hidden_states); // [N, n_token, hidden_size]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1087,8 +1300,9 @@ namespace LLM {
|
|||||||
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
|
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
|
||||||
const sd::Tensor<float>& attention_mask_tensor,
|
const sd::Tensor<float>& attention_mask_tensor,
|
||||||
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds_tensor,
|
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds_tensor,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
bool return_all_hidden_states = false) {
|
||||||
|
ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE);
|
||||||
ggml_tensor* input_ids = make_input(input_ids_tensor);
|
ggml_tensor* input_ids = make_input(input_ids_tensor);
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||||
image_embeds.reserve(image_embeds_tensor.size());
|
image_embeds.reserve(image_embeds_tensor.size());
|
||||||
@ -1098,7 +1312,10 @@ namespace LLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t n_tokens = input_ids->ne[0];
|
int64_t n_tokens = input_ids->ne[0];
|
||||||
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) {
|
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 ||
|
||||||
|
params.arch == LLMArch::MINISTRAL_3_3B ||
|
||||||
|
params.arch == LLMArch::QWEN3 ||
|
||||||
|
params.arch == LLMArch::GEMMA3_12B) {
|
||||||
input_pos_vec.resize(n_tokens);
|
input_pos_vec.resize(n_tokens);
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
input_pos_vec[i] = i;
|
input_pos_vec[i] = i;
|
||||||
@ -1118,7 +1335,8 @@ namespace LLM {
|
|||||||
input_pos_vec.size());
|
input_pos_vec.size());
|
||||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
ggml_tensor* attention_mask = nullptr;
|
ggml_tensor* attention_mask = nullptr;
|
||||||
|
ggml_tensor* sliding_attention_mask = nullptr;
|
||||||
if (!attention_mask_tensor.empty()) {
|
if (!attention_mask_tensor.empty()) {
|
||||||
attention_mask = make_input(attention_mask_tensor);
|
attention_mask = make_input(attention_mask_tensor);
|
||||||
} else {
|
} else {
|
||||||
@ -1136,9 +1354,36 @@ namespace LLM {
|
|||||||
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.arch == LLMArch::GEMMA3_12B) {
|
||||||
|
sliding_attention_mask_vec.resize(n_tokens * n_tokens);
|
||||||
|
if (!attention_mask_tensor.empty()) {
|
||||||
|
GGML_ASSERT(attention_mask_tensor.numel() == n_tokens * n_tokens);
|
||||||
|
sliding_attention_mask_vec = attention_mask_tensor.values();
|
||||||
|
} else {
|
||||||
|
sliding_attention_mask_vec = attention_mask_vec;
|
||||||
|
}
|
||||||
|
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||||
|
for (int i1 = 0; i1 < n_tokens; i1++) {
|
||||||
|
if (i0 + 1024 <= i1) {
|
||||||
|
LOG_DEBUG("xxxxxxxxxxxxxx");
|
||||||
|
sliding_attention_mask_vec[i1 * n_tokens + i0] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sliding_attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
|
set_backend_tensor_data(sliding_attention_mask, sliding_attention_mask_vec.data());
|
||||||
|
}
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
ggml_tensor* hidden_states = forward(&runner_ctx,
|
||||||
|
input_ids,
|
||||||
|
input_pos,
|
||||||
|
attention_mask,
|
||||||
|
sliding_attention_mask,
|
||||||
|
image_embeds,
|
||||||
|
out_layers,
|
||||||
|
return_all_hidden_states);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
@ -1149,9 +1394,14 @@ namespace LLM {
|
|||||||
const sd::Tensor<int32_t>& input_ids,
|
const sd::Tensor<int32_t>& input_ids,
|
||||||
const sd::Tensor<float>& attention_mask,
|
const sd::Tensor<float>& attention_mask,
|
||||||
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
|
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers,
|
||||||
|
bool return_all_hidden_states = false) {
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
|
return build_graph(input_ids,
|
||||||
|
attention_mask,
|
||||||
|
image_embeds,
|
||||||
|
out_layers,
|
||||||
|
return_all_hidden_states);
|
||||||
};
|
};
|
||||||
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
|
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
|
||||||
}
|
}
|
||||||
|
|||||||
1109
src/ltx_audio_vae.h
Normal file
1109
src/ltx_audio_vae.h
Normal file
File diff suppressed because it is too large
Load Diff
1299
src/ltx_vae.hpp
Normal file
1299
src/ltx_vae.hpp
Normal file
File diff suppressed because it is too large
Load Diff
1955
src/ltxv.hpp
1955
src/ltxv.hpp
File diff suppressed because it is too large
Load Diff
@ -462,6 +462,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) {
|
||||||
return VERSION_ERNIE_IMAGE;
|
return VERSION_ERNIE_IMAGE;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.bias") != std::string::npos) {
|
||||||
|
return VERSION_LTXAV;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||||
is_wan = true;
|
is_wan = true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,6 +42,7 @@ enum SDVersion {
|
|||||||
VERSION_ANIMA,
|
VERSION_ANIMA,
|
||||||
VERSION_FLUX2,
|
VERSION_FLUX2,
|
||||||
VERSION_FLUX2_KLEIN,
|
VERSION_FLUX2_KLEIN,
|
||||||
|
VERSION_LTXAV,
|
||||||
VERSION_HIDREAM_O1,
|
VERSION_HIDREAM_O1,
|
||||||
VERSION_Z_IMAGE,
|
VERSION_Z_IMAGE,
|
||||||
VERSION_OVIS_IMAGE,
|
VERSION_OVIS_IMAGE,
|
||||||
@ -105,6 +106,13 @@ static inline bool sd_version_is_flux2(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_ltxav(SDVersion version) {
|
||||||
|
if (version == VERSION_LTXAV) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_wan(SDVersion version) {
|
static inline bool sd_version_is_wan(SDVersion version) {
|
||||||
if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) {
|
if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) {
|
||||||
return true;
|
return true;
|
||||||
@ -161,6 +169,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
|
|||||||
static inline bool sd_version_is_dit(SDVersion version) {
|
static inline bool sd_version_is_dit(SDVersion version) {
|
||||||
if (sd_version_is_flux(version) ||
|
if (sd_version_is_flux(version) ||
|
||||||
sd_version_is_flux2(version) ||
|
sd_version_is_flux2(version) ||
|
||||||
|
sd_version_is_ltxav(version) ||
|
||||||
sd_version_is_sd3(version) ||
|
sd_version_is_sd3(version) ||
|
||||||
sd_version_is_wan(version) ||
|
sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
#include "diffusion_model.hpp"
|
#include "diffusion_model.hpp"
|
||||||
#include "esrgan.hpp"
|
#include "esrgan.hpp"
|
||||||
#include "lora.hpp"
|
#include "lora.hpp"
|
||||||
|
#include "ltx_audio_vae.h"
|
||||||
|
#include "ltx_vae.hpp"
|
||||||
#include "pmid.hpp"
|
#include "pmid.hpp"
|
||||||
#include "sample-cache.h"
|
#include "sample-cache.h"
|
||||||
#include "tae.hpp"
|
#include "tae.hpp"
|
||||||
@ -53,6 +55,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Anima",
|
"Anima",
|
||||||
"Flux.2",
|
"Flux.2",
|
||||||
"Flux.2 klein",
|
"Flux.2 klein",
|
||||||
|
"LTXAV",
|
||||||
"HiDream O1",
|
"HiDream O1",
|
||||||
"Z-Image",
|
"Z-Image",
|
||||||
"Ovis Image",
|
"Ovis Image",
|
||||||
@ -134,6 +137,7 @@ public:
|
|||||||
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
||||||
std::shared_ptr<VAE> first_stage_model;
|
std::shared_ptr<VAE> first_stage_model;
|
||||||
std::shared_ptr<VAE> preview_vae;
|
std::shared_ptr<VAE> preview_vae;
|
||||||
|
std::shared_ptr<LTXV::LTXAudioVAERunner> audio_vae_model;
|
||||||
std::shared_ptr<ControlNet> control_net;
|
std::shared_ptr<ControlNet> control_net;
|
||||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||||
std::shared_ptr<LoraModel> pmid_lora;
|
std::shared_ptr<LoraModel> pmid_lora;
|
||||||
@ -144,7 +148,7 @@ public:
|
|||||||
bool apply_lora_immediately = false;
|
bool apply_lora_immediately = false;
|
||||||
|
|
||||||
std::string taesd_path;
|
std::string taesd_path;
|
||||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
|
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0};
|
||||||
bool offload_params_to_cpu = false;
|
bool offload_params_to_cpu = false;
|
||||||
float max_vram = 0.f;
|
float max_vram = 0.f;
|
||||||
bool use_pmid = false;
|
bool use_pmid = false;
|
||||||
@ -222,7 +226,8 @@ public:
|
|||||||
backend_spec = SAFE_STR(sd_ctx_params->backend);
|
backend_spec = SAFE_STR(sd_ctx_params->backend);
|
||||||
params_backend_spec = SAFE_STR(sd_ctx_params->params_backend);
|
params_backend_spec = SAFE_STR(sd_ctx_params->params_backend);
|
||||||
|
|
||||||
bool use_tae = false;
|
bool use_tae = false;
|
||||||
|
bool use_audio_vae = false;
|
||||||
|
|
||||||
rng = get_rng(sd_ctx_params->rng_type);
|
rng = get_rng(sd_ctx_params->rng_type);
|
||||||
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
||||||
@ -324,6 +329,22 @@ public:
|
|||||||
use_tae = true;
|
use_tae = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
|
||||||
|
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
|
||||||
|
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
|
||||||
|
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (strlen(SAFE_STR(sd_ctx_params->audio_vae_path)) > 0) {
|
||||||
|
LOG_INFO("loading LTX audio VAE from '%s'", sd_ctx_params->audio_vae_path);
|
||||||
|
if (!model_loader.init_from_file(sd_ctx_params->audio_vae_path)) {
|
||||||
|
LOG_WARN("loading LTX audio VAE weights from '%s' failed", sd_ctx_params->audio_vae_path);
|
||||||
|
} else {
|
||||||
|
use_audio_vae = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
model_loader.convert_tensors_name();
|
model_loader.convert_tensors_name();
|
||||||
|
|
||||||
version = model_loader.get_sd_version();
|
version = model_loader.get_sd_version();
|
||||||
@ -437,7 +458,6 @@ public:
|
|||||||
// Might need vae encode for control cond
|
// Might need vae encode for control cond
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
||||||
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
|
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
|
||||||
tae_preview_only = false;
|
tae_preview_only = false;
|
||||||
@ -514,6 +534,14 @@ public:
|
|||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
version,
|
version,
|
||||||
sd_ctx_params->chroma_use_dit_mask);
|
sd_ctx_params->chroma_use_dit_mask);
|
||||||
|
} else if (sd_version_is_ltxav(version)) {
|
||||||
|
cond_stage_model = std::make_shared<LTXAVEmbedder>(backend_for(SDBackendModule::TE),
|
||||||
|
params_backend_for(SDBackendModule::TE),
|
||||||
|
tensor_storage_map);
|
||||||
|
diffusion_model = std::make_shared<LTXAVModel>(backend_for(SDBackendModule::DIFFUSION),
|
||||||
|
params_backend_for(SDBackendModule::DIFFUSION),
|
||||||
|
tensor_storage_map,
|
||||||
|
"model.diffusion_model");
|
||||||
} else if (sd_version_is_wan(version)) {
|
} else if (sd_version_is_wan(version)) {
|
||||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(backend_for(SDBackendModule::TE),
|
cond_stage_model = std::make_shared<T5CLIPEmbedder>(backend_for(SDBackendModule::TE),
|
||||||
params_backend_for(SDBackendModule::TE),
|
params_backend_for(SDBackendModule::TE),
|
||||||
@ -668,9 +696,16 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_is_ltxav(version)) {
|
||||||
sd_version_is_qwen_image(version) ||
|
return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE),
|
||||||
sd_version_is_anima(version)) {
|
params_backend_for(SDBackendModule::VAE),
|
||||||
|
tensor_storage_map,
|
||||||
|
"first_stage_model",
|
||||||
|
vae_decode_only,
|
||||||
|
version);
|
||||||
|
} else if (sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_anima(version)) {
|
||||||
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
||||||
params_backend_for(SDBackendModule::VAE),
|
params_backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
@ -723,6 +758,13 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_audio_vae) {
|
||||||
|
audio_vae_model = std::make_shared<LTXV::LTXAudioVAERunner>(backend_for(SDBackendModule::VAE),
|
||||||
|
params_backend_for(SDBackendModule::VAE),
|
||||||
|
tensor_storage_map);
|
||||||
|
get_param_tensors_p(audio_vae_model, vae_mmap, "");
|
||||||
|
}
|
||||||
|
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
LOG_INFO("Using Conv2d direct in the vae model");
|
||||||
first_stage_model->set_conv2d_direct_enabled(true);
|
first_stage_model->set_conv2d_direct_enabled(true);
|
||||||
@ -856,6 +898,9 @@ public:
|
|||||||
ignore_tensors.insert("tae.encoder");
|
ignore_tensors.insert("tae.encoder");
|
||||||
ignore_tensors.insert("text_encoders.llm.visual.");
|
ignore_tensors.insert("text_encoders.llm.visual.");
|
||||||
}
|
}
|
||||||
|
if (audio_vae_model) {
|
||||||
|
ignore_tensors.insert("audio_vae.encoder");
|
||||||
|
}
|
||||||
if (version == VERSION_OVIS_IMAGE) {
|
if (version == VERSION_OVIS_IMAGE) {
|
||||||
ignore_tensors.insert("text_encoders.llm.vision_model.");
|
ignore_tensors.insert("text_encoders.llm.vision_model.");
|
||||||
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
|
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
|
||||||
@ -905,6 +950,11 @@ public:
|
|||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (audio_vae_model && !audio_vae_model->alloc_params_buffer()) {
|
||||||
|
LOG_ERROR("LTX audio VAE params buffer allocation failed");
|
||||||
|
ggml_free(ctx);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (use_pmid && pmid_model && !pmid_model->alloc_params_buffer()) {
|
if (use_pmid && pmid_model && !pmid_model->alloc_params_buffer()) {
|
||||||
LOG_ERROR("PhotoMaker params buffer allocation failed");
|
LOG_ERROR("PhotoMaker params buffer allocation failed");
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
@ -931,6 +981,9 @@ public:
|
|||||||
if (preview_vae) {
|
if (preview_vae) {
|
||||||
vae_params_mem_size += preview_vae->get_params_buffer_size();
|
vae_params_mem_size += preview_vae->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
|
if (audio_vae_model) {
|
||||||
|
vae_params_mem_size += audio_vae_model->get_params_buffer_size();
|
||||||
|
}
|
||||||
size_t control_net_params_mem_size = 0;
|
size_t control_net_params_mem_size = 0;
|
||||||
if (control_net) {
|
if (control_net) {
|
||||||
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
|
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
|
||||||
@ -1023,6 +1076,7 @@ public:
|
|||||||
pred_type = EPS_PRED;
|
pred_type = EPS_PRED;
|
||||||
}
|
}
|
||||||
} else if (sd_version_is_sd3(version) ||
|
} else if (sd_version_is_sd3(version) ||
|
||||||
|
sd_version_is_ltxav(version) ||
|
||||||
sd_version_is_wan(version) ||
|
sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
version == VERSION_HIDREAM_O1 ||
|
version == VERSION_HIDREAM_O1 ||
|
||||||
@ -1030,7 +1084,9 @@ public:
|
|||||||
sd_version_is_ernie_image(version) ||
|
sd_version_is_ernie_image(version) ||
|
||||||
sd_version_is_z_image(version)) {
|
sd_version_is_z_image(version)) {
|
||||||
pred_type = FLOW_PRED;
|
pred_type = FLOW_PRED;
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_ltxav(version)) {
|
||||||
|
default_flow_shift = 2.37f;
|
||||||
|
} else if (sd_version_is_wan(version)) {
|
||||||
default_flow_shift = 5.f;
|
default_flow_shift = 5.f;
|
||||||
} else if (sd_version_is_ernie_image(version)) {
|
} else if (sd_version_is_ernie_image(version)) {
|
||||||
default_flow_shift = 4.f;
|
default_flow_shift = 4.f;
|
||||||
@ -1067,8 +1123,13 @@ public:
|
|||||||
denoiser = std::make_shared<EDMVDenoiser>();
|
denoiser = std::make_shared<EDMVDenoiser>();
|
||||||
break;
|
break;
|
||||||
case FLOW_PRED: {
|
case FLOW_PRED: {
|
||||||
LOG_INFO("running in FLOW mode");
|
if (sd_version_is_ltxav(version)) {
|
||||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
LOG_INFO("running in LTXAV FLOW mode");
|
||||||
|
denoiser = std::make_shared<FluxFlowDenoiser>();
|
||||||
|
} else {
|
||||||
|
LOG_INFO("running in FLOW mode");
|
||||||
|
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FLUX_FLOW_PRED: {
|
case FLUX_FLOW_PRED: {
|
||||||
@ -1505,6 +1566,38 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> process_ltxav_video_timesteps(const std::vector<float>& timesteps,
|
||||||
|
const sd::Tensor<float>& init_latent,
|
||||||
|
const sd::Tensor<float>& denoise_mask) {
|
||||||
|
if (timesteps.empty() || denoise_mask.empty() || init_latent.dim() < 4 || denoise_mask.dim() < 4) {
|
||||||
|
return timesteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t width = init_latent.shape()[0];
|
||||||
|
int64_t height = init_latent.shape()[1];
|
||||||
|
int64_t frames = init_latent.shape()[2];
|
||||||
|
if (denoise_mask.shape()[0] != width ||
|
||||||
|
denoise_mask.shape()[1] != height ||
|
||||||
|
denoise_mask.shape()[2] != frames ||
|
||||||
|
denoise_mask.shape()[3] < 1) {
|
||||||
|
LOG_WARN("unexpected LTXAV denoise mask shape for timestep processing");
|
||||||
|
return timesteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> video_timesteps(static_cast<size_t>(width * height * frames));
|
||||||
|
size_t idx = 0;
|
||||||
|
for (int64_t t = 0; t < frames; ++t) {
|
||||||
|
for (int64_t h = 0; h < height; ++h) {
|
||||||
|
for (int64_t w = 0; w < width; ++w) {
|
||||||
|
float mask = denoise_mask.dim() == 5 ? denoise_mask.index(w, h, t, 0, 0)
|
||||||
|
: denoise_mask.index(w, h, t, 0);
|
||||||
|
video_timesteps[idx++] = mask * timesteps[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return video_timesteps;
|
||||||
|
}
|
||||||
|
|
||||||
void preview_image(int step,
|
void preview_image(int step,
|
||||||
const sd::Tensor<float>& latents,
|
const sd::Tensor<float>& latents,
|
||||||
enum SDVersion version,
|
enum SDVersion version,
|
||||||
@ -1586,9 +1679,11 @@ public:
|
|||||||
sd::Tensor<float> decoded;
|
sd::Tensor<float> decoded;
|
||||||
bool is_video = preview_latent_tensor_is_video(latents);
|
bool is_video = preview_latent_tensor_is_video(latents);
|
||||||
if (preview_vae) {
|
if (preview_vae) {
|
||||||
|
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
|
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
|
||||||
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
||||||
} else {
|
} else {
|
||||||
|
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
|
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
|
||||||
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
|
||||||
}
|
}
|
||||||
@ -1730,6 +1825,8 @@ public:
|
|||||||
const sd::Tensor<float>& denoise_mask,
|
const sd::Tensor<float>& denoise_mask,
|
||||||
const sd::Tensor<float>& vace_context,
|
const sd::Tensor<float>& vace_context,
|
||||||
float vace_strength,
|
float vace_strength,
|
||||||
|
int audio_length,
|
||||||
|
float frame_rate,
|
||||||
const sd_cache_params_t* cache_params) {
|
const sd_cache_params_t* cache_params) {
|
||||||
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
|
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
|
||||||
float cfg_scale = guidance.txt_cfg;
|
float cfg_scale = guidance.txt_cfg;
|
||||||
@ -1778,14 +1875,24 @@ public:
|
|||||||
float c_out = scaling[1];
|
float c_out = scaling[1];
|
||||||
float c_in = scaling[2];
|
float c_in = scaling[2];
|
||||||
|
|
||||||
std::vector<float> timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep);
|
std::vector<float> base_timesteps_vec = prepare_sample_timesteps(sigma, shifted_timestep);
|
||||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
std::vector<float> timesteps_vec = base_timesteps_vec;
|
||||||
adjust_sample_step_scalings(shifted_timestep, timesteps_vec, c_in, &c_skip, &c_out);
|
sd::Tensor<float> audio_timesteps_tensor;
|
||||||
|
if (sd_version_is_ltxav(version) && !denoise_mask.empty()) {
|
||||||
|
timesteps_vec = process_ltxav_video_timesteps(base_timesteps_vec, init_latent, denoise_mask);
|
||||||
|
audio_timesteps_tensor = sd::Tensor<float>({static_cast<int64_t>(base_timesteps_vec.size())}, base_timesteps_vec);
|
||||||
|
} else {
|
||||||
|
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||||
|
}
|
||||||
|
const std::vector<float>& scaling_timesteps_vec = (sd_version_is_ltxav(version) && !denoise_mask.empty())
|
||||||
|
? base_timesteps_vec
|
||||||
|
: timesteps_vec;
|
||||||
|
adjust_sample_step_scalings(shifted_timestep, scaling_timesteps_vec, c_in, &c_skip, &c_out);
|
||||||
|
|
||||||
sd::Tensor<float> timesteps_tensor({static_cast<int64_t>(timesteps_vec.size())}, timesteps_vec);
|
sd::Tensor<float> timesteps_tensor({static_cast<int64_t>(timesteps_vec.size())}, timesteps_vec);
|
||||||
sd::Tensor<float> guidance_tensor({1}, std::vector<float>{guidance.distilled_guidance});
|
sd::Tensor<float> guidance_tensor({1}, std::vector<float>{guidance.distilled_guidance});
|
||||||
sd::Tensor<float> noised_input = x * c_in;
|
sd::Tensor<float> noised_input = x * c_in;
|
||||||
if (!denoise_mask.empty() && version == VERSION_WAN2_2_TI2V) {
|
if (!denoise_mask.empty() && (version == VERSION_WAN2_2_TI2V || sd_version_is_ltxav(version))) {
|
||||||
noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask);
|
noised_input = noised_input * denoise_mask + init_latent * (1.0f - denoise_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1816,6 +1923,7 @@ public:
|
|||||||
DiffusionParams diffusion_params;
|
DiffusionParams diffusion_params;
|
||||||
diffusion_params.x = &noised_input;
|
diffusion_params.x = &noised_input;
|
||||||
diffusion_params.timesteps = ×teps_tensor;
|
diffusion_params.timesteps = ×teps_tensor;
|
||||||
|
diffusion_params.audio_timesteps = audio_timesteps_tensor.empty() ? nullptr : &audio_timesteps_tensor;
|
||||||
diffusion_params.guidance = &guidance_tensor;
|
diffusion_params.guidance = &guidance_tensor;
|
||||||
diffusion_params.ref_latents = &ref_latents;
|
diffusion_params.ref_latents = &ref_latents;
|
||||||
diffusion_params.increase_ref_index = increase_ref_index;
|
diffusion_params.increase_ref_index = increase_ref_index;
|
||||||
@ -1823,6 +1931,8 @@ public:
|
|||||||
diffusion_params.control_strength = control_strength;
|
diffusion_params.control_strength = control_strength;
|
||||||
diffusion_params.vace_context = vace_context.empty() ? nullptr : &vace_context;
|
diffusion_params.vace_context = vace_context.empty() ? nullptr : &vace_context;
|
||||||
diffusion_params.vace_strength = vace_strength;
|
diffusion_params.vace_strength = vace_strength;
|
||||||
|
diffusion_params.audio_length = audio_length;
|
||||||
|
diffusion_params.frame_rate = frame_rate;
|
||||||
diffusion_params.skip_layers = nullptr;
|
diffusion_params.skip_layers = nullptr;
|
||||||
|
|
||||||
compute_sample_controls(control_image,
|
compute_sample_controls(control_image,
|
||||||
@ -1994,7 +2104,9 @@ public:
|
|||||||
int get_latent_channel() {
|
int get_latent_channel() {
|
||||||
int latent_channel = 4;
|
int latent_channel = 4;
|
||||||
if (sd_version_is_dit(version)) {
|
if (sd_version_is_dit(version)) {
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
if (sd_version_is_ltxav(version)) {
|
||||||
|
latent_channel = 128;
|
||||||
|
} else if (version == VERSION_WAN2_2_TI2V) {
|
||||||
latent_channel = 48;
|
latent_channel = 48;
|
||||||
} else if (version == VERSION_HIDREAM_O1) {
|
} else if (version == VERSION_HIDREAM_O1) {
|
||||||
latent_channel = 3;
|
latent_channel = 3;
|
||||||
@ -2022,7 +2134,9 @@ public:
|
|||||||
int W = width / vae_scale_factor;
|
int W = width / vae_scale_factor;
|
||||||
int H = height / vae_scale_factor;
|
int H = height / vae_scale_factor;
|
||||||
int T = frames;
|
int T = frames;
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_ltxav(version)) {
|
||||||
|
T = ((T - 1) / 8) + 1;
|
||||||
|
} else if (sd_version_is_wan(version)) {
|
||||||
T = ((T - 1) / 4) + 1;
|
T = ((T - 1) / 4) + 1;
|
||||||
}
|
}
|
||||||
int C = get_latent_channel();
|
int C = get_latent_channel();
|
||||||
@ -2054,9 +2168,21 @@ public:
|
|||||||
|
|
||||||
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
|
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
|
||||||
auto latents = first_stage_model->diffusion_to_vae_latents(x);
|
auto latents = first_stage_model->diffusion_to_vae_latents(x);
|
||||||
|
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
|
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> decode_ltx_audio_latent(const sd::Tensor<float>& audio_latent) {
|
||||||
|
if (audio_vae_model == nullptr || audio_latent.empty()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto waveform = audio_vae_model->decode(n_threads, audio_latent);
|
||||||
|
if (free_params_immediately) {
|
||||||
|
audio_vae_model->free_params_buffer();
|
||||||
|
}
|
||||||
|
return waveform;
|
||||||
|
}
|
||||||
|
|
||||||
void set_flow_shift(float flow_shift = INFINITY) {
|
void set_flow_shift(float flow_shift = INFINITY) {
|
||||||
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
|
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
|
||||||
if (flow_denoiser) {
|
if (flow_denoiser) {
|
||||||
@ -2164,6 +2290,7 @@ const char* scheduler_to_str[] = {
|
|||||||
"kl_optimal",
|
"kl_optimal",
|
||||||
"lcm",
|
"lcm",
|
||||||
"bong_tangent",
|
"bong_tangent",
|
||||||
|
"ltx2",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
||||||
@ -2364,7 +2491,9 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"llm_vision_path: %s\n"
|
"llm_vision_path: %s\n"
|
||||||
"diffusion_model_path: %s\n"
|
"diffusion_model_path: %s\n"
|
||||||
"high_noise_diffusion_model_path: %s\n"
|
"high_noise_diffusion_model_path: %s\n"
|
||||||
|
"embeddings_connectors_path: %s\n"
|
||||||
"vae_path: %s\n"
|
"vae_path: %s\n"
|
||||||
|
"audio_vae_path: %s\n"
|
||||||
"taesd_path: %s\n"
|
"taesd_path: %s\n"
|
||||||
"control_net_path: %s\n"
|
"control_net_path: %s\n"
|
||||||
"photo_maker_path: %s\n"
|
"photo_maker_path: %s\n"
|
||||||
@ -2399,7 +2528,9 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
SAFE_STR(sd_ctx_params->llm_vision_path),
|
SAFE_STR(sd_ctx_params->llm_vision_path),
|
||||||
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
||||||
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
||||||
|
SAFE_STR(sd_ctx_params->embeddings_connectors_path),
|
||||||
SAFE_STR(sd_ctx_params->vae_path),
|
SAFE_STR(sd_ctx_params->vae_path),
|
||||||
|
SAFE_STR(sd_ctx_params->audio_vae_path),
|
||||||
SAFE_STR(sd_ctx_params->taesd_path),
|
SAFE_STR(sd_ctx_params->taesd_path),
|
||||||
SAFE_STR(sd_ctx_params->control_net_path),
|
SAFE_STR(sd_ctx_params->control_net_path),
|
||||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||||
@ -2501,7 +2632,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
sd_img_gen_params->batch_count = 1;
|
sd_img_gen_params->batch_count = 1;
|
||||||
sd_img_gen_params->control_strength = 0.9f;
|
sd_img_gen_params->control_strength = 0.9f;
|
||||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||||
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||||
sd_cache_params_init(&sd_img_gen_params->cache);
|
sd_cache_params_init(&sd_img_gen_params->cache);
|
||||||
sd_hires_params_init(&sd_img_gen_params->hires);
|
sd_hires_params_init(&sd_img_gen_params->hires);
|
||||||
}
|
}
|
||||||
@ -2530,7 +2661,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
"increase_ref_index: %s\n"
|
"increase_ref_index: %s\n"
|
||||||
"control_strength: %.2f\n"
|
"control_strength: %.2f\n"
|
||||||
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
||||||
"VAE tiling: %s\n"
|
"VAE tiling: %s (temporal=%s)\n"
|
||||||
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
|
||||||
SAFE_STR(sd_img_gen_params->prompt),
|
SAFE_STR(sd_img_gen_params->prompt),
|
||||||
SAFE_STR(sd_img_gen_params->negative_prompt),
|
SAFE_STR(sd_img_gen_params->negative_prompt),
|
||||||
@ -2549,6 +2680,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
sd_img_gen_params->pm_params.id_images_count,
|
sd_img_gen_params->pm_params.id_images_count,
|
||||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
||||||
|
BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling),
|
||||||
BOOL_STR(sd_img_gen_params->hires.enabled),
|
BOOL_STR(sd_img_gen_params->hires.enabled),
|
||||||
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
|
||||||
SAFE_STR(sd_img_gen_params->hires.model_path),
|
SAFE_STR(sd_img_gen_params->hires.model_path),
|
||||||
@ -2583,9 +2715,10 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
|||||||
sd_vid_gen_params->strength = 0.75f;
|
sd_vid_gen_params->strength = 0.75f;
|
||||||
sd_vid_gen_params->seed = -1;
|
sd_vid_gen_params->seed = -1;
|
||||||
sd_vid_gen_params->video_frames = 6;
|
sd_vid_gen_params->video_frames = 6;
|
||||||
|
sd_vid_gen_params->fps = 16;
|
||||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||||
sd_vid_gen_params->vace_strength = 1.f;
|
sd_vid_gen_params->vace_strength = 1.f;
|
||||||
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||||
sd_cache_params_init(&sd_vid_gen_params->cache);
|
sd_cache_params_init(&sd_vid_gen_params->cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2594,7 +2727,7 @@ struct sd_ctx_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static bool sd_version_supports_video_generation(SDVersion version) {
|
static bool sd_version_supports_video_generation(SDVersion version) {
|
||||||
return version == VERSION_SVD || sd_version_is_wan(version);
|
return version == VERSION_SVD || sd_version_is_wan(version) || sd_version_is_ltxav(version);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool sd_version_supports_image_generation(SDVersion version) {
|
static bool sd_version_supports_image_generation(SDVersion version) {
|
||||||
@ -2630,6 +2763,45 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
|
|||||||
free(sd_ctx);
|
free(sd_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd,
|
||||||
|
const sd::Tensor<float>& waveform) {
|
||||||
|
if (sd == nullptr || waveform.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t sample_count = waveform.shape()[0];
|
||||||
|
int64_t channels = waveform.shape().size() > 1 ? waveform.shape()[1] : 1;
|
||||||
|
if (sample_count <= 0 || channels <= 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_audio_t* audio = (sd_audio_t*)malloc(sizeof(sd_audio_t));
|
||||||
|
if (audio == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
audio->sample_rate = static_cast<uint32_t>(sd->audio_vae_model != nullptr ? sd->audio_vae_model->config.output_sample_rate() : 0);
|
||||||
|
audio->channels = static_cast<uint32_t>(channels);
|
||||||
|
audio->sample_count = static_cast<uint64_t>(sample_count);
|
||||||
|
size_t sample_bytes = waveform.numel() * sizeof(float);
|
||||||
|
audio->data = (float*)malloc(sample_bytes);
|
||||||
|
if (audio->data == nullptr) {
|
||||||
|
free(audio);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::memcpy(audio->data, waveform.data(), sample_bytes);
|
||||||
|
return audio;
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_sd_audio(sd_audio_t* audio) {
|
||||||
|
if (audio == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
free(audio->data);
|
||||||
|
audio->data = nullptr;
|
||||||
|
free(audio);
|
||||||
|
}
|
||||||
|
|
||||||
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
|
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
|
||||||
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
|
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
@ -2664,6 +2836,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
|
|||||||
return LCM_SCHEDULER;
|
return LCM_SCHEDULER;
|
||||||
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
|
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
|
||||||
return SIMPLE_SCHEDULER;
|
return SIMPLE_SCHEDULER;
|
||||||
|
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
|
return LTX2_SCHEDULER;
|
||||||
}
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
}
|
}
|
||||||
@ -2743,6 +2917,8 @@ struct GenerationRequest {
|
|||||||
sd_pm_params_t pm_params = {};
|
sd_pm_params_t pm_params = {};
|
||||||
sd_hires_params_t hires = {};
|
sd_hires_params_t hires = {};
|
||||||
int frames = -1;
|
int frames = -1;
|
||||||
|
int requested_frames = -1;
|
||||||
|
int fps = 16;
|
||||||
float vace_strength = 1.f;
|
float vace_strength = 1.f;
|
||||||
|
|
||||||
GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
|
GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
|
||||||
@ -2769,20 +2945,33 @@ struct GenerationRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
|
GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||||
prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
||||||
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
||||||
width = sd_vid_gen_params->width;
|
width = sd_vid_gen_params->width;
|
||||||
height = sd_vid_gen_params->height;
|
height = sd_vid_gen_params->height;
|
||||||
frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1;
|
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
|
frames = ((requested_frames - 1 + 7) / 8) * 8 + 1;
|
||||||
|
} else {
|
||||||
|
frames = (requested_frames - 1) / 4 * 4 + 1;
|
||||||
|
}
|
||||||
clip_skip = sd_vid_gen_params->clip_skip;
|
clip_skip = sd_vid_gen_params->clip_skip;
|
||||||
|
fps = std::max(1, sd_vid_gen_params->fps);
|
||||||
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||||
diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
|
diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
|
||||||
seed = sd_vid_gen_params->seed;
|
seed = sd_vid_gen_params->seed;
|
||||||
|
strength = sd_vid_gen_params->strength;
|
||||||
cache_params = &sd_vid_gen_params->cache;
|
cache_params = &sd_vid_gen_params->cache;
|
||||||
vace_strength = sd_vid_gen_params->vace_strength;
|
vace_strength = sd_vid_gen_params->vace_strength;
|
||||||
guidance = sd_vid_gen_params->sample_params.guidance;
|
guidance = sd_vid_gen_params->sample_params.guidance;
|
||||||
high_noise_guidance = sd_vid_gen_params->high_noise_sample_params.guidance;
|
high_noise_guidance = sd_vid_gen_params->high_noise_sample_params.guidance;
|
||||||
resolve(sd_ctx);
|
resolve(sd_ctx);
|
||||||
|
if (frames != requested_frames) {
|
||||||
|
LOG_WARN("align video frames from %d to %d for %s",
|
||||||
|
requested_frames,
|
||||||
|
frames,
|
||||||
|
model_version_to_str[sd_ctx->sd->version]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void align_generation_request_size() {
|
void align_generation_request_size() {
|
||||||
@ -2980,10 +3169,16 @@ struct SamplePlan {
|
|||||||
scheduler_t scheduler = resolve_scheduler(sd_ctx,
|
scheduler_t scheduler = resolve_scheduler(sd_ctx,
|
||||||
sample_params->scheduler,
|
sample_params->scheduler,
|
||||||
sample_method);
|
sample_method);
|
||||||
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
|
int sample_seq_len = sd_ctx->sd->get_image_seq_len(request->height, request->width);
|
||||||
sd_ctx->sd->get_image_seq_len(request->height, request->width),
|
if (sd_version_is_ltxav(sd_ctx->sd->version) && request->frames > 0) {
|
||||||
scheduler,
|
int latent_frames = ((request->frames - 1) / 8) + 1;
|
||||||
sd_ctx->sd->version);
|
sample_seq_len *= latent_frames;
|
||||||
|
}
|
||||||
|
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
|
||||||
|
sample_seq_len,
|
||||||
|
scheduler,
|
||||||
|
sd_ctx->sd->version,
|
||||||
|
sample_params->extra_sample_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
eta = resolve_eta(sd_ctx, eta, sample_method);
|
eta = resolve_eta(sd_ctx, eta, sample_method);
|
||||||
@ -3017,6 +3212,7 @@ struct ImageGenerationLatents {
|
|||||||
sd::Tensor<float> init_latent;
|
sd::Tensor<float> init_latent;
|
||||||
sd::Tensor<float> concat_latent;
|
sd::Tensor<float> concat_latent;
|
||||||
sd::Tensor<float> uncond_concat_latent;
|
sd::Tensor<float> uncond_concat_latent;
|
||||||
|
sd::Tensor<float> audio_latent;
|
||||||
sd::Tensor<float> control_image;
|
sd::Tensor<float> control_image;
|
||||||
std::vector<sd::Tensor<float>> ref_images;
|
std::vector<sd::Tensor<float>> ref_images;
|
||||||
std::vector<sd::Tensor<float>> ref_latents;
|
std::vector<sd::Tensor<float>> ref_latents;
|
||||||
@ -3024,8 +3220,131 @@ struct ImageGenerationLatents {
|
|||||||
sd::Tensor<float> clip_vision_output;
|
sd::Tensor<float> clip_vision_output;
|
||||||
sd::Tensor<float> vace_context;
|
sd::Tensor<float> vace_context;
|
||||||
int64_t ref_image_num = 0;
|
int64_t ref_image_num = 0;
|
||||||
|
int audio_length = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<float>& video_latent,
|
||||||
|
const sd::Tensor<float>& audio_latent) {
|
||||||
|
if (audio_latent.empty()) {
|
||||||
|
return video_latent;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5);
|
||||||
|
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
|
||||||
|
if (video_latent.dim() == 5) {
|
||||||
|
GGML_ASSERT(video_latent.shape()[4] == 1);
|
||||||
|
}
|
||||||
|
if (audio_latent.dim() == 4) {
|
||||||
|
GGML_ASSERT(audio_latent.shape()[3] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t width = video_latent.shape()[0];
|
||||||
|
int64_t height = video_latent.shape()[1];
|
||||||
|
int64_t frames = video_latent.shape()[2];
|
||||||
|
int64_t video_ch = video_latent.shape()[3];
|
||||||
|
int64_t spatial_size = width * height * frames;
|
||||||
|
int64_t audio_values = audio_latent.numel();
|
||||||
|
int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size;
|
||||||
|
|
||||||
|
std::vector<int64_t> packed_shape = video_latent.shape();
|
||||||
|
packed_shape[3] = video_ch + extra_ch;
|
||||||
|
sd::Tensor<float> packed = sd::zeros<float>(packed_shape);
|
||||||
|
|
||||||
|
std::copy_n(video_latent.data(), video_latent.numel(), packed.data());
|
||||||
|
std::copy_n(audio_latent.data(), audio_latent.numel(), packed.data() + video_latent.numel());
|
||||||
|
return packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
static sd::Tensor<float> pack_ltxav_audio_and_video_denoise_mask(const sd::Tensor<float>& video_mask,
|
||||||
|
const sd::Tensor<float>& video_latent,
|
||||||
|
const sd::Tensor<float>& audio_latent) {
|
||||||
|
if (video_mask.empty() || audio_latent.empty()) {
|
||||||
|
return video_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5);
|
||||||
|
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
|
||||||
|
if (video_latent.dim() == 5) {
|
||||||
|
GGML_ASSERT(video_latent.shape()[4] == 1);
|
||||||
|
}
|
||||||
|
if (audio_latent.dim() == 4) {
|
||||||
|
GGML_ASSERT(audio_latent.shape()[3] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t width = video_latent.shape()[0];
|
||||||
|
int64_t height = video_latent.shape()[1];
|
||||||
|
int64_t frames = video_latent.shape()[2];
|
||||||
|
int64_t video_ch = video_latent.shape()[3];
|
||||||
|
int64_t spatial_size = width * height * frames;
|
||||||
|
int64_t audio_values = audio_latent.numel();
|
||||||
|
int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size;
|
||||||
|
|
||||||
|
GGML_ASSERT(video_mask.dim() == video_latent.dim());
|
||||||
|
GGML_ASSERT(video_mask.shape()[0] == width);
|
||||||
|
GGML_ASSERT(video_mask.shape()[1] == height);
|
||||||
|
GGML_ASSERT(video_mask.shape()[2] == frames);
|
||||||
|
if (video_mask.dim() == 5) {
|
||||||
|
GGML_ASSERT(video_mask.shape()[4] == video_latent.shape()[4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t mask_ch = video_mask.shape()[3];
|
||||||
|
if (mask_ch == video_ch + extra_ch) {
|
||||||
|
return video_mask;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(mask_ch == 1 || mask_ch == video_ch);
|
||||||
|
|
||||||
|
sd::Tensor<float> video_mask_full = video_mask;
|
||||||
|
if (mask_ch == 1 && video_ch != 1) {
|
||||||
|
video_mask_full = video_mask * sd::Tensor<float>::ones(video_latent.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> audio_mask_shape = video_latent.shape();
|
||||||
|
audio_mask_shape[3] = extra_ch;
|
||||||
|
auto audio_mask = sd::Tensor<float>::ones(audio_mask_shape);
|
||||||
|
return sd::ops::concat(video_mask_full, audio_mask, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent,
|
||||||
|
int audio_length,
|
||||||
|
int video_channels) {
|
||||||
|
if (packed_latent.empty() || audio_length <= 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(packed_latent.dim() == 4 || packed_latent.dim() == 5);
|
||||||
|
int64_t width = packed_latent.shape()[0];
|
||||||
|
int64_t height = packed_latent.shape()[1];
|
||||||
|
int64_t frames = packed_latent.shape()[2];
|
||||||
|
int64_t total_channels = packed_latent.shape()[3];
|
||||||
|
int64_t spatial_size = width * height * frames;
|
||||||
|
if (total_channels <= video_channels) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int kLtxavAudioFrequencyBins = 16;
|
||||||
|
constexpr int kLtxavAudioChannels = 8;
|
||||||
|
int64_t required_values = static_cast<int64_t>(audio_length) * kLtxavAudioFrequencyBins * kLtxavAudioChannels;
|
||||||
|
int64_t packed_values = (total_channels - video_channels) * spatial_size;
|
||||||
|
if (packed_values < required_values) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> audio_latent({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
|
||||||
|
const float* audio_src = packed_latent.data() + static_cast<size_t>(video_channels) * static_cast<size_t>(spatial_size);
|
||||||
|
std::copy_n(audio_src, static_cast<size_t>(required_values), audio_latent.data());
|
||||||
|
return audio_latent;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_ltxav_num_audio_latents(int frames, int fps) {
|
||||||
|
GGML_ASSERT(frames > 0);
|
||||||
|
GGML_ASSERT(fps > 0);
|
||||||
|
constexpr float kSampleRate = 16000.0f;
|
||||||
|
constexpr float kMelHopLength = 160.0f;
|
||||||
|
constexpr float kAudioLatentDownsample = 4.0f;
|
||||||
|
constexpr float kLatentsPerSecond = kSampleRate / kMelHopLength / kAudioLatentDownsample;
|
||||||
|
return static_cast<int>(std::ceil((static_cast<float>(frames) / static_cast<float>(fps)) * kLatentsPerSecond));
|
||||||
|
}
|
||||||
|
|
||||||
struct ImageGenerationEmbeds {
|
struct ImageGenerationEmbeds {
|
||||||
SDCondition cond;
|
SDCondition cond;
|
||||||
SDCondition uncond;
|
SDCondition uncond;
|
||||||
@ -3617,6 +3936,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
|
|||||||
latents.denoise_mask,
|
latents.denoise_mask,
|
||||||
sd::Tensor<float>(),
|
sd::Tensor<float>(),
|
||||||
1.f,
|
1.f,
|
||||||
|
0,
|
||||||
|
static_cast<float>(request.fps),
|
||||||
request.cache_params);
|
request.cache_params);
|
||||||
int64_t sampling_end = ggml_time_ms();
|
int64_t sampling_end = ggml_time_ms();
|
||||||
if (!x_0.empty()) {
|
if (!x_0.empty()) {
|
||||||
@ -3676,7 +3997,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
|
|||||||
hires_steps,
|
hires_steps,
|
||||||
sd_ctx->sd->get_image_seq_len(request.hires.target_height, request.hires.target_width),
|
sd_ctx->sd->get_image_seq_len(request.hires.target_height, request.hires.target_width),
|
||||||
sd_img_gen_params->sample_params.scheduler,
|
sd_img_gen_params->sample_params.scheduler,
|
||||||
sd_ctx->sd->version);
|
sd_ctx->sd->version,
|
||||||
|
sd_img_gen_params->sample_params.extra_sample_args);
|
||||||
|
|
||||||
size_t t_enc = static_cast<size_t>(hires_steps * request.hires.denoising_strength);
|
size_t t_enc = static_cast<size_t>(hires_steps * request.hires.denoising_strength);
|
||||||
if (t_enc >= static_cast<size_t>(hires_steps)) {
|
if (t_enc >= static_cast<size_t>(hires_steps)) {
|
||||||
@ -3743,6 +4065,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
|
|||||||
hires_denoise_mask,
|
hires_denoise_mask,
|
||||||
sd::Tensor<float>(),
|
sd::Tensor<float>(),
|
||||||
1.f,
|
1.f,
|
||||||
|
0,
|
||||||
|
static_cast<float>(request.fps),
|
||||||
request.cache_params);
|
request.cache_params);
|
||||||
int64_t hires_sample_end = ggml_time_ms();
|
int64_t hires_sample_end = ggml_time_ms();
|
||||||
if (!x_0.empty()) {
|
if (!x_0.empty()) {
|
||||||
@ -3801,6 +4125,57 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
|||||||
end_image = sd_image_to_tensor(sd_vid_gen_params->end_image, request->width, request->height);
|
end_image = sd_image_to_tensor(sd_vid_gen_params->end_image, request->width, request->height);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
|
constexpr int kLtxavAudioFrequencyBins = 16;
|
||||||
|
constexpr int kLtxavAudioChannels = 8;
|
||||||
|
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
|
||||||
|
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
|
if (!end_image.empty() || sd_vid_gen_params->control_frames_size > 0) {
|
||||||
|
LOG_ERROR("LTXAV currently supports txt2vid and init_image i2v only; end_image and control_frames are not implemented");
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!start_image.empty()) {
|
||||||
|
if (sd_ctx->sd->vae_decode_only) {
|
||||||
|
LOG_ERROR("LTXAV init_image i2v requires VAE encoder weights; create the context with vae_decode_only=false");
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("IMG2VID");
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
auto init_img = start_image.reshape({start_image.shape()[0],
|
||||||
|
start_image.shape()[1],
|
||||||
|
1,
|
||||||
|
start_image.shape()[2],
|
||||||
|
start_image.shape()[3]});
|
||||||
|
auto init_image_latent = sd_ctx->sd->encode_first_stage(init_img);
|
||||||
|
if (init_image_latent.empty()) {
|
||||||
|
LOG_ERROR("failed to encode LTXAV init image");
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
|
||||||
|
sd::ops::slice_assign(&latents.init_latent, 2, 0, init_image_latent.shape()[2], init_image_latent);
|
||||||
|
|
||||||
|
float conditioning_strength = std::clamp(request->strength, 0.f, 1.f);
|
||||||
|
float conditioned_mask = 1.0f - conditioning_strength;
|
||||||
|
latents.denoise_mask = sd::full<float>({latents.init_latent.shape()[0],
|
||||||
|
latents.init_latent.shape()[1],
|
||||||
|
latents.init_latent.shape()[2],
|
||||||
|
1,
|
||||||
|
1},
|
||||||
|
1.f);
|
||||||
|
sd::ops::fill_slice(&latents.denoise_mask, 2, 0, init_image_latent.shape()[2], conditioned_mask);
|
||||||
|
|
||||||
|
int64_t t2 = ggml_time_ms();
|
||||||
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
|
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
|
||||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
|
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
|
||||||
@ -3971,6 +4346,15 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
|||||||
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
|
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) {
|
||||||
|
if (!latents.denoise_mask.empty()) {
|
||||||
|
latents.denoise_mask = pack_ltxav_audio_and_video_denoise_mask(latents.denoise_mask,
|
||||||
|
latents.init_latent,
|
||||||
|
latents.audio_latent);
|
||||||
|
}
|
||||||
|
latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
|
||||||
|
}
|
||||||
|
|
||||||
return latents;
|
return latents;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4007,14 +4391,26 @@ static ImageGenerationEmbeds prepare_video_generation_embeds(sd_ctx_t* sd_ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
||||||
|
const GenerationRequest& request,
|
||||||
const sd::Tensor<float>& final_latent,
|
const sd::Tensor<float>& final_latent,
|
||||||
int* num_frames_out) {
|
int* num_frames_out) {
|
||||||
if (final_latent.empty()) {
|
if (final_latent.empty()) {
|
||||||
LOG_ERROR("no latent video to decode");
|
LOG_ERROR("no latent video to decode");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
sd::Tensor<float> video_latent = final_latent;
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
||||||
|
video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) {
|
||||||
|
video_latent = sd::ops::slice(video_latent, 3, 0, sd_ctx->sd->get_latent_channel());
|
||||||
|
}
|
||||||
|
LOG_DEBUG("decode_video_outputs latent %dx%dx%dx%d",
|
||||||
|
(int)video_latent.shape()[0],
|
||||||
|
(int)video_latent.shape()[1],
|
||||||
|
(int)video_latent.shape()[2],
|
||||||
|
(int)video_latent.shape()[3]);
|
||||||
|
// auto z = sd::load_tensor_from_file_as_tensor<float>("ltx_vae_z.bin");
|
||||||
int64_t t4 = ggml_time_ms();
|
int64_t t4 = ggml_time_ms();
|
||||||
sd::Tensor<float> vid = sd_ctx->sd->decode_first_stage(final_latent, true);
|
sd::Tensor<float> vid = sd_ctx->sd->decode_first_stage(video_latent, true);
|
||||||
int64_t t5 = ggml_time_ms();
|
int64_t t5 = ggml_time_ms();
|
||||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
@ -4024,6 +4420,15 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
|||||||
LOG_ERROR("decode_first_stage failed for video");
|
LOG_ERROR("decode_first_stage failed for video");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
LOG_DEBUG("decode_video_outputs decoded %dx%dx%dx%d",
|
||||||
|
(int)vid.shape()[0],
|
||||||
|
(int)vid.shape()[1],
|
||||||
|
(int)vid.shape()[2],
|
||||||
|
(int)vid.shape()[3]);
|
||||||
|
if (request.requested_frames > 0 &&
|
||||||
|
vid.shape()[2] > request.requested_frames) {
|
||||||
|
vid = sd::ops::slice(vid, 2, 0, request.requested_frames);
|
||||||
|
}
|
||||||
|
|
||||||
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
|
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
|
||||||
if (result_images == nullptr) {
|
if (result_images == nullptr) {
|
||||||
@ -4040,9 +4445,19 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
|||||||
return result_images;
|
return result_images;
|
||||||
}
|
}
|
||||||
|
|
||||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) {
|
SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
||||||
|
const sd_vid_gen_params_t* sd_vid_gen_params,
|
||||||
|
sd_image_t** frames_out,
|
||||||
|
int* num_frames_out,
|
||||||
|
sd_audio_t** audio_out) {
|
||||||
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
||||||
return nullptr;
|
return false;
|
||||||
|
}
|
||||||
|
if (frames_out != nullptr) {
|
||||||
|
*frames_out = nullptr;
|
||||||
|
}
|
||||||
|
if (audio_out != nullptr) {
|
||||||
|
*audio_out = nullptr;
|
||||||
}
|
}
|
||||||
if (num_frames_out != nullptr) {
|
if (num_frames_out != nullptr) {
|
||||||
*num_frames_out = 0;
|
*num_frames_out = 0;
|
||||||
@ -4058,7 +4473,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
SamplePlan plan(sd_ctx, sd_vid_gen_params, request);
|
SamplePlan plan(sd_ctx, sd_vid_gen_params, request);
|
||||||
auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request);
|
auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request);
|
||||||
if (!latent_inputs_opt.has_value()) {
|
if (!latent_inputs_opt.has_value()) {
|
||||||
return nullptr;
|
return false;
|
||||||
}
|
}
|
||||||
ImageGenerationLatents latents = std::move(*latent_inputs_opt);
|
ImageGenerationLatents latents = std::move(*latent_inputs_opt);
|
||||||
ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx,
|
ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx,
|
||||||
@ -4108,6 +4523,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
latents.denoise_mask,
|
latents.denoise_mask,
|
||||||
latents.vace_context,
|
latents.vace_context,
|
||||||
request.vace_strength,
|
request.vace_strength,
|
||||||
|
latents.audio_length,
|
||||||
|
static_cast<float>(request.fps),
|
||||||
request.cache_params);
|
request.cache_params);
|
||||||
int64_t sampling_end = ggml_time_ms();
|
int64_t sampling_end = ggml_time_ms();
|
||||||
if (x_t_sampled.empty()) {
|
if (x_t_sampled.empty()) {
|
||||||
@ -4115,7 +4532,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
sd_ctx->sd->high_noise_diffusion_model->free_params_buffer();
|
sd_ctx->sd->high_noise_diffusion_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
return nullptr;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
x_t = std::move(x_t_sampled);
|
x_t = std::move(x_t_sampled);
|
||||||
@ -4151,6 +4568,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
latents.denoise_mask,
|
latents.denoise_mask,
|
||||||
latents.vace_context,
|
latents.vace_context,
|
||||||
request.vace_strength,
|
request.vace_strength,
|
||||||
|
latents.audio_length,
|
||||||
|
static_cast<float>(request.fps),
|
||||||
request.cache_params);
|
request.cache_params);
|
||||||
|
|
||||||
int64_t sampling_end = ggml_time_ms();
|
int64_t sampling_end = ggml_time_ms();
|
||||||
@ -4159,10 +4578,27 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
if (final_latent.empty()) {
|
if (final_latent.empty()) {
|
||||||
LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||||
return nullptr;
|
return false;
|
||||||
}
|
}
|
||||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||||
|
|
||||||
|
sd_audio_t* generated_audio = nullptr;
|
||||||
|
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
||||||
|
latents.audio_length > 0 &&
|
||||||
|
sd_ctx->sd->audio_vae_model != nullptr) {
|
||||||
|
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
||||||
|
latents.audio_length,
|
||||||
|
sd_ctx->sd->get_latent_channel());
|
||||||
|
if (!audio_latent.empty()) {
|
||||||
|
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
||||||
|
if (!waveform.empty()) {
|
||||||
|
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
||||||
|
} else {
|
||||||
|
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (latents.ref_image_num > 0) {
|
if (latents.ref_image_num > 0) {
|
||||||
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
||||||
}
|
}
|
||||||
@ -4170,14 +4606,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
int64_t latent_end = ggml_time_ms();
|
int64_t latent_end = ggml_time_ms();
|
||||||
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
||||||
|
|
||||||
auto result = decode_video_outputs(sd_ctx, final_latent, num_frames_out);
|
auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out);
|
||||||
if (result == nullptr) {
|
if (result == nullptr) {
|
||||||
return nullptr;
|
free_sd_audio(generated_audio);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_ctx->sd->lora_stat();
|
sd_ctx->sd->lora_stat();
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
return result;
|
if (frames_out != nullptr) {
|
||||||
|
*frames_out = result;
|
||||||
|
}
|
||||||
|
if (audio_out != nullptr) {
|
||||||
|
*audio_out = generated_audio;
|
||||||
|
} else {
|
||||||
|
free_sd_audio(generated_audio);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
#define __TAE_HPP__
|
#define __TAE_HPP__
|
||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|||||||
@ -104,7 +104,7 @@ namespace sd {
|
|||||||
throw std::invalid_argument("tensor file type does not match requested sd::Tensor type");
|
throw std::invalid_argument("tensor file type does not match requested sd::Tensor type");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> shape(4, 1);
|
std::vector<int64_t> shape(n_dims, 1);
|
||||||
for (int i = 0; i < n_dims; ++i) {
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
int32_t dim = 1;
|
int32_t dim = 1;
|
||||||
file.read(reinterpret_cast<char*>(&dim), sizeof(dim));
|
file.read(reinterpret_cast<char*>(&dim), sizeof(dim));
|
||||||
|
|||||||
@ -162,13 +162,37 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
|
|||||||
|
|
||||||
std::string token_str = token;
|
std::string token_str = token;
|
||||||
std::u32string utf32_token;
|
std::u32string utf32_token;
|
||||||
for (int i = 0; i < static_cast<int>(token_str.length()); i++) {
|
if (byte_level_bpe) {
|
||||||
unsigned char b = token_str[i];
|
for (int i = 0; i < token_str.length(); i++) {
|
||||||
utf32_token += byte_encoder[b];
|
unsigned char b = token_str[i];
|
||||||
|
utf32_token += byte_encoder[b];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
utf32_token = utf8_to_utf32(token_str);
|
||||||
}
|
}
|
||||||
auto bpe_strs = bpe(utf32_token);
|
auto bpe_strs = bpe(utf32_token);
|
||||||
for (auto bpe_str : bpe_strs) {
|
for (auto bpe_str : bpe_strs) {
|
||||||
bpe_tokens.push_back(encoder[bpe_str]);
|
int token_id;
|
||||||
|
auto iter = encoder.find(bpe_str);
|
||||||
|
if (iter != encoder.end()) {
|
||||||
|
token_id = iter->second;
|
||||||
|
} else {
|
||||||
|
if (byte_fallback) {
|
||||||
|
auto utf8_token_str = utf32_to_utf8(bpe_str);
|
||||||
|
for (int i = 0; i < utf8_token_str.length(); i++) {
|
||||||
|
unsigned char b = utf8_token_str[i];
|
||||||
|
char hex_buf[16];
|
||||||
|
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
|
||||||
|
iter = encoder.find(utf8_to_utf32(hex_buf));
|
||||||
|
bpe_tokens.push_back(token_id);
|
||||||
|
token_strs.push_back(hex_buf);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
token_id = UNK_TOKEN_ID;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bpe_tokens.push_back(token_id);
|
||||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,8 +20,10 @@ protected:
|
|||||||
std::map<std::u32string, int> encoder;
|
std::map<std::u32string, int> encoder;
|
||||||
std::map<int, std::u32string> decoder;
|
std::map<int, std::u32string> decoder;
|
||||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||||
int encoder_len = 0;
|
int encoder_len = 0;
|
||||||
int bpe_len = 0;
|
int bpe_len = 0;
|
||||||
|
bool byte_level_bpe = true;
|
||||||
|
bool byte_fallback = false;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static std::vector<std::pair<int, std::u32string>> bytes_to_unicode();
|
static std::vector<std::pair<int, std::u32string>> bytes_to_unicode();
|
||||||
|
|||||||
191
src/tokenizers/gemma_tokenizer.cpp
Normal file
191
src/tokenizers/gemma_tokenizer.cpp
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#include "gemma_tokenizer.h"
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
std::string GemmaTokenizer::normalize(const std::string& text) const {
|
||||||
|
std::string normalized = text;
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = normalized.find(' ', pos)) != std::string::npos) {
|
||||||
|
normalized.replace(pos, 1, "\xE2\x96\x81");
|
||||||
|
pos += 3;
|
||||||
|
}
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GemmaTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
nlohmann::json vocab;
|
||||||
|
try {
|
||||||
|
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||||
|
} catch (const nlohmann::json::parse_error&) {
|
||||||
|
GGML_ABORT("invalid vocab json str");
|
||||||
|
}
|
||||||
|
for (const auto& [key, value] : vocab.items()) {
|
||||||
|
std::u32string token = utf8_to_utf32(key);
|
||||||
|
int i = value;
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
}
|
||||||
|
encoder_len = static_cast<int>(vocab.size());
|
||||||
|
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||||
|
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
GemmaTokenizer::GemmaTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
byte_level_bpe = false;
|
||||||
|
byte_fallback = true;
|
||||||
|
add_bos_token = true;
|
||||||
|
pad_left = true;
|
||||||
|
PAD_TOKEN = "<pad>";
|
||||||
|
EOS_TOKEN = "<eos>";
|
||||||
|
BOS_TOKEN = "<bos>";
|
||||||
|
UNK_TOKEN = "<unk>";
|
||||||
|
|
||||||
|
PAD_TOKEN_ID = 0;
|
||||||
|
EOS_TOKEN_ID = 1;
|
||||||
|
BOS_TOKEN_ID = 2;
|
||||||
|
UNK_TOKEN_ID = 3;
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens_before_merge = {
|
||||||
|
PAD_TOKEN,
|
||||||
|
EOS_TOKEN,
|
||||||
|
BOS_TOKEN,
|
||||||
|
UNK_TOKEN,
|
||||||
|
"<mask>",
|
||||||
|
"[multimodal]",
|
||||||
|
};
|
||||||
|
for (int i = 0; i <= 98; i++) {
|
||||||
|
special_tokens_before_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||||
|
}
|
||||||
|
special_tokens_before_merge.push_back("<start_of_turn>");
|
||||||
|
special_tokens_before_merge.push_back("<end_of_turn>");
|
||||||
|
for (int i = 1; i <= 31; i++) {
|
||||||
|
special_tokens_before_merge.push_back(std::string(i, '\n'));
|
||||||
|
}
|
||||||
|
for (int i = 2; i <= 31; i++) {
|
||||||
|
std::string whitespace_token;
|
||||||
|
for (int j = 0; j < i; j++) {
|
||||||
|
whitespace_token += "\xE2\x96\x81";
|
||||||
|
}
|
||||||
|
special_tokens_before_merge.push_back(whitespace_token);
|
||||||
|
}
|
||||||
|
std::vector<std::string> html_tokens = {
|
||||||
|
"<table>",
|
||||||
|
"<caption>",
|
||||||
|
"<thead>",
|
||||||
|
"<tbody>",
|
||||||
|
"<tfoot>",
|
||||||
|
"<tr>",
|
||||||
|
"<th>",
|
||||||
|
"<td>",
|
||||||
|
"</table>",
|
||||||
|
"</caption>",
|
||||||
|
"</thead>",
|
||||||
|
"</tbody>",
|
||||||
|
"</tfoot>",
|
||||||
|
"</tr>",
|
||||||
|
"</th>",
|
||||||
|
"</td>",
|
||||||
|
"<h1>",
|
||||||
|
"<h2>",
|
||||||
|
"<h3>",
|
||||||
|
"<h4>",
|
||||||
|
"<h5>",
|
||||||
|
"<h6>",
|
||||||
|
"<blockquote>",
|
||||||
|
"</h1>",
|
||||||
|
"</h2>",
|
||||||
|
"</h3>",
|
||||||
|
"</h4>",
|
||||||
|
"</h5>",
|
||||||
|
"</h6>",
|
||||||
|
"</blockquote>",
|
||||||
|
"<strong>",
|
||||||
|
"<em>",
|
||||||
|
"<b>",
|
||||||
|
"<i>",
|
||||||
|
"<u>",
|
||||||
|
"<s>",
|
||||||
|
"<sub>",
|
||||||
|
"<sup>",
|
||||||
|
"<code>",
|
||||||
|
"</strong>",
|
||||||
|
"</em>",
|
||||||
|
"</b>",
|
||||||
|
"</i>",
|
||||||
|
"</u>",
|
||||||
|
"</s>",
|
||||||
|
"</sub>",
|
||||||
|
"</sup>",
|
||||||
|
"</code>",
|
||||||
|
"<a>",
|
||||||
|
"<html>",
|
||||||
|
"<body>",
|
||||||
|
"<img>",
|
||||||
|
"<span>",
|
||||||
|
"<bbox>",
|
||||||
|
"<ul>",
|
||||||
|
"<li>",
|
||||||
|
"<div>",
|
||||||
|
"<iframe>",
|
||||||
|
"<footer>",
|
||||||
|
"</a>",
|
||||||
|
"</html>",
|
||||||
|
"</body>",
|
||||||
|
"</img>",
|
||||||
|
"</span>",
|
||||||
|
"</bbox>",
|
||||||
|
"</ul>",
|
||||||
|
"</li>",
|
||||||
|
"</div>",
|
||||||
|
"</iframe>",
|
||||||
|
"</footer>",
|
||||||
|
};
|
||||||
|
special_tokens_before_merge.insert(special_tokens_before_merge.end(),
|
||||||
|
html_tokens.begin(),
|
||||||
|
html_tokens.end());
|
||||||
|
for (int i = 0; i <= 0xFF; i++) {
|
||||||
|
char hex_buf[16];
|
||||||
|
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i);
|
||||||
|
special_tokens_before_merge.push_back(hex_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens_after_merge = {
|
||||||
|
"<start_of_image>",
|
||||||
|
"<end_of_image>",
|
||||||
|
};
|
||||||
|
for (int i = 1; i <= 31; i++) {
|
||||||
|
special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1,
|
||||||
|
std::string(i, '\t'));
|
||||||
|
}
|
||||||
|
for (int i = 99; i <= 6241; i++) {
|
||||||
|
special_tokens_after_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||||
|
}
|
||||||
|
special_tokens_after_merge.push_back("<image_soft_token>");
|
||||||
|
|
||||||
|
special_tokens = special_tokens_before_merge;
|
||||||
|
special_tokens.insert(special_tokens.end(),
|
||||||
|
special_tokens_after_merge.begin(),
|
||||||
|
special_tokens_after_merge.end());
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
|
||||||
|
}
|
||||||
|
}
|
||||||
17
src/tokenizers/gemma_tokenizer.h
Normal file
17
src/tokenizers/gemma_tokenizer.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_GEMMA_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_GEMMA_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
class GemmaTokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||||
|
std::string normalize(const std::string& text) const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__
|
||||||
3
src/tokenizers/vocab/gemma_merges.hpp
Normal file
3
src/tokenizers/vocab/gemma_merges.hpp
Normal file
File diff suppressed because one or more lines are too long
3
src/tokenizers/vocab/gemma_vocab.hpp
Normal file
3
src/tokenizers/vocab/gemma_vocab.hpp
Normal file
File diff suppressed because one or more lines are too long
@ -1,5 +1,7 @@
|
|||||||
#include "vocab.h"
|
#include "vocab.h"
|
||||||
#include "clip_t5.hpp"
|
#include "clip_t5.hpp"
|
||||||
|
#include "gemma_merges.hpp"
|
||||||
|
#include "gemma_vocab.hpp"
|
||||||
#include "mistral.hpp"
|
#include "mistral.hpp"
|
||||||
#include "qwen.hpp"
|
#include "qwen.hpp"
|
||||||
#include "umt5.hpp"
|
#include "umt5.hpp"
|
||||||
@ -33,3 +35,13 @@ std::string load_umt5_tokenizer_json() {
|
|||||||
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
|
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
|
||||||
return json_str;
|
return json_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string load_gemma_merges() {
|
||||||
|
std::string merges_utf8_str(reinterpret_cast<const char*>(gemma_merges_utf8_c_str), sizeof(gemma_merges_utf8_c_str));
|
||||||
|
return merges_utf8_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string load_gemma_vocab_json() {
|
||||||
|
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
|
||||||
|
return json_str;
|
||||||
|
}
|
||||||
@ -9,5 +9,7 @@ std::string load_mistral_merges();
|
|||||||
std::string load_mistral_vocab_json();
|
std::string load_mistral_vocab_json();
|
||||||
std::string load_t5_tokenizer_json();
|
std::string load_t5_tokenizer_json();
|
||||||
std::string load_umt5_tokenizer_json();
|
std::string load_umt5_tokenizer_json();
|
||||||
|
std::string load_gemma_merges();
|
||||||
|
std::string load_gemma_vocab_json();
|
||||||
|
|
||||||
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||||
@ -67,7 +67,9 @@ public:
|
|||||||
|
|
||||||
int get_scale_factor() {
|
int get_scale_factor() {
|
||||||
int scale_factor = 8;
|
int scale_factor = 8;
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
if (version == VERSION_LTXAV) {
|
||||||
|
scale_factor = 32;
|
||||||
|
} else if (version == VERSION_WAN2_2_TI2V) {
|
||||||
scale_factor = 16;
|
scale_factor = 16;
|
||||||
} else if (sd_version_uses_flux2_vae(version)) {
|
} else if (sd_version_uses_flux2_vae(version)) {
|
||||||
scale_factor = 16;
|
scale_factor = 16;
|
||||||
@ -213,6 +215,7 @@ public:
|
|||||||
virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0;
|
virtual sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||||
|
virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); };
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FakeVAE : public VAE {
|
struct FakeVAE : public VAE {
|
||||||
|
|||||||
16
src/wan.hpp
16
src/wan.hpp
@ -972,10 +972,10 @@ namespace WAN {
|
|||||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* patchify(ggml_context* ctx,
|
static ggml_tensor* patchify(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
int64_t patch_size,
|
int64_t patch_size,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
// x: [b*c, f, h*q, w*r]
|
// x: [b*c, f, h*q, w*r]
|
||||||
// return: [b*c*r*q, f, h, w]
|
// return: [b*c*r*q, f, h, w]
|
||||||
if (patch_size == 1) {
|
if (patch_size == 1) {
|
||||||
@ -999,10 +999,10 @@ namespace WAN {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* unpatchify(ggml_context* ctx,
|
static ggml_tensor* unpatchify(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
int64_t patch_size,
|
int64_t patch_size,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
// x: [b*c*r*q, f, h, w]
|
// x: [b*c*r*q, f, h, w]
|
||||||
// return: [b*c, f, h*q, w*r]
|
// return: [b*c, f, h*q, w*r]
|
||||||
if (patch_size == 1) {
|
if (patch_size == 1) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user