feat: add Qwen Image Edit support (#877)

* add ref latent support for qwen image

* optimize clip_preprocess and fix get_first_stage_encoding

* add qwen2vl vit support

* add qwen image edit support

* fix qwen image edit pipeline

* add mmproj file support

* support dynamic number of Qwen image transformer blocks

* set prompt_template_encode_start_idx every time

* to_add_out precision fix

* to_out.0 precision fix

* update docs
This commit is contained in:
leejet 2025-10-13 23:17:18 +08:00 committed by GitHub
parent c64994dc1d
commit 2e9242e37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1339 additions and 365 deletions

View File

@ -24,6 +24,7 @@ API and command-line option may change frequently.***
- [Qwen Image](./docs/qwen_image.md) - [Qwen Image](./docs/qwen_image.md)
- Image Edit Models - Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md) - [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- Video Models - Video Models
- [Wan2.1/Wan2.2](./docs/wan.md) - [Wan2.1/Wan2.2](./docs/wan.md)
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
@ -298,6 +299,7 @@ arguments:
--clip_vision path to the clip-vision encoder --clip_vision path to the clip-vision encoder
--t5xxl path to the t5xxl text encoder --t5xxl path to the t5xxl text encoder
--qwen2vl path to the qwen2vl text encoder --qwen2vl path to the qwen2vl text encoder
--qwen2vl_vision path to the qwen2vl vit
--vae [VAE] path to vae --vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model --control-net [CONTROL_PATH] path to control net model

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 415 KiB

View File

@ -15,28 +15,28 @@ struct SDCondition {
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
}; };
struct ConditionerParams {
std::string text;
int clip_skip = -1;
int width = -1;
int height = -1;
int adm_in_channels = -1;
bool zero_out_masked = false;
int num_input_imgs = 0; // for photomaker
std::vector<sd_image_t*> ref_images = {}; // for qwen image edit
};
struct Conditioner { struct Conditioner {
virtual SDCondition get_learned_condition(ggml_context* work_ctx, virtual SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) = 0;
int clip_skip, virtual void alloc_params_buffer() = 0;
int width, virtual void free_params_buffer() = 0;
int height, virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
int adm_in_channels = -1, virtual size_t get_params_buffer_size() = 0;
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ABORT("Not implemented yet!"); GGML_ABORT("Not implemented yet!");
} }
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx, virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
@ -555,20 +555,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::tuple<SDCondition, std::vector<bool>> std::tuple<SDCondition, std::vector<bool>>
get_learned_condition_with_trigger(ggml_context* work_ctx, get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto image_tokens = convert_token_to_id(trigger_word); auto image_tokens = convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){ // if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]); // printf(" image token id is: %d \n", image_tokens[0]);
// } // }
GGML_ASSERT(image_tokens.size() == 1); GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = tokenize_with_trigger_token(text, auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
num_input_imgs, conditioner_params.num_input_imgs,
image_tokens[0], image_tokens[0],
true); true);
std::vector<int>& tokens = std::get<0>(tokens_and_weights); std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@ -582,7 +576,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// for(int i = 0; i < clsm.size(); ++i) // for(int i = 0; i < clsm.size(); ++i)
// printf("%d ", clsm[i]?1:0); // printf("%d ", clsm[i]?1:0);
// printf("\n"); // printf("\n");
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked); auto cond = get_learned_condition_common(work_ctx,
n_threads,
tokens,
weights,
conditioner_params.clip_skip,
conditioner_params.width,
conditioner_params.height,
conditioner_params.adm_in_channels,
conditioner_params.zero_out_masked);
return std::make_tuple(cond, clsm); return std::make_tuple(cond, clsm);
} }
@ -600,16 +602,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip, auto tokens_and_weights = tokenize(conditioner_params.text, true);
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
auto tokens_and_weights = tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first; std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second; std::vector<float>& weights = tokens_and_weights.second;
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked); return get_learned_condition_common(work_ctx,
n_threads,
tokens,
weights,
conditioner_params.clip_skip,
conditioner_params.width,
conditioner_params.height,
conditioner_params.adm_in_channels,
conditioner_params.zero_out_masked);
} }
}; };
@ -974,14 +979,13 @@ struct SD3CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip, auto tokens_and_weights = tokenize(conditioner_params.text, 77, true);
int width, return get_learned_condition_common(work_ctx,
int height, n_threads,
int adm_in_channels = -1, tokens_and_weights,
bool zero_out_masked = false) { conditioner_params.clip_skip,
auto tokens_and_weights = tokenize(text, 77, true); conditioner_params.zero_out_masked);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
}; };
@ -1174,14 +1178,13 @@ struct FluxCLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip, auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
int width, return get_learned_condition_common(work_ctx,
int height, n_threads,
int adm_in_channels = -1, tokens_and_weights,
bool zero_out_masked = false) { conditioner_params.clip_skip,
auto tokens_and_weights = tokenize(text, chunk_len, true); conditioner_params.zero_out_masked);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
}; };
@ -1360,27 +1363,30 @@ struct T5CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const std::string& text, const ConditionerParams& conditioner_params) {
int clip_skip, auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
int width, return get_learned_condition_common(work_ctx,
int height, n_threads,
int adm_in_channels = -1, tokens_and_weights,
bool zero_out_masked = false) { conditioner_params.clip_skip,
auto tokens_and_weights = tokenize(text, chunk_len, true); conditioner_params.zero_out_masked);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
} }
}; };
struct Qwen2_5_VLCLIPEmbedder : public Conditioner { struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer; Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl; std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
int prompt_template_encode_start_idx = 34;
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend, Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "") { const std::string prefix = "",
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl"); bool enable_vision = false) {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
offload_params_to_cpu,
tensor_types,
"text_encoders.qwen2vl",
enable_vision);
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
@ -1402,9 +1408,19 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
} }
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0, size_t max_length = 0,
bool padding = false) { size_t system_prompt_length = 0,
auto parsed_attention = parse_prompt_attention(text); bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) {
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
}
{ {
std::stringstream ss; std::stringstream ss;
@ -1429,20 +1445,89 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
tokenizer.pad_tokens(tokens, weights, max_length, padding); tokenizer.pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) { // for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", "; // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
// } // }
// std::cout << std::endl; // std::cout << std::endl;
return {tokens, weights}; return {tokens, weights};
} }
SDCondition get_learned_condition_common(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
std::tuple<std::vector<int>, std::vector<float>> token_and_weights, const ConditionerParams& conditioner_params) {
int clip_skip, std::string prompt;
bool zero_out_masked = false) { std::vector<std::pair<int, ggml_tensor*>> image_embeds;
auto& tokens = std::get<0>(token_and_weights); size_t system_prompt_length = 0;
auto& weights = std::get<1>(token_and_weights); int prompt_template_encode_start_idx = 34;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
int min_pixels = 384 * 384;
int max_pixels = 560 * 560;
std::string placeholder = "<|image_pad|>";
std::string img_prompt;
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
h_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
w_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
}
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
free(image.data);
image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image, image_tensor, false);
free(resized_image.data);
resized_image.data = nullptr;
ggml_tensor* image_embed = nullptr;
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int j = 0; j < num_image_tokens; j++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
}
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
system_prompt_length = prompt.size();
prompt += img_prompt;
prompt += conditioner_params.text;
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
}
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584] struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
@ -1451,6 +1536,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
qwenvl->compute(n_threads, qwenvl->compute(n_threads,
input_ids, input_ids,
image_embeds,
&hidden_states, &hidden_states,
work_ctx); work_ctx);
{ {
@ -1486,19 +1572,6 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return SDCondition(new_hidden_states, nullptr, nullptr); return SDCondition(new_hidden_states, nullptr, nullptr);
} }
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(prompt, 0, false);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
}; };
#endif #endif

