diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index c0bd55b..843798e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -27,6 +27,8 @@ #include "avi_writer.h" +#include "qwenvl.hpp" + #if defined(_WIN32) #define NOMINMAX #include @@ -1142,6 +1144,10 @@ bool load_images_from_dir(const std::string dir, int main(int argc, const char* argv[]) { SDParams params; + params.verbose = true; + sd_set_log_callback(sd_log_cb, (void*)¶ms); + Qwen::Qwen2_5_VLEmbedder::load_from_file_and_test(argv[1]); + return 1; parse_args(argc, argv, params); params.sample_params.guidance.slg.layers = params.skip_layers.data(); params.sample_params.guidance.slg.layer_count = params.skip_layers.size(); diff --git a/flux.hpp b/flux.hpp index 10dba08..4153c6f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -81,57 +81,6 @@ namespace Flux { } }; - __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* pe) { - // x: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - int64_t d_head = x->ne[0]; - int64_t n_head = x->ne[1]; - int64_t L = x->ne[2]; - int64_t N = x->ne[3]; - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] - x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] - x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] - - int64_t offset = x->nb[2] * x->ne[2]; - auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] - auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] - x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] - x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] - auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); - x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] - x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] - - pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] - offset = pe->nb[2] * pe->ne[2]; - auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] - auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] - - auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] - x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head] - return x_out; - } - - __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, - ggml_backend_t backend, - struct ggml_tensor* q, - struct ggml_tensor* k, - struct ggml_tensor* v, - struct ggml_tensor* pe, - struct ggml_tensor* mask, - bool flash_attn, - float kv_scale = 1.0f) { - // q,k,v: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - // return: [N, L, n_head*d_head] - q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] - k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - - auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] - return x; - } - struct SelfAttention : public GGMLBlock { public: int64_t num_heads; @@ -179,9 +128,9 @@ namespace Flux { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -369,8 +318,8 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, attn->ne[0], @@ -504,7 +453,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] diff --git a/qwen_image.hpp b/qwen_image.hpp index 68b481a..a7bdc3b 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -156,7 +156,7 @@ namespace Qwen { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, diff --git a/qwenvl.hpp b/qwenvl.hpp index 228452d..fa88260 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -15,9 +15,11 @@ #include "clip.hpp" #include "ggml_extend.hpp" #include "json.hpp" +#include "rope.hpp" #include "tokenize_util.h" namespace Qwen { + constexpr int QWENVL_GRAPH_SIZE = 10240; class Qwen2Tokenizer { private: @@ -340,9 +342,9 @@ namespace Qwen { struct Qwen2_5_VLMLP : public GGMLBlock { public: Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { - blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); - blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); - blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, false)); + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { @@ -359,6 +361,218 @@ namespace Qwen { } }; + struct Qwen2_5_VisionPatchEmbed : public GGMLBlock { + protected: + int64_t patch_size; + int64_t temporal_patch_size; + int64_t embed_dim; + + public: + Qwen2_5_VisionPatchEmbed(int64_t patch_size = 14, + int64_t temporal_patch_size = 2, + int64_t in_channels = 3, + int64_t embed_dim = 1152) + : patch_size(patch_size), temporal_patch_size(temporal_patch_size), embed_dim(embed_dim) { + std::tuple kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size}; + blocks["proj"] = std::shared_ptr(new Conv3d(in_channels, + embed_dim, + kernel_size, + kernel_size, // stride + {0, 0, 0}, // padding + {1, 1, 1}, // dilation + false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size] + // return: [N*grid_t*grid_h*grid_w, embed_dim] + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + x = ggml_reshape_4d(ctx, + x, + patch_size, + patch_size, + temporal_patch_size, + ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size)); + x = proj->forward(ctx, x); + x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim); + return x; + } + }; + + struct Qwen2_5_VLPatchMerger : public GGMLBlock { + protected: + int64_t hidden_size; + + public: + Qwen2_5_VLPatchMerger(int64_t dim, + int64_t context_dim, + int64_t spatial_merge_size) { + hidden_size = context_dim * spatial_merge_size * spatial_merge_size; + blocks["ln_q"] = std::shared_ptr(new RMSNorm(context_dim, 1e-6f)); + blocks["mlp.0"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + // mlp.1 is nn.GELU() + blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); + auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); + auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); + + x = ln_q->forward(ctx, x); + x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size); + x = mlp_0->forward(ctx, x); + x = ggml_gelu(ctx, x); + x = mlp_2->forward(ctx, x); + return x; + } + }; + + struct Qwen2_5_VLVisionAttention : public GGMLBlock { + protected: + int64_t head_dim; + int64_t num_heads; + + public: + Qwen2_5_VLVisionAttention(int64_t hidden_size, + int64_t num_heads) + : num_heads(num_heads) { + head_dim = hidden_size / num_heads; + GGML_ASSERT(num_heads * head_dim == hidden_size); + blocks["qkv"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3)); + blocks["proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr) { + // x: [N, n_token, hidden_size] + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + + x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] + + x = proj->forward(ctx, x); // [N, n_token, hidden_size] + return x; + } + }; + + struct Qwen2_5_VLVisionBlock : public GGMLBlock { + public: + Qwen2_5_VLVisionBlock(int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + float eps = 1e-6f) { + blocks["attn"] = std::shared_ptr(new Qwen2_5_VLVisionAttention(hidden_size, num_heads)); + blocks["mlp"] = std::shared_ptr(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true)); + blocks["norm1"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + blocks["norm2"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr) { + // x: [N, n_token, hidden_size] + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + + auto residual = x; + x = norm1->forward(ctx, x); + x = attn->forward(ctx, backend, x, pe, mask); + x = ggml_add_inplace(ctx, x, residual); + + residual = x; + x = norm2->forward(ctx, x); + x = mlp->forward(ctx, x); + x = ggml_add_inplace(ctx, x, residual); + + return x; + } + }; + + struct Qwen2_5_VLVisionModel : public GGMLBlock { + protected: + int64_t num_layers; + int64_t spatial_merge_size; + std::set fullatt_block_indexes; + + public: + Qwen2_5_VLVisionModel(int64_t num_layers, + int64_t in_channels, + int64_t hidden_size, + int64_t out_hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t spatial_merge_size, + int64_t patch_size, + int64_t temporal_patch_size, + int64_t window_size, + std::set fullatt_block_indexes = {7, 15, 23, 31}, + float eps = 1e-6f) + : num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) { + blocks["patch_embed"] = std::shared_ptr(new Qwen2_5_VisionPatchEmbed(patch_size, temporal_patch_size, in_channels, hidden_size)); + for (int i = 0; i < num_layers; i++) { + blocks["blocks." + std::to_string(i)] = std::shared_ptr(new Qwen2_5_VLVisionBlock(hidden_size, + intermediate_size, + num_heads, + eps)); + } + blocks["merger"] = std::shared_ptr(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* pe, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + // pixel_values: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw] + // window_index: [grid_t*(H/mh/ph)*(W/mw/pw)] + // window_inverse_index: [grid_t*(H/mh/ph)*(W/mw/pw)] + // window_mask: [grid_h*grid_w, grid_h*grid_w] + auto patch_embed = std::dynamic_pointer_cast(blocks["patch_embed"]); + auto merger = std::dynamic_pointer_cast(blocks["merger"]); + + auto x = patch_embed->forward(ctx, pixel_values); + + x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]); + x = ggml_get_rows(ctx, x, window_index); + x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]); + + for (int i = 0; i < num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); + + auto mask = window_mask; + if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) { + mask = nullptr; + } + x = block->forward(ctx, backend, x, pe, mask); + } + + x = merger->forward(ctx, x); + + x = ggml_get_rows(ctx, x, window_inverse_index); + + return x; + } + }; + struct Qwen2_5_VLAttention : public GGMLBlock { protected: int64_t head_dim; @@ -498,6 +712,20 @@ namespace Qwen { } }; + struct Qwen2_5_VLVisionParams { + int64_t num_layers = 32; + int64_t hidden_size = 1280; + int64_t intermediate_size = 3420; + int64_t num_heads = 16; + int64_t in_channels = 3; + int64_t out_hidden_size = 3584; + int64_t temporal_patch_size = 2; + int64_t patch_size = 14; + int64_t spatial_merge_size = 2; + int64_t window_size = 112; + std::set fullatt_block_indexes = {7, 15, 23, 31}; + }; + struct Qwen2_5_VLParams { int64_t num_layers = 28; int64_t hidden_size = 3584; @@ -506,15 +734,17 @@ namespace Qwen { int64_t num_kv_heads = 4; int64_t vocab_size = 152064; float rms_norm_eps = 1e-06f; + Qwen2_5_VLVisionParams vision; }; struct Qwen2_5_VL : public GGMLBlock { + bool enable_vision; Qwen2_5_VLParams params; public: Qwen2_5_VL() {} - Qwen2_5_VL(Qwen2_5_VLParams params) - : params(params) { + Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false) + : enable_vision(enable_vision), params(params) { blocks["model"] = std::shared_ptr(new Qwen2_5_VLTextModel(params.num_layers, params.vocab_size, params.hidden_size, @@ -522,6 +752,19 @@ namespace Qwen { params.num_heads, params.num_kv_heads, params.rms_norm_eps)); + if (enable_vision) { + blocks["visual"] = std::shared_ptr(new Qwen2_5_VLVisionModel(params.vision.num_layers, + params.vision.in_channels, + params.vision.hidden_size, + params.vision.out_hidden_size, + params.vision.intermediate_size, + params.vision.num_heads, + params.vision.spatial_merge_size, + params.vision.patch_size, + params.vision.temporal_patch_size, + params.vision.window_size, + params.vision.fullatt_block_indexes)); + } } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -534,6 +777,18 @@ namespace Qwen { auto x = model->forward(ctx, backend, input_ids, input_pos); return x; } + + struct ggml_tensor* vision_forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* pe, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + GGML_ASSERT(enable_vision); + auto vision_model = std::dynamic_pointer_cast(blocks["visual"]); + return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask); + } }; struct Qwen2_5_VLRunner : public GGMLRunner { @@ -541,13 +796,17 @@ namespace Qwen { Qwen2_5_VL model; std::vector input_pos_vec; + std::vector window_mask_vec; + std::vector window_index_vec; + std::vector window_inverse_index_vec; + std::vector pe_vec; Qwen2_5_VLRunner(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types, - const std::string prefix) - : GGMLRunner(backend, offload_params_to_cpu) { - model = Qwen2_5_VL(params); + const std::string prefix, + bool enable_vision = false) + : GGMLRunner(backend, offload_params_to_cpu), model(params, enable_vision) { model.init(params_ctx, tensor_types, prefix); } @@ -567,6 +826,17 @@ namespace Qwen { return hidden_states; } + struct ggml_tensor* vision_forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* input_pos, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask); + return hidden_states; + } + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); @@ -602,6 +872,166 @@ namespace Qwen { }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } + + struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) { + // image: [C, H, W] + // return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1 + int64_t C = image->ne[2]; + int64_t H = image->ne[1]; + int64_t W = image->ne[0]; + int64_t mh = params.vision.spatial_merge_size; + int64_t mw = params.vision.spatial_merge_size; + int64_t pt = params.vision.temporal_patch_size; + int64_t ph = params.vision.patch_size; + int64_t pw = params.vision.patch_size; + + image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw] + image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw] + image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw] + image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw] + image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw] + image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw] + image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw] + return image; + } + + struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false); + + GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); + GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); + + int grid_t = 1; + int grid_h = image->ne[1] / params.vision.patch_size; + int grid_w = image->ne[0] / params.vision.patch_size; + int llm_grid_h = grid_h / params.vision.spatial_merge_size; + int llm_grid_w = grid_w / params.vision.spatial_merge_size; + int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size; + + image = to_backend(image); + + auto pixel_values = process_image(compute_ctx, image); + + // window index + int inverse_index = 0; + window_index_vec.resize(llm_grid_h * llm_grid_w); + window_inverse_index_vec.resize(llm_grid_h * llm_grid_w); + std::vector seqlens; + for (int ih = 0; ih < llm_grid_h; ih += vit_merger_window_size) { + for (int iw = 0; iw < llm_grid_w; iw += vit_merger_window_size) { + int win_h = std::min(vit_merger_window_size, llm_grid_h - ih); + int win_w = std::min(vit_merger_window_size, llm_grid_w - iw); + for (int iy = 0; iy < win_h; iy++) { + for (int ix = 0; ix < win_w; ix++) { + int index = (ih + iy) * llm_grid_w + iw + ix; + window_index_vec[inverse_index] = index; + window_inverse_index_vec[index] = inverse_index; + inverse_index++; + } + } + seqlens.push_back(win_h * win_w * params.vision.spatial_merge_size * params.vision.spatial_merge_size); + } + } + // printf("window_index: "); + // for (int i : window_index_vec) { + // printf("%d ", i); + // } + // printf("\n"); + // printf("window_inverse_index: "); + // for (int i : window_inverse_index_vec) { + // printf("%d ", i); + // } + // printf("\n"); + // printf("seqlens: "); + // for (int i : seqlens) { + // printf("%d ", i); + // } + // printf("\n"); + auto window_index = ggml_new_tensor_1d(compute_ctx, + GGML_TYPE_I32, + llm_grid_h * llm_grid_w); + auto window_inverse_index = ggml_new_tensor_1d(compute_ctx, + GGML_TYPE_I32, + llm_grid_h * llm_grid_w); + set_backend_tensor_data(window_index, window_index_vec.data()); + set_backend_tensor_data(window_inverse_index, window_inverse_index_vec.data()); + + // window mask + int seq_window_size = (vit_merger_window_size * params.vision.spatial_merge_size) * (vit_merger_window_size * params.vision.spatial_merge_size); + window_mask_vec.resize((grid_h * grid_w) * (grid_h * grid_w)); + int window_start_index = 0; + for (int seq_index = 0; seq_index < seqlens.size(); seq_index++) { + int window_end_index = window_start_index + seqlens[seq_index]; + // LOG_DEBUG("%d %d", window_start_index, window_end_index); + GGML_ASSERT(window_end_index <= grid_h * grid_w); + for (int i = window_start_index; i < window_end_index; i++) { + for (int j = 0; j < grid_h * grid_w; j++) { + float mask_value = -INFINITY; + if (j >= window_start_index && j < window_end_index) { + mask_value = 0; + } + GGML_ASSERT((i * (grid_h * grid_w) + j) < window_mask_vec.size()); + window_mask_vec[i * (grid_h * grid_w) + j] = mask_value; + } + } + window_start_index = window_end_index; + // printf("\n"); + } + // printf("window_mask: \n"); + // for (int i = 0; i < grid_h*grid_w; i++) { + // for (int j = 0; j < grid_h*grid_w; j++) { + // printf("%f ", window_mask_vec[i * (grid_h * grid_w) + j]); + // } + // printf("\n"); + // } + auto window_mask = ggml_new_tensor_2d(compute_ctx, + GGML_TYPE_F32, + grid_h * grid_w, + grid_h * grid_w); + set_backend_tensor_data(window_mask, window_mask_vec.data()); + + // pe + int head_dim = params.vision.hidden_size / params.vision.num_heads; + pe_vec = Rope::gen_qwen2vl_pe(grid_h, + grid_w, + params.vision.spatial_merge_size, + window_inverse_index_vec, + 10000.f, + {head_dim / 2, head_dim / 2}); + int pos_len = pe_vec.size() / head_dim / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); + + struct ggml_tensor* hidden_states = vision_forward(compute_ctx, + runtime_backend, + pixel_values, + pe, + window_index, + window_inverse_index, + window_mask); + ggml_build_forward_expand(gf, hidden_states); + + return gf; + } + + void encode_image(const int n_threads, + struct ggml_tensor* image, + ggml_tensor** output, + ggml_context* output_ctx = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_encode_image_graph(image); + }; + GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + } }; struct Qwen2_5_VLEmbedder { @@ -611,8 +1041,9 @@ namespace Qwen { Qwen2_5_VLEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, - const std::string prefix = "") - : model(backend, offload_params_to_cpu, tensor_types, prefix) { + const std::string prefix = "", + bool enable_vision = false) + : model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -666,8 +1097,26 @@ namespace Qwen { struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); + bool test_vit = true; - { + if (test_vit) { + // auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3); + // ggml_set_f32(image, 0.f); + auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin"); + print_ggml_tensor(image, false, "image"); + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + model.encode_image(8, image, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out, false, "out"); + + // auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin"); + // ggml_tensor_diff(ref_out, out, 0.01f); + + LOG_DEBUG("qwen2vl test done in %dms", t1 - t0); + } else { std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n"); auto tokens_and_weights = tokenize(text, 0, false); std::vector& tokens = std::get<0>(tokens_and_weights); @@ -692,7 +1141,7 @@ namespace Qwen { // cpu f16: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_Q8_0; + ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; if (!model_loader.init_from_file(file_path, "qwen2vl.")) { @@ -708,7 +1157,11 @@ namespace Qwen { } } - std::shared_ptr qwenvl = std::shared_ptr(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl")); + std::shared_ptr qwenvl = std::shared_ptr(new Qwen2_5_VLEmbedder(backend, + false, + tensor_types, + "qwen2vl", + true)); qwenvl->alloc_params_buffer(); std::map tensors; diff --git a/rope.hpp b/rope.hpp index 8ecd818..295c9a2 100644 --- a/rope.hpp +++ b/rope.hpp @@ -4,9 +4,9 @@ #include #include "ggml_extend.hpp" -struct Rope { +namespace Rope { template - static std::vector linspace(T start, T end, int num) { + __STATIC_INLINE__ std::vector linspace(T start, T end, int num) { std::vector result(num); if (num == 1) { result[0] = start; @@ -19,7 +19,7 @@ struct Rope { return result; } - static std::vector> transpose(const std::vector>& mat) { + __STATIC_INLINE__ std::vector> transpose(const std::vector>& mat) { int rows = mat.size(); int cols = mat[0].size(); std::vector> transposed(cols, std::vector(rows)); @@ -31,7 +31,7 @@ struct Rope { return transposed; } - static std::vector flatten(const std::vector>& vec) { + __STATIC_INLINE__ std::vector flatten(const std::vector>& vec) { std::vector flat_vec; for (const auto& sub_vec : vec) { flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); @@ -39,7 +39,7 @@ struct Rope { return flat_vec; } - static std::vector> rope(const std::vector& pos, int dim, int theta) { + __STATIC_INLINE__ std::vector> rope(const std::vector& pos, int dim, int theta) { assert(dim % 2 == 0); int half_dim = dim / 2; @@ -72,11 +72,11 @@ struct Rope { } // Generate IDs for image patches and text - static std::vector> gen_txt_ids(int bs, int context_len) { + __STATIC_INLINE__ std::vector> gen_txt_ids(int bs, int context_len) { return std::vector>(bs * context_len, std::vector(3, 0.0)); } - static std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { + __STATIC_INLINE__ std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; @@ -102,9 +102,9 @@ struct Rope { return img_ids_repeated; } - static std::vector> concat_ids(const std::vector>& a, - const std::vector>& b, - int bs) { + __STATIC_INLINE__ std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { size_t a_len = a.size() / bs; size_t b_len = b.size() / bs; std::vector> ids(a.size() + b.size(), std::vector(3)); @@ -119,10 +119,10 @@ struct Rope { return ids; } - static std::vector embed_nd(const std::vector>& ids, - int bs, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector embed_nd(const std::vector>& ids, + int bs, + int theta, + const std::vector& axes_dim) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; int num_axes = axes_dim.size(); @@ -151,10 +151,10 @@ struct Rope { return flatten(emb); } - static std::vector> gen_refs_ids(int patch_size, - int bs, - const std::vector& ref_latents, - bool increase_ref_index) { + __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, + int bs, + const std::vector& ref_latents, + bool increase_ref_index) { std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; @@ -183,13 +183,13 @@ struct Rope { return ids; } - static std::vector> gen_flux_ids(int h, - int w, - int patch_size, - int bs, - int context_len, - const std::vector& ref_latents, - bool increase_ref_index) { + __STATIC_INLINE__ std::vector> gen_flux_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { auto txt_ids = gen_txt_ids(bs, context_len); auto img_ids = gen_img_ids(h, w, patch_size, bs); @@ -202,26 +202,26 @@ struct Rope { } // Generate flux positional embeddings - static std::vector gen_flux_pe(int h, - int w, - int patch_size, - int bs, - int context_len, - const std::vector& ref_latents, - bool increase_ref_index, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector gen_flux_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index, + int theta, + const std::vector& axes_dim) { std::vector> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); return embed_nd(ids, bs, theta, axes_dim); } - static std::vector> gen_qwen_image_ids(int h, - int w, - int patch_size, - int bs, - int context_len, - const std::vector& ref_latents, - bool increase_ref_index) { + __STATIC_INLINE__ std::vector> gen_qwen_image_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; int txt_id_start = std::max(h_len, w_len); @@ -242,29 +242,29 @@ struct Rope { } // Generate qwen_image positional embeddings - static std::vector gen_qwen_image_pe(int h, - int w, - int patch_size, - int bs, - int context_len, - const std::vector& ref_latents, - bool increase_ref_index, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector gen_qwen_image_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index, + int theta, + const std::vector& axes_dim) { std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); return embed_nd(ids, bs, theta, axes_dim); } - static std::vector> gen_vid_ids(int t, - int h, - int w, - int pt, - int ph, - int pw, - int bs, - int t_offset = 0, - int h_offset = 0, - int w_offset = 0) { + __STATIC_INLINE__ std::vector> gen_vid_ids(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int t_offset = 0, + int h_offset = 0, + int w_offset = 0) { int t_len = (t + (pt / 2)) / pt; int h_len = (h + (ph / 2)) / ph; int w_len = (w + (pw / 2)) / pw; @@ -296,18 +296,115 @@ struct Rope { } // Generate wan positional embeddings - static std::vector gen_wan_pe(int t, - int h, - int w, - int pt, - int ph, - int pw, - int bs, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector gen_wan_pe(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int theta, + const std::vector& axes_dim) { std::vector> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); return embed_nd(ids, bs, theta, axes_dim); } -}; // struct Rope + + __STATIC_INLINE__ std::vector> gen_qwen2vl_ids(int grid_h, + int grid_w, + int merge_size, + const std::vector& window_index) { + std::vector> ids(grid_h * grid_w, std::vector(2, 0.0)); + int index = 0; + for (int ih = 0; ih < grid_h; ih += merge_size) { + for (int iw = 0; iw < grid_w; iw += merge_size) { + for (int iy = 0; iy < merge_size; iy++) { + for (int ix = 0; ix < merge_size; ix++) { + int inverse_index = window_index[index / (merge_size * merge_size)]; + int i = inverse_index * (merge_size * merge_size) + index % (merge_size * merge_size); + + GGML_ASSERT(i < grid_h * grid_w); + + ids[i][0] = ih + iy; + ids[i][1] = iw + ix; + index++; + } + } + } + } + return ids; + } + + // Generate qwen2vl positional embeddings + __STATIC_INLINE__ std::vector gen_qwen2vl_pe(int grid_h, + int grid_w, + int merge_size, + const std::vector& window_index, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index); + return embed_nd(ids, 1, theta, axes_dim); + } + + __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + bool rope_interleaved = true) { + // x: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2], [[cos, -sin], [sin, cos]] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] + if (rope_interleaved) { + x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + } else { + x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2] + } + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] + x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + + auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] + if (!rope_interleaved) { + x_out = ggml_cont(ctx, ggml_permute(ctx, x_out, 1, 0, 2, 3)); // [N * n_head, L, x, d_head/2] + } + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head] + return x_out; + } + + __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + struct ggml_tensor* pe, + struct ggml_tensor* mask, + bool flash_attn, + float kv_scale = 1.0f, + bool rope_interleaved = true) { + // q,k,v: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + // return: [N, L, n_head*d_head] + q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] + k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] + + auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] + return x; + } +}; // namespace Rope #endif // __ROPE_HPP__ diff --git a/wan.hpp b/wan.hpp index af829b1..31fa90b 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1333,7 +1333,7 @@ namespace WAN { k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x;