add qwen2vl vit support

This commit is contained in:
leejet 2025-09-29 23:05:30 +08:00
parent 95cae28465
commit 58e81adf61
6 changed files with 647 additions and 142 deletions

View File

@ -27,6 +27,8 @@
#include "avi_writer.h" #include "avi_writer.h"
#include "qwenvl.hpp"
#if defined(_WIN32) #if defined(_WIN32)
#define NOMINMAX #define NOMINMAX
#include <windows.h> #include <windows.h>
@ -1142,6 +1144,10 @@ bool load_images_from_dir(const std::string dir,
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
params.verbose = true;
sd_set_log_callback(sd_log_cb, (void*)&params);
Qwen::Qwen2_5_VLEmbedder::load_from_file_and_test(argv[1]);
return 1;
parse_args(argc, argv, params); parse_args(argc, argv, params);
params.sample_params.guidance.slg.layers = params.skip_layers.data(); params.sample_params.guidance.slg.layers = params.skip_layers.data();
params.sample_params.guidance.slg.layer_count = params.skip_layers.size(); params.sample_params.guidance.slg.layer_count = params.skip_layers.size();

View File

@ -81,57 +81,6 @@ namespace Flux {
} }
}; };
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
struct SelfAttention : public GGMLBlock { struct SelfAttention : public GGMLBlock {
public: public:
int64_t num_heads; int64_t num_heads;
@ -179,9 +128,9 @@ namespace Flux {
// x: [N, n_token, dim] // x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2] // pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim] // return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
}; };
@ -369,8 +318,8 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,
attn->ne[0], attn->ne[0],
@ -504,7 +453,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q); q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]

View File