View File

@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.x, diffusion_params.x,
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
output, output,
output_ctx); output_ctx);
} }

35
docs/qwen_image_edit.md Normal file
View File

@ -0,0 +1,35 @@
# How to Use
## Download weights
- Download Qwen Image
- Qwen Image Edit
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-GGUF/tree/main
- Qwen Image Edit 2509
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
- Download qwen_2.5_vl 7b
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main
## Examples
### Qwen Image Edit
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
```
<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />
### Qwen Image Edit 2509
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
```
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />

View File

@ -62,6 +62,7 @@ struct SDParams {
std::string clip_vision_path; std::string clip_vision_path;
std::string t5xxl_path; std::string t5xxl_path;
std::string qwen2vl_path; std::string qwen2vl_path;
std::string qwen2vl_vision_path;
std::string diffusion_model_path; std::string diffusion_model_path;
std::string high_noise_diffusion_model_path; std::string high_noise_diffusion_model_path;
std::string vae_path; std::string vae_path;
@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str()); printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str());
@ -220,6 +222,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip_vision path to the clip-vision encoder\n"); printf(" --clip_vision path to the clip-vision encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --qwen2vl path to the qwen2vl text encoder\n"); printf(" --qwen2vl path to the qwen2vl text encoder\n");
printf(" --qwen2vl_vision path to the qwen2vl vit\n");
printf(" --vae [VAE] path to vae\n"); printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n");
@ -490,6 +493,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip_vision", "", &params.clip_vision_path}, {"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path}, {"", "--t5xxl", "", &params.t5xxl_path},
{"", "--qwen2vl", "", &params.qwen2vl_path}, {"", "--qwen2vl", "", &params.qwen2vl_path},
{"", "--qwen2vl_vision", "", &params.qwen2vl_vision_path},
{"", "--diffusion-model", "", &params.diffusion_model_path}, {"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path}, {"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
{"", "--vae", "", &params.vae_path}, {"", "--vae", "", &params.vae_path},
@ -952,7 +956,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
if (!te.empty()) { if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", "; parameter_string += "TE: " + sd_basename(te) + ", ";
} }
@ -1336,6 +1340,7 @@ int main(int argc, const char* argv[]) {
params.clip_vision_path.c_str(), params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(), params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(), params.qwen2vl_path.c_str(),
params.qwen2vl_vision_path.c_str(),
params.diffusion_model_path.c_str(), params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(), params.vae_path.c_str(),

View File

@ -81,57 +81,6 @@ namespace Flux {
} }
}; };
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
struct SelfAttention : public GGMLBlock { struct SelfAttention : public GGMLBlock {
public: public:
int64_t num_heads; int64_t num_heads;
@ -179,9 +128,9 @@ namespace Flux {
// x: [N, n_token, dim] // x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2] // pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim] // return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
}; };
@ -369,8 +318,8 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,
attn->ne[0], attn->ne[0],
@ -504,7 +453,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q); q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]

View File

@ -197,8 +197,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
return value; return value;
} }
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { __STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) {
value /= 255.f;
}
return value; return value;
} }
@ -458,24 +461,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
} }
} }
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, __STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
struct ggml_tensor* output, ggml_tensor* tensor,
bool scale = true) { bool scale = true) {
int64_t width = output->ne[0]; GGML_ASSERT(image.width == tensor->ne[0]);
int64_t height = output->ne[1]; GGML_ASSERT(image.height == tensor->ne[1]);
int64_t channels = output->ne[2]; GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); GGML_ASSERT(1 == tensor->ne[3]);
for (int iy = 0; iy < height; iy++) { GGML_ASSERT(tensor->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) { ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
for (int k = 0; k < channels; k++) { float value = sd_image_get_f32(image, i0, i1, i2, scale);
int value = *(image_data + iy * width * channels + ix * channels + k); ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
if (scale) { });
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
} }
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,

View File

