Compare commits

...

5 Commits

Author SHA1 Message Date
leejet
f6b9aa1a43 refector: optimize the usage of tensor_types 2025-07-28 23:18:29 +08:00
Wagner Bruna
7eb30d00e5
feat: add missing models and parameters to image metadata (#743)
* feat: add new scheduler types, clip skip and vae to image embedded params

- If a non default scheduler is set, include it in the 'Sampler' tag in the data
embedded into the final image.
- If a custom VAE path is set, include the vae name (without path and extension)
in embedded image params under a `VAE:` tag.
- If a custom Clip skip is set, include that Clip skip value in embedded image
params under a `Clip skip:` tag.

* feat: add separate diffusion and text models to metadata

---------

Co-authored-by: one-lithe-rune <skapusniak@lithe-runes.com>
2025-07-28 22:00:27 +08:00
stduhpf
59080d3ce1
feat: change image dimensions requirement for DiT models (#742) 2025-07-28 21:58:17 +08:00
R0CKSTAR
8c3c788f31
feat: upgrade musa sdk to rc4.2.0 (#732) 2025-07-28 21:51:11 +08:00
leejet
f54524f620 sync: update ggml 2025-07-28 21:50:12 +08:00
20 changed files with 153 additions and 121 deletions

View File

@ -1,6 +1,7 @@
ARG MUSA_VERSION=rc3.1.1
ARG MUSA_VERSION=rc4.2.0
ARG UBUNTU_VERSION=22.04
FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build
FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64 as build
RUN apt-get update && apt-get install -y ccache cmake git
@ -15,7 +16,7 @@ RUN mkdir build && cd build && \
-DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \
cmake --build . --config Release
FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime
FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64 as runtime
COPY --from=build /sd.cpp/build/bin/sd /sd

View File

@ -545,9 +545,9 @@ protected:
int64_t vocab_size;
int64_t num_positions;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type token_wtype = GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32;
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
@ -594,10 +594,10 @@ protected:
int64_t image_size;
int64_t num_patches;
int64_t num_positions;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type patch_wtype = GGML_TYPE_F16;
enum ggml_type class_wtype = GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32;
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
@ -657,9 +657,9 @@ enum CLIPVersion {
class CLIPTextModel : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
if (version == OPEN_CLIP_VIT_BIGG_14) {
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
}
}
@ -805,8 +805,8 @@ protected:
int64_t out_features;
bool transpose_weight;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
if (transpose_weight) {
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
} else {
@ -868,7 +868,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model;
CLIPTextModelRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string prefix,
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
bool with_final_ln = true,

View File

@ -182,9 +182,9 @@ protected:
int64_t dim_in;
int64_t dim_out;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
enum ggml_type bias_wtype = GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
}
@ -440,9 +440,9 @@ public:
class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}

View File

@ -57,7 +57,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<std::string> readed_embeddings;
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1,
@ -618,7 +618,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
CLIPVisionModelProjection vision_model;
FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
FrozenCLIPVisionEmbedder(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
}
@ -663,8 +663,8 @@ struct SD3CLIPEmbedder : public Conditioner {
std::shared_ptr<T5Runner> t5;
SD3CLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1)
const String2GGMLType& tensor_types = {},
int clip_skip = -1)
: clip_g_tokenizer(0) {
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
@ -1010,8 +1010,8 @@ struct FluxCLIPEmbedder : public Conditioner {
size_t chunk_len = 256;
FluxCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1) {
const String2GGMLType& tensor_types = {},
int clip_skip = -1) {
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
set_clip_skip(clip_skip);
@ -1231,10 +1231,10 @@ struct PixArtCLIPEmbedder : public Conditioner {
int mask_pad = 1;
PixArtCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1,
bool use_mask = false,
int mask_pad = 1)
const String2GGMLType& tensor_types = {},
int clip_skip = -1,
bool use_mask = false,
int mask_pad = 1)
: use_mask(use_mask), mask_pad(mask_pad) {
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
}

View File

@ -317,8 +317,8 @@ struct ControlNet : public GGMLRunner {
bool guided_hint_cached = false;
ControlNet(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1)
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1)
: GGMLRunner(backend), control_net(version) {
control_net.init(params_ctx, tensor_types, "");
}

View File

@ -32,9 +32,9 @@ struct UNetModel : public DiffusionModel {
UNetModelRunner unet;
UNetModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}
@ -85,7 +85,7 @@ struct MMDiTModel : public DiffusionModel {
MMDiTRunner mmdit;
MMDiTModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types)
const String2GGMLType& tensor_types = {})
: mmdit(backend, tensor_types, "model.diffusion_model") {
}
@ -135,10 +135,10 @@ struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;
FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
}

