feat: add Qwen Image Edit support (#877)

* add ref latent support for qwen image

* optimize clip_preprocess and fix get_first_stage_encoding

* add qwen2vl vit support

* add qwen image edit support

* fix qwen image edit pipeline

* add mmproj file support

* support dynamic number of Qwen image transformer blocks

* set prompt_template_encode_start_idx every time

* to_add_out precision fix

* to_out.0 precision fix

* update docs
This commit is contained in:
leejet 2025-10-13 23:17:18 +08:00 committed by GitHub
parent c64994dc1d
commit 2e9242e37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1339 additions and 365 deletions

View File

@ -24,6 +24,7 @@ API and command-line option may change frequently.***
- [Qwen Image](./docs/qwen_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- Video Models
- [Wan2.1/Wan2.2](./docs/wan.md)
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
@ -298,6 +299,7 @@ arguments:
--clip_vision path to the clip-vision encoder
--t5xxl path to the t5xxl text encoder
--qwen2vl path to the qwen2vl text encoder
--qwen2vl_vision path to the qwen2vl vit
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 415 KiB

View File

@ -15,28 +15,28 @@ struct SDCondition {
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
};
struct ConditionerParams {
std::string text;
int clip_skip = -1;
int width = -1;
int height = -1;
int adm_in_channels = -1;
bool zero_out_masked = false;
int num_input_imgs = 0; // for photomaker
std::vector<sd_image_t*> ref_images = {}; // for qwen image edit
};
struct Conditioner {
virtual SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
const ConditionerParams& conditioner_params) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
const ConditionerParams& conditioner_params) {
GGML_ABORT("Not implemented yet!");
}
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
@ -555,20 +555,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::tuple<SDCondition, std::vector<bool>>
get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
const ConditionerParams& conditioner_params) {
auto image_tokens = convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]);
// }
GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = tokenize_with_trigger_token(text,
num_input_imgs,
auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
conditioner_params.num_input_imgs,
image_tokens[0],
true);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@ -582,7 +576,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// for(int i = 0; i < clsm.size(); ++i)
// printf("%d ", clsm[i]?1:0);
// printf("\n");
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
auto cond = get_learned_condition_common(work_ctx,
n_threads,
tokens,
weights,
conditioner_params.clip_skip,
conditioner_params.width,
conditioner_params.height,
conditioner_params.adm_in_channels,
conditioner_params.zero_out_masked);
return std::make_tuple(cond, clsm);
}
@ -600,16 +602,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, true);
const ConditionerParams& conditioner_params) {
auto tokens_and_weights = tokenize(conditioner_params.text, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
return get_learned_condition_common(work_ctx,
n_threads,
tokens,
weights,
conditioner_params.clip_skip,
conditioner_params.width,
conditioner_params.height,
conditioner_params.adm_in_channels,
conditioner_params.zero_out_masked);
}
};
@ -974,14 +979,13 @@ struct SD3CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, 77, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
const ConditionerParams& conditioner_params) {
auto tokens_and_weights = tokenize(conditioner_params.text, 77, true);
return get_learned_condition_common(work_ctx,
n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
conditioner_params.zero_out_masked);
}
};
@ -1174,14 +1178,13 @@ struct FluxCLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
const ConditionerParams& conditioner_params) {
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
return get_learned_condition_common(work_ctx,
n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
conditioner_params.zero_out_masked);
}
};
@ -1360,27 +1363,30 @@ struct T5CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
const ConditionerParams& conditioner_params) {
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
return get_learned_condition_common(work_ctx,
n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
conditioner_params.zero_out_masked);
}
};
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
int prompt_template_encode_start_idx = 34;
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "") {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
const std::string prefix = "",
bool enable_vision = false) {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
offload_params_to_cpu,
tensor_types,
"text_encoders.qwen2vl",
enable_vision);
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
@ -1402,9 +1408,19 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
}
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
size_t max_length = 0,
size_t system_prompt_length = 0,
bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) {
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
}
{
std::stringstream ss;
@ -1429,20 +1445,89 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
tokenizer.pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
// }
// std::cout << std::endl;
return {tokens, weights};
}
SDCondition get_learned_condition_common(ggml_context* work_ctx,
int n_threads,
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
int clip_skip,
bool zero_out_masked = false) {
auto& tokens = std::get<0>(token_and_weights);
auto& weights = std::get<1>(token_and_weights);
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const ConditionerParams& conditioner_params) {
std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
size_t system_prompt_length = 0;
int prompt_template_encode_start_idx = 34;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
int min_pixels = 384 * 384;
int max_pixels = 560 * 560;
std::string placeholder = "<|image_pad|>";
std::string img_prompt;
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
h_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
w_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
}
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
free(image.data);
image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image, image_tensor, false);
free(resized_image.data);
resized_image.data = nullptr;
ggml_tensor* image_embed = nullptr;
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int j = 0; j < num_image_tokens; j++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
}
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
system_prompt_length = prompt.size();
prompt += img_prompt;
prompt += conditioner_params.text;
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
}
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
@ -1451,6 +1536,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
qwenvl->compute(n_threads,
input_ids,
image_embeds,
&hidden_states,
work_ctx);
{
@ -1486,19 +1572,6 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return SDCondition(new_hidden_states, nullptr, nullptr);
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(prompt, 0, false);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
};
#endif