@ -113,7 +113,6 @@ const char* unused_tensors[] = {
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.qwen2vl.output.weight", "text_encoders.qwen2vl.output.weight",
"text_encoders.qwen2vl.lm_head.", "text_encoders.qwen2vl.lm_head.",
"text_encoders.qwen2vl.visual.",
}; };
bool is_unused_tensor(std::string name) { bool is_unused_tensor(std::string name) {
@ -212,6 +211,24 @@ std::unordered_map<std::string, std::string> qwenvl_name_map{
{"output_norm.", "model.norm."}, {"output_norm.", "model.norm."},
}; };
std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
{"mm.", "merger.mlp."},
{"v.post_ln.", "merger.ln_q."},
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
{"v.blk.", "blocks."},
{"attn_q.", "attn.q_proj."},
{"attn_k.", "attn.k_proj."},
{"attn_v.", "attn.v_proj."},
{"attn_out.", "attn.proj."},
{"ffn_down.", "mlp.down_proj."},
{"ffn_gate.", "mlp.gate_proj."},
{"ffn_up.", "mlp.up_proj."},
{"ln1.", "norm1."},
{"ln2.", "norm2."},
};
std::string convert_cond_model_name(const std::string& name) { std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name; std::string new_name = name;
std::string prefix; std::string prefix;
@ -270,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) {
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
} }
} else if (contains(name, "qwen2vl")) { } else if (contains(name, "qwen2vl")) {
for (auto kv : qwenvl_name_map) { if (contains(name, "qwen2vl.visual")) {
size_t pos = new_name.find(kv.first); for (auto kv : qwenvl_vision_name_map) {
if (pos != std::string::npos) { size_t pos = new_name.find(kv.first);
new_name.replace(pos, kv.first.size(), kv.second); if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
} else {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
} }
} }
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {

View File

@ -94,12 +94,12 @@ namespace Qwen {
blocks["norm_added_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps)); blocks["norm_added_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps)); blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias));
// to_out.1 is nn.Dropout
float scale = 1.f / 32.f; float scale = 1.f / 32.f;
// The purpose of the scale here is to prevent NaN issues in certain situations. // The purpose of the scale here is to prevent NaN issues in certain situations.
// For example when using CUDA but the weights are k-quants (not all prompts). // For example when using CUDA but the weights are k-quants (not all prompts).
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, false, scale));
// to_out.1 is nn.Dropout
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale)); blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
} }
@ -159,7 +159,7 @@ namespace Qwen {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,
@ -389,6 +389,13 @@ namespace Qwen {
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int64_t h, int64_t h,
@ -449,7 +456,8 @@ namespace Qwen {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* pe) { struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT. // Forward pass of DiT.
// x: [N, C, H, W] // x: [N, C, H, W]
// timestep: [N,] // timestep: [N,]
@ -462,13 +470,26 @@ namespace Qwen {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
x = pad_to_patch_size(ctx, x); auto img = process_img(ctx, x);
x = patchify(ctx, x); uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
}
}
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
@ -495,6 +516,25 @@ namespace Qwen {
bool flash_attn = false) bool flash_attn = false)
: GGMLRunner(backend, offload_params_to_cpu) { : GGMLRunner(backend, offload_params_to_cpu) {
qwen_image_params.flash_attn = flash_attn; qwen_image_params.flash_attn = flash_attn;
qwen_image_params.num_layers = 0;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("transformer_blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > qwen_image_params.num_layers) {
qwen_image_params.num_layers = block_index + 1;
}
}
continue;
}
}
LOG_ERROR("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
qwen_image = QwenImageModel(qwen_image_params); qwen_image = QwenImageModel(qwen_image_params);
qwen_image.init(params_ctx, tensor_types, prefix); qwen_image.init(params_ctx, tensor_types, prefix);
} }
@ -509,7 +549,9 @@ namespace Qwen {
struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context) { struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
@ -517,18 +559,24 @@ namespace Qwen {
context = to_backend(context); context = to_backend(context);
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_qwen_image_pe(x->ne[1], pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
x->ne[0], x->ne[0],
qwen_image_params.patch_size, qwen_image_params.patch_size,
x->ne[3], x->ne[3],
context->ne[1], context->ne[1],
ref_latents,
increase_ref_index,
qwen_image_params.theta, qwen_image_params.theta,
qwen_image_params.axes_dim); qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
// print_ggml_tensor(pe); // print_ggml_tensor(pe, true, "pe");
// pe->data = NULL; // pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
@ -537,7 +585,8 @@ namespace Qwen {
x, x,
timesteps, timesteps,
context, context,
pe); pe,
ref_latents);
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
@ -548,13 +597,15 @@ namespace Qwen {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor** output = NULL, std::vector<ggml_tensor*> ref_latents = {},
struct ggml_context* output_ctx = NULL) { bool increase_ref_index = false,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
// timesteps: [N, ] // timesteps: [N, ]
// context: [N, max_position, hidden_size] // context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context); return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
}; };
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -586,7 +637,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, &out, work_ctx); compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

View File