View File

@ -142,7 +142,7 @@ struct ESRGAN : public GGMLRunner {
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
ESRGAN(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types)
ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
: GGMLRunner(backend) {
rrdb_net.init(params_ctx, tensor_types, "");
}

View File

@ -596,13 +596,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.width <= 0 || params.width % 64 != 0) {
fprintf(stderr, "error: the width must be a multiple of 64\n");
if (params.width <= 0) {
fprintf(stderr, "error: the width must be greater than 0\n");
exit(1);
}
if (params.height <= 0 || params.height % 64 != 0) {
fprintf(stderr, "error: the height must be a multiple of 64\n");
if (params.height <= 0) {
fprintf(stderr, "error: the height must be greater than 0\n");
exit(1);
}
@ -677,10 +677,24 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", ";
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_method));
if (params.schedule == KARRAS) {
parameter_string += " karras";
if (params.schedule != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.schedule));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
}
if (!params.diffusion_model_path.empty()) {
parameter_string += "Unet: " + sd_basename(params.diffusion_model_path) + ", ";
}
if (!params.vae_path.empty()) {
parameter_string += "VAE: " + sd_basename(params.vae_path) + ", ";
}
if (params.clip_skip != -1) {
parameter_string += "Clip skip: " + std::to_string(params.clip_skip) + ", ";
}
parameter_string += "Version: stable-diffusion.cpp";
return parameter_string;
}

View File

@ -35,8 +35,8 @@ namespace Flux {
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
ggml_type wtype = GGML_TYPE_F32;
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -1039,8 +1039,6 @@ namespace Flux {
};
struct FluxRunner : public GGMLRunner {
static std::map<std::string, enum ggml_type> empty_tensor_types;
public:
FluxParams flux_params;
Flux flux;
@ -1050,11 +1048,11 @@ namespace Flux {
bool use_mask = false;
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend), use_mask(use_mask) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;

2
ggml

@ -1 +1 @@
Subproject commit 56938c4a3b2d923f42040f9ad32d229c76c466cd
Subproject commit b96890f3ab5ffbdbe56bc126df5366c34bd08d39

View File

@ -841,21 +841,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head));
int kv_pad = 0;
//if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
//}
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
// if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// }
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && (
d_head == 64 ||
d_head == 80 ||
d_head == 96 ||
d_head == 112 ||
d_head == 128 ||
d_head == 256
);
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
d_head == 80 ||
d_head == 96 ||
d_head == 112 ||
d_head == 128 ||
d_head == 256);
#if 0
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
#else
@ -880,9 +878,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
ggml_tensor* kqv = nullptr;
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
if (can_use_flash_attn && flash_attn) {
//LOG_DEBUG(" uses flash attention");
// LOG_DEBUG(" uses flash attention");
if (kv_pad != 0) {
//LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
}
k = ggml_cast(ctx, k, GGML_TYPE_F16);
@ -1099,6 +1097,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
#define MAX_PARAMS_TENSOR_NUM 32768
#define MAX_GRAPH_SIZE 32768
typedef std::map<std::string, enum ggml_type> String2GGMLType;
struct GGMLRunner {
protected:
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
@ -1310,17 +1310,25 @@ protected:
GGMLBlockMap blocks;
ParameterMap params;
void init_blocks(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
ggml_type get_type(const std::string& name, const String2GGMLType& tensor_types, ggml_type default_type) {
auto iter = tensor_types.find(name);
if (iter != tensor_types.end()) {
return iter->second;
}
return default_type;
}
void init_blocks(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
for (auto& pair : blocks) {
auto& block = pair.second;
block->init(ctx, tensor_types, prefix + pair.first);
}
}
virtual void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {}
virtual void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {}
public:
void init(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
void init(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
if (prefix.size() > 0) {
prefix = prefix + ".";
}
@ -1381,8 +1389,8 @@ protected:
bool bias;
bool force_f32;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
@ -1417,8 +1425,8 @@ class Embedding : public UnaryBlock {
protected:
int64_t embedding_dim;
int64_t num_embeddings;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
}
@ -1457,11 +1465,11 @@ protected:
std::pair<int, int> dilation;
bool bias;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
}
}
@ -1502,11 +1510,11 @@ protected:
int64_t dilation;
bool bias;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
}
}
@ -1546,12 +1554,12 @@ protected:
bool elementwise_affine;
bool bias;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
if (elementwise_affine) {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
}
}
@ -1588,10 +1596,10 @@ protected:
float eps;
bool affine;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
if (affine) {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
enum ggml_type wtype = GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, num_channels);
params["bias"] = ggml_new_tensor_1d(ctx, bias_wtype, num_channels);
}