View File

@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
output,
output_ctx);
}

35
docs/qwen_image_edit.md Normal file
View File

@ -0,0 +1,35 @@
# How to Use
## Download weights
- Download Qwen Image
- Qwen Image Edit
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-GGUF/tree/main
- Qwen Image Edit 2509
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
- Download qwen_2.5_vl 7b
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main
## Examples
### Qwen Image Edit
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
```
<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />
### Qwen Image Edit 2509
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
```
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />

View File

@ -62,6 +62,7 @@ struct SDParams {
std::string clip_vision_path;
std::string t5xxl_path;
std::string qwen2vl_path;
std::string qwen2vl_vision_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
@ -220,6 +222,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip_vision path to the clip-vision encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --qwen2vl path to the qwen2vl text encoder\n");
printf(" --qwen2vl_vision path to the qwen2vl vit\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
@ -490,6 +493,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path},
{"", "--qwen2vl", "", &params.qwen2vl_path},
{"", "--qwen2vl_vision", "", &params.qwen2vl_vision_path},
{"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
{"", "--vae", "", &params.vae_path},
@ -952,7 +956,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
@ -1336,6 +1340,7 @@ int main(int argc, const char* argv[]) {
params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(),
params.qwen2vl_vision_path.c_str(),
params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(),

View File

@ -81,57 +81,6 @@ namespace Flux {
}
};
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
struct SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
@ -179,9 +128,9 @@ namespace Flux {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
@ -369,8 +318,8 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
attn->ne[0],
@ -504,7 +453,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]

View File

@ -197,8 +197,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
return value;
}
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) {
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) {
value /= 255.f;
}
return value;
}
@ -458,24 +461,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
}
}
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
struct ggml_tensor* output,
__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
ggml_tensor* tensor,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
int value = *(image_data + iy * width * channels + ix * channels + k);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
GGML_ASSERT(image.width == tensor->ne[0]);
GGML_ASSERT(image.height == tensor->ne[1]);
GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(1 == tensor->ne[3]);
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(image, i0, i1, i2, scale);
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
});
}
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,

View File

