mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add qwen image edit support
This commit is contained in:
parent
58e81adf61
commit
40752b629f
212
conditioner.hpp
212
conditioner.hpp
@ -15,28 +15,28 @@ struct SDCondition {
|
|||||||
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
|
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ConditionerParams {
|
||||||
|
std::string text;
|
||||||
|
int clip_skip = -1;
|
||||||
|
int width = -1;
|
||||||
|
int height = -1;
|
||||||
|
int adm_in_channels = -1;
|
||||||
|
bool zero_out_masked = false;
|
||||||
|
int num_input_imgs = 0; // for photomaker
|
||||||
|
std::vector<sd_image_t*> ref_images = {}; // for qwen image edit
|
||||||
|
};
|
||||||
|
|
||||||
struct Conditioner {
|
struct Conditioner {
|
||||||
virtual SDCondition get_learned_condition(ggml_context* work_ctx,
|
virtual SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) = 0;
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) = 0;
|
|
||||||
virtual void alloc_params_buffer() = 0;
|
virtual void alloc_params_buffer() = 0;
|
||||||
virtual void free_params_buffer() = 0;
|
virtual void free_params_buffer() = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||||
virtual size_t get_params_buffer_size() = 0;
|
virtual size_t get_params_buffer_size() = 0;
|
||||||
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int num_input_imgs,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
GGML_ABORT("Not implemented yet!");
|
GGML_ABORT("Not implemented yet!");
|
||||||
}
|
}
|
||||||
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||||
@ -555,20 +555,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
std::tuple<SDCondition, std::vector<bool>>
|
std::tuple<SDCondition, std::vector<bool>>
|
||||||
get_learned_condition_with_trigger(ggml_context* work_ctx,
|
get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int num_input_imgs,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
auto image_tokens = convert_token_to_id(trigger_word);
|
auto image_tokens = convert_token_to_id(trigger_word);
|
||||||
// if(image_tokens.size() == 1){
|
// if(image_tokens.size() == 1){
|
||||||
// printf(" image token id is: %d \n", image_tokens[0]);
|
// printf(" image token id is: %d \n", image_tokens[0]);
|
||||||
// }
|
// }
|
||||||
GGML_ASSERT(image_tokens.size() == 1);
|
GGML_ASSERT(image_tokens.size() == 1);
|
||||||
auto tokens_and_weights = tokenize_with_trigger_token(text,
|
auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
|
||||||
num_input_imgs,
|
conditioner_params.num_input_imgs,
|
||||||
image_tokens[0],
|
image_tokens[0],
|
||||||
true);
|
true);
|
||||||
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
||||||
@ -582,7 +576,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
// for(int i = 0; i < clsm.size(); ++i)
|
// for(int i = 0; i < clsm.size(); ++i)
|
||||||
// printf("%d ", clsm[i]?1:0);
|
// printf("%d ", clsm[i]?1:0);
|
||||||
// printf("\n");
|
// printf("\n");
|
||||||
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
auto cond = get_learned_condition_common(work_ctx,
|
||||||
|
n_threads,
|
||||||
|
tokens,
|
||||||
|
weights,
|
||||||
|
conditioner_params.clip_skip,
|
||||||
|
conditioner_params.width,
|
||||||
|
conditioner_params.height,
|
||||||
|
conditioner_params.adm_in_channels,
|
||||||
|
conditioner_params.zero_out_masked);
|
||||||
return std::make_tuple(cond, clsm);
|
return std::make_tuple(cond, clsm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -600,16 +602,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
auto tokens_and_weights = tokenize(conditioner_params.text, true);
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
auto tokens_and_weights = tokenize(text, true);
|
|
||||||
std::vector<int>& tokens = tokens_and_weights.first;
|
std::vector<int>& tokens = tokens_and_weights.first;
|
||||||
std::vector<float>& weights = tokens_and_weights.second;
|
std::vector<float>& weights = tokens_and_weights.second;
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
return get_learned_condition_common(work_ctx,
|
||||||
|
n_threads,
|
||||||
|
tokens,
|
||||||
|
weights,
|
||||||
|
conditioner_params.clip_skip,
|
||||||
|
conditioner_params.width,
|
||||||
|
conditioner_params.height,
|
||||||
|
conditioner_params.adm_in_channels,
|
||||||
|
conditioner_params.zero_out_masked);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -974,14 +979,13 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
auto tokens_and_weights = tokenize(conditioner_params.text, 77, true);
|
||||||
int width,
|
return get_learned_condition_common(work_ctx,
|
||||||
int height,
|
n_threads,
|
||||||
int adm_in_channels = -1,
|
tokens_and_weights,
|
||||||
bool zero_out_masked = false) {
|
conditioner_params.clip_skip,
|
||||||
auto tokens_and_weights = tokenize(text, 77, true);
|
conditioner_params.zero_out_masked);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1174,14 +1178,13 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
|
||||||
int width,
|
return get_learned_condition_common(work_ctx,
|
||||||
int height,
|
n_threads,
|
||||||
int adm_in_channels = -1,
|
tokens_and_weights,
|
||||||
bool zero_out_masked = false) {
|
conditioner_params.clip_skip,
|
||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
conditioner_params.zero_out_masked);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1360,14 +1363,13 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
|
||||||
int width,
|
return get_learned_condition_common(work_ctx,
|
||||||
int height,
|
n_threads,
|
||||||
int adm_in_channels = -1,
|
tokens_and_weights,
|
||||||
bool zero_out_masked = false) {
|
conditioner_params.clip_skip,
|
||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
conditioner_params.zero_out_masked);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1379,8 +1381,13 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
|
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "") {
|
const std::string prefix = "",
|
||||||
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
|
bool enable_vision = false) {
|
||||||
|
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
tensor_types,
|
||||||
|
"text_encoders.qwen2vl",
|
||||||
|
enable_vision);
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||||
@ -1436,13 +1443,78 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
return {tokens, weights};
|
return {tokens, weights};
|
||||||
}
|
}
|
||||||
|
|
||||||
SDCondition get_learned_condition_common(ggml_context* work_ctx,
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
|
const ConditionerParams& conditioner_params) {
|
||||||
int clip_skip,
|
std::string prompt;
|
||||||
bool zero_out_masked = false) {
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||||
auto& tokens = std::get<0>(token_and_weights);
|
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||||
auto& weights = std::get<1>(token_and_weights);
|
LOG_INFO("QwenImageEditPlusPipeline");
|
||||||
|
prompt_template_encode_start_idx = 64;
|
||||||
|
int image_embed_idx = 64 + 6;
|
||||||
|
|
||||||
|
int min_pixels = 56 * 56;
|
||||||
|
int max_pixels = 560 * 560;
|
||||||
|
std::string placeholder = "<|image_pad|>";
|
||||||
|
std::string img_prompt;
|
||||||
|
|
||||||
|
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
|
||||||
|
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
|
||||||
|
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
|
||||||
|
int height = image.height;
|
||||||
|
int width = image.width;
|
||||||
|
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
|
||||||
|
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
|
||||||
|
|
||||||
|
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
|
||||||
|
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
|
||||||
|
h_bar = std::max(static_cast<int>(factor),
|
||||||
|
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
|
||||||
|
w_bar = std::max(static_cast<int>(factor),
|
||||||
|
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
|
||||||
|
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
|
||||||
|
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
|
||||||
|
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
|
||||||
|
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
|
||||||
|
|
||||||
|
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
|
||||||
|
free(image.data);
|
||||||
|
image.data = nullptr;
|
||||||
|
|
||||||
|
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
||||||
|
sd_image_f32_to_tensor(resized_image.data, image_tensor, false);
|
||||||
|
free(resized_image.data);
|
||||||
|
resized_image.data = nullptr;
|
||||||
|
|
||||||
|
ggml_tensor* image_embed = nullptr;
|
||||||
|
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
|
||||||
|
image_embeds.emplace_back(image_embed_idx, image_embed);
|
||||||
|
image_embed_idx += 1 + image_embed->ne[1] + 6;
|
||||||
|
|
||||||
|
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
|
||||||
|
int64_t num_image_tokens = image_embed->ne[1];
|
||||||
|
img_prompt.reserve(num_image_tokens * placeholder.size());
|
||||||
|
for (int j = 0; j < num_image_tokens; j++) {
|
||||||
|
img_prompt += placeholder;
|
||||||
|
}
|
||||||
|
img_prompt += "<|vision_end|>";
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
|
||||||
|
|
||||||
|
prompt += img_prompt;
|
||||||
|
prompt += conditioner_params.text;
|
||||||
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
|
} else {
|
||||||
|
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tokens_and_weights = tokenize(prompt, 0, false);
|
||||||
|
auto& tokens = std::get<0>(tokens_and_weights);
|
||||||
|
auto& weights = std::get<1>(tokens_and_weights);
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
|
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
|
||||||
@ -1451,6 +1523,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
qwenvl->compute(n_threads,
|
qwenvl->compute(n_threads,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
image_embeds,
|
||||||
&hidden_states,
|
&hidden_states,
|
||||||
work_ctx);
|
work_ctx);
|
||||||
{
|
{
|
||||||
@ -1486,19 +1559,6 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
return SDCondition(new_hidden_states, nullptr, nullptr);
|
return SDCondition(new_hidden_states, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
|
||||||
int n_threads,
|
|
||||||
const std::string& text,
|
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
|
|
||||||
auto tokens_and_weights = tokenize(prompt, 0, false);
|
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -1144,10 +1144,6 @@ bool load_images_from_dir(const std::string dir,
|
|||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
SDParams params;
|
SDParams params;
|
||||||
params.verbose = true;
|
|
||||||
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
|
||||||
Qwen::Qwen2_5_VLEmbedder::load_from_file_and_test(argv[1]);
|
|
||||||
return 1;
|
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
params.sample_params.guidance.slg.layers = params.skip_layers.data();
|
params.sample_params.guidance.slg.layers = params.skip_layers.data();
|
||||||
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
|
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
|
||||||
|
|||||||
@ -113,7 +113,6 @@ const char* unused_tensors[] = {
|
|||||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||||
"text_encoders.qwen2vl.output.weight",
|
"text_encoders.qwen2vl.output.weight",
|
||||||
"text_encoders.qwen2vl.lm_head.",
|
"text_encoders.qwen2vl.lm_head.",
|
||||||
"text_encoders.qwen2vl.visual.",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_unused_tensor(std::string name) {
|
bool is_unused_tensor(std::string name) {
|
||||||
|
|||||||
153
qwenvl.hpp
153
qwenvl.hpp
@ -692,7 +692,8 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
// return: [N, n_token, hidden_size]
|
// return: [N, n_token, hidden_size]
|
||||||
|
|
||||||
@ -701,6 +702,46 @@ namespace Qwen {
|
|||||||
|
|
||||||
auto x = embed_tokens->forward(ctx, input_ids);
|
auto x = embed_tokens->forward(ctx, input_ids);
|
||||||
|
|
||||||
|
if (image_embeds.size() > 0) {
|
||||||
|
GGML_ASSERT(x->ne[2] == 1); // N == 1
|
||||||
|
|
||||||
|
auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type);
|
||||||
|
int64_t txt_token_start = 0;
|
||||||
|
int64_t txt_token_end = 0;
|
||||||
|
|
||||||
|
ggml_tensor* input_embed = nullptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < image_embeds.size(); i++) {
|
||||||
|
if (i == 0) {
|
||||||
|
txt_token_start = 0;
|
||||||
|
} else {
|
||||||
|
txt_token_start = image_embeds[i - 1].first + image_embeds[i - 1].second->ne[1];
|
||||||
|
}
|
||||||
|
txt_token_end = image_embeds[i].first;
|
||||||
|
|
||||||
|
auto txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
|
||||||
|
if (input_embed == nullptr) {
|
||||||
|
input_embed = txt_embed;
|
||||||
|
} else {
|
||||||
|
input_embed = ggml_concat(ctx, input_embed, txt_embed, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto image_embed = image_embeds[i].second;
|
||||||
|
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto final_txt_embed = ggml_slice(ctx,
|
||||||
|
raw_x,
|
||||||
|
1,
|
||||||
|
image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1],
|
||||||
|
raw_x->ne[1]);
|
||||||
|
|
||||||
|
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
|
||||||
|
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
|
||||||
|
|
||||||
|
x = input_embed;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
@ -770,11 +811,12 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
|
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
|
||||||
|
|
||||||
auto x = model->forward(ctx, backend, input_ids, input_pos);
|
auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -793,6 +835,7 @@ namespace Qwen {
|
|||||||
|
|
||||||
struct Qwen2_5_VLRunner : public GGMLRunner {
|
struct Qwen2_5_VLRunner : public GGMLRunner {
|
||||||
Qwen2_5_VLParams params;
|
Qwen2_5_VLParams params;
|
||||||
|
bool enable_vision;
|
||||||
Qwen2_5_VL model;
|
Qwen2_5_VL model;
|
||||||
|
|
||||||
std::vector<int> input_pos_vec;
|
std::vector<int> input_pos_vec;
|
||||||
@ -805,8 +848,27 @@ namespace Qwen {
|
|||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
const String2GGMLType& tensor_types,
|
const String2GGMLType& tensor_types,
|
||||||
const std::string prefix,
|
const std::string prefix,
|
||||||
bool enable_vision = false)
|
bool enable_vision_ = false)
|
||||||
: GGMLRunner(backend, offload_params_to_cpu), model(params, enable_vision) {
|
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
|
||||||
|
bool have_vision_weight = false;
|
||||||
|
for (auto pair : tensor_types) {
|
||||||
|
std::string tensor_name = pair.first;
|
||||||
|
if (tensor_name.find(prefix) == std::string::npos)
|
||||||
|
continue;
|
||||||
|
size_t pos = tensor_name.find("visual.");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
have_vision_weight = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (enable_vision && !have_vision_weight) {
|
||||||
|
LOG_WARN("no vision weights detected, vision disabled");
|
||||||
|
enable_vision = false;
|
||||||
|
}
|
||||||
|
if (enable_vision) {
|
||||||
|
LOG_DEBUG("enable qwen2vl vision");
|
||||||
|
}
|
||||||
|
model = Qwen2_5_VL(params, enable_vision);
|
||||||
model.init(params_ctx, tensor_types, prefix);
|
model.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -821,8 +883,9 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size]
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
|
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -837,11 +900,15 @@ namespace Qwen {
|
|||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
|
||||||
input_ids = to_backend(input_ids);
|
input_ids = to_backend(input_ids);
|
||||||
|
|
||||||
|
for (auto& image_embed : image_embeds) {
|
||||||
|
image_embed.second = to_backend(image_embed.second);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t n_tokens = input_ids->ne[0];
|
int64_t n_tokens = input_ids->ne[0];
|
||||||
input_pos_vec.resize(n_tokens * 4);
|
input_pos_vec.resize(n_tokens * 4);
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
@ -856,7 +923,7 @@ namespace Qwen {
|
|||||||
n_tokens * 4);
|
n_tokens * 4);
|
||||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos);
|
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
@ -865,14 +932,24 @@ namespace Qwen {
|
|||||||
|
|
||||||
void compute(const int n_threads,
|
void compute(const int n_threads,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
ggml_tensor** output,
|
ggml_tensor** output,
|
||||||
ggml_context* output_ctx = NULL) {
|
ggml_context* output_ctx = NULL) {
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_graph(input_ids);
|
return build_graph(input_ids, image_embeds);
|
||||||
};
|
};
|
||||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
|
||||||
|
int grid_t = 1;
|
||||||
|
int grid_h = h / params.vision.patch_size;
|
||||||
|
int grid_w = w / params.vision.patch_size;
|
||||||
|
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
|
||||||
|
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
|
||||||
|
return grid_t * grid_h * grid_w;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) {
|
struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) {
|
||||||
// image: [C, H, W]
|
// image: [C, H, W]
|
||||||
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
|
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
|
||||||
@ -1030,7 +1107,7 @@ namespace Qwen {
|
|||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_encode_image_graph(image);
|
return build_encode_image_graph(image);
|
||||||
};
|
};
|
||||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1098,8 +1175,58 @@ namespace Qwen {
|
|||||||
struct ggml_context* work_ctx = ggml_init(params);
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
GGML_ASSERT(work_ctx != NULL);
|
GGML_ASSERT(work_ctx != NULL);
|
||||||
bool test_vit = true;
|
bool test_vit = true;
|
||||||
|
bool test_decoder_with_vit = true;
|
||||||
|
|
||||||
if (test_vit) {
|
if (test_decoder_with_vit) {
|
||||||
|
ggml_tensor* image_embed = nullptr;
|
||||||
|
{
|
||||||
|
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
|
||||||
|
print_ggml_tensor(image, false, "image");
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int t0 = ggml_time_ms();
|
||||||
|
model.encode_image(8, image, &out, work_ctx);
|
||||||
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out, false, "image_embed");
|
||||||
|
image_embed = out;
|
||||||
|
LOG_DEBUG("qwen2vl encode_image test done in %dms", t1 - t0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string placeholder = "<|image_pad|>";
|
||||||
|
std::string img_prompt = "Picture 1: <|vision_start|>"; // [24669, 220, 16, 25, 220, 151652]
|
||||||
|
int64_t num_image_tokens = image_embed->ne[1];
|
||||||
|
img_prompt.reserve(num_image_tokens * placeholder.size());
|
||||||
|
for (int i = 0; i < num_image_tokens; i++) {
|
||||||
|
img_prompt += placeholder;
|
||||||
|
}
|
||||||
|
img_prompt += "<|vision_end|>";
|
||||||
|
|
||||||
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||||
|
image_embeds.emplace_back(64, image_embed);
|
||||||
|
|
||||||
|
std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
|
||||||
|
text += img_prompt;
|
||||||
|
text += "change 'flux.cpp' to 'edit.cpp'";
|
||||||
|
text += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
|
|
||||||
|
auto tokens_and_weights = tokenize(text, 0, false);
|
||||||
|
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
||||||
|
std::vector<float>& weights = std::get<1>(tokens_and_weights);
|
||||||
|
for (auto token : tokens) {
|
||||||
|
printf("%d ", token);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int t0 = ggml_time_ms();
|
||||||
|
model.compute(8, input_ids, image_embeds, &out, work_ctx);
|
||||||
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
|
||||||
|
} else if (test_vit) {
|
||||||
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
|
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
|
||||||
// ggml_set_f32(image, 0.f);
|
// ggml_set_f32(image, 0.f);
|
||||||
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
|
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
|
||||||
@ -1129,7 +1256,7 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* out = NULL;
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
int t0 = ggml_time_ms();
|
int t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, &out, work_ctx);
|
model.compute(8, input_ids, {}, &out, work_ctx);
|
||||||
int t1 = ggml_time_ms();
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|||||||
@ -272,6 +272,15 @@ public:
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto& tensor_types = model_loader.tensor_storages_types;
|
||||||
|
for (auto& item : tensor_types) {
|
||||||
|
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||||
|
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
|
||||||
|
item.second = GGML_TYPE_F16;
|
||||||
|
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
||||||
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
||||||
? (ggml_type)sd_ctx_params->wtype
|
? (ggml_type)sd_ctx_params->wtype
|
||||||
@ -420,9 +429,15 @@ public:
|
|||||||
clip_vision->get_param_tensors(tensors);
|
clip_vision->get_param_tensors(tensors);
|
||||||
}
|
}
|
||||||
} else if (sd_version_is_qwen_image(version)) {
|
} else if (sd_version_is_qwen_image(version)) {
|
||||||
|
bool enable_vision = false;
|
||||||
|
if (!vae_decode_only) {
|
||||||
|
enable_vision = true;
|
||||||
|
}
|
||||||
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types,
|
||||||
|
"",
|
||||||
|
enable_vision);
|
||||||
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
@ -594,6 +609,7 @@ public:
|
|||||||
if (vae_decode_only) {
|
if (vae_decode_only) {
|
||||||
ignore_tensors.insert("first_stage_model.encoder");
|
ignore_tensors.insert("first_stage_model.encoder");
|
||||||
ignore_tensors.insert("first_stage_model.quant");
|
ignore_tensors.insert("first_stage_model.quant");
|
||||||
|
ignore_tensors.insert("text_encoders.qwen2vl.visual.");
|
||||||
}
|
}
|
||||||
if (version == VERSION_SVD) {
|
if (version == VERSION_SVD) {
|
||||||
ignore_tensors.insert("conditioner.embedders.3");
|
ignore_tensors.insert("conditioner.embedders.3");
|
||||||
@ -1977,6 +1993,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
sd_image_t control_image,
|
sd_image_t control_image,
|
||||||
float control_strength,
|
float control_strength,
|
||||||
sd_pm_params_t pm_params,
|
sd_pm_params_t pm_params,
|
||||||
|
std::vector<sd_image_t*> ref_images,
|
||||||
std::vector<ggml_tensor*> ref_latents,
|
std::vector<ggml_tensor*> ref_latents,
|
||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
ggml_tensor* concat_latent = NULL,
|
ggml_tensor* concat_latent = NULL,
|
||||||
@ -2007,6 +2024,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
ggml_tensor* init_img = NULL;
|
ggml_tensor* init_img = NULL;
|
||||||
SDCondition id_cond;
|
SDCondition id_cond;
|
||||||
std::vector<bool> class_tokens_mask;
|
std::vector<bool> class_tokens_mask;
|
||||||
|
|
||||||
|
ConditionerParams condition_params;
|
||||||
|
condition_params.clip_skip = clip_skip;
|
||||||
|
condition_params.width = width;
|
||||||
|
condition_params.height = height;
|
||||||
|
condition_params.ref_images = ref_images;
|
||||||
|
condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels();
|
||||||
|
|
||||||
if (sd_ctx->sd->stacked_id) {
|
if (sd_ctx->sd->stacked_id) {
|
||||||
if (!sd_ctx->sd->pmid_lora->applied) {
|
if (!sd_ctx->sd->pmid_lora->applied) {
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
@ -2047,13 +2072,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
processed_id_images.clear();
|
processed_id_images.clear();
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
|
condition_params.text = prompt;
|
||||||
|
condition_params.num_input_imgs = pm_params.id_images_count;
|
||||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||||
sd_ctx->sd->n_threads, prompt,
|
sd_ctx->sd->n_threads,
|
||||||
clip_skip,
|
condition_params);
|
||||||
width,
|
|
||||||
height,
|
|
||||||
pm_params.id_images_count,
|
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
|
||||||
id_cond = std::get<0>(cond_tup);
|
id_cond = std::get<0>(cond_tup);
|
||||||
class_tokens_mask = std::get<1>(cond_tup); //
|
class_tokens_mask = std::get<1>(cond_tup); //
|
||||||
struct ggml_tensor* id_embeds = NULL;
|
struct ggml_tensor* id_embeds = NULL;
|
||||||
@ -2083,13 +2106,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
// Get learned condition
|
// Get learned condition
|
||||||
t0 = ggml_time_ms();
|
t0 = ggml_time_ms();
|
||||||
|
condition_params.text = prompt;
|
||||||
|
condition_params.zero_out_masked = false;
|
||||||
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
sd_ctx->sd->n_threads,
|
sd_ctx->sd->n_threads,
|
||||||
prompt,
|
condition_params);
|
||||||
clip_skip,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
|
||||||
|
|
||||||
SDCondition uncond;
|
SDCondition uncond;
|
||||||
if (guidance.txt_cfg != 1.0 ||
|
if (guidance.txt_cfg != 1.0 ||
|
||||||
@ -2098,14 +2119,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
|
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
|
||||||
zero_out_masked = true;
|
zero_out_masked = true;
|
||||||
}
|
}
|
||||||
|
condition_params.text = negative_prompt;
|
||||||
|
condition_params.zero_out_masked = zero_out_masked;
|
||||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
sd_ctx->sd->n_threads,
|
sd_ctx->sd->n_threads,
|
||||||
negative_prompt,
|
condition_params);
|
||||||
clip_skip,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
|
||||||
zero_out_masked);
|
|
||||||
}
|
}
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
@ -2507,6 +2525,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
sd_img_gen_params->control_image,
|
sd_img_gen_params->control_image,
|
||||||
sd_img_gen_params->control_strength,
|
sd_img_gen_params->control_strength,
|
||||||
sd_img_gen_params->pm_params,
|
sd_img_gen_params->pm_params,
|
||||||
|
ref_images,
|
||||||
ref_latents,
|
ref_latents,
|
||||||
sd_img_gen_params->increase_ref_index,
|
sd_img_gen_params->increase_ref_index,
|
||||||
concat_latent,
|
concat_latent,
|
||||||
@ -2764,28 +2783,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get learned condition
|
// Get learned condition
|
||||||
bool zero_out_masked = true;
|
ConditionerParams condition_params;
|
||||||
|
condition_params.clip_skip = sd_vid_gen_params->clip_skip;
|
||||||
|
condition_params.zero_out_masked = true;
|
||||||
|
condition_params.text = prompt;
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
sd_ctx->sd->n_threads,
|
sd_ctx->sd->n_threads,
|
||||||
prompt,
|
condition_params);
|
||||||
sd_vid_gen_params->clip_skip,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
|
||||||
zero_out_masked);
|
|
||||||
cond.c_concat = concat_latent;
|
cond.c_concat = concat_latent;
|
||||||
cond.c_vector = clip_vision_output;
|
cond.c_vector = clip_vision_output;
|
||||||
SDCondition uncond;
|
SDCondition uncond;
|
||||||
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) {
|
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) {
|
||||||
|
condition_params.text = negative_prompt;
|
||||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||||
sd_ctx->sd->n_threads,
|
sd_ctx->sd->n_threads,
|
||||||
negative_prompt,
|
condition_params);
|
||||||
sd_vid_gen_params->clip_skip,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
|
||||||
zero_out_masked);
|
|
||||||
uncond.c_concat = concat_latent;
|
uncond.c_concat = concat_latent;
|
||||||
uncond.c_vector = clip_vision_output;
|
uncond.c_vector = clip_vision_output;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user