mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 19:03:35 +00:00
add support for extra contexts
This commit is contained in:
parent
8004d32de2
commit
190c523cec
202
conditioner.hpp
202
conditioner.hpp
@ -10,9 +10,14 @@ struct SDCondition {
|
||||
struct ggml_tensor* c_vector = nullptr; // aka y
|
||||
struct ggml_tensor* c_concat = nullptr;
|
||||
|
||||
std::vector<struct ggml_tensor*> extra_c_crossattns;
|
||||
|
||||
SDCondition() = default;
|
||||
SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat)
|
||||
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
|
||||
SDCondition(struct ggml_tensor* c_crossattn,
|
||||
struct ggml_tensor* c_vector,
|
||||
struct ggml_tensor* c_concat,
|
||||
const std::vector<struct ggml_tensor*>& extra_c_crossattns = {})
|
||||
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {}
|
||||
};
|
||||
|
||||
struct ConditionerParams {
|
||||
@ -1657,10 +1662,11 @@ struct LLMEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||
std::pair<int, int> attn_range,
|
||||
const std::pair<int, int>& attn_range,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
std::vector<std::pair<std::string, float>> parsed_attention;
|
||||
if (attn_range.first >= 0 && attn_range.second > 0) {
|
||||
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
|
||||
if (attn_range.second - attn_range.first > 0) {
|
||||
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
|
||||
@ -1669,6 +1675,10 @@ struct LLMEmbedder : public Conditioner {
|
||||
new_parsed_attention.end());
|
||||
}
|
||||
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
|
||||
} else {
|
||||
parsed_attention.emplace_back(text, 1.f);
|
||||
}
|
||||
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
@ -1699,16 +1709,88 @@ struct LLMEmbedder : public Conditioner {
|
||||
return {tokens, weights};
|
||||
}
|
||||
|
||||
ggml_tensor* encode_prompt(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string prompt,
|
||||
const std::pair<int, int>& prompt_attn_range,
|
||||
int max_length,
|
||||
int min_length,
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||
const std::set<int>& out_layers,
|
||||
int prompt_template_encode_start_idx) {
|
||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
||||
auto& tokens = std::get<0>(tokens_and_weights);
|
||||
auto& weights = std::get<1>(tokens_and_weights);
|
||||
|
||||
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size]
|
||||
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
|
||||
llm->compute(n_threads,
|
||||
input_ids,
|
||||
image_embeds,
|
||||
out_layers,
|
||||
&hidden_states,
|
||||
work_ctx);
|
||||
{
|
||||
auto tensor = hidden_states;
|
||||
float original_mean = ggml_ext_tensor_mean(tensor);
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
|
||||
value *= weights[i1];
|
||||
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
|
||||
}
|
||||
}
|
||||
}
|
||||
float new_mean = ggml_ext_tensor_mean(tensor);
|
||||
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
|
||||
}
|
||||
|
||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||
|
||||
int64_t zero_pad_len = 0;
|
||||
if (min_length > 0) {
|
||||
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
|
||||
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
hidden_states->ne[0],
|
||||
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
|
||||
hidden_states->ne[2]);
|
||||
|
||||
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = 0.f;
|
||||
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
|
||||
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
|
||||
}
|
||||
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
|
||||
});
|
||||
|
||||
return new_hidden_states;
|
||||
}
|
||||
|
||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const ConditionerParams& conditioner_params) override {
|
||||
std::string prompt;
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||
std::pair<int, int> prompt_attn_range;
|
||||
std::vector<std::string> extra_prompts;
|
||||
std::vector<std::pair<int, int>> extra_prompts_attn_range;
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||
int prompt_template_encode_start_idx = 34;
|
||||
int max_length = 0;
|
||||
int min_length = 0;
|
||||
std::set<int> out_layers;
|
||||
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
|
||||
if (sd_version_is_qwen_image(version)) {
|
||||
if (llm->enable_vision && !conditioner_params.ref_images.empty() > 0) {
|
||||
LOG_INFO("QwenImageEditPlusPipeline");
|
||||
prompt_template_encode_start_idx = 64;
|
||||
int image_embed_idx = 64 + 6;
|
||||
@ -1771,6 +1853,17 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
} else {
|
||||
prompt_template_encode_start_idx = 34;
|
||||
|
||||
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
|
||||
|
||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||
prompt += conditioner_params.text;
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
prompt_template_encode_start_idx = 0;
|
||||
out_layers = {10, 20, 30};
|
||||
@ -1786,6 +1879,15 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_template_encode_start_idx = 0;
|
||||
out_layers = {35}; // -2
|
||||
|
||||
if (!conditioner_params.ref_images.empty()) {
|
||||
LOG_INFO("ZImageOmniPipeline");
|
||||
prompt = "<|im_start|>user\n<|vision_start|>";
|
||||
for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) {
|
||||
extra_prompts.push_back("<|vision_end|><|vision_start|>");
|
||||
}
|
||||
extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>");
|
||||
extra_prompts.push_back("<|vision_end|><|im_end|>");
|
||||
} else {
|
||||
prompt = "<|im_start|>user\n";
|
||||
|
||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||
@ -1793,6 +1895,7 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
prompt_template_encode_start_idx = 0;
|
||||
out_layers = {10, 20, 30};
|
||||
@ -1804,6 +1907,8 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = prompt.size();
|
||||
|
||||
prompt += "[/INST]";
|
||||
|
||||
min_length = 512;
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
prompt_template_encode_start_idx = 28;
|
||||
max_length = prompt_template_encode_start_idx + 256;
|
||||
@ -1816,81 +1921,36 @@ struct LLMEmbedder : public Conditioner {
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||
} else {
|
||||
prompt_template_encode_start_idx = 34;
|
||||
|
||||
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
|
||||
|
||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||
prompt += conditioner_params.text;
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
GGML_ABORT("unknown version %d", version);
|
||||
}
|
||||
|
||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
||||
auto& tokens = std::get<0>(tokens_and_weights);
|
||||
auto& weights = std::get<1>(tokens_and_weights);
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
|
||||
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
|
||||
llm->compute(n_threads,
|
||||
input_ids,
|
||||
auto hidden_states = encode_prompt(work_ctx,
|
||||
n_threads,
|
||||
prompt,
|
||||
prompt_attn_range,
|
||||
max_length,
|
||||
min_length,
|
||||
image_embeds,
|
||||
out_layers,
|
||||
&hidden_states,
|
||||
work_ctx);
|
||||
{
|
||||
auto tensor = hidden_states;
|
||||
float original_mean = ggml_ext_tensor_mean(tensor);
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
|
||||
value *= weights[i1];
|
||||
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
|
||||
}
|
||||
}
|
||||
}
|
||||
float new_mean = ggml_ext_tensor_mean(tensor);
|
||||
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
|
||||
}
|
||||
prompt_template_encode_start_idx);
|
||||
|
||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||
|
||||
int64_t min_length = 0;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
min_length = 512;
|
||||
std::vector<ggml_tensor*> extra_hidden_states_vec;
|
||||
for (int i = 0; i < extra_prompts.size(); i++) {
|
||||
auto extra_hidden_states = encode_prompt(work_ctx,
|
||||
n_threads,
|
||||
extra_prompts[i],
|
||||
extra_prompts_attn_range[i],
|
||||
max_length,
|
||||
min_length,
|
||||
image_embeds,
|
||||
out_layers,
|
||||
prompt_template_encode_start_idx);
|
||||
extra_hidden_states_vec.push_back(extra_hidden_states);
|
||||
}
|
||||
|
||||
int64_t zero_pad_len = 0;
|
||||
if (min_length > 0) {
|
||||
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
|
||||
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
hidden_states->ne[0],
|
||||
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
|
||||
hidden_states->ne[2]);
|
||||
|
||||
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = 0.f;
|
||||
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
|
||||
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
|
||||
}
|
||||
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
|
||||
});
|
||||
|
||||
// print_ggml_tensor(new_hidden_states);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
return {new_hidden_states, nullptr, nullptr};
|
||||
return {hidden_states, nullptr, nullptr, extra_hidden_states_vec};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ struct DiffusionParams {
|
||||
struct ggml_tensor* vace_context = nullptr;
|
||||
float vace_strength = 1.f;
|
||||
std::vector<int> skip_layers = {};
|
||||
std::vector<struct ggml_tensor*> extra_contexts; // for z-image-omni
|
||||
};
|
||||
|
||||
struct DiffusionModel {
|
||||
@ -436,10 +437,12 @@ struct ZImageModel : public DiffusionModel {
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
std::vector<ggml_tensor*> contexts = {diffusion_params.context};
|
||||
contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end());
|
||||
return z_image.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
{diffusion_params.context},
|
||||
contexts,
|
||||
diffusion_params.ref_latents,
|
||||
{},
|
||||
output,
|
||||
|
||||
@ -1932,6 +1932,7 @@ public:
|
||||
if (start_merge_step == -1 || step <= start_merge_step) {
|
||||
// cond
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.extra_contexts = cond.extra_c_crossattns;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = cond.c_vector;
|
||||
active_condition = &cond;
|
||||
@ -1968,6 +1969,7 @@ public:
|
||||
current_step_skipped = cache_step_is_skipped();
|
||||
diffusion_params.controls = controls;
|
||||
diffusion_params.context = uncond.c_crossattn;
|
||||
diffusion_params.extra_contexts = uncond.extra_c_crossattns;
|
||||
diffusion_params.c_concat = uncond.c_concat;
|
||||
diffusion_params.y = uncond.c_vector;
|
||||
bool skip_uncond = cache_before_condition(&uncond, out_uncond);
|
||||
@ -1986,6 +1988,7 @@ public:
|
||||
float* img_cond_data = nullptr;
|
||||
if (has_img_cond) {
|
||||
diffusion_params.context = img_cond.c_crossattn;
|
||||
diffusion_params.extra_contexts = img_cond.extra_c_crossattns;
|
||||
diffusion_params.c_concat = img_cond.c_concat;
|
||||
diffusion_params.y = img_cond.c_vector;
|
||||
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond);
|
||||
|
||||
@ -644,7 +644,7 @@ namespace ZImage {
|
||||
t_clean = t_embedder->forward(ctx,
|
||||
ggml_scale(ctx->ggml_ctx,
|
||||
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
|
||||
0.f));
|
||||
1000.f));
|
||||
} else {
|
||||
t_emb = t_embedder->forward(ctx, timestep);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user