@ -15,9 +15,11 @@
#include "clip.hpp" #include "clip.hpp"
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "json.hpp" #include "json.hpp"
#include "rope.hpp"
#include "tokenize_util.h" #include "tokenize_util.h"
namespace Qwen { namespace Qwen {
constexpr int QWENVL_GRAPH_SIZE = 10240;
class Qwen2Tokenizer { class Qwen2Tokenizer {
private: private:
@ -340,9 +342,9 @@ namespace Qwen {
struct Qwen2_5_VLMLP : public GGMLBlock { struct Qwen2_5_VLMLP : public GGMLBlock {
public: public:
Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false)); blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false)); blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, false)); blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@ -359,6 +361,288 @@ namespace Qwen {
} }
}; };
struct Qwen2_5_VisionPatchEmbed : public GGMLBlock {
protected:
bool llama_cpp_style;
int patch_size;
int temporal_patch_size;
int64_t in_channels;
int64_t embed_dim;
public:
Qwen2_5_VisionPatchEmbed(bool llama_cpp_style,
int patch_size = 14,
int temporal_patch_size = 2,
int64_t in_channels = 3,
int64_t embed_dim = 1152)
: llama_cpp_style(llama_cpp_style),
patch_size(patch_size),
temporal_patch_size(temporal_patch_size),
in_channels(in_channels),
embed_dim(embed_dim) {
if (llama_cpp_style) {
blocks["proj.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels,
embed_dim,
{patch_size, patch_size},
{patch_size, patch_size}, // stride
{0, 0}, // padding
{1, 1}, // dilation
false));
blocks["proj.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels,
embed_dim,
{patch_size, patch_size},
{patch_size, patch_size}, // stride
{0, 0}, // padding
{1, 1}, // dilation
false));
} else {
std::tuple<int, int, int> kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size};
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
embed_dim,
kernel_size,
kernel_size, // stride
{0, 0, 0}, // padding
{1, 1, 1}, // dilation
false));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
// return: [N*grid_t*grid_h*grid_w, embed_dim]
x = ggml_reshape_4d(ctx,
x,
patch_size,
patch_size,
temporal_patch_size,
ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size));
if (llama_cpp_style) {
auto proj_0 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.0"]);
auto proj_1 = std::dynamic_pointer_cast<Conv2d>(blocks["proj.1"]);
auto x0 = ggml_slice(ctx, x, 2, 0, 1);
x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels);
x0 = proj_0->forward(ctx, x0);
auto x1 = ggml_slice(ctx, x, 2, 1, 2);
x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels);
x1 = proj_1->forward(ctx, x1);
x = ggml_add(ctx, x0, x1);
} else {
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
x = proj->forward(ctx, x);
}
x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
return x;
}
};
struct Qwen2_5_VLPatchMerger : public GGMLBlock {
protected:
int64_t hidden_size;
public:
Qwen2_5_VLPatchMerger(int64_t dim,
int64_t context_dim,
int64_t spatial_merge_size) {
hidden_size = context_dim * spatial_merge_size * spatial_merge_size;
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(context_dim, 1e-6f));
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
// mlp.1 is nn.GELU()
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
};
struct Qwen2_5_VLVisionAttention : public GGMLBlock {
protected:
bool llama_cpp_style;
int64_t head_dim;
int64_t num_heads;
public:
Qwen2_5_VLVisionAttention(bool llama_cpp_style,
int64_t hidden_size,
int64_t num_heads)
: llama_cpp_style(llama_cpp_style), num_heads(num_heads) {
head_dim = hidden_size / num_heads;
GGML_ASSERT(num_heads * head_dim == hidden_size);
if (llama_cpp_style) {
blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
blocks["k_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
blocks["v_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3));
}
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
std::vector<ggml_tensor*> qkv_vec;
if (llama_cpp_style) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q_proj"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k_proj"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v_proj"]);
auto q = q_proj->forward(ctx, x);
auto k = k_proj->forward(ctx, x);
auto v = v_proj->forward(ctx, x);
qkv_vec = {q, k, v};
} else {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
qkv_vec = split_qkv(ctx, qkv);
}
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size]
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
}
};
struct Qwen2_5_VLVisionBlock : public GGMLBlock {
public:
Qwen2_5_VLVisionBlock(bool llama_cpp_style,
int64_t hidden_size,
int64_t intermediate_size,
int64_t num_heads,
float eps = 1e-6f) {
blocks["attn"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionAttention(llama_cpp_style, hidden_size, num_heads));
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
auto attn = std::dynamic_pointer_cast<Qwen2_5_VLVisionAttention>(blocks["attn"]);
auto mlp = std::dynamic_pointer_cast<Qwen2_5_VLMLP>(blocks["mlp"]);
auto norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm2"]);
auto residual = x;
x = norm1->forward(ctx, x);
x = attn->forward(ctx, backend, x, pe, mask);
x = ggml_add_inplace(ctx, x, residual);
residual = x;
x = norm2->forward(ctx, x);
x = mlp->forward(ctx, x);
x = ggml_add_inplace(ctx, x, residual);
return x;
}
};
struct Qwen2_5_VLVisionModel : public GGMLBlock {
protected:
int64_t num_layers;
int64_t spatial_merge_size;
std::set<int> fullatt_block_indexes;
public:
Qwen2_5_VLVisionModel(bool llama_cpp_style,
int64_t num_layers,
int64_t in_channels,
int64_t hidden_size,
int64_t out_hidden_size,
int64_t intermediate_size,
int64_t num_heads,
int64_t spatial_merge_size,
int64_t patch_size,
int64_t temporal_patch_size,
int64_t window_size,
std::set<int> fullatt_block_indexes = {7, 15, 23, 31},
float eps = 1e-6f)
: num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) {
blocks["patch_embed"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VisionPatchEmbed(llama_cpp_style,
patch_size,
temporal_patch_size,
in_channels,
hidden_size));
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionBlock(llama_cpp_style,
hidden_size,
intermediate_size,
num_heads,
eps));
}
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
// pixel_values: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
// window_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_inverse_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_mask: [grid_h*grid_w, grid_h*grid_w]
auto patch_embed = std::dynamic_pointer_cast<Qwen2_5_VisionPatchEmbed>(blocks["patch_embed"]);
auto merger = std::dynamic_pointer_cast<Qwen2_5_VLPatchMerger>(blocks["merger"]);
auto x = patch_embed->forward(ctx, pixel_values);
x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx, x, window_index);
x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
auto mask = window_mask;
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
mask = nullptr;
}
x = block->forward(ctx, backend, x, pe, mask);
}
x = merger->forward(ctx, x);
x = ggml_get_rows(ctx, x, window_inverse_index);
return x;
}
};
struct Qwen2_5_VLAttention : public GGMLBlock { struct Qwen2_5_VLAttention : public GGMLBlock {
protected: protected:
int64_t head_dim; int64_t head_dim;
@ -478,7 +762,8 @@ namespace Qwen {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) { struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
// input_ids: [N, n_token] // input_ids: [N, n_token]
// return: [N, n_token, hidden_size] // return: [N, n_token, hidden_size]
@ -487,6 +772,45 @@ namespace Qwen {
auto x = embed_tokens->forward(ctx, input_ids); auto x = embed_tokens->forward(ctx, input_ids);
if (image_embeds.size() > 0) {
GGML_ASSERT(x->ne[2] == 1); // N == 1
auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type);
int64_t txt_token_start = 0;
int64_t txt_token_end = 0;
ggml_tensor* input_embed = nullptr;
for (int i = 0; i < image_embeds.size(); i++) {
if (i == 0) {
txt_token_start = 0;
} else {
txt_token_start = image_embeds[i - 1].first + image_embeds[i - 1].second->ne[1];
}
txt_token_end = image_embeds[i].first;
auto txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
if (input_embed == nullptr) {
input_embed = txt_embed;
} else {
input_embed = ggml_concat(ctx, input_embed, txt_embed, 1);
}
auto image_embed = image_embeds[i].second;
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
}
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
txt_token_end = raw_x->ne[1];
auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
x = input_embed;
}
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
@ -498,6 +822,20 @@ namespace Qwen {
} }
}; };
struct Qwen2_5_VLVisionParams {
int64_t num_layers = 32;
int64_t hidden_size = 1280;
int64_t intermediate_size = 3420;
int64_t num_heads = 16;
int64_t in_channels = 3;
int64_t out_hidden_size = 3584;
int64_t temporal_patch_size = 2;
int64_t patch_size = 14;
int64_t spatial_merge_size = 2;
int64_t window_size = 112;
std::set<int> fullatt_block_indexes = {7, 15, 23, 31};
};
struct Qwen2_5_VLParams { struct Qwen2_5_VLParams {
int64_t num_layers = 28; int64_t num_layers = 28;
int64_t hidden_size = 3584; int64_t hidden_size = 3584;
@ -506,15 +844,17 @@ namespace Qwen {
int64_t num_kv_heads = 4; int64_t num_kv_heads = 4;
int64_t vocab_size = 152064; int64_t vocab_size = 152064;
float rms_norm_eps = 1e-06f; float rms_norm_eps = 1e-06f;
Qwen2_5_VLVisionParams vision;
}; };
struct Qwen2_5_VL : public GGMLBlock { struct Qwen2_5_VL : public GGMLBlock {
bool enable_vision;
Qwen2_5_VLParams params; Qwen2_5_VLParams params;
public: public:
Qwen2_5_VL() {} Qwen2_5_VL() {}
Qwen2_5_VL(Qwen2_5_VLParams params) Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false, bool llama_cpp_style = false)
: params(params) { : enable_vision(enable_vision), params(params) {
blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers, blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers,
params.vocab_size, params.vocab_size,
params.hidden_size, params.hidden_size,
@ -522,32 +862,90 @@ namespace Qwen {
params.num_heads, params.num_heads,
params.num_kv_heads, params.num_kv_heads,
params.rms_norm_eps)); params.rms_norm_eps));
if (enable_vision) {
blocks["visual"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionModel(llama_cpp_style,
params.vision.num_layers,
params.vision.in_channels,
params.vision.hidden_size,
params.vision.out_hidden_size,
params.vision.intermediate_size,
params.vision.num_heads,
params.vision.spatial_merge_size,
params.vision.patch_size,
params.vision.temporal_patch_size,
params.vision.window_size,
params.vision.fullatt_block_indexes));
}
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) { struct ggml_tensor* input_pos,
std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
// input_ids: [N, n_token] // input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]); auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
auto x = model->forward(ctx, backend, input_ids, input_pos); auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds);
return x; return x;
} }
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
GGML_ASSERT(enable_vision);
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask);
}
}; };
struct Qwen2_5_VLRunner : public GGMLRunner { struct Qwen2_5_VLRunner : public GGMLRunner {
Qwen2_5_VLParams params; Qwen2_5_VLParams params;
bool enable_vision;
Qwen2_5_VL model; Qwen2_5_VL model;
std::vector<int> input_pos_vec; std::vector<int> input_pos_vec;
std::vector<float> window_mask_vec;
std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec;
std::vector<float> pe_vec;
Qwen2_5_VLRunner(ggml_backend_t backend, Qwen2_5_VLRunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types, const String2GGMLType& tensor_types,
const std::string prefix) const std::string prefix,
: GGMLRunner(backend, offload_params_to_cpu) { bool enable_vision_ = false)
model = Qwen2_5_VL(params); : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
bool have_vision_weight = false;
bool llama_cpp_style = false;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
size_t pos = tensor_name.find("visual.");
if (pos != std::string::npos) {
have_vision_weight = true;
if (contains(tensor_name, "attn.q_proj")) {
llama_cpp_style = true;
break;
}
}
}
if (enable_vision && !have_vision_weight) {
LOG_WARN("no vision weights detected, vision disabled");
enable_vision = false;
}
if (enable_vision) {
LOG_DEBUG("enable qwen2vl vision");
if (llama_cpp_style) {
LOG_DEBUG("llama.cpp style vision weight");
}
}
model = Qwen2_5_VL(params, enable_vision, llama_cpp_style);
model.init(params_ctx, tensor_types, prefix); model.init(params_ctx, tensor_types, prefix);
} }
@ -562,16 +960,32 @@ namespace Qwen {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos) { struct ggml_tensor* input_pos,
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size] std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size]
return hidden_states; return hidden_states;
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* input_pos,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, std::vector<std::pair<int, ggml_tensor*>> image_embeds) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
input_ids = to_backend(input_ids); input_ids = to_backend(input_ids);
for (auto& image_embed : image_embeds) {
image_embed.second = to_backend(image_embed.second);
}
int64_t n_tokens = input_ids->ne[0]; int64_t n_tokens = input_ids->ne[0];
input_pos_vec.resize(n_tokens * 4); input_pos_vec.resize(n_tokens * 4);
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
@ -586,7 +1000,7 @@ namespace Qwen {
n_tokens * 4); n_tokens * 4);
set_backend_tensor_data(input_pos, input_pos_vec.data()); set_backend_tensor_data(input_pos, input_pos_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos); struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);
@ -595,13 +1009,183 @@ namespace Qwen {
void compute(const int n_threads, void compute(const int n_threads,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
ggml_tensor** output, ggml_tensor** output,
ggml_context* output_ctx = NULL) { ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids); return build_graph(input_ids, image_embeds);
}; };
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} }
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
int grid_t = 1;
int grid_h = h / params.vision.patch_size;
int grid_w = w / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
return grid_t * grid_h * grid_w;
}
struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) {
// image: [C, H, W]
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
int64_t C = image->ne[2];
int64_t H = image->ne[1];
int64_t W = image->ne[0];
int64_t mh = params.vision.spatial_merge_size;
int64_t mw = params.vision.spatial_merge_size;
int64_t pt = params.vision.temporal_patch_size;
int64_t ph = params.vision.patch_size;
int64_t pw = params.vision.patch_size;
image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw]
image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw]
image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw]
image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw]
image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw]
image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw]
image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
return image;
}
struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false);
GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
int grid_t = 1;
int grid_h = image->ne[1] / params.vision.patch_size;
int grid_w = image->ne[0] / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size;
image = to_backend(image);
auto pixel_values = process_image(compute_ctx, image);
// window index
int inverse_index = 0;
window_index_vec.resize(llm_grid_h * llm_grid_w);
window_inverse_index_vec.resize(llm_grid_h * llm_grid_w);
std::vector<int> seqlens;
for (int ih = 0; ih < llm_grid_h; ih += vit_merger_window_size) {
for (int iw = 0; iw < llm_grid_w; iw += vit_merger_window_size) {
int win_h = std::min(vit_merger_window_size, llm_grid_h - ih);
int win_w = std::min(vit_merger_window_size, llm_grid_w - iw);
for (int iy = 0; iy < win_h; iy++) {
for (int ix = 0; ix < win_w; ix++) {
int index = (ih + iy) * llm_grid_w + iw + ix;
window_index_vec[inverse_index] = index;
window_inverse_index_vec[index] = inverse_index;
inverse_index++;
}
}
seqlens.push_back(win_h * win_w * params.vision.spatial_merge_size * params.vision.spatial_merge_size);
}
}
// printf("window_index: ");
// for (int i : window_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("window_inverse_index: ");
// for (int i : window_inverse_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("seqlens: ");
// for (int i : seqlens) {
// printf("%d ", i);
// }
// printf("\n");
auto window_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
auto window_inverse_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
set_backend_tensor_data(window_index, window_index_vec.data());
set_backend_tensor_data(window_inverse_index, window_inverse_index_vec.data());
// window mask
int seq_window_size = (vit_merger_window_size * params.vision.spatial_merge_size) * (vit_merger_window_size * params.vision.spatial_merge_size);
window_mask_vec.resize((grid_h * grid_w) * (grid_h * grid_w));
int window_start_index = 0;
for (int seq_index = 0; seq_index < seqlens.size(); seq_index++) {
int window_end_index = window_start_index + seqlens[seq_index];
// LOG_DEBUG("%d %d", window_start_index, window_end_index);
GGML_ASSERT(window_end_index <= grid_h * grid_w);
for (int i = window_start_index; i < window_end_index; i++) {
for (int j = 0; j < grid_h * grid_w; j++) {
float mask_value = -INFINITY;
if (j >= window_start_index && j < window_end_index) {
mask_value = 0;
}
GGML_ASSERT((i * (grid_h * grid_w) + j) < window_mask_vec.size());
window_mask_vec[i * (grid_h * grid_w) + j] = mask_value;
}
}
window_start_index = window_end_index;
// printf("\n");
}
// printf("window_mask: \n");
// for (int i = 0; i < grid_h*grid_w; i++) {
// for (int j = 0; j < grid_h*grid_w; j++) {
// printf("%f ", window_mask_vec[i * (grid_h * grid_w) + j]);
// }
// printf("\n");
// }
auto window_mask = ggml_new_tensor_2d(compute_ctx,
GGML_TYPE_F32,
grid_h * grid_w,
grid_h * grid_w);
set_backend_tensor_data(window_mask, window_mask_vec.data());
// pe
int head_dim = params.vision.hidden_size / params.vision.num_heads;
pe_vec = Rope::gen_qwen2vl_pe(grid_h,
grid_w,
params.vision.spatial_merge_size,
window_inverse_index_vec,
10000.f,
{head_dim / 2, head_dim / 2});
int pos_len = pe_vec.size() / head_dim / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* hidden_states = vision_forward(compute_ctx,
runtime_backend,
pixel_values,
pe,
window_index,
window_inverse_index,
window_mask);
ggml_build_forward_expand(gf, hidden_states);
return gf;
}
void encode_image(const int n_threads,
struct ggml_tensor* image,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_encode_image_graph(image);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
}; };
struct Qwen2_5_VLEmbedder { struct Qwen2_5_VLEmbedder {
@ -611,8 +1195,9 @@ namespace Qwen {
Qwen2_5_VLEmbedder(ggml_backend_t backend, Qwen2_5_VLEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "") const std::string prefix = "",
: model(backend, offload_params_to_cpu, tensor_types, prefix) { bool enable_vision = false)
: model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) {
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -666,8 +1251,76 @@ namespace Qwen {
struct ggml_context* work_ctx = ggml_init(params); struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL); GGML_ASSERT(work_ctx != NULL);
bool test_vit = true;
bool test_decoder_with_vit = true;
{ if (test_decoder_with_vit) {
ggml_tensor* image_embed = nullptr;
{
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out, false, "image_embed");
image_embed = out;
LOG_DEBUG("qwen2vl encode_image test done in %dms", t1 - t0);
}
std::string placeholder = "<|image_pad|>";
std::string img_prompt = "Picture 1: <|vision_start|>"; // [24669, 220, 16, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int i = 0; i < num_image_tokens; i++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
image_embeds.emplace_back(64, image_embed);
std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
text += img_prompt;
text += "change 'flux.cpp' to 'edit.cpp'";
text += "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(text, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
for (auto token : tokens) {
printf("%d ", token);
}
printf("\n");
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.compute(8, input_ids, image_embeds, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else if (test_vit) {
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
// ggml_set_f32(image, 0.f);
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out, false, "out");
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
// ggml_tensor_diff(ref_out, out, 0.01f);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else {
std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n"); std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n");
auto tokens_and_weights = tokenize(text, 0, false); auto tokens_and_weights = tokenize(text, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights); std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@ -680,7 +1333,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
model.compute(8, input_ids, &out, work_ctx); model.compute(8, input_ids, {}, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
@ -692,7 +1345,7 @@ namespace Qwen {
// cpu f16: pass // cpu f16: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "qwen2vl.")) { if (!model_loader.init_from_file(file_path, "qwen2vl.")) {
@ -708,7 +1361,11 @@ namespace Qwen {
} }
} }
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl")); std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend,
false,
tensor_types,
"qwen2vl",
true));
qwenvl->alloc_params_buffer(); qwenvl->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;

251
rope.hpp
View File

@ -4,9 +4,9 @@
#include <vector> #include <vector>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
struct Rope { namespace Rope {
template <class T> template <class T>
static std::vector<T> linspace(T start, T end, int num) { __STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num); std::vector<T> result(num);
if (num == 1) { if (num == 1) {
result[0] = start; result[0] = start;
@ -19,7 +19,7 @@ struct Rope {
return result; return result;
} }
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) { __STATIC_INLINE__ std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size(); int rows = mat.size();
int cols = mat[0].size(); int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows)); std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
@ -31,7 +31,7 @@ struct Rope {
return transposed; return transposed;
} }
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) { __STATIC_INLINE__ std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec; std::vector<float> flat_vec;
for (const auto& sub_vec : vec) { for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
@ -39,7 +39,7 @@ struct Rope {
return flat_vec; return flat_vec;
} }
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) { __STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0); assert(dim % 2 == 0);
int half_dim = dim / 2; int half_dim = dim / 2;
@ -72,11 +72,11 @@ struct Rope {
} }
// Generate IDs for image patches and text // Generate IDs for image patches and text
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0)); return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
} }
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
@ -102,9 +102,9 @@ struct Rope {
return img_ids_repeated; return img_ids_repeated;
} }
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a, __STATIC_INLINE__ std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b, const std::vector<std::vector<float>>& b,
int bs) { int bs) {
size_t a_len = a.size() / bs; size_t a_len = a.size() / bs;
size_t b_len = b.size() / bs; size_t b_len = b.size() / bs;
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3)); std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
@ -119,10 +119,10 @@ struct Rope {
return ids; return ids;
} }
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids, __STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs, int bs,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size(); int num_axes = axes_dim.size();
@ -151,17 +151,11 @@ struct Rope {
return flatten(emb); return flatten(emb);
} }
static std::vector<std::vector<float>> gen_flux_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int w, int bs,
int patch_size, const std::vector<ggml_tensor*>& ref_latents,
int bs, bool increase_ref_index) {
int context_len, std::vector<std::vector<float>> ids;
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
int index = 1; int index = 1;
@ -189,25 +183,45 @@ struct Rope {
return ids; return ids;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate flux positional embeddings // Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h, __STATIC_INLINE__ std::vector<float> gen_flux_pe(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
std::vector<ggml_tensor*> ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
static std::vector<std::vector<float>> gen_qwen_image_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len) { int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len); int txt_id_start = std::max(h_len, w_len);
@ -220,31 +234,37 @@ struct Rope {
} }
auto img_ids = gen_img_ids(h, w, patch_size, bs); auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids; return ids;
} }
// Generate qwen_image positional embeddings // Generate qwen_image positional embeddings
static std::vector<float> gen_qwen_image_pe(int h, __STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
int theta, const std::vector<ggml_tensor*>& ref_latents,
const std::vector<int>& axes_dim) { bool increase_ref_index,
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len); int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
static std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h, int h,
int w, int w,
int pt, int pt,
int ph, int ph,
int pw, int pw,
int bs, int bs,
int t_offset = 0, int t_offset = 0,
int h_offset = 0, int h_offset = 0,
int w_offset = 0) { int w_offset = 0) {
int t_len = (t + (pt / 2)) / pt; int t_len = (t + (pt / 2)) / pt;
int h_len = (h + (ph / 2)) / ph; int h_len = (h + (ph / 2)) / ph;
int w_len = (w + (pw / 2)) / pw; int w_len = (w + (pw / 2)) / pw;
@ -276,18 +296,115 @@ struct Rope {
} }
// Generate wan positional embeddings // Generate wan positional embeddings
static std::vector<float> gen_wan_pe(int t, __STATIC_INLINE__ std::vector<float> gen_wan_pe(int t,
int h, int h,
int w, int w,
int pt, int pt,
int ph, int ph,
int pw, int pw,
int bs, int bs,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
}; // struct Rope
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen2vl_ids(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index) {
std::vector<std::vector<float>> ids(grid_h * grid_w, std::vector<float>(2, 0.0));
int index = 0;
for (int ih = 0; ih < grid_h; ih += merge_size) {
for (int iw = 0; iw < grid_w; iw += merge_size) {
for (int iy = 0; iy < merge_size; iy++) {
for (int ix = 0; ix < merge_size; ix++) {
int inverse_index = window_index[index / (merge_size * merge_size)];
int i = inverse_index * (merge_size * merge_size) + index % (merge_size * merge_size);
GGML_ASSERT(i < grid_h * grid_w);
ids[i][0] = ih + iy;
ids[i][1] = iw + ix;
index++;
}
}
}
}
return ids;
}
// Generate qwen2vl positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen2vl_pe(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index);
return embed_nd(ids, 1, theta, axes_dim);
}
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
bool rope_interleaved = true) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2], [[cos, -sin], [sin, cos]]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
if (rope_interleaved) {
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
} else {
x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2]
}
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
if (!rope_interleaved) {
x_out = ggml_cont(ctx, ggml_permute(ctx, x_out, 1, 0, 2, 3)); // [N * n_head, L, x, d_head/2]
}
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f,
bool rope_interleaved = true) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope
#endif // __ROPE_HPP__ #endif // __ROPE_HPP__

