mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add graph cut markers for LTXAV transformer (#1534)
This commit is contained in:
parent
b3374e6a71
commit
ef92a0027e
@ -1487,6 +1487,9 @@ namespace LTXV {
|
|||||||
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor))
|
->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, av_ca_audio_timestep, av_ca_factor))
|
||||||
.first;
|
.first;
|
||||||
|
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.prelude", "vx");
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.prelude", "ax");
|
||||||
|
|
||||||
for (int i = 0; i < cfg.num_layers; i++) {
|
for (int i = 0; i < cfg.num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<BasicAVTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<BasicAVTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||||
auto out = block->forward(ctx,
|
auto out = block->forward(ctx,
|
||||||
@ -1509,6 +1512,8 @@ namespace LTXV {
|
|||||||
a_prompt_timestep_mod);
|
a_prompt_timestep_mod);
|
||||||
vx = out.first;
|
vx = out.first;
|
||||||
ax = out.second;
|
ax = out.second;
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.transformer_blocks." + std::to_string(i), "vx");
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.transformer_blocks." + std::to_string(i), "ax");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);
|
auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size);
|
||||||
|
|||||||
@ -681,8 +681,7 @@ public:
|
|||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
sd_version_is_anima(version) ||
|
sd_version_is_anima(version) ||
|
||||||
sd_version_is_ltxav(version)
|
sd_version_is_ltxav(version)) {
|
||||||
) {
|
|
||||||
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
||||||
params_backend_for(SDBackendModule::VAE),
|
params_backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
|
|||||||
@ -325,7 +325,7 @@ public:
|
|||||||
int t_downscale = 1;
|
int t_downscale = 1;
|
||||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
||||||
: z_channels(z_channels), patch_size(patch_size) {
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||||
t_downscale = 1;
|
t_downscale = 1;
|
||||||
for (bool downscale : time_downscale) {
|
for (bool downscale : time_downscale) {
|
||||||
if (downscale) {
|
if (downscale) {
|
||||||
@ -383,7 +383,7 @@ class TinyVideoDecoder : public UnaryBlock {
|
|||||||
static const int num_layers = 3;
|
static const int num_layers = 3;
|
||||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||||
int patch_size = 1;
|
int patch_size = 1;
|
||||||
int t_upscale = 1;
|
int t_upscale = 1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user