@ -113,7 +113,6 @@ const char* unused_tensors[] = {
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.qwen2vl.output.weight",
"text_encoders.qwen2vl.lm_head.",
"text_encoders.qwen2vl.visual.",
};
bool is_unused_tensor(std::string name) {
@ -212,6 +211,24 @@ std::unordered_map<std::string, std::string> qwenvl_name_map{
{"output_norm.", "model.norm."},
};
std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
{"mm.", "merger.mlp."},
{"v.post_ln.", "merger.ln_q."},
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
{"v.blk.", "blocks."},
{"attn_q.", "attn.q_proj."},
{"attn_k.", "attn.k_proj."},
{"attn_v.", "attn.v_proj."},
{"attn_out.", "attn.proj."},
{"ffn_down.", "mlp.down_proj."},
{"ffn_gate.", "mlp.gate_proj."},
{"ffn_up.", "mlp.up_proj."},
{"ln1.", "norm1."},
{"ln2.", "norm2."},
};
std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name;
std::string prefix;
@ -270,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) {
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
}
} else if (contains(name, "qwen2vl")) {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
if (contains(name, "qwen2vl.visual")) {
for (auto kv : qwenvl_vision_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
} else {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
}
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {

View File

@ -94,12 +94,12 @@ namespace Qwen {
blocks["norm_added_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias));
// to_out.1 is nn.Dropout
float scale = 1.f / 32.f;
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example when using CUDA but the weights are k-quants (not all prompts).
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, false, scale));
// to_out.1 is nn.Dropout
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
}
@ -159,7 +159,7 @@ namespace Qwen {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
@ -389,6 +389,13 @@ namespace Qwen {
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
@ -449,7 +456,8 @@ namespace Qwen {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT.
// x: [N, C, H, W]
// timestep: [N,]
@ -462,13 +470,26 @@ namespace Qwen {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
}
}
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
@ -495,6 +516,25 @@ namespace Qwen {
bool flash_attn = false)
: GGMLRunner(backend, offload_params_to_cpu) {
qwen_image_params.flash_attn = flash_attn;
qwen_image_params.num_layers = 0;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("transformer_blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > qwen_image_params.num_layers) {
qwen_image_params.num_layers = block_index + 1;
}
}
continue;
}
}
LOG_ERROR("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
qwen_image = QwenImageModel(qwen_image_params);
qwen_image.init(params_ctx, tensor_types, prefix);
}
@ -509,7 +549,9 @@ namespace Qwen {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context) {
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
@ -517,18 +559,24 @@ namespace Qwen {
context = to_backend(context);
timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
x->ne[0],
qwen_image_params.patch_size,
x->ne[3],
context->ne[1],
ref_latents,
increase_ref_index,
qwen_image_params.theta,
qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// print_ggml_tensor(pe, true, "pe");
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
@ -537,7 +585,8 @@ namespace Qwen {
x,
timesteps,
context,
pe);
pe,
ref_latents);
ggml_build_forward_expand(gf, out);
@ -548,13 +597,15 @@ namespace Qwen {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context);
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -586,7 +637,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, &out, work_ctx);
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);

View File

@ -15,9 +15,11 @@
#include "clip.hpp"
#include "ggml_extend.hpp"
#include "json.hpp"
#include "rope.hpp"
#include "tokenize_util.h"
namespace Qwen {
constexpr int QWENVL_GRAPH_SIZE = 10240;
class Qwen2Tokenizer {
private:
@ -340,9 +342,9 @@ namespace Qwen {
struct Qwen2_5_VLMLP : public GGMLBlock {
public:
Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false));
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false));
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, false));
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@ -359,6 +361,288 @@ namespace Qwen {
}
};
struct Qwen2_5_VisionPatchEmbed : public GGMLBlock {
protected:
bool llama_cpp_style;
int patch_size;
int temporal_patch_size;
int64_t in_channels;
int64_t embed_dim;
public:
Qwen2_5_VisionPatchEmbed(bool llama_cpp_style,
int patch_size = 14,
int temporal_patch_size = 2,
int64_t in_channels = 3,
int64_t embed_dim = 1152)
: llama_cpp_style(llama_cpp_style),
patch_size(patch_size),
temporal_patch_size(temporal_patch_size),
in_channels(in_channels),
embed_dim(embed_dim) {
if (llama_cpp_style) {
blocks["proj.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels,
embed_dim,
{patch_size, patch_size},
{patch_size, patch_size}, // stride
{0, 0}, // padding
{1, 1}, // dilation
false));
blocks["proj.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels,
embed_dim,
{patch_size, patch_size},
{patch_size, patch_size}, // stride
{0, 0}, // padding
{1, 1}, // dilation
false));
} else {
std::tuple<int, int, int> kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size};
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
embed_dim,
kernel_size,
kernel_size, // stride
{0, 0, 0}, // padding
{1, 1, 1}, // dilation
false));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
// return: [N*grid_t*grid_h*grid_w, embed_dim]
x = ggml_reshape_4d(ctx,
x,
patch_size,
patch_size,
temporal_patch_size,
ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size));
if (llama_cpp_style) {
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
auto x0 = ggml_slice(ctx, x, 2, 0, 1);
x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
x0 = proj_0->forward(ctx, x0);
auto x1 = ggml_slice(ctx, x, 2, 1, 2);
x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
x1 = proj_1->forward(ctx, x1);
x = ggml_add(ctx, x0, x1);
} else {
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
x = proj->forward(ctx, x);
}
x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
return x;
}
};
struct Qwen2_5_VLPatchMerger : public GGMLBlock {
protected:
int64_t hidden_size;
public:
Qwen2_5_VLPatchMerger(int64_t dim,
int64_t context_dim,
int64_t spatial_merge_size) {
hidden_size = context_dim * spatial_merge_size * spatial_merge_size;
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(context_dim, 1e-6f));
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
// mlp.1 is nn.GELU()
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
};
struct Qwen2_5_VLVisionAttention : public GGMLBlock {
protected:
bool llama_cpp_style;
int64_t head_dim;
int64_t num_heads;
public:
Qwen2_5_VLVisionAttention(bool llama_cpp_style,
int64_t hidden_size,
int64_t num_heads)
: llama_cpp_style(llama_cpp_style), num_heads(num_heads) {
head_dim = hidden_size / num_heads;
GGML_ASSERT(num_heads * head_dim == hidden_size);
if (llama_cpp_style) {
blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
blocks["k_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
blocks["v_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3));
}
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
std::vector<ggml_tensor*> qkv_vec;
if (llama_cpp_style) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q_proj"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k_proj"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v_proj"]);
auto q = q_proj->forward(ctx, x);
auto k = k_proj->forward(ctx, x);
auto v = v_proj->forward(ctx, x);
qkv_vec = {q, k, v};
} else {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
qkv_vec = split_qkv(ctx, qkv);
}
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size]
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
}
};
struct Qwen2_5_VLVisionBlock : public GGMLBlock {
public:
Qwen2_5_VLVisionBlock(bool llama_cpp_style,
int64_t hidden_size,
int64_t intermediate_size,
int64_t num_heads,
float eps = 1e-6f) {
blocks["attn"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionAttention(llama_cpp_style, hidden_size, num_heads));
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
auto attn = std::dynamic_pointer_cast<Qwen2_5_VLVisionAttention>(blocks["attn"]);
auto mlp = std::dynamic_pointer_cast<Qwen2_5_VLMLP>(blocks["mlp"]);
auto norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm2"]);
auto residual = x;
x = norm1->forward(ctx, x);
x = attn->forward(ctx, backend, x, pe, mask);
x = ggml_add_inplace(ctx, x, residual);
residual = x;
x = norm2->forward(ctx, x);
x = mlp->forward(ctx, x);
x = ggml_add_inplace(ctx, x, residual);
return x;
}
};
struct Qwen2_5_VLVisionModel : public GGMLBlock {
protected:
int64_t num_layers;
int64_t spatial_merge_size;
std::set<int> fullatt_block_indexes;
public:
Qwen2_5_VLVisionModel(bool llama_cpp_style,
int64_t num_layers,
int64_t in_channels,
int64_t hidden_size,
int64_t out_hidden_size,
int64_t intermediate_size,
int64_t num_heads,
int64_t spatial_merge_size,
int64_t patch_size,
int64_t temporal_patch_size,
int64_t window_size,
std::set<int> fullatt_block_indexes = {7, 15, 23, 31},
float eps = 1e-6f)
: num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) {
blocks["patch_embed"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VisionPatchEmbed(llama_cpp_style,
patch_size,
temporal_patch_size,
in_channels,
hidden_size));
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionBlock(llama_cpp_style,
hidden_size,
intermediate_size,
num_heads,
eps));
}
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
// pixel_values: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
// window_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_inverse_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_mask: [grid_h*grid_w, grid_h*grid_w]
auto patch_embed = std::dynamic_pointer_cast<Qwen2_5_VisionPatchEmbed>(blocks["patch_embed"]);
auto merger = std::dynamic_pointer_cast<Qwen2_5_VLPatchMerger>(blocks["merger"]);
auto x = patch_embed->forward(ctx, pixel_values);
x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx, x, window_index);
x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
auto mask = window_mask;
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
mask = nullptr;
}
x = block->forward(ctx, backend, x, pe, mask);
}
x = merger->forward(ctx, x);
x = ggml_get_rows(ctx, x, window_inverse_index);
return x;
}
};
struct Qwen2_5_VLAttention : public GGMLBlock {
protected:
int64_t head_dim;
@ -478,7 +762,8 @@ namespace Qwen {
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) {
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
// input_ids: [N, n_token]
// return: [N, n_token, hidden_size]
@ -487,6 +772,45 @@ namespace Qwen {
auto x = embed_tokens->forward(ctx, input_ids);
if (image_embeds.size() > 0) {
GGML_ASSERT(x->ne[2] == 1); // N == 1
auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type);
int64_t txt_token_start = 0;
int64_t txt_token_end = 0;
ggml_tensor* input_embed = nullptr;
for (int i = 0; i < image_embeds.size(); i++) {
if (i == 0) {
txt_token_start = 0;
} else {
txt_token_start = image_embeds[i - 1].first + image_embeds[i - 1].second->ne[1];
}
txt_token_end = image_embeds[i].first;
auto txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
if (input_embed == nullptr) {
input_embed = txt_embed;
} else {
input_embed = ggml_concat(ctx, input_embed, txt_embed, 1);
}
auto image_embed = image_embeds[i].second;
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
}
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
txt_token_end = raw_x->ne[1];
auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
x = input_embed;
}
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
@ -498,6 +822,20 @@ namespace Qwen {
}
};
struct Qwen2_5_VLVisionParams {
int64_t num_layers = 32;
int64_t hidden_size = 1280;
int64_t intermediate_size = 3420;
int64_t num_heads = 16;
int64_t in_channels = 3;
int64_t out_hidden_size = 3584;
int64_t temporal_patch_size = 2;
int64_t patch_size = 14;
int64_t spatial_merge_size = 2;
int64_t window_size = 112;
std::set<int> fullatt_block_indexes = {7, 15, 23, 31};
};
struct Qwen2_5_VLParams {
int64_t num_layers = 28;
int64_t hidden_size = 3584;
@ -506,15 +844,17 @@ namespace Qwen {
int64_t num_kv_heads = 4;
int64_t vocab_size = 152064;
float rms_norm_eps = 1e-06f;
Qwen2_5_VLVisionParams vision;
};
struct Qwen2_5_VL : public GGMLBlock {
bool enable_vision;
Qwen2_5_VLParams params;
public:
Qwen2_5_VL() {}
Qwen2_5_VL(Qwen2_5_VLParams params)
: params(params) {
Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false, bool llama_cpp_style = false)
: enable_vision(enable_vision), params(params) {
blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers,
params.vocab_size,
params.hidden_size,
@ -522,32 +862,90 @@ namespace Qwen {
params.num_heads,
params.num_kv_heads,
params.rms_norm_eps));
if (enable_vision) {
blocks["visual"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionModel(llama_cpp_style,
params.vision.num_layers,
params.vision.in_channels,
params.vision.hidden_size,
params.vision.out_hidden_size,
params.vision.intermediate_size,
params.vision.num_heads,
params.vision.spatial_merge_size,
params.vision.patch_size,
params.vision.temporal_patch_size,
params.vision.window_size,
params.vision.fullatt_block_indexes));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) {
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
// input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
auto x = model->forward(ctx, backend, input_ids, input_pos);
auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds);
return x;
}
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
GGML_ASSERT(enable_vision);
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask);
}
};
struct Qwen2_5_VLRunner : public GGMLRunner {
Qwen2_5_VLParams params;
bool enable_vision;
Qwen2_5_VL model;
std::vector<int> input_pos_vec;
std::vector<float> window_mask_vec;
std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec;
std::vector<float> pe_vec;
Qwen2_5_VLRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const std::string prefix)
: GGMLRunner(backend, offload_params_to_cpu) {
model = Qwen2_5_VL(params);
const std::string prefix,
bool enable_vision_ = false)
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
bool have_vision_weight = false;
bool llama_cpp_style = false;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("visual.");
if (pos != std::string::npos) {
have_vision_weight = true;
if (contains(tensor_name, "attn.q_proj")) {
llama_cpp_style = true;
break;
}
}
}
if (enable_vision && !have_vision_weight) {
LOG_WARN("no vision weights detected, vision disabled");
enable_vision = false;
}
if (enable_vision) {
LOG_DEBUG("enable qwen2vl vision");
if (llama_cpp_style) {
LOG_DEBUG("llama.cpp style vision weight");
}
}
model = Qwen2_5_VL(params, enable_vision, llama_cpp_style);
model.init(params_ctx, tensor_types, prefix);
}
@ -562,16 +960,32 @@ namespace Qwen {
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) {
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size]
struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) {
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* input_pos,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
input_ids = to_backend(input_ids);
for (auto& image_embed : image_embeds) {
image_embed.second = to_backend(image_embed.second);
}
int64_t n_tokens = input_ids->ne[0];
input_pos_vec.resize(n_tokens * 4);
for (int i = 0; i < n_tokens; ++i) {
@ -586,7 +1000,7 @@ namespace Qwen {
n_tokens * 4);
set_backend_tensor_data(input_pos, input_pos_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos);
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds);
ggml_build_forward_expand(gf, hidden_states);
@ -595,13 +1009,183 @@ namespace Qwen {
void compute(const int n_threads,
struct ggml_tensor* input_ids,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids);
return build_graph(input_ids, image_embeds);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
int grid_t = 1;
int grid_h = h / params.vision.patch_size;
int grid_w = w / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
return grid_t * grid_h * grid_w;
}
struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) {
// image: [C, H, W]
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
int64_t C = image->ne[2];
int64_t H = image->ne[1];
int64_t W = image->ne[0];
int64_t mh = params.vision.spatial_merge_size;
int64_t mw = params.vision.spatial_merge_size;
int64_t pt = params.vision.temporal_patch_size;
int64_t ph = params.vision.patch_size;
int64_t pw = params.vision.patch_size;
image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw]
image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw]
image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw]
image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw]
image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw]
image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw]
image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
return image;
}
struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false);
GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
int grid_t = 1;
int grid_h = image->ne[1] / params.vision.patch_size;
int grid_w = image->ne[0] / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size;
image = to_backend(image);
auto pixel_values = process_image(compute_ctx, image);
// window index
int inverse_index = 0;
window_index_vec.resize(llm_grid_h * llm_grid_w);
window_inverse_index_vec.resize(llm_grid_h * llm_grid_w);
std::vector<int> seqlens;
for (int ih = 0; ih < llm_grid_h; ih += vit_merger_window_size) {
for (int iw = 0; iw < llm_grid_w; iw += vit_merger_window_size) {
int win_h = std::min(vit_merger_window_size, llm_grid_h - ih);
int win_w = std::min(vit_merger_window_size, llm_grid_w - iw);
for (int iy = 0; iy < win_h; iy++) {
for (int ix = 0; ix < win_w; ix++) {
int index = (ih + iy) * llm_grid_w + iw + ix;
window_index_vec[inverse_index] = index;
window_inverse_index_vec[index] = inverse_index;
inverse_index++;
}
}
seqlens.push_back(win_h * win_w * params.vision.spatial_merge_size * params.vision.spatial_merge_size);
}
}
// printf("window_index: ");
// for (int i : window_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("window_inverse_index: ");
// for (int i : window_inverse_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("seqlens: ");
// for (int i : seqlens) {
// printf("%d ", i);
// }
// printf("\n");
auto window_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
auto window_inverse_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
set_backend_tensor_data(window_index, window_index_vec.data());
set_backend_tensor_data(window_inverse_index, window_inverse_index_vec.data());
// window mask
int seq_window_size = (vit_merger_window_size * params.vision.spatial_merge_size) * (vit_merger_window_size * params.vision.spatial_merge_size);
window_mask_vec.resize((grid_h * grid_w) * (grid_h * grid_w));
int window_start_index = 0;
for (int seq_index = 0; seq_index < seqlens.size(); seq_index++) {
int window_end_index = window_start_index + seqlens[seq_index];
// LOG_DEBUG("%d %d", window_start_index, window_end_index);
GGML_ASSERT(window_end_index <= grid_h * grid_w);
for (int i = window_start_index; i < window_end_index; i++) {
for (int j = 0; j < grid_h * grid_w; j++) {
float mask_value = -INFINITY;
if (j >= window_start_index && j < window_end_index) {
mask_value = 0;
}
GGML_ASSERT((i * (grid_h * grid_w) + j) < window_mask_vec.size());
window_mask_vec[i * (grid_h * grid_w) + j] = mask_value;
}
}
window_start_index = window_end_index;
// printf("\n");
}
// printf("window_mask: \n");
// for (int i = 0; i < grid_h*grid_w; i++) {
// for (int j = 0; j < grid_h*grid_w; j++) {
// printf("%f ", window_mask_vec[i * (grid_h * grid_w) + j]);
// }
// printf("\n");
// }
auto window_mask = ggml_new_tensor_2d(compute_ctx,
GGML_TYPE_F32,
grid_h * grid_w,
grid_h * grid_w);
set_backend_tensor_data(window_mask, window_mask_vec.data());
// pe
int head_dim = params.vision.hidden_size / params.vision.num_heads;
pe_vec = Rope::gen_qwen2vl_pe(grid_h,
grid_w,
params.vision.spatial_merge_size,
window_inverse_index_vec,
10000.f,
{head_dim / 2, head_dim / 2});
int pos_len = pe_vec.size() / head_dim / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* hidden_states = vision_forward(compute_ctx,
runtime_backend,
pixel_values,
pe,
window_index,
window_inverse_index,
window_mask);
ggml_build_forward_expand(gf, hidden_states);
return gf;
}
void encode_image(const int n_threads,
struct ggml_tensor* image,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_encode_image_graph(image);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};
struct Qwen2_5_VLEmbedder {
@ -611,8 +1195,9 @@ namespace Qwen {
Qwen2_5_VLEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: model(backend, offload_params_to_cpu, tensor_types, prefix) {
const std::string prefix = "",
bool enable_vision = false)
: model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -666,8 +1251,76 @@ namespace Qwen {
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL);
bool test_vit = true;
bool test_decoder_with_vit = true;
{
if (test_decoder_with_vit) {
ggml_tensor* image_embed = nullptr;
{
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out, false, "image_embed");
image_embed = out;
LOG_DEBUG("qwen2vl encode_image test done in %dms", t1 - t0);
}
std::string placeholder = "<|image_pad|>";
std::string img_prompt = "Picture 1: <|vision_start|>"; // [24669, 220, 16, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int i = 0; i < num_image_tokens; i++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
image_embeds.emplace_back(64, image_embed);
std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
text += img_prompt;
text += "change 'flux.cpp' to 'edit.cpp'";
text += "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(text, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
for (auto token : tokens) {
printf("%d ", token);
}
printf("\n");
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.compute(8, input_ids, image_embeds, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else if (test_vit) {
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
// ggml_set_f32(image, 0.f);
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out, false, "out");
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
// ggml_tensor_diff(ref_out, out, 0.01f);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else {
std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n");
auto tokens_and_weights = tokenize(text, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@ -680,7 +1333,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.compute(8, input_ids, &out, work_ctx);
model.compute(8, input_ids, {}, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
@ -692,7 +1345,7 @@ namespace Qwen {
// cpu f16: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "qwen2vl.")) {
@ -708,7 +1361,11 @@ namespace Qwen {
}
}
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl"));
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend,
false,
tensor_types,
"qwen2vl",
true));
qwenvl->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;

251
rope.hpp
View File

@ -4,9 +4,9 @@
#include <vector>
#include "ggml_extend.hpp"
struct Rope {
namespace Rope {
template <class T>
static std::vector<T> linspace(T start, T end, int num) {
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num);
if (num == 1) {
result[0] = start;
@ -19,7 +19,7 @@ struct Rope {
return result;
}
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
__STATIC_INLINE__ std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size();
int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
@ -31,7 +31,7 @@ struct Rope {
return transposed;
}
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
__STATIC_INLINE__ std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec;
for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
@ -39,7 +39,7 @@ struct Rope {
return flat_vec;
}
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0);
int half_dim = dim / 2;
@ -72,11 +72,11 @@ struct Rope {
}
// Generate IDs for image patches and text
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
}
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
@ -102,9 +102,9 @@ struct Rope {
return img_ids_repeated;
}
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b,
int bs) {
__STATIC_INLINE__ std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b,
int bs) {
size_t a_len = a.size() / bs;
size_t b_len = b.size() / bs;
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
@ -119,10 +119,10 @@ struct Rope {
return ids;
}
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs,
int theta,
const std::vector<int>& axes_dim) {
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size();
@ -151,17 +151,11 @@ struct Rope {
return flatten(emb);
}
static std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
int index = 1;
@ -189,25 +183,45 @@ struct Rope {
return ids;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
__STATIC_INLINE__ std::vector<float> gen_flux_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
}
static std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
@ -220,31 +234,37 @@ struct Rope {
}
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate qwen_image positional embeddings
static std::vector<float> gen_qwen_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len);
__STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
}
static std::vector<std::vector<float>> gen_vid_ids(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int t_offset = 0,
int h_offset = 0,
int w_offset = 0) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int t_offset = 0,
int h_offset = 0,
int w_offset = 0) {
int t_len = (t + (pt / 2)) / pt;
int h_len = (h + (ph / 2)) / ph;
int w_len = (w + (pw / 2)) / pw;
@ -276,18 +296,115 @@ struct Rope {
}
// Generate wan positional embeddings
static std::vector<float> gen_wan_pe(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int theta,
const std::vector<int>& axes_dim) {
__STATIC_INLINE__ std::vector<float> gen_wan_pe(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
return embed_nd(ids, bs, theta, axes_dim);
}
}; // struct Rope
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen2vl_ids(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index) {
std::vector<std::vector<float>> ids(grid_h * grid_w, std::vector<float>(2, 0.0));
int index = 0;
for (int ih = 0; ih < grid_h; ih += merge_size) {
for (int iw = 0; iw < grid_w; iw += merge_size) {
for (int iy = 0; iy < merge_size; iy++) {
for (int ix = 0; ix < merge_size; ix++) {
int inverse_index = window_index[index / (merge_size * merge_size)];
int i = inverse_index * (merge_size * merge_size) + index % (merge_size * merge_size);
GGML_ASSERT(i < grid_h * grid_w);
ids[i][0] = ih + iy;
ids[i][1] = iw + ix;
index++;
}
}
}
}
return ids;
}
// Generate qwen2vl positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen2vl_pe(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index);
return embed_nd(ids, 1, theta, axes_dim);
}
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
bool rope_interleaved = true) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2], [[cos, -sin], [sin, cos]]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
if (rope_interleaved) {
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
} else {
x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2]
}
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
if (!rope_interleaved) {
x_out = ggml_cont(ctx, ggml_permute(ctx, x_out, 1, 0, 2, 3)); // [N * n_head, L, x, d_head/2]
}
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f,
bool rope_interleaved = true) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope
#endif // __ROPE_HPP__