@ -156,7 +156,7 @@ namespace Qwen {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,

View File

@ -15,9 +15,11 @@
#include "clip.hpp" #include "clip.hpp"
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "json.hpp" #include "json.hpp"
#include "rope.hpp"
#include "tokenize_util.h" #include "tokenize_util.h"
namespace Qwen { namespace Qwen {
constexpr int QWENVL_GRAPH_SIZE = 10240;
class Qwen2Tokenizer { class Qwen2Tokenizer {
private: private:
@ -340,9 +342,9 @@ namespace Qwen {
struct Qwen2_5_VLMLP : public GGMLBlock { struct Qwen2_5_VLMLP : public GGMLBlock {
public: public:
Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false)); blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false)); blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, false)); blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@ -359,6 +361,218 @@ namespace Qwen {
} }
}; };
struct Qwen2_5_VisionPatchEmbed : public GGMLBlock {
protected:
int64_t patch_size;
int64_t temporal_patch_size;
int64_t embed_dim;
public:
Qwen2_5_VisionPatchEmbed(int64_t patch_size = 14,
int64_t temporal_patch_size = 2,
int64_t in_channels = 3,
int64_t embed_dim = 1152)
: patch_size(patch_size), temporal_patch_size(temporal_patch_size), embed_dim(embed_dim) {
std::tuple<int, int, int> kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size};
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
embed_dim,
kernel_size,
kernel_size, // stride
{0, 0, 0}, // padding
{1, 1, 1}, // dilation
false));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size]
// return: [N*grid_t*grid_h*grid_w, embed_dim]
auto proj = std::dynamic_pointer_cast<Conv3d>(blocks["proj"]);
x = ggml_reshape_4d(ctx,
x,
patch_size,
patch_size,
temporal_patch_size,
ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size));
x = proj->forward(ctx, x);
x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim);
return x;
}
};
struct Qwen2_5_VLPatchMerger : public GGMLBlock {
protected:
int64_t hidden_size;
public:
Qwen2_5_VLPatchMerger(int64_t dim,
int64_t context_dim,
int64_t spatial_merge_size) {
hidden_size = context_dim * spatial_merge_size * spatial_merge_size;
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(context_dim, 1e-6f));
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
// mlp.1 is nn.GELU()
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto ln_q = std::dynamic_pointer_cast<RMSNorm>(blocks["ln_q"]);
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
};
struct Qwen2_5_VLVisionAttention : public GGMLBlock {
protected:
int64_t head_dim;
int64_t num_heads;
public:
Qwen2_5_VLVisionAttention(int64_t hidden_size,
int64_t num_heads)
: num_heads(num_heads) {
head_dim = hidden_size / num_heads;
GGML_ASSERT(num_heads * head_dim == hidden_size);
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size]
x = proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
}
};
struct Qwen2_5_VLVisionBlock : public GGMLBlock {
public:
Qwen2_5_VLVisionBlock(int64_t hidden_size,
int64_t intermediate_size,
int64_t num_heads,
float eps = 1e-6f) {
blocks["attn"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionAttention(hidden_size, num_heads));
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
auto attn = std::dynamic_pointer_cast<Qwen2_5_VLVisionAttention>(blocks["attn"]);
auto mlp = std::dynamic_pointer_cast<Qwen2_5_VLMLP>(blocks["mlp"]);
auto norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm2"]);
auto residual = x;
x = norm1->forward(ctx, x);
x = attn->forward(ctx, backend, x, pe, mask);
x = ggml_add_inplace(ctx, x, residual);
residual = x;
x = norm2->forward(ctx, x);
x = mlp->forward(ctx, x);
x = ggml_add_inplace(ctx, x, residual);
return x;
}
};
struct Qwen2_5_VLVisionModel : public GGMLBlock {
protected:
int64_t num_layers;
int64_t spatial_merge_size;
std::set<int> fullatt_block_indexes;
public:
Qwen2_5_VLVisionModel(int64_t num_layers,
int64_t in_channels,
int64_t hidden_size,
int64_t out_hidden_size,
int64_t intermediate_size,
int64_t num_heads,
int64_t spatial_merge_size,
int64_t patch_size,
int64_t temporal_patch_size,
int64_t window_size,
std::set<int> fullatt_block_indexes = {7, 15, 23, 31},
float eps = 1e-6f)
: num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) {
blocks["patch_embed"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VisionPatchEmbed(patch_size, temporal_patch_size, in_channels, hidden_size));
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionBlock(hidden_size,
intermediate_size,
num_heads,
eps));
}
blocks["merger"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
// pixel_values: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
// window_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_inverse_index: [grid_t*(H/mh/ph)*(W/mw/pw)]
// window_mask: [grid_h*grid_w, grid_h*grid_w]
auto patch_embed = std::dynamic_pointer_cast<Qwen2_5_VisionPatchEmbed>(blocks["patch_embed"]);
auto merger = std::dynamic_pointer_cast<Qwen2_5_VLPatchMerger>(blocks["merger"]);
auto x = patch_embed->forward(ctx, pixel_values);
x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx, x, window_index);
x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]);
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<Qwen2_5_VLVisionBlock>(blocks["blocks." + std::to_string(i)]);
auto mask = window_mask;
if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) {
mask = nullptr;
}
x = block->forward(ctx, backend, x, pe, mask);
}
x = merger->forward(ctx, x);
x = ggml_get_rows(ctx, x, window_inverse_index);
return x;
}
};
struct Qwen2_5_VLAttention : public GGMLBlock { struct Qwen2_5_VLAttention : public GGMLBlock {
protected: protected:
int64_t head_dim; int64_t head_dim;
@ -498,6 +712,20 @@ namespace Qwen {
} }
}; };
struct Qwen2_5_VLVisionParams {
int64_t num_layers = 32;
int64_t hidden_size = 1280;
int64_t intermediate_size = 3420;
int64_t num_heads = 16;
int64_t in_channels = 3;
int64_t out_hidden_size = 3584;
int64_t temporal_patch_size = 2;
int64_t patch_size = 14;
int64_t spatial_merge_size = 2;
int64_t window_size = 112;
std::set<int> fullatt_block_indexes = {7, 15, 23, 31};
};
struct Qwen2_5_VLParams { struct Qwen2_5_VLParams {
int64_t num_layers = 28; int64_t num_layers = 28;
int64_t hidden_size = 3584; int64_t hidden_size = 3584;
@ -506,15 +734,17 @@ namespace Qwen {
int64_t num_kv_heads = 4; int64_t num_kv_heads = 4;
int64_t vocab_size = 152064; int64_t vocab_size = 152064;
float rms_norm_eps = 1e-06f; float rms_norm_eps = 1e-06f;
Qwen2_5_VLVisionParams vision;
}; };
struct Qwen2_5_VL : public GGMLBlock { struct Qwen2_5_VL : public GGMLBlock {
bool enable_vision;
Qwen2_5_VLParams params; Qwen2_5_VLParams params;
public: public:
Qwen2_5_VL() {} Qwen2_5_VL() {}
Qwen2_5_VL(Qwen2_5_VLParams params) Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false)
: params(params) { : enable_vision(enable_vision), params(params) {
blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers, blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers,
params.vocab_size, params.vocab_size,
params.hidden_size, params.hidden_size,
@ -522,6 +752,19 @@ namespace Qwen {
params.num_heads, params.num_heads,
params.num_kv_heads, params.num_kv_heads,
params.rms_norm_eps)); params.rms_norm_eps));
if (enable_vision) {
blocks["visual"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLVisionModel(params.vision.num_layers,
params.vision.in_channels,
params.vision.hidden_size,
params.vision.out_hidden_size,
params.vision.intermediate_size,
params.vision.num_heads,
params.vision.spatial_merge_size,
params.vision.patch_size,
params.vision.temporal_patch_size,
params.vision.window_size,
params.vision.fullatt_block_indexes));
}
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* forward(struct ggml_context* ctx,
@ -534,6 +777,18 @@ namespace Qwen {
auto x = model->forward(ctx, backend, input_ids, input_pos); auto x = model->forward(ctx, backend, input_ids, input_pos);
return x; return x;
} }
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* pe,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
GGML_ASSERT(enable_vision);
auto vision_model = std::dynamic_pointer_cast<Qwen2_5_VLVisionModel>(blocks["visual"]);
return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask);
}
}; };
struct Qwen2_5_VLRunner : public GGMLRunner { struct Qwen2_5_VLRunner : public GGMLRunner {
@ -541,13 +796,17 @@ namespace Qwen {
Qwen2_5_VL model; Qwen2_5_VL model;
std::vector<int> input_pos_vec; std::vector<int> input_pos_vec;
std::vector<float> window_mask_vec;
std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec;
std::vector<float> pe_vec;
Qwen2_5_VLRunner(ggml_backend_t backend, Qwen2_5_VLRunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types, const String2GGMLType& tensor_types,
const std::string prefix) const std::string prefix,
: GGMLRunner(backend, offload_params_to_cpu) { bool enable_vision = false)
model = Qwen2_5_VL(params); : GGMLRunner(backend, offload_params_to_cpu), model(params, enable_vision) {
model.init(params_ctx, tensor_types, prefix); model.init(params_ctx, tensor_types, prefix);
} }
@ -567,6 +826,17 @@ namespace Qwen {
return hidden_states; return hidden_states;
} }
struct ggml_tensor* vision_forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
struct ggml_tensor* input_pos,
struct ggml_tensor* window_index,
struct ggml_tensor* window_inverse_index,
struct ggml_tensor* window_mask) {
auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask);
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
@ -602,6 +872,166 @@ namespace Qwen {
}; };
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} }
struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) {
// image: [C, H, W]
// return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1
int64_t C = image->ne[2];
int64_t H = image->ne[1];
int64_t W = image->ne[0];
int64_t mh = params.vision.spatial_merge_size;
int64_t mw = params.vision.spatial_merge_size;
int64_t pt = params.vision.temporal_patch_size;
int64_t ph = params.vision.patch_size;
int64_t pw = params.vision.patch_size;
image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw]
image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw]
image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw]
image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw]
image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw]
image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw]
image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw]
image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw]
return image;
}
struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false);
GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
int grid_t = 1;
int grid_h = image->ne[1] / params.vision.patch_size;
int grid_w = image->ne[0] / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size;
image = to_backend(image);
auto pixel_values = process_image(compute_ctx, image);
// window index
int inverse_index = 0;
window_index_vec.resize(llm_grid_h * llm_grid_w);
window_inverse_index_vec.resize(llm_grid_h * llm_grid_w);
std::vector<int> seqlens;
for (int ih = 0; ih < llm_grid_h; ih += vit_merger_window_size) {
for (int iw = 0; iw < llm_grid_w; iw += vit_merger_window_size) {
int win_h = std::min(vit_merger_window_size, llm_grid_h - ih);
int win_w = std::min(vit_merger_window_size, llm_grid_w - iw);
for (int iy = 0; iy < win_h; iy++) {
for (int ix = 0; ix < win_w; ix++) {
int index = (ih + iy) * llm_grid_w + iw + ix;
window_index_vec[inverse_index] = index;
window_inverse_index_vec[index] = inverse_index;
inverse_index++;
}
}
seqlens.push_back(win_h * win_w * params.vision.spatial_merge_size * params.vision.spatial_merge_size);
}
}
// printf("window_index: ");
// for (int i : window_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("window_inverse_index: ");
// for (int i : window_inverse_index_vec) {
// printf("%d ", i);
// }
// printf("\n");
// printf("seqlens: ");
// for (int i : seqlens) {
// printf("%d ", i);
// }
// printf("\n");
auto window_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
auto window_inverse_index = ggml_new_tensor_1d(compute_ctx,
GGML_TYPE_I32,
llm_grid_h * llm_grid_w);
set_backend_tensor_data(window_index, window_index_vec.data());
set_backend_tensor_data(window_inverse_index, window_inverse_index_vec.data());
// window mask
int seq_window_size = (vit_merger_window_size * params.vision.spatial_merge_size) * (vit_merger_window_size * params.vision.spatial_merge_size);
window_mask_vec.resize((grid_h * grid_w) * (grid_h * grid_w));
int window_start_index = 0;
for (int seq_index = 0; seq_index < seqlens.size(); seq_index++) {
int window_end_index = window_start_index + seqlens[seq_index];
// LOG_DEBUG("%d %d", window_start_index, window_end_index);
GGML_ASSERT(window_end_index <= grid_h * grid_w);
for (int i = window_start_index; i < window_end_index; i++) {
for (int j = 0; j < grid_h * grid_w; j++) {
float mask_value = -INFINITY;
if (j >= window_start_index && j < window_end_index) {
mask_value = 0;
}
GGML_ASSERT((i * (grid_h * grid_w) + j) < window_mask_vec.size());
window_mask_vec[i * (grid_h * grid_w) + j] = mask_value;
}
}
window_start_index = window_end_index;
// printf("\n");
}
// printf("window_mask: \n");
// for (int i = 0; i < grid_h*grid_w; i++) {
// for (int j = 0; j < grid_h*grid_w; j++) {
// printf("%f ", window_mask_vec[i * (grid_h * grid_w) + j]);
// }
// printf("\n");
// }
auto window_mask = ggml_new_tensor_2d(compute_ctx,
GGML_TYPE_F32,
grid_h * grid_w,
grid_h * grid_w);
set_backend_tensor_data(window_mask, window_mask_vec.data());
// pe
int head_dim = params.vision.hidden_size / params.vision.num_heads;
pe_vec = Rope::gen_qwen2vl_pe(grid_h,
grid_w,
params.vision.spatial_merge_size,
window_inverse_index_vec,
10000.f,
{head_dim / 2, head_dim / 2});
int pos_len = pe_vec.size() / head_dim / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* hidden_states = vision_forward(compute_ctx,
runtime_backend,
pixel_values,
pe,
window_index,
window_inverse_index,
window_mask);
ggml_build_forward_expand(gf, hidden_states);
return gf;
}
void encode_image(const int n_threads,
struct ggml_tensor* image,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_encode_image_graph(image);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
}; };
struct Qwen2_5_VLEmbedder { struct Qwen2_5_VLEmbedder {
@ -611,8 +1041,9 @@ namespace Qwen {
Qwen2_5_VLEmbedder(ggml_backend_t backend, Qwen2_5_VLEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "") const std::string prefix = "",
: model(backend, offload_params_to_cpu, tensor_types, prefix) { bool enable_vision = false)
: model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) {
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -666,8 +1097,26 @@ namespace Qwen {
struct ggml_context* work_ctx = ggml_init(params); struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL); GGML_ASSERT(work_ctx != NULL);
bool test_vit = true;
{ if (test_vit) {
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
// ggml_set_f32(image, 0.f);
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out, false, "out");
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
// ggml_tensor_diff(ref_out, out, 0.01f);
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
} else {
std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n"); std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n");
auto tokens_and_weights = tokenize(text, 0, false); auto tokens_and_weights = tokenize(text, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights); std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@ -692,7 +1141,7 @@ namespace Qwen {
// cpu f16: pass // cpu f16: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "qwen2vl.")) { if (!model_loader.init_from_file(file_path, "qwen2vl.")) {
@ -708,7 +1157,11 @@ namespace Qwen {
} }
} }
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl")); std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend,
false,
tensor_types,
"qwen2vl",
true));
qwenvl->alloc_params_buffer(); qwenvl->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;

