mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 21:38:58 +00:00
Compare commits
4 Commits
5865b5e703
...
985aedda32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
985aedda32 | ||
|
|
3f3610b5cd | ||
|
|
118683de8a | ||
|
|
bcc9c0d0b3 |
4
clip.hpp
4
clip.hpp
@ -963,7 +963,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* input_ids,
|
||||
int num_custom_embeddings,
|
||||
void* custom_embeddings_data,
|
||||
@ -975,7 +975,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -703,7 +703,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
ggml_tensor* pixel_values,
|
||||
bool return_pooled,
|
||||
int clip_skip,
|
||||
@ -712,7 +712,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(pixel_values, return_pooled, clip_skip);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
10
control.hpp
10
control.hpp
@ -414,7 +414,7 @@ struct ControlNet : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* hint,
|
||||
struct ggml_tensor* timesteps,
|
||||
@ -430,8 +430,12 @@ struct ControlNet : public GGMLRunner {
|
||||
return build_graph(x, hint, timesteps, context, y);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
guided_hint_cached = true;
|
||||
bool res = GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
if (res) {
|
||||
// cache guided_hint
|
||||
guided_hint_cached = true;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
|
||||
47
denoiser.hpp
47
denoiser.hpp
@ -666,7 +666,7 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
|
||||
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
|
||||
|
||||
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
||||
static void sample_k_diffusion(sample_method_t method,
|
||||
static bool sample_k_diffusion(sample_method_t method,
|
||||
denoise_cb_t model,
|
||||
ggml_context* work_ctx,
|
||||
ggml_tensor* x,
|
||||
@ -685,6 +685,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigma, i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// d = (x - denoised) / sigma
|
||||
{
|
||||
@ -738,6 +741,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigma, i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// d = (x - denoised) / sigma
|
||||
{
|
||||
@ -769,6 +775,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int i = 0; i < steps; i++) {
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigmas[i], -(i + 1));
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// d = (x - denoised) / sigma
|
||||
{
|
||||
@ -803,7 +812,10 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
}
|
||||
|
||||
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||
float d2 = (vec_x2[j] - vec_denoised[j]) / sigmas[i + 1];
|
||||
vec_d[j] = (vec_d[j] + d2) / 2;
|
||||
@ -819,6 +831,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int i = 0; i < steps; i++) {
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// d = (x - denoised) / sigma
|
||||
{
|
||||
@ -855,7 +870,10 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
}
|
||||
|
||||
ggml_tensor* denoised = model(x2, sigma_mid, i + 1);
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||
float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid;
|
||||
vec_x[j] = vec_x[j] + d2 * dt_2;
|
||||
@ -871,6 +889,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int i = 0; i < steps; i++) {
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get_ancestral_step
|
||||
float sigma_up = std::min(sigmas[i + 1],
|
||||
@ -907,6 +928,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
}
|
||||
|
||||
ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Second half-step
|
||||
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||
@ -937,6 +961,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int i = 0; i < steps; i++) {
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float t = t_fn(sigmas[i]);
|
||||
float t_next = t_fn(sigmas[i + 1]);
|
||||
@ -976,6 +1003,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
for (int i = 0; i < steps; i++) {
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float t = t_fn(sigmas[i]);
|
||||
float t_next = t_fn(sigmas[i + 1]);
|
||||
@ -1026,7 +1056,10 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
|
||||
// Denoising step
|
||||
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
// d_cur = (x_cur - denoised) / sigma
|
||||
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
|
||||
float* vec_d_cur = (float*)d_cur->data;
|
||||
@ -1169,6 +1202,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
|
||||
// denoise
|
||||
ggml_tensor* denoised = model(x, sigma, i + 1);
|
||||
if (denoised == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// x = denoised
|
||||
{
|
||||
@ -1561,8 +1597,9 @@ static void sample_k_diffusion(sample_method_t method,
|
||||
|
||||
default:
|
||||
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
||||
abort();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif // __DENOISER_HPP__
|
||||
|
||||
@ -27,7 +27,7 @@ struct DiffusionParams {
|
||||
|
||||
struct DiffusionModel {
|
||||
virtual std::string get_desc() = 0;
|
||||
virtual void compute(int n_threads,
|
||||
virtual bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) = 0;
|
||||
@ -87,7 +87,7 @@ struct UNetModel : public DiffusionModel {
|
||||
unet.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
@ -148,7 +148,7 @@ struct MMDiTModel : public DiffusionModel {
|
||||
mmdit.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
@ -210,7 +210,7 @@ struct FluxModel : public DiffusionModel {
|
||||
flux.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
@ -277,7 +277,7 @@ struct WanModel : public DiffusionModel {
|
||||
wan.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
@ -343,7 +343,7 @@ struct QwenImageModel : public DiffusionModel {
|
||||
qwen_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
@ -406,7 +406,7 @@ struct ZImageModel : public DiffusionModel {
|
||||
z_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
|
||||
@ -353,14 +353,14 @@ struct ESRGAN : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -409,18 +409,18 @@ struct SDCliParams {
|
||||
return -1;
|
||||
}
|
||||
const char* preview = argv[index];
|
||||
int preview_method = -1;
|
||||
int preview_found = -1;
|
||||
for (int m = 0; m < PREVIEW_COUNT; m++) {
|
||||
if (!strcmp(preview, previews_str[m])) {
|
||||
preview_method = m;
|
||||
preview_found = m;
|
||||
}
|
||||
}
|
||||
if (preview_method == -1) {
|
||||
if (preview_found == -1) {
|
||||
fprintf(stderr, "error: preview method %s\n",
|
||||
preview);
|
||||
return -1;
|
||||
}
|
||||
preview_method = (preview_t)preview_method;
|
||||
preview_method = (preview_t)preview_found;
|
||||
return 1;
|
||||
};
|
||||
|
||||
@ -515,7 +515,7 @@ struct SDContextParams {
|
||||
bool chroma_use_t5_mask = false;
|
||||
int chroma_t5_mask_pad = 1;
|
||||
|
||||
prediction_t prediction = DEFAULT_PRED;
|
||||
prediction_t prediction = PREDICTION_COUNT;
|
||||
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
|
||||
|
||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
|
||||
4
flux.hpp
4
flux.hpp
@ -1413,7 +1413,7 @@ namespace Flux {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -1434,7 +1434,7 @@ namespace Flux {
|
||||
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
@ -1938,25 +1938,35 @@ public:
|
||||
return ggml_get_tensor(cache_ctx, name.c_str());
|
||||
}
|
||||
|
||||
void compute(get_graph_cb_t get_graph,
|
||||
bool compute(get_graph_cb_t get_graph,
|
||||
int n_threads,
|
||||
bool free_compute_buffer_immediately = true,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) {
|
||||
if (!offload_params_to_runtime_backend()) {
|
||||
LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str());
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
if (!alloc_compute_buffer(get_graph)) {
|
||||
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
|
||||
return false;
|
||||
}
|
||||
alloc_compute_buffer(get_graph);
|
||||
reset_compute_ctx();
|
||||
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
||||
if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) {
|
||||
LOG_ERROR("%s alloc compute graph failed", get_desc().c_str());
|
||||
return false;
|
||||
}
|
||||
copy_data_to_backend_tensor();
|
||||
if (ggml_backend_is_cpu(runtime_backend)) {
|
||||
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
|
||||
}
|
||||
|
||||
ggml_backend_graph_compute(runtime_backend, gf);
|
||||
ggml_status status = ggml_backend_graph_compute(runtime_backend, gf);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status));
|
||||
return false;
|
||||
}
|
||||
#ifdef GGML_PERF
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
@ -1974,6 +1984,7 @@ public:
|
||||
if (free_compute_buffer_immediately) {
|
||||
free_compute_buffer();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
|
||||
4
llm.hpp
4
llm.hpp
@ -1191,7 +1191,7 @@ namespace LLM {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* input_ids,
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||
std::set<int> out_layers,
|
||||
@ -1200,7 +1200,7 @@ namespace LLM {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(input_ids, image_embeds, out_layers);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
|
||||
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
|
||||
|
||||
@ -894,7 +894,7 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -910,7 +910,7 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
return build_graph(x, timesteps, context, y, skip_layers);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
@ -104,8 +104,8 @@ const char* unused_tensors[] = {
|
||||
"embedding_manager",
|
||||
"denoiser.sigmas",
|
||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||
"ztsnr", // Found in some SDXL vpred models
|
||||
"edm_vpred.sigma_min", // Found in CosXL
|
||||
"ztsnr", // Found in some SDXL vpred models
|
||||
"edm_vpred.sigma_min", // Found in CosXL
|
||||
// TODO: find another way to avoid the "unknown tensor" for these two
|
||||
// "edm_vpred.sigma_max", // Used to detect CosXL
|
||||
// "v_pred", // Used to detect SDXL vpred models
|
||||
|
||||
4
pmid.hpp
4
pmid.hpp
@ -548,7 +548,7 @@ public:
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* id_pixel_values,
|
||||
struct ggml_tensor* prompt_embeds,
|
||||
struct ggml_tensor* id_embeds,
|
||||
@ -561,7 +561,7 @@ public:
|
||||
};
|
||||
|
||||
// GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds);
|
||||
GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -588,7 +588,7 @@ namespace Qwen {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -603,7 +603,7 @@ namespace Qwen {
|
||||
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
@ -707,7 +707,7 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
|
||||
LOG_DEBUG("finished loaded file");
|
||||
|
||||
{
|
||||
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
|
||||
@ -782,8 +782,59 @@ public:
|
||||
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
|
||||
}
|
||||
|
||||
if (sd_ctx_params->prediction != DEFAULT_PRED) {
|
||||
switch (sd_ctx_params->prediction) {
|
||||
// init denoiser
|
||||
{
|
||||
prediction_t pred_type = sd_ctx_params->prediction;
|
||||
float flow_shift = sd_ctx_params->flow_shift;
|
||||
|
||||
if (pred_type == PREDICTION_COUNT) {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
pred_type = V_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
pred_type = EDM_V_PRED;
|
||||
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
pred_type = V_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
} else if (sd_version_is_sd3(version) ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_z_image(version)) {
|
||||
pred_type = FLOW_PRED;
|
||||
if (flow_shift == INFINITY) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
flow_shift = 5.f;
|
||||
} else {
|
||||
flow_shift = 3.f;
|
||||
}
|
||||
}
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
pred_type = FLUX_FLOW_PRED;
|
||||
if (flow_shift == INFINITY) {
|
||||
flow_shift = 1.0f; // TODO: validate
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
flow_shift = 1.15f;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
pred_type = FLUX2_FLOW_PRED;
|
||||
} else {
|
||||
pred_type = EPS_PRED;
|
||||
}
|
||||
}
|
||||
|
||||
switch (pred_type) {
|
||||
case EPS_PRED:
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
break;
|
||||
@ -795,22 +846,14 @@ public:
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
break;
|
||||
case SD3_FLOW_PRED: {
|
||||
case FLOW_PRED: {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
|
||||
break;
|
||||
}
|
||||
case FLUX_FLOW_PRED: {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
|
||||
break;
|
||||
}
|
||||
case FLUX2_FLOW_PRED: {
|
||||
@ -819,93 +862,21 @@ public:
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
|
||||
LOG_ERROR("Unknown predition type %i", pred_type);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
// TODO: V_PREDICTION_EDM
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (sd_version_is_sd3(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
|
||||
if (comp_vis_denoiser) {
|
||||
for (int i = 0; i < TIMESTEPS; i++) {
|
||||
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
|
||||
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 1.0f; // TODO: validate
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
shift = 1.15f;
|
||||
}
|
||||
}
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
LOG_INFO("running in Flux2 FLOW mode");
|
||||
denoiser = std::make_shared<Flux2FlowDenoiser>();
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 5.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_z_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0f;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
} else if (is_using_edm_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
} else {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
}
|
||||
}
|
||||
|
||||
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
|
||||
if (comp_vis_denoiser) {
|
||||
for (int i = 0; i < TIMESTEPS; i++) {
|
||||
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
|
||||
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DEBUG("finished loaded file");
|
||||
ggml_free(ctx);
|
||||
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
|
||||
return true;
|
||||
@ -998,6 +969,12 @@ public:
|
||||
lora_state_diff[lora_name] -= curr_multiplier;
|
||||
}
|
||||
|
||||
if (lora_state_diff.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_INFO("apply lora immediately");
|
||||
|
||||
size_t rm = lora_state_diff.size() - lora_state.size();
|
||||
if (rm != 0) {
|
||||
LOG_INFO("attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
|
||||
@ -1027,6 +1004,10 @@ public:
|
||||
cond_stage_lora_models.clear();
|
||||
diffusion_lora_models.clear();
|
||||
first_stage_lora_models.clear();
|
||||
if (lora_state.empty()) {
|
||||
return;
|
||||
}
|
||||
LOG_INFO("apply lora at runtime");
|
||||
if (cond_stage_model) {
|
||||
std::vector<std::shared_ptr<LoraModel>> lora_models;
|
||||
auto lora_state_diff = lora_state;
|
||||
@ -1161,10 +1142,8 @@ public:
|
||||
}
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (apply_lora_immediately) {
|
||||
LOG_INFO("apply lora immediately");
|
||||
apply_loras_immediately(lora_f2m);
|
||||
} else {
|
||||
LOG_INFO("apply at runtime");
|
||||
apply_loras_at_runtime(lora_f2m);
|
||||
}
|
||||
int64_t t1 = ggml_time_ms();
|
||||
@ -1683,8 +1662,11 @@ public:
|
||||
std::vector<struct ggml_tensor*> controls;
|
||||
|
||||
if (control_hint != nullptr && control_net != nullptr) {
|
||||
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
|
||||
controls = control_net->controls;
|
||||
if (control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector)) {
|
||||
controls = control_net->controls;
|
||||
} else {
|
||||
LOG_ERROR("controlnet compute failed");
|
||||
}
|
||||
// print_ggml_tensor(controls[12]);
|
||||
// GGML_ASSERT(0);
|
||||
}
|
||||
@ -1716,9 +1698,12 @@ public:
|
||||
|
||||
bool skip_model = easycache_before_condition(active_condition, *active_output);
|
||||
if (!skip_model) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
active_output);
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
active_output)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(active_condition, *active_output);
|
||||
}
|
||||
|
||||
@ -1728,8 +1713,11 @@ public:
|
||||
if (has_unconditioned) {
|
||||
// uncond
|
||||
if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
|
||||
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
|
||||
controls = control_net->controls;
|
||||
if (control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector)) {
|
||||
controls = control_net->controls;
|
||||
} else {
|
||||
LOG_ERROR("controlnet compute failed");
|
||||
}
|
||||
}
|
||||
current_step_skipped = easycache_step_is_skipped();
|
||||
diffusion_params.controls = controls;
|
||||
@ -1738,9 +1726,12 @@ public:
|
||||
diffusion_params.y = uncond.c_vector;
|
||||
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
|
||||
if (!skip_uncond) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_uncond);
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_uncond)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(&uncond, out_uncond);
|
||||
}
|
||||
negative_data = (float*)out_uncond->data;
|
||||
@ -1753,9 +1744,12 @@ public:
|
||||
diffusion_params.y = img_cond.c_vector;
|
||||
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
|
||||
if (!skip_img_cond) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_img_cond);
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_img_cond)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(&img_cond, out_img_cond);
|
||||
}
|
||||
img_cond_data = (float*)out_img_cond->data;
|
||||
@ -1772,9 +1766,12 @@ public:
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = cond.c_vector;
|
||||
diffusion_params.skip_layers = skip_layers;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_skip);
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_skip)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
skip_layer_data = (float*)out_skip->data;
|
||||
}
|
||||
@ -1837,7 +1834,15 @@ public:
|
||||
return denoised;
|
||||
};
|
||||
|
||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
|
||||
if (!sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta)) {
|
||||
LOG_ERROR("Diffusion model sampling failed");
|
||||
if (control_net) {
|
||||
control_net->free_control_ctx();
|
||||
control_net->free_compute_buffer();
|
||||
}
|
||||
diffusion_model->free_compute_buffer();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (easycache_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
@ -2392,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
|
||||
}
|
||||
|
||||
const char* prediction_to_str[] = {
|
||||
"default",
|
||||
"eps",
|
||||
"v",
|
||||
"edm_v",
|
||||
@ -2478,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->wtype = SD_TYPE_COUNT;
|
||||
sd_ctx_params->rng_type = CUDA_RNG;
|
||||
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
|
||||
sd_ctx_params->prediction = DEFAULT_PRED;
|
||||
sd_ctx_params->prediction = PREDICTION_COUNT;
|
||||
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
|
||||
sd_ctx_params->offload_params_to_cpu = false;
|
||||
sd_ctx_params->keep_clip_on_cpu = false;
|
||||
@ -3064,10 +3068,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
nullptr,
|
||||
1.0f,
|
||||
easycache_params);
|
||||
// print_ggml_tensor(x_0);
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
final_latents.push_back(x_0);
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
if (x_0 != nullptr) {
|
||||
// print_ggml_tensor(x_0);
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
final_latents.push_back(x_0);
|
||||
} else {
|
||||
LOG_ERROR("sampling for image %d/%d failed after %.2fs", b + 1, batch_count, (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
|
||||
@ -65,11 +65,10 @@ enum scheduler_t {
|
||||
};
|
||||
|
||||
enum prediction_t {
|
||||
DEFAULT_PRED,
|
||||
EPS_PRED,
|
||||
V_PRED,
|
||||
EDM_V_PRED,
|
||||
SD3_FLOW_PRED,
|
||||
FLOW_PRED,
|
||||
FLUX_FLOW_PRED,
|
||||
FLUX2_FLOW_PRED,
|
||||
PREDICTION_COUNT
|
||||
|
||||
4
t5.hpp
4
t5.hpp
@ -820,7 +820,7 @@ struct T5Runner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* input_ids,
|
||||
struct ggml_tensor* attention_mask,
|
||||
ggml_tensor** output,
|
||||
@ -828,7 +828,7 @@ struct T5Runner : public GGMLRunner {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(input_ids, attention_mask);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
|
||||
static std::vector<int> _relative_position_bucket(const std::vector<int>& relative_position,
|
||||
|
||||
4
tae.hpp
4
tae.hpp
@ -247,7 +247,7 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -256,7 +256,7 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
4
unet.hpp
4
unet.hpp
@ -645,7 +645,7 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -665,7 +665,7 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
9
vae.hpp
9
vae.hpp
@ -617,7 +617,7 @@ public:
|
||||
struct VAE : public GGMLRunner {
|
||||
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
virtual void compute(const int n_threads,
|
||||
virtual bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -629,7 +629,7 @@ struct VAE : public GGMLRunner {
|
||||
struct FakeVAE : public VAE {
|
||||
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: VAE(backend, offload_params_to_cpu) {}
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -641,6 +641,7 @@ struct FakeVAE : public VAE {
|
||||
float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3);
|
||||
ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
|
||||
@ -711,7 +712,7 @@ struct AutoEncoderKL : public VAE {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -722,7 +723,7 @@ struct AutoEncoderKL : public VAE {
|
||||
};
|
||||
// ggml_set_f32(z, 0.5f);
|
||||
// print_ggml_tensor(z);
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
15
wan.hpp
15
wan.hpp
@ -1175,7 +1175,7 @@ namespace WAN {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(const int n_threads,
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -1184,7 +1184,7 @@ namespace WAN {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
} else { // chunk 1 result is weird
|
||||
ae.clear_cache();
|
||||
int64_t t = z->ne[2];
|
||||
@ -1193,11 +1193,11 @@ namespace WAN {
|
||||
return build_graph_partial(z, decode_graph, i);
|
||||
};
|
||||
struct ggml_tensor* out = nullptr;
|
||||
GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
|
||||
bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
|
||||
ae.clear_cache();
|
||||
if (t == 1) {
|
||||
*output = out;
|
||||
return;
|
||||
return res;
|
||||
}
|
||||
|
||||
*output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]);
|
||||
@ -1221,11 +1221,12 @@ namespace WAN {
|
||||
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
|
||||
|
||||
for (i = 1; i < t; i++) {
|
||||
GGMLRunner::compute(get_graph, n_threads, true, &out);
|
||||
res = res || GGMLRunner::compute(get_graph, n_threads, true, &out);
|
||||
ae.clear_cache();
|
||||
copy_to_output();
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2194,7 +2195,7 @@ namespace WAN {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -2209,7 +2210,7 @@ namespace WAN {
|
||||
return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
@ -30,7 +30,7 @@ namespace ZImage {
|
||||
JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
|
||||
: head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
|
||||
blocks["qkv"] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false);
|
||||
float scale = 1.f;
|
||||
float scale = 1.f;
|
||||
#if GGML_USE_HIP
|
||||
// Prevent NaN issues with certain ROCm setups
|
||||
scale = 1.f / 16.f;
|
||||
@ -574,7 +574,7 @@ namespace ZImage {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
@ -589,7 +589,7 @@ namespace ZImage {
|
||||
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user