add wan2.1 t2i support

This commit is contained in:
leejet 2025-08-10 17:07:17 +08:00
parent bace0a08c4
commit 1d9ccea41a
11 changed files with 503 additions and 331 deletions

View File

@ -733,7 +733,7 @@ public:
if (text_projection != NULL) { if (text_projection != NULL) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
} else { } else {
LOG_DEBUG("Missing text_projection matrix, assuming identity..."); LOG_DEBUG("identity projection");
} }
return pooled; // [hidden_size, 1, 1] return pooled; // [hidden_size, 1, 1]
} }

View File

@ -22,7 +22,7 @@ struct Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) = 0; bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0; virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0; virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
@ -35,7 +35,7 @@ struct Conditioner {
int height, int height,
int num_input_imgs, int num_input_imgs,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) = 0; bool zero_out_masked = false) = 0;
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx, virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) = 0; const std::string& prompt) = 0;
}; };
@ -410,7 +410,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
set_clip_skip(clip_skip); set_clip_skip(clip_skip);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size] struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
@ -499,7 +499,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
float new_mean = ggml_tensor_mean(result); float new_mean = ggml_tensor_mean(result);
ggml_tensor_scale(result, (original_mean / new_mean)); ggml_tensor_scale(result, (original_mean / new_mean));
} }
if (force_zero_embeddings) { if (zero_out_masked) {
float* vec = (float*)result->data; float* vec = (float*)result->data;
for (int i = 0; i < ggml_nelements(result); i++) { for (int i = 0; i < ggml_nelements(result); i++) {
vec[i] = 0; vec[i] = 0;
@ -563,7 +563,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int height, int height,
int num_input_imgs, int num_input_imgs,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto image_tokens = convert_token_to_id(trigger_word); auto image_tokens = convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){ // if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]); // printf(" image token id is: %d \n", image_tokens[0]);
@ -584,7 +584,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// for(int i = 0; i < clsm.size(); ++i) // for(int i = 0; i < clsm.size(); ++i)
// printf("%d ", clsm[i]?1:0); // printf("%d ", clsm[i]?1:0);
// printf("\n"); // printf("\n");
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings); auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
return std::make_tuple(cond, clsm); return std::make_tuple(cond, clsm);
} }
@ -607,11 +607,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, true); auto tokens_and_weights = tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first; std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second; std::vector<float>& weights = tokens_and_weights.second;
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings); return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
} }
}; };
@ -773,7 +773,7 @@ struct SD3CLIPEmbedder : public Conditioner {
int n_threads, int n_threads,
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights, std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
int clip_skip, int clip_skip,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
set_clip_skip(clip_skip); set_clip_skip(clip_skip);
auto& clip_l_tokens = token_and_weights[0].first; auto& clip_l_tokens = token_and_weights[0].first;
auto& clip_l_weights = token_and_weights[0].second; auto& clip_l_weights = token_and_weights[0].second;
@ -952,7 +952,7 @@ struct SD3CLIPEmbedder : public Conditioner {
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
if (force_zero_embeddings) { if (zero_out_masked) {
float* vec = (float*)chunk_hidden_states->data; float* vec = (float*)chunk_hidden_states->data;
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
vec[i] = 0; vec[i] = 0;
@ -979,9 +979,9 @@ struct SD3CLIPEmbedder : public Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, 77, true); auto tokens_and_weights = tokenize(text, 77, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
@ -992,7 +992,7 @@ struct SD3CLIPEmbedder : public Conditioner {
int height, int height,
int num_input_imgs, int num_input_imgs,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!"); GGML_ASSERT(0 && "Not implemented yet!");
} }
@ -1101,7 +1101,7 @@ struct FluxCLIPEmbedder : public Conditioner {
int n_threads, int n_threads,
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights, std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
int clip_skip, int clip_skip,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
set_clip_skip(clip_skip); set_clip_skip(clip_skip);
auto& clip_l_tokens = token_and_weights[0].first; auto& clip_l_tokens = token_and_weights[0].first;
auto& clip_l_weights = token_and_weights[0].second; auto& clip_l_weights = token_and_weights[0].second;
@ -1173,7 +1173,7 @@ struct FluxCLIPEmbedder : public Conditioner {
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
if (force_zero_embeddings) { if (zero_out_masked) {
float* vec = (float*)chunk_hidden_states->data; float* vec = (float*)chunk_hidden_states->data;
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
vec[i] = 0; vec[i] = 0;
@ -1200,9 +1200,9 @@ struct FluxCLIPEmbedder : public Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, chunk_len, true); auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
@ -1213,7 +1213,7 @@ struct FluxCLIPEmbedder : public Conditioner {
int height, int height,
int num_input_imgs, int num_input_imgs,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!"); GGML_ASSERT(0 && "Not implemented yet!");
} }
@ -1229,6 +1229,7 @@ struct T5CLIPEmbedder : public Conditioner {
size_t chunk_len = 512; size_t chunk_len = 512;
bool use_mask = false; bool use_mask = false;
int mask_pad = 1; int mask_pad = 1;
bool is_umt5 = false;
T5CLIPEmbedder(ggml_backend_t backend, T5CLIPEmbedder(ggml_backend_t backend,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
@ -1318,7 +1319,7 @@ struct T5CLIPEmbedder : public Conditioner {
int n_threads, int n_threads,
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights, std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
int clip_skip, int clip_skip,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto& t5_tokens = std::get<0>(token_and_weights); auto& t5_tokens = std::get<0>(token_and_weights);
auto& t5_weights = std::get<1>(token_and_weights); auto& t5_weights = std::get<1>(token_and_weights);
auto& t5_attn_mask_vec = std::get<2>(token_and_weights); auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
@ -1326,8 +1327,8 @@ struct T5CLIPEmbedder : public Conditioner {
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
struct ggml_tensor* pooled = NULL; // [768,] struct ggml_tensor* pooled = NULL;
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,] struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [n_token]
std::vector<float> hidden_states_vec; std::vector<float> hidden_states_vec;
@ -1368,10 +1369,16 @@ struct T5CLIPEmbedder : public Conditioner {
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
if (force_zero_embeddings) { if (zero_out_masked) {
float* vec = (float*)chunk_hidden_states->data; auto tensor = chunk_hidden_states;
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
vec[i] = 0; for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
if (chunk_mask[i1] < 0.f) {
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
}
}
}
} }
} }
@ -1380,16 +1387,12 @@ struct T5CLIPEmbedder : public Conditioner {
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
} }
if (hidden_states_vec.size() > 0) { GGML_ASSERT(hidden_states_vec.size() > 0);
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
hidden_states = ggml_reshape_2d(work_ctx, hidden_states = ggml_reshape_2d(work_ctx,
hidden_states, hidden_states,
chunk_hidden_states->ne[0], chunk_hidden_states->ne[0],
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
} else {
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
ggml_set_f32(hidden_states, 0.f);
}
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
@ -1403,9 +1406,9 @@ struct T5CLIPEmbedder : public Conditioner {
int width, int width,
int height, int height,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, chunk_len, true); auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
@ -1416,7 +1419,7 @@ struct T5CLIPEmbedder : public Conditioner {
int height, int height,
int num_input_imgs, int num_input_imgs,
int adm_in_channels = -1, int adm_in_channels = -1,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!"); GGML_ASSERT(0 && "Not implemented yet!");
} }

View File

@ -4,6 +4,7 @@
#include "flux.hpp" #include "flux.hpp"
#include "mmdit.hpp" #include "mmdit.hpp"
#include "unet.hpp" #include "unet.hpp"
#include "wan.hpp"
struct DiffusionModel { struct DiffusionModel {
virtual void compute(int n_threads, virtual void compute(int n_threads,
@ -184,4 +185,56 @@ struct FluxModel : public DiffusionModel {
} }
}; };
struct WanModel : public DiffusionModel {
WAN::WanRunner wan;
WanModel(ggml_backend_t backend,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: wan(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}
void alloc_params_buffer() {
wan.alloc_params_buffer();
}
void free_params_buffer() {
wan.free_params_buffer();
}
void free_compute_buffer() {
wan.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
wan.get_param_tensors(tensors, "model.diffusion_model");
}
size_t get_params_buffer_size() {
return wan.get_params_buffer_size();
}
int64_t get_adm_in_channels() {
return 768;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return wan.compute(n_threads, x, timesteps, context, NULL, NULL, output, output_ctx);
}
};
#endif #endif

View File

@ -24,11 +24,14 @@
#define STB_IMAGE_RESIZE_STATIC #define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h" #include "stb_image_resize.h"
#if defined(_WIN32)
#define NOMINMAX
#include <windows.h>
#endif // _WIN32
#define SAFE_STR(s) ((s) ? (s) : "") #define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false") #define BOOL_STR(b) ((b) ? "true" : "false")
#include "t5.hpp"
const char* modes_str[] = { const char* modes_str[] = {
"img_gen", "img_gen",
"vid_gen", "vid_gen",
@ -69,7 +72,6 @@ struct SDParams {
std::string prompt; std::string prompt;
std::string negative_prompt; std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f; float cfg_scale = 7.0f;
float img_cfg_scale = INFINITY; float img_cfg_scale = INFINITY;
float guidance = 3.5f; float guidance = 3.5f;
@ -80,10 +82,7 @@ struct SDParams {
int height = 512; int height = 512;
int batch_count = 1; int batch_count = 1;
int video_frames = 6; int video_frames = 1;
int motion_bucket_id = 127;
int fps = 6;
float augmentation_level = 0.f;
sample_method_t sample_method = EULER_A; sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT; schedule_t schedule = DEFAULT;
@ -147,7 +146,6 @@ void print_params(SDParams params) {
printf(" strength(control): %.2f\n", params.control_strength); printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str()); printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale); printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale); printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale); printf(" slg_scale: %.2f\n", params.slg_scale);
@ -243,6 +241,42 @@ void print_usage(int argc, const char* argv[]) {
printf(" -v, --verbose print extra info\n"); printf(" -v, --verbose print extra info\n");
} }
#if defined(_WIN32)
static std::string utf16_to_utf8(const std::wstring& wstr) {
if (wstr.empty())
return {};
int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
nullptr, 0, nullptr, nullptr);
if (size_needed <= 0)
throw std::runtime_error("UTF-16 to UTF-8 conversion failed");
std::string utf8(size_needed, 0);
WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
(char*)utf8.data(), size_needed, nullptr, nullptr);
return utf8;
}
static std::string argv_to_utf8(int index, const char** argv) {
int argc;
wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc);
if (!argv_w)
throw std::runtime_error("Failed to parse command line");
std::string result;
if (index < argc) {
result = utf16_to_utf8(argv_w[index]);
}
LocalFree(argv_w);
return result;
}
#else // Linux / macOS
static std::string argv_to_utf8(int index, const char** argv) {
return std::string(argv[index]);
}
#endif
struct StringOption { struct StringOption {
std::string short_name; std::string short_name;
std::string long_name; std::string long_name;
@ -299,7 +333,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
invalid_arg = true; invalid_arg = true;
break; break;
} }
*option.target = std::string(argv[i]); *option.target = argv_to_utf8(i, argv);
} }
} }
if (invalid_arg) { if (invalid_arg) {
@ -746,17 +780,9 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
// params.verbose = true;
// sd_set_log_callback(sd_log_cb, (void*)&params);
// T5Embedder::load_from_file_and_test(argv[1]);
// return 0;
parse_args(argc, argv, params); parse_args(argc, argv, params);
sd_guidance_params_t guidance_params = {params.cfg_scale, sd_guidance_params_t guidance_params = {params.cfg_scale,
params.img_cfg_scale, params.img_cfg_scale,
params.min_cfg,
params.guidance, params.guidance,
{ {
params.skip_layers.data(), params.skip_layers.data(),
@ -791,11 +817,6 @@ int main(int argc, const char* argv[]) {
} }
} }
if (params.mode == VID_GEN) {
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
return 1;
}
bool vae_decode_only = true; bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL; uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL; uint8_t* control_image_buffer = NULL;
@ -992,18 +1013,19 @@ int main(int argc, const char* argv[]) {
expected_num_results = params.batch_count; expected_num_results = params.batch_count;
} else if (params.mode == VID_GEN) { } else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = { sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
guidance_params,
input_image, input_image,
params.width, params.width,
params.height, params.height,
guidance_params,
params.sample_method, params.sample_method,
params.sample_steps, params.sample_steps,
params.eta,
params.strength, params.strength,
params.seed, params.seed,
params.video_frames, params.video_frames,
params.motion_bucket_id,
params.fps,
params.augmentation_level,
}; };
results = generate_video(sd_ctx, &vid_gen_params); results = generate_video(sd_ctx, &vid_gen_params);

View File

@ -323,17 +323,27 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
return image_data; return image_data;
} }
__STATIC_INLINE__ uint8_t* sd_tensor_to_mul_image(struct ggml_tensor* input, int idx) { __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx, bool video = false) {
int64_t width = input->ne[0]; int64_t width = input->ne[0];
int64_t height = input->ne[1]; int64_t height = input->ne[1];
int64_t channels = input->ne[2]; int64_t channels;
if (video) {
channels = input->ne[3];
} else {
channels = input->ne[2];
}
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32); GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
uint8_t* image_data = (uint8_t*)malloc(width * height * channels); uint8_t* image_data = (uint8_t*)malloc(width * height * channels);
for (int iy = 0; iy < height; iy++) { for (int ih = 0; ih < height; ih++) {
for (int ix = 0; ix < width; ix++) { for (int iw = 0; iw < width; iw++) {
for (int k = 0; k < channels; k++) { for (int ic = 0; ic < channels; ic++) {
float value = ggml_tensor_get_f32(input, ix, iy, k, idx); float value;
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f); if (video) {
value = ggml_tensor_get_f32(input, iw, ih, idx, ic);
} else {
value = ggml_tensor_get_f32(input, iw, ih, ic, idx);
}
*(image_data + ih * width * channels + iw * channels + ic) = (uint8_t)(value * 255.0f);
} }
} }
} }