237
rope.hpp
View File

@ -4,9 +4,9 @@
#include <vector> #include <vector>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
struct Rope { namespace Rope {
template <class T> template <class T>
static std::vector<T> linspace(T start, T end, int num) { __STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num); std::vector<T> result(num);
if (num == 1) { if (num == 1) {
result[0] = start; result[0] = start;
@ -19,7 +19,7 @@ struct Rope {
return result; return result;
} }
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) { __STATIC_INLINE__ std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size(); int rows = mat.size();
int cols = mat[0].size(); int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows)); std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
@ -31,7 +31,7 @@ struct Rope {
return transposed; return transposed;
} }
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) { __STATIC_INLINE__ std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec; std::vector<float> flat_vec;
for (const auto& sub_vec : vec) { for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
@ -39,7 +39,7 @@ struct Rope {
return flat_vec; return flat_vec;
} }
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) { __STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0); assert(dim % 2 == 0);
int half_dim = dim / 2; int half_dim = dim / 2;
@ -72,11 +72,11 @@ struct Rope {
} }
// Generate IDs for image patches and text // Generate IDs for image patches and text
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0)); return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
} }
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
@ -102,9 +102,9 @@ struct Rope {
return img_ids_repeated; return img_ids_repeated;
} }
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a, __STATIC_INLINE__ std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b, const std::vector<std::vector<float>>& b,
int bs) { int bs) {
size_t a_len = a.size() / bs; size_t a_len = a.size() / bs;
size_t b_len = b.size() / bs; size_t b_len = b.size() / bs;
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3)); std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
@ -119,10 +119,10 @@ struct Rope {
return ids; return ids;
} }
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids, __STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs, int bs,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size(); int num_axes = axes_dim.size();
@ -151,10 +151,10 @@ struct Rope {
return flatten(emb); return flatten(emb);
} }
static std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs, int bs,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
@ -183,13 +183,13 @@ struct Rope {
return ids; return ids;
} }
static std::vector<std::vector<float>> gen_flux_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_ids(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len); auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs); auto img_ids = gen_img_ids(h, w, patch_size, bs);
@ -202,26 +202,26 @@ struct Rope {
} }
// Generate flux positional embeddings // Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h, __STATIC_INLINE__ std::vector<float> gen_flux_pe(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
static std::vector<std::vector<float>> gen_qwen_image_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len); int txt_id_start = std::max(h_len, w_len);
@ -242,29 +242,29 @@ struct Rope {
} }
// Generate qwen_image positional embeddings // Generate qwen_image positional embeddings
static std::vector<float> gen_qwen_image_pe(int h, __STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
static std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h, int h,
int w, int w,
int pt, int pt,
int ph, int ph,
int pw, int pw,
int bs, int bs,
int t_offset = 0, int t_offset = 0,
int h_offset = 0, int h_offset = 0,
int w_offset = 0) { int w_offset = 0) {
int t_len = (t + (pt / 2)) / pt; int t_len = (t + (pt / 2)) / pt;
int h_len = (h + (ph / 2)) / ph; int h_len = (h + (ph / 2)) / ph;
int w_len = (w + (pw / 2)) / pw; int w_len = (w + (pw / 2)) / pw;
@ -296,18 +296,115 @@ struct Rope {
} }
// Generate wan positional embeddings // Generate wan positional embeddings
static std::vector<float> gen_wan_pe(int t, __STATIC_INLINE__ std::vector<float> gen_wan_pe(int t,
int h, int h,
int w, int w,
int pt, int pt,
int ph, int ph,
int pw, int pw,
int bs, int bs,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
}; // struct Rope
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen2vl_ids(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index) {
std::vector<std::vector<float>> ids(grid_h * grid_w, std::vector<float>(2, 0.0));
int index = 0;
for (int ih = 0; ih < grid_h; ih += merge_size) {
for (int iw = 0; iw < grid_w; iw += merge_size) {
for (int iy = 0; iy < merge_size; iy++) {
for (int ix = 0; ix < merge_size; ix++) {
int inverse_index = window_index[index / (merge_size * merge_size)];
int i = inverse_index * (merge_size * merge_size) + index % (merge_size * merge_size);
GGML_ASSERT(i < grid_h * grid_w);
ids[i][0] = ih + iy;
ids[i][1] = iw + ix;
index++;
}
}
}
}
return ids;
}
// Generate qwen2vl positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen2vl_pe(int grid_h,
int grid_w,
int merge_size,
const std::vector<int>& window_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index);
return embed_nd(ids, 1, theta, axes_dim);
}
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
bool rope_interleaved = true) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2], [[cos, -sin], [sin, cos]]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
if (rope_interleaved) {
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
} else {
x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2]
}
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
if (!rope_interleaved) {
x_out = ggml_cont(ctx, ggml_permute(ctx, x_out, 1, 0, 2, 3)); // [N * n_head, L, x, d_head/2]
}
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f,
bool rope_interleaved = true) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}
}; // namespace Rope
#endif // __ROPE_HPP__ #endif // __ROPE_HPP__

View File

@ -1333,7 +1333,7 @@ namespace WAN {
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;