848 lines
40 KiB
C++

#ifndef __SD_MODEL_DIFFUSION_UNET_HPP__
#define __SD_MODEL_DIFFUSION_UNET_HPP__
#include <algorithm>
#include <vector>
#include "model.h"
#include "model/common/block.hpp"
#include "model/diffusion/model.hpp"
/*==================================================== UnetModel =====================================================*/
#define UNET_GRAPH_SIZE 102400
struct UNetConfig {
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
int num_res_blocks = 2;
std::vector<int> attention_resolutions = {4, 2, 1};
std::vector<int> channel_mult = {1, 2, 4, 4};
std::vector<int> transformer_depth = {1, 1, 1, 1};
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
bool use_linear_projection = false;
bool tiny_unet = false;
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
static UNetConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
SDVersion version = VERSION_SD1) {
UNetConfig config;
config.version = version;
if (sd_version_is_sd2(version)) {
config.context_dim = 1024;
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
} else if (sd_version_is_sdxl(version)) {
config.context_dim = 2048;
config.attention_resolutions = {4, 2};
config.channel_mult = {1, 2, 4};
config.transformer_depth = {1, 2, 10};
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
if (version == VERSION_SDXL_VEGA) {
config.transformer_depth = {1, 1, 2};
}
} else if (version == VERSION_SVD) {
config.in_channels = 8;
config.out_channels = 4;
config.context_dim = 1024;
config.adm_in_channels = 768;
config.num_head_channels = 64;
config.num_heads = -1;
config.use_linear_projection = true;
}
if (sd_version_is_inpaint(version)) {
config.in_channels = 9;
} else if (sd_version_is_unet_edit(version)) {
config.in_channels = 8;
}
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
config.num_res_blocks = 1;
config.channel_mult = {1, 2, 4};
config.tiny_unet = true;
if (version == VERSION_SDXS_512_DS) {
config.attention_resolutions = {4, 2}; // here just like SDXL
}
}
auto find_weight = [&](const std::string& suffix) -> const TensorStorage* {
std::string name = prefix.empty() ? suffix : prefix + "." + suffix;
auto it = tensor_storage_map.find(name);
if (it == tensor_storage_map.end()) {
return nullptr;
}
return &it->second;
};
if (const TensorStorage* input = find_weight("input_blocks.0.0.weight")) {
if (input->n_dims == 4) {
config.in_channels = static_cast<int>(input->ne[2]);
config.model_channels = static_cast<int>(input->ne[3]);
config.time_embed_dim = config.model_channels * 4;
}
}
if (const TensorStorage* time_embed = find_weight("time_embed.0.weight")) {
if (time_embed->n_dims == 2) {
config.model_channels = static_cast<int>(time_embed->ne[0]);
config.time_embed_dim = static_cast<int>(time_embed->ne[1]);
}
}
if (const TensorStorage* label_emb = find_weight("label_emb.0.0.weight")) {
if (label_emb->n_dims == 2) {
config.adm_in_channels = static_cast<int>(label_emb->ne[0]);
config.time_embed_dim = static_cast<int>(label_emb->ne[1]);
}
}
if (const TensorStorage* out = find_weight("out.2.weight")) {
if (out->n_dims == 4) {
config.out_channels = static_cast<int>(out->ne[3]);
}
}
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (name.find("attn2.to_k.weight") != std::string::npos && tensor_storage.n_dims == 2) {
config.context_dim = static_cast<int>(tensor_storage.ne[0]);
break;
}
}
LOG_DEBUG("unet: in_channels = %d, out_channels = %d, model_channels = %d, time_embed_dim = %d, context_dim = %d, adm_in_channels = %d, num_res_blocks = %d, tiny_unet = %s",
config.in_channels,
config.out_channels,
config.model_channels,
config.time_embed_dim,
config.context_dim,
config.adm_in_channels,
config.num_res_blocks,
config.tiny_unet ? "true" : "false");
return config;
}
};
class SpatialVideoTransformer : public SpatialTransformer {
protected:
int64_t time_depth;
int max_time_embed_period;
public:
SpatialVideoTransformer(int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim,
bool use_linear,
int64_t time_depth = 1,
int max_time_embed_period = 10000)
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear),
max_time_embed_period(max_time_embed_period) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
// use_spatial_context is always True
// merge_strategy is always learned_with_images
// merge_factor is loaded from weights
// time_context_dim is always None
// ff_in is always True
// disable_self_attn is always False
// disable_temporal_crossattention is always False
int64_t inner_dim = n_head * d_head;
GGML_ASSERT(depth == time_depth);
GGML_ASSERT(in_channels == inner_dim);
int64_t time_mix_d_head = d_head;
int64_t n_time_mix_heads = n_head;
int64_t time_mix_inner_dim = time_mix_d_head * n_time_mix_heads; // equal to inner_dim
int64_t time_context_dim = context_dim;
for (int i = 0; i < time_depth; i++) {
std::string name = "time_stack." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim,
n_time_mix_heads,
time_mix_d_head,
time_context_dim,
true));
}
int64_t time_embed_dim = in_channels * 4;
blocks["time_pos_embed.0"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, time_embed_dim));
// time_pos_embed.1 is nn.SiLU()
blocks["time_pos_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, in_channels));
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* context,
int timesteps) {
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w], t == timesteps
// context: [N, max_position(aka n_context), hidden_size(aka context_dim)] aka [b*t, n_context, context_dim], t == timesteps
// t_emb: [N, in_channels] aka [b*t, in_channels]
// timesteps is num_frames
// time_context is always None
// image_only_indicator is always tensor([0.])
// transformer_options is not used
// GGML_ASSERT(ggml_n_dims(context) == 3);
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto time_pos_embed_0 = std::dynamic_pointer_cast<Linear>(blocks["time_pos_embed.0"]);
auto time_pos_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["time_pos_embed.2"]);
auto time_mixer = std::dynamic_pointer_cast<AlphaBlender>(blocks["time_mixer"]);
auto x_in = x;
int64_t n = x->ne[3];
int64_t h = x->ne[1];
int64_t w = x->ne[0];
int64_t inner_dim = n_head * d_head;
GGML_ASSERT(n == timesteps); // We compute cond and uncond separately, so batch_size==1
auto time_context = context; // [b*t, n_context, context_dim]
auto spatial_context = context;
// time_context_first_timestep = time_context[::timesteps]
auto time_context_first_timestep = ggml_view_3d(ctx->ggml_ctx,
time_context,
time_context->ne[0],
time_context->ne[1],
1,
time_context->nb[1],
time_context->nb[2],
0); // [b, n_context, context_dim]
time_context = ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32,
time_context_first_timestep->ne[0],
time_context_first_timestep->ne[1],
time_context_first_timestep->ne[2] * h * w);
time_context = ggml_repeat(ctx->ggml_ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim]
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
auto num_frames = ggml_arange(ctx->ggml_ctx, 0.f, static_cast<float>(timesteps), 1.f);
// since b is 1, no need to do repeat
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, static_cast<int>(in_channels), max_time_embed_period); // [N, in_channels]
auto emb = time_pos_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels]
emb = ggml_reshape_3d(ctx->ggml_ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels]
for (int i = 0; i < depth; i++) {
std::string transformer_name = "transformer_blocks." + std::to_string(i);
std::string time_stack_name = "time_stack." + std::to_string(i);
auto block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[transformer_name]);
auto mix_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[time_stack_name]);
x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim]
// in_channels == inner_dim
auto x_mix = x;
x_mix = ggml_add(ctx->ggml_ctx, x_mix, emb); // [N, h * w, inner_dim]
int64_t N = x_mix->ne[2];
int64_t T = timesteps;
int64_t B = N / T;
int64_t S = x_mix->ne[1];
int64_t C = x_mix->ne[0];
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c
x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim]
x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c
x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c
x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c
x = time_mixer->forward(ctx, x, x_mix); // [N, h * w, inner_dim]
}
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_add(ctx->ggml_ctx, x, x_in);
return x;
}
};
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
public:
UNetConfig config;
explicit UnetModelBlock(UNetConfig config = {})
: config(config) {
const SDVersion version = this->config.version;
const int in_channels = this->config.in_channels;
const int out_channels = this->config.out_channels;
const int num_res_blocks = this->config.num_res_blocks;
const auto& attention_resolutions = this->config.attention_resolutions;
const auto& channel_mult = this->config.channel_mult;
const auto& transformer_depth = this->config.transformer_depth;
const int time_embed_dim = this->config.time_embed_dim;
const int num_heads = this->config.num_heads;
const int num_head_channels = this->config.num_head_channels;
const int context_dim = this->config.context_dim;
const bool use_linear_projection = this->config.use_linear_projection;
const bool tiny_unet = this->config.tiny_unet;
const int model_channels = this->config.model_channels;
const int adm_in_channels = this->config.adm_in_channels;
// dims is always 2
// use_temporal_attention is always True for SVD
blocks["time_embed.0"] = std::shared_ptr<GGMLBlock>(new Linear(model_channels, time_embed_dim));
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
}
// input_blocks
blocks["input_blocks.0.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, model_channels, {3, 3}, {1, 1}, {1, 1}));
std::vector<int> input_block_chans;
input_block_chans.push_back(model_channels);
int ch = model_channels;
int input_block_idx = 0;
int ds = 1;
auto get_resblock = [&](int64_t channels, int64_t emb_channels, int64_t out_channels) -> ResBlock* {
if (version == VERSION_SVD) {
return new VideoResBlock(channels, emb_channels, out_channels);
} else {
return new ResBlock(channels, emb_channels, out_channels);
}
};
auto get_attention_layer = [&](int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim) -> SpatialTransformer* {
if (version == VERSION_SVD) {
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
} else {
if (version == VERSION_SDXS_09 && n_head == 5) {
n_head = 1; // to carry a special case of sdxs_09 into CrossAttentionLayer,
d_head = 320; // works as long the product remains equal (5*64 == 1*320)
}
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
}
};
size_t len_mults = channel_mult.size();
for (int i = 0; i < len_mults; i++) {
int mult = channel_mult[i];
for (int j = 0; j < num_res_blocks; j++) {
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
blocks[name] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, mult * model_channels));
ch = mult * model_channels;
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
int n_head = num_heads;
int d_head = ch / num_heads;
if (num_head_channels != -1) {
d_head = num_head_channels;
n_head = ch / d_head;
}
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
int td = transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i == 2) {
td = 4;
}
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
td,
context_dim));
}
input_block_chans.push_back(ch);
if (tiny_unet) {
input_block_idx++;
}
}
if (i != len_mults - 1) {
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(ch, ch));
input_block_chans.push_back(ch);
ds *= 2;
}
}
// middle blocks
int n_head = num_heads;
int d_head = ch / num_heads;
if (num_head_channels != -1) {
d_head = num_head_channels;
n_head = ch / d_head;
}
if (!tiny_unet) {
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[transformer_depth.size() - 1],
context_dim));
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
}
}
// output_blocks
int output_block_idx = 0;
for (int i = (int)len_mults - 1; i >= 0; i--) {
int mult = channel_mult[i];
for (int j = 0; j < num_res_blocks + 1; j++) {
int ich = input_block_chans.back();
input_block_chans.pop_back();
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
blocks[name] = std::shared_ptr<GGMLBlock>(get_resblock(ch + ich, time_embed_dim, mult * model_channels));
ch = mult * model_channels;
int up_sample_idx = 1;
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
int n_head = num_heads;
int d_head = ch / num_heads;
if (num_head_channels != -1) {
d_head = num_head_channels;
n_head = ch / d_head;
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
int td = transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i == 2 && (j == 0 || j == 1)) {
td = 4;
}
if (i == 1 && (j == 1 || j == 2)) {
td = 1;
}
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));
up_sample_idx++;
}
if (i > 0 && j == num_res_blocks) {
if (tiny_unet) {
output_block_idx++;
if (output_block_idx == 2) {
up_sample_idx = 1;
}
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(ch, ch));
ds /= 2;
}
output_block_idx += 1;
}
}
// out
blocks["out.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(ch)); // ch == model_channels
// out_1 is nn.SiLU()
blocks["out.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(model_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
ggml_tensor* resblock_forward(std::string name,
GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* emb,
int num_video_frames) {
if (config.version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<VideoResBlock>(blocks[name]);
return block->forward(ctx, x, emb, num_video_frames);
} else {
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
return block->forward(ctx, x, emb);
}
}
ggml_tensor* attention_layer_forward(std::string name,
GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* context,
int timesteps) {
if (config.version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
return block->forward(ctx, x, context, timesteps);
} else {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, x, context);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timesteps,
ggml_tensor* context,
ggml_tensor* c_concat = nullptr,
ggml_tensor* y = nullptr,
int num_video_frames = -1,
std::vector<ggml_tensor*> controls = {},
float control_strength = 0.f) {
// x: [N, in_channels, h, w] or [N, in_channels/2, h, w]
// timesteps: [N,]
// context: [N, max_position, hidden_size] or [1, max_position, hidden_size]. for example, [N, 77, 768]
// c_concat: [N, in_channels, h, w] or [1, in_channels, h, w]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// return: [N, out_channels, h, w]
const SDVersion version = config.version;
const int model_channels = config.model_channels;
const int num_res_blocks = config.num_res_blocks;
const auto& attention_resolutions = config.attention_resolutions;
const auto& channel_mult = config.channel_mult;
const bool tiny_unet = config.tiny_unet;
if (context != nullptr) {
if (context->ne[2] != x->ne[3]) {
context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3]));
}
}
if (c_concat != nullptr) {
if (c_concat->ne[3] != x->ne[3]) {
c_concat = ggml_repeat(ctx->ggml_ctx, c_concat, x);
}
x = ggml_concat(ctx->ggml_ctx, x, c_concat, 2);
}
if (y != nullptr) {
if (y->ne[1] != x->ne[3]) {
y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3]));
}
}
auto time_embed_0 = std::dynamic_pointer_cast<Linear>(blocks["time_embed.0"]);
auto time_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["time_embed.2"]);
auto input_blocks_0_0 = std::dynamic_pointer_cast<Conv2d>(blocks["input_blocks.0.0"]);
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim]
// SDXL/SVD
if (y != nullptr) {
auto label_embed_0 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.0"]);
auto label_embed_2 = std::dynamic_pointer_cast<Linear>(blocks["label_emb.0.2"]);
auto label_emb = label_embed_0->forward(ctx, y);
label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb);
label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim]
emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim]
}
// sd::ggml_graph_cut::mark_graph_cut(emb, "unet.prelude", "emb");
// input_blocks
std::vector<ggml_tensor*> hs;
// input block 0
auto h = input_blocks_0_0->forward(ctx, x);
sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks.0", "h");
ggml_set_name(h, "bench-start");
hs.push_back(h);
// input block 1-11
size_t len_mults = channel_mult.size();
int input_block_idx = 0;
int ds = 1;
for (int i = 0; i < len_mults; i++) {
int mult = channel_mult[i];
for (int j = 0; j < num_res_blocks; j++) {
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w]
}
sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks." + std::to_string(input_block_idx), "h");
hs.push_back(h);
}
if (tiny_unet) {
input_block_idx++;
}
if (i != len_mults - 1) {
ds *= 2;
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
auto block = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
h = block->forward(ctx, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))]
// sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks." + std::to_string(input_block_idx), "h");
hs.push_back(h);
}
}
// [N, 4*model_channels, h/8, w/8]
// middle_block
if (!tiny_unet) {
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
}
}
sd::ggml_graph_cut::mark_graph_cut(h, "unet.middle_block", "h");
if (controls.size() > 0) {
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
}
int control_offset = static_cast<int>(controls.size() - 2);
// output_blocks
int output_block_idx = 0;
for (int i = (int)len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
auto h_skip = hs.back();
hs.pop_back();
if (controls.size() > 0) {
auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--;
}
h = ggml_concat(ctx->ggml_ctx, h, h_skip, 2);
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
h = resblock_forward(name, ctx, h, emb, num_video_frames);
int up_sample_idx = 1;
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context, num_video_frames);
up_sample_idx++;
}
if (i > 0 && j == num_res_blocks) {
if (tiny_unet) {
output_block_idx++;
if (output_block_idx == 2) {
up_sample_idx = 1;
}
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + "." + std::to_string(up_sample_idx);
auto block = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
h = block->forward(ctx, h);
ds /= 2;
}
output_block_idx += 1;
sd::ggml_graph_cut::mark_graph_cut(h, "unet.output_blocks." + std::to_string(output_block_idx - 1), "h");
}
}
// out
h = out_0->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = out_2->forward(ctx, h);
ggml_set_name(h, "bench-end");
return h; // [N, out_channels, h, w]
}
};
struct UNetModelRunner : public DiffusionModelRunner {
UNetConfig config;
UnetModelBlock unet;
UNetModelRunner(ggml_backend_t backend,
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
SDVersion version = VERSION_SD1)
: DiffusionModelRunner(backend, params_backend, prefix),
config(UNetConfig::detect_from_weights(tensor_storage_map, prefix, version)),
unet(config) {
unet.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "unet";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
unet.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor = {},
const sd::Tensor<float>& c_concat_tensor = {},
const sd::Tensor<float>& y_tensor = {},
int num_video_frames = -1,
const std::vector<sd::Tensor<float>>& controls_tensor = {},
float control_strength = 0.f) {
ggml_cgraph* gf = new_graph_custom(UNET_GRAPH_SIZE);
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
ggml_tensor* context = make_optional_input(context_tensor);
ggml_tensor* c_concat = make_optional_input(c_concat_tensor);
ggml_tensor* y = make_optional_input(y_tensor);
std::vector<ggml_tensor*> controls;
controls.reserve(controls_tensor.size());
for (const auto& control_tensor : controls_tensor) {
controls.push_back(make_input(control_tensor));
}
if (num_video_frames == -1) {
num_video_frames = static_cast<int>(x->ne[3]);
}
auto runner_ctx = get_context();
ggml_tensor* out = unet.forward(&runner_ctx,
x,
timesteps,
context,
c_concat,
y,
num_video_frames,
controls,
control_strength);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context = {},
const sd::Tensor<float>& c_concat = {},
const sd::Tensor<float>& y = {},
int num_video_frames = -1,
const std::vector<sd::Tensor<float>>& controls = {},
float control_strength = 0.f) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]([N, 77, 768]) or [1, max_position, hidden_size]
// c_concat: [N, in_channels, h, w] or [1, in_channels, h, w]
// y: [N, adm_in_channels] or [1, adm_in_channels]
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
}
sd::Tensor<float> compute(int n_threads,
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
const auto* extra = diffusion_extra_as<UNetDiffusionExtra>(diffusion_params);
static const std::vector<sd::Tensor<float>> empty_controls;
return compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
tensor_or_empty(diffusion_params.context),
tensor_or_empty(diffusion_params.c_concat),
tensor_or_empty(diffusion_params.y),
extra->num_video_frames,
extra->controls ? *extra->controls : empty_controls,
extra->control_strength);
}
void test() {
ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
params.mem_buffer = nullptr;
params.no_alloc = false;
ggml_context* ctx = ggml_init(params);
GGML_ASSERT(ctx != nullptr);
{
// CPU, num_video_frames = 1, x{num_video_frames, 8, 8, 8}: Pass
// CUDA, num_video_frames = 1, x{num_video_frames, 8, 8, 8}: Pass
// CPU, num_video_frames = 3, x{num_video_frames, 8, 8, 8}: Wrong result
// CUDA, num_video_frames = 3, x{num_video_frames, 8, 8, 8}: nan
int num_video_frames = 3;
sd::Tensor<float> x({8, 8, 8, num_video_frames});
std::vector<float> timesteps_vec(num_video_frames, 999.f);
auto timesteps = sd::Tensor<float>::from_vector(timesteps_vec);
x.fill_(0.5f);
// print_ggml_tensor(x);
sd::Tensor<float> context({1024, 1, num_video_frames});
context.fill_(0.5f);
// print_ggml_tensor(context);
sd::Tensor<float> y({768, num_video_frames});
y.fill_(0.5f);
// print_ggml_tensor(y);
sd::Tensor<float> out;
int64_t t0 = ggml_time_ms();
auto out_opt = compute(8,
x,
timesteps,
context,
{},
y,
num_video_frames,
{},
0.f);
int64_t t1 = ggml_time_ms();
GGML_ASSERT(!out_opt.empty());
out = std::move(out_opt);
print_sd_tensor(out);
LOG_DEBUG("unet test done in %lldms", t1 - t0);
}
}
};
#endif // __SD_MODEL_DIFFUSION_UNET_HPP__