View File

@ -1055,7 +1055,11 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
// LOG_DEBUG("%s", name.c_str()); // LOG_DEBUG("%s", name.c_str());
TensorStorage tensor_storage(prefix + name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset); if (!starts_with(name, prefix)) {
name = prefix + name;
}
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
@ -1195,7 +1199,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
n_dims = 1; n_dims = 1;
} }
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); if (!starts_with(name, prefix)) {
name = prefix + name;
}
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne(); tensor_storage.reverse_ne();
size_t tensor_data_size = end - begin; size_t tensor_data_size = end - begin;
@ -1580,7 +1588,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
reader.tensor_storage.file_index = file_index; reader.tensor_storage.file_index = file_index;
// if(strcmp(prefix.c_str(), "scarlett") == 0) // if(strcmp(prefix.c_str(), "scarlett") == 0)
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
reader.tensor_storage.name = prefix + reader.tensor_storage.name; std::string name = reader.tensor_storage.name;
if (!starts_with(name, prefix)) {
name = prefix + name;
}
reader.tensor_storage.name = name;
tensor_storages.push_back(reader.tensor_storage); tensor_storages.push_back(reader.tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type); add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
@ -1654,10 +1666,10 @@ SDVersion ModelLoader::get_sd_version() {
bool is_xl = false; bool is_xl = false;
bool is_flux = false; bool is_flux = false;
bool is_wan = false;
#define found_family (is_xl || is_flux)
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (!found_family) { if (!(is_xl || is_flux)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true; is_flux = true;
if (input_block_checked) { if (input_block_checked) {
@ -1667,6 +1679,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3; return VERSION_SD3;
} }
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
return VERSION_WAN2;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true; is_unet = true;
if (has_multiple_encoders) { if (has_multiple_encoders) {
@ -1701,7 +1716,7 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") { if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage; input_block_weight = tensor_storage;
input_block_checked = true; input_block_checked = true;
if (found_family) { if (is_xl || is_flux) {
break; break;
} }
} }

View File

@ -31,8 +31,7 @@ enum SDVersion {
VERSION_SD3, VERSION_SD3,
VERSION_FLUX, VERSION_FLUX,
VERSION_FLUX_FILL, VERSION_FLUX_FILL,
VERSION_WAN_2_1, VERSION_WAN2,
VERSION_WAN_2_2,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -72,7 +71,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
} }
static inline bool sd_version_is_wan(SDVersion version) { static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN_2_1 || version == VERSION_WAN_2_2) { if (version == VERSION_WAN2) {
return true; return true;
} }
return false; return false;

View File

@ -36,7 +36,9 @@ const char* model_version_to_str[] = {
"SVD", "SVD",
"SD3.x", "SD3.x",
"Flux", "Flux",
"Flux Fill"}; "Flux Fill",
"Wan 2.x",
};
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"Euler A", "Euler A",
@ -50,7 +52,8 @@ const char* sampling_methods_str[] = {
"iPNDM_v", "iPNDM_v",
"LCM", "LCM",
"DDIM \"trailing\"", "DDIM \"trailing\"",
"TCD"}; "TCD",
};
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
@ -93,7 +96,7 @@ public:
std::shared_ptr<Conditioner> cond_stage_model; std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd
std::shared_ptr<DiffusionModel> diffusion_model; std::shared_ptr<DiffusionModel> diffusion_model;
std::shared_ptr<AutoEncoderKL> first_stage_model; std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage; std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net; std::shared_ptr<ControlNet> control_net;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model; std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
@ -274,10 +277,10 @@ public:
model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
} }
LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); LOG_INFO("Weight type: %s", ggml_type_name(model_wtype));
LOG_INFO("Conditioner weight type: %s", conditioner_wtype != GGML_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype));
LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype));
LOG_INFO("VAE weight type: %s", vae_wtype != GGML_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??"); LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype));
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
@ -293,34 +296,25 @@ public:
} else if (sd_version_is_sd3(version)) { } else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f; scale_factor = 1.5305f;
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
scale_factor = 0.3611; scale_factor = 0.3611f;
// TODO: shift_factor // TODO: shift_factor
} else if (sd_version_is_wan(version)) {
scale_factor = 1.0f;
} }
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
if (version == VERSION_SVD) { {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
diffusion_model = std::make_shared<UNetModel>(backend, model_loader.tensor_storages_types, version);
diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors);
first_stage_model = std::make_shared<AutoEncoderKL>(backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, true, version);
LOG_DEBUG("vae_decode_only %d", vae_decode_only);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
clip_backend = backend; clip_backend = backend;
bool use_t5xxl = false; bool use_t5xxl = false;
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
use_t5xxl = true; use_t5xxl = true;
} }
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { if (!ggml_backend_is_cpu(backend) && use_t5xxl) {
clip_on_cpu = true; LOG_WARN(
LOG_INFO("set clip_on_cpu to true"); "!!!It appears that you are using the T5 model. Some backends may encounter issues with it."
"If you notice that the generated images are completely black,"
"try running the T5 model on the CPU using the --clip-on-cpu parameter.");
} }
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend"); LOG_INFO("CLIP: Using CPU backend");
@ -357,7 +351,18 @@ public:
version, version,
sd_ctx_params->diffusion_flash_attn, sd_ctx_params->diffusion_flash_attn,
sd_ctx_params->chroma_use_dit_mask); sd_ctx_params->chroma_use_dit_mask);
} else { } else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
model_loader.tensor_storages_types,
-1,
true,
1,
true);
diffusion_model = std::make_shared<WanModel>(backend,
model_loader.tensor_storages_types,
version,
sd_ctx_params->diffusion_flash_attn);
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
@ -382,13 +387,21 @@ public:
diffusion_model->alloc_params_buffer(); diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors); diffusion_model->get_param_tensors(tensors);
if (!use_tiny_autoencoder) {
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend"); LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init(); vae_backend = ggml_backend_cpu_init();
} else { } else {
vae_backend = backend; vae_backend = backend;
} }
if (sd_version_is_wan(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
model_loader.tensor_storages_types,
"first_stage_model",
vae_decode_only);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
"first_stage_model", "first_stage_model",
@ -398,7 +411,7 @@ public:
first_stage_model->alloc_params_buffer(); first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model"); first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else { } else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
"decoder.layers", "decoder.layers",
vae_decode_only, vae_decode_only,
@ -485,11 +498,7 @@ public:
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0); // LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
if (version == VERSION_SVD) { {
// diffusion_model->test();
// first_stage_model->test();
// return false;
} else {
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size(); size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
size_t unet_params_mem_size = diffusion_model->get_params_buffer_size(); size_t unet_params_mem_size = diffusion_model->get_params_buffer_size();
size_t vae_params_mem_size = 0; size_t vae_params_mem_size = 0;
@ -594,6 +603,9 @@ public:
} }
} }
denoiser = std::make_shared<FluxFlowDenoiser>(shift); denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (is_using_v_parameterization) { } else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode"); LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>(); denoiser = std::make_shared<CompVisVDenoiser>();
@ -733,9 +745,9 @@ public:
size_t rm = lora_state_diff.size() - lora_state.size(); size_t rm = lora_state_diff.size() - lora_state.size();
if (rm != 0) { if (rm != 0) {
LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); LOG_INFO("attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
} else { } else {
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); LOG_INFO("attempting to apply %lu LoRAs", lora_state.size());
} }
for (auto& kv : lora_state_diff) { for (auto& kv : lora_state_diff) {
@ -745,6 +757,21 @@ public:
curr_lora_state = lora_state; curr_lora_state = lora_state;
} }
std::string apply_loras_from_prompt(const std::string& prompt) {
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
}
int64_t t0 = ggml_time_ms();
apply_loras(lora_f2m);
int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str());
return result_pair.second;
}
ggml_tensor* id_encoder(ggml_context* work_ctx, ggml_tensor* id_encoder(ggml_context* work_ctx,
ggml_tensor* init_img, ggml_tensor* init_img,
ggml_tensor* prompts_embeds, ggml_tensor* prompts_embeds,
@ -762,12 +789,12 @@ public:
int fps = 6, int fps = 6,
int motion_bucket_id = 127, int motion_bucket_id = 127,
float augmentation_level = 0.f, float augmentation_level = 0.f,
bool force_zero_embeddings = false) { bool zero_out_masked = false) {
// c_crossattn // c_crossattn
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* c_crossattn = NULL; struct ggml_tensor* c_crossattn = NULL;
{ {
if (force_zero_embeddings) { if (zero_out_masked) {
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim); c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim);
ggml_set_f32(c_crossattn, 0.f); ggml_set_f32(c_crossattn, 0.f);
} else { } else {
@ -790,7 +817,7 @@ public:
// c_concat // c_concat
struct ggml_tensor* c_concat = NULL; struct ggml_tensor* c_concat = NULL;
{ {
if (force_zero_embeddings) { if (zero_out_masked) {
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1); c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
ggml_set_f32(c_concat, 0.f); ggml_set_f32(c_concat, 0.f);
} else { } else {
@ -855,28 +882,14 @@ public:
float img_cfg_scale = guidance.img_cfg; float img_cfg_scale = guidance.img_cfg;
float slg_scale = guidance.slg.scale; float slg_scale = guidance.slg.scale;
float min_cfg = guidance.min_cfg; LOG_DEBUG("cfg_scale %.2f", cfg_scale);
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) { if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance..."); LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
img_cfg_scale = cfg_scale; img_cfg_scale = cfg_scale;
} }
LOG_DEBUG("Sample");
struct ggml_init_params params;
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
for (int i = 1; i < 4; i++) {
data_size *= init_latent->ne[i];
}
data_size += 1024;
params.mem_size = data_size * 3;
params.mem_buffer = NULL;
params.no_alloc = false;
ggml_context* tmp_ctx = ggml_init(params);
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(noise);
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
copy_ggml_tensor(x, init_latent); copy_ggml_tensor(x, init_latent);
x = denoiser->noise_scaling(sigmas[0], noise, x); x = denoiser->noise_scaling(sigmas[0], noise, x);
@ -922,9 +935,9 @@ public:
float c_in = scaling[2]; float c_in = scaling[2];
float t = denoiser->sigma_to_t(sigma); float t = denoiser->sigma_to_t(sigma);
std::vector<float> timesteps_vec(x->ne[3], t); // [N, ] std::vector<float> timesteps_vec(1, t); // [N, ]
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
std::vector<float> guidance_vec(x->ne[3], guidance.distilled_guidance); std::vector<float> guidance_vec(1, guidance.distilled_guidance);
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
copy_ggml_tensor(noised_input, input); copy_ggml_tensor(noised_input, input);
@ -1038,11 +1051,6 @@ public:
float latent_result = positive_data[i]; float latent_result = positive_data[i];
if (has_unconditioned) { if (has_unconditioned) {
// out_uncond + cfg_scale * (out_cond - out_uncond) // out_uncond + cfg_scale * (out_cond - out_uncond)
int64_t ne3 = out_cond->ne[3];
if (min_cfg != cfg_scale && ne3 != 1) {
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
} else {
if (has_img_cond) { if (has_img_cond) {
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) // out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
@ -1050,7 +1058,6 @@ public:
// img_cfg_scale == cfg_scale // img_cfg_scale == cfg_scale
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
} }
}
} else if (has_img_cond) { } else if (has_img_cond) {
// img_cfg_scale == 1 // img_cfg_scale == 1
latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]);
@ -1085,6 +1092,7 @@ public:
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]);
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
if (control_net) { if (control_net) {
@ -1101,7 +1109,6 @@ public:
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]); ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
ggml_tensor_set_f32_randn(noise, rng); ggml_tensor_set_f32_randn(noise, rng);
// noise = load_tensor_from_file(work_ctx, "noise.bin");
{ {
float mean = 0; float mean = 0;
float logvar = 0; float logvar = 0;
@ -1127,9 +1134,9 @@ public:
return latent; return latent;
} }
ggml_tensor* compute_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode) { ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
int64_t W = x->ne[0]; int64_t W = x->ne[0] / 8;
int64_t H = x->ne[1]; int64_t H = x->ne[1] / 8;
int64_t C = 8; int64_t C = 8;
if (use_tiny_autoencoder) { if (use_tiny_autoencoder) {
C = 4; C = 4;
@ -1140,59 +1147,106 @@ public:
C = 32; C = 32;
} }
} }
ggml_tensor* result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, ggml_tensor* result = ggml_new_tensor_4d(work_ctx,
decode ? (W * 8) : (W / 8), // width GGML_TYPE_F32,
decode ? (H * 8) : (H / 8), // height W,
decode ? 3 : C, H,
x->ne[3]); // channels C,
x->ne[3]);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
if (decode) {
ggml_tensor_scale(x, 1.0f / scale_factor);
} else {
ggml_tensor_scale_input(x); ggml_tensor_scale_input(x);
first_stage_model->compute(n_threads, x, false, &result, NULL);
first_stage_model->free_compute_buffer();
} else {
tae_first_stage->compute(n_threads, x, false, &result, NULL);
tae_first_stage->free_compute_buffer();
} }
if (vae_tiling && decode) { // TODO: support tiling vae encode
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
return result;
}
void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
GGML_ASSERT(latent->ne[3] == 16);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i];
float std_ = latents_std_vec[i];
for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_tensor_get_f32(latent, l, k, j, i);
value = value * std_ / scale_factor + mean;
ggml_tensor_set_f32(latent, value, l, k, j, i);
}
}
}
}
} else {
ggml_tensor_scale(latent, 1.0f / scale_factor);
}
}
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
int64_t W = x->ne[0] * 8;
int64_t H = x->ne[1] * 8;
int64_t C = 3;
ggml_tensor* result;
if (decode_video) {
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
x->ne[2],
3);
} else {
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
C,
x->ne[3]);
}
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
LOG_DEBUG("scale_factor %.2f", scale_factor);
process_latent_out(x);
if (vae_tiling && !decode_video) {
// split latent in 32x32 tiles and compute in several steps // split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out); first_stage_model->compute(n_threads, in, true, &out, NULL);
}; };
sd_tiling(x, result, 8, 32, 0.5f, on_tiling); sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
} else { } else {
first_stage_model->compute(n_threads, x, decode, &result); first_stage_model->compute(n_threads, x, true, &result, NULL);
} }
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
if (decode) {
ggml_tensor_scale_output(result); ggml_tensor_scale_output(result);
}
} else { } else {
if (vae_tiling && decode) { // TODO: support tiling vae encode if (vae_tiling && !decode_video) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, decode, &out); tae_first_stage->compute(n_threads, in, true, &out);
}; };
sd_tiling(x, result, 8, 64, 0.5f, on_tiling); sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
} else { } else {
tae_first_stage->compute(n_threads, x, decode, &result); tae_first_stage->compute(n_threads, x, true, &result);
} }
tae_first_stage->free_compute_buffer(); tae_first_stage->free_compute_buffer();
} }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae [mode: %s] graph completed, taking %.2fs", decode ? "DECODE" : "ENCODE", (t1 - t0) * 1.0f / 1000); LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (decode) {
ggml_tensor_clamp(result, 0.0f, 1.0f); ggml_tensor_clamp(result, 0.0f, 1.0f);
}
return result; return result;
} }
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
return compute_first_stage(work_ctx, x, false);
}
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
return compute_first_stage(work_ctx, x, true);
}
}; };
/*================================================= SD API ==================================================*/ /*================================================= SD API ==================================================*/
@ -1373,7 +1427,6 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t)); memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t));
sd_img_gen_params->clip_skip = -1; sd_img_gen_params->clip_skip = -1;
sd_img_gen_params->guidance.txt_cfg = 7.0f; sd_img_gen_params->guidance.txt_cfg = 7.0f;
sd_img_gen_params->guidance.min_cfg = 1.0f;
sd_img_gen_params->guidance.img_cfg = INFINITY; sd_img_gen_params->guidance.img_cfg = INFINITY;
sd_img_gen_params->guidance.distilled_guidance = 3.5f; sd_img_gen_params->guidance.distilled_guidance = 3.5f;
sd_img_gen_params->guidance.slg.layer_count = 0; sd_img_gen_params->guidance.slg.layer_count = 0;
@ -1406,7 +1459,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"clip_skip: %d\n" "clip_skip: %d\n"
"txt_cfg: %.2f\n" "txt_cfg: %.2f\n"
"img_cfg: %.2f\n" "img_cfg: %.2f\n"
"min_cfg: %.2f\n"
"distilled_guidance: %.2f\n" "distilled_guidance: %.2f\n"
"slg.layer_count: %zu\n" "slg.layer_count: %zu\n"
"slg.layer_start: %.2f\n" "slg.layer_start: %.2f\n"
@ -1431,7 +1483,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->clip_skip, sd_img_gen_params->clip_skip,
sd_img_gen_params->guidance.txt_cfg, sd_img_gen_params->guidance.txt_cfg,
sd_img_gen_params->guidance.img_cfg, sd_img_gen_params->guidance.img_cfg,
sd_img_gen_params->guidance.min_cfg,
sd_img_gen_params->guidance.distilled_guidance, sd_img_gen_params->guidance.distilled_guidance,
sd_img_gen_params->guidance.slg.layer_count, sd_img_gen_params->guidance.slg.layer_count,
sd_img_gen_params->guidance.slg.layer_start, sd_img_gen_params->guidance.slg.layer_start,
@ -1457,7 +1508,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t)); memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
sd_vid_gen_params->guidance.txt_cfg = 7.0f; sd_vid_gen_params->guidance.txt_cfg = 7.0f;
sd_vid_gen_params->guidance.min_cfg = 1.0f;
sd_vid_gen_params->guidance.img_cfg = INFINITY; sd_vid_gen_params->guidance.img_cfg = INFINITY;
sd_vid_gen_params->guidance.distilled_guidance = 3.5f; sd_vid_gen_params->guidance.distilled_guidance = 3.5f;
sd_vid_gen_params->guidance.slg.layer_count = 0; sd_vid_gen_params->guidance.slg.layer_count = 0;
@ -1471,9 +1521,6 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->strength = 0.75f; sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1; sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->motion_bucket_id = 127;
sd_vid_gen_params->fps = 6;
sd_vid_gen_params->augmentation_level = 0.f;
} }
struct sd_ctx_t { struct sd_ctx_t {
@ -1545,21 +1592,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int sample_steps = sigmas.size() - 1; int sample_steps = sigmas.size() - 1;
// Apply lora
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
}
prompt = result_pair.second;
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
sd_ctx->sd->apply_loras(lora_f2m); // Apply lora
int64_t t1 = ggml_time_ms(); prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
// Photo Maker // Photo Maker
std::string prompt_text_only; std::string prompt_text_only;
@ -1568,9 +1603,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<bool> class_tokens_mask; std::vector<bool> class_tokens_mask;
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) { if (!sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads); sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true; sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
@ -1625,7 +1660,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
else else
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
} }
t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads, prompt, sd_ctx->sd->n_threads, prompt,
clip_skip, clip_skip,
@ -1642,7 +1677,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
// print_ggml_tensor(id_embeds, true, "id_embeds:"); // print_ggml_tensor(id_embeds, true, "id_embeds:");
} }
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_model->free_params_buffer(); sd_ctx->sd->pmid_model->free_params_buffer();
@ -1679,9 +1714,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
SDCondition uncond; SDCondition uncond;
if (guidance.txt_cfg != 1.0 || if (guidance.txt_cfg != 1.0 ||
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) { (sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
bool force_zero_embeddings = false; bool zero_out_masked = false;
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
force_zero_embeddings = true; zero_out_masked = true;
} }
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads, sd_ctx->sd->n_threads,
@ -1690,9 +1725,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
width, width,
height, height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(), sd_ctx->sd->diffusion_model->get_adm_in_channels(),
force_zero_embeddings); zero_out_masked);
} }
t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
@ -1780,9 +1815,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
} }
// Disable min_cfg
guidance.min_cfg = guidance.txt_cfg;
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t, x_t,
noise, noise,
@ -1799,8 +1831,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
id_cond, id_cond,
ref_latents, ref_latents,
denoise_mask); denoise_mask);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0); // print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@ -1852,16 +1882,25 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
ggml_context* work_ctx, ggml_context* work_ctx,
int width, int width,
int height) { int height,
int frames = 1,
bool video = false) {
int C = 4; int C = 4;
if (sd_version_is_sd3(sd_ctx->sd->version)) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
C = 16;
} }
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor* init_latent;
if (video) {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, frames, C);
} else {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
}
if (sd_version_is_sd3(sd_ctx->sd->version)) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f); ggml_set_f32(init_latent, 0.0609f);
} else if (sd_version_is_flux(sd_ctx->sd->version)) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
@ -1877,11 +1916,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
int height = sd_img_gen_params->height; int height = sd_img_gen_params->height;
if (sd_version_is_dit(sd_ctx->sd->version)) { if (sd_version_is_dit(sd_ctx->sd->version)) {
if (width % 16 || height % 16) { if (width % 16 || height % 16) {
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height); LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
model_version_to_str[sd_ctx->sd->version],
width,
height);
return NULL; return NULL;
} }
} else if (width % 64 || height % 64) { } else if (width % 64 || height % 64) {
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height); LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)",
model_version_to_str[sd_ctx->sd->version],
width,
height);
return NULL; return NULL;
} }
LOG_DEBUG("generate_image %dx%d", width, height); LOG_DEBUG("generate_image %dx%d", width, height);
@ -2095,20 +2140,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
return NULL; return NULL;
} }
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
int width = sd_vid_gen_params->width; int width = sd_vid_gen_params->width;
int height = sd_vid_gen_params->height; int height = sd_vid_gen_params->height;
LOG_INFO("img2vid %dx%d", width, height); int frames = sd_vid_gen_params->video_frames;
LOG_INFO("img2vid %dx%dx%d", width, height, frames);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 50 MB
params.mem_size += width * height * 3 * sizeof(float) * sd_vid_gen_params->video_frames; params.mem_size += width * height * frames * 3 * sizeof(float);
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = false; params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size); // LOG_DEBUG("mem_size %u ", params.mem_size);
// draft context
struct ggml_context* work_ctx = ggml_init(params); struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) { if (!work_ctx) {
LOG_ERROR("ggml_init() failed"); LOG_ERROR("ggml_init() failed");
@ -2124,90 +2172,100 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->get_svd_condition(work_ctx, ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
sd_vid_gen_params->init_image, int sample_steps = sigmas.size() - 1;
// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
// Get learned condition
bool zero_out_masked = true;
t0 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
prompt,
sd_vid_gen_params->clip_skip,
width, width,
height, height,
sd_vid_gen_params->fps, sd_ctx->sd->diffusion_model->get_adm_in_channels(),
sd_vid_gen_params->motion_bucket_id, zero_out_masked);
sd_vid_gen_params->augmentation_level); SDCondition uncond;
if (sd_vid_gen_params->guidance.txt_cfg != 1.0) {
auto uc_crossattn = ggml_dup_tensor(work_ctx, cond.c_crossattn); uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
ggml_set_f32(uc_crossattn, 0.f); sd_ctx->sd->n_threads,
negative_prompt,
auto uc_concat = ggml_dup_tensor(work_ctx, cond.c_concat); sd_vid_gen_params->clip_skip,
ggml_set_f32(uc_concat, 0.f); width,
height,
auto uc_vector = ggml_dup_tensor(work_ctx, cond.c_vector); sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
SDCondition uncond = SDCondition(uc_crossattn, uc_vector, uc_concat); }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->clip_vision->free_params_buffer(); sd_ctx->sd->cond_stage_model->free_params_buffer();
} }
sd_ctx->sd->rng->manual_seed(seed);
int C = 4;
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames); int T = frames;
ggml_set_f32(x_t, 0.f); int C = 16;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames); struct ggml_tensor* final_latent;
// Sample
{
int64_t sampling_start = ggml_time_ms();
struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
LOG_INFO("sampling using %s method", sampling_methods_str[sd_vid_gen_params->sample_method]); final_latent = sd_ctx->sd->sample(work_ctx,
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t, x_t,
noise, noise,
cond, cond,
uncond, uncond,
{}, {},
{}, NULL,
0.f, 0,
sd_vid_gen_params->guidance, sd_vid_gen_params->guidance,
0.f, sd_vid_gen_params->eta,
sd_vid_gen_params->sample_method, sd_vid_gen_params->sample_method,
sigmas, sigmas,
-1, -1,
SDCondition(NULL, NULL, NULL)); {});
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
}
int64_t t2 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer(); sd_ctx->sd->diffusion_model->free_params_buffer();
} }
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0); int64_t t3 = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (t3 - t1) * 1.0f / 1000);
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->first_stage_model->free_params_buffer(); sd_ctx->sd->first_stage_model->free_params_buffer();
} }
if (img == NULL) {
ggml_free(work_ctx);
return NULL;
}
sd_image_t* result_images = (sd_image_t*)calloc(sd_vid_gen_params->video_frames, sizeof(sd_image_t)); sd_image_t* result_images = (sd_image_t*)calloc(T, sizeof(sd_image_t));
if (result_images == NULL) { if (result_images == NULL) {
ggml_free(work_ctx); ggml_free(work_ctx);
return NULL; return NULL;
} }
for (size_t i = 0; i < sd_vid_gen_params->video_frames; i++) { for (size_t i = 0; i < T; i++) {
auto img_i = ggml_view_3d(work_ctx, img, img->ne[0], img->ne[1], img->ne[2], img->nb[1], img->nb[2], img->nb[3] * i); result_images[i].width = final_latent->ne[0] * 8;
result_images[i].height = final_latent->ne[1] * 8;
result_images[i].width = width;
result_images[i].height = height;
result_images[i].channel = 3; result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(img_i); result_images[i].data = sd_tensor_to_image(vid, i, true);
} }
ggml_free(work_ctx); ggml_free(work_ctx);
int64_t t3 = ggml_time_ms(); LOG_INFO("img2vid completed in %.2fs", (t4 - t0) * 1.0f / 1000);
LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000);
return result_images; return result_images;
} }

