feat: add graph cut markers for LTXAV transformer (#1534)

This commit is contained in:
leejet 2026-05-20 23:22:10 +08:00 committed by GitHub
parent b3374e6a71
commit ef92a0027e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 5 deletions

View File

@ -1487,6 +1487,9 @@ namespace LTXV {
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor))
.first;
sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.prelude", "vx");
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.prelude", "ax");
for (int i = 0; i < cfg.num_layers; i++) {
auto block = std::dynamic_pointer_cast<BasicAVTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
auto out = block->forward(ctx,
@ -1509,6 +1512,8 @@ namespace LTXV {
a_prompt_timestep_mod);
vx = out.first;
ax = out.second;
sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.transformer_blocks." + std::to_string(i), "vx");
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.transformer_blocks." + std::to_string(i), "ax");
}
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);

View File

@ -681,8 +681,7 @@ public:
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_ltxav(version)
) {
sd_version_is_ltxav(version)) {
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
params_backend_for(SDBackendModule::VAE),
tensor_storage_map,