View File

@ -147,8 +147,8 @@ protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -652,13 +652,13 @@ protected:
int64_t hidden_size;
std::string qk_norm;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
}
public:
MMDiT(std::map<std::string, enum ggml_type>& tensor_types) {
MMDiT(const String2GGMLType& tensor_types = {}) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
@ -869,11 +869,9 @@ public:
struct MMDiTRunner : public GGMLRunner {
MMDiT mmdit;
static std::map<std::string, enum ggml_type> empty_tensor_types;
MMDiTRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "")
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend), mmdit(tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
}

View File

@ -648,7 +648,7 @@ std::string convert_tensor_name(std::string name) {
return new_name;
}
void add_preprocess_tensor_storage_types(std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
void add_preprocess_tensor_storage_types(String2GGMLType& tensor_storages_types, std::string name, enum ggml_type type) {
std::string new_name = convert_tensor_name(name);
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {

View File

@ -207,6 +207,8 @@ struct TensorStorage {
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
typedef std::map<std::string, enum ggml_type> String2GGMLType;
class ModelLoader {
protected:
std::vector<std::string> file_paths_;
@ -225,7 +227,7 @@ protected:
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public:
std::map<std::string, enum ggml_type> tensor_storages_types;
String2GGMLType tensor_storages_types;
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
bool model_is_unet();

View File

@ -623,7 +623,12 @@ public:
std::vector<float> zeros_right;
public:
PhotoMakerIDEncoder(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
PhotoMakerIDEncoder(ggml_backend_t backend,
const String2GGMLType& tensor_types,
const std::string prefix,
SDVersion version = VERSION_SDXL,
PMVersion pm_v = PM_VERSION_1,
float sty = 20.f)
: GGMLRunner(backend),
version(version),
pm_version(pm_v),

View File

@ -1875,6 +1875,15 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;
if (sd_version_is_dit(sd_ctx->sd->version)) {
if (width % 16 || height % 16) {
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height);
return NULL;
}
} else if (width % 64 || height % 64) {
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height);
return NULL;
}
LOG_DEBUG("generate_image %dx%d", width, height);
if (sd_ctx == NULL || sd_img_gen_params == NULL) {
return NULL;

22
t5.hpp
View File

@ -457,8 +457,8 @@ protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -735,7 +735,7 @@ struct T5Runner : public GGMLRunner {
std::vector<int> relative_position_bucket_vec;
T5Runner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string prefix,
int64_t num_layers = 24,
int64_t model_dim = 4096,
@ -876,16 +876,14 @@ struct T5Embedder {
T5UniGramTokenizer tokenizer;
T5Runner model;
static std::map<std::string, enum ggml_type> empty_tensor_types;
T5Embedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
int64_t num_layers = 24,
int64_t model_dim = 4096,
int64_t ff_dim = 10240,
int64_t num_heads = 64,
int64_t vocab_size = 32128)
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
int64_t num_layers = 24,
int64_t model_dim = 4096,
int64_t ff_dim = 10240,
int64_t num_heads = 64,
int64_t vocab_size = 32128)
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) {
}

View File

@ -196,7 +196,7 @@ struct TinyAutoEncoder : public GGMLRunner {
bool decode_only = false;
TinyAutoEncoder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string prefix,
bool decoder_only = true,
SDVersion version = VERSION_SD1)

View File

@ -166,7 +166,6 @@ public:
// ldm.modules.diffusionmodules.openaimodel.UNetModel
class UnetModelBlock : public GGMLBlock {
protected:
static std::map<std::string, enum ggml_type> empty_tensor_types;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
@ -184,7 +183,7 @@ public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}, bool flash_attn = false)
: version(version) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
@ -539,7 +538,7 @@ struct UNetModelRunner : public GGMLRunner {
UnetModelBlock unet;
UNetModelRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string prefix,
SDVersion version = VERSION_SD1,
bool flash_attn = false)

View File

@ -163,8 +163,8 @@ public:
class VideoResnetBlock : public ResnetBlock {
protected:
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_types, GGML_TYPE_F32);
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
@ -525,7 +525,7 @@ struct AutoEncoderKL : public GGMLRunner {
AutoencodingEngine ae;
AutoEncoderKL(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const String2GGMLType& tensor_types,
const std::string prefix,
bool decode_only = false,
bool use_video_decoder = false,