View File

@ -157,7 +157,6 @@ typedef struct {
typedef struct { typedef struct {
float txt_cfg; float txt_cfg;
float img_cfg; float img_cfg;
float min_cfg;
float distilled_guidance; float distilled_guidance;
sd_slg_params_t slg; sd_slg_params_t slg;
} sd_guidance_params_t; } sd_guidance_params_t;
@ -187,18 +186,19 @@ typedef struct {
} sd_img_gen_params_t; } sd_img_gen_params_t;
typedef struct { typedef struct {
const char* prompt;
const char* negative_prompt;
int clip_skip;
sd_guidance_params_t guidance;
sd_image_t init_image; sd_image_t init_image;
int width; int width;
int height; int height;
sd_guidance_params_t guidance;
enum sample_method_t sample_method; enum sample_method_t sample_method;
int sample_steps; int sample_steps;
float eta;
float strength; float strength;
int64_t seed; int64_t seed;
int video_frames; int video_frames;
int motion_bucket_id;
int fps;
float augmentation_level;
} sd_vid_gen_params_t; } sd_vid_gen_params_t;
typedef struct sd_ctx_t sd_ctx_t; typedef struct sd_ctx_t sd_ctx_t;

15
vae.hpp
View File

@ -520,7 +520,18 @@ public:
} }
}; };
struct AutoEncoderKL : public GGMLRunner { struct VAE : public GGMLRunner {
VAE(ggml_backend_t backend)
: GGMLRunner(backend) {}
virtual void compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
};
struct AutoEncoderKL : public VAE {
bool decode_only = true; bool decode_only = true;
AutoencodingEngine ae; AutoencodingEngine ae;
@ -530,7 +541,7 @@ struct AutoEncoderKL : public GGMLRunner {
bool decode_only = false, bool decode_only = false,
bool use_video_decoder = false, bool use_video_decoder = false,
SDVersion version = VERSION_SD1) SDVersion version = VERSION_SD1)
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) { : decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend) {
ae.init(params_ctx, tensor_types, prefix); ae.init(params_ctx, tensor_types, prefix);
} }