View File

@ -261,6 +261,13 @@ public:
}
}
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) {
LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path);
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) {
LOG_WARN("loading qwen2vl vision from '%s' failed", sd_ctx_params->qwen2vl_vision_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
@ -274,6 +281,15 @@ public:
return false;
}
auto& tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
item.second = GGML_TYPE_F16;
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
}
}
LOG_INFO("Version: %s ", model_version_to_str[version]);
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
@ -417,9 +433,15 @@ public:
clip_vision->get_param_tensors(tensors);
}
} else if (sd_version_is_qwen_image(version)) {
bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
model_loader.tensor_storages_types,
"",
enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
@ -590,7 +612,9 @@ public:
if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.conv1");
ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("text_encoders.qwen2vl.visual.");
}
if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3");
@ -949,12 +973,12 @@ public:
ggml_set_f32(output, 0.f);
} else {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size);
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size, clip_vision->vision_model.image_size);
free(image.data);
image.data = NULL;
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image.data, pixel_values, false);
sd_image_f32_to_tensor(resized_image, pixel_values, false);
free(resized_image.data);
resized_image.data = NULL;
@ -991,7 +1015,7 @@ public:
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
free(image.data);
image.data = NULL;
sd_image_f32_to_tensor(resized_image.data, init_img, false);
sd_image_f32_to_tensor(resized_image, init_img, false);
free(resized_image.data);
resized_image.data = NULL;
} else {
@ -1749,6 +1773,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"clip_vision_path: %s\n"
"t5xxl_path: %s\n"
"qwen2vl_path: %s\n"
"qwen2vl_vision_path: %s\n"
"diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n"
"vae_path: %s\n"
@ -1777,6 +1802,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->qwen2vl_path),
SAFE_STR(sd_ctx_params->qwen2vl_vision_path),
SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path),
@ -1987,6 +2013,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_image_t control_image,
float control_strength,
sd_pm_params_t pm_params,
std::vector<sd_image_t*> ref_images,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
ggml_tensor* concat_latent = NULL,
@ -2019,6 +2046,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
ggml_tensor* init_img = NULL;
SDCondition id_cond;
std::vector<bool> class_tokens_mask;
ConditionerParams condition_params;
condition_params.clip_skip = clip_skip;
condition_params.width = width;
condition_params.height = height;
condition_params.ref_images = ref_images;
condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels();
if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) {
int64_t t0 = ggml_time_ms();
@ -2041,7 +2076,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<sd_image_f32_t> processed_id_images;
for (int i = 0; i < pm_params.id_images_count; i++) {
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
free(id_image.data);
id_image.data = NULL;
processed_id_images.push_back(processed_id_image);
@ -2058,17 +2093,15 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
}
processed_id_images.clear();
int64_t t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads, prompt,
clip_skip,
width,
height,
pm_params.id_images_count,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = NULL;
int64_t t0 = ggml_time_ms();
condition_params.text = prompt;
condition_params.num_input_imgs = pm_params.id_images_count;
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = NULL;
if (pmv2 && pm_params.id_embed_path != nullptr) {
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
// print_ggml_tensor(id_embeds, true, "id_embeds:");
@ -2094,14 +2127,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
}
// Get learned condition
t0 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
prompt,
clip_skip,
width,
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
t0 = ggml_time_ms();
condition_params.text = prompt;
condition_params.zero_out_masked = false;
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
SDCondition uncond;
if (guidance.txt_cfg != 1.0 ||
@ -2110,14 +2141,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
zero_out_masked = true;
}
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
negative_prompt,
clip_skip,
width,
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
condition_params.text = negative_prompt;
condition_params.zero_out_masked = zero_out_masked;
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
@ -2538,13 +2566,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
std::vector<ggml_tensor*> ref_latents;
for (int i = 0; i < ref_images.size(); i++) {
ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
ref_images[i]->width,
ref_images[i]->height,
3,
1);
sd_image_to_tensor(*ref_images[i], img);
ggml_tensor* img;
if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]);
int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height);
double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height);
double vae_height = vae_width * ref_image.height / ref_image.width;
vae_height = round(vae_height / 32) * 32;
vae_width = round(vae_width / 32) * 32;
sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast<int>(vae_width), static_cast<int>(vae_height));
free(ref_image.data);
ref_image.data = nullptr;
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
resized_image.width,
resized_image.height,
3,
1);
sd_image_f32_to_tensor(resized_image, img);
free(resized_image.data);
resized_image.data = nullptr;
} else {
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
ref_images[i]->width,
ref_images[i]->height,
3,
1);
sd_image_to_tensor(*ref_images[i], img);
}
// print_ggml_tensor(img, false, "img");
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent);
@ -2578,6 +2635,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->control_image,
sd_img_gen_params->control_strength,
sd_img_gen_params->pm_params,
ref_images,
ref_latents,
sd_img_gen_params->increase_ref_index,
concat_latent,
@ -2835,30 +2893,25 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
// Get learned condition
bool zero_out_masked = true;
int64_t t1 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
prompt,
sd_vid_gen_params->clip_skip,
width,
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
cond.c_concat = concat_latent;
cond.c_vector = clip_vision_output;
ConditionerParams condition_params;
condition_params.clip_skip = sd_vid_gen_params->clip_skip;
condition_params.zero_out_masked = true;
condition_params.text = prompt;
int64_t t1 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
cond.c_concat = concat_latent;
cond.c_vector = clip_vision_output;
SDCondition uncond;
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) {
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
negative_prompt,
sd_vid_gen_params->clip_skip,
width,
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
uncond.c_concat = concat_latent;
uncond.c_vector = clip_vision_output;
condition_params.text = negative_prompt;
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
uncond.c_concat = concat_latent;
uncond.c_vector = clip_vision_output;
}
int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);

