Merge branch 'master' into server

This commit is contained in:
leejet 2025-12-10 23:38:38 +08:00
commit 96aea6340c
38 changed files with 1044 additions and 504 deletions

View File

@ -163,7 +163,7 @@ jobs:
- build: "avx512"
defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON"
- build: "cuda12"
defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES=90;89;86;80;75"
defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES='61;70;75;80;86;89;90;100;120'"
- build: 'vulkan'
defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON"
steps:
@ -176,9 +176,9 @@ jobs:
- name: Install cuda-toolkit
id: cuda-toolkit
if: ${{ matrix.build == 'cuda12' }}
uses: Jimver/cuda-toolkit@v0.2.19
uses: Jimver/cuda-toolkit@v0.2.22
with:
cuda: "12.6.2"
cuda: "12.8.1"
method: "network"
sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]'

View File

@ -87,6 +87,38 @@ file(GLOB SD_LIB_SOURCES
"*.hpp"
)
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
if(GIT_EXE)
execute_process(COMMAND ${GIT_EXE} describe --tags --abbrev=7 --dirty=+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE SDCPP_BUILD_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE SDCPP_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
endif()
if(NOT SDCPP_BUILD_VERSION)
set(SDCPP_BUILD_VERSION unknown)
endif()
message(STATUS "stable-diffusion.cpp version ${SDCPP_BUILD_VERSION}")
if(NOT SDCPP_BUILD_COMMIT)
set(SDCPP_BUILD_COMMIT unknown)
endif()
message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}")
set_property(
SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
APPEND PROPERTY COMPILE_DEFINITIONS
SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION}
)
if(SD_BUILD_SHARED_LIBS)
message("-- Build shared library")
message(${SD_LIB_SOURCES})

View File