View File

@ -261,6 +261,13 @@ public:
} }
} }
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) {
LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path);
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) {
LOG_WARN("loading qwen2vl vision from '%s' failed", sd_ctx_params->qwen2vl_vision_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
@ -274,6 +281,15 @@ public:
return false; return false;
} }
auto& tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
item.second = GGML_TYPE_F16;
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
}
}
LOG_INFO("Version: %s ", model_version_to_str[version]); LOG_INFO("Version: %s ", model_version_to_str[version]);
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT) ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype ? (ggml_type)sd_ctx_params->wtype
@ -417,9 +433,15 @@ public:
clip_vision->get_param_tensors(tensors); clip_vision->get_param_tensors(tensors);
} }
} else if (sd_version_is_qwen_image(version)) { } else if (sd_version_is_qwen_image(version)) {
bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
model_loader.tensor_storages_types); model_loader.tensor_storages_types,
"",
enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend, diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
@ -590,7 +612,9 @@ public:
if (vae_decode_only) { if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder"); ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.conv1");
ignore_tensors.insert("first_stage_model.quant"); ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("text_encoders.qwen2vl.visual.");
} }
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3"); ignore_tensors.insert("conditioner.embedders.3");
@ -949,12 +973,12 @@ public:
ggml_set_f32(output, 0.f); ggml_set_f32(output, 0.f);
} else { } else {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image); sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size); sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size, clip_vision->vision_model.image_size);
free(image.data); free(image.data);
image.data = NULL; image.data = NULL;
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image.data, pixel_values, false); sd_image_f32_to_tensor(resized_image, pixel_values, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = NULL; resized_image.data = NULL;
@ -991,7 +1015,7 @@ public:
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height); sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
free(image.data); free(image.data);
image.data = NULL; image.data = NULL;
sd_image_f32_to_tensor(resized_image.data, init_img, false); sd_image_f32_to_tensor(resized_image, init_img, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = NULL; resized_image.data = NULL;
} else { } else {
@ -1749,6 +1773,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"clip_vision_path: %s\n" "clip_vision_path: %s\n"
"t5xxl_path: %s\n" "t5xxl_path: %s\n"
"qwen2vl_path: %s\n" "qwen2vl_path: %s\n"
"qwen2vl_vision_path: %s\n"
"diffusion_model_path: %s\n" "diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n"
"vae_path: %s\n" "vae_path: %s\n"
@ -1777,6 +1802,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->qwen2vl_path), SAFE_STR(sd_ctx_params->qwen2vl_path),
SAFE_STR(sd_ctx_params->qwen2vl_vision_path),
SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->vae_path),
@ -1987,6 +2013,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_image_t control_image, sd_image_t control_image,
float control_strength, float control_strength,
sd_pm_params_t pm_params, sd_pm_params_t pm_params,
std::vector<sd_image_t*> ref_images,
std::vector<ggml_tensor*> ref_latents, std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index, bool increase_ref_index,
ggml_tensor* concat_latent = NULL, ggml_tensor* concat_latent = NULL,
@ -2019,6 +2046,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
ggml_tensor* init_img = NULL; ggml_tensor* init_img = NULL;
SDCondition id_cond; SDCondition id_cond;
std::vector<bool> class_tokens_mask; std::vector<bool> class_tokens_mask;
ConditionerParams condition_params;
condition_params.clip_skip = clip_skip;
condition_params.width = width;
condition_params.height = height;
condition_params.ref_images = ref_images;
condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels();
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->stacked_id) {
if (!sd_ctx->sd->pmid_lora->applied) { if (!sd_ctx->sd->pmid_lora->applied) {
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
@ -2041,7 +2076,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<sd_image_f32_t> processed_id_images; std::vector<sd_image_f32_t> processed_id_images;
for (int i = 0; i < pm_params.id_images_count; i++) { for (int i = 0; i < pm_params.id_images_count; i++) {
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size); sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
free(id_image.data); free(id_image.data);
id_image.data = NULL; id_image.data = NULL;
processed_id_images.push_back(processed_id_image); processed_id_images.push_back(processed_id_image);
@ -2058,17 +2093,15 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
} }
processed_id_images.clear(); processed_id_images.clear();
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, condition_params.text = prompt;
sd_ctx->sd->n_threads, prompt, condition_params.num_input_imgs = pm_params.id_images_count;
clip_skip, auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
width, sd_ctx->sd->n_threads,
height, condition_params);
pm_params.id_images_count, id_cond = std::get<0>(cond_tup);
sd_ctx->sd->diffusion_model->get_adm_in_channels()); class_tokens_mask = std::get<1>(cond_tup); //
id_cond = std::get<0>(cond_tup); struct ggml_tensor* id_embeds = NULL;
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = NULL;
if (pmv2 && pm_params.id_embed_path != nullptr) { if (pmv2 && pm_params.id_embed_path != nullptr) {
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
// print_ggml_tensor(id_embeds, true, "id_embeds:"); // print_ggml_tensor(id_embeds, true, "id_embeds:");
@ -2094,14 +2127,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
} }
// Get learned condition // Get learned condition
t0 = ggml_time_ms(); t0 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, condition_params.text = prompt;
sd_ctx->sd->n_threads, condition_params.zero_out_masked = false;
prompt, SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
clip_skip, sd_ctx->sd->n_threads,
width, condition_params);
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
SDCondition uncond; SDCondition uncond;
if (guidance.txt_cfg != 1.0 || if (guidance.txt_cfg != 1.0 ||
@ -2110,14 +2141,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
zero_out_masked = true; zero_out_masked = true;
} }
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, condition_params.text = negative_prompt;
sd_ctx->sd->n_threads, condition_params.zero_out_masked = zero_out_masked;
negative_prompt, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
clip_skip, sd_ctx->sd->n_threads,
width, condition_params);
height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
} }
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
@ -2538,13 +2566,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
std::vector<ggml_tensor*> ref_latents; std::vector<ggml_tensor*> ref_latents;
for (int i = 0; i < ref_images.size(); i++) { for (int i = 0; i < ref_images.size(); i++) {
ggml_tensor* img = ggml_new_tensor_4d(work_ctx, ggml_tensor* img;
GGML_TYPE_F32, if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
ref_images[i]->width, sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]);
ref_images[i]->height, int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height);
3, double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height);
1); double vae_height = vae_width * ref_image.height / ref_image.width;
sd_image_to_tensor(*ref_images[i], img);
vae_height = round(vae_height / 32) * 32;
vae_width = round(vae_width / 32) * 32;
sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast<int>(vae_width), static_cast<int>(vae_height));
free(ref_image.data);
ref_image.data = nullptr;
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
resized_image.width,
resized_image.height,
3,
1);
sd_image_f32_to_tensor(resized_image, img);
free(resized_image.data);
resized_image.data = nullptr;
} else {
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
ref_images[i]->width,
ref_images[i]->height,
3,
1);
sd_image_to_tensor(*ref_images[i], img);
}
// print_ggml_tensor(img, false, "img");
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent); ref_latents.push_back(latent);
@ -2578,6 +2635,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->control_image, sd_img_gen_params->control_image,
sd_img_gen_params->control_strength, sd_img_gen_params->control_strength,
sd_img_gen_params->pm_params, sd_img_gen_params->pm_params,
ref_images,
ref_latents, ref_latents,
sd_img_gen_params->increase_ref_index, sd_img_gen_params->increase_ref_index,
concat_latent, concat_latent,
@ -2835,30 +2893,25 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
// Get learned condition // Get learned condition
bool zero_out_masked = true; ConditionerParams condition_params;
int64_t t1 = ggml_time_ms(); condition_params.clip_skip = sd_vid_gen_params->clip_skip;
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, condition_params.zero_out_masked = true;
sd_ctx->sd->n_threads, condition_params.text = prompt;
prompt,
sd_vid_gen_params->clip_skip, int64_t t1 = ggml_time_ms();
width, SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
height, sd_ctx->sd->n_threads,
sd_ctx->sd->diffusion_model->get_adm_in_channels(), condition_params);
zero_out_masked); cond.c_concat = concat_latent;
cond.c_concat = concat_latent; cond.c_vector = clip_vision_output;
cond.c_vector = clip_vision_output;
SDCondition uncond; SDCondition uncond;
if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) { if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) {
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, condition_params.text = negative_prompt;
sd_ctx->sd->n_threads, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
negative_prompt, sd_ctx->sd->n_threads,
sd_vid_gen_params->clip_skip, condition_params);
width, uncond.c_concat = concat_latent;
height, uncond.c_vector = clip_vision_output;
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked);
uncond.c_concat = concat_latent;
uncond.c_vector = clip_vision_output;
} }
int64_t t2 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);