19
wan.hpp
View File

@ -7,6 +7,7 @@
#include "flux.hpp" #include "flux.hpp"
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "rope.hpp" #include "rope.hpp"
#include "vae.hpp"
namespace WAN { namespace WAN {
@ -522,7 +523,6 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) { for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i]; in_dim = dims[i];
out_dim = dims[i + 1]; out_dim = dims[i + 1];
LOG_DEBUG("in_dim %u out_dim %u", in_dim, out_dim);
if (i == 1 || i == 2 || i == 3) { if (i == 1 || i == 2 || i == 3) {
in_dim = in_dim / 2; in_dim = in_dim / 2;
} }
@ -726,7 +726,7 @@ namespace WAN {
} }
}; };
struct WanVAERunner : public GGMLRunner { struct WanVAERunner : public VAE {
bool decode_only = true; bool decode_only = true;
WanVAE ae; WanVAE ae;
@ -734,7 +734,7 @@ namespace WAN {
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "", const std::string prefix = "",
bool decode_only = false) bool decode_only = false)
: decode_only(decode_only), ae(decode_only), GGMLRunner(backend) { : decode_only(decode_only), ae(decode_only), VAE(backend) {
ae.init(params_ctx, tensor_types, prefix); ae.init(params_ctx, tensor_types, prefix);
} }
@ -1217,13 +1217,13 @@ namespace WAN {
int64_t axes_dim_sum = 128; int64_t axes_dim_sum = 128;
}; };
class WanModel : public GGMLBlock { class Wan : public GGMLBlock {
protected: protected:
WanParams params; WanParams params;
public: public:
WanModel() {} Wan() {}
WanModel(WanParams params) Wan(WanParams params)
: params(params) { : params(params) {
// patch_embedding // patch_embedding
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size)); blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
@ -1418,14 +1418,15 @@ namespace WAN {
struct WanRunner : public GGMLRunner { struct WanRunner : public GGMLRunner {
public: public:
WanParams wan_params; WanParams wan_params;
WanModel wan; Wan wan;
std::vector<float> pe_vec; std::vector<float> pe_vec;
SDVersion version; SDVersion version;
WanRunner(ggml_backend_t backend, WanRunner(ggml_backend_t backend,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "", const std::string prefix = "",
SDVersion version = VERSION_WAN_2_1) SDVersion version = VERSION_WAN2,
bool flash_attn = false)
: GGMLRunner(backend) { : GGMLRunner(backend) {
wan_params.num_layers = 0; wan_params.num_layers = 0;
for (auto pair : tensor_types) { for (auto pair : tensor_types) {
@ -1476,7 +1477,7 @@ namespace WAN {
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers); GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
} }
wan = WanModel(wan_params); wan = Wan(wan_params);
wan.init(params_ctx, tensor_types, prefix); wan.init(params_ctx, tensor_types, prefix);
} }