@ -1,5 +1,5 @@
<p align="center">
<img src="./assets/cat_with_sd_cpp_42.png" width="360x">
<img src="./assets/logo.png" width="360x">
</p>
# stable-diffusion.cpp
@ -49,6 +49,7 @@ API and command-line option may change frequently.***
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
- [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
@ -104,7 +105,7 @@ API and command-line option may change frequently.***
### Download model weights
- download weights(.ckpt or .safetensors or .gguf). For example
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
- Stable Diffusion v1.5 from https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
```sh
curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
@ -134,6 +135,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
- [🔥Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

BIN
assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 401 KiB

107
clip.hpp
View File

@ -3,34 +3,10 @@
#include "ggml_extend.hpp"
#include "model.h"
#include "tokenize_util.h"
/*================================================== CLIPTokenizer ===================================================*/
__STATIC_INLINE__ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
std::regex re("<lora:([^:]+):([^>]+)>");
std::smatch matches;
std::unordered_map<std::string, float> filename2multiplier;
while (std::regex_search(text, matches, re)) {
std::string filename = matches[1].str();
float multiplier = std::stof(matches[2].str());
text = std::regex_replace(text, re, "", std::regex_constants::format_first_only);
if (multiplier == 0.f) {
continue;
}
if (filename2multiplier.find(filename) == filename2multiplier.end()) {
filename2multiplier[filename] = multiplier;
} else {
filename2multiplier[filename] += multiplier;
}
}
return std::make_pair(filename2multiplier, text);
}
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
@ -72,6 +48,8 @@ private:
int encoder_len;
int bpe_len;
std::vector<std::string> special_tokens;
public:
const std::string UNK_TOKEN = "<|endoftext|>";
const std::string BOS_TOKEN = "<|startoftext|>";
@ -117,6 +95,15 @@ private:
return pairs;
}
bool is_special_token(const std::string& token) {
for (auto& special_token : special_tokens) {
if (special_token == token) {
return true;
}
}
return false;
}
public:
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
: PAD_TOKEN_ID(pad_token_id) {
@ -125,6 +112,8 @@ public:
} else {
load_from_merges(ModelLoader::load_merges());
}
add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>");
}
void load_from_merges(const std::string& merges_utf8_str) {
@ -201,6 +190,10 @@ public:
}
}
void add_special_token(const std::string& token) {
special_tokens.push_back(token);
}
std::u32string bpe(const std::u32string& token) {
std::vector<std::u32string> word;
@ -379,25 +372,54 @@ public:
return trim(text);
}
std::vector<std::string> token_split(const std::string& text) {
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
std::regex::icase);
std::sregex_iterator iter(text.begin(), text.end(), pat);
std::sregex_iterator end;
std::vector<std::string> result;
for (; iter != end; ++iter) {
result.emplace_back(iter->str());
}
return result;
}
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
text = whitespace_clean(text);
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
std::regex::icase);
std::smatch matches;
std::string str = text;
std::vector<std::string> token_strs;
while (std::regex_search(str, matches, pat)) {
bool skip = on_new_token_cb(str, bpe_tokens);
if (skip) {
auto splited_texts = split_with_special_tokens(text, special_tokens);
for (auto& splited_text : splited_texts) {
LOG_DEBUG("token %s", splited_text.c_str());
if (is_special_token(splited_text)) {
LOG_DEBUG("special %s", splited_text.c_str());
bool skip = on_new_token_cb(splited_text, bpe_tokens);
if (skip) {
token_strs.push_back(splited_text);
continue;
}
continue;
}
for (auto& token : matches) {
std::string token_str = token.str();
auto tokens = token_split(splited_text);
for (auto& token : tokens) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(token, bpe_tokens);
if (skip) {
token_strs.push_back(token);
continue;
}
}
std::string token_str = token;
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];
@ -417,14 +439,13 @@ public:
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
str = matches.suffix();
}
std::stringstream ss;
ss << "[";
for (auto token : token_strs) {
ss << "\"" << token << "\", ";
}
ss << "]";
// std::stringstream ss;
// ss << "[";
// for (auto token : token_strs) {
// ss << "\"" << token << "\", ";
// }
// ss << "]";
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
@ -963,7 +984,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* input_ids,
int num_custom_embeddings,
void* custom_embeddings_data,
@ -975,7 +996,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
};

View File

@ -56,7 +56,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::shared_ptr<CLIPTextModelRunner> text_model2;
std::string trigger_word = "img"; // should be user settable
std::string embd_dir;
std::map<std::string, std::string> embedding_map;
int32_t num_custom_embeddings = 0;
int32_t num_custom_embeddings_2 = 0;
std::vector<uint8_t> token_embed_custom;
@ -65,11 +65,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
const std::string& embd_dir,
const std::map<std::string, std::string>& orig_embedding_map,
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1)
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
bool force_clip_f32 = embd_dir.size() > 0;
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) {
for (const auto& kv : orig_embedding_map) {
std::string name = kv.first;
std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); });
embedding_map[name] = kv.second;
tokenizer.add_special_token(name);
}
bool force_clip_f32 = !embedding_map.empty();
if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
} else if (sd_version_is_sd2(version)) {
@ -196,25 +202,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> convert_token_to_id(std::string text) {
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
auto iter = embedding_map.find(str);
if (iter == embedding_map.end()) {
return false;
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
std::string embedding_path = iter->second;
if (load_embedding(str, embedding_path, bpe_tokens)) {
return true;
}
return false;
};
@ -245,25 +239,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
auto iter = embedding_map.find(str);
if (iter == embedding_map.end()) {
return false;
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
std::string embedding_path = iter->second;
if (load_embedding(str, embedding_path, bpe_tokens)) {
return true;
}
return false;
};
@ -376,25 +358,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
auto iter = embedding_map.find(str);
if (iter == embedding_map.end()) {
return false;
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
std::string embedding_path = iter->second;
if (load_embedding(str, embedding_path, bpe_tokens)) {
return true;
}
return false;
};
@ -703,7 +673,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
ggml_tensor* pixel_values,
bool return_pooled,
int clip_skip,
@ -712,7 +682,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(pixel_values, return_pooled, clip_skip);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
};
@ -1638,7 +1608,7 @@ struct LLMEmbedder : public Conditioner {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version)) {
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@ -1728,6 +1698,7 @@ struct LLMEmbedder : public Conditioner {
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
std::pair<int, int> prompt_attn_range;
int prompt_template_encode_start_idx = 34;
int max_length = 0;
std::set<int> out_layers;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
@ -1825,6 +1796,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = prompt.size();
prompt += "[/INST]";
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += " " + conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else {
prompt_template_encode_start_idx = 34;
@ -1837,7 +1819,7 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
@ -1870,9 +1852,13 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
int64_t min_length = 0;
if (sd_version_is_flux2(version)) {
int64_t min_length = 512;
min_length = 512;
}
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
@ -1892,6 +1878,8 @@ struct LLMEmbedder : public Conditioner {
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
// print_ggml_tensor(new_hidden_states);
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr};

View File

@ -414,7 +414,7 @@ struct ControlNet : public GGMLRunner {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* hint,
struct ggml_tensor* timesteps,
@ -430,8 +430,12 @@ struct ControlNet : public GGMLRunner {
return build_graph(x, hint, timesteps, context, y);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
guided_hint_cached = true;
bool res = GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
if (res) {
// cache guided_hint
guided_hint_cached = true;
}
return res;
}
bool load_from_file(const std::string& file_path, int n_threads) {

View File

@ -666,7 +666,7 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static void sample_k_diffusion(sample_method_t method,
static bool sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
ggml_context* work_ctx,
ggml_tensor* x,
@ -685,6 +685,9 @@ static void sample_k_diffusion(sample_method_t method,
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
if (denoised == nullptr) {
return false;
}
// d = (x - denoised) / sigma
{
@ -738,6 +741,9 @@ static void sample_k_diffusion(sample_method_t method,
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
if (denoised == nullptr) {
return false;
}
// d = (x - denoised) / sigma
{
@ -769,6 +775,9 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], -(i + 1));
if (denoised == nullptr) {
return false;
}
// d = (x - denoised) / sigma
{
@ -803,7 +812,10 @@ static void sample_k_diffusion(sample_method_t method,
}
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
float* vec_denoised = (float*)denoised->data;
if (denoised == nullptr) {
return false;
}
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d2 = (vec_x2[j] - vec_denoised[j]) / sigmas[i + 1];
vec_d[j] = (vec_d[j] + d2) / 2;
@ -819,6 +831,9 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
// d = (x - denoised) / sigma
{
@ -855,7 +870,10 @@ static void sample_k_diffusion(sample_method_t method,
}
ggml_tensor* denoised = model(x2, sigma_mid, i + 1);
float* vec_denoised = (float*)denoised->data;
if (denoised == nullptr) {
return false;
}
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid;
vec_x[j] = vec_x[j] + d2 * dt_2;
@ -871,6 +889,9 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
// get_ancestral_step
float sigma_up = std::min(sigmas[i + 1],
@ -907,6 +928,9 @@ static void sample_k_diffusion(sample_method_t method,
}
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
if (denoised == nullptr) {
return false;
}
// Second half-step
for (int j = 0; j < ggml_nelements(x); j++) {
@ -937,6 +961,9 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
float t = t_fn(sigmas[i]);
float t_next = t_fn(sigmas[i + 1]);
@ -976,6 +1003,9 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) {
// denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
float t = t_fn(sigmas[i]);
float t_next = t_fn(sigmas[i + 1]);
@ -1026,7 +1056,10 @@ static void sample_k_diffusion(sample_method_t method,
// Denoising step
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
if (denoised == nullptr) {
return false;
}
float* vec_denoised = (float*)denoised->data;
// d_cur = (x_cur - denoised) / sigma
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
float* vec_d_cur = (float*)d_cur->data;
@ -1169,6 +1202,9 @@ static void sample_k_diffusion(sample_method_t method,
// denoise
ggml_tensor* denoised = model(x, sigma, i + 1);
if (denoised == nullptr) {
return false;
}
// x = denoised
{
@ -1561,8 +1597,9 @@ static void sample_k_diffusion(sample_method_t method,
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
abort();
return false;
}
return true;
}
#endif // __DENOISER_HPP__

View File

@ -27,7 +27,7 @@ struct DiffusionParams {
struct DiffusionModel {
virtual std::string get_desc() = 0;
virtual void compute(int n_threads,
virtual bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) = 0;
@ -87,7 +87,7 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
@ -148,7 +148,7 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
@ -210,7 +210,7 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
@ -277,7 +277,7 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
@ -343,7 +343,7 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
@ -406,7 +406,7 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {

View File

@ -15,7 +15,7 @@ You can run Flux using stable-diffusion.cpp with a GPU that has 6GB or even 4GB
You can download the preconverted gguf weights from [FLUX.1-dev-gguf](https://huggingface.co/leejet/FLUX.1-dev-gguf) or [FLUX.1-schnell](https://huggingface.co/leejet/FLUX.1-schnell-gguf), this way you don't have to do the conversion yourself.
Using fp16 will lead to overflow, but ggml's support for bf16 is not yet fully developed. Therefore, we need to convert flux to gguf format here, which also saves VRAM. For example:
For example:
```
.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\flux1-dev.sft -o ..\models\flux1-dev-q8_0.gguf -v --type q8_0
```

19
docs/ovis_image.md Normal file
View File

@ -0,0 +1,19 @@
# How to Use
## Download weights
- Download Ovis-Image-7B
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Ovis-Image-7B-GGUF
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Ovis 2.5
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/text_encoders
## Examples
```
.\bin\Release\sd.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
```
<img alt="ovis image example" src="../assets/ovis_image/example.png" />

View File

@ -156,9 +156,10 @@ struct ESRGAN : public GGMLRunner {
ESRGAN(ggml_backend_t backend,
bool offload_params_to_cpu,
int tile_size = 128,
const String2TensorStorage& tensor_storage_map = {})
: GGMLRunner(backend, offload_params_to_cpu) {
// rrdb_net will be created in load_from_file
this->tile_size = tile_size;
}
std::string get_desc() override {
@ -353,14 +354,14 @@ struct ESRGAN : public GGMLRunner {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* x,
ggml_tensor** output,
ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};

View File

@ -31,8 +31,6 @@
#include "common/common.hpp"
namespace fs = std::filesystem;
const char* previews_str[] = {
"none",
"proj",
@ -45,6 +43,7 @@ struct SDCliParams {
std::string output_path = "output.png";
bool verbose = false;
bool version = false;
bool canny_preprocess = false;
preview_t preview_method = PREVIEW_NONE;
@ -87,6 +86,10 @@ struct SDCliParams {
"--verbose",
"print extra info",
true, &verbose},
{"",
"--version",
"print stable-diffusion.cpp version",
true, &version},
{"",
"--color",
"colors the logging tags according to level",
@ -130,18 +133,18 @@ struct SDCliParams {
return -1;
}
const char* preview = argv[index];
int preview_method = -1;
int preview_found = -1;
for (int m = 0; m < PREVIEW_COUNT; m++) {
if (!strcmp(preview, previews_str[m])) {
preview_method = m;
preview_found = m;
}
}
if (preview_method == -1) {
if (preview_found == -1) {
fprintf(stderr, "error: preview method %s\n",
preview);
return -1;
}
preview_method = (preview_t)preview_method;
preview_method = (preview_t)preview_found;
return 1;
};
@ -202,6 +205,7 @@ struct SDCliParams {
};
void print_usage(int argc, const char* argv[], const std::vector<ArgOptions>& options_list) {
std::cout << version_string() << "\n";
std::cout << "Usage: " << argv[0] << " [options]\n\n";
std::cout << "CLI Options:\n";
options_list[0].print();
@ -219,7 +223,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
exit(cli_params.normal_exit ? 0 : 1);
}
if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode)) {
if (!cli_params.process_and_check() ||
!ctx_params.process_and_check(cli_params.mode) ||
!gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) {
print_usage(argc, argv, options_vec);
exit(1);
}
@ -484,11 +490,19 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy,
}
int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
return EXIT_SUCCESS;
}
SDCliParams cli_params;
SDContextParams ctx_params;
SDGenerationParams gen_params;
parse_args(argc, argv, cli_params, ctx_params, gen_params);
if (cli_params.verbose || cli_params.version) {
std::cout << version_string() << "\n";
}
if (gen_params.video_frames > 4) {
size_t last_dot_pos = cli_params.preview_path.find_last_of(".");
std::string base_path = cli_params.preview_path;
@ -724,6 +738,8 @@ int main(int argc, const char* argv[]) {
if (cli_params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
@ -755,6 +771,8 @@ int main(int argc, const char* argv[]) {
num_results = gen_params.batch_count;
} else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
@ -791,7 +809,8 @@ int main(int argc, const char* argv[]) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct,
ctx_params.n_threads);
ctx_params.n_threads,
gen_params.upscale_tile_size);
if (upscaler_ctx == nullptr) {
printf("new_upscaler_ctx failed\n");

View File

@ -6,15 +6,19 @@
#include <sstream>
#include <string>
#include <vector>
#include <filesystem>
#include <json.hpp>
using json = nlohmann::json;
namespace fs = std::filesystem;
#if defined(_WIN32)
#define NOMINMAX
#include <windows.h>
#endif // _WIN32
#include "stable-diffusion.h"
#define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false")
@ -312,6 +316,9 @@ struct SDContextParams {
std::string tensor_type_rules;
std::string lora_model_dir;
std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec;
rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false;
@ -326,7 +333,7 @@ struct SDContextParams {
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
prediction_t prediction = DEFAULT_PRED;
prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
@ -639,6 +646,37 @@ struct SDContextParams {
return options;
}
void build_embedding_map() {
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
return;
}
for (auto& p : fs::directory_iterator(embedding_dir)) {
if (!p.is_regular_file())
continue;
auto path = p.path();
std::string ext = path.extension().string();
bool valid = false;
for (auto& e : valid_ext) {
if (ext == e) {
valid = true;
break;
}
}
if (!valid)
continue;
std::string key = path.stem().string();
std::string value = path.string();
embedding_map[key] = value;
}
}
bool process_and_check(SDMode mode) {
if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
@ -656,10 +694,24 @@ struct SDContextParams {
n_threads = sd_get_num_physical_cores();
}
build_embedding_map();
return true;
}
std::string to_string() const {
std::ostringstream emb_ss;
emb_ss << "{\n";
for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) {
emb_ss << " \"" << it->first << "\": \"" << it->second << "\"";
if (std::next(it) != embedding_map.end()) {
emb_ss << ",";
}
emb_ss << "\n";
}
emb_ss << " }";
std::string embeddings_str = emb_ss.str();
std::ostringstream oss;
oss << "SDContextParams {\n"
<< " n_threads: " << n_threads << ",\n"
@ -677,6 +729,7 @@ struct SDContextParams {
<< " esrgan_path: \"" << esrgan_path << "\",\n"
<< " control_net_path: \"" << control_net_path << "\",\n"
<< " embedding_dir: \"" << embedding_dir << "\",\n"
<< " embeddings: " << embeddings_str << "\n"
<< " wtype: " << sd_type_name(wtype) << ",\n"
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
@ -709,6 +762,15 @@ struct SDContextParams {
}
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
embedding_vec.clear();
embedding_vec.reserve(embedding_map.size());
for (const auto& kv : embedding_map) {
sd_embedding_t item;
item.name = kv.first.c_str();
item.path = kv.second.c_str();
embedding_vec.emplace_back(item);
}
sd_ctx_params_t sd_ctx_params = {
model_path.c_str(),
clip_l_path.c_str(),
@ -723,7 +785,8 @@ struct SDContextParams {
taesd_path.c_str(),
control_net_path.c_str(),
lora_model_dir.c_str(),
embedding_dir.c_str(),
embedding_vec.data(),
static_cast<uint32_t>(embedding_vec.size()),
photo_maker_path.c_str(),
tensor_type_rules.c_str(),
vae_decode_only,
@ -777,6 +840,15 @@ static std::string vec_str_to_string(const std::vector<std::string>& v) {
return oss.str();
}
static bool is_absolute_path(const std::string& p) {
#ifdef _WIN32
// Windows: C:/path or C:\path
return p.size() > 1 && std::isalpha(static_cast<unsigned char>(p[0])) && p[1] == ':';
#else
return !p.empty() && p[0] == '/';
#endif
}
struct SDGenerationParams {
std::string prompt;
std::string negative_prompt;
@ -817,7 +889,12 @@ struct SDGenerationParams {
std::string pm_id_embed_path;
float pm_style_strength = 20.f;
int upscale_repeats = 1;
int upscale_repeats = 1;
int upscale_tile_size = 128;
std::map<std::string, float> lora_map;
std::map<std::string, float> high_noise_lora_map;
std::vector<sd_lora_t> lora_vec;
SDGenerationParams() {
sd_sample_params_init(&sample_params);
@ -910,6 +987,10 @@ struct SDGenerationParams {
"--upscale-repeats",
"Run the ESRGAN upscaler this many times (default: 1)",
&upscale_repeats},
{"",
"--upscale-tile-size",
"tile size for ESRGAN upscaling (default: 128)",
&upscale_tile_size},
};
options.float_options = {
@ -1251,7 +1332,88 @@ struct SDGenerationParams {
return true;
}
bool process_and_check(SDMode mode) {
void extract_and_remove_lora(const std::string& lora_model_dir) {
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
std::smatch m;
std::string tmp = prompt;
while (std::regex_search(tmp, m, re)) {
std::string raw_path = m[1].str();
const std::string raw_mul = m[2].str();
float mul = 0.f;
try {
mul = std::stof(raw_mul);
} catch (...) {
tmp = m.suffix().str();
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
continue;
}
bool is_high_noise = false;
static const std::string prefix = "|high_noise|";
if (raw_path.rfind(prefix, 0) == 0) {
raw_path.erase(0, prefix.size());
is_high_noise = true;
}
fs::path final_path;
if (is_absolute_path(raw_path)) {
final_path = raw_path;
} else {
final_path = fs::path(lora_model_dir) / raw_path;
}
if (!fs::exists(final_path)) {
bool found = false;
for (const auto& ext : valid_ext) {
fs::path try_path = final_path;
try_path += ext;
if (fs::exists(try_path)) {
final_path = try_path;
found = true;
break;
}
}
if (!found) {
printf("can not found lora %s\n", final_path.lexically_normal().string().c_str());
tmp = m.suffix().str();
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
continue;
}
}
const std::string key = final_path.lexically_normal().string();
if (is_high_noise)
high_noise_lora_map[key] += mul;
else
lora_map[key] += mul;
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
tmp = m.suffix().str();
}
for (const auto& kv : lora_map) {
sd_lora_t item;
item.is_high_noise = false;
item.path = kv.first.c_str();
item.multiplier = kv.second;
lora_vec.emplace_back(item);
}
for (const auto& kv : high_noise_lora_map) {
sd_lora_t item;
item.is_high_noise = true;
item.path = kv.first.c_str();
item.multiplier = kv.second;
lora_vec.emplace_back(item);
}
}
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
if (width <= 0) {
fprintf(stderr, "error: the width must be greater than 0\n");
return false;
@ -1350,6 +1512,10 @@ struct SDGenerationParams {
return false;
}
if (upscale_tile_size < 1) {
return false;
}
if (mode == UPSCALE) {
if (init_image_path.length() == 0) {
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
@ -1362,14 +1528,44 @@ struct SDGenerationParams {
seed = rand();
}
extract_and_remove_lora(lora_model_dir);
return true;
}
std::string to_string() const {
char* sample_params_str = sd_sample_params_to_str(&sample_params);
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);
std::ostringstream lora_ss;
lora_ss << "{\n";
for (auto it = lora_map.begin(); it != lora_map.end(); ++it) {
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
if (std::next(it) != lora_map.end()) {
lora_ss << ",";
}
lora_ss << "\n";
}
lora_ss << " }";
std::string loras_str = lora_ss.str();
lora_ss = std::ostringstream();
;
lora_ss << "{\n";
for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) {
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
if (std::next(it) != high_noise_lora_map.end()) {
lora_ss << ",";
}
lora_ss << "\n";
}
lora_ss << " }";
std::string high_noise_loras_str = lora_ss.str();
std::ostringstream oss;
oss << "SDGenerationParams {\n"
<< " loras: \"" << loras_str << "\",\n"
<< " high_noise_loras: \"" << high_noise_loras_str << "\",\n"
<< " prompt: \"" << prompt << "\",\n"
<< " negative_prompt: \"" << negative_prompt << "\",\n"
<< " clip_skip: " << clip_skip << ",\n"
@ -1405,9 +1601,14 @@ struct SDGenerationParams {
<< " control_strength: " << control_strength << ",\n"
<< " seed: " << seed << ",\n"
<< " upscale_repeats: " << upscale_repeats << ",\n"
<< " upscale_tile_size: " << upscale_tile_size << ",\n"
<< "}";
free(sample_params_str);
free(high_noise_sample_params_str);
return oss.str();
}
};
static std::string version_string() {
return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit();
}

View File

@ -134,6 +134,7 @@ struct SDSvrParams {
};
void print_usage(int argc, const char* argv[], const std::vector<ArgOptions>& options_list) {
std::cout << version_string() << "\n";
std::cout << "Usage: " << argv[0] << " [options]\n\n";
std::cout << "Svr Options:\n";
options_list[0].print();
@ -362,7 +363,7 @@ int main(int argc, const char** argv) {
return;
}
if (!gen_params.process_and_check(IMG_GEN)) {
if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) {
res.status = 400;
res.set_content(R"({"error":"invalid params"})", "application/json");
return;
@ -378,6 +379,8 @@ int main(int argc, const char** argv) {
std::vector<sd_image_t> pmid_images;
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,

138
flux.hpp
View File

@ -134,6 +134,54 @@ namespace Flux {
}
};
struct MLP : public UnaryBlock {
bool use_mlp_silu_act;
public:
MLP(int64_t hidden_size, int64_t intermediate_size, bool use_mlp_silu_act = false, bool bias = false)
: use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
blocks["0"] = std::make_shared<Linear>(hidden_size, intermediate_size * mlp_mult_factor, bias);
blocks["2"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
x = mlp_0->forward(ctx, x);
if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
}
x = mlp_2->forward(ctx, x);
return x;
}
};
struct YakMLP : public UnaryBlock {
public:
YakMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = true) {
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
blocks["down_proj"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
auto gate = gate_proj->forward(ctx, x);
gate = ggml_silu_inplace(ctx->ggml_ctx, gate);
x = up_proj->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, gate);
x = down_proj->forward(ctx, x);
return x;
}
};
struct ModulationOut {
ggml_tensor* shift = nullptr;
ggml_tensor* scale = nullptr;
@ -199,7 +247,6 @@ namespace Flux {
struct DoubleStreamBlock : public GGMLBlock {
bool prune_mod;
int idx = 0;
bool use_mlp_silu_act;
public:
DoubleStreamBlock(int64_t hidden_size,
@ -210,10 +257,10 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
@ -222,9 +269,11 @@ namespace Flux {
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
if (use_yak_mlp) {
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
} else {
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
}
if (!prune_mod && !share_modulation) {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
@ -233,9 +282,11 @@ namespace Flux {
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
if (use_yak_mlp) {
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
} else {
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
}
}
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
@ -272,15 +323,13 @@ namespace Flux {
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm2"]);
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
auto img_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["img_mlp"]);
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm2"]);
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
auto txt_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["txt_mlp"]);
if (img_mods.empty()) {
if (prune_mod) {
@ -348,27 +397,15 @@ namespace Flux {
// calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
if (use_mlp_silu_act) {
img_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, img_mlp_out);
} else {
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
}
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
auto img_mlp_out = img_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
// calculate the txt bloks
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
if (use_mlp_silu_act) {
txt_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, txt_mlp_out);
} else {
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
}
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
auto txt_mlp_out = txt_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
return {img, txt};
}
@ -381,6 +418,7 @@ namespace Flux {
int64_t mlp_hidden_dim;
bool prune_mod;
int idx = 0;
bool use_yak_mlp;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
@ -393,8 +431,9 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
@ -402,7 +441,7 @@ namespace Flux {
}
mlp_hidden_dim = hidden_size * mlp_ratio;
mlp_mult_factor = 1;
if (use_mlp_silu_act) {
if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2;
}
@ -481,7 +520,9 @@ namespace Flux {
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
if (use_mlp_silu_act) {
if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
@ -726,6 +767,8 @@ namespace Flux {
int64_t in_dim = 64;
bool disable_bias = false;
bool share_modulation = false;
bool semantic_txt_norm = false;
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params;
@ -759,6 +802,9 @@ namespace Flux {
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
}
}
if (params.semantic_txt_norm) {
blocks["txt_norm"] = std::make_shared<RMSNorm>(params.context_in_dim);
}
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, !params.disable_bias);
for (int i = 0; i < params.depth; i++) {
@ -770,6 +816,7 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
}
@ -782,6 +829,7 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
}
@ -948,6 +996,12 @@ namespace Flux {
ss_mods = single_stream_modulation->forward(ctx, vec);
}
if (params.semantic_txt_norm) {
auto semantic_txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
txt = semantic_txt_norm->forward(ctx, txt);
}
txt = txt_in->forward(ctx, txt);
for (int i = 0; i < params.depth; i++) {
@ -1206,6 +1260,11 @@ namespace Flux {
} else if (version == VERSION_CHROMA_RADIANCE) {
flux_params.in_channels = 3;
flux_params.patch_size = 16;
} else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
@ -1364,13 +1423,22 @@ namespace Flux {
ref_latents[i] = to_backend(ref_latents[i]);
}
std::set<int> txt_arange_dims;
if (sd_version_is_flux2(version)) {
txt_arange_dims = {3};
increase_ref_index = true;
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}
pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
x->ne[3],
context->ne[1],
txt_arange_dims,
ref_latents,
sd_version_is_flux2(version) ? true : increase_ref_index,
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
@ -1413,7 +1481,7 @@ namespace Flux {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -1434,7 +1502,7 @@ namespace Flux {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

View File

@ -60,6 +60,14 @@
#define SD_UNUSED(x) (void)(x)
#endif
__STATIC_INLINE__ int align_up_offset(int n, int multiple) {
return (multiple - n % multiple) % multiple;
}
__STATIC_INLINE__ int align_up(int n, int multiple) {
return n + align_up_offset(n, multiple);
}
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
switch (level) {
case GGML_LOG_LEVEL_DEBUG:
@ -760,17 +768,23 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
return chunks;
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x) {
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x, bool gate_first = true) {
// x: [ne3, ne2, ne1, ne0]
// return: [ne3, ne2, ne1, ne0/2]
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
auto x1 = x_vec[0]; // [ne3, ne2, ne1, ne0/2]
auto x2 = x_vec[1]; // [ne3, ne2, ne1, ne0/2]
ggml_tensor* gate;
if (gate_first) {
gate = x_vec[0];
x = x_vec[1];
} else {
x = x_vec[0];
gate = x_vec[1];
}
x1 = ggml_gelu_inplace(ctx, x1);
gate = ggml_silu_inplace(ctx, gate);
x = ggml_mul(ctx, x1, x2); // [ne3, ne2, ne1, ne0/2]
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
return x;
}
@ -1938,25 +1952,35 @@ public:
return ggml_get_tensor(cache_ctx, name.c_str());
}
void compute(get_graph_cb_t get_graph,
bool compute(get_graph_cb_t get_graph,
int n_threads,
bool free_compute_buffer_immediately = true,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) {
if (!offload_params_to_runtime_backend()) {
LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str());
return;
return false;
}
if (!alloc_compute_buffer(get_graph)) {
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
return false;
}
alloc_compute_buffer(get_graph);
reset_compute_ctx();
struct ggml_cgraph* gf = get_compute_graph(get_graph);
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) {
LOG_ERROR("%s alloc compute graph failed", get_desc().c_str());
return false;
}
copy_data_to_backend_tensor();
if (ggml_backend_is_cpu(runtime_backend)) {
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
}
ggml_backend_graph_compute(runtime_backend, gf);
ggml_status status = ggml_backend_graph_compute(runtime_backend, gf);
if (status != GGML_STATUS_SUCCESS) {
LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status));
return false;
}
#ifdef GGML_PERF
ggml_graph_print(gf);
#endif
@ -1974,6 +1998,7 @@ public:
if (free_compute_buffer_immediately) {
free_compute_buffer();
}
return true;
}
void set_flash_attention_enabled(bool enabled) {

View File

@ -91,6 +91,41 @@ const float flux_latent_rgb_proj[16][3] = {
{-0.111849f, -0.055589f, -0.032361f}};
float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f};
const float flux2_latent_rgb_proj[32][3] = {
{0.000736f, -0.008385f, -0.019710f},
{-0.001352f, -0.016392f, 0.020693f},
{-0.006376f, 0.002428f, 0.036736f},
{0.039384f, 0.074167f, 0.119789f},
{0.007464f, -0.005705f, -0.004734f},
{-0.004086f, 0.005287f, -0.000409f},
{-0.032835f, 0.050802f, -0.028120f},
{-0.003158f, -0.000835f, 0.000406f},
{-0.112840f, -0.084337f, -0.023083f},
{0.001462f, -0.006656f, 0.000549f},
{-0.009980f, -0.007480f, 0.009702f},
{0.032540f, 0.000214f, -0.061388f},
{0.011023f, 0.000694f, 0.007143f},
{-0.001468f, -0.006723f, -0.001678f},
{-0.005921f, -0.010320f, -0.003907f},
{-0.028434f, 0.027584f, 0.018457f},
{0.014349f, 0.011523f, 0.000441f},
{0.009874f, 0.003081f, 0.001507f},
{0.002218f, 0.005712f, 0.001563f},
{0.053010f, -0.019844f, 0.008683f},
{-0.002507f, 0.005384f, 0.000938f},
{-0.002177f, -0.011366f, 0.003559f},
{-0.000261f, 0.015121f, -0.003240f},
{-0.003944f, -0.002083f, 0.005043f},
{-0.009138f, 0.011336f, 0.003781f},
{0.011429f, 0.003985f, -0.003855f},
{0.010518f, -0.005586f, 0.010131f},
{0.007883f, 0.002912f, -0.001473f},
{-0.003318f, -0.003160f, 0.003684f},
{-0.034560f, -0.008740f, 0.012996f},
{0.000166f, 0.001079f, -0.012153f},
{0.017772f, 0.000937f, -0.011953f}};
float flux2_latent_rgb_bias[3] = {-0.028738f, -0.098463f, -0.107619f};
// This one was taken straight from
// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
// (MiT Licence)
@ -128,16 +163,42 @@ const float sd_latent_rgb_proj[4][3] = {
{-0.178022f, -0.200862f, -0.678514f}};
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
size_t buffer_head = 0;
uint32_t latent_width = latents->ne[0];
uint32_t latent_height = latents->ne[1];
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) {
frames = latents->ne[2];
}
uint32_t rgb_width = latent_width * patch_size;
uint32_t rgb_height = latent_height * patch_size;
uint32_t unpatched_dim = dim / (patch_size * patch_size);
for (int k = 0; k < frames; k++) {
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]);
for (int rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
for (int rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
int latent_x = rgb_x / patch_size;
int latent_y = rgb_y / patch_size;
int channel_offset = 0;
if (patch_size > 1) {
channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
}
size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]);
// should be incremented by 1 for each pixel
size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
float r = 0, g = 0, b = 0;
if (latent_rgb_proj != nullptr) {
for (int d = 0; d < dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]);
for (int d = 0; d < unpatched_dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]);
r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1];
b += value * latent_rgb_proj[d][2];
@ -164,9 +225,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
g = g >= 0 ? g <= 1 ? g : 1 : 0;
b = b >= 0 ? b <= 1 ? b : 1 : 0;
buffer[buffer_head++] = (uint8_t)(r * 255);
buffer[buffer_head++] = (uint8_t)(g * 255);
buffer[buffer_head++] = (uint8_t)(b * 255);
buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255);
buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255);
buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255);
}
}
}

75
llm.hpp
View File

@ -356,6 +356,10 @@ namespace LLM {
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<tool_response>",
"</tool_response>",
"<think>",
"</think>",
};
if (merges_utf8_str.size() > 0) {
@ -859,11 +863,11 @@ namespace LLM {
}
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else {
int sections[4] = {16, 24, 24, 0};
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -1073,29 +1077,22 @@ namespace LLM {
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
params.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
params.num_layers = 40;
params.hidden_size = 5120;
params.intermediate_size = 32768;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.vocab_size = 131072;
params.rms_norm_eps = 1e-5f;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.rms_norm_eps = 1e-5f;
} else if (arch == LLMArch::QWEN3) {
params.num_layers = 36;
params.hidden_size = 2560;
params.intermediate_size = 9728;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.vocab_size = 151936;
params.rms_norm_eps = 1e-6f;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.rms_norm_eps = 1e-6f;
}
bool have_vision_weight = false;
bool llama_cpp_style = false;
params.num_layers = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
@ -1105,10 +1102,36 @@ namespace LLM {
have_vision_weight = true;
if (contains(tensor_name, "attn.q_proj")) {
llama_cpp_style = true;
break;
}
continue;
}
pos = tensor_name.find("layers.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > params.num_layers) {
params.num_layers = block_index + 1;
}
}
}
if (contains(tensor_name, "embed_tokens.weight")) {
params.hidden_size = pair.second.ne[0];
params.vocab_size = pair.second.ne[1];
}
if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) {
params.intermediate_size = pair.second.ne[1];
}
}
if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B
params.num_heads = 16;
}
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
params.num_layers,
params.vocab_size,
params.hidden_size,
params.intermediate_size);
if (enable_vision && !have_vision_weight) {
LOG_WARN("no vision weights detected, vision disabled");
enable_vision = false;
@ -1191,7 +1214,7 @@ namespace LLM {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* input_ids,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers,
@ -1200,7 +1223,7 @@ namespace LLM {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids, image_embeds, out_layers);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {

View File

@ -894,7 +894,7 @@ struct MMDiTRunner : public GGMLRunner {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -910,7 +910,7 @@ struct MMDiTRunner : public GGMLRunner {
return build_graph(x, timesteps, context, y, skip_layers);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

View File

@ -123,11 +123,6 @@ bool is_unused_tensor(std::string name) {
return false;
}
float bf16_to_f32(uint16_t bfloat16) {
uint32_t val_bits = (static_cast<uint32_t>(bfloat16) << 16);
return *reinterpret_cast<float*>(&val_bits);
}
uint16_t f8_e4m3_to_f16(uint8_t f8) {
// do we need to support uz?
@ -210,13 +205,6 @@ uint16_t f8_e5m2_to_f16(uint8_t fp8) {
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
dst[i] = bf16_to_f32(src[i]);
}
}
void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
@ -495,7 +483,7 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
if (dtype == "F16") {
ttype = GGML_TYPE_F16;
} else if (dtype == "BF16") {
ttype = GGML_TYPE_F32;
ttype = GGML_TYPE_BF16;
} else if (dtype == "F32") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F64") {
@ -623,10 +611,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
size_t tensor_data_size = end - begin;
if (dtype == "BF16") {
tensor_storage.is_bf16 = true;
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E4M3") {
if (dtype == "F8_E4M3") {
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
@ -1071,6 +1056,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
return VERSION_FLUX2;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
return VERSION_OVIS_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
return VERSION_Z_IMAGE;
}
@ -1522,9 +1510,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
read_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
if (tensor_storage.is_bf16) {
bf16_to_f32_vec((uint16_t*)read_buf, (float*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
if (tensor_storage.is_f8_e4m3) {
f8_e4m3_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
f8_e5m2_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());

View File

@ -45,6 +45,7 @@ enum SDVersion {
VERSION_QWEN_IMAGE,
VERSION_FLUX2,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_COUNT,
};
@ -90,6 +91,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
version == VERSION_FLUX_FILL ||
version == VERSION_FLUX_CONTROLS ||
version == VERSION_FLEX_2 ||
version == VERSION_OVIS_IMAGE ||
version == VERSION_CHROMA_RADIANCE) {
return true;
}
@ -168,7 +170,6 @@ struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
ggml_type expected_type = GGML_TYPE_COUNT;
bool is_bf16 = false;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
bool is_f64 = false;
@ -202,7 +203,7 @@ struct TensorStorage {
}
int64_t nbytes_to_read() const {
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
if (is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else if (is_f64 || is_i64) {
return nbytes() * 2;
@ -250,9 +251,7 @@ struct TensorStorage {
std::string to_string() const {
std::stringstream ss;
const char* type_name = ggml_type_name(type);
if (is_bf16) {
type_name = "bf16";
} else if (is_f8_e4m3) {
if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
type_name = "f8_e5m2";

View File

@ -548,7 +548,7 @@ public:
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* id_embeds,
@ -561,7 +561,7 @@ public:
};
// GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds);
GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
}
};

View File

@ -588,7 +588,7 @@ namespace Qwen {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -603,7 +603,7 @@ namespace Qwen {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

View File

@ -72,11 +72,13 @@ namespace Rope {
}
// Generate IDs for image patches and text
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num, std::set<int> arange_dims) {
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
if (axes_dim_num == 4) {
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][3] = (i % context_len);
for (int dim = 0; dim < axes_dim_num; dim++) {
if (arange_dims.find(dim) != arange_dims.end()) {
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][dim] = (i % context_len);
}
}
}
return txt_ids;
@ -211,10 +213,11 @@ namespace Rope {
int bs,
int axes_dim_num,
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num);
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids, img_ids, bs);
@ -231,6 +234,7 @@ namespace Rope {
int patch_size,
int bs,
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale,
@ -242,6 +246,7 @@ namespace Rope {
bs,
static_cast<int>(axes_dim.size()),
context_len,
txt_arange_dims,
ref_latents,
increase_ref_index,
ref_index_scale);

View File

@ -46,6 +46,7 @@ const char* model_version_to_str[] = {
"Qwen Image",
"Flux.2",
"Z-Image",
"Ovis Image",
};
const char* sampling_methods_str[] = {
@ -307,13 +308,6 @@ public:
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (contains(name, "llm") &&
ends_with(name, "weight") &&
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
tensor_storage.expected_type = GGML_TYPE_F16;
}
}
LOG_INFO("Version: %s ", model_version_to_str[version]);
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
@ -431,6 +425,13 @@ public:
tensor_storage_map,
sd_ctx_params->chroma_use_t5_mask,
sd_ctx_params->chroma_t5_mask_pad);
} else if (version == VERSION_OVIS_IMAGE) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version,
"",
false);
} else {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
@ -507,18 +508,22 @@ public:
"model.diffusion_model",
version);
} else { // SD1.x SD2.x SDXL
std::map<std::string, std::string> embbeding_map;
for (int i = 0; i < sd_ctx_params->embedding_count; i++) {
embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path));
}
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
embbeding_map,
version,
PM_VERSION_2);
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
embbeding_map,
version);
}
diffusion_model = std::make_shared<UNetModel>(backend,
@ -697,6 +702,11 @@ public:
ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("text_encoders.llm.visual.");
}
if (version == VERSION_OVIS_IMAGE) {
ignore_tensors.insert("text_encoders.llm.vision_model.");
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
ignore_tensors.insert("text_encoders.llm.vte.");
}
if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3");
}
@ -707,7 +717,7 @@ public:
return false;
}
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
LOG_DEBUG("finished loaded file");
{
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
@ -782,8 +792,59 @@ public:
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
}
if (sd_ctx_params->prediction != DEFAULT_PRED) {
switch (sd_ctx_params->prediction) {
// init denoiser
{
prediction_t pred_type = sd_ctx_params->prediction;
float flow_shift = sd_ctx_params->flow_shift;
if (pred_type == PREDICTION_COUNT) {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
pred_type = EDM_V_PRED;
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
pred_type = V_PRED;
} else {
pred_type = EPS_PRED;
}
} else if (sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (flow_shift == INFINITY) {
if (sd_version_is_wan(version)) {
flow_shift = 5.f;
} else {
flow_shift = 3.f;
}
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
flow_shift = 1.15f;
}
}
}
} else if (sd_version_is_flux2(version)) {
pred_type = FLUX2_FLOW_PRED;
} else {
pred_type = EPS_PRED;
}
}
switch (pred_type) {
case EPS_PRED:
LOG_INFO("running in eps-prediction mode");
break;
@ -795,22 +856,14 @@ public:
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
break;
case SD3_FLOW_PRED: {
case FLOW_PRED: {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
break;
}
case FLUX_FLOW_PRED: {
LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
break;
}
case FLUX2_FLOW_PRED: {
@ -819,93 +872,21 @@ public:
break;
}
default: {
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
LOG_ERROR("Unknown predition type %i", pred_type);
ggml_free(ctx);
return false;
}
}
} else {
if (sd_version_is_sd2(version)) {
// check is_using_v_parameterization_for_sd2
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}
if (sd_version_is_sd3(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
shift = 1.15f;
}
}
}
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_flux2(version)) {
LOG_INFO("running in Flux2 FLOW mode");
denoiser = std::make_shared<Flux2FlowDenoiser>();
} else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 5.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_qwen_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_z_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0f;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else if (is_using_edm_v_parameterization) {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
}
}
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
if (comp_vis_denoiser) {
for (int i = 0; i < TIMESTEPS; i++) {
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
}
}
LOG_DEBUG("finished loaded file");
ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
return true;
@ -956,28 +937,17 @@ public:
float multiplier,
ggml_backend_t backend,
LoraModel::filter_t lora_tensor_filter = nullptr) {
std::string lora_name = lora_id;
std::string high_noise_tag = "|high_noise|";
bool is_high_noise = false;
if (starts_with(lora_name, high_noise_tag)) {
lora_name = lora_name.substr(high_noise_tag.size());
std::string lora_path = lora_id;
static std::string high_noise_tag = "|high_noise|";
bool is_high_noise = false;
if (starts_with(lora_path, high_noise_tag)) {
lora_path = lora_path.substr(high_noise_tag.size());
is_high_noise = true;
LOG_DEBUG("high noise lora: %s", lora_name.c_str());
LOG_DEBUG("high noise lora: %s", lora_path.c_str());
}
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
std::string file_path;
if (file_exists(st_file_path)) {
file_path = st_file_path;
} else if (file_exists(ckpt_file_path)) {
file_path = ckpt_file_path;
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return nullptr;
}
auto lora = std::make_shared<LoraModel>(lora_id, backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
auto lora = std::make_shared<LoraModel>(lora_id, backend, lora_path, is_high_noise ? "model.high_noise_" : "", version);
if (!lora->load_from_file(n_threads, lora_tensor_filter)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
LOG_WARN("load lora tensors from %s failed", lora_path.c_str());
return nullptr;
}
@ -998,6 +968,12 @@ public:
lora_state_diff[lora_name] -= curr_multiplier;
}
if (lora_state_diff.empty()) {
return;
}
LOG_INFO("apply lora immediately");
size_t rm = lora_state_diff.size() - lora_state.size();
if (rm != 0) {
LOG_INFO("attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
@ -1027,6 +1003,10 @@ public:
cond_stage_lora_models.clear();
diffusion_lora_models.clear();
first_stage_lora_models.clear();
if (lora_state.empty()) {
return;
}
LOG_INFO("apply lora at runtime");
if (cond_stage_model) {
std::vector<std::shared_ptr<LoraModel>> lora_models;
auto lora_state_diff = lora_state;
@ -1152,27 +1132,26 @@ public:
}
}
std::string apply_loras_from_prompt(const std::string& prompt) {
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
void apply_loras(const sd_lora_t* loras, uint32_t lora_count) {
std::unordered_map<std::string, float> lora_f2m;
for (int i = 0; i < lora_count; i++) {
std::string lora_id = SAFE_STR(loras[i].path);
if (loras[i].is_high_noise) {
lora_id = "|high_noise|" + lora_id;
}
lora_f2m[lora_id] = loras[i].multiplier;
LOG_DEBUG("lora %s:%.2f", lora_id.c_str(), loras[i].multiplier);
}
int64_t t0 = ggml_time_ms();
if (apply_lora_immediately) {
LOG_INFO("apply lora immediately");
apply_loras_immediately(lora_f2m);
} else {
LOG_INFO("apply at runtime");
apply_loras_at_runtime(lora_f2m);
}
int64_t t1 = ggml_time_ms();
if (!lora_f2m.empty()) {
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str());
}
return result_pair.second;
}
ggml_tensor* id_encoder(ggml_context* work_ctx,
@ -1337,10 +1316,17 @@ public:
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
if (preview_mode == PREVIEW_PROJ) {
int64_t patch_sz = 1;
const float(*latent_rgb_proj)[channel] = nullptr;
float* latent_rgb_bias = nullptr;
if (dim == 48) {
if (dim == 128) {
if (sd_version_is_flux2(version)) {
latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias;
patch_sz = 2;
}
} else if (dim == 48) {
if (sd_version_is_wan(version)) {
latent_rgb_proj = wan_22_latent_rgb_proj;
latent_rgb_bias = wan_22_latent_rgb_bias;
@ -1393,12 +1379,15 @@ public:
frames = latents->ne[2];
}
uint8_t* data = (uint8_t*)malloc(frames * width * height * channel * sizeof(uint8_t));
uint32_t img_width = width * patch_sz;
uint32_t img_height = height * patch_sz;
preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, width, height, frames, dim);
uint8_t* data = (uint8_t*)malloc(frames * img_width * img_height * channel * sizeof(uint8_t));
preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, patch_sz);
sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t));
for (int i = 0; i < frames; i++) {
images[i] = {width, height, channel, data + i * width * height * channel};
images[i] = {img_width, img_height, channel, data + i * img_width * img_height * channel};
}
step_callback(step, frames, images, is_noisy, step_callback_data);
free(data);
@ -1683,8 +1672,11 @@ public:
std::vector<struct ggml_tensor*> controls;
if (control_hint != nullptr && control_net != nullptr) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
controls = control_net->controls;
if (control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector)) {
controls = control_net->controls;
} else {
LOG_ERROR("controlnet compute failed");
}
// print_ggml_tensor(controls[12]);
// GGML_ASSERT(0);
}
@ -1716,9 +1708,12 @@ public:
bool skip_model = easycache_before_condition(active_condition, *active_output);
if (!skip_model) {
work_diffusion_model->compute(n_threads,
diffusion_params,
active_output);
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
active_output)) {
LOG_ERROR("diffusion model compute failed");
return nullptr;
}
easycache_after_condition(active_condition, *active_output);
}
@ -1728,8 +1723,11 @@ public:
if (has_unconditioned) {
// uncond
if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls;
if (control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector)) {
controls = control_net->controls;
} else {
LOG_ERROR("controlnet compute failed");
}
}
current_step_skipped = easycache_step_is_skipped();
diffusion_params.controls = controls;
@ -1738,9 +1736,12 @@ public:
diffusion_params.y = uncond.c_vector;
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
if (!skip_uncond) {
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond);
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond)) {
LOG_ERROR("diffusion model compute failed");
return nullptr;
}
easycache_after_condition(&uncond, out_uncond);
}
negative_data = (float*)out_uncond->data;
@ -1753,9 +1754,12 @@ public:
diffusion_params.y = img_cond.c_vector;
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
if (!skip_img_cond) {
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond);
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond)) {
LOG_ERROR("diffusion model compute failed");
return nullptr;
}
easycache_after_condition(&img_cond, out_img_cond);
}
img_cond_data = (float*)out_img_cond->data;
@ -1772,9 +1776,12 @@ public:
diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = cond.c_vector;
diffusion_params.skip_layers = skip_layers;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_skip);
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
&out_skip)) {
LOG_ERROR("diffusion model compute failed");
return nullptr;
}
}
skip_layer_data = (float*)out_skip->data;
}
@ -1837,7 +1844,15 @@ public:
return denoised;
};
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
if (!sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta)) {
LOG_ERROR("Diffusion model sampling failed");
if (control_net) {
control_net->free_control_ctx();
control_net->free_compute_buffer();
}
diffusion_model->free_compute_buffer();
return NULL;
}
if (easycache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
@ -1883,6 +1898,18 @@ public:
return vae_scale_factor;
}
int get_diffusion_model_down_factor() {
int down_factor = 8; // unet
if (sd_version_is_dit(version)) {
if (sd_version_is_wan(version)) {
down_factor = 2;
} else {
down_factor = 1;
}
}
return down_factor;
}
int get_latent_channel() {
int latent_channel = 4;
if (sd_version_is_dit(version)) {
@ -2392,7 +2419,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
}
const char* prediction_to_str[] = {
"default",
"eps",
"v",
"edm_v",
@ -2478,7 +2504,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG;
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
sd_ctx_params->prediction = DEFAULT_PRED;
sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->keep_clip_on_cpu = false;
@ -2511,7 +2537,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"taesd_path: %s\n"
"control_net_path: %s\n"
"lora_model_dir: %s\n"
"embedding_dir: %s\n"
"photo_maker_path: %s\n"
"tensor_type_rules: %s\n"
"vae_decode_only: %s\n"
@ -2542,7 +2567,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->lora_model_dir),
SAFE_STR(sd_ctx_params->embedding_dir),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
BOOL_STR(sd_ctx_params->vae_decode_only),
@ -2793,8 +2817,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int sample_steps = sigmas.size() - 1;
int64_t t0 = ggml_time_ms();
// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
// Photo Maker
std::string prompt_text_only;
@ -3064,10 +3086,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
nullptr,
1.0f,
easycache_params);
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
final_latents.push_back(x_0);
int64_t sampling_end = ggml_time_ms();
if (x_0 != nullptr) {
// print_ggml_tensor(x_0);
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
final_latents.push_back(x_0);
} else {
LOG_ERROR("sampling for image %d/%d failed after %.2fs", b + 1, batch_count, (sampling_end - sampling_start) * 1.0f / 1000);
}
}
if (sd_ctx->sd->free_params_immediately) {
@ -3119,22 +3145,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
if (sd_version_is_dit(sd_ctx->sd->version)) {
if (width % 16 || height % 16) {
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
model_version_to_str[sd_ctx->sd->version],
width,
height);
return nullptr;
}
} else if (width % 64 || height % 64) {
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)",
model_version_to_str[sd_ctx->sd->version],
width,
height);
return nullptr;
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
int spatial_multiple = vae_scale_factor * diffusion_model_down_factor;
int width_offset = align_up_offset(width, spatial_multiple);
int height_offset = align_up_offset(height, spatial_multiple);
if (width_offset > 0 || height_offset > 0) {
width += width_offset;
height += height_offset;
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple);
}
LOG_DEBUG("generate_image %dx%d", width, height);
if (sd_ctx == nullptr || sd_img_gen_params == nullptr) {
return nullptr;
@ -3162,6 +3185,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
size_t t0 = ggml_time_ms();
// Apply lora
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
@ -3405,9 +3431,19 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int frames = sd_vid_gen_params->video_frames;
frames = (frames - 1) / 4 * 4 + 1;
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
int spatial_multiple = vae_scale_factor * diffusion_model_down_factor;
int width_offset = align_up_offset(width, spatial_multiple);
int height_offset = align_up_offset(height, spatial_multiple);
if (width_offset > 0 || height_offset > 0) {
width += width_offset;
height += height_offset;
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_vid_gen_params->width, sd_vid_gen_params->height, width, height, spatial_multiple);
}
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
@ -3461,7 +3497,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t0 = ggml_time_ms();
// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count);
ggml_tensor* init_latent = nullptr;
ggml_tensor* clip_vision_output = nullptr;

View File

@ -65,11 +65,10 @@ enum scheduler_t {
};
enum prediction_t {
DEFAULT_PRED,
EPS_PRED,
V_PRED,
EDM_V_PRED,
SD3_FLOW_PRED,
FLOW_PRED,
FLUX_FLOW_PRED,
FLUX2_FLOW_PRED,
PREDICTION_COUNT
@ -151,6 +150,11 @@ typedef struct {
float rel_size_y;
} sd_tiling_params_t;
typedef struct {
const char* name;
const char* path;
} sd_embedding_t;
typedef struct {
const char* model_path;
const char* clip_l_path;
@ -165,7 +169,8 @@ typedef struct {
const char* taesd_path;
const char* control_net_path;
const char* lora_model_dir;
const char* embedding_dir;
const sd_embedding_t* embeddings;
uint32_t embedding_count;
const char* photo_maker_path;
const char* tensor_type_rules;
bool vae_decode_only;
@ -237,6 +242,14 @@ typedef struct {
} sd_easycache_params_t;
typedef struct {
bool is_high_noise;
float multiplier;
const char* path;
} sd_lora_t;
typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt;
const char* negative_prompt;
int clip_skip;
@ -260,6 +273,8 @@ typedef struct {
} sd_img_gen_params_t;
typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt;
const char* negative_prompt;
int clip_skip;
@ -332,7 +347,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
bool offload_params_to_cpu,
bool direct,
int n_threads);
int n_threads,
int tile_size);
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
@ -354,6 +370,9 @@ SD_API bool preprocess_canny(sd_image_t image,
float strong,
bool inverse);
SD_API const char* sd_commit(void);
SD_API const char* sd_version(void);
#ifdef __cplusplus
}
#endif

4
t5.hpp
View File

@ -820,7 +820,7 @@ struct T5Runner : public GGMLRunner {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
ggml_tensor** output,
@ -828,7 +828,7 @@ struct T5Runner : public GGMLRunner {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids, attention_mask);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
static std::vector<int> _relative_position_bucket(const std::vector<int>& relative_position,

View File

@ -247,7 +247,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
@ -256,7 +256,7 @@ struct TinyAutoEncoder : public GGMLRunner {
return build_graph(z, decode_graph);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};

View File

@ -645,7 +645,7 @@ struct UNetModelRunner : public GGMLRunner {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -665,7 +665,7 @@ struct UNetModelRunner : public GGMLRunner {
return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

View File

@ -9,12 +9,15 @@ struct UpscalerGGML {
std::shared_ptr<ESRGAN> esrgan_upscaler;
std::string esrgan_path;
int n_threads;
bool direct = false;
bool direct = false;
int tile_size = 128;
UpscalerGGML(int n_threads,
bool direct = false)
bool direct = false,
int tile_size = 128)
: n_threads(n_threads),
direct(direct) {
direct(direct),
tile_size(tile_size) {
}
bool load_from_file(const std::string& esrgan_path,
@ -51,7 +54,7 @@ struct UpscalerGGML {
backend = ggml_backend_cpu_init();
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map());
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map());
if (direct) {
esrgan_upscaler->set_conv2d_direct_enabled(true);
}
@ -113,14 +116,15 @@ struct upscaler_ctx_t {
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
bool offload_params_to_cpu,
bool direct,
int n_threads) {
int n_threads,
int tile_size) {
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
if (upscaler_ctx == nullptr) {
return nullptr;
}
std::string esrgan_path(esrgan_path_c_str);
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct, tile_size);
if (upscaler_ctx->upscaler == nullptr) {
return nullptr;
}

View File

@ -95,20 +95,6 @@ bool is_directory(const std::string& path) {
return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY));
}
std::string get_full_path(const std::string& dir, const std::string& filename) {
std::string full_path = dir + "\\" + filename;
WIN32_FIND_DATA find_file_data;
HANDLE hFind = FindFirstFile(full_path.c_str(), &find_file_data);
if (hFind != INVALID_HANDLE_VALUE) {
FindClose(hFind);
return full_path;
} else {
return "";
}
}
#else // Unix
#include <dirent.h>
#include <sys/stat.h>
@ -123,26 +109,6 @@ bool is_directory(const std::string& path) {
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
// TODO: add windows version
std::string get_full_path(const std::string& dir, const std::string& filename) {
DIR* dp = opendir(dir.c_str());
if (dp != nullptr) {
struct dirent* entry;
while ((entry = readdir(dp)) != nullptr) {
if (strcasecmp(entry->d_name, filename.c_str()) == 0) {
closedir(dp);
return dir + "/" + entry->d_name;
}
}
closedir(dp);
}
return "";
}
#endif
// get_num_physical_cores is copy from

1
util.h
View File

@ -22,7 +22,6 @@ int round_up_to(int value, int base);
bool file_exists(const std::string& filename);
bool is_directory(const std::string& path);
std::string get_full_path(const std::string& dir, const std::string& filename);
std::u32string utf8_to_utf32(const std::string& utf8_str);
std::string utf32_to_utf8(const std::u32string& utf32_str);

View File

@ -617,7 +617,7 @@ public:
struct VAE : public GGMLRunner {
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
: GGMLRunner(backend, offload_params_to_cpu) {}
virtual void compute(const int n_threads,
virtual bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
@ -629,7 +629,7 @@ struct VAE : public GGMLRunner {
struct FakeVAE : public VAE {
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
: VAE(backend, offload_params_to_cpu) {}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
@ -641,6 +641,7 @@ struct FakeVAE : public VAE {
float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3);
ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3);
});
return true;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
@ -711,7 +712,7 @@ struct AutoEncoderKL : public VAE {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
@ -722,7 +723,7 @@ struct AutoEncoderKL : public VAE {
};
// ggml_set_f32(z, 0.5f);
// print_ggml_tensor(z);
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

20
version.cpp Normal file
View File

@ -0,0 +1,20 @@
#include "stable-diffusion.h"
#ifndef SDCPP_BUILD_COMMIT
#define SDCPP_BUILD_COMMIT unknown
#endif
#ifndef SDCPP_BUILD_VERSION
#define SDCPP_BUILD_VERSION unknown
#endif
#define STRINGIZE2(x) #x
#define STRINGIZE(x) STRINGIZE2(x)
const char* sd_commit(void) {
return STRINGIZE(SDCPP_BUILD_COMMIT);
}
const char* sd_version(void) {
return STRINGIZE(SDCPP_BUILD_VERSION);
}

15
wan.hpp
View File

@ -1175,7 +1175,7 @@ namespace WAN {
return gf;
}
void compute(const int n_threads,
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
@ -1184,7 +1184,7 @@ namespace WAN {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} else { // chunk 1 result is weird
ae.clear_cache();
int64_t t = z->ne[2];
@ -1193,11 +1193,11 @@ namespace WAN {
return build_graph_partial(z, decode_graph, i);
};
struct ggml_tensor* out = nullptr;
GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
ae.clear_cache();
if (t == 1) {
*output = out;
return;
return res;
}
*output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]);
@ -1221,11 +1221,12 @@ namespace WAN {
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
for (i = 1; i < t; i++) {
GGMLRunner::compute(get_graph, n_threads, true, &out);
res = res || GGMLRunner::compute(get_graph, n_threads, true, &out);
ae.clear_cache();
copy_to_output();
}
free_cache_ctx_and_buffer();
return res;
}
}
@ -2194,7 +2195,7 @@ namespace WAN {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -2209,7 +2210,7 @@ namespace WAN {
return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {

View File

@ -574,7 +574,7 @@ namespace ZImage {
return gf;
}
void compute(int n_threads,
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@ -589,7 +589,7 @@ namespace ZImage {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {