mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 18:53:36 +00:00
feat: add taehv support for Wan/Qwen (#937)
This commit is contained in:
parent
a23262dfde
commit
9fa7f415df
@ -31,6 +31,7 @@ Context Options:
|
|||||||
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
|
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
|
||||||
--vae <string> path to standalone vae model
|
--vae <string> path to standalone vae model
|
||||||
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
||||||
|
--tae <string> alias of --taesd
|
||||||
--control-net <string> path to control net model
|
--control-net <string> path to control net model
|
||||||
--embd-dir <string> embeddings directory
|
--embd-dir <string> embeddings directory
|
||||||
--lora-model-dir <string> lora model directory
|
--lora-model-dir <string> lora model directory
|
||||||
|
|||||||
@ -406,6 +406,10 @@ struct SDContextParams {
|
|||||||
"--taesd",
|
"--taesd",
|
||||||
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
||||||
&taesd_path},
|
&taesd_path},
|
||||||
|
{"",
|
||||||
|
"--tae",
|
||||||
|
"alias of --taesd",
|
||||||
|
&taesd_path},
|
||||||
{"",
|
{"",
|
||||||
"--control-net",
|
"--control-net",
|
||||||
"path to control net model",
|
"path to control net model",
|
||||||
|
|||||||
@ -24,6 +24,7 @@ Context Options:
|
|||||||
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
|
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
|
||||||
--vae <string> path to standalone vae model
|
--vae <string> path to standalone vae model
|
||||||
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
||||||
|
--tae <string> alias of --taesd
|
||||||
--control-net <string> path to control net model
|
--control-net <string> path to control net model
|
||||||
--embd-dir <string> embeddings directory
|
--embd-dir <string> embeddings directory
|
||||||
--lora-model-dir <string> lora model directory
|
--lora-model-dir <string> lora model directory
|
||||||
|
|||||||
@ -562,14 +562,27 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
if (!use_tiny_autoencoder) {
|
||||||
offload_params_to_cpu,
|
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||||
tensor_storage_map,
|
offload_params_to_cpu,
|
||||||
"first_stage_model",
|
tensor_storage_map,
|
||||||
vae_decode_only,
|
"first_stage_model",
|
||||||
version);
|
vae_decode_only,
|
||||||
first_stage_model->alloc_params_buffer();
|
version);
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->alloc_params_buffer();
|
||||||
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
} else {
|
||||||
|
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
tensor_storage_map,
|
||||||
|
"decoder",
|
||||||
|
vae_decode_only,
|
||||||
|
version);
|
||||||
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
|
LOG_INFO("Using Conv2d direct in the tae model");
|
||||||
|
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
||||||
offload_params_to_cpu);
|
offload_params_to_cpu);
|
||||||
@ -596,14 +609,13 @@ public:
|
|||||||
}
|
}
|
||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
}
|
} else if (use_tiny_autoencoder) {
|
||||||
if (use_tiny_autoencoder) {
|
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
||||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
|
offload_params_to_cpu,
|
||||||
offload_params_to_cpu,
|
tensor_storage_map,
|
||||||
tensor_storage_map,
|
"decoder.layers",
|
||||||
"decoder.layers",
|
vae_decode_only,
|
||||||
vae_decode_only,
|
version);
|
||||||
version);
|
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the tae model");
|
LOG_INFO("Using Conv2d direct in the tae model");
|
||||||
tae_first_stage->set_conv2d_direct_enabled(true);
|
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||||
@ -3614,7 +3626,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||||
ggml_set_f32(denoise_mask, 1.f);
|
ggml_set_f32(denoise_mask, 1.f);
|
||||||
|
|
||||||
sd_ctx->sd->process_latent_out(init_latent);
|
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||||
|
sd_ctx->sd->process_latent_out(init_latent);
|
||||||
|
|
||||||
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
||||||
@ -3624,7 +3637,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
sd_ctx->sd->process_latent_in(init_latent);
|
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||||
|
sd_ctx->sd->process_latent_in(init_latent);
|
||||||
|
|
||||||
int64_t t2 = ggml_time_ms();
|
int64_t t2 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
@ -3847,7 +3861,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||||
int64_t t5 = ggml_time_ms();
|
int64_t t5 = ggml_time_ms();
|
||||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
400
tae.hpp
400
tae.hpp
@ -162,6 +162,311 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TPool : public UnaryBlock {
|
||||||
|
int stride;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TPool(int channels, int stride)
|
||||||
|
: stride(stride) {
|
||||||
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
|
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks["conv"]);
|
||||||
|
auto h = x;
|
||||||
|
if (stride != 1) {
|
||||||
|
h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride);
|
||||||
|
}
|
||||||
|
h = conv->forward(ctx, h);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TGrow : public UnaryBlock {
|
||||||
|
int stride;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TGrow(int channels, int stride)
|
||||||
|
: stride(stride) {
|
||||||
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
|
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks["conv"]);
|
||||||
|
auto h = conv->forward(ctx, x);
|
||||||
|
if (stride != 1) {
|
||||||
|
h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride);
|
||||||
|
}
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MemBlock : public GGMLBlock {
|
||||||
|
bool has_skip_conv = false;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MemBlock(int channels, int out_channels)
|
||||||
|
: has_skip_conv(channels != out_channels) {
|
||||||
|
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
if (has_skip_conv) {
|
||||||
|
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) {
|
||||||
|
// x: [n, channels, h, w]
|
||||||
|
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||||
|
|
||||||
|
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
|
||||||
|
h = conv0->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
h = conv1->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
h = conv2->forward(ctx, h);
|
||||||
|
|
||||||
|
auto skip = x;
|
||||||
|
if (has_skip_conv) {
|
||||||
|
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
|
||||||
|
skip = skip_conv->forward(ctx, x);
|
||||||
|
}
|
||||||
|
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t patch_size,
|
||||||
|
int64_t b = 1) {
|
||||||
|
// x: [f, b*c, h*q, w*r]
|
||||||
|
// return: [f, b*c*r*q, h, w]
|
||||||
|
if (patch_size == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
int64_t r = patch_size;
|
||||||
|
int64_t q = patch_size;
|
||||||
|
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t f = x->ne[3];
|
||||||
|
|
||||||
|
int64_t w = W / r;
|
||||||
|
int64_t h = H / q;
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f]
|
||||||
|
x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f]
|
||||||
|
x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t patch_size,
|
||||||
|
int64_t b = 1) {
|
||||||
|
// x: [f, b*c*r*q, h, w]
|
||||||
|
// return: [f, b*c, h*q, w*r]
|
||||||
|
if (patch_size == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
int64_t r = patch_size;
|
||||||
|
int64_t q = patch_size;
|
||||||
|
int64_t c = x->ne[2] / b / q / r;
|
||||||
|
int64_t f = x->ne[3];
|
||||||
|
int64_t h = x->ne[1];
|
||||||
|
int64_t w = x->ne[0];
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f]
|
||||||
|
x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
|
||||||
|
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f]
|
||||||
|
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
class TinyVideoEncoder : public UnaryBlock {
|
||||||
|
int in_channels = 3;
|
||||||
|
int hidden = 64;
|
||||||
|
int z_channels = 4;
|
||||||
|
int num_blocks = 3;
|
||||||
|
int num_layers = 3;
|
||||||
|
int patch_size = 1;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TinyVideoEncoder(int z_channels = 4, int patch_size = 1)
|
||||||
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
|
int index = 0;
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
index++; // nn.ReLU()
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
int stride = i == num_layers - 1 ? 1 : 2;
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||||
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(hidden, hidden));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blocks[std::to_string(index)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
|
||||||
|
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["0"]);
|
||||||
|
|
||||||
|
if (patch_size > 1) {
|
||||||
|
z = patchify(ctx->ggml_ctx, z, patch_size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto h = first_conv->forward(ctx, z);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
|
||||||
|
int index = 2;
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||||
|
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||||
|
|
||||||
|
h = pool->forward(ctx, h);
|
||||||
|
h = conv->forward(ctx, h);
|
||||||
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||||
|
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||||
|
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||||
|
h = block->forward(ctx, h, mem);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(index)]);
|
||||||
|
h = last_conv->forward(ctx, h);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TinyVideoDecoder : public UnaryBlock {
|
||||||
|
int z_channels = 4;
|
||||||
|
int out_channels = 3;
|
||||||
|
int num_blocks = 3;
|
||||||
|
static const int num_layers = 3;
|
||||||
|
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||||
|
int patch_size = 1;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TinyVideoDecoder(int z_channels = 4, int patch_size = 1)
|
||||||
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
|
int index = 1; // Clamp()
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
index++; // nn.ReLU()
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
int stride = i == 0 ? 1 : 2;
|
||||||
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||||
|
}
|
||||||
|
index++; // nn.Upsample()
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
|
}
|
||||||
|
index++; // nn.ReLU()
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
|
||||||
|
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
|
||||||
|
|
||||||
|
// Clamp()
|
||||||
|
auto h = ggml_scale_inplace(ctx->ggml_ctx,
|
||||||
|
ggml_tanh_inplace(ctx->ggml_ctx,
|
||||||
|
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||||
|
3.0f);
|
||||||
|
|
||||||
|
h = first_conv->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
int index = 3;
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||||
|
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||||
|
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||||
|
h = block->forward(ctx, h, mem);
|
||||||
|
}
|
||||||
|
// upsample
|
||||||
|
index++;
|
||||||
|
h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
|
auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||||
|
h = block->forward(ctx, h);
|
||||||
|
block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
|
||||||
|
h = block->forward(ctx, h);
|
||||||
|
}
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
|
||||||
|
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(++index)]);
|
||||||
|
h = last_conv->forward(ctx, h);
|
||||||
|
if (patch_size > 1) {
|
||||||
|
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
|
||||||
|
}
|
||||||
|
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
|
||||||
|
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TAEHV : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
bool decode_only;
|
||||||
|
SDVersion version;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||||
|
: decode_only(decode_only), version(version) {
|
||||||
|
int z_channels = 16;
|
||||||
|
int patch = 1;
|
||||||
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
|
z_channels = 48;
|
||||||
|
patch = 2;
|
||||||
|
}
|
||||||
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch));
|
||||||
|
if (!decode_only) {
|
||||||
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
|
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
// (W, H, C, T) -> (W, H, T, C)
|
||||||
|
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
|
||||||
|
}
|
||||||
|
auto result = decoder->forward(ctx, z);
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
// (W, H, C, T) -> (W, H, T, C)
|
||||||
|
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
|
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
|
||||||
|
// (W, H, T, C) -> (W, H, C, T)
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||||
|
int64_t num_frames = x->ne[3];
|
||||||
|
if (num_frames % 4) {
|
||||||
|
// pad to multiple of 4 at the end
|
||||||
|
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
|
||||||
|
for (int i = 0; i < 4 - num_frames % 4; i++) {
|
||||||
|
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x = encoder->forward(ctx, x);
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class TAESD : public GGMLBlock {
|
class TAESD : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
bool decode_only;
|
bool decode_only;
|
||||||
@ -192,18 +497,30 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TinyAutoEncoder : public GGMLRunner {
|
struct TinyAutoEncoder : public GGMLRunner {
|
||||||
|
TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||||
|
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||||
|
virtual bool compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx = nullptr) = 0;
|
||||||
|
|
||||||
|
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||||
TAESD taesd;
|
TAESD taesd;
|
||||||
bool decode_only = false;
|
bool decode_only = false;
|
||||||
|
|
||||||
TinyAutoEncoder(ggml_backend_t backend,
|
TinyImageAutoEncoder(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2TensorStorage& tensor_storage_map,
|
const String2TensorStorage& tensor_storage_map,
|
||||||
const std::string prefix,
|
const std::string prefix,
|
||||||
bool decoder_only = true,
|
bool decoder_only = true,
|
||||||
SDVersion version = VERSION_SD1)
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taesd(decoder_only, version),
|
taesd(decoder_only, version),
|
||||||
GGMLRunner(backend, offload_params_to_cpu) {
|
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||||
taesd.init(params_ctx, tensor_storage_map, prefix);
|
taesd.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,4 +577,73 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||||
|
TAEHV taehv;
|
||||||
|
bool decode_only = false;
|
||||||
|
|
||||||
|
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string prefix,
|
||||||
|
bool decoder_only = true,
|
||||||
|
SDVersion version = VERSION_WAN2)
|
||||||
|
: decode_only(decoder_only),
|
||||||
|
taehv(decoder_only, version),
|
||||||
|
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||||
|
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "taehv";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||||
|
LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||||
|
alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> taehv_tensors;
|
||||||
|
taehv.get_param_tensors(taehv_tensors);
|
||||||
|
std::set<std::string> ignore_tensors;
|
||||||
|
if (decode_only) {
|
||||||
|
ignore_tensors.insert("encoder.");
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path)) {
|
||||||
|
LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tae tensors from model loader failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("taehv model loaded");
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
z = to_backend(z);
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx = nullptr) {
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(z, decode_graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif // __TAE_HPP__
|
#endif // __TAE_HPP__
|
||||||
Loading…
x
Reference in New Issue
Block a user