mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 18:28:57 +00:00
Compare commits
6 Commits
2efd19978d
...
a48b4a3ade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a48b4a3ade | ||
|
|
b87fe13afd | ||
|
|
e50e1f253d | ||
|
|
c6206fb351 | ||
|
|
639091fbe9 | ||
|
|
9293016c9d |
@ -15,6 +15,9 @@ API and command-line option may change frequently.***
|
|||||||
|
|
||||||
## 🔥Important News
|
## 🔥Important News
|
||||||
|
|
||||||
|
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
|
||||||
|
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
|
||||||
|
|
||||||
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
|
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
|
||||||
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
|
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
# Running distilled models: SSD1B and SDx.x with tiny U-Nets
|
# Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets
|
||||||
|
|
||||||
## Preface
|
## Preface
|
||||||
|
|
||||||
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
||||||
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
|
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
|
||||||
|
|
||||||
## SSD1B
|
## SSD1B
|
||||||
@ -17,7 +17,17 @@ Useful LoRAs are also available:
|
|||||||
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
||||||
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
|
||||||
These files can be used out-of-the-box, unlike the models described in the next section.
|
## Vega
|
||||||
|
|
||||||
|
Segmind's Vega model is available online here:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors
|
||||||
|
|
||||||
|
VegaRT is an example for an LCM-LoRA:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
|
||||||
|
Both files can be used out-of-the-box, unlike the models described in next sections.
|
||||||
|
|
||||||
|
|
||||||
## SD1.x, SD2.x with tiny U-Nets
|
## SD1.x, SD2.x with tiny U-Nets
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
## Using ESRGAN to upscale results
|
## Using ESRGAN to upscale results
|
||||||
|
|
||||||
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon.
|
You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity.
|
||||||
|
|
||||||
- Specify the model path using the `--upscale-model PATH` parameter. example:
|
- Specify the model path using the `--upscale-model PATH` parameter. example:
|
||||||
|
|
||||||
|
|||||||
22
flux.hpp
22
flux.hpp
@ -748,7 +748,7 @@ namespace Flux {
|
|||||||
int nerf_depth = 4;
|
int nerf_depth = 4;
|
||||||
int nerf_max_freqs = 8;
|
int nerf_max_freqs = 8;
|
||||||
bool use_x0 = false;
|
bool use_x0 = false;
|
||||||
bool use_patch_size_32 = false;
|
bool fake_patch_size_x2 = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FluxParams {
|
struct FluxParams {
|
||||||
@ -786,7 +786,10 @@ namespace Flux {
|
|||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
std::pair<int, int> kernel_size = {16, 16};
|
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
|
||||||
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
|
kernel_size = {params.patch_size / 2, params.patch_size / 2};
|
||||||
|
}
|
||||||
std::pair<int, int> stride = kernel_size;
|
std::pair<int, int> stride = kernel_size;
|
||||||
|
|
||||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||||
@ -1082,7 +1085,7 @@ namespace Flux {
|
|||||||
auto img = pad_to_patch_size(ctx, x);
|
auto img = pad_to_patch_size(ctx, x);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||||
@ -1304,6 +1307,7 @@ namespace Flux {
|
|||||||
flux_params.use_mlp_silu_act = true;
|
flux_params.use_mlp_silu_act = true;
|
||||||
}
|
}
|
||||||
int64_t head_dim = 0;
|
int64_t head_dim = 0;
|
||||||
|
int64_t actual_radiance_patch_size = -1;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
if (!starts_with(tensor_name, prefix))
|
if (!starts_with(tensor_name, prefix))
|
||||||
@ -1316,10 +1320,13 @@ namespace Flux {
|
|||||||
flux_params.chroma_radiance_params.use_x0 = true;
|
flux_params.chroma_radiance_params.use_x0 = true;
|
||||||
}
|
}
|
||||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||||
LOG_DEBUG("using patch size 32 prediction");
|
LOG_DEBUG("using patch size 32");
|
||||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
|
||||||
flux_params.patch_size = 32;
|
flux_params.patch_size = 32;
|
||||||
}
|
}
|
||||||
|
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
|
||||||
|
actual_radiance_patch_size = pair.second.ne[0];
|
||||||
|
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
|
||||||
|
}
|
||||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||||
// Chroma
|
// Chroma
|
||||||
flux_params.is_chroma = true;
|
flux_params.is_chroma = true;
|
||||||
@ -1351,6 +1358,11 @@ namespace Flux {
|
|||||||
head_dim = pair.second.ne[0];
|
head_dim = pair.second.ne[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
|
||||||
|
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
|
||||||
|
LOG_DEBUG("using fake x2 patch size");
|
||||||
|
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
|
||||||
|
}
|
||||||
|
|
||||||
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||||
|
|
||||||
|
|||||||
@ -1040,6 +1040,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
int64_t patch_embedding_channels = 0;
|
int64_t patch_embedding_channels = 0;
|
||||||
bool has_img_emb = false;
|
bool has_img_emb = false;
|
||||||
bool has_middle_block_1 = false;
|
bool has_middle_block_1 = false;
|
||||||
|
bool has_output_block_311 = false;
|
||||||
bool has_output_block_71 = false;
|
bool has_output_block_71 = false;
|
||||||
|
|
||||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||||
@ -1100,6 +1101,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||||
has_middle_block_1 = true;
|
has_middle_block_1 = true;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||||
|
has_output_block_311 = true;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||||
has_output_block_71 = true;
|
has_output_block_71 = true;
|
||||||
}
|
}
|
||||||
@ -1138,6 +1142,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SDXL_PIX2PIX;
|
return VERSION_SDXL_PIX2PIX;
|
||||||
}
|
}
|
||||||
if (!has_middle_block_1) {
|
if (!has_middle_block_1) {
|
||||||
|
if (!has_output_block_311) {
|
||||||
|
return VERSION_SDXL_VEGA;
|
||||||
|
}
|
||||||
return VERSION_SDXL_SSD1B;
|
return VERSION_SDXL_SSD1B;
|
||||||
}
|
}
|
||||||
return VERSION_SDXL;
|
return VERSION_SDXL;
|
||||||
|
|||||||
3
model.h
3
model.h
@ -32,6 +32,7 @@ enum SDVersion {
|
|||||||
VERSION_SDXL,
|
VERSION_SDXL,
|
||||||
VERSION_SDXL_INPAINT,
|
VERSION_SDXL_INPAINT,
|
||||||
VERSION_SDXL_PIX2PIX,
|
VERSION_SDXL_PIX2PIX,
|
||||||
|
VERSION_SDXL_VEGA,
|
||||||
VERSION_SDXL_SSD1B,
|
VERSION_SDXL_SSD1B,
|
||||||
VERSION_SVD,
|
VERSION_SVD,
|
||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
@ -66,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_sdxl(SDVersion version) {
|
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
|
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -35,6 +35,7 @@ const char* model_version_to_str[] = {
|
|||||||
"SDXL",
|
"SDXL",
|
||||||
"SDXL Inpaint",
|
"SDXL Inpaint",
|
||||||
"SDXL Instruct-Pix2Pix",
|
"SDXL Instruct-Pix2Pix",
|
||||||
|
"SDXL (Vega)",
|
||||||
"SDXL (SSD1B)",
|
"SDXL (SSD1B)",
|
||||||
"SVD",
|
"SVD",
|
||||||
"SD3.x",
|
"SD3.x",
|
||||||
@ -623,7 +624,7 @@ public:
|
|||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
LOG_INFO("Using Conv2d direct in the vae model");
|
||||||
first_stage_model->set_conv2d_direct_enabled(true);
|
first_stage_model->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
if (version == VERSION_SDXL &&
|
if (sd_version_is_sdxl(version) &&
|
||||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||||
float vae_conv_2d_scale = 1.f / 32.f;
|
float vae_conv_2d_scale = 1.f / 32.f;
|
||||||
LOG_WARN(
|
LOG_WARN(
|
||||||
|
|||||||
54
tae.hpp
54
tae.hpp
@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
|
|||||||
protected:
|
protected:
|
||||||
int n_in;
|
int n_in;
|
||||||
int n_out;
|
int n_out;
|
||||||
|
bool use_midblock_gn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAEBlock(int n_in, int n_out)
|
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
|
||||||
: n_in(n_in), n_out(n_out) {
|
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
|
||||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||||
if (n_in != n_out) {
|
if (n_in != n_out) {
|
||||||
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
|
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
}
|
}
|
||||||
|
if (use_midblock_gn) {
|
||||||
|
int n_gn = n_in * 4;
|
||||||
|
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
|
||||||
|
// pool.2 is ReLU, handled in forward
|
||||||
|
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
// x: [n, n_in, h, w]
|
// x: [n, n_in, h, w]
|
||||||
// return: [n, n_out, h, w]
|
// return: [n, n_out, h, w]
|
||||||
|
|
||||||
|
if (use_midblock_gn) {
|
||||||
|
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
|
||||||
|
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
|
||||||
|
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
|
||||||
|
|
||||||
|
auto p = pool_0->forward(ctx, x);
|
||||||
|
p = pool_1->forward(ctx, p);
|
||||||
|
p = ggml_relu_inplace(ctx->ggml_ctx, p);
|
||||||
|
p = pool_3->forward(ctx, p);
|
||||||
|
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, p);
|
||||||
|
}
|
||||||
|
|
||||||
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||||
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
||||||
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||||
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyEncoder(int z_channels = 4)
|
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||||
: z_channels(z_channels) {
|
: z_channels(z_channels) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
@ -80,7 +101,7 @@ public:
|
|||||||
|
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||||
for (int i = 0; i < num_blocks; i++) {
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
|
|||||||
int num_blocks = 3;
|
int num_blocks = 3;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyDecoder(int z_channels = 4)
|
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||||
: z_channels(z_channels) {
|
: z_channels(z_channels) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
@ -115,7 +136,7 @@ public:
|
|||||||
index++; // nn.ReLU()
|
index++; // nn.ReLU()
|
||||||
|
|
||||||
for (int i = 0; i < num_blocks; i++) {
|
for (int i = 0; i < num_blocks; i++) {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
index++; // nn.Upsample()
|
index++; // nn.Upsample()
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
@ -470,29 +491,44 @@ public:
|
|||||||
class TAESD : public GGMLBlock {
|
class TAESD : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
bool decode_only;
|
bool decode_only;
|
||||||
|
bool taef2 = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only) {
|
: decode_only(decode_only) {
|
||||||
int z_channels = 4;
|
int z_channels = 4;
|
||||||
|
bool use_midblock_gn = false;
|
||||||
|
taef2 = sd_version_is_flux2(version);
|
||||||
|
|
||||||
if (sd_version_is_dit(version)) {
|
if (sd_version_is_dit(version)) {
|
||||||
z_channels = 16;
|
z_channels = 16;
|
||||||
}
|
}
|
||||||
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
|
if (taef2) {
|
||||||
|
z_channels = 32;
|
||||||
|
use_midblock_gn = true;
|
||||||
|
}
|
||||||
|
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
|
||||||
|
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
|
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
||||||
|
if (taef2) {
|
||||||
|
z = unpatchify(ctx->ggml_ctx, z, 2);
|
||||||
|
}
|
||||||
return decoder->forward(ctx, z);
|
return decoder->forward(ctx, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
||||||
return encoder->forward(ctx, x);
|
auto z = encoder->forward(ctx, x);
|
||||||
|
if (taef2) {
|
||||||
|
z = patchify(ctx->ggml_ctx, z, 2);
|
||||||
|
}
|
||||||
|
return z;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
7
unet.hpp
7
unet.hpp
@ -201,6 +201,9 @@ public:
|
|||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
use_linear_projection = true;
|
use_linear_projection = true;
|
||||||
|
if (version == VERSION_SDXL_VEGA) {
|
||||||
|
transformer_depth = {1, 1, 2};
|
||||||
|
}
|
||||||
} else if (version == VERSION_SVD) {
|
} else if (version == VERSION_SVD) {
|
||||||
in_channels = 8;
|
in_channels = 8;
|
||||||
out_channels = 4;
|
out_channels = 4;
|
||||||
@ -319,7 +322,7 @@ public:
|
|||||||
}
|
}
|
||||||
if (!tiny_unet) {
|
if (!tiny_unet) {
|
||||||
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||||
if (version != VERSION_SDXL_SSD1B) {
|
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||||
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||||
n_head,
|
n_head,
|
||||||
d_head,
|
d_head,
|
||||||
@ -520,7 +523,7 @@ public:
|
|||||||
// middle_block
|
// middle_block
|
||||||
if (!tiny_unet) {
|
if (!tiny_unet) {
|
||||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
if (version != VERSION_SDXL_SSD1B) {
|
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||||
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user