mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
feat: add SSD1B and tiny-sd support (#897)
* feat: add code and doc for running SSD1B models * Added some more lines to support SD1.x with TINY U-Nets too. * support SSD-1B.safetensors * fix sdv1.5 diffusers format loader --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
faabc5ad3c
commit
062490aa7c
@ -35,6 +35,7 @@ API and command-line option may change frequently.***
|
|||||||
- Image Models
|
- Image Models
|
||||||
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
|
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
|
||||||
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||||
|
- [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
|
|||||||
86
docs/distilled_sd.md
Normal file
86
docs/distilled_sd.md
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Running distilled models: SSD1B and SD1.x with tiny U-Nets
|
||||||
|
|
||||||
|
## Preface
|
||||||
|
|
||||||
|
This kind of models have a reduced U-Net part.
|
||||||
|
Unlike other SDXL models the U-Net of SSD1B has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 .
|
||||||
|
Unlike other SD 1.x models Tiny-UNet models consist of only 6 U-Net blocks, resulting in relatively smaller files (approximately 1 GB). Running these models saves almost 50% of the time. For more details, refer to the paper: https://arxiv.org/pdf/2305.15798.pdf .
|
||||||
|
|
||||||
|
## SSD1B
|
||||||
|
|
||||||
|
Unfortunately not all of this models follow the standard model parameter naming mapping.
|
||||||
|
Anyway there are some very useful SSD1B models available online, such as:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors
|
||||||
|
* https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors
|
||||||
|
|
||||||
|
Also there are useful LORAs available:
|
||||||
|
|
||||||
|
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
||||||
|
|
||||||
|
You can use this files **out-of-the-box** - unlike models in next section.
|
||||||
|
|
||||||
|
|
||||||
|
## SD1.x with tiny U-Nets
|
||||||
|
|
||||||
|
There are some Tiny SD 1.x models available online, such as:
|
||||||
|
|
||||||
|
* https://huggingface.co/segmind/tiny-sd
|
||||||
|
* https://huggingface.co/segmind/portrait-finetuned
|
||||||
|
* https://huggingface.co/nota-ai/bk-sdm-tiny
|
||||||
|
|
||||||
|
These models need some conversion, for example because partially tensors are **non contiguous** stored. To create a usable checkpoint file, follow these **easy** steps:
|
||||||
|
|
||||||
|
### Download model from Hugging Face
|
||||||
|
|
||||||
|
Download the model using Python on your computer, for example this way:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained("segmind/tiny-sd")
|
||||||
|
unet=pipe.unet
|
||||||
|
for param in unet.parameters():
|
||||||
|
param.data = param.data.contiguous() # <- important here
|
||||||
|
pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Convert that to a ckpt file
|
||||||
|
|
||||||
|
To convert the downloaded model to a checkpoint file, you need another Python script. Download the conversion script from here:
|
||||||
|
|
||||||
|
* https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
|
|
||||||
|
### Run convert script
|
||||||
|
|
||||||
|
Now, run that conversion script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert_diffusers_to_original_stable_diffusion.py \
|
||||||
|
--model_path ./segmindtiny-sd \
|
||||||
|
--checkpoint_path ./segmind_tiny-sd.ckpt --half
|
||||||
|
```
|
||||||
|
|
||||||
|
The file **segmind_tiny-sd.ckpt** will be generated and is now ready to use with sd.cpp
|
||||||
|
|
||||||
|
You can follow a similar process for other models mentioned above from Hugging Face.
|
||||||
|
|
||||||
|
|
||||||
|
### Another ckpt file on the net
|
||||||
|
|
||||||
|
There is another model file available online:
|
||||||
|
|
||||||
|
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
|
||||||
|
|
||||||
|
If you want to use that, you have to adjust some **non-contiguous tensors** first:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
|
||||||
|
for key, value in ckpt['state_dict'].items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
ckpt['state_dict'][key] = value.contiguous()
|
||||||
|
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
|
||||||
|
```
|
||||||
25
model.cpp
25
model.cpp
@ -330,6 +330,10 @@ std::string convert_cond_model_name(const std::string& name) {
|
|||||||
return new_name;
|
return new_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (new_name == "model.text_projection.weight") {
|
||||||
|
new_name = "transformer.text_model.text_projection";
|
||||||
|
}
|
||||||
|
|
||||||
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
|
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
|
||||||
new_name = open_clip_to_hf_clip_model[new_name];
|
new_name = open_clip_to_hf_clip_model[new_name];
|
||||||
}
|
}
|
||||||
@ -623,6 +627,14 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
if (starts_with(name, "diffusion_model")) {
|
if (starts_with(name, "diffusion_model")) {
|
||||||
name = "model." + name;
|
name = "model." + name;
|
||||||
}
|
}
|
||||||
|
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) {
|
||||||
|
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1,
|
||||||
|
"model.diffusion_model.output_blocks.0.1.");
|
||||||
|
}
|
||||||
|
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) {
|
||||||
|
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1,
|
||||||
|
"model.diffusion_model.output_blocks.1.1.");
|
||||||
|
}
|
||||||
// size_t pos = name.find("lora_A");
|
// size_t pos = name.find("lora_A");
|
||||||
// if (pos != std::string::npos) {
|
// if (pos != std::string::npos) {
|
||||||
// name.replace(pos, strlen("lora_A"), "lora_up");
|
// name.replace(pos, strlen("lora_A"), "lora_up");
|
||||||
@ -1776,6 +1788,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
bool is_wan = false;
|
bool is_wan = false;
|
||||||
int64_t patch_embedding_channels = 0;
|
int64_t patch_embedding_channels = 0;
|
||||||
bool has_img_emb = false;
|
bool has_img_emb = false;
|
||||||
|
bool has_middle_block_1 = false;
|
||||||
|
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
if (!(is_xl || is_flux)) {
|
if (!(is_xl || is_flux)) {
|
||||||
@ -1822,6 +1835,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SVD;
|
return VERSION_SVD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos ||
|
||||||
|
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||||
|
has_middle_block_1 = true;
|
||||||
|
}
|
||||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
|
||||||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
|
||||||
@ -1834,7 +1851,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
||||||
input_block_weight = tensor_storage;
|
input_block_weight = tensor_storage;
|
||||||
input_block_checked = true;
|
input_block_checked = true;
|
||||||
if (is_xl || is_flux) {
|
if (is_flux) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1858,6 +1875,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (is_ip2p) {
|
if (is_ip2p) {
|
||||||
return VERSION_SDXL_PIX2PIX;
|
return VERSION_SDXL_PIX2PIX;
|
||||||
}
|
}
|
||||||
|
if (!has_middle_block_1) {
|
||||||
|
return VERSION_SDXL_SSD1B;
|
||||||
|
}
|
||||||
return VERSION_SDXL;
|
return VERSION_SDXL;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1881,6 +1901,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (is_ip2p) {
|
if (is_ip2p) {
|
||||||
return VERSION_SD1_PIX2PIX;
|
return VERSION_SD1_PIX2PIX;
|
||||||
}
|
}
|
||||||
|
if (!has_middle_block_1) {
|
||||||
|
return VERSION_SD1_TINY_UNET;
|
||||||
|
}
|
||||||
return VERSION_SD1;
|
return VERSION_SD1;
|
||||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||||
if (is_inpaint) {
|
if (is_inpaint) {
|
||||||
|
|||||||
6
model.h
6
model.h
@ -23,11 +23,13 @@ enum SDVersion {
|
|||||||
VERSION_SD1,
|
VERSION_SD1,
|
||||||
VERSION_SD1_INPAINT,
|
VERSION_SD1_INPAINT,
|
||||||
VERSION_SD1_PIX2PIX,
|
VERSION_SD1_PIX2PIX,
|
||||||
|
VERSION_SD1_TINY_UNET,
|
||||||
VERSION_SD2,
|
VERSION_SD2,
|
||||||
VERSION_SD2_INPAINT,
|
VERSION_SD2_INPAINT,
|
||||||
VERSION_SDXL,
|
VERSION_SDXL,
|
||||||
VERSION_SDXL_INPAINT,
|
VERSION_SDXL_INPAINT,
|
||||||
VERSION_SDXL_PIX2PIX,
|
VERSION_SDXL_PIX2PIX,
|
||||||
|
VERSION_SDXL_SSD1B,
|
||||||
VERSION_SVD,
|
VERSION_SVD,
|
||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
VERSION_FLUX,
|
VERSION_FLUX,
|
||||||
@ -42,7 +44,7 @@ enum SDVersion {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static inline bool sd_version_is_sd1(SDVersion version) {
|
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
|
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -56,7 +58,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_sdxl(SDVersion version) {
|
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
|
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -28,11 +28,13 @@ const char* model_version_to_str[] = {
|
|||||||
"SD 1.x",
|
"SD 1.x",
|
||||||
"SD 1.x Inpaint",
|
"SD 1.x Inpaint",
|
||||||
"Instruct-Pix2Pix",
|
"Instruct-Pix2Pix",
|
||||||
|
"SD 1.x Tiny UNet",
|
||||||
"SD 2.x",
|
"SD 2.x",
|
||||||
"SD 2.x Inpaint",
|
"SD 2.x Inpaint",
|
||||||
"SDXL",
|
"SDXL",
|
||||||
"SDXL Inpaint",
|
"SDXL Inpaint",
|
||||||
"SDXL Instruct-Pix2Pix",
|
"SDXL Instruct-Pix2Pix",
|
||||||
|
"SDXL (SSD1B)",
|
||||||
"SVD",
|
"SVD",
|
||||||
"SD3.x",
|
"SD3.x",
|
||||||
"Flux",
|
"Flux",
|
||||||
|
|||||||
50
unet.hpp
50
unet.hpp
@ -204,6 +204,9 @@ public:
|
|||||||
adm_in_channels = 768;
|
adm_in_channels = 768;
|
||||||
num_head_channels = 64;
|
num_head_channels = 64;
|
||||||
num_heads = -1;
|
num_heads = -1;
|
||||||
|
} else if (version == VERSION_SD1_TINY_UNET) {
|
||||||
|
num_res_blocks = 1;
|
||||||
|
channel_mult = {1, 2, 4};
|
||||||
}
|
}
|
||||||
if (sd_version_is_inpaint(version)) {
|
if (sd_version_is_inpaint(version)) {
|
||||||
in_channels = 9;
|
in_channels = 9;
|
||||||
@ -270,13 +273,22 @@ public:
|
|||||||
n_head = ch / d_head;
|
n_head = ch / d_head;
|
||||||
}
|
}
|
||||||
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
|
||||||
|
int td = transformer_depth[i];
|
||||||
|
if (version == VERSION_SDXL_SSD1B) {
|
||||||
|
if (i == 2) {
|
||||||
|
td = 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||||
n_head,
|
n_head,
|
||||||
d_head,
|
d_head,
|
||||||
transformer_depth[i],
|
td,
|
||||||
context_dim));
|
context_dim));
|
||||||
}
|
}
|
||||||
input_block_chans.push_back(ch);
|
input_block_chans.push_back(ch);
|
||||||
|
if (version == VERSION_SD1_TINY_UNET) {
|
||||||
|
input_block_idx++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (i != len_mults - 1) {
|
if (i != len_mults - 1) {
|
||||||
input_block_idx += 1;
|
input_block_idx += 1;
|
||||||
@ -295,14 +307,17 @@ public:
|
|||||||
d_head = num_head_channels;
|
d_head = num_head_channels;
|
||||||
n_head = ch / d_head;
|
n_head = ch / d_head;
|
||||||
}
|
}
|
||||||
|
if (version != VERSION_SD1_TINY_UNET) {
|
||||||
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||||
|
if (version != VERSION_SDXL_SSD1B) {
|
||||||
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||||
n_head,
|
n_head,
|
||||||
d_head,
|
d_head,
|
||||||
transformer_depth[transformer_depth.size() - 1],
|
transformer_depth[transformer_depth.size() - 1],
|
||||||
context_dim));
|
context_dim));
|
||||||
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||||
|
}
|
||||||
|
}
|
||||||
// output_blocks
|
// output_blocks
|
||||||
int output_block_idx = 0;
|
int output_block_idx = 0;
|
||||||
for (int i = (int)len_mults - 1; i >= 0; i--) {
|
for (int i = (int)len_mults - 1; i >= 0; i--) {
|
||||||
@ -324,12 +339,27 @@ public:
|
|||||||
n_head = ch / d_head;
|
n_head = ch / d_head;
|
||||||
}
|
}
|
||||||
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
|
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim));
|
int td = transformer_depth[i];
|
||||||
|
if (version == VERSION_SDXL_SSD1B) {
|
||||||
|
if (i == 2 && (j == 0 || j == 1)) {
|
||||||
|
td = 4;
|
||||||
|
}
|
||||||
|
if (i == 1 && (j == 1 || j == 2)) {
|
||||||
|
td = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));
|
||||||
|
|
||||||
up_sample_idx++;
|
up_sample_idx++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i > 0 && j == num_res_blocks) {
|
if (i > 0 && j == num_res_blocks) {
|
||||||
|
if (version == VERSION_SD1_TINY_UNET) {
|
||||||
|
output_block_idx++;
|
||||||
|
if (output_block_idx == 2) {
|
||||||
|
up_sample_idx = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
|
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(ch, ch));
|
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(ch, ch));
|
||||||
|
|
||||||
@ -463,6 +493,9 @@ public:
|
|||||||
}
|
}
|
||||||
hs.push_back(h);
|
hs.push_back(h);
|
||||||
}
|
}
|
||||||
|
if (version == VERSION_SD1_TINY_UNET) {
|
||||||
|
input_block_idx++;
|
||||||
|
}
|
||||||
if (i != len_mults - 1) {
|
if (i != len_mults - 1) {
|
||||||
ds *= 2;
|
ds *= 2;
|
||||||
input_block_idx += 1;
|
input_block_idx += 1;
|
||||||
@ -477,10 +510,13 @@ public:
|
|||||||
// [N, 4*model_channels, h/8, w/8]
|
// [N, 4*model_channels, h/8, w/8]
|
||||||
|
|
||||||
// middle_block
|
// middle_block
|
||||||
|
if (version != VERSION_SD1_TINY_UNET) {
|
||||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
|
if (version != VERSION_SDXL_SSD1B) {
|
||||||
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||||
|
}
|
||||||
|
}
|
||||||
if (controls.size() > 0) {
|
if (controls.size() > 0) {
|
||||||
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
|
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
|
||||||
h = ggml_add(ctx, h, cs); // middle control
|
h = ggml_add(ctx, h, cs); // middle control
|
||||||
@ -516,6 +552,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (i > 0 && j == num_res_blocks) {
|
if (i > 0 && j == num_res_blocks) {
|
||||||
|
if (version == VERSION_SD1_TINY_UNET) {
|
||||||
|
output_block_idx++;
|
||||||
|
if (output_block_idx == 2) {
|
||||||
|
up_sample_idx = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
|
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
|
||||||
auto block = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user