View File

@ -132,6 +132,7 @@ typedef struct {
const char* clip_vision_path; const char* clip_vision_path;
const char* t5xxl_path; const char* t5xxl_path;
const char* qwen2vl_path; const char* qwen2vl_path;
const char* qwen2vl_vision_path;
const char* diffusion_model_path; const char* diffusion_model_path;
const char* high_noise_diffusion_model_path; const char* high_noise_diffusion_model_path;
const char* vae_path; const char* vae_path;

View File

@ -84,6 +84,7 @@ int round_up_to(int value, int base) {
} }
#ifdef _WIN32 // code for windows #ifdef _WIN32 // code for windows
#define NOMINMAX
#include <windows.h> #include <windows.h>
bool file_exists(const std::string& filename) { bool file_exists(const std::string& filename) {
@ -298,7 +299,7 @@ std::string trim(const std::string& s) {
static sd_log_cb_t sd_log_cb = NULL; static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL; void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024 #define LOG_BUFFER_SIZE 4096
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) { void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
va_list args; va_list args;
@ -387,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
float original_x = (float)x * image.width / target_width; float original_x = (float)x * image.width / target_width;
float original_y = (float)y * image.height / target_height; float original_y = (float)y * image.height / target_height;
int x1 = (int)original_x; uint32_t x1 = (uint32_t)original_x;
int y1 = (int)original_y; uint32_t y1 = (uint32_t)original_y;
int x2 = x1 + 1; uint32_t x2 = std::min(x1 + 1, image.width - 1);
int y2 = y1 + 1; uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -427,23 +428,26 @@ float means[3] = {0.48145466, 0.4578275, 0.40821073};
float stds[3] = {0.26862954, 0.26130258, 0.27577711}; float stds[3] = {0.26862954, 0.26130258, 0.27577711};
// Function to clip and preprocess sd_image_f32_t // Function to clip and preprocess sd_image_f32_t
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) { sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height) {
float scale = (float)size / fmin(image.width, image.height); float width_scale = (float)target_width / image.width;
float height_scale = (float)target_height / image.height;
float scale = std::fmax(width_scale, height_scale);
// Interpolation // Interpolation
int new_width = (int)(scale * image.width); int resized_width = (int)(scale * image.width);
int new_height = (int)(scale * image.height); int resized_height = (int)(scale * image.height);
float* resized_data = (float*)malloc(new_width * new_height * image.channel * sizeof(float)); float* resized_data = (float*)malloc(resized_width * resized_height * image.channel * sizeof(float));
for (int y = 0; y < new_height; y++) { for (int y = 0; y < resized_height; y++) {
for (int x = 0; x < new_width; x++) { for (int x = 0; x < resized_width; x++) {
float original_x = (float)x * image.width / new_width; float original_x = (float)x * image.width / resized_width;
float original_y = (float)y * image.height / new_height; float original_y = (float)y * image.height / resized_height;
int x1 = (int)original_x; uint32_t x1 = (uint32_t)original_x;
int y1 = (int)original_y; uint32_t y1 = (uint32_t)original_y;
int x2 = x1 + 1; uint32_t x2 = std::min(x1 + 1, image.width - 1);
int y2 = y1 + 1; uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -456,26 +460,28 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio); float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio);
*(resized_data + y * new_width * image.channel + x * image.channel + k) = value; *(resized_data + y * resized_width * image.channel + x * image.channel + k) = value;
} }
} }
} }
// Clip and preprocess // Clip and preprocess
int h = (new_height - size) / 2; int h_offset = std::max((int)(resized_height - target_height) / 2, 0);
int w = (new_width - size) / 2; int w_offset = std::max((int)(resized_width - target_width) / 2, 0);
sd_image_f32_t result; sd_image_f32_t result;
result.width = size; result.width = target_width;
result.height = size; result.height = target_height;
result.channel = image.channel; result.channel = image.channel;
result.data = (float*)malloc(size * size * image.channel * sizeof(float)); result.data = (float*)malloc(target_height * target_width * image.channel * sizeof(float));
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) { for (int i = 0; i < result.height; i++) {
for (int j = 0; j < size; j++) { for (int j = 0; j < result.width; j++) {
*(result.data + i * size * image.channel + j * image.channel + k) = int src_y = std::min(i + h_offset, resized_height - 1);
fmin(fmax(*(resized_data + (i + h) * new_width * image.channel + (j + w) * image.channel + k), 0.0f), 255.0f) / 255.0f; int src_x = std::min(j + w_offset, resized_width - 1);
*(result.data + i * result.width * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f;
} }
} }
} }
@ -485,10 +491,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
// Normalize // Normalize
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) { for (int i = 0; i < result.height; i++) {
for (int j = 0; j < size; j++) { for (int j = 0; j < result.width; j++) {
// *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f; // *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f;
int offset = i * size * image.channel + j * image.channel + k; int offset = i * result.width * image.channel + j * image.channel + k;
float value = *(result.data + offset); float value = *(result.data + offset);
value = (value - means[k]) / stds[k]; value = (value - means[k]) / stds[k];
// value = 0.5f; // value = 0.5f;

2
util.h
View File

@ -42,7 +42,7 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image);
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height); sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter); std::vector<std::string> split_string(const std::string& str, char delimiter);

View File

@ -1333,7 +1333,7 @@ namespace WAN {
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;