mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
Merge branch 'master' into set_hw_from_image
This commit is contained in:
commit
6c634edd4c
@ -15,6 +15,9 @@ API and command-line option may change frequently.***
|
||||
|
||||
## 🔥Important News
|
||||
|
||||
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
|
||||
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
|
||||
|
||||
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
|
||||
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Running distilled models: SSD1B and SDx.x with tiny U-Nets
|
||||
# Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets
|
||||
|
||||
## Preface
|
||||
|
||||
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
||||
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
|
||||
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
|
||||
|
||||
## SSD1B
|
||||
@ -17,7 +17,17 @@ Useful LoRAs are also available:
|
||||
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
|
||||
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
|
||||
|
||||
These files can be used out-of-the-box, unlike the models described in the next section.
|
||||
## Vega
|
||||
|
||||
Segmind's Vega model is available online here:
|
||||
|
||||
* https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors
|
||||
|
||||
VegaRT is an example for an LCM-LoRA:
|
||||
|
||||
* https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors
|
||||
|
||||
Both files can be used out-of-the-box, unlike the models described in next sections.
|
||||
|
||||
|
||||
## SD1.x, SD2.x with tiny U-Nets
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
## Using ESRGAN to upscale results
|
||||
|
||||
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon.
|
||||
You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity.
|
||||
|
||||
- Specify the model path using the `--upscale-model PATH` parameter. example:
|
||||
|
||||
|
||||
@ -86,21 +86,6 @@ std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string iso_timestamp_now() {
|
||||
using namespace std::chrono;
|
||||
auto now = system_clock::now();
|
||||
std::time_t t = system_clock::to_time_t(now);
|
||||
std::tm tm{};
|
||||
#ifdef _MSC_VER
|
||||
gmtime_s(&tm, &t);
|
||||
#else
|
||||
gmtime_r(&t, &tm);
|
||||
#endif
|
||||
std::ostringstream oss;
|
||||
oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
struct SDSvrParams {
|
||||
std::string listen_ip = "127.0.0.1";
|
||||
int listen_port = 1234;
|
||||
@ -404,7 +389,7 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
|
||||
json out;
|
||||
out["created"] = iso_timestamp_now();
|
||||
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||
out["data"] = json::array();
|
||||
out["output_format"] = output_format;
|
||||
|
||||
@ -692,7 +677,7 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
|
||||
json out;
|
||||
out["created"] = iso_timestamp_now();
|
||||
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||
out["data"] = json::array();
|
||||
out["output_format"] = output_format;
|
||||
|
||||
@ -732,6 +717,327 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
});
|
||||
|
||||
// sdapi endpoints (AUTOMATIC1111 / Forge)
|
||||
|
||||
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json j = json::parse(req.body);
|
||||
|
||||
std::string prompt = j.value("prompt", "");
|
||||
std::string negative_prompt = j.value("negative_prompt", "");
|
||||
int width = j.value("width", 512);
|
||||
int height = j.value("height", 512);
|
||||
int steps = j.value("steps", -1);
|
||||
float cfg_scale = j.value("cfg_scale", 7.f);
|
||||
int64_t seed = j.value("seed", -1);
|
||||
int batch_size = j.value("batch_size", 1);
|
||||
int clip_skip = j.value("clip_skip", -1);
|
||||
std::string sampler_name = j.value("sampler_name", "");
|
||||
std::string scheduler_name = j.value("scheduler", "");
|
||||
|
||||
auto bad = [&](const std::string& msg) {
|
||||
res.status = 400;
|
||||
res.set_content("{\"error\":\"" + msg + "\"}", "application/json");
|
||||
return;
|
||||
};
|
||||
|
||||
if (width <= 0 || height <= 0) {
|
||||
return bad("width and height must be positive");
|
||||
}
|
||||
|
||||
if (steps < 1 || steps > 150) {
|
||||
return bad("steps must be in range [1, 150]");
|
||||
}
|
||||
|
||||
if (batch_size < 1 || batch_size > 8) {
|
||||
return bad("batch_size must be in range [1, 8]");
|
||||
}
|
||||
|
||||
if (cfg_scale < 0.f) {
|
||||
return bad("cfg_scale must be positive");
|
||||
}
|
||||
|
||||
if (prompt.empty()) {
|
||||
return bad("prompt required");
|
||||
}
|
||||
|
||||
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||
// some applications use a hardcoded sampler list
|
||||
std::transform(name.begin(), name.end(), name.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
||||
{"euler a", EULER_A_SAMPLE_METHOD},
|
||||
{"k_euler_a", EULER_A_SAMPLE_METHOD},
|
||||
{"euler", EULER_SAMPLE_METHOD},
|
||||
{"k_euler", EULER_SAMPLE_METHOD},
|
||||
{"heun", HEUN_SAMPLE_METHOD},
|
||||
{"k_heun", HEUN_SAMPLE_METHOD},
|
||||
{"dpm2", DPM2_SAMPLE_METHOD},
|
||||
{"k_dpm_2", DPM2_SAMPLE_METHOD},
|
||||
{"lcm", LCM_SAMPLE_METHOD},
|
||||
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
|
||||
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
|
||||
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}};
|
||||
auto it = hardcoded.find(name);
|
||||
if (it != hardcoded.end()) return it->second;
|
||||
return SAMPLE_METHOD_COUNT;
|
||||
};
|
||||
|
||||
enum sample_method_t sample_method = get_sample_method(sampler_name);
|
||||
|
||||
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
||||
|
||||
// avoid excessive resource usage
|
||||
|
||||
SDGenerationParams gen_params = default_gen_params;
|
||||
gen_params.prompt = prompt;
|
||||
gen_params.negative_prompt = negative_prompt;
|
||||
gen_params.width = width;
|
||||
gen_params.height = height;
|
||||
gen_params.seed = seed;
|
||||
gen_params.sample_params.sample_steps = steps;
|
||||
gen_params.batch_count = batch_size;
|
||||
|
||||
if (clip_skip > 0) {
|
||||
gen_params.clip_skip = clip_skip;
|
||||
}
|
||||
|
||||
if (sample_method != SAMPLE_METHOD_COUNT) {
|
||||
gen_params.sample_params.sample_method = sample_method;
|
||||
}
|
||||
|
||||
if (scheduler != SCHEDULER_COUNT) {
|
||||
gen_params.sample_params.scheduler = scheduler;
|
||||
}
|
||||
|
||||
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
|
||||
|
||||
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
|
||||
std::vector<uint8_t> mask_data;
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
std::vector<sd_image_t> ref_images;
|
||||
|
||||
if (img2img) {
|
||||
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
|
||||
// remove data URI prefix if present ("data:image/png;base64,")
|
||||
auto comma_pos = encoded.find(',');
|
||||
if (comma_pos != std::string::npos) {
|
||||
encoded = encoded.substr(comma_pos + 1);
|
||||
}
|
||||
std::vector<uint8_t> img_data = base64_decode(encoded);
|
||||
if (!img_data.empty()) {
|
||||
int img_w = image.width;
|
||||
int img_h = image.height;
|
||||
uint8_t* raw_data = load_image_from_memory(
|
||||
(const char*)img_data.data(), (int)img_data.size(),
|
||||
img_w, img_h,
|
||||
image.width, image.height, image.channel);
|
||||
if (raw_data) {
|
||||
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
|
||||
std::string encoded = j["init_images"][0].get<std::string>();
|
||||
decode_image(init_image, encoded);
|
||||
}
|
||||
|
||||
if (j.contains("mask") && j["mask"].is_string()) {
|
||||
std::string encoded = j["mask"].get<std::string>();
|
||||
decode_image(mask_image, encoded);
|
||||
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
|
||||
if (inpainting_mask_invert && mask_image.data != nullptr) {
|
||||
for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) {
|
||||
mask_image.data[i] = 255 - mask_image.data[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mask_data = std::vector<uint8_t>(width * height, 255);
|
||||
mask_image.width = width;
|
||||
mask_image.height = height;
|
||||
mask_image.channel = 1;
|
||||
mask_image.data = mask_data.data();
|
||||
}
|
||||
|
||||
if (j.contains("extra_images") && j["extra_images"].is_array()) {
|
||||
for (auto extra_image : j["extra_images"]) {
|
||||
std::string encoded = extra_image.get<std::string>();
|
||||
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
if (decode_image(tmp_image, encoded)) {
|
||||
ref_images.push_back(tmp_image);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float denoising_strength = j.value("denoising_strength", -1.f);
|
||||
if (denoising_strength >= 0.f) {
|
||||
denoising_strength = std::min(denoising_strength, 1.0f);
|
||||
gen_params.strength = denoising_strength;
|
||||
}
|
||||
}
|
||||
|
||||
sd_img_gen_params_t img_gen_params = {
|
||||
gen_params.lora_vec.data(),
|
||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
||||
gen_params.prompt.c_str(),
|
||||
gen_params.negative_prompt.c_str(),
|
||||
gen_params.clip_skip,
|
||||
init_image,
|
||||
ref_images.data(),
|
||||
(int)ref_images.size(),
|
||||
gen_params.auto_resize_ref_image,
|
||||
gen_params.increase_ref_index,
|
||||
mask_image,
|
||||
gen_params.width,
|
||||
gen_params.height,
|
||||
gen_params.sample_params,
|
||||
gen_params.strength,
|
||||
gen_params.seed,
|
||||
gen_params.batch_count,
|
||||
control_image,
|
||||
gen_params.control_strength,
|
||||
{
|
||||
pmid_images.data(),
|
||||
(int)pmid_images.size(),
|
||||
gen_params.pm_id_embed_path.c_str(),
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
sd_image_t* results = nullptr;
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
num_results = gen_params.batch_count;
|
||||
}
|
||||
|
||||
json out;
|
||||
out["images"] = json::array();
|
||||
out["parameters"] = j; // TODO should return changed defaults
|
||||
out["info"] = "";
|
||||
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto image_bytes = write_image_to_vector(ImageFormat::PNG,
|
||||
results[i].data,
|
||||
results[i].width,
|
||||
results[i].height,
|
||||
results[i].channel);
|
||||
|
||||
if (image_bytes.empty()) {
|
||||
LOG_ERROR("write image to mem failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string b64 = base64_encode(image_bytes);
|
||||
out["images"].push_back(b64);
|
||||
}
|
||||
|
||||
res.set_content(out.dump(), "application/json");
|
||||
res.status = 200;
|
||||
|
||||
if (init_image.data) {
|
||||
stbi_image_free(init_image.data);
|
||||
}
|
||||
if (mask_image.data && mask_data.empty()) {
|
||||
stbi_image_free(mask_image.data);
|
||||
}
|
||||
for (auto ref_image : ref_images) {
|
||||
stbi_image_free(ref_image.data);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
json err;
|
||||
err["error"] = "server_error";
|
||||
err["message"] = e.what();
|
||||
res.set_content(err.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, false);
|
||||
});
|
||||
|
||||
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, true);
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> sampler_names;
|
||||
sampler_names.push_back("default");
|
||||
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
||||
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
|
||||
}
|
||||
json r = json::array();
|
||||
for (auto name : sampler_names) {
|
||||
json entry;
|
||||
entry["name"] = name;
|
||||
entry["aliases"] = json::array({name});
|
||||
entry["options"] = json::object();
|
||||
r.push_back(entry);
|
||||
}
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> scheduler_names;
|
||||
scheduler_names.push_back("default");
|
||||
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
||||
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
|
||||
}
|
||||
json r = json::array();
|
||||
for (auto name : scheduler_names) {
|
||||
json entry;
|
||||
entry["name"] = name;
|
||||
entry["label"] = name;
|
||||
r.push_back(entry);
|
||||
}
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = ctx_params.model_path;
|
||||
json entry;
|
||||
entry["title"] = model_path.stem();
|
||||
entry["model_name"] = model_path.stem();
|
||||
entry["filename"] = model_path.filename();
|
||||
entry["hash"] = "8888888888";
|
||||
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
|
||||
entry["config"] = nullptr;
|
||||
json r = json::array();
|
||||
r.push_back(entry);
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = ctx_params.model_path;
|
||||
json r;
|
||||
r["samples_format"] = "png";
|
||||
r["sd_model_checkpoint"] = model_path.stem();
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
||||
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
||||
|
||||
|
||||
28
flux.hpp
28
flux.hpp
@ -748,7 +748,7 @@ namespace Flux {
|
||||
int nerf_depth = 4;
|
||||
int nerf_max_freqs = 8;
|
||||
bool use_x0 = false;
|
||||
bool use_patch_size_32 = false;
|
||||
bool fake_patch_size_x2 = false;
|
||||
};
|
||||
|
||||
struct FluxParams {
|
||||
@ -786,8 +786,11 @@ namespace Flux {
|
||||
Flux(FluxParams params)
|
||||
: params(params) {
|
||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||
std::pair<int, int> kernel_size = {16, 16};
|
||||
std::pair<int, int> stride = kernel_size;
|
||||
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
|
||||
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||
kernel_size = {params.patch_size / 2, params.patch_size / 2};
|
||||
}
|
||||
std::pair<int, int> stride = kernel_size;
|
||||
|
||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||
params.hidden_size,
|
||||
@ -1082,7 +1085,7 @@ namespace Flux {
|
||||
auto img = pad_to_patch_size(ctx, x);
|
||||
auto orig_img = img;
|
||||
|
||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
||||
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||
@ -1303,7 +1306,8 @@ namespace Flux {
|
||||
flux_params.ref_index_scale = 10.f;
|
||||
flux_params.use_mlp_silu_act = true;
|
||||
}
|
||||
int64_t head_dim = 0;
|
||||
int64_t head_dim = 0;
|
||||
int64_t actual_radiance_patch_size = -1;
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (!starts_with(tensor_name, prefix))
|
||||
@ -1316,9 +1320,12 @@ namespace Flux {
|
||||
flux_params.chroma_radiance_params.use_x0 = true;
|
||||
}
|
||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||
LOG_DEBUG("using patch size 32 prediction");
|
||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
||||
flux_params.patch_size = 32;
|
||||
LOG_DEBUG("using patch size 32");
|
||||
flux_params.patch_size = 32;
|
||||
}
|
||||
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
|
||||
actual_radiance_patch_size = pair.second.ne[0];
|
||||
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
|
||||
}
|
||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
// Chroma
|
||||
@ -1351,6 +1358,11 @@ namespace Flux {
|
||||
head_dim = pair.second.ne[0];
|
||||
}
|
||||
}
|
||||
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
|
||||
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
|
||||
LOG_DEBUG("using fake x2 patch size");
|
||||
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
|
||||
}
|
||||
|
||||
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||
|
||||
|
||||
13
model.cpp
13
model.cpp
@ -376,7 +376,11 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
|
||||
LOG_INFO("load %s using checkpoint format", file_path.c_str());
|
||||
return init_from_ckpt_file(file_path, prefix);
|
||||
} else {
|
||||
LOG_WARN("unknown format %s", file_path.c_str());
|
||||
if (file_exists(file_path)) {
|
||||
LOG_WARN("unknown format %s", file_path.c_str());
|
||||
} else {
|
||||
LOG_WARN("file %s not found", file_path.c_str());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -1040,6 +1044,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
int64_t patch_embedding_channels = 0;
|
||||
bool has_img_emb = false;
|
||||
bool has_middle_block_1 = false;
|
||||
bool has_output_block_311 = false;
|
||||
bool has_output_block_71 = false;
|
||||
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
@ -1100,6 +1105,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||
has_middle_block_1 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||
has_output_block_311 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||
has_output_block_71 = true;
|
||||
}
|
||||
@ -1138,6 +1146,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
return VERSION_SDXL_PIX2PIX;
|
||||
}
|
||||
if (!has_middle_block_1) {
|
||||
if (!has_output_block_311) {
|
||||
return VERSION_SDXL_VEGA;
|
||||
}
|
||||
return VERSION_SDXL_SSD1B;
|
||||
}
|
||||
return VERSION_SDXL;
|
||||
|
||||
3
model.h
3
model.h
@ -32,6 +32,7 @@ enum SDVersion {
|
||||
VERSION_SDXL,
|
||||
VERSION_SDXL_INPAINT,
|
||||
VERSION_SDXL_PIX2PIX,
|
||||
VERSION_SDXL_VEGA,
|
||||
VERSION_SDXL_SSD1B,
|
||||
VERSION_SVD,
|
||||
VERSION_SD3,
|
||||
@ -66,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
|
||||
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -35,6 +35,7 @@ const char* model_version_to_str[] = {
|
||||
"SDXL",
|
||||
"SDXL Inpaint",
|
||||
"SDXL Instruct-Pix2Pix",
|
||||
"SDXL (Vega)",
|
||||
"SDXL (SSD1B)",
|
||||
"SVD",
|
||||
"SD3.x",
|
||||
@ -623,7 +624,7 @@ public:
|
||||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
if (version == VERSION_SDXL &&
|
||||
if (sd_version_is_sdxl(version) &&
|
||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||
float vae_conv_2d_scale = 1.f / 32.f;
|
||||
LOG_WARN(
|
||||
|
||||
56
tae.hpp
56
tae.hpp
@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
|
||||
protected:
|
||||
int n_in;
|
||||
int n_out;
|
||||
bool use_midblock_gn;
|
||||
|
||||
public:
|
||||
TAEBlock(int n_in, int n_out)
|
||||
: n_in(n_in), n_out(n_out) {
|
||||
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
|
||||
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
|
||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
|
||||
if (n_in != n_out) {
|
||||
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||
}
|
||||
if (use_midblock_gn) {
|
||||
int n_gn = n_in * 4;
|
||||
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
|
||||
// pool.2 is ReLU, handled in forward
|
||||
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [n, n_in, h, w]
|
||||
// return: [n, n_out, h, w]
|
||||
|
||||
if (use_midblock_gn) {
|
||||
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
|
||||
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
|
||||
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
|
||||
|
||||
auto p = pool_0->forward(ctx, x);
|
||||
p = pool_1->forward(ctx, p);
|
||||
p = ggml_relu_inplace(ctx->ggml_ctx, p);
|
||||
p = pool_3->forward(ctx, p);
|
||||
|
||||
x = ggml_add(ctx->ggml_ctx, x, p);
|
||||
}
|
||||
|
||||
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
|
||||
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
|
||||
int num_blocks = 3;
|
||||
|
||||
public:
|
||||
TinyEncoder(int z_channels = 4)
|
||||
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||
: z_channels(z_channels) {
|
||||
int index = 0;
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
@ -80,7 +101,7 @@ public:
|
||||
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||
}
|
||||
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
|
||||
int num_blocks = 3;
|
||||
|
||||
public:
|
||||
TinyDecoder(int z_channels = 4)
|
||||
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
|
||||
: z_channels(z_channels) {
|
||||
int index = 0;
|
||||
|
||||
@ -115,7 +136,7 @@ public:
|
||||
index++; // nn.ReLU()
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
|
||||
}
|
||||
index++; // nn.Upsample()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||
@ -470,29 +491,44 @@ public:
|
||||
class TAESD : public GGMLBlock {
|
||||
protected:
|
||||
bool decode_only;
|
||||
bool taef2 = false;
|
||||
|
||||
public:
|
||||
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only) {
|
||||
int z_channels = 4;
|
||||
int z_channels = 4;
|
||||
bool use_midblock_gn = false;
|
||||
taef2 = sd_version_is_flux2(version);
|
||||
|
||||
if (sd_version_is_dit(version)) {
|
||||
z_channels = 16;
|
||||
}
|
||||
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
|
||||
if (taef2) {
|
||||
z_channels = 32;
|
||||
use_midblock_gn = true;
|
||||
}
|
||||
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
|
||||
|
||||
if (!decode_only) {
|
||||
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
|
||||
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
|
||||
if (taef2) {
|
||||
z = unpatchify(ctx->ggml_ctx, z, 2);
|
||||
}
|
||||
return decoder->forward(ctx, z);
|
||||
}
|
||||
|
||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
|
||||
return encoder->forward(ctx, x);
|
||||
auto z = encoder->forward(ctx, x);
|
||||
if (taef2) {
|
||||
z = patchify(ctx->ggml_ctx, z, 2);
|
||||
}
|
||||
return z;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
7
unet.hpp
7
unet.hpp
@ -201,6 +201,9 @@ public:
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
use_linear_projection = true;
|
||||
if (version == VERSION_SDXL_VEGA) {
|
||||
transformer_depth = {1, 1, 2};
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
in_channels = 8;
|
||||
out_channels = 4;
|
||||
@ -319,7 +322,7 @@ public:
|
||||
}
|
||||
if (!tiny_unet) {
|
||||
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
|
||||
if (version != VERSION_SDXL_SSD1B) {
|
||||
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
|
||||
n_head,
|
||||
d_head,
|
||||
@ -520,7 +523,7 @@ public:
|
||||
// middle_block
|
||||
if (!tiny_unet) {
|
||||
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
if (version != VERSION_SDXL_SSD1B) {
|
||||
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
|
||||
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user