mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan2.1 t2i support
This commit is contained in:
parent
bace0a08c4
commit
1d9ccea41a
6
clip.hpp
6
clip.hpp
@ -179,9 +179,9 @@ public:
|
|||||||
|
|
||||||
auto it = encoder.find(utf8_to_utf32("img</w>"));
|
auto it = encoder.find(utf8_to_utf32("img</w>"));
|
||||||
if (it != encoder.end()) {
|
if (it != encoder.end()) {
|
||||||
LOG_DEBUG(" trigger word img already in vocab");
|
LOG_DEBUG("trigger word img already in vocab");
|
||||||
} else {
|
} else {
|
||||||
LOG_DEBUG(" trigger word img not in vocab yet");
|
LOG_DEBUG("trigger word img not in vocab yet");
|
||||||
}
|
}
|
||||||
|
|
||||||
int rank = 0;
|
int rank = 0;
|
||||||
@ -733,7 +733,7 @@ public:
|
|||||||
if (text_projection != NULL) {
|
if (text_projection != NULL) {
|
||||||
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
||||||
} else {
|
} else {
|
||||||
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
|
LOG_DEBUG("identity projection");
|
||||||
}
|
}
|
||||||
return pooled; // [hidden_size, 1, 1]
|
return pooled; // [hidden_size, 1, 1]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ struct Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) = 0;
|
bool zero_out_masked = false) = 0;
|
||||||
virtual void alloc_params_buffer() = 0;
|
virtual void alloc_params_buffer() = 0;
|
||||||
virtual void free_params_buffer() = 0;
|
virtual void free_params_buffer() = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||||
@ -35,7 +35,7 @@ struct Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) = 0;
|
bool zero_out_masked = false) = 0;
|
||||||
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||||
const std::string& prompt) = 0;
|
const std::string& prompt) = 0;
|
||||||
};
|
};
|
||||||
@ -410,7 +410,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
set_clip_skip(clip_skip);
|
set_clip_skip(clip_skip);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
|
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
|
||||||
@ -499,7 +499,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
float new_mean = ggml_tensor_mean(result);
|
float new_mean = ggml_tensor_mean(result);
|
||||||
ggml_tensor_scale(result, (original_mean / new_mean));
|
ggml_tensor_scale(result, (original_mean / new_mean));
|
||||||
}
|
}
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
float* vec = (float*)result->data;
|
float* vec = (float*)result->data;
|
||||||
for (int i = 0; i < ggml_nelements(result); i++) {
|
for (int i = 0; i < ggml_nelements(result); i++) {
|
||||||
vec[i] = 0;
|
vec[i] = 0;
|
||||||
@ -563,7 +563,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto image_tokens = convert_token_to_id(trigger_word);
|
auto image_tokens = convert_token_to_id(trigger_word);
|
||||||
// if(image_tokens.size() == 1){
|
// if(image_tokens.size() == 1){
|
||||||
// printf(" image token id is: %d \n", image_tokens[0]);
|
// printf(" image token id is: %d \n", image_tokens[0]);
|
||||||
@ -584,7 +584,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
// for(int i = 0; i < clsm.size(); ++i)
|
// for(int i = 0; i < clsm.size(); ++i)
|
||||||
// printf("%d ", clsm[i]?1:0);
|
// printf("%d ", clsm[i]?1:0);
|
||||||
// printf("\n");
|
// printf("\n");
|
||||||
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
|
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
||||||
return std::make_tuple(cond, clsm);
|
return std::make_tuple(cond, clsm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -607,11 +607,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto tokens_and_weights = tokenize(text, true);
|
auto tokens_and_weights = tokenize(text, true);
|
||||||
std::vector<int>& tokens = tokens_and_weights.first;
|
std::vector<int>& tokens = tokens_and_weights.first;
|
||||||
std::vector<float>& weights = tokens_and_weights.second;
|
std::vector<float>& weights = tokens_and_weights.second;
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
|
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -773,7 +773,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
int n_threads,
|
int n_threads,
|
||||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
set_clip_skip(clip_skip);
|
set_clip_skip(clip_skip);
|
||||||
auto& clip_l_tokens = token_and_weights[0].first;
|
auto& clip_l_tokens = token_and_weights[0].first;
|
||||||
auto& clip_l_weights = token_and_weights[0].second;
|
auto& clip_l_weights = token_and_weights[0].second;
|
||||||
@ -952,7 +952,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
float* vec = (float*)chunk_hidden_states->data;
|
float* vec = (float*)chunk_hidden_states->data;
|
||||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
||||||
vec[i] = 0;
|
vec[i] = 0;
|
||||||
@ -979,9 +979,9 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto tokens_and_weights = tokenize(text, 77, true);
|
auto tokens_and_weights = tokenize(text, 77, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
@ -992,7 +992,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
GGML_ASSERT(0 && "Not implemented yet!");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1101,7 +1101,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
int n_threads,
|
int n_threads,
|
||||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
set_clip_skip(clip_skip);
|
set_clip_skip(clip_skip);
|
||||||
auto& clip_l_tokens = token_and_weights[0].first;
|
auto& clip_l_tokens = token_and_weights[0].first;
|
||||||
auto& clip_l_weights = token_and_weights[0].second;
|
auto& clip_l_weights = token_and_weights[0].second;
|
||||||
@ -1173,7 +1173,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
float* vec = (float*)chunk_hidden_states->data;
|
float* vec = (float*)chunk_hidden_states->data;
|
||||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
||||||
vec[i] = 0;
|
vec[i] = 0;
|
||||||
@ -1200,9 +1200,9 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
@ -1213,7 +1213,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
GGML_ASSERT(0 && "Not implemented yet!");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1229,6 +1229,7 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
size_t chunk_len = 512;
|
size_t chunk_len = 512;
|
||||||
bool use_mask = false;
|
bool use_mask = false;
|
||||||
int mask_pad = 1;
|
int mask_pad = 1;
|
||||||
|
bool is_umt5 = false;
|
||||||
|
|
||||||
T5CLIPEmbedder(ggml_backend_t backend,
|
T5CLIPEmbedder(ggml_backend_t backend,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
@ -1318,7 +1319,7 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
int n_threads,
|
int n_threads,
|
||||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
|
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
|
||||||
int clip_skip,
|
int clip_skip,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto& t5_tokens = std::get<0>(token_and_weights);
|
auto& t5_tokens = std::get<0>(token_and_weights);
|
||||||
auto& t5_weights = std::get<1>(token_and_weights);
|
auto& t5_weights = std::get<1>(token_and_weights);
|
||||||
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
|
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
|
||||||
@ -1326,8 +1327,8 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
|
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
|
||||||
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
|
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
|
||||||
struct ggml_tensor* pooled = NULL; // [768,]
|
struct ggml_tensor* pooled = NULL;
|
||||||
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
|
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [n_token]
|
||||||
|
|
||||||
std::vector<float> hidden_states_vec;
|
std::vector<float> hidden_states_vec;
|
||||||
|
|
||||||
@ -1368,10 +1369,16 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
float* vec = (float*)chunk_hidden_states->data;
|
auto tensor = chunk_hidden_states;
|
||||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||||
vec[i] = 0;
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||||
|
if (chunk_mask[i1] < 0.f) {
|
||||||
|
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1380,16 +1387,12 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
|
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hidden_states_vec.size() > 0) {
|
GGML_ASSERT(hidden_states_vec.size() > 0);
|
||||||
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
|
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
|
||||||
hidden_states = ggml_reshape_2d(work_ctx,
|
hidden_states = ggml_reshape_2d(work_ctx,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
chunk_hidden_states->ne[0],
|
chunk_hidden_states->ne[0],
|
||||||
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
||||||
} else {
|
|
||||||
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
|
|
||||||
ggml_set_f32(hidden_states, 0.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
|
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
|
||||||
|
|
||||||
@ -1403,9 +1406,9 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
@ -1416,7 +1419,7 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
GGML_ASSERT(0 && "Not implemented yet!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "mmdit.hpp"
|
#include "mmdit.hpp"
|
||||||
#include "unet.hpp"
|
#include "unet.hpp"
|
||||||
|
#include "wan.hpp"
|
||||||
|
|
||||||
struct DiffusionModel {
|
struct DiffusionModel {
|
||||||
virtual void compute(int n_threads,
|
virtual void compute(int n_threads,
|
||||||
@ -184,4 +185,56 @@ struct FluxModel : public DiffusionModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct WanModel : public DiffusionModel {
|
||||||
|
WAN::WanRunner wan;
|
||||||
|
|
||||||
|
WanModel(ggml_backend_t backend,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
SDVersion version = VERSION_FLUX,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: wan(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() {
|
||||||
|
wan.alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_params_buffer() {
|
||||||
|
wan.free_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_compute_buffer() {
|
||||||
|
wan.free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||||
|
wan.get_param_tensors(tensors, "model.diffusion_model");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_params_buffer_size() {
|
||||||
|
return wan.get_params_buffer_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_adm_in_channels() {
|
||||||
|
return 768;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(int n_threads,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
|
struct ggml_tensor* y,
|
||||||
|
struct ggml_tensor* guidance,
|
||||||
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
|
int num_video_frames = -1,
|
||||||
|
std::vector<struct ggml_tensor*> controls = {},
|
||||||
|
float control_strength = 0.f,
|
||||||
|
struct ggml_tensor** output = NULL,
|
||||||
|
struct ggml_context* output_ctx = NULL,
|
||||||
|
std::vector<int> skip_layers = std::vector<int>()) {
|
||||||
|
return wan.compute(n_threads, x, timesteps, context, NULL, NULL, output, output_ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -24,11 +24,14 @@
|
|||||||
#define STB_IMAGE_RESIZE_STATIC
|
#define STB_IMAGE_RESIZE_STATIC
|
||||||
#include "stb_image_resize.h"
|
#include "stb_image_resize.h"
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define NOMINMAX
|
||||||
|
#include <windows.h>
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
#define SAFE_STR(s) ((s) ? (s) : "")
|
#define SAFE_STR(s) ((s) ? (s) : "")
|
||||||
#define BOOL_STR(b) ((b) ? "true" : "false")
|
#define BOOL_STR(b) ((b) ? "true" : "false")
|
||||||
|
|
||||||
#include "t5.hpp"
|
|
||||||
|
|
||||||
const char* modes_str[] = {
|
const char* modes_str[] = {
|
||||||
"img_gen",
|
"img_gen",
|
||||||
"vid_gen",
|
"vid_gen",
|
||||||
@ -69,7 +72,6 @@ struct SDParams {
|
|||||||
|
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string negative_prompt;
|
std::string negative_prompt;
|
||||||
float min_cfg = 1.0f;
|
|
||||||
float cfg_scale = 7.0f;
|
float cfg_scale = 7.0f;
|
||||||
float img_cfg_scale = INFINITY;
|
float img_cfg_scale = INFINITY;
|
||||||
float guidance = 3.5f;
|
float guidance = 3.5f;
|
||||||
@ -80,10 +82,7 @@ struct SDParams {
|
|||||||
int height = 512;
|
int height = 512;
|
||||||
int batch_count = 1;
|
int batch_count = 1;
|
||||||
|
|
||||||
int video_frames = 6;
|
int video_frames = 1;
|
||||||
int motion_bucket_id = 127;
|
|
||||||
int fps = 6;
|
|
||||||
float augmentation_level = 0.f;
|
|
||||||
|
|
||||||
sample_method_t sample_method = EULER_A;
|
sample_method_t sample_method = EULER_A;
|
||||||
schedule_t schedule = DEFAULT;
|
schedule_t schedule = DEFAULT;
|
||||||
@ -147,7 +146,6 @@ void print_params(SDParams params) {
|
|||||||
printf(" strength(control): %.2f\n", params.control_strength);
|
printf(" strength(control): %.2f\n", params.control_strength);
|
||||||
printf(" prompt: %s\n", params.prompt.c_str());
|
printf(" prompt: %s\n", params.prompt.c_str());
|
||||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||||
printf(" min_cfg: %.2f\n", params.min_cfg);
|
|
||||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||||
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
|
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
|
||||||
printf(" slg_scale: %.2f\n", params.slg_scale);
|
printf(" slg_scale: %.2f\n", params.slg_scale);
|
||||||
@ -243,6 +241,42 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" -v, --verbose print extra info\n");
|
printf(" -v, --verbose print extra info\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
static std::string utf16_to_utf8(const std::wstring& wstr) {
|
||||||
|
if (wstr.empty())
|
||||||
|
return {};
|
||||||
|
int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
|
||||||
|
nullptr, 0, nullptr, nullptr);
|
||||||
|
if (size_needed <= 0)
|
||||||
|
throw std::runtime_error("UTF-16 to UTF-8 conversion failed");
|
||||||
|
|
||||||
|
std::string utf8(size_needed, 0);
|
||||||
|
WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
|
||||||
|
(char*)utf8.data(), size_needed, nullptr, nullptr);
|
||||||
|
return utf8;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string argv_to_utf8(int index, const char** argv) {
|
||||||
|
int argc;
|
||||||
|
wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc);
|
||||||
|
if (!argv_w)
|
||||||
|
throw std::runtime_error("Failed to parse command line");
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
if (index < argc) {
|
||||||
|
result = utf16_to_utf8(argv_w[index]);
|
||||||
|
}
|
||||||
|
LocalFree(argv_w);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // Linux / macOS
|
||||||
|
static std::string argv_to_utf8(int index, const char** argv) {
|
||||||
|
return std::string(argv[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
struct StringOption {
|
struct StringOption {
|
||||||
std::string short_name;
|
std::string short_name;
|
||||||
std::string long_name;
|
std::string long_name;
|
||||||
@ -299,7 +333,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
|
|||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
*option.target = std::string(argv[i]);
|
*option.target = argv_to_utf8(i, argv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (invalid_arg) {
|
if (invalid_arg) {
|
||||||
@ -746,17 +780,9 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
SDParams params;
|
SDParams params;
|
||||||
// params.verbose = true;
|
|
||||||
// sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
|
||||||
|
|
||||||
// T5Embedder::load_from_file_and_test(argv[1]);
|
|
||||||
// return 0;
|
|
||||||
|
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
|
|
||||||
sd_guidance_params_t guidance_params = {params.cfg_scale,
|
sd_guidance_params_t guidance_params = {params.cfg_scale,
|
||||||
params.img_cfg_scale,
|
params.img_cfg_scale,
|
||||||
params.min_cfg,
|
|
||||||
params.guidance,
|
params.guidance,
|
||||||
{
|
{
|
||||||
params.skip_layers.data(),
|
params.skip_layers.data(),
|
||||||
@ -791,11 +817,6 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.mode == VID_GEN) {
|
|
||||||
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool vae_decode_only = true;
|
bool vae_decode_only = true;
|
||||||
uint8_t* input_image_buffer = NULL;
|
uint8_t* input_image_buffer = NULL;
|
||||||
uint8_t* control_image_buffer = NULL;
|
uint8_t* control_image_buffer = NULL;
|
||||||
@ -992,18 +1013,19 @@ int main(int argc, const char* argv[]) {
|
|||||||
expected_num_results = params.batch_count;
|
expected_num_results = params.batch_count;
|
||||||
} else if (params.mode == VID_GEN) {
|
} else if (params.mode == VID_GEN) {
|
||||||
sd_vid_gen_params_t vid_gen_params = {
|
sd_vid_gen_params_t vid_gen_params = {
|
||||||
|
params.prompt.c_str(),
|
||||||
|
params.negative_prompt.c_str(),
|
||||||
|
params.clip_skip,
|
||||||
|
guidance_params,
|
||||||
input_image,
|
input_image,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
guidance_params,
|
|
||||||
params.sample_method,
|
params.sample_method,
|
||||||
params.sample_steps,
|
params.sample_steps,
|
||||||
|
params.eta,
|
||||||
params.strength,
|
params.strength,
|
||||||
params.seed,
|
params.seed,
|
||||||
params.video_frames,
|
params.video_frames,
|
||||||
params.motion_bucket_id,
|
|
||||||
params.fps,
|
|
||||||
params.augmentation_level,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
results = generate_video(sd_ctx, &vid_gen_params);
|
results = generate_video(sd_ctx, &vid_gen_params);
|
||||||
|
|||||||
@ -323,17 +323,27 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
|
|||||||
return image_data;
|
return image_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ uint8_t* sd_tensor_to_mul_image(struct ggml_tensor* input, int idx) {
|
__STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx, bool video = false) {
|
||||||
int64_t width = input->ne[0];
|
int64_t width = input->ne[0];
|
||||||
int64_t height = input->ne[1];
|
int64_t height = input->ne[1];
|
||||||
int64_t channels = input->ne[2];
|
int64_t channels;
|
||||||
|
if (video) {
|
||||||
|
channels = input->ne[3];
|
||||||
|
} else {
|
||||||
|
channels = input->ne[2];
|
||||||
|
}
|
||||||
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
|
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
|
||||||
uint8_t* image_data = (uint8_t*)malloc(width * height * channels);
|
uint8_t* image_data = (uint8_t*)malloc(width * height * channels);
|
||||||
for (int iy = 0; iy < height; iy++) {
|
for (int ih = 0; ih < height; ih++) {
|
||||||
for (int ix = 0; ix < width; ix++) {
|
for (int iw = 0; iw < width; iw++) {
|
||||||
for (int k = 0; k < channels; k++) {
|
for (int ic = 0; ic < channels; ic++) {
|
||||||
float value = ggml_tensor_get_f32(input, ix, iy, k, idx);
|
float value;
|
||||||
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f);
|
if (video) {
|
||||||
|
value = ggml_tensor_get_f32(input, iw, ih, idx, ic);
|
||||||
|
} else {
|
||||||
|
value = ggml_tensor_get_f32(input, iw, ih, ic, idx);
|
||||||
|
}
|
||||||
|
*(image_data + ih * width * channels + iw * channels + ic) = (uint8_t)(value * 255.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
27
model.cpp
27
model.cpp
@ -1055,7 +1055,11 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
|||||||
|
|
||||||
// LOG_DEBUG("%s", name.c_str());
|
// LOG_DEBUG("%s", name.c_str());
|
||||||
|
|
||||||
TensorStorage tensor_storage(prefix + name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
|
if (!starts_with(name, prefix)) {
|
||||||
|
name = prefix + name;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
||||||
|
|
||||||
@ -1195,7 +1199,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
n_dims = 1;
|
n_dims = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
if (!starts_with(name, prefix)) {
|
||||||
|
name = prefix + name;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||||
tensor_storage.reverse_ne();
|
tensor_storage.reverse_ne();
|
||||||
|
|
||||||
size_t tensor_data_size = end - begin;
|
size_t tensor_data_size = end - begin;
|
||||||
@ -1580,7 +1588,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
|||||||
reader.tensor_storage.file_index = file_index;
|
reader.tensor_storage.file_index = file_index;
|
||||||
// if(strcmp(prefix.c_str(), "scarlett") == 0)
|
// if(strcmp(prefix.c_str(), "scarlett") == 0)
|
||||||
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
||||||
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
|
std::string name = reader.tensor_storage.name;
|
||||||
|
if (!starts_with(name, prefix)) {
|
||||||
|
name = prefix + name;
|
||||||
|
}
|
||||||
|
reader.tensor_storage.name = name;
|
||||||
tensor_storages.push_back(reader.tensor_storage);
|
tensor_storages.push_back(reader.tensor_storage);
|
||||||
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
|
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
|
||||||
|
|
||||||
@ -1654,10 +1666,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
|
|
||||||
bool is_xl = false;
|
bool is_xl = false;
|
||||||
bool is_flux = false;
|
bool is_flux = false;
|
||||||
|
bool is_wan = false;
|
||||||
|
|
||||||
#define found_family (is_xl || is_flux)
|
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
if (!found_family) {
|
if (!(is_xl || is_flux)) {
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||||
is_flux = true;
|
is_flux = true;
|
||||||
if (input_block_checked) {
|
if (input_block_checked) {
|
||||||
@ -1667,6 +1679,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||||
return VERSION_SD3;
|
return VERSION_SD3;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||||
|
return VERSION_WAN2;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
||||||
is_unet = true;
|
is_unet = true;
|
||||||
if (has_multiple_encoders) {
|
if (has_multiple_encoders) {
|
||||||
@ -1701,7 +1716,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
||||||
input_block_weight = tensor_storage;
|
input_block_weight = tensor_storage;
|
||||||
input_block_checked = true;
|
input_block_checked = true;
|
||||||
if (found_family) {
|
if (is_xl || is_flux) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
5
model.h
5
model.h
@ -31,8 +31,7 @@ enum SDVersion {
|
|||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
VERSION_FLUX,
|
VERSION_FLUX,
|
||||||
VERSION_FLUX_FILL,
|
VERSION_FLUX_FILL,
|
||||||
VERSION_WAN_2_1,
|
VERSION_WAN2,
|
||||||
VERSION_WAN_2_2,
|
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -72,7 +71,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_wan(SDVersion version) {
|
static inline bool sd_version_is_wan(SDVersion version) {
|
||||||
if (version == VERSION_WAN_2_1 || version == VERSION_WAN_2_2) {
|
if (version == VERSION_WAN2) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -36,7 +36,9 @@ const char* model_version_to_str[] = {
|
|||||||
"SVD",
|
"SVD",
|
||||||
"SD3.x",
|
"SD3.x",
|
||||||
"Flux",
|
"Flux",
|
||||||
"Flux Fill"};
|
"Flux Fill",
|
||||||
|
"Wan 2.x",
|
||||||
|
};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
"Euler A",
|
"Euler A",
|
||||||
@ -50,7 +52,8 @@ const char* sampling_methods_str[] = {
|
|||||||
"iPNDM_v",
|
"iPNDM_v",
|
||||||
"LCM",
|
"LCM",
|
||||||
"DDIM \"trailing\"",
|
"DDIM \"trailing\"",
|
||||||
"TCD"};
|
"TCD",
|
||||||
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
|
|
||||||
@ -93,7 +96,7 @@ public:
|
|||||||
std::shared_ptr<Conditioner> cond_stage_model;
|
std::shared_ptr<Conditioner> cond_stage_model;
|
||||||
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd
|
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd
|
||||||
std::shared_ptr<DiffusionModel> diffusion_model;
|
std::shared_ptr<DiffusionModel> diffusion_model;
|
||||||
std::shared_ptr<AutoEncoderKL> first_stage_model;
|
std::shared_ptr<VAE> first_stage_model;
|
||||||
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
|
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
|
||||||
std::shared_ptr<ControlNet> control_net;
|
std::shared_ptr<ControlNet> control_net;
|
||||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||||
@ -274,10 +277,10 @@ public:
|
|||||||
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
|
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??");
|
LOG_INFO("Weight type: %s", ggml_type_name(model_wtype));
|
||||||
LOG_INFO("Conditioner weight type: %s", conditioner_wtype != GGML_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??");
|
LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype));
|
||||||
LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??");
|
LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype));
|
||||||
LOG_INFO("VAE weight type: %s", vae_wtype != GGML_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??");
|
LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype));
|
||||||
|
|
||||||
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
|
||||||
|
|
||||||
@ -293,34 +296,25 @@ public:
|
|||||||
} else if (sd_version_is_sd3(version)) {
|
} else if (sd_version_is_sd3(version)) {
|
||||||
scale_factor = 1.5305f;
|
scale_factor = 1.5305f;
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
scale_factor = 0.3611;
|
scale_factor = 0.3611f;
|
||||||
// TODO: shift_factor
|
// TODO: shift_factor
|
||||||
|
} else if (sd_version_is_wan(version)) {
|
||||||
|
scale_factor = 1.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
|
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
|
||||||
|
|
||||||
if (version == VERSION_SVD) {
|
{
|
||||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types);
|
|
||||||
clip_vision->alloc_params_buffer();
|
|
||||||
clip_vision->get_param_tensors(tensors);
|
|
||||||
|
|
||||||
diffusion_model = std::make_shared<UNetModel>(backend, model_loader.tensor_storages_types, version);
|
|
||||||
diffusion_model->alloc_params_buffer();
|
|
||||||
diffusion_model->get_param_tensors(tensors);
|
|
||||||
|
|
||||||
first_stage_model = std::make_shared<AutoEncoderKL>(backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, true, version);
|
|
||||||
LOG_DEBUG("vae_decode_only %d", vae_decode_only);
|
|
||||||
first_stage_model->alloc_params_buffer();
|
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
|
||||||
} else {
|
|
||||||
clip_backend = backend;
|
clip_backend = backend;
|
||||||
bool use_t5xxl = false;
|
bool use_t5xxl = false;
|
||||||
if (sd_version_is_dit(version)) {
|
if (sd_version_is_dit(version)) {
|
||||||
use_t5xxl = true;
|
use_t5xxl = true;
|
||||||
}
|
}
|
||||||
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
|
if (!ggml_backend_is_cpu(backend) && use_t5xxl) {
|
||||||
clip_on_cpu = true;
|
LOG_WARN(
|
||||||
LOG_INFO("set clip_on_cpu to true");
|
"!!!It appears that you are using the T5 model. Some backends may encounter issues with it."
|
||||||
|
"If you notice that the generated images are completely black,"
|
||||||
|
"try running the T5 model on the CPU using the --clip-on-cpu parameter.");
|
||||||
}
|
}
|
||||||
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
|
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
|
||||||
LOG_INFO("CLIP: Using CPU backend");
|
LOG_INFO("CLIP: Using CPU backend");
|
||||||
@ -357,7 +351,18 @@ public:
|
|||||||
version,
|
version,
|
||||||
sd_ctx_params->diffusion_flash_attn,
|
sd_ctx_params->diffusion_flash_attn,
|
||||||
sd_ctx_params->chroma_use_dit_mask);
|
sd_ctx_params->chroma_use_dit_mask);
|
||||||
} else {
|
} else if (sd_version_is_wan(version)) {
|
||||||
|
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||||
|
model_loader.tensor_storages_types,
|
||||||
|
-1,
|
||||||
|
true,
|
||||||
|
1,
|
||||||
|
true);
|
||||||
|
diffusion_model = std::make_shared<WanModel>(backend,
|
||||||
|
model_loader.tensor_storages_types,
|
||||||
|
version,
|
||||||
|
sd_ctx_params->diffusion_flash_attn);
|
||||||
|
} else { // SD1.x SD2.x SDXL
|
||||||
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
@ -382,13 +387,21 @@ public:
|
|||||||
diffusion_model->alloc_params_buffer();
|
diffusion_model->alloc_params_buffer();
|
||||||
diffusion_model->get_param_tensors(tensors);
|
diffusion_model->get_param_tensors(tensors);
|
||||||
|
|
||||||
if (!use_tiny_autoencoder) {
|
|
||||||
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
|
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
|
||||||
LOG_INFO("VAE Autoencoder: Using CPU backend");
|
LOG_INFO("VAE Autoencoder: Using CPU backend");
|
||||||
vae_backend = ggml_backend_cpu_init();
|
vae_backend = ggml_backend_cpu_init();
|
||||||
} else {
|
} else {
|
||||||
vae_backend = backend;
|
vae_backend = backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||||
|
model_loader.tensor_storages_types,
|
||||||
|
"first_stage_model",
|
||||||
|
vae_decode_only);
|
||||||
|
first_stage_model->alloc_params_buffer();
|
||||||
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
} else if (!use_tiny_autoencoder) {
|
||||||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
"first_stage_model",
|
"first_stage_model",
|
||||||
@ -398,7 +411,7 @@ public:
|
|||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
} else {
|
} else {
|
||||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend,
|
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
"decoder.layers",
|
"decoder.layers",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
@ -485,11 +498,7 @@ public:
|
|||||||
|
|
||||||
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
|
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
|
||||||
|
|
||||||
if (version == VERSION_SVD) {
|
{
|
||||||
// diffusion_model->test();
|
|
||||||
// first_stage_model->test();
|
|
||||||
// return false;
|
|
||||||
} else {
|
|
||||||
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
|
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
|
||||||
size_t unet_params_mem_size = diffusion_model->get_params_buffer_size();
|
size_t unet_params_mem_size = diffusion_model->get_params_buffer_size();
|
||||||
size_t vae_params_mem_size = 0;
|
size_t vae_params_mem_size = 0;
|
||||||
@ -594,6 +603,9 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||||
|
} else if (sd_version_is_wan(version)) {
|
||||||
|
LOG_INFO("running in FLOW mode");
|
||||||
|
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
||||||
} else if (is_using_v_parameterization) {
|
} else if (is_using_v_parameterization) {
|
||||||
LOG_INFO("running in v-prediction mode");
|
LOG_INFO("running in v-prediction mode");
|
||||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||||
@ -733,9 +745,9 @@ public:
|
|||||||
|
|
||||||
size_t rm = lora_state_diff.size() - lora_state.size();
|
size_t rm = lora_state_diff.size() - lora_state.size();
|
||||||
if (rm != 0) {
|
if (rm != 0) {
|
||||||
LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
|
LOG_INFO("attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
|
||||||
} else {
|
} else {
|
||||||
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
|
LOG_INFO("attempting to apply %lu LoRAs", lora_state.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& kv : lora_state_diff) {
|
for (auto& kv : lora_state_diff) {
|
||||||
@ -745,6 +757,21 @@ public:
|
|||||||
curr_lora_state = lora_state;
|
curr_lora_state = lora_state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string apply_loras_from_prompt(const std::string& prompt) {
|
||||||
|
auto result_pair = extract_and_remove_lora(prompt);
|
||||||
|
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
|
||||||
|
|
||||||
|
for (auto& kv : lora_f2m) {
|
||||||
|
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
|
||||||
|
}
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
apply_loras(lora_f2m);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str());
|
||||||
|
return result_pair.second;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* id_encoder(ggml_context* work_ctx,
|
ggml_tensor* id_encoder(ggml_context* work_ctx,
|
||||||
ggml_tensor* init_img,
|
ggml_tensor* init_img,
|
||||||
ggml_tensor* prompts_embeds,
|
ggml_tensor* prompts_embeds,
|
||||||
@ -762,12 +789,12 @@ public:
|
|||||||
int fps = 6,
|
int fps = 6,
|
||||||
int motion_bucket_id = 127,
|
int motion_bucket_id = 127,
|
||||||
float augmentation_level = 0.f,
|
float augmentation_level = 0.f,
|
||||||
bool force_zero_embeddings = false) {
|
bool zero_out_masked = false) {
|
||||||
// c_crossattn
|
// c_crossattn
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* c_crossattn = NULL;
|
struct ggml_tensor* c_crossattn = NULL;
|
||||||
{
|
{
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim);
|
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim);
|
||||||
ggml_set_f32(c_crossattn, 0.f);
|
ggml_set_f32(c_crossattn, 0.f);
|
||||||
} else {
|
} else {
|
||||||
@ -790,7 +817,7 @@ public:
|
|||||||
// c_concat
|
// c_concat
|
||||||
struct ggml_tensor* c_concat = NULL;
|
struct ggml_tensor* c_concat = NULL;
|
||||||
{
|
{
|
||||||
if (force_zero_embeddings) {
|
if (zero_out_masked) {
|
||||||
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
|
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
|
||||||
ggml_set_f32(c_concat, 0.f);
|
ggml_set_f32(c_concat, 0.f);
|
||||||
} else {
|
} else {
|
||||||
@ -855,28 +882,14 @@ public:
|
|||||||
float img_cfg_scale = guidance.img_cfg;
|
float img_cfg_scale = guidance.img_cfg;
|
||||||
float slg_scale = guidance.slg.scale;
|
float slg_scale = guidance.slg.scale;
|
||||||
|
|
||||||
float min_cfg = guidance.min_cfg;
|
LOG_DEBUG("cfg_scale %.2f", cfg_scale);
|
||||||
|
|
||||||
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
|
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
|
||||||
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
|
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
|
||||||
img_cfg_scale = cfg_scale;
|
img_cfg_scale = cfg_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DEBUG("Sample");
|
|
||||||
struct ggml_init_params params;
|
|
||||||
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
|
|
||||||
for (int i = 1; i < 4; i++) {
|
|
||||||
data_size *= init_latent->ne[i];
|
|
||||||
}
|
|
||||||
data_size += 1024;
|
|
||||||
params.mem_size = data_size * 3;
|
|
||||||
params.mem_buffer = NULL;
|
|
||||||
params.no_alloc = false;
|
|
||||||
ggml_context* tmp_ctx = ggml_init(params);
|
|
||||||
|
|
||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
|
|
||||||
// print_ggml_tensor(noise);
|
|
||||||
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
||||||
copy_ggml_tensor(x, init_latent);
|
copy_ggml_tensor(x, init_latent);
|
||||||
x = denoiser->noise_scaling(sigmas[0], noise, x);
|
x = denoiser->noise_scaling(sigmas[0], noise, x);
|
||||||
@ -922,9 +935,9 @@ public:
|
|||||||
float c_in = scaling[2];
|
float c_in = scaling[2];
|
||||||
|
|
||||||
float t = denoiser->sigma_to_t(sigma);
|
float t = denoiser->sigma_to_t(sigma);
|
||||||
std::vector<float> timesteps_vec(x->ne[3], t); // [N, ]
|
std::vector<float> timesteps_vec(1, t); // [N, ]
|
||||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||||
std::vector<float> guidance_vec(x->ne[3], guidance.distilled_guidance);
|
std::vector<float> guidance_vec(1, guidance.distilled_guidance);
|
||||||
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
||||||
|
|
||||||
copy_ggml_tensor(noised_input, input);
|
copy_ggml_tensor(noised_input, input);
|
||||||
@ -1038,11 +1051,6 @@ public:
|
|||||||
float latent_result = positive_data[i];
|
float latent_result = positive_data[i];
|
||||||
if (has_unconditioned) {
|
if (has_unconditioned) {
|
||||||
// out_uncond + cfg_scale * (out_cond - out_uncond)
|
// out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
int64_t ne3 = out_cond->ne[3];
|
|
||||||
if (min_cfg != cfg_scale && ne3 != 1) {
|
|
||||||
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
|
|
||||||
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
|
|
||||||
} else {
|
|
||||||
if (has_img_cond) {
|
if (has_img_cond) {
|
||||||
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
||||||
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
|
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
|
||||||
@ -1050,7 +1058,6 @@ public:
|
|||||||
// img_cfg_scale == cfg_scale
|
// img_cfg_scale == cfg_scale
|
||||||
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
|
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else if (has_img_cond) {
|
} else if (has_img_cond) {
|
||||||
// img_cfg_scale == 1
|
// img_cfg_scale == 1
|
||||||
latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]);
|
latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]);
|
||||||
@ -1085,6 +1092,7 @@ public:
|
|||||||
|
|
||||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||||
|
|
||||||
|
LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]);
|
||||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||||
|
|
||||||
if (control_net) {
|
if (control_net) {
|
||||||
@ -1101,7 +1109,6 @@ public:
|
|||||||
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
||||||
ggml_tensor_set_f32_randn(noise, rng);
|
ggml_tensor_set_f32_randn(noise, rng);
|
||||||
// noise = load_tensor_from_file(work_ctx, "noise.bin");
|
|
||||||
{
|
{
|
||||||
float mean = 0;
|
float mean = 0;
|
||||||
float logvar = 0;
|
float logvar = 0;
|
||||||
@ -1127,9 +1134,9 @@ public:
|
|||||||
return latent;
|
return latent;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* compute_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode) {
|
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0] / 8;
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1] / 8;
|
||||||
int64_t C = 8;
|
int64_t C = 8;
|
||||||
if (use_tiny_autoencoder) {
|
if (use_tiny_autoencoder) {
|
||||||
C = 4;
|
C = 4;
|
||||||
@ -1140,59 +1147,106 @@ public:
|
|||||||
C = 32;
|
C = 32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_tensor* result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
|
ggml_tensor* result = ggml_new_tensor_4d(work_ctx,
|
||||||
decode ? (W * 8) : (W / 8), // width
|
GGML_TYPE_F32,
|
||||||
decode ? (H * 8) : (H / 8), // height
|
W,
|
||||||
decode ? 3 : C,
|
H,
|
||||||
x->ne[3]); // channels
|
C,
|
||||||
|
x->ne[3]);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
if (!use_tiny_autoencoder) {
|
if (!use_tiny_autoencoder) {
|
||||||
if (decode) {
|
|
||||||
ggml_tensor_scale(x, 1.0f / scale_factor);
|
|
||||||
} else {
|
|
||||||
ggml_tensor_scale_input(x);
|
ggml_tensor_scale_input(x);
|
||||||
|
first_stage_model->compute(n_threads, x, false, &result, NULL);
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
} else {
|
||||||
|
tae_first_stage->compute(n_threads, x, false, &result, NULL);
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
}
|
}
|
||||||
if (vae_tiling && decode) { // TODO: support tiling vae encode
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void process_latent_out(ggml_tensor* latent) {
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
GGML_ASSERT(latent->ne[3] == 16);
|
||||||
|
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||||
|
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||||
|
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||||
|
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||||
|
for (int i = 0; i < latent->ne[3]; i++) {
|
||||||
|
float mean = latents_mean_vec[i];
|
||||||
|
float std_ = latents_std_vec[i];
|
||||||
|
for (int j = 0; j < latent->ne[2]; j++) {
|
||||||
|
for (int k = 0; k < latent->ne[1]; k++) {
|
||||||
|
for (int l = 0; l < latent->ne[0]; l++) {
|
||||||
|
float value = ggml_tensor_get_f32(latent, l, k, j, i);
|
||||||
|
value = value * std_ / scale_factor + mean;
|
||||||
|
ggml_tensor_set_f32(latent, value, l, k, j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ggml_tensor_scale(latent, 1.0f / scale_factor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||||
|
int64_t W = x->ne[0] * 8;
|
||||||
|
int64_t H = x->ne[1] * 8;
|
||||||
|
int64_t C = 3;
|
||||||
|
ggml_tensor* result;
|
||||||
|
if (decode_video) {
|
||||||
|
result = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
W,
|
||||||
|
H,
|
||||||
|
x->ne[2],
|
||||||
|
3);
|
||||||
|
} else {
|
||||||
|
result = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
W,
|
||||||
|
H,
|
||||||
|
C,
|
||||||
|
x->ne[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
if (!use_tiny_autoencoder) {
|
||||||
|
LOG_DEBUG("scale_factor %.2f", scale_factor);
|
||||||
|
process_latent_out(x);
|
||||||
|
if (vae_tiling && !decode_video) {
|
||||||
// split latent in 32x32 tiles and compute in several steps
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
first_stage_model->compute(n_threads, in, decode, &out);
|
first_stage_model->compute(n_threads, in, true, &out, NULL);
|
||||||
};
|
};
|
||||||
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
|
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
|
||||||
} else {
|
} else {
|
||||||
first_stage_model->compute(n_threads, x, decode, &result);
|
first_stage_model->compute(n_threads, x, true, &result, NULL);
|
||||||
}
|
}
|
||||||
first_stage_model->free_compute_buffer();
|
first_stage_model->free_compute_buffer();
|
||||||
if (decode) {
|
|
||||||
ggml_tensor_scale_output(result);
|
ggml_tensor_scale_output(result);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if (vae_tiling && decode) { // TODO: support tiling vae encode
|
if (vae_tiling && !decode_video) {
|
||||||
// split latent in 64x64 tiles and compute in several steps
|
// split latent in 64x64 tiles and compute in several steps
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
tae_first_stage->compute(n_threads, in, decode, &out);
|
tae_first_stage->compute(n_threads, in, true, &out);
|
||||||
};
|
};
|
||||||
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
||||||
} else {
|
} else {
|
||||||
tae_first_stage->compute(n_threads, x, decode, &result);
|
tae_first_stage->compute(n_threads, x, true, &result);
|
||||||
}
|
}
|
||||||
tae_first_stage->free_compute_buffer();
|
tae_first_stage->free_compute_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_DEBUG("computing vae [mode: %s] graph completed, taking %.2fs", decode ? "DECODE" : "ENCODE", (t1 - t0) * 1.0f / 1000);
|
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
if (decode) {
|
|
||||||
ggml_tensor_clamp(result, 0.0f, 1.0f);
|
ggml_tensor_clamp(result, 0.0f, 1.0f);
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
|
||||||
return compute_first_stage(work_ctx, x, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
|
||||||
return compute_first_stage(work_ctx, x, true);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================= SD API ==================================================*/
|
/*================================================= SD API ==================================================*/
|
||||||
@ -1373,7 +1427,6 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t));
|
memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t));
|
||||||
sd_img_gen_params->clip_skip = -1;
|
sd_img_gen_params->clip_skip = -1;
|
||||||
sd_img_gen_params->guidance.txt_cfg = 7.0f;
|
sd_img_gen_params->guidance.txt_cfg = 7.0f;
|
||||||
sd_img_gen_params->guidance.min_cfg = 1.0f;
|
|
||||||
sd_img_gen_params->guidance.img_cfg = INFINITY;
|
sd_img_gen_params->guidance.img_cfg = INFINITY;
|
||||||
sd_img_gen_params->guidance.distilled_guidance = 3.5f;
|
sd_img_gen_params->guidance.distilled_guidance = 3.5f;
|
||||||
sd_img_gen_params->guidance.slg.layer_count = 0;
|
sd_img_gen_params->guidance.slg.layer_count = 0;
|
||||||
@ -1406,7 +1459,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
"clip_skip: %d\n"
|
"clip_skip: %d\n"
|
||||||
"txt_cfg: %.2f\n"
|
"txt_cfg: %.2f\n"
|
||||||
"img_cfg: %.2f\n"
|
"img_cfg: %.2f\n"
|
||||||
"min_cfg: %.2f\n"
|
|
||||||
"distilled_guidance: %.2f\n"
|
"distilled_guidance: %.2f\n"
|
||||||
"slg.layer_count: %zu\n"
|
"slg.layer_count: %zu\n"
|
||||||
"slg.layer_start: %.2f\n"
|
"slg.layer_start: %.2f\n"
|
||||||
@ -1431,7 +1483,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
sd_img_gen_params->clip_skip,
|
sd_img_gen_params->clip_skip,
|
||||||
sd_img_gen_params->guidance.txt_cfg,
|
sd_img_gen_params->guidance.txt_cfg,
|
||||||
sd_img_gen_params->guidance.img_cfg,
|
sd_img_gen_params->guidance.img_cfg,
|
||||||
sd_img_gen_params->guidance.min_cfg,
|
|
||||||
sd_img_gen_params->guidance.distilled_guidance,
|
sd_img_gen_params->guidance.distilled_guidance,
|
||||||
sd_img_gen_params->guidance.slg.layer_count,
|
sd_img_gen_params->guidance.slg.layer_count,
|
||||||
sd_img_gen_params->guidance.slg.layer_start,
|
sd_img_gen_params->guidance.slg.layer_start,
|
||||||
@ -1457,7 +1508,6 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||||
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
|
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
|
||||||
sd_vid_gen_params->guidance.txt_cfg = 7.0f;
|
sd_vid_gen_params->guidance.txt_cfg = 7.0f;
|
||||||
sd_vid_gen_params->guidance.min_cfg = 1.0f;
|
|
||||||
sd_vid_gen_params->guidance.img_cfg = INFINITY;
|
sd_vid_gen_params->guidance.img_cfg = INFINITY;
|
||||||
sd_vid_gen_params->guidance.distilled_guidance = 3.5f;
|
sd_vid_gen_params->guidance.distilled_guidance = 3.5f;
|
||||||
sd_vid_gen_params->guidance.slg.layer_count = 0;
|
sd_vid_gen_params->guidance.slg.layer_count = 0;
|
||||||
@ -1471,9 +1521,6 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
|||||||
sd_vid_gen_params->strength = 0.75f;
|
sd_vid_gen_params->strength = 0.75f;
|
||||||
sd_vid_gen_params->seed = -1;
|
sd_vid_gen_params->seed = -1;
|
||||||
sd_vid_gen_params->video_frames = 6;
|
sd_vid_gen_params->video_frames = 6;
|
||||||
sd_vid_gen_params->motion_bucket_id = 127;
|
|
||||||
sd_vid_gen_params->fps = 6;
|
|
||||||
sd_vid_gen_params->augmentation_level = 0.f;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct sd_ctx_t {
|
struct sd_ctx_t {
|
||||||
@ -1545,21 +1592,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
int sample_steps = sigmas.size() - 1;
|
int sample_steps = sigmas.size() - 1;
|
||||||
|
|
||||||
// Apply lora
|
|
||||||
auto result_pair = extract_and_remove_lora(prompt);
|
|
||||||
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
|
|
||||||
|
|
||||||
for (auto& kv : lora_f2m) {
|
|
||||||
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt = result_pair.second;
|
|
||||||
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
|
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
sd_ctx->sd->apply_loras(lora_f2m);
|
// Apply lora
|
||||||
int64_t t1 = ggml_time_ms();
|
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
|
||||||
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
||||||
|
|
||||||
// Photo Maker
|
// Photo Maker
|
||||||
std::string prompt_text_only;
|
std::string prompt_text_only;
|
||||||
@ -1568,9 +1603,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
std::vector<bool> class_tokens_mask;
|
std::vector<bool> class_tokens_mask;
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
if (!sd_ctx->sd->pmid_lora->applied) {
|
if (!sd_ctx->sd->pmid_lora->applied) {
|
||||||
t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
|
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
|
||||||
t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
sd_ctx->sd->pmid_lora->applied = true;
|
sd_ctx->sd->pmid_lora->applied = true;
|
||||||
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
@ -1625,7 +1660,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
else
|
else
|
||||||
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
|
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
|
||||||
}
|
}
|
||||||
t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||||
sd_ctx->sd->n_threads, prompt,
|
sd_ctx->sd->n_threads, prompt,
|
||||||
clip_skip,
|
clip_skip,
|
||||||
@ -1642,7 +1677,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
// print_ggml_tensor(id_embeds, true, "id_embeds:");
|
// print_ggml_tensor(id_embeds, true, "id_embeds:");
|
||||||
}
|
}
|
||||||
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
|
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
|
||||||
t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
|
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
sd_ctx->sd->pmid_model->free_params_buffer();
|
sd_ctx->sd->pmid_model->free_params_buffer();
|
||||||
@ -1679,9 +1714,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
SDCondition uncond;
|
SDCondition uncond;
|
||||||
if (guidance.txt_cfg != 1.0 ||
|
if (guidance.txt_cfg != 1.0 ||
|
||||||
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
|
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
|
||||||
bool force_zero_embeddings = false;
|
bool zero_out_masked = false;
|
||||||
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
|
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
|
||||||
force_zero_embeddings = true;
|
zero_out_masked = true;
|
||||||
}
|
}
|
||||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
sd_ctx->sd->n_threads,
|
sd_ctx->sd->n_threads,
|
||||||
@ -1690,9 +1725,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||||
force_zero_embeddings);
|
zero_out_masked);
|
||||||
}
|
}
|
||||||
t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
|
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
@ -1780,9 +1815,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
|
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable min_cfg
|
|
||||||
guidance.min_cfg = guidance.txt_cfg;
|
|
||||||
|
|
||||||
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
|
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
|
||||||
x_t,
|
x_t,
|
||||||
noise,
|
noise,
|
||||||
@ -1799,8 +1831,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
id_cond,
|
id_cond,
|
||||||
ref_latents,
|
ref_latents,
|
||||||
denoise_mask);
|
denoise_mask);
|
||||||
|
|
||||||
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
|
|
||||||
// print_ggml_tensor(x_0);
|
// print_ggml_tensor(x_0);
|
||||||
int64_t sampling_end = ggml_time_ms();
|
int64_t sampling_end = ggml_time_ms();
|
||||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||||
@ -1852,16 +1882,25 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
||||||
ggml_context* work_ctx,
|
ggml_context* work_ctx,
|
||||||
int width,
|
int width,
|
||||||
int height) {
|
int height,
|
||||||
|
int frames = 1,
|
||||||
|
bool video = false) {
|
||||||
int C = 4;
|
int C = 4;
|
||||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
|
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
||||||
|
C = 16;
|
||||||
}
|
}
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
int H = height / 8;
|
int H = height / 8;
|
||||||
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
ggml_tensor* init_latent;
|
||||||
|
if (video) {
|
||||||
|
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, frames, C);
|
||||||
|
} else {
|
||||||
|
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
|
}
|
||||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||||
ggml_set_f32(init_latent, 0.0609f);
|
ggml_set_f32(init_latent, 0.0609f);
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
@ -1877,11 +1916,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
int height = sd_img_gen_params->height;
|
int height = sd_img_gen_params->height;
|
||||||
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
||||||
if (width % 16 || height % 16) {
|
if (width % 16 || height % 16) {
|
||||||
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height);
|
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
|
||||||
|
model_version_to_str[sd_ctx->sd->version],
|
||||||
|
width,
|
||||||
|
height);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
} else if (width % 64 || height % 64) {
|
} else if (width % 64 || height % 64) {
|
||||||
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)", model_version_to_str[sd_ctx->sd->version], width, height);
|
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)",
|
||||||
|
model_version_to_str[sd_ctx->sd->version],
|
||||||
|
width,
|
||||||
|
height);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
LOG_DEBUG("generate_image %dx%d", width, height);
|
LOG_DEBUG("generate_image %dx%d", width, height);
|
||||||
@ -2095,20 +2140,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
|
||||||
|
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
|
||||||
|
|
||||||
int width = sd_vid_gen_params->width;
|
int width = sd_vid_gen_params->width;
|
||||||
int height = sd_vid_gen_params->height;
|
int height = sd_vid_gen_params->height;
|
||||||
LOG_INFO("img2vid %dx%d", width, height);
|
int frames = sd_vid_gen_params->video_frames;
|
||||||
|
LOG_INFO("img2vid %dx%dx%d", width, height, frames);
|
||||||
|
|
||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB
|
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 50 MB
|
||||||
params.mem_size += width * height * 3 * sizeof(float) * sd_vid_gen_params->video_frames;
|
params.mem_size += width * height * frames * 3 * sizeof(float);
|
||||||
params.mem_buffer = NULL;
|
params.mem_buffer = NULL;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
// LOG_DEBUG("mem_size %u ", params.mem_size);
|
// LOG_DEBUG("mem_size %u ", params.mem_size);
|
||||||
|
|
||||||
// draft context
|
|
||||||
struct ggml_context* work_ctx = ggml_init(params);
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
if (!work_ctx) {
|
if (!work_ctx) {
|
||||||
LOG_ERROR("ggml_init() failed");
|
LOG_ERROR("ggml_init() failed");
|
||||||
@ -2124,90 +2172,100 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
SDCondition cond = sd_ctx->sd->get_svd_condition(work_ctx,
|
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||||
sd_vid_gen_params->init_image,
|
int sample_steps = sigmas.size() - 1;
|
||||||
|
// Apply lora
|
||||||
|
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
|
||||||
|
|
||||||
|
// Get learned condition
|
||||||
|
bool zero_out_masked = true;
|
||||||
|
t0 = ggml_time_ms();
|
||||||
|
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
|
sd_ctx->sd->n_threads,
|
||||||
|
prompt,
|
||||||
|
sd_vid_gen_params->clip_skip,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sd_vid_gen_params->fps,
|
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||||
sd_vid_gen_params->motion_bucket_id,
|
zero_out_masked);
|
||||||
sd_vid_gen_params->augmentation_level);
|
SDCondition uncond;
|
||||||
|
if (sd_vid_gen_params->guidance.txt_cfg != 1.0) {
|
||||||
auto uc_crossattn = ggml_dup_tensor(work_ctx, cond.c_crossattn);
|
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
ggml_set_f32(uc_crossattn, 0.f);
|
sd_ctx->sd->n_threads,
|
||||||
|
negative_prompt,
|
||||||
auto uc_concat = ggml_dup_tensor(work_ctx, cond.c_concat);
|
sd_vid_gen_params->clip_skip,
|
||||||
ggml_set_f32(uc_concat, 0.f);
|
width,
|
||||||
|
height,
|
||||||
auto uc_vector = ggml_dup_tensor(work_ctx, cond.c_vector);
|
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||||
|
zero_out_masked);
|
||||||
SDCondition uncond = SDCondition(uc_crossattn, uc_vector, uc_concat);
|
}
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
|
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
sd_ctx->sd->clip_vision->free_params_buffer();
|
sd_ctx->sd->cond_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_ctx->sd->rng->manual_seed(seed);
|
|
||||||
int C = 4;
|
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
int H = height / 8;
|
int H = height / 8;
|
||||||
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames);
|
int T = frames;
|
||||||
ggml_set_f32(x_t, 0.f);
|
int C = 16;
|
||||||
|
|
||||||
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames);
|
struct ggml_tensor* final_latent;
|
||||||
|
// Sample
|
||||||
|
{
|
||||||
|
int64_t sampling_start = ggml_time_ms();
|
||||||
|
struct ggml_tensor* x_t = init_latent;
|
||||||
|
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
|
||||||
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
|
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
|
||||||
|
|
||||||
LOG_INFO("sampling using %s method", sampling_methods_str[sd_vid_gen_params->sample_method]);
|
final_latent = sd_ctx->sd->sample(work_ctx,
|
||||||
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
|
|
||||||
x_t,
|
x_t,
|
||||||
noise,
|
noise,
|
||||||
cond,
|
cond,
|
||||||
uncond,
|
uncond,
|
||||||
{},
|
{},
|
||||||
{},
|
NULL,
|
||||||
0.f,
|
0,
|
||||||
sd_vid_gen_params->guidance,
|
sd_vid_gen_params->guidance,
|
||||||
0.f,
|
sd_vid_gen_params->eta,
|
||||||
sd_vid_gen_params->sample_method,
|
sd_vid_gen_params->sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
-1,
|
-1,
|
||||||
SDCondition(NULL, NULL, NULL));
|
{});
|
||||||
|
|
||||||
|
int64_t sampling_end = ggml_time_ms();
|
||||||
|
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t t2 = ggml_time_ms();
|
|
||||||
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);
|
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
sd_ctx->sd->diffusion_model->free_params_buffer();
|
sd_ctx->sd->diffusion_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0);
|
int64_t t3 = ggml_time_ms();
|
||||||
|
LOG_INFO("generating latent video completed, taking %.2fs", (t3 - t1) * 1.0f / 1000);
|
||||||
|
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||||
|
int64_t t4 = ggml_time_ms();
|
||||||
|
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately) {
|
||||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
if (img == NULL) {
|
|
||||||
ggml_free(work_ctx);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
sd_image_t* result_images = (sd_image_t*)calloc(sd_vid_gen_params->video_frames, sizeof(sd_image_t));
|
sd_image_t* result_images = (sd_image_t*)calloc(T, sizeof(sd_image_t));
|
||||||
if (result_images == NULL) {
|
if (result_images == NULL) {
|
||||||
ggml_free(work_ctx);
|
ggml_free(work_ctx);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < sd_vid_gen_params->video_frames; i++) {
|
for (size_t i = 0; i < T; i++) {
|
||||||
auto img_i = ggml_view_3d(work_ctx, img, img->ne[0], img->ne[1], img->ne[2], img->nb[1], img->nb[2], img->nb[3] * i);
|
result_images[i].width = final_latent->ne[0] * 8;
|
||||||
|
result_images[i].height = final_latent->ne[1] * 8;
|
||||||
result_images[i].width = width;
|
|
||||||
result_images[i].height = height;
|
|
||||||
result_images[i].channel = 3;
|
result_images[i].channel = 3;
|
||||||
result_images[i].data = sd_tensor_to_image(img_i);
|
result_images[i].data = sd_tensor_to_image(vid, i, true);
|
||||||
}
|
}
|
||||||
ggml_free(work_ctx);
|
ggml_free(work_ctx);
|
||||||
|
|
||||||
int64_t t3 = ggml_time_ms();
|
LOG_INFO("img2vid completed in %.2fs", (t4 - t0) * 1.0f / 1000);
|
||||||
|
|
||||||
LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000);
|
|
||||||
|
|
||||||
return result_images;
|
return result_images;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -157,7 +157,6 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
float txt_cfg;
|
float txt_cfg;
|
||||||
float img_cfg;
|
float img_cfg;
|
||||||
float min_cfg;
|
|
||||||
float distilled_guidance;
|
float distilled_guidance;
|
||||||
sd_slg_params_t slg;
|
sd_slg_params_t slg;
|
||||||
} sd_guidance_params_t;
|
} sd_guidance_params_t;
|
||||||
@ -187,18 +186,19 @@ typedef struct {
|
|||||||
} sd_img_gen_params_t;
|
} sd_img_gen_params_t;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
const char* prompt;
|
||||||
|
const char* negative_prompt;
|
||||||
|
int clip_skip;
|
||||||
|
sd_guidance_params_t guidance;
|
||||||
sd_image_t init_image;
|
sd_image_t init_image;
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
sd_guidance_params_t guidance;
|
|
||||||
enum sample_method_t sample_method;
|
enum sample_method_t sample_method;
|
||||||
int sample_steps;
|
int sample_steps;
|
||||||
|
float eta;
|
||||||
float strength;
|
float strength;
|
||||||
int64_t seed;
|
int64_t seed;
|
||||||
int video_frames;
|
int video_frames;
|
||||||
int motion_bucket_id;
|
|
||||||
int fps;
|
|
||||||
float augmentation_level;
|
|
||||||
} sd_vid_gen_params_t;
|
} sd_vid_gen_params_t;
|
||||||
|
|
||||||
typedef struct sd_ctx_t sd_ctx_t;
|
typedef struct sd_ctx_t sd_ctx_t;
|
||||||
|
|||||||
15
vae.hpp
15
vae.hpp
@ -520,7 +520,18 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AutoEncoderKL : public GGMLRunner {
|
struct VAE : public GGMLRunner {
|
||||||
|
VAE(ggml_backend_t backend)
|
||||||
|
: GGMLRunner(backend) {}
|
||||||
|
virtual void compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx) = 0;
|
||||||
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct AutoEncoderKL : public VAE {
|
||||||
bool decode_only = true;
|
bool decode_only = true;
|
||||||
AutoencodingEngine ae;
|
AutoencodingEngine ae;
|
||||||
|
|
||||||
@ -530,7 +541,7 @@ struct AutoEncoderKL : public GGMLRunner {
|
|||||||
bool decode_only = false,
|
bool decode_only = false,
|
||||||
bool use_video_decoder = false,
|
bool use_video_decoder = false,
|
||||||
SDVersion version = VERSION_SD1)
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) {
|
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend) {
|
||||||
ae.init(params_ctx, tensor_types, prefix);
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
19
wan.hpp
19
wan.hpp
@ -7,6 +7,7 @@
|
|||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "rope.hpp"
|
#include "rope.hpp"
|
||||||
|
#include "vae.hpp"
|
||||||
|
|
||||||
namespace WAN {
|
namespace WAN {
|
||||||
|
|
||||||
@ -522,7 +523,6 @@ namespace WAN {
|
|||||||
for (int i = 0; i < dims.size() - 1; i++) {
|
for (int i = 0; i < dims.size() - 1; i++) {
|
||||||
in_dim = dims[i];
|
in_dim = dims[i];
|
||||||
out_dim = dims[i + 1];
|
out_dim = dims[i + 1];
|
||||||
LOG_DEBUG("in_dim %u out_dim %u", in_dim, out_dim);
|
|
||||||
if (i == 1 || i == 2 || i == 3) {
|
if (i == 1 || i == 2 || i == 3) {
|
||||||
in_dim = in_dim / 2;
|
in_dim = in_dim / 2;
|
||||||
}
|
}
|
||||||
@ -726,7 +726,7 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct WanVAERunner : public GGMLRunner {
|
struct WanVAERunner : public VAE {
|
||||||
bool decode_only = true;
|
bool decode_only = true;
|
||||||
WanVAE ae;
|
WanVAE ae;
|
||||||
|
|
||||||
@ -734,7 +734,7 @@ namespace WAN {
|
|||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
bool decode_only = false)
|
bool decode_only = false)
|
||||||
: decode_only(decode_only), ae(decode_only), GGMLRunner(backend) {
|
: decode_only(decode_only), ae(decode_only), VAE(backend) {
|
||||||
ae.init(params_ctx, tensor_types, prefix);
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1217,13 +1217,13 @@ namespace WAN {
|
|||||||
int64_t axes_dim_sum = 128;
|
int64_t axes_dim_sum = 128;
|
||||||
};
|
};
|
||||||
|
|
||||||
class WanModel : public GGMLBlock {
|
class Wan : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
WanParams params;
|
WanParams params;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
WanModel() {}
|
Wan() {}
|
||||||
WanModel(WanParams params)
|
Wan(WanParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
// patch_embedding
|
// patch_embedding
|
||||||
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
|
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
|
||||||
@ -1418,14 +1418,15 @@ namespace WAN {
|
|||||||
struct WanRunner : public GGMLRunner {
|
struct WanRunner : public GGMLRunner {
|
||||||
public:
|
public:
|
||||||
WanParams wan_params;
|
WanParams wan_params;
|
||||||
WanModel wan;
|
Wan wan;
|
||||||
std::vector<float> pe_vec;
|
std::vector<float> pe_vec;
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
|
|
||||||
WanRunner(ggml_backend_t backend,
|
WanRunner(ggml_backend_t backend,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
SDVersion version = VERSION_WAN_2_1)
|
SDVersion version = VERSION_WAN2,
|
||||||
|
bool flash_attn = false)
|
||||||
: GGMLRunner(backend) {
|
: GGMLRunner(backend) {
|
||||||
wan_params.num_layers = 0;
|
wan_params.num_layers = 0;
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
@ -1476,7 +1477,7 @@ namespace WAN {
|
|||||||
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
|
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
|
||||||
}
|
}
|
||||||
|
|
||||||
wan = WanModel(wan_params);
|
wan = Wan(wan_params);
|
||||||
wan.init(params_ctx, tensor_types, prefix);
|
wan.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user