View File

@ -132,6 +132,7 @@ typedef struct {
const char* clip_vision_path;
const char* t5xxl_path;
const char* qwen2vl_path;
const char* qwen2vl_vision_path;
const char* diffusion_model_path;
const char* high_noise_diffusion_model_path;
const char* vae_path;

View File

@ -84,6 +84,7 @@ int round_up_to(int value, int base) {
}
#ifdef _WIN32 // code for windows
#define NOMINMAX
#include <windows.h>
bool file_exists(const std::string& filename) {
@ -298,7 +299,7 @@ std::string trim(const std::string& s) {
static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024
#define LOG_BUFFER_SIZE 4096
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
va_list args;
@ -387,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
float original_x = (float)x * image.width / target_width;
float original_y = (float)y * image.height / target_height;
int x1 = (int)original_x;
int y1 = (int)original_y;
int x2 = x1 + 1;
int y2 = y1 + 1;
uint32_t x1 = (uint32_t)original_x;
uint32_t y1 = (uint32_t)original_y;
uint32_t x2 = std::min(x1 + 1, image.width - 1);
uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -427,23 +428,26 @@ float means[3] = {0.48145466, 0.4578275, 0.40821073};
float stds[3] = {0.26862954, 0.26130258, 0.27577711};
// Function to clip and preprocess sd_image_f32_t
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
float scale = (float)size / fmin(image.width, image.height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height) {
float width_scale = (float)target_width / image.width;
float height_scale = (float)target_height / image.height;
float scale = std::fmax(width_scale, height_scale);
// Interpolation
int new_width = (int)(scale * image.width);
int new_height = (int)(scale * image.height);
float* resized_data = (float*)malloc(new_width * new_height * image.channel * sizeof(float));
int resized_width = (int)(scale * image.width);
int resized_height = (int)(scale * image.height);
float* resized_data = (float*)malloc(resized_width * resized_height * image.channel * sizeof(float));
for (int y = 0; y < new_height; y++) {
for (int x = 0; x < new_width; x++) {
float original_x = (float)x * image.width / new_width;
float original_y = (float)y * image.height / new_height;
for (int y = 0; y < resized_height; y++) {
for (int x = 0; x < resized_width; x++) {
float original_x = (float)x * image.width / resized_width;
float original_y = (float)y * image.height / resized_height;
int x1 = (int)original_x;
int y1 = (int)original_y;
int x2 = x1 + 1;
int y2 = y1 + 1;
uint32_t x1 = (uint32_t)original_x;
uint32_t y1 = (uint32_t)original_y;
uint32_t x2 = std::min(x1 + 1, image.width - 1);
uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -456,26 +460,28 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio);
*(resized_data + y * new_width * image.channel + x * image.channel + k) = value;
*(resized_data + y * resized_width * image.channel + x * image.channel + k) = value;
}
}
}
// Clip and preprocess
int h = (new_height - size) / 2;
int w = (new_width - size) / 2;
int h_offset = std::max((int)(resized_height - target_height) / 2, 0);
int w_offset = std::max((int)(resized_width - target_width) / 2, 0);
sd_image_f32_t result;
result.width = size;
result.height = size;
result.width = target_width;
result.height = target_height;
result.channel = image.channel;
result.data = (float*)malloc(size * size * image.channel * sizeof(float));
result.data = (float*)malloc(target_height * target_width * image.channel * sizeof(float));
for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
*(result.data + i * size * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + (i + h) * new_width * image.channel + (j + w) * image.channel + k), 0.0f), 255.0f) / 255.0f;
for (int i = 0; i < result.height; i++) {
for (int j = 0; j < result.width; j++) {
int src_y = std::min(i + h_offset, resized_height - 1);
int src_x = std::min(j + w_offset, resized_width - 1);
*(result.data + i * result.width * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f;
}
}
}
@ -485,10 +491,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
// Normalize
for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
for (int i = 0; i < result.height; i++) {
for (int j = 0; j < result.width; j++) {
// *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f;
int offset = i * size * image.channel + j * image.channel + k;
int offset = i * result.width * image.channel + j * image.channel + k;
float value = *(result.data + offset);
value = (value - means[k]) / stds[k];
// value = 0.5f;

2
util.h
View File

@ -42,7 +42,7 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image);
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter);

View File

@ -1333,7 +1333,7 @@ namespace WAN {
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;