mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add Qwen Image support (#851)
* add qwen tokenizer * add qwen2.5 vl support * mv qwen.hpp -> qwenvl.hpp * add qwen image model * add qwen image t2i pipeline * fix qwen image flash attn * add qwen image i2i pipline * change encoding of vocab_qwen.hpp to utf8 * fix get_first_stage_encoding * apply jeffbolz f32 patch https://github.com/leejet/stable-diffusion.cpp/pull/851#issuecomment-3335515302 * fix the issue that occurs when using CUDA with k-quants weights * optimize the handling of the FeedForward precision fix * to_add_out precision fix * update docs
This commit is contained in:
parent
aa68b875b9
commit
beb99a2de2
@ -21,6 +21,7 @@ API and command-line option may change frequently.***
|
|||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
|
- [Qwen Image](./docs/qwen_image.md)
|
||||||
- Image Edit Models
|
- Image Edit Models
|
||||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||||
- Video Models
|
- Video Models
|
||||||
@ -296,6 +297,7 @@ arguments:
|
|||||||
--clip_g path to the clip-g text encoder
|
--clip_g path to the clip-g text encoder
|
||||||
--clip_vision path to the clip-vision encoder
|
--clip_vision path to the clip-vision encoder
|
||||||
--t5xxl path to the t5xxl text encoder
|
--t5xxl path to the t5xxl text encoder
|
||||||
|
--qwen2vl path to the qwen2vl text encoder
|
||||||
--vae [VAE] path to vae
|
--vae [VAE] path to vae
|
||||||
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
||||||
--control-net [CONTROL_PATH] path to control net model
|
--control-net [CONTROL_PATH] path to control net model
|
||||||
@ -464,6 +466,7 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
|
|||||||
## References
|
## References
|
||||||
|
|
||||||
- [ggml](https://github.com/ggerganov/ggml)
|
- [ggml](https://github.com/ggerganov/ggml)
|
||||||
|
- [diffusers](https://github.com/huggingface/diffusers)
|
||||||
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||||
- [sd3-ref](https://github.com/Stability-AI/sd3-ref)
|
- [sd3-ref](https://github.com/Stability-AI/sd3-ref)
|
||||||
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
|
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
|
||||||
|
|||||||
BIN
assets/qwen/example.png
Normal file
BIN
assets/qwen/example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
4
clip.hpp
4
clip.hpp
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
/*================================================== CLIPTokenizer ===================================================*/
|
/*================================================== CLIPTokenizer ===================================================*/
|
||||||
|
|
||||||
std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
|
__STATIC_INLINE__ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
|
||||||
std::regex re("<lora:([^:]+):([^>]+)>");
|
std::regex re("<lora:([^:]+):([^>]+)>");
|
||||||
std::smatch matches;
|
std::smatch matches;
|
||||||
std::unordered_map<std::string, float> filename2multiplier;
|
std::unordered_map<std::string, float> filename2multiplier;
|
||||||
@ -31,7 +31,7 @@ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remov
|
|||||||
return std::make_pair(filename2multiplier, text);
|
return std::make_pair(filename2multiplier, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||||
std::set<int> byte_set;
|
std::set<int> byte_set;
|
||||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||||
|
|||||||
44
common.hpp
44
common.hpp
@ -177,7 +177,7 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class GEGLU : public GGMLBlock {
|
class GEGLU : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t dim_in;
|
int64_t dim_in;
|
||||||
int64_t dim_out;
|
int64_t dim_out;
|
||||||
@ -216,23 +216,57 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GELU : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
GELU(int64_t dim_in, int64_t dim_out, bool bias = true) {
|
||||||
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out, bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
// x: [ne3, ne2, ne1, dim_in]
|
||||||
|
// return: [ne3, ne2, ne1, dim_out]
|
||||||
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
|
x = proj->forward(ctx, x);
|
||||||
|
x = ggml_gelu_inplace(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class FeedForward : public GGMLBlock {
|
class FeedForward : public GGMLBlock {
|
||||||
public:
|
public:
|
||||||
|
enum class Activation {
|
||||||
|
GEGLU,
|
||||||
|
GELU
|
||||||
|
};
|
||||||
FeedForward(int64_t dim,
|
FeedForward(int64_t dim,
|
||||||
int64_t dim_out,
|
int64_t dim_out,
|
||||||
int64_t mult = 4) {
|
int64_t mult = 4,
|
||||||
|
Activation activation = Activation::GEGLU,
|
||||||
|
bool precision_fix = false) {
|
||||||
int64_t inner_dim = dim * mult;
|
int64_t inner_dim = dim * mult;
|
||||||
|
if (activation == Activation::GELU) {
|
||||||
|
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GELU(dim, inner_dim));
|
||||||
|
} else {
|
||||||
|
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
|
||||||
|
}
|
||||||
|
|
||||||
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
|
|
||||||
// net_1 is nn.Dropout(), skip for inference
|
// net_1 is nn.Dropout(), skip for inference
|
||||||
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
|
float scale = 1.f;
|
||||||
|
if (precision_fix) {
|
||||||
|
scale = 1.f / 128.f;
|
||||||
|
}
|
||||||
|
// The purpose of the scale here is to prevent NaN issues in certain situations.
|
||||||
|
// For example, when using Vulkan without enabling force_prec_f32,
|
||||||
|
// or when using CUDA but the weights are k-quants.
|
||||||
|
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
// x: [ne3, ne2, ne1, dim]
|
// x: [ne3, ne2, ne1, dim]
|
||||||
// return: [ne3, ne2, ne1, dim_out]
|
// return: [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]);
|
auto net_0 = std::dynamic_pointer_cast<UnaryBlock>(blocks["net.0"]);
|
||||||
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
|
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
|
||||||
|
|
||||||
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
|
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
|
||||||
|
|||||||
192
conditioner.hpp
192
conditioner.hpp
@ -2,6 +2,7 @@
|
|||||||
#define __CONDITIONER_HPP__
|
#define __CONDITIONER_HPP__
|
||||||
|
|
||||||
#include "clip.hpp"
|
#include "clip.hpp"
|
||||||
|
#include "qwenvl.hpp"
|
||||||
#include "t5.hpp"
|
#include "t5.hpp"
|
||||||
|
|
||||||
struct SDCondition {
|
struct SDCondition {
|
||||||
@ -22,11 +23,11 @@ struct Conditioner {
|
|||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool zero_out_masked = false) = 0;
|
bool zero_out_masked = false) = 0;
|
||||||
virtual void alloc_params_buffer() = 0;
|
virtual void alloc_params_buffer() = 0;
|
||||||
virtual void free_params_buffer() = 0;
|
virtual void free_params_buffer() = 0;
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||||
virtual size_t get_params_buffer_size() = 0;
|
virtual size_t get_params_buffer_size() = 0;
|
||||||
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
const std::string& text,
|
const std::string& text,
|
||||||
@ -35,9 +36,13 @@ struct Conditioner {
|
|||||||
int height,
|
int height,
|
||||||
int num_input_imgs,
|
int num_input_imgs,
|
||||||
int adm_in_channels = -1,
|
int adm_in_channels = -1,
|
||||||
bool zero_out_masked = false) = 0;
|
bool zero_out_masked = false) {
|
||||||
|
GGML_ABORT("Not implemented yet!");
|
||||||
|
}
|
||||||
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||||
const std::string& prompt) = 0;
|
const std::string& prompt) {
|
||||||
|
GGML_ABORT("Not implemented yet!");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
@ -978,23 +983,6 @@ struct SD3CLIPEmbedder : public Conditioner {
|
|||||||
auto tokens_and_weights = tokenize(text, 77, true);
|
auto tokens_and_weights = tokenize(text, 77, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
|
||||||
int n_threads,
|
|
||||||
const std::string& text,
|
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int num_input_imgs,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
|
||||||
const std::string& prompt) {
|
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FluxCLIPEmbedder : public Conditioner {
|
struct FluxCLIPEmbedder : public Conditioner {
|
||||||
@ -1195,23 +1183,6 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
|
||||||
int n_threads,
|
|
||||||
const std::string& text,
|
|
||||||
int clip_skip,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int num_input_imgs,
|
|
||||||
int adm_in_channels = -1,
|
|
||||||
bool zero_out_masked = false) {
|
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
|
||||||
const std::string& prompt) {
|
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct T5CLIPEmbedder : public Conditioner {
|
struct T5CLIPEmbedder : public Conditioner {
|
||||||
@ -1398,22 +1369,135 @@ struct T5CLIPEmbedder : public Conditioner {
|
|||||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
||||||
int n_threads,
|
Qwen::Qwen2Tokenizer tokenizer;
|
||||||
const std::string& text,
|
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
|
||||||
int clip_skip,
|
int prompt_template_encode_start_idx = 34;
|
||||||
int width,
|
|
||||||
int height,
|
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
|
||||||
int num_input_imgs,
|
bool offload_params_to_cpu,
|
||||||
int adm_in_channels = -1,
|
const String2GGMLType& tensor_types = {},
|
||||||
bool zero_out_masked = false) {
|
const std::string prefix = "") {
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||||
const std::string& prompt) {
|
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
|
||||||
GGML_ASSERT(0 && "Not implemented yet!");
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() {
|
||||||
|
qwenvl->alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_params_buffer() {
|
||||||
|
qwenvl->free_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_params_buffer_size() {
|
||||||
|
size_t buffer_size = 0;
|
||||||
|
buffer_size += qwenvl->get_params_buffer_size();
|
||||||
|
return buffer_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||||
|
size_t max_length = 0,
|
||||||
|
bool padding = false) {
|
||||||
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
for (const auto& item : parsed_attention) {
|
||||||
|
ss << "['" << item.first << "', " << item.second << "], ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tokens;
|
||||||
|
std::vector<float> weights;
|
||||||
|
for (const auto& item : parsed_attention) {
|
||||||
|
const std::string& curr_text = item.first;
|
||||||
|
float curr_weight = item.second;
|
||||||
|
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
|
||||||
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||||
|
|
||||||
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
|
||||||
|
return {tokens, weights};
|
||||||
|
}
|
||||||
|
|
||||||
|
SDCondition get_learned_condition_common(ggml_context* work_ctx,
|
||||||
|
int n_threads,
|
||||||
|
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
|
||||||
|
int clip_skip,
|
||||||
|
bool zero_out_masked = false) {
|
||||||
|
auto& tokens = std::get<0>(token_and_weights);
|
||||||
|
auto& weights = std::get<1>(token_and_weights);
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
|
||||||
|
|
||||||
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||||
|
|
||||||
|
qwenvl->compute(n_threads,
|
||||||
|
input_ids,
|
||||||
|
&hidden_states,
|
||||||
|
work_ctx);
|
||||||
|
{
|
||||||
|
auto tensor = hidden_states;
|
||||||
|
float original_mean = ggml_tensor_mean(tensor);
|
||||||
|
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||||
|
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||||
|
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
|
||||||
|
value *= weights[i1];
|
||||||
|
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float new_mean = ggml_tensor_mean(tensor);
|
||||||
|
ggml_tensor_scale(tensor, (original_mean / new_mean));
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||||
|
|
||||||
|
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
hidden_states->ne[0],
|
||||||
|
hidden_states->ne[1] - prompt_template_encode_start_idx,
|
||||||
|
hidden_states->ne[2]);
|
||||||
|
|
||||||
|
ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
|
||||||
|
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
|
return SDCondition(new_hidden_states, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||||
|
int n_threads,
|
||||||
|
const std::string& text,
|
||||||
|
int clip_skip,
|
||||||
|
int width,
|
||||||
|
int height,
|
||||||
|
int adm_in_channels = -1,
|
||||||
|
bool zero_out_masked = false) {
|
||||||
|
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
|
auto tokens_and_weights = tokenize(prompt, 0, false);
|
||||||
|
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include "flux.hpp"
|
#include "flux.hpp"
|
||||||
#include "mmdit.hpp"
|
#include "mmdit.hpp"
|
||||||
|
#include "qwen_image.hpp"
|
||||||
#include "unet.hpp"
|
#include "unet.hpp"
|
||||||
#include "wan.hpp"
|
#include "wan.hpp"
|
||||||
|
|
||||||
@ -263,4 +264,58 @@ struct WanModel : public DiffusionModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct QwenImageModel : public DiffusionModel {
|
||||||
|
std::string prefix;
|
||||||
|
Qwen::QwenImageRunner qwen_image;
|
||||||
|
|
||||||
|
QwenImageModel(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
const std::string prefix = "model.diffusion_model",
|
||||||
|
SDVersion version = VERSION_QWEN_IMAGE,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return qwen_image.get_desc();
|
||||||
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() {
|
||||||
|
qwen_image.alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_params_buffer() {
|
||||||
|
qwen_image.free_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_compute_buffer() {
|
||||||
|
qwen_image.free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||||
|
qwen_image.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_params_buffer_size() {
|
||||||
|
return qwen_image.get_params_buffer_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t get_adm_in_channels() {
|
||||||
|
return 768;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(int n_threads,
|
||||||
|
DiffusionParams diffusion_params,
|
||||||
|
struct ggml_tensor** output = NULL,
|
||||||
|
struct ggml_context* output_ctx = NULL) {
|
||||||
|
return qwen_image.compute(n_threads,
|
||||||
|
diffusion_params.x,
|
||||||
|
diffusion_params.timesteps,
|
||||||
|
diffusion_params.context,
|
||||||
|
output,
|
||||||
|
output_ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
23
docs/qwen_image.md
Normal file
23
docs/qwen_image.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# How to Use
|
||||||
|
|
||||||
|
## Download weights
|
||||||
|
|
||||||
|
- Download Qwen Image
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/diffusion_models
|
||||||
|
- gguf: https://huggingface.co/QuantStack/Qwen-Image-GGUF/tree/main
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
|
||||||
|
- Download qwen_2.5_vl 7b
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
|
||||||
|
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="qwen example" src="../assets/qwen/example.png" />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +61,7 @@ struct SDParams {
|
|||||||
std::string clip_g_path;
|
std::string clip_g_path;
|
||||||
std::string clip_vision_path;
|
std::string clip_vision_path;
|
||||||
std::string t5xxl_path;
|
std::string t5xxl_path;
|
||||||
|
std::string qwen2vl_path;
|
||||||
std::string diffusion_model_path;
|
std::string diffusion_model_path;
|
||||||
std::string high_noise_diffusion_model_path;
|
std::string high_noise_diffusion_model_path;
|
||||||
std::string vae_path;
|
std::string vae_path;
|
||||||
@ -146,6 +147,7 @@ void print_params(SDParams params) {
|
|||||||
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
||||||
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
||||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||||
|
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
|
||||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||||
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
|
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
|
||||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||||
@ -217,6 +219,7 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" --clip_g path to the clip-g text encoder\n");
|
printf(" --clip_g path to the clip-g text encoder\n");
|
||||||
printf(" --clip_vision path to the clip-vision encoder\n");
|
printf(" --clip_vision path to the clip-vision encoder\n");
|
||||||
printf(" --t5xxl path to the t5xxl text encoder\n");
|
printf(" --t5xxl path to the t5xxl text encoder\n");
|
||||||
|
printf(" --qwen2vl path to the qwen2vl text encoder\n");
|
||||||
printf(" --vae [VAE] path to vae\n");
|
printf(" --vae [VAE] path to vae\n");
|
||||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||||
printf(" --control-net [CONTROL_PATH] path to control net model\n");
|
printf(" --control-net [CONTROL_PATH] path to control net model\n");
|
||||||
@ -486,6 +489,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
|||||||
{"", "--clip_g", "", ¶ms.clip_g_path},
|
{"", "--clip_g", "", ¶ms.clip_g_path},
|
||||||
{"", "--clip_vision", "", ¶ms.clip_vision_path},
|
{"", "--clip_vision", "", ¶ms.clip_vision_path},
|
||||||
{"", "--t5xxl", "", ¶ms.t5xxl_path},
|
{"", "--t5xxl", "", ¶ms.t5xxl_path},
|
||||||
|
{"", "--qwen2vl", "", ¶ms.qwen2vl_path},
|
||||||
{"", "--diffusion-model", "", ¶ms.diffusion_model_path},
|
{"", "--diffusion-model", "", ¶ms.diffusion_model_path},
|
||||||
{"", "--high-noise-diffusion-model", "", ¶ms.high_noise_diffusion_model_path},
|
{"", "--high-noise-diffusion-model", "", ¶ms.high_noise_diffusion_model_path},
|
||||||
{"", "--vae", "", ¶ms.vae_path},
|
{"", "--vae", "", ¶ms.vae_path},
|
||||||
@ -948,7 +952,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
|||||||
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
|
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
|
||||||
}
|
}
|
||||||
parameter_string += ", ";
|
parameter_string += ", ";
|
||||||
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
|
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
|
||||||
if (!te.empty()) {
|
if (!te.empty()) {
|
||||||
parameter_string += "TE: " + sd_basename(te) + ", ";
|
parameter_string += "TE: " + sd_basename(te) + ", ";
|
||||||
}
|
}
|
||||||
@ -1331,6 +1335,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.clip_g_path.c_str(),
|
params.clip_g_path.c_str(),
|
||||||
params.clip_vision_path.c_str(),
|
params.clip_vision_path.c_str(),
|
||||||
params.t5xxl_path.c_str(),
|
params.t5xxl_path.c_str(),
|
||||||
|
params.qwen2vl_path.c_str(),
|
||||||
params.diffusion_model_path.c_str(),
|
params.diffusion_model_path.c_str(),
|
||||||
params.high_noise_diffusion_model_path.c_str(),
|
params.high_noise_diffusion_model_path.c_str(),
|
||||||
params.vae_path.c_str(),
|
params.vae_path.c_str(),
|
||||||
|
|||||||
5
flux.hpp
5
flux.hpp
@ -120,14 +120,15 @@ namespace Flux {
|
|||||||
struct ggml_tensor* v,
|
struct ggml_tensor* v,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask,
|
struct ggml_tensor* mask,
|
||||||
bool flash_attn) {
|
bool flash_attn,
|
||||||
|
float kv_scale = 1.0f) {
|
||||||
// q,k,v: [N, L, n_head, d_head]
|
// q,k,v: [N, L, n_head, d_head]
|
||||||
// pe: [L, d_head/2, 2, 2]
|
// pe: [L, d_head/2, 2, 2]
|
||||||
// return: [N, L, n_head*d_head]
|
// return: [N, L, n_head*d_head]
|
||||||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||||
|
|
||||||
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
|
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
124
ggml_extend.hpp
124
ggml_extend.hpp
@ -56,6 +56,10 @@
|
|||||||
#define __STATIC_INLINE__ static inline
|
#define __STATIC_INLINE__ static inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef SD_UNUSED
|
||||||
|
#define SD_UNUSED(x) (void)(x)
|
||||||
|
#endif
|
||||||
|
|
||||||
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
|
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
|
||||||
switch (level) {
|
switch (level) {
|
||||||
case GGML_LOG_LEVEL_DEBUG:
|
case GGML_LOG_LEVEL_DEBUG:
|
||||||
@ -939,8 +943,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ct
|
|||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* w,
|
struct ggml_tensor* w,
|
||||||
struct ggml_tensor* b) {
|
struct ggml_tensor* b,
|
||||||
|
bool force_prec_f32 = false,
|
||||||
|
float scale = 1.f) {
|
||||||
|
if (scale != 1.f) {
|
||||||
|
x = ggml_scale(ctx, x, scale);
|
||||||
|
}
|
||||||
x = ggml_mul_mat(ctx, w, x);
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
|
if (force_prec_f32) {
|
||||||
|
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
if (scale != 1.f) {
|
||||||
|
x = ggml_scale(ctx, x, 1.f / scale);
|
||||||
|
}
|
||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
x = ggml_add_inplace(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
@ -1125,9 +1140,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
|
|||||||
return kqv;
|
return kqv;
|
||||||
}
|
}
|
||||||
|
|
||||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
||||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
||||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
||||||
// mask: [N, L_q, L_k]
|
// mask: [N, L_q, L_k]
|
||||||
// return: [N, L_q, C]
|
// return: [N, L_q, C]
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
||||||
@ -1139,33 +1154,38 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
struct ggml_tensor* mask = NULL,
|
struct ggml_tensor* mask = NULL,
|
||||||
bool diag_mask_inf = false,
|
bool diag_mask_inf = false,
|
||||||
bool skip_reshape = false,
|
bool skip_reshape = false,
|
||||||
bool flash_attn = false) {
|
bool flash_attn = false, // avoid overflow
|
||||||
|
float kv_scale = 1.0f) {
|
||||||
int64_t L_q;
|
int64_t L_q;
|
||||||
int64_t L_k;
|
int64_t L_k;
|
||||||
int64_t C;
|
int64_t C;
|
||||||
int64_t N;
|
int64_t N;
|
||||||
int64_t d_head;
|
int64_t d_head;
|
||||||
|
int64_t n_kv_head;
|
||||||
if (!skip_reshape) {
|
if (!skip_reshape) {
|
||||||
L_q = q->ne[1];
|
L_q = q->ne[1];
|
||||||
L_k = k->ne[1];
|
L_k = k->ne[1];
|
||||||
C = q->ne[0];
|
C = q->ne[0];
|
||||||
N = q->ne[2];
|
N = q->ne[2];
|
||||||
d_head = C / n_head;
|
d_head = C / n_head;
|
||||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
n_kv_head = k->ne[0] / d_head;
|
||||||
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
|
||||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
|
||||||
|
|
||||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||||
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||||
|
|
||||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
k = ggml_reshape_4d(ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
|
||||||
|
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_kv_head, L_k, d_head]
|
||||||
|
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head]
|
||||||
|
|
||||||
|
v = ggml_reshape_4d(ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
|
||||||
} else {
|
} else {
|
||||||
L_q = q->ne[1];
|
L_q = q->ne[1];
|
||||||
L_k = k->ne[1];
|
L_k = k->ne[1];
|
||||||
d_head = v->ne[0];
|
d_head = v->ne[0];
|
||||||
N = v->ne[3];
|
N = v->ne[3];
|
||||||
C = d_head * n_head;
|
n_kv_head = k->ne[2] / N;
|
||||||
|
C = d_head * n_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
float scale = (1.0f / sqrt((float)d_head));
|
float scale = (1.0f / sqrt((float)d_head));
|
||||||
@ -1177,13 +1197,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
k_in = ggml_scale(ctx, k_in, kv_scale);
|
||||||
|
}
|
||||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
||||||
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_head * N);
|
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
|
||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
v_in = ggml_scale(ctx, v_in, kv_scale);
|
||||||
|
}
|
||||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
if (mask_in != nullptr) {
|
if (mask_in != nullptr) {
|
||||||
@ -1207,8 +1233,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0);
|
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
||||||
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
out = ggml_scale(ctx, out, 1.0f / kv_scale);
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1238,8 +1267,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
// if (flash_attn) {
|
// if (flash_attn) {
|
||||||
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||||
// }
|
// }
|
||||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k]
|
||||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
||||||
|
|
||||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||||
kq = ggml_scale_inplace(ctx, kq, scale);
|
kq = ggml_scale_inplace(ctx, kq, scale);
|
||||||
@ -1355,15 +1384,13 @@ __STATIC_INLINE__ std::vector<float> arange(float start, float end, float step =
|
|||||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||||
__STATIC_INLINE__ std::vector<float> timestep_embedding(std::vector<float> timesteps,
|
__STATIC_INLINE__ std::vector<float> timestep_embedding(std::vector<float> timesteps,
|
||||||
int dim,
|
int dim,
|
||||||
int max_period = 10000) {
|
int max_period = 10000,
|
||||||
|
bool flip_sin_to_cos = true,
|
||||||
|
float scale = 1.f) {
|
||||||
// timesteps: [N,]
|
// timesteps: [N,]
|
||||||
// embedding: [N, dim]
|
// embedding: [N, dim]
|
||||||
size_t N = timesteps.size();
|
size_t N = timesteps.size();
|
||||||
int acutual_dim = dim;
|
std::vector<float> embedding(N * dim, 0.f);
|
||||||
if (dim % 2 != 0) {
|
|
||||||
acutual_dim = dim + 1;
|
|
||||||
}
|
|
||||||
std::vector<float> embedding(N * acutual_dim, 0.f);
|
|
||||||
int half = dim / 2;
|
int half = dim / 2;
|
||||||
std::vector<float> freqs(half);
|
std::vector<float> freqs(half);
|
||||||
for (int i = 0; i < half; ++i) {
|
for (int i = 0; i < half; ++i) {
|
||||||
@ -1371,9 +1398,14 @@ __STATIC_INLINE__ std::vector<float> timestep_embedding(std::vector<float> times
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
for (int j = 0; j < half; ++j) {
|
for (int j = 0; j < half; ++j) {
|
||||||
float arg = timesteps[i] * freqs[j];
|
float arg = timesteps[i] * freqs[j] * scale;
|
||||||
embedding[i * acutual_dim + j] = std::cos(arg);
|
if (flip_sin_to_cos) {
|
||||||
embedding[i * acutual_dim + j + half] = std::sin(arg);
|
embedding[i * dim + j] = std::cos(arg);
|
||||||
|
embedding[i * dim + j + half] = std::sin(arg);
|
||||||
|
} else {
|
||||||
|
embedding[i * dim + j] = std::sin(arg);
|
||||||
|
embedding[i * dim + j + half] = std::cos(arg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return embedding;
|
return embedding;
|
||||||
@ -1394,11 +1426,7 @@ __STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context
|
|||||||
// timesteps: [N,]
|
// timesteps: [N,]
|
||||||
// embedding: [N, dim]
|
// embedding: [N, dim]
|
||||||
std::vector<float> embedding_vec = timestep_embedding(timesteps, dim, max_period);
|
std::vector<float> embedding_vec = timestep_embedding(timesteps, dim, max_period);
|
||||||
int acutual_dim = dim;
|
struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps.size());
|
||||||
if (dim % 2 != 0) {
|
|
||||||
acutual_dim = dim + 1;
|
|
||||||
}
|
|
||||||
struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, acutual_dim, timesteps.size());
|
|
||||||
if (embedding->data != NULL) {
|
if (embedding->data != NULL) {
|
||||||
memcpy(((char*)embedding->data), ((char*)embedding_vec.data()), ggml_nbytes(embedding));
|
memcpy(((char*)embedding->data), ((char*)embedding_vec.data()), ggml_nbytes(embedding));
|
||||||
} else {
|
} else {
|
||||||
@ -1940,6 +1968,8 @@ protected:
|
|||||||
int64_t out_features;
|
int64_t out_features;
|
||||||
bool bias;
|
bool bias;
|
||||||
bool force_f32;
|
bool force_f32;
|
||||||
|
bool force_prec_f32;
|
||||||
|
float scale;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||||
@ -1956,12 +1986,16 @@ protected:
|
|||||||
public:
|
public:
|
||||||
Linear(int64_t in_features,
|
Linear(int64_t in_features,
|
||||||
int64_t out_features,
|
int64_t out_features,
|
||||||
bool bias = true,
|
bool bias = true,
|
||||||
bool force_f32 = false)
|
bool force_f32 = false,
|
||||||
|
bool force_prec_f32 = false,
|
||||||
|
float scale = 1.f)
|
||||||
: in_features(in_features),
|
: in_features(in_features),
|
||||||
out_features(out_features),
|
out_features(out_features),
|
||||||
bias(bias),
|
bias(bias),
|
||||||
force_f32(force_f32) {}
|
force_f32(force_f32),
|
||||||
|
force_prec_f32(force_prec_f32),
|
||||||
|
scale(scale) {}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
@ -1969,7 +2003,7 @@ public:
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
return ggml_nn_linear(ctx, x, w, b);
|
return ggml_nn_linear(ctx, x, w, b, force_prec_f32, scale);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
41
model.cpp
41
model.cpp
@ -17,6 +17,7 @@
|
|||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
#include "vocab.hpp"
|
#include "vocab.hpp"
|
||||||
|
#include "vocab_qwen.hpp"
|
||||||
#include "vocab_umt5.hpp"
|
#include "vocab_umt5.hpp"
|
||||||
|
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
@ -110,6 +111,9 @@ const char* unused_tensors[] = {
|
|||||||
"embedding_manager",
|
"embedding_manager",
|
||||||
"denoiser.sigmas",
|
"denoiser.sigmas",
|
||||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||||
|
"text_encoders.qwen2vl.output.weight",
|
||||||
|
"text_encoders.qwen2vl.lm_head.",
|
||||||
|
"text_encoders.qwen2vl.visual.",
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_unused_tensor(std::string name) {
|
bool is_unused_tensor(std::string name) {
|
||||||
@ -193,6 +197,21 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
|||||||
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> qwenvl_name_map{
|
||||||
|
{"token_embd.", "model.embed_tokens."},
|
||||||
|
{"blk.", "model.layers."},
|
||||||
|
{"attn_q.", "self_attn.q_proj."},
|
||||||
|
{"attn_k.", "self_attn.k_proj."},
|
||||||
|
{"attn_v.", "self_attn.v_proj."},
|
||||||
|
{"attn_output.", "self_attn.o_proj."},
|
||||||
|
{"attn_norm.", "input_layernorm."},
|
||||||
|
{"ffn_down.", "mlp.down_proj."},
|
||||||
|
{"ffn_gate.", "mlp.gate_proj."},
|
||||||
|
{"ffn_up.", "mlp.up_proj."},
|
||||||
|
{"ffn_norm.", "post_attention_layernorm."},
|
||||||
|
{"output_norm.", "model.norm."},
|
||||||
|
};
|
||||||
|
|
||||||
std::string convert_cond_model_name(const std::string& name) {
|
std::string convert_cond_model_name(const std::string& name) {
|
||||||
std::string new_name = name;
|
std::string new_name = name;
|
||||||
std::string prefix;
|
std::string prefix;
|
||||||
@ -250,6 +269,13 @@ std::string convert_cond_model_name(const std::string& name) {
|
|||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
|
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
|
||||||
}
|
}
|
||||||
|
} else if (contains(name, "qwen2vl")) {
|
||||||
|
for (auto kv : qwenvl_name_map) {
|
||||||
|
size_t pos = new_name.find(kv.first);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
new_name.replace(pos, kv.first.size(), kv.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
|
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
|
||||||
new_name = "text_encoders.t5xxl.transformer.shared.weight";
|
new_name = "text_encoders.t5xxl.transformer.shared.weight";
|
||||||
}
|
}
|
||||||
@ -580,7 +606,11 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
// name.replace(pos, strlen("lora_B"), "lora_down");
|
// name.replace(pos, strlen("lora_B"), "lora_down");
|
||||||
// }
|
// }
|
||||||
std::string new_name = name;
|
std::string new_name = name;
|
||||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
if (starts_with(name, "cond_stage_model.") ||
|
||||||
|
starts_with(name, "conditioner.embedders.") ||
|
||||||
|
starts_with(name, "text_encoders.") ||
|
||||||
|
ends_with(name, ".vision_model.visual_projection.weight") ||
|
||||||
|
starts_with(name, "qwen2vl")) {
|
||||||
new_name = convert_cond_model_name(name);
|
new_name = convert_cond_model_name(name);
|
||||||
} else if (starts_with(name, "first_stage_model.decoder")) {
|
} else if (starts_with(name, "first_stage_model.decoder")) {
|
||||||
new_name = convert_vae_decoder_name(name);
|
new_name = convert_vae_decoder_name(name);
|
||||||
@ -699,6 +729,7 @@ void preprocess_tensor(TensorStorage tensor_storage,
|
|||||||
|
|
||||||
// convert unet transformer linear to conv2d 1x1
|
// convert unet transformer linear to conv2d 1x1
|
||||||
if (starts_with(new_name, "model.diffusion_model.") &&
|
if (starts_with(new_name, "model.diffusion_model.") &&
|
||||||
|
!starts_with(new_name, "model.diffusion_model.proj_out.") &&
|
||||||
(ends_with(new_name, "proj_in.weight") || ends_with(new_name, "proj_out.weight"))) {
|
(ends_with(new_name, "proj_in.weight") || ends_with(new_name, "proj_out.weight"))) {
|
||||||
tensor_storage.unsqueeze();
|
tensor_storage.unsqueeze();
|
||||||
}
|
}
|
||||||
@ -1732,6 +1763,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||||
return VERSION_SD3;
|
return VERSION_SD3;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
|
||||||
|
return VERSION_QWEN_IMAGE;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||||
is_wan = true;
|
is_wan = true;
|
||||||
}
|
}
|
||||||
@ -1945,6 +1979,11 @@ std::string ModelLoader::load_merges() {
|
|||||||
return merges_utf8_str;
|
return merges_utf8_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ModelLoader::load_qwen2_merges() {
|
||||||
|
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
|
||||||
|
return merges_utf8_str;
|
||||||
|
}
|
||||||
|
|
||||||
std::string ModelLoader::load_t5_tokenizer_json() {
|
std::string ModelLoader::load_t5_tokenizer_json() {
|
||||||
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
|
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
|
||||||
return json_str;
|
return json_str;
|
||||||
|
|||||||
14
model.h
14
model.h
@ -36,6 +36,7 @@ enum SDVersion {
|
|||||||
VERSION_WAN2,
|
VERSION_WAN2,
|
||||||
VERSION_WAN2_2_I2V,
|
VERSION_WAN2_2_I2V,
|
||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V,
|
||||||
|
VERSION_QWEN_IMAGE,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -81,6 +82,13 @@ static inline bool sd_version_is_wan(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_qwen_image(SDVersion version) {
|
||||||
|
if (version == VERSION_QWEN_IMAGE) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||||
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
|
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
|
||||||
return true;
|
return true;
|
||||||
@ -89,7 +97,10 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_dit(SDVersion version) {
|
static inline bool sd_version_is_dit(SDVersion version) {
|
||||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
|
if (sd_version_is_flux(version) ||
|
||||||
|
sd_version_is_sd3(version) ||
|
||||||
|
sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_qwen_image(version)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -272,6 +283,7 @@ public:
|
|||||||
~ModelLoader() = default;
|
~ModelLoader() = default;
|
||||||
|
|
||||||
static std::string load_merges();
|
static std::string load_merges();
|
||||||
|
static std::string load_qwen2_merges();
|
||||||
static std::string load_t5_tokenizer_json();
|
static std::string load_t5_tokenizer_json();
|
||||||
static std::string load_umt5_tokenizer_json();
|
static std::string load_umt5_tokenizer_json();
|
||||||
};
|
};
|
||||||
|
|||||||
643
qwen_image.hpp
Normal file
643
qwen_image.hpp
Normal file
@ -0,0 +1,643 @@
|
|||||||
|
#ifndef __QWEN_IMAGE_HPP__
|
||||||
|
#define __QWEN_IMAGE_HPP__
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
#include "flux.hpp"
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
|
namespace Qwen {
|
||||||
|
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
|
||||||
|
|
||||||
|
struct TimestepEmbedding : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
TimestepEmbedding(int64_t in_channels,
|
||||||
|
int64_t time_embed_dim,
|
||||||
|
int64_t out_dim = 0,
|
||||||
|
int64_t cond_proj_dim = 0,
|
||||||
|
bool sample_proj_bias = true) {
|
||||||
|
blocks["linear_1"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, time_embed_dim, sample_proj_bias));
|
||||||
|
if (cond_proj_dim > 0) {
|
||||||
|
blocks["cond_proj"] = std::shared_ptr<GGMLBlock>(new Linear(cond_proj_dim, in_channels, false));
|
||||||
|
}
|
||||||
|
if (out_dim <= 0) {
|
||||||
|
out_dim = time_embed_dim;
|
||||||
|
}
|
||||||
|
blocks["linear_2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, out_dim, sample_proj_bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* sample,
|
||||||
|
struct ggml_tensor* condition = nullptr) {
|
||||||
|
if (condition != nullptr) {
|
||||||
|
auto cond_proj = std::dynamic_pointer_cast<Linear>(blocks["cond_proj"]);
|
||||||
|
sample = ggml_add(ctx, sample, cond_proj->forward(ctx, condition));
|
||||||
|
}
|
||||||
|
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
|
||||||
|
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
|
||||||
|
|
||||||
|
sample = linear_1->forward(ctx, sample);
|
||||||
|
sample = ggml_silu_inplace(ctx, sample);
|
||||||
|
sample = linear_2->forward(ctx, sample);
|
||||||
|
return sample;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct QwenTimestepProjEmbeddings : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
QwenTimestepProjEmbeddings(int64_t embedding_dim) {
|
||||||
|
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* timesteps) {
|
||||||
|
// timesteps: [N,]
|
||||||
|
// return: [N, embedding_dim]
|
||||||
|
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
||||||
|
|
||||||
|
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
|
||||||
|
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
||||||
|
return timesteps_emb;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct QwenImageAttention : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim_head;
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
|
public:
|
||||||
|
QwenImageAttention(int64_t query_dim,
|
||||||
|
int64_t dim_head,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t out_dim = 0,
|
||||||
|
int64_t out_context_dim = 0,
|
||||||
|
bool bias = true,
|
||||||
|
bool out_bias = true,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: dim_head(dim_head), flash_attn(flash_attn) {
|
||||||
|
int64_t inner_dim = out_dim > 0 ? out_dim : dim_head * num_heads;
|
||||||
|
out_dim = out_dim > 0 ? out_dim : query_dim;
|
||||||
|
out_context_dim = out_context_dim > 0 ? out_context_dim : query_dim;
|
||||||
|
|
||||||
|
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
|
||||||
|
blocks["norm_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
|
||||||
|
blocks["norm_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
|
||||||
|
|
||||||
|
blocks["add_q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
blocks["add_k_proj"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
blocks["add_v_proj"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, bias));
|
||||||
|
|
||||||
|
blocks["norm_added_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
|
||||||
|
blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
|
||||||
|
|
||||||
|
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias));
|
||||||
|
// to_out.1 is nn.Dropout
|
||||||
|
|
||||||
|
float scale = 1.f / 32.f;
|
||||||
|
// The purpose of the scale here is to prevent NaN issues in certain situations.
|
||||||
|
// For example when using CUDA but the weights are k-quants (not all prompts).
|
||||||
|
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* img,
|
||||||
|
struct ggml_tensor* txt,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* mask = nullptr) {
|
||||||
|
// img: [N, n_img_token, hidden_size]
|
||||||
|
// txt: [N, n_txt_token, hidden_size]
|
||||||
|
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
|
||||||
|
// return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size])
|
||||||
|
|
||||||
|
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
|
||||||
|
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
|
||||||
|
|
||||||
|
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||||
|
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
|
||||||
|
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
|
||||||
|
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
|
||||||
|
|
||||||
|
auto norm_added_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_q"]);
|
||||||
|
auto norm_added_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_k"]);
|
||||||
|
|
||||||
|
auto add_q_proj = std::dynamic_pointer_cast<Linear>(blocks["add_q_proj"]);
|
||||||
|
auto add_k_proj = std::dynamic_pointer_cast<Linear>(blocks["add_k_proj"]);
|
||||||
|
auto add_v_proj = std::dynamic_pointer_cast<Linear>(blocks["add_v_proj"]);
|
||||||
|
auto to_add_out = std::dynamic_pointer_cast<Linear>(blocks["to_add_out"]);
|
||||||
|
|
||||||
|
int64_t N = img->ne[2];
|
||||||
|
int64_t n_img_token = img->ne[1];
|
||||||
|
int64_t n_txt_token = txt->ne[1];
|
||||||
|
|
||||||
|
auto img_q = to_q->forward(ctx, img);
|
||||||
|
int64_t num_heads = img_q->ne[0] / dim_head;
|
||||||
|
img_q = ggml_reshape_4d(ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
|
auto img_k = to_k->forward(ctx, img);
|
||||||
|
img_k = ggml_reshape_4d(ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
|
auto img_v = to_v->forward(ctx, img);
|
||||||
|
img_v = ggml_reshape_4d(ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head]
|
||||||
|
|
||||||
|
img_q = norm_q->forward(ctx, img_q);
|
||||||
|
img_k = norm_k->forward(ctx, img_k);
|
||||||
|
|
||||||
|
auto txt_q = add_q_proj->forward(ctx, txt);
|
||||||
|
txt_q = ggml_reshape_4d(ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
|
auto txt_k = add_k_proj->forward(ctx, txt);
|
||||||
|
txt_k = ggml_reshape_4d(ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
|
auto txt_v = add_v_proj->forward(ctx, txt);
|
||||||
|
txt_v = ggml_reshape_4d(ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head]
|
||||||
|
|
||||||
|
txt_q = norm_added_q->forward(ctx, txt_q);
|
||||||
|
txt_k = norm_added_k->forward(ctx, txt_k);
|
||||||
|
|
||||||
|
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
|
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
|
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
|
auto txt_attn_out = ggml_view_3d(ctx,
|
||||||
|
attn,
|
||||||
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
|
txt->ne[1],
|
||||||
|
attn->nb[1],
|
||||||
|
attn->nb[2],
|
||||||
|
0); // [n_txt_token, N, hidden_size]
|
||||||
|
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
|
||||||
|
auto img_attn_out = ggml_view_3d(ctx,
|
||||||
|
attn,
|
||||||
|
attn->ne[0],
|
||||||
|
attn->ne[1],
|
||||||
|
img->ne[1],
|
||||||
|
attn->nb[1],
|
||||||
|
attn->nb[2],
|
||||||
|
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
|
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
|
img_attn_out = to_out_0->forward(ctx, img_attn_out);
|
||||||
|
txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
|
||||||
|
|
||||||
|
return {img_attn_out, txt_attn_out};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class QwenImageTransformerBlock : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
QwenImageTransformerBlock(int64_t dim,
|
||||||
|
int64_t num_attention_heads,
|
||||||
|
int64_t attention_head_dim,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false) {
|
||||||
|
// img_mod.0 is nn.SiLU()
|
||||||
|
blocks["img_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
|
||||||
|
|
||||||
|
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
|
||||||
|
|
||||||
|
// txt_mod.0 is nn.SiLU()
|
||||||
|
blocks["txt_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
|
||||||
|
|
||||||
|
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU));
|
||||||
|
|
||||||
|
blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim,
|
||||||
|
attention_head_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
0, // out_dim
|
||||||
|
0, // out_context-dim
|
||||||
|
true, // bias
|
||||||
|
true, // out_bias
|
||||||
|
eps,
|
||||||
|
flash_attn));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* img,
|
||||||
|
struct ggml_tensor* txt,
|
||||||
|
struct ggml_tensor* t_emb,
|
||||||
|
struct ggml_tensor* pe) {
|
||||||
|
// img: [N, n_img_token, hidden_size]
|
||||||
|
// txt: [N, n_txt_token, hidden_size]
|
||||||
|
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
|
||||||
|
// return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size])
|
||||||
|
|
||||||
|
auto img_mod_1 = std::dynamic_pointer_cast<Linear>(blocks["img_mod.1"]);
|
||||||
|
auto img_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm1"]);
|
||||||
|
auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm2"]);
|
||||||
|
auto img_mlp = std::dynamic_pointer_cast<FeedForward>(blocks["img_mlp"]);
|
||||||
|
|
||||||
|
auto txt_mod_1 = std::dynamic_pointer_cast<Linear>(blocks["txt_mod.1"]);
|
||||||
|
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
|
||||||
|
auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm2"]);
|
||||||
|
auto txt_mlp = std::dynamic_pointer_cast<FeedForward>(blocks["txt_mlp"]);
|
||||||
|
|
||||||
|
auto attn = std::dynamic_pointer_cast<QwenImageAttention>(blocks["attn"]);
|
||||||
|
|
||||||
|
auto img_mod_params = ggml_silu(ctx, t_emb);
|
||||||
|
img_mod_params = img_mod_1->forward(ctx, img_mod_params);
|
||||||
|
auto img_mod_param_vec = ggml_chunk(ctx, img_mod_params, 6, 0);
|
||||||
|
|
||||||
|
auto txt_mod_params = ggml_silu(ctx, t_emb);
|
||||||
|
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
|
||||||
|
auto txt_mod_param_vec = ggml_chunk(ctx, txt_mod_params, 6, 0);
|
||||||
|
|
||||||
|
auto img_normed = img_norm1->forward(ctx, img);
|
||||||
|
auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
|
||||||
|
auto img_gate1 = img_mod_param_vec[2];
|
||||||
|
|
||||||
|
auto txt_normed = txt_norm1->forward(ctx, txt);
|
||||||
|
auto txt_modulated = Flux::modulate(ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]);
|
||||||
|
auto txt_gate1 = txt_mod_param_vec[2];
|
||||||
|
|
||||||
|
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe);
|
||||||
|
|
||||||
|
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1));
|
||||||
|
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1));
|
||||||
|
|
||||||
|
auto img_normed2 = img_norm2->forward(ctx, img);
|
||||||
|
auto img_modulated2 = Flux::modulate(ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
|
||||||
|
auto img_gate2 = img_mod_param_vec[5];
|
||||||
|
|
||||||
|
auto txt_normed2 = txt_norm2->forward(ctx, txt);
|
||||||
|
auto txt_modulated2 = Flux::modulate(ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]);
|
||||||
|
auto txt_gate2 = txt_mod_param_vec[5];
|
||||||
|
|
||||||
|
auto img_mlp_out = img_mlp->forward(ctx, img_modulated2);
|
||||||
|
auto txt_mlp_out = txt_mlp->forward(ctx, txt_modulated2);
|
||||||
|
|
||||||
|
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_gate2));
|
||||||
|
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_gate2));
|
||||||
|
|
||||||
|
return {img, txt};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct AdaLayerNormContinuous : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
AdaLayerNormContinuous(int64_t embedding_dim,
|
||||||
|
int64_t conditioning_embedding_dim,
|
||||||
|
bool elementwise_affine = true,
|
||||||
|
float eps = 1e-5f,
|
||||||
|
bool bias = true) {
|
||||||
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(conditioning_embedding_dim, eps, elementwise_affine, bias));
|
||||||
|
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(conditioning_embedding_dim, embedding_dim * 2, bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* c) {
|
||||||
|
// x: [N, n_token, hidden_size]
|
||||||
|
// c: [N, hidden_size]
|
||||||
|
// return: [N, n_token, patch_size * patch_size * out_channels]
|
||||||
|
|
||||||
|
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||||
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
|
|
||||||
|
auto emb = linear->forward(ctx, ggml_silu(ctx, c));
|
||||||
|
auto mods = ggml_chunk(ctx, emb, 2, 0);
|
||||||
|
auto scale = mods[0];
|
||||||
|
auto shift = mods[1];
|
||||||
|
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
x = Flux::modulate(ctx, x, shift, scale);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct QwenImageParams {
|
||||||
|
int64_t patch_size = 2;
|
||||||
|
int64_t in_channels = 64;
|
||||||
|
int64_t out_channels = 16;
|
||||||
|
int64_t num_layers = 60;
|
||||||
|
int64_t attention_head_dim = 128;
|
||||||
|
int64_t num_attention_heads = 24;
|
||||||
|
int64_t joint_attention_dim = 3584;
|
||||||
|
float theta = 10000;
|
||||||
|
std::vector<int> axes_dim = {16, 56, 56};
|
||||||
|
int64_t axes_dim_sum = 128;
|
||||||
|
bool flash_attn = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
class QwenImageModel : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
QwenImageParams params;
|
||||||
|
|
||||||
|
public:
|
||||||
|
QwenImageModel() {}
|
||||||
|
QwenImageModel(QwenImageParams params)
|
||||||
|
: params(params) {
|
||||||
|
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
|
||||||
|
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim));
|
||||||
|
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
|
||||||
|
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
|
||||||
|
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));
|
||||||
|
|
||||||
|
// blocks
|
||||||
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
|
||||||
|
params.num_attention_heads,
|
||||||
|
params.attention_head_dim,
|
||||||
|
1e-6f,
|
||||||
|
params.flash_attn));
|
||||||
|
blocks["transformer_blocks." + std::to_string(i)] = block;
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new AdaLayerNormContinuous(inner_dim, inner_dim, false, 1e-6f));
|
||||||
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
|
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
||||||
|
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
||||||
|
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
// x: [N, C, H, W]
|
||||||
|
// return: [N, h*w, C * patch_size * patch_size]
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t p = params.patch_size;
|
||||||
|
int64_t h = H / params.patch_size;
|
||||||
|
int64_t w = W / params.patch_size;
|
||||||
|
|
||||||
|
GGML_ASSERT(h * p == H && w * p == W);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
||||||
|
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
|
||||||
|
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t h,
|
||||||
|
int64_t w) {
|
||||||
|
// x: [N, h*w, C*patch_size*patch_size]
|
||||||
|
// return: [N, C, H, W]
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
|
||||||
|
int64_t H = h * params.patch_size;
|
||||||
|
int64_t W = w * params.patch_size;
|
||||||
|
int64_t p = params.patch_size;
|
||||||
|
|
||||||
|
GGML_ASSERT(C * p * p == x->ne[0]);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
|
||||||
|
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
||||||
|
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* pe) {
|
||||||
|
auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks["time_text_embed"]);
|
||||||
|
auto txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
|
||||||
|
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
|
||||||
|
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
|
||||||
|
auto norm_out = std::dynamic_pointer_cast<AdaLayerNormContinuous>(blocks["norm_out"]);
|
||||||
|
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||||
|
|
||||||
|
auto t_emb = time_text_embed->forward(ctx, timestep);
|
||||||
|
auto img = img_in->forward(ctx, x);
|
||||||
|
auto txt = txt_norm->forward(ctx, context);
|
||||||
|
txt = txt_in->forward(ctx, txt);
|
||||||
|
|
||||||
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
|
auto result = block->forward(ctx, backend, img, txt, t_emb, pe);
|
||||||
|
img = result.first;
|
||||||
|
txt = result.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
img = norm_out->forward(ctx, img, t_emb);
|
||||||
|
img = proj_out->forward(ctx, img);
|
||||||
|
|
||||||
|
return img;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* pe) {
|
||||||
|
// Forward pass of DiT.
|
||||||
|
// x: [N, C, H, W]
|
||||||
|
// timestep: [N,]
|
||||||
|
// context: [N, L, D]
|
||||||
|
// pe: [L, d_head/2, 2, 2]
|
||||||
|
// return: [N, C, H, W]
|
||||||
|
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
|
x = pad_to_patch_size(ctx, x);
|
||||||
|
x = patchify(ctx, x);
|
||||||
|
|
||||||
|
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
||||||
|
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
||||||
|
|
||||||
|
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||||
|
|
||||||
|
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
|
// slice
|
||||||
|
out = ggml_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
|
||||||
|
out = ggml_slice(ctx, out, 0, 0, W); // [N, C, H, W]
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct QwenImageRunner : public GGMLRunner {
|
||||||
|
public:
|
||||||
|
QwenImageParams qwen_image_params;
|
||||||
|
QwenImageModel qwen_image;
|
||||||
|
std::vector<float> pe_vec;
|
||||||
|
SDVersion version;
|
||||||
|
|
||||||
|
QwenImageRunner(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
const std::string prefix = "",
|
||||||
|
SDVersion version = VERSION_QWEN_IMAGE,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||||
|
qwen_image_params.flash_attn = flash_attn;
|
||||||
|
qwen_image = QwenImageModel(qwen_image_params);
|
||||||
|
qwen_image.init(params_ctx, tensor_types, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return "qwen_image";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
qwen_image.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* context) {
|
||||||
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
|
||||||
|
|
||||||
|
x = to_backend(x);
|
||||||
|
context = to_backend(context);
|
||||||
|
timesteps = to_backend(timesteps);
|
||||||
|
|
||||||
|
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
|
||||||
|
x->ne[0],
|
||||||
|
qwen_image_params.patch_size,
|
||||||
|
x->ne[3],
|
||||||
|
context->ne[1],
|
||||||
|
qwen_image_params.theta,
|
||||||
|
qwen_image_params.axes_dim);
|
||||||
|
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
||||||
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
|
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
|
||||||
|
// pe->data = pe_vec.data();
|
||||||
|
// print_ggml_tensor(pe);
|
||||||
|
// pe->data = NULL;
|
||||||
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
|
struct ggml_tensor* out = qwen_image.forward(compute_ctx,
|
||||||
|
runtime_backend,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
pe);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(int n_threads,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor** output = NULL,
|
||||||
|
struct ggml_context* output_ctx = NULL) {
|
||||||
|
// x: [N, in_channels, h, w]
|
||||||
|
// timesteps: [N, ]
|
||||||
|
// context: [N, max_position, hidden_size]
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(x, timesteps, context);
|
||||||
|
};
|
||||||
|
|
||||||
|
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1GB
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
|
||||||
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(work_ctx != NULL);
|
||||||
|
|
||||||
|
{
|
||||||
|
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
||||||
|
// ggml_set_f32(x, 0.01f);
|
||||||
|
auto x = load_tensor_from_file(work_ctx, "./qwen_image_x.bin");
|
||||||
|
print_ggml_tensor(x);
|
||||||
|
|
||||||
|
std::vector<float> timesteps_vec(1, 1000.f);
|
||||||
|
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||||
|
|
||||||
|
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 3584, 256, 1);
|
||||||
|
// ggml_set_f32(context, 0.01f);
|
||||||
|
auto context = load_tensor_from_file(work_ctx, "./qwen_image_context.bin");
|
||||||
|
print_ggml_tensor(context);
|
||||||
|
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int t0 = ggml_time_ms();
|
||||||
|
compute(8, x, timesteps, context, &out, work_ctx);
|
||||||
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("qwen_image test done in %dms", t1 - t0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
|
// cuda q8: pass
|
||||||
|
// cuda q8 fa: nan
|
||||||
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
|
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor_types = model_loader.tensor_storages_types;
|
||||||
|
for (auto& item : tensor_types) {
|
||||||
|
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||||
|
if (ends_with(item.first, "weight")) {
|
||||||
|
item.second = model_data_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<QwenImageRunner> qwen_image = std::shared_ptr<QwenImageRunner>(new QwenImageRunner(backend,
|
||||||
|
false,
|
||||||
|
tensor_types,
|
||||||
|
"model.diffusion_model",
|
||||||
|
VERSION_QWEN_IMAGE,
|
||||||
|
true));
|
||||||
|
|
||||||
|
qwen_image->alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
qwen_image->get_param_tensors(tensors, "model.diffusion_model");
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(tensors);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tensors from model loader failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("qwen_image model loaded");
|
||||||
|
qwen_image->test();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace name
|
||||||
|
|
||||||
|
#endif // __QWEN_IMAGE_HPP__
|
||||||
731
qwenvl.hpp
Normal file
731
qwenvl.hpp
Normal file
@ -0,0 +1,731 @@
|
|||||||
|
#ifndef __QWENVL_HPP__
|
||||||
|
#define __QWENVL_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
|
#include <regex>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "clip.hpp"
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
|
||||||
|
namespace Qwen {
|
||||||
|
|
||||||
|
class Qwen2Tokenizer {
|
||||||
|
private:
|
||||||
|
std::map<int, std::u32string> byte_encoder;
|
||||||
|
std::map<std::u32string, int> byte_decoder;
|
||||||
|
std::map<std::u32string, int> encoder;
|
||||||
|
std::map<int, std::u32string> decoder;
|
||||||
|
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||||
|
std::regex pat;
|
||||||
|
int encoder_len;
|
||||||
|
int bpe_len;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||||
|
const std::string EOS_TOKEN = "<|endoftext|>";
|
||||||
|
const std::string PAD_TOKEN = "<|endoftext|>";
|
||||||
|
|
||||||
|
const int UNK_TOKEN_ID = 151643;
|
||||||
|
const int EOS_TOKEN_ID = 151643;
|
||||||
|
const int PAD_TOKEN_ID = 151643;
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens = {
|
||||||
|
"<|endoftext|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>",
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<|fim_prefix|>",
|
||||||
|
"<|fim_middle|>",
|
||||||
|
"<|fim_suffix|>",
|
||||||
|
"<|fim_pad|>",
|
||||||
|
"<|repo_name|>",
|
||||||
|
"<|file_sep|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::string strip(const std::string& str) {
|
||||||
|
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||||
|
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||||
|
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
// String contains only whitespace characters
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
return str.substr(start, end - start + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string whitespace_clean(std::string text) {
|
||||||
|
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||||
|
text = strip(text);
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||||
|
if (subwords.size() == 0) {
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
std::u32string prev_subword = subwords[0];
|
||||||
|
for (int i = 1; i < subwords.size(); i++) {
|
||||||
|
std::u32string subword = subwords[i];
|
||||||
|
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||||
|
pairs.insert(pair);
|
||||||
|
prev_subword = subword;
|
||||||
|
}
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_special_token(const std::string& token) {
|
||||||
|
for (auto& special_token : special_tokens) {
|
||||||
|
if (special_token == token) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") {
|
||||||
|
if (merges_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(ModelLoader::load_qwen2_merges());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str) {
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
for (auto& pair : byte_unicode_pairs) {
|
||||||
|
byte_decoder[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
// for (auto & pair: byte_unicode_pairs) {
|
||||||
|
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||||
|
// }
|
||||||
|
std::vector<std::u32string> merges;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t pos;
|
||||||
|
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||||
|
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||||
|
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
LOG_DEBUG("merges size %llu", merges.size());
|
||||||
|
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||||
|
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||||
|
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::u32string> vocab;
|
||||||
|
for (const auto& pair : byte_unicode_pairs) {
|
||||||
|
vocab.push_back(pair.second);
|
||||||
|
}
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
vocab.push_back(merge.first + merge.second);
|
||||||
|
}
|
||||||
|
for (auto& special_token : special_tokens) {
|
||||||
|
vocab.push_back(utf8_to_utf32(special_token));
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& token : vocab) {
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
encoder_len = i;
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::u32string bpe(const std::u32string& token) {
|
||||||
|
std::vector<std::u32string> word;
|
||||||
|
|
||||||
|
for (int i = 0; i < token.size(); i++) {
|
||||||
|
word.emplace_back(1, token[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||||
|
|
||||||
|
if (pairs.empty()) {
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||||
|
pairs.end(),
|
||||||
|
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||||
|
const std::pair<std::u32string, std::u32string>& b) {
|
||||||
|
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||||
|
return false;
|
||||||
|
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||||
|
});
|
||||||
|
|
||||||
|
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||||
|
|
||||||
|
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string first = bigram.first;
|
||||||
|
std::u32string second = bigram.second;
|
||||||
|
std::vector<std::u32string> new_word;
|
||||||
|
int32_t i = 0;
|
||||||
|
|
||||||
|
while (i < word.size()) {
|
||||||
|
auto it = std::find(word.begin() + i, word.end(), first);
|
||||||
|
if (it == word.end()) {
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||||
|
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||||
|
|
||||||
|
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||||
|
new_word.push_back(first + second);
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
new_word.push_back(word[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
word = new_word;
|
||||||
|
|
||||||
|
if (word.size() == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
pairs = get_pairs(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string result;
|
||||||
|
for (int i = 0; i < word.size(); i++) {
|
||||||
|
result += word[i];
|
||||||
|
if (i != word.size() - 1) {
|
||||||
|
result += utf8_to_utf32(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tokenize(std::string text,
|
||||||
|
on_new_token_cb_t on_new_token_cb = nullptr,
|
||||||
|
size_t max_length = 0,
|
||||||
|
bool padding = false) {
|
||||||
|
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
|
||||||
|
|
||||||
|
if (max_length > 0) {
|
||||||
|
if (tokens.size() < max_length) {
|
||||||
|
tokens.resize(max_length);
|
||||||
|
} else {
|
||||||
|
if (padding) {
|
||||||
|
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
void pad_tokens(std::vector<int>& tokens,
|
||||||
|
std::vector<float>& weights,
|
||||||
|
size_t max_length = 0,
|
||||||
|
bool padding = false) {
|
||||||
|
if (max_length > 0 && padding) {
|
||||||
|
size_t n = std::ceil(tokens.size() * 1.0 / max_length);
|
||||||
|
if (n == 0) {
|
||||||
|
n = 1;
|
||||||
|
}
|
||||||
|
size_t length = max_length * n;
|
||||||
|
LOG_DEBUG("token length: %llu", length);
|
||||||
|
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
|
||||||
|
weights.insert(weights.end(), length - weights.size(), 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
|
||||||
|
std::string original_text = text;
|
||||||
|
std::vector<int32_t> bpe_tokens;
|
||||||
|
std::vector<std::string> token_strs;
|
||||||
|
|
||||||
|
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||||
|
|
||||||
|
for (auto& splited_text : splited_texts) {
|
||||||
|
if (is_special_token(splited_text)) {
|
||||||
|
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
||||||
|
token_strs.push_back(splited_text);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto tokens = token_split(splited_text);
|
||||||
|
for (auto& token : tokens) {
|
||||||
|
if (on_new_token_cb != nullptr) {
|
||||||
|
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||||
|
if (skip) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token_str = token;
|
||||||
|
std::u32string utf32_token;
|
||||||
|
for (int i = 0; i < token_str.length(); i++) {
|
||||||
|
unsigned char b = token_str[i];
|
||||||
|
utf32_token += byte_encoder[b];
|
||||||
|
}
|
||||||
|
auto bpe_strs = bpe(utf32_token);
|
||||||
|
size_t start = 0;
|
||||||
|
size_t pos;
|
||||||
|
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||||
|
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||||
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
|
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||||
|
bpe_tokens.push_back(encoder[bpe_str]);
|
||||||
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
for (auto token : token_strs) {
|
||||||
|
ss << "\"" << token << "\", ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||||
|
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||||
|
return bpe_tokens;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLMLP : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
|
||||||
|
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false));
|
||||||
|
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, false));
|
||||||
|
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
// x: [N, n_token, hidden_size]
|
||||||
|
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
|
||||||
|
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
|
||||||
|
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
||||||
|
|
||||||
|
auto h = gate_proj->forward(ctx, x);
|
||||||
|
h = ggml_silu_inplace(ctx, h);
|
||||||
|
h = ggml_mul_inplace(ctx, h, up_proj->forward(ctx, x));
|
||||||
|
h = down_proj->forward(ctx, h);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLAttention : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t head_dim;
|
||||||
|
int64_t num_heads;
|
||||||
|
int64_t num_kv_heads;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Qwen2_5_VLAttention(int64_t hidden_size,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads)
|
||||||
|
: num_heads(num_heads), num_kv_heads(num_kv_heads) {
|
||||||
|
head_dim = hidden_size / num_heads;
|
||||||
|
GGML_ASSERT(num_heads * head_dim == hidden_size);
|
||||||
|
blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, num_heads * head_dim));
|
||||||
|
blocks["k_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, num_kv_heads * head_dim));
|
||||||
|
blocks["v_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, num_kv_heads * head_dim));
|
||||||
|
blocks["o_proj"] = std::shared_ptr<GGMLBlock>(new Linear(num_heads * head_dim, hidden_size, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* input_pos) {
|
||||||
|
// x: [N, n_token, hidden_size]
|
||||||
|
int64_t n_token = x->ne[1];
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q_proj"]);
|
||||||
|
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k_proj"]);
|
||||||
|
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v_proj"]);
|
||||||
|
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["o_proj"]);
|
||||||
|
|
||||||
|
auto q = q_proj->forward(ctx, x); // [N, n_token, num_heads*head_dim]
|
||||||
|
auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
||||||
|
auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim]
|
||||||
|
|
||||||
|
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim]
|
||||||
|
k = ggml_reshape_4d(ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
v = ggml_reshape_4d(ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
int sections[4] = {16, 24, 24, 0};
|
||||||
|
q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
|
k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
|
|
||||||
|
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim]
|
||||||
|
q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim]
|
||||||
|
|
||||||
|
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||||
|
k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||||
|
|
||||||
|
x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
|
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLBlock : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
Qwen2_5_VLBlock(int64_t hidden_size,
|
||||||
|
int64_t intermediate_size,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads,
|
||||||
|
float eps = 1e-6f) {
|
||||||
|
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLAttention(hidden_size, num_heads, num_kv_heads));
|
||||||
|
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLMLP(hidden_size, intermediate_size));
|
||||||
|
blocks["input_layernorm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
|
blocks["post_attention_layernorm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* input_pos) {
|
||||||
|
// x: [N, n_token, hidden_size]
|
||||||
|
auto self_attn = std::dynamic_pointer_cast<Qwen2_5_VLAttention>(blocks["self_attn"]);
|
||||||
|
auto mlp = std::dynamic_pointer_cast<Qwen2_5_VLMLP>(blocks["mlp"]);
|
||||||
|
auto input_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["input_layernorm"]);
|
||||||
|
auto post_attention_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["post_attention_layernorm"]);
|
||||||
|
|
||||||
|
auto residual = x;
|
||||||
|
x = input_layernorm->forward(ctx, x);
|
||||||
|
x = self_attn->forward(ctx, backend, x, input_pos);
|
||||||
|
x = ggml_add_inplace(ctx, x, residual);
|
||||||
|
|
||||||
|
residual = x;
|
||||||
|
x = post_attention_layernorm->forward(ctx, x);
|
||||||
|
x = mlp->forward(ctx, x);
|
||||||
|
x = ggml_add_inplace(ctx, x, residual);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLTextModel : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t num_layers;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Qwen2_5_VLTextModel(int64_t num_layers,
|
||||||
|
int64_t vocab_size,
|
||||||
|
int64_t hidden_size,
|
||||||
|
int64_t intermediate_size,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads,
|
||||||
|
float eps = 1e-6f)
|
||||||
|
: num_layers(num_layers) {
|
||||||
|
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(vocab_size, hidden_size));
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLBlock(hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads));
|
||||||
|
}
|
||||||
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(hidden_size, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* input_pos) {
|
||||||
|
// input_ids: [N, n_token]
|
||||||
|
// return: [N, n_token, hidden_size]
|
||||||
|
|
||||||
|
auto embed_tokens = std::dynamic_pointer_cast<Embedding>(blocks["embed_tokens"]);
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
|
||||||
|
auto x = embed_tokens->forward(ctx, input_ids);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_layers; i++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<Qwen2_5_VLBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
|
x = block->forward(ctx, backend, x, input_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLParams {
|
||||||
|
int64_t num_layers = 28;
|
||||||
|
int64_t hidden_size = 3584;
|
||||||
|
int64_t intermediate_size = 18944;
|
||||||
|
int64_t num_heads = 28;
|
||||||
|
int64_t num_kv_heads = 4;
|
||||||
|
int64_t vocab_size = 152064;
|
||||||
|
float rms_norm_eps = 1e-06f;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VL : public GGMLBlock {
|
||||||
|
Qwen2_5_VLParams params;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Qwen2_5_VL() {}
|
||||||
|
Qwen2_5_VL(Qwen2_5_VLParams params)
|
||||||
|
: params(params) {
|
||||||
|
blocks["model"] = std::shared_ptr<GGMLBlock>(new Qwen2_5_VLTextModel(params.num_layers,
|
||||||
|
params.vocab_size,
|
||||||
|
params.hidden_size,
|
||||||
|
params.intermediate_size,
|
||||||
|
params.num_heads,
|
||||||
|
params.num_kv_heads,
|
||||||
|
params.rms_norm_eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* input_pos) {
|
||||||
|
// input_ids: [N, n_token]
|
||||||
|
auto model = std::dynamic_pointer_cast<Qwen2_5_VLTextModel>(blocks["model"]);
|
||||||
|
|
||||||
|
auto x = model->forward(ctx, backend, input_ids, input_pos);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLRunner : public GGMLRunner {
|
||||||
|
Qwen2_5_VLParams params;
|
||||||
|
Qwen2_5_VL model;
|
||||||
|
|
||||||
|
std::vector<int> input_pos_vec;
|
||||||
|
|
||||||
|
Qwen2_5_VLRunner(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2GGMLType& tensor_types,
|
||||||
|
const std::string prefix)
|
||||||
|
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||||
|
model = Qwen2_5_VL(params);
|
||||||
|
model.init(params_ctx, tensor_types, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return "qwenvl2.5";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
model.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* input_pos) {
|
||||||
|
auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size]
|
||||||
|
return hidden_states;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) {
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
|
||||||
|
input_ids = to_backend(input_ids);
|
||||||
|
|
||||||
|
int64_t n_tokens = input_ids->ne[0];
|
||||||
|
input_pos_vec.resize(n_tokens * 4);
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
input_pos_vec[i] = i;
|
||||||
|
input_pos_vec[n_tokens + i] = i;
|
||||||
|
input_pos_vec[2 * n_tokens + i] = i;
|
||||||
|
input_pos_vec[3 * n_tokens + i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_pos = ggml_new_tensor_1d(compute_ctx,
|
||||||
|
GGML_TYPE_I32,
|
||||||
|
n_tokens * 4);
|
||||||
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
|
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const int n_threads,
|
||||||
|
struct ggml_tensor* input_ids,
|
||||||
|
ggml_tensor** output,
|
||||||
|
ggml_context* output_ctx = NULL) {
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(input_ids);
|
||||||
|
};
|
||||||
|
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Qwen2_5_VLEmbedder {
|
||||||
|
Qwen2Tokenizer tokenizer;
|
||||||
|
Qwen2_5_VLRunner model;
|
||||||
|
|
||||||
|
Qwen2_5_VLEmbedder(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
const std::string prefix = "")
|
||||||
|
: model(backend, offload_params_to_cpu, tensor_types, prefix) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
model.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
void alloc_params_buffer() {
|
||||||
|
model.alloc_params_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||||
|
size_t max_length = 0,
|
||||||
|
bool padding = false) {
|
||||||
|
auto parsed_attention = parse_prompt_attention(text);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
for (const auto& item : parsed_attention) {
|
||||||
|
ss << "['" << item.first << "', " << item.second << "], ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tokens;
|
||||||
|
std::vector<float> weights;
|
||||||
|
for (const auto& item : parsed_attention) {
|
||||||
|
const std::string& curr_text = item.first;
|
||||||
|
float curr_weight = item.second;
|
||||||
|
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
|
||||||
|
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||||
|
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||||
|
|
||||||
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
|
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||||
|
// }
|
||||||
|
// std::cout << std::endl;
|
||||||
|
|
||||||
|
return {tokens, weights};
|
||||||
|
}
|
||||||
|
|
||||||
|
void test() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1GB
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
|
||||||
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(work_ctx != NULL);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n");
|
||||||
|
auto tokens_and_weights = tokenize(text, 0, false);
|
||||||
|
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
||||||
|
std::vector<float>& weights = std::get<1>(tokens_and_weights);
|
||||||
|
for (auto token : tokens) {
|
||||||
|
printf("%d ", token);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int t0 = ggml_time_ms();
|
||||||
|
model.compute(8, input_ids, &out, work_ctx);
|
||||||
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("qwen2vl test done in %dms", t1 - t0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
|
// cpu f16: pass
|
||||||
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
|
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path, "qwen2vl.")) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor_types = model_loader.tensor_storages_types;
|
||||||
|
for (auto& item : tensor_types) {
|
||||||
|
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||||
|
if (ends_with(item.first, "weight")) {
|
||||||
|
item.second = model_data_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::shared_ptr<Qwen2_5_VLEmbedder>(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl"));
|
||||||
|
|
||||||
|
qwenvl->alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
qwenvl->get_param_tensors(tensors, "qwen2vl");
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(tensors);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tensors from model loader failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("qwenvl model loaded");
|
||||||
|
qwenvl->test();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // Qwen
|
||||||
|
|
||||||
|
#endif // __QWENVL_HPP__
|
||||||
32
rope.hpp
32
rope.hpp
@ -203,6 +203,38 @@ struct Rope {
|
|||||||
return embed_nd(ids, bs, theta, axes_dim);
|
return embed_nd(ids, bs, theta, axes_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> gen_qwen_image_ids(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len) {
|
||||||
|
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||||
|
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||||
|
int txt_id_start = std::max(h_len, w_len);
|
||||||
|
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
|
||||||
|
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
|
||||||
|
for (int i = 0; i < bs; ++i) {
|
||||||
|
for (int j = 0; j < txt_ids.size(); ++j) {
|
||||||
|
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||||
|
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate qwen_image positional embeddings
|
||||||
|
static std::vector<float> gen_qwen_image_pe(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len,
|
||||||
|
int theta,
|
||||||
|
const std::vector<int>& axes_dim) {
|
||||||
|
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len);
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim);
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<std::vector<float>> gen_vid_ids(int t,
|
static std::vector<std::vector<float>> gen_vid_ids(int t,
|
||||||
int h,
|
int h,
|
||||||
int w,
|
int w,
|
||||||
|
|||||||
@ -42,6 +42,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Wan 2.x",
|
"Wan 2.x",
|
||||||
"Wan 2.2 I2V",
|
"Wan 2.2 I2V",
|
||||||
"Wan 2.2 TI2V",
|
"Wan 2.2 TI2V",
|
||||||
|
"Qwen Image",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
@ -253,6 +254,13 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_path)) > 0) {
|
||||||
|
LOG_INFO("loading qwen2vl from '%s'", sd_ctx_params->qwen2vl_path);
|
||||||
|
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_path, "text_encoders.qwen2vl.")) {
|
||||||
|
LOG_WARN("loading qwen2vl from '%s' failed", sd_ctx_params->qwen2vl_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
|
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
|
||||||
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
|
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
|
||||||
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
|
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
|
||||||
@ -318,7 +326,7 @@ public:
|
|||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
scale_factor = 0.3611f;
|
scale_factor = 0.3611f;
|
||||||
// TODO: shift_factor
|
// TODO: shift_factor
|
||||||
} else if (sd_version_is_wan(version)) {
|
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
scale_factor = 1.0f;
|
scale_factor = 1.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -332,7 +340,7 @@ public:
|
|||||||
{
|
{
|
||||||
clip_backend = backend;
|
clip_backend = backend;
|
||||||
bool use_t5xxl = false;
|
bool use_t5xxl = false;
|
||||||
if (sd_version_is_dit(version)) {
|
if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) {
|
||||||
use_t5xxl = true;
|
use_t5xxl = true;
|
||||||
}
|
}
|
||||||
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
|
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
|
||||||
@ -418,6 +426,16 @@ public:
|
|||||||
clip_vision->alloc_params_buffer();
|
clip_vision->alloc_params_buffer();
|
||||||
clip_vision->get_param_tensors(tensors);
|
clip_vision->get_param_tensors(tensors);
|
||||||
}
|
}
|
||||||
|
} else if (sd_version_is_qwen_image(version)) {
|
||||||
|
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
model_loader.tensor_storages_types);
|
||||||
|
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
model_loader.tensor_storages_types,
|
||||||
|
"model.diffusion_model",
|
||||||
|
version,
|
||||||
|
sd_ctx_params->diffusion_flash_attn);
|
||||||
} else { // SD1.x SD2.x SDXL
|
} else { // SD1.x SD2.x SDXL
|
||||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||||
@ -466,7 +484,7 @@ public:
|
|||||||
vae_backend = backend;
|
vae_backend = backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
@ -711,6 +729,13 @@ public:
|
|||||||
shift = 5.0;
|
shift = 5.0;
|
||||||
}
|
}
|
||||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||||
|
} else if (sd_version_is_qwen_image(version)) {
|
||||||
|
LOG_INFO("running in FLOW mode");
|
||||||
|
float shift = sd_ctx_params->flow_shift;
|
||||||
|
if (shift == INFINITY) {
|
||||||
|
shift = 3.0;
|
||||||
|
}
|
||||||
|
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||||
} else if (is_using_v_parameterization) {
|
} else if (is_using_v_parameterization) {
|
||||||
LOG_INFO("running in v-prediction mode");
|
LOG_INFO("running in v-prediction mode");
|
||||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||||
@ -989,7 +1014,7 @@ public:
|
|||||||
ggml_tensor_scale(noise, augmentation_level);
|
ggml_tensor_scale(noise, augmentation_level);
|
||||||
ggml_tensor_add(init_img, noise);
|
ggml_tensor_add(init_img, noise);
|
||||||
}
|
}
|
||||||
ggml_tensor* moments = encode_first_stage(work_ctx, init_img);
|
ggml_tensor* moments = vae_encode(work_ctx, init_img);
|
||||||
c_concat = get_first_stage_encoding(work_ctx, moments);
|
c_concat = get_first_stage_encoding(work_ctx, moments);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1298,118 +1323,8 @@ public:
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding
|
|
||||||
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
|
|
||||||
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
|
||||||
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
|
||||||
ggml_tensor_set_f32_randn(noise, rng);
|
|
||||||
{
|
|
||||||
float mean = 0;
|
|
||||||
float logvar = 0;
|
|
||||||
float value = 0;
|
|
||||||
float std_ = 0;
|
|
||||||
for (int i = 0; i < latent->ne[3]; i++) {
|
|
||||||
for (int j = 0; j < latent->ne[2]; j++) {
|
|
||||||
for (int k = 0; k < latent->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latent->ne[0]; l++) {
|
|
||||||
mean = ggml_tensor_get_f32(moments, l, k, j, i);
|
|
||||||
logvar = ggml_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
|
|
||||||
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
|
||||||
std_ = std::exp(0.5f * logvar);
|
|
||||||
value = mean + std_ * ggml_tensor_get_f32(noise, l, k, j, i);
|
|
||||||
value = value * scale_factor;
|
|
||||||
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
|
||||||
ggml_tensor_set_f32(latent, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return latent;
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_tile_sizes(int& tile_size_x,
|
|
||||||
int& tile_size_y,
|
|
||||||
float& tile_overlap,
|
|
||||||
const sd_tiling_params_t& params,
|
|
||||||
int latent_x,
|
|
||||||
int latent_y,
|
|
||||||
float encoding_factor = 1.0f) {
|
|
||||||
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
|
||||||
auto get_tile_size = [&](int requested_size, float factor, int latent_size) {
|
|
||||||
const int default_tile_size = 32;
|
|
||||||
const int min_tile_dimension = 4;
|
|
||||||
int tile_size = default_tile_size;
|
|
||||||
// factor <= 1 means simple fraction of the latent dimension
|
|
||||||
// factor > 1 means number of tiles across that dimension
|
|
||||||
if (factor > 0.f) {
|
|
||||||
if (factor > 1.0)
|
|
||||||
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
|
||||||
tile_size = std::round(latent_size * factor);
|
|
||||||
} else if (requested_size >= min_tile_dimension) {
|
|
||||||
tile_size = requested_size;
|
|
||||||
}
|
|
||||||
tile_size *= encoding_factor;
|
|
||||||
return std::max(std::min(tile_size, latent_size), min_tile_dimension);
|
|
||||||
};
|
|
||||||
|
|
||||||
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
|
||||||
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
ggml_tensor* result = NULL;
|
|
||||||
int W = x->ne[0] / 8;
|
|
||||||
int H = x->ne[1] / 8;
|
|
||||||
if (vae_tiling_params.enabled && !encode_video) {
|
|
||||||
// TODO wan2.2 vae support?
|
|
||||||
int C = sd_version_is_dit(version) ? 16 : 4;
|
|
||||||
if (!use_tiny_autoencoder) {
|
|
||||||
C *= 2;
|
|
||||||
}
|
|
||||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!use_tiny_autoencoder) {
|
|
||||||
process_vae_input_tensor(x);
|
|
||||||
if (vae_tiling_params.enabled && !encode_video) {
|
|
||||||
float tile_overlap;
|
|
||||||
int tile_size_x, tile_size_y;
|
|
||||||
// multiply tile size for encode to keep the compute buffer size consistent
|
|
||||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f);
|
|
||||||
|
|
||||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
|
||||||
|
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
||||||
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
|
|
||||||
};
|
|
||||||
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
|
||||||
} else {
|
|
||||||
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
|
||||||
}
|
|
||||||
first_stage_model->free_compute_buffer();
|
|
||||||
} else {
|
|
||||||
if (vae_tiling_params.enabled && !encode_video) {
|
|
||||||
// split latent in 32x32 tiles and compute in several steps
|
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
||||||
tae_first_stage->compute(n_threads, in, false, &out, NULL);
|
|
||||||
};
|
|
||||||
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
|
||||||
} else {
|
|
||||||
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
|
||||||
}
|
|
||||||
tae_first_stage->free_compute_buffer();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void process_latent_in(ggml_tensor* latent) {
|
void process_latent_in(ggml_tensor* latent) {
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||||
@ -1449,7 +1364,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void process_latent_out(ggml_tensor* latent) {
|
void process_latent_out(ggml_tensor* latent) {
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||||
@ -1488,6 +1403,146 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void get_tile_sizes(int& tile_size_x,
|
||||||
|
int& tile_size_y,
|
||||||
|
float& tile_overlap,
|
||||||
|
const sd_tiling_params_t& params,
|
||||||
|
int latent_x,
|
||||||
|
int latent_y,
|
||||||
|
float encoding_factor = 1.0f) {
|
||||||
|
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
||||||
|
auto get_tile_size = [&](int requested_size, float factor, int latent_size) {
|
||||||
|
const int default_tile_size = 32;
|
||||||
|
const int min_tile_dimension = 4;
|
||||||
|
int tile_size = default_tile_size;
|
||||||
|
// factor <= 1 means simple fraction of the latent dimension
|
||||||
|
// factor > 1 means number of tiles across that dimension
|
||||||
|
if (factor > 0.f) {
|
||||||
|
if (factor > 1.0)
|
||||||
|
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
||||||
|
tile_size = std::round(latent_size * factor);
|
||||||
|
} else if (requested_size >= min_tile_dimension) {
|
||||||
|
tile_size = requested_size;
|
||||||
|
}
|
||||||
|
tile_size *= encoding_factor;
|
||||||
|
return std::max(std::min(tile_size, latent_size), min_tile_dimension);
|
||||||
|
};
|
||||||
|
|
||||||
|
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
||||||
|
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
ggml_tensor* result = NULL;
|
||||||
|
int W = x->ne[0] / 8;
|
||||||
|
int H = x->ne[1] / 8;
|
||||||
|
if (vae_tiling_params.enabled && !encode_video) {
|
||||||
|
// TODO wan2.2 vae support?
|
||||||
|
int C = sd_version_is_dit(version) ? 16 : 4;
|
||||||
|
if (!use_tiny_autoencoder) {
|
||||||
|
C *= 2;
|
||||||
|
}
|
||||||
|
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_qwen_image(version)) {
|
||||||
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!use_tiny_autoencoder) {
|
||||||
|
process_vae_input_tensor(x);
|
||||||
|
if (vae_tiling_params.enabled && !encode_video) {
|
||||||
|
float tile_overlap;
|
||||||
|
int tile_size_x, tile_size_y;
|
||||||
|
// multiply tile size for encode to keep the compute buffer size consistent
|
||||||
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f);
|
||||||
|
|
||||||
|
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||||
|
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
|
||||||
|
};
|
||||||
|
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
||||||
|
} else {
|
||||||
|
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
||||||
|
}
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
} else {
|
||||||
|
if (vae_tiling_params.enabled && !encode_video) {
|
||||||
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
tae_first_stage->compute(n_threads, in, false, &out, NULL);
|
||||||
|
};
|
||||||
|
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
||||||
|
} else {
|
||||||
|
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
||||||
|
}
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments) {
|
||||||
|
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
||||||
|
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||||
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
||||||
|
ggml_tensor_set_f32_randn(noise, rng);
|
||||||
|
{
|
||||||
|
float mean = 0;
|
||||||
|
float logvar = 0;
|
||||||
|
float value = 0;
|
||||||
|
float std_ = 0;
|
||||||
|
for (int i = 0; i < latent->ne[3]; i++) {
|
||||||
|
for (int j = 0; j < latent->ne[2]; j++) {
|
||||||
|
for (int k = 0; k < latent->ne[1]; k++) {
|
||||||
|
for (int l = 0; l < latent->ne[0]; l++) {
|
||||||
|
mean = ggml_tensor_get_f32(moments, l, k, j, i);
|
||||||
|
logvar = ggml_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
|
||||||
|
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
||||||
|
std_ = std::exp(0.5f * logvar);
|
||||||
|
value = mean + std_ * ggml_tensor_get_f32(noise, l, k, j, i);
|
||||||
|
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
||||||
|
ggml_tensor_set_f32(latent, value, l, k, j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return latent;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
|
||||||
|
ggml_tensor* latent;
|
||||||
|
if (use_tiny_autoencoder || sd_version_is_qwen_image(version) || sd_version_is_wan(version)) {
|
||||||
|
latent = vae_output;
|
||||||
|
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||||
|
latent = ggml_view_3d(work_ctx,
|
||||||
|
vae_output,
|
||||||
|
vae_output->ne[0],
|
||||||
|
vae_output->ne[1],
|
||||||
|
vae_output->ne[2] / 2,
|
||||||
|
vae_output->nb[1],
|
||||||
|
vae_output->nb[2],
|
||||||
|
0);
|
||||||
|
} else {
|
||||||
|
latent = gaussian_latent_sample(work_ctx, vae_output);
|
||||||
|
}
|
||||||
|
process_latent_in(latent);
|
||||||
|
if (sd_version_is_qwen_image(version)) {
|
||||||
|
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
|
||||||
|
}
|
||||||
|
return latent;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
||||||
|
ggml_tensor* vae_output = vae_encode(work_ctx, x, encode_video);
|
||||||
|
return get_first_stage_encoding(work_ctx, vae_output);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||||
int64_t W = x->ne[0] * 8;
|
int64_t W = x->ne[0] * 8;
|
||||||
int64_t H = x->ne[1] * 8;
|
int64_t H = x->ne[1] * 8;
|
||||||
@ -1518,6 +1573,9 @@ public:
|
|||||||
}
|
}
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
if (!use_tiny_autoencoder) {
|
if (!use_tiny_autoencoder) {
|
||||||
|
if (sd_version_is_qwen_image(version)) {
|
||||||
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||||
|
}
|
||||||
process_latent_out(x);
|
process_latent_out(x);
|
||||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||||
if (vae_tiling_params.enabled && !decode_video) {
|
if (vae_tiling_params.enabled && !decode_video) {
|
||||||
@ -1689,6 +1747,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"clip_g_path: %s\n"
|
"clip_g_path: %s\n"
|
||||||
"clip_vision_path: %s\n"
|
"clip_vision_path: %s\n"
|
||||||
"t5xxl_path: %s\n"
|
"t5xxl_path: %s\n"
|
||||||
|
"qwen2vl_path: %s\n"
|
||||||
"diffusion_model_path: %s\n"
|
"diffusion_model_path: %s\n"
|
||||||
"high_noise_diffusion_model_path: %s\n"
|
"high_noise_diffusion_model_path: %s\n"
|
||||||
"vae_path: %s\n"
|
"vae_path: %s\n"
|
||||||
@ -1716,6 +1775,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
SAFE_STR(sd_ctx_params->clip_g_path),
|
SAFE_STR(sd_ctx_params->clip_g_path),
|
||||||
SAFE_STR(sd_ctx_params->clip_vision_path),
|
SAFE_STR(sd_ctx_params->clip_vision_path),
|
||||||
SAFE_STR(sd_ctx_params->t5xxl_path),
|
SAFE_STR(sd_ctx_params->t5xxl_path),
|
||||||
|
SAFE_STR(sd_ctx_params->qwen2vl_path),
|
||||||
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
||||||
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
||||||
SAFE_STR(sd_ctx_params->vae_path),
|
SAFE_STR(sd_ctx_params->vae_path),
|
||||||
@ -2079,6 +2139,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
C = 16;
|
C = 16;
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
|
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
||||||
|
C = 16;
|
||||||
}
|
}
|
||||||
int W = width / 8;
|
int W = width / 8;
|
||||||
int H = height / 8;
|
int H = height / 8;
|
||||||
@ -2086,12 +2148,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
struct ggml_tensor* control_latent = NULL;
|
struct ggml_tensor* control_latent = NULL;
|
||||||
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
|
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
|
||||||
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
|
||||||
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
|
|
||||||
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
|
|
||||||
} else {
|
|
||||||
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
|
|
||||||
}
|
|
||||||
ggml_tensor_scale(control_latent, control_strength);
|
ggml_tensor_scale(control_latent, control_strength);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2278,6 +2335,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
|||||||
C = 16;
|
C = 16;
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
|
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
||||||
|
C = 16;
|
||||||
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
||||||
C = 16;
|
C = 16;
|
||||||
T = ((T - 1) / 4) + 1;
|
T = ((T - 1) / 4) + 1;
|
||||||
@ -2354,7 +2413,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
|
||||||
|
|
||||||
ggml_tensor* init_latent = NULL;
|
ggml_tensor* init_latent = NULL;
|
||||||
ggml_tensor* init_moments = NULL;
|
|
||||||
ggml_tensor* concat_latent = NULL;
|
ggml_tensor* concat_latent = NULL;
|
||||||
ggml_tensor* denoise_mask = NULL;
|
ggml_tensor* denoise_mask = NULL;
|
||||||
if (sd_img_gen_params->init_image.data) {
|
if (sd_img_gen_params->init_image.data) {
|
||||||
@ -2374,12 +2432,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img);
|
sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img);
|
||||||
sd_image_to_tensor(sd_img_gen_params->init_image, init_img);
|
sd_image_to_tensor(sd_img_gen_params->init_image, init_img);
|
||||||
|
|
||||||
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
|
||||||
init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
|
|
||||||
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
|
|
||||||
} else {
|
|
||||||
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
int64_t mask_channels = 1;
|
int64_t mask_channels = 1;
|
||||||
@ -2389,16 +2442,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
mask_channels = 1 + init_latent->ne[2];
|
mask_channels = 1 + init_latent->ne[2];
|
||||||
}
|
}
|
||||||
ggml_tensor* masked_latent = NULL;
|
ggml_tensor* masked_latent = NULL;
|
||||||
|
|
||||||
if (sd_ctx->sd->version != VERSION_FLEX_2) {
|
if (sd_ctx->sd->version != VERSION_FLEX_2) {
|
||||||
// most inpaint models mask before vae
|
// most inpaint models mask before vae
|
||||||
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
sd_apply_mask(init_img, mask_img, masked_img);
|
sd_apply_mask(init_img, mask_img, masked_img);
|
||||||
if (!sd_ctx->sd->use_tiny_autoencoder) {
|
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
|
||||||
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
|
|
||||||
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
|
|
||||||
} else {
|
|
||||||
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// mask after vae
|
// mask after vae
|
||||||
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||||
@ -2458,7 +2507,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
LOG_INFO("TXT2IMG");
|
LOG_INFO("TXT2IMG");
|
||||||
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
@ -2497,23 +2545,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
1);
|
1);
|
||||||
sd_image_to_tensor(*ref_images[i], img);
|
sd_image_to_tensor(*ref_images[i], img);
|
||||||
|
|
||||||
ggml_tensor* latent = NULL;
|
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
||||||
if (sd_ctx->sd->use_tiny_autoencoder) {
|
|
||||||
latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
|
||||||
} else if (sd_ctx->sd->version == VERSION_SD1_PIX2PIX) {
|
|
||||||
latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
|
||||||
latent = ggml_view_3d(work_ctx,
|
|
||||||
latent,
|
|
||||||
latent->ne[0],
|
|
||||||
latent->ne[1],
|
|
||||||
latent->ne[2] / 2,
|
|
||||||
latent->nb[1],
|
|
||||||
latent->nb[2],
|
|
||||||
0);
|
|
||||||
} else {
|
|
||||||
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
|
||||||
latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
|
|
||||||
}
|
|
||||||
ref_latents.push_back(latent);
|
ref_latents.push_back(latent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2675,8 +2707,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
int64_t t2 = ggml_time_ms();
|
int64_t t2 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
|
|
||||||
sd_ctx->sd->process_latent_in(concat_latent);
|
|
||||||
|
|
||||||
ggml_tensor* concat_mask = ggml_new_tensor_4d(work_ctx,
|
ggml_tensor* concat_mask = ggml_new_tensor_4d(work_ctx,
|
||||||
GGML_TYPE_F32,
|
GGML_TYPE_F32,
|
||||||
concat_latent->ne[0],
|
concat_latent->ne[0],
|
||||||
@ -2702,7 +2732,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
sd_image_to_tensor(sd_vid_gen_params->init_image, init_img);
|
sd_image_to_tensor(sd_vid_gen_params->init_image, init_img);
|
||||||
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
||||||
|
|
||||||
auto init_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||||
|
|
||||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||||
@ -2733,7 +2763,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
ref_img = ggml_reshape_4d(work_ctx, ref_img, width, height, 1, 3);
|
ref_img = ggml_reshape_4d(work_ctx, ref_img, width, height, 1, 3);
|
||||||
|
|
||||||
ref_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, ref_img); // [b*c, 1, h/16, w/16]
|
ref_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, ref_img); // [b*c, 1, h/16, w/16]
|
||||||
sd_ctx->sd->process_latent_in(ref_image_latent);
|
|
||||||
auto zero_latent = ggml_dup_tensor(work_ctx, ref_image_latent);
|
auto zero_latent = ggml_dup_tensor(work_ctx, ref_image_latent);
|
||||||
ggml_set_f32(zero_latent, 0.f);
|
ggml_set_f32(zero_latent, 0.f);
|
||||||
ref_image_latent = ggml_tensor_concat(work_ctx, ref_image_latent, zero_latent, 3); // [b*2*c, 1, h/16, w/16]
|
ref_image_latent = ggml_tensor_concat(work_ctx, ref_image_latent, zero_latent, 3); // [b*2*c, 1, h/16, w/16]
|
||||||
@ -2765,9 +2794,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8]
|
inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8]
|
||||||
reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8]
|
reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8]
|
||||||
|
|
||||||
sd_ctx->sd->process_latent_in(inactive);
|
|
||||||
sd_ctx->sd->process_latent_in(reactive);
|
|
||||||
|
|
||||||
int64_t length = inactive->ne[2];
|
int64_t length = inactive->ne[2];
|
||||||
if (ref_image_latent) {
|
if (ref_image_latent) {
|
||||||
length += 1;
|
length += 1;
|
||||||
|
|||||||
@ -131,6 +131,7 @@ typedef struct {
|
|||||||
const char* clip_g_path;
|
const char* clip_g_path;
|
||||||
const char* clip_vision_path;
|
const char* clip_vision_path;
|
||||||
const char* t5xxl_path;
|
const char* t5xxl_path;
|
||||||
|
const char* qwen2vl_path;
|
||||||
const char* diffusion_model_path;
|
const char* diffusion_model_path;
|
||||||
const char* high_noise_diffusion_model_path;
|
const char* high_noise_diffusion_model_path;
|
||||||
const char* vae_path;
|
const char* vae_path;
|
||||||
|
|||||||
985
tokenize_util.cpp
Normal file
985
tokenize_util.cpp
Normal file
@ -0,0 +1,985 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tokenize_util.h"
|
||||||
|
|
||||||
|
bool is_number(char32_t ch) {
|
||||||
|
return (ch >= U'0' && ch <= U'9');
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_letter(char32_t ch) {
|
||||||
|
static const struct { char32_t start, end; } ranges[] = {
|
||||||
|
{0x41, 0x5A},
|
||||||
|
{0x61, 0x7A},
|
||||||
|
{0xAA, 0xAA},
|
||||||
|
{0xB5, 0xB5},
|
||||||
|
{0xBA, 0xBA},
|
||||||
|
{0xC0, 0xD6},
|
||||||
|
{0xD8, 0xF6},
|
||||||
|
{0xF8, 0x2C1},
|
||||||
|
{0x2C6, 0x2D1},
|
||||||
|
{0x2E0, 0x2E4},
|
||||||
|
{0x2EC, 0x2EC},
|
||||||
|
{0x2EE, 0x2EE},
|
||||||
|
{0x370, 0x374},
|
||||||
|
{0x376, 0x377},
|
||||||
|
{0x37A, 0x37D},
|
||||||
|
{0x37F, 0x37F},
|
||||||
|
{0x386, 0x386},
|
||||||
|
{0x388, 0x38A},
|
||||||
|
{0x38C, 0x38C},
|
||||||
|
{0x38E, 0x3A1},
|
||||||
|
{0x3A3, 0x3F5},
|
||||||
|
{0x3F7, 0x481},
|
||||||
|
{0x48A, 0x52F},
|
||||||
|
{0x531, 0x556},
|
||||||
|
{0x559, 0x559},
|
||||||
|
{0x560, 0x588},
|
||||||
|
{0x5D0, 0x5EA},
|
||||||
|
{0x5EF, 0x5F2},
|
||||||
|
{0x620, 0x64A},
|
||||||
|
{0x66E, 0x66F},
|
||||||
|
{0x671, 0x6D3},
|
||||||
|
{0x6D5, 0x6D5},
|
||||||
|
{0x6E5, 0x6E6},
|
||||||
|
{0x6EE, 0x6EF},
|
||||||
|
{0x6FA, 0x6FC},
|
||||||
|
{0x6FF, 0x6FF},
|
||||||
|
{0x710, 0x710},
|
||||||
|
{0x712, 0x72F},
|
||||||
|
{0x74D, 0x7A5},
|
||||||
|
{0x7B1, 0x7B1},
|
||||||
|
{0x7CA, 0x7EA},
|
||||||
|
{0x7F4, 0x7F5},
|
||||||
|
{0x7FA, 0x7FA},
|
||||||
|
{0x800, 0x815},
|
||||||
|
{0x81A, 0x81A},
|
||||||
|
{0x824, 0x824},
|
||||||
|
{0x828, 0x828},
|
||||||
|
{0x840, 0x858},
|
||||||
|
{0x860, 0x86A},
|
||||||
|
{0x870, 0x887},
|
||||||
|
{0x889, 0x88F},
|
||||||
|
{0x8A0, 0x8C9},
|
||||||
|
{0x904, 0x939},
|
||||||
|
{0x93D, 0x93D},
|
||||||
|
{0x950, 0x950},
|
||||||
|
{0x958, 0x961},
|
||||||
|
{0x971, 0x980},
|
||||||
|
{0x985, 0x98C},
|
||||||
|
{0x98F, 0x990},
|
||||||
|
{0x993, 0x9A8},
|
||||||
|
{0x9AA, 0x9B0},
|
||||||
|
{0x9B2, 0x9B2},
|
||||||
|
{0x9B6, 0x9B9},
|
||||||
|
{0x9BD, 0x9BD},
|
||||||
|
{0x9CE, 0x9CE},
|
||||||
|
{0x9DC, 0x9DD},
|
||||||
|
{0x9DF, 0x9E1},
|
||||||
|
{0x9F0, 0x9F1},
|
||||||
|
{0x9FC, 0x9FC},
|
||||||
|
{0xA05, 0xA0A},
|
||||||
|
{0xA0F, 0xA10},
|
||||||
|
{0xA13, 0xA28},
|
||||||
|
{0xA2A, 0xA30},
|
||||||
|
{0xA32, 0xA33},
|
||||||
|
{0xA35, 0xA36},
|
||||||
|
{0xA38, 0xA39},
|
||||||
|
{0xA59, 0xA5C},
|
||||||
|
{0xA5E, 0xA5E},
|
||||||
|
{0xA72, 0xA74},
|
||||||
|
{0xA85, 0xA8D},
|
||||||
|
{0xA8F, 0xA91},
|
||||||
|
{0xA93, 0xAA8},
|
||||||
|
{0xAAA, 0xAB0},
|
||||||
|
{0xAB2, 0xAB3},
|
||||||
|
{0xAB5, 0xAB9},
|
||||||
|
{0xABD, 0xABD},
|
||||||
|
{0xAD0, 0xAD0},
|
||||||
|
{0xAE0, 0xAE1},
|
||||||
|
{0xAF9, 0xAF9},
|
||||||
|
{0xB05, 0xB0C},
|
||||||
|
{0xB0F, 0xB10},
|
||||||
|
{0xB13, 0xB28},
|
||||||
|
{0xB2A, 0xB30},
|
||||||
|
{0xB32, 0xB33},
|
||||||
|
{0xB35, 0xB39},
|
||||||
|
{0xB3D, 0xB3D},
|
||||||
|
{0xB5C, 0xB5D},
|
||||||
|
{0xB5F, 0xB61},
|
||||||
|
{0xB71, 0xB71},
|
||||||
|
{0xB83, 0xB83},
|
||||||
|
{0xB85, 0xB8A},
|
||||||
|
{0xB8E, 0xB90},
|
||||||
|
{0xB92, 0xB95},
|
||||||
|
{0xB99, 0xB9A},
|
||||||
|
{0xB9C, 0xB9C},
|
||||||
|
{0xB9E, 0xB9F},
|
||||||
|
{0xBA3, 0xBA4},
|
||||||
|
{0xBA8, 0xBAA},
|
||||||
|
{0xBAE, 0xBB9},
|
||||||
|
{0xBD0, 0xBD0},
|
||||||
|
{0xC05, 0xC0C},
|
||||||
|
{0xC0E, 0xC10},
|
||||||
|
{0xC12, 0xC28},
|
||||||
|
{0xC2A, 0xC39},
|
||||||
|
{0xC3D, 0xC3D},
|
||||||
|
{0xC58, 0xC5A},
|
||||||
|
{0xC5C, 0xC5D},
|
||||||
|
{0xC60, 0xC61},
|
||||||
|
{0xC80, 0xC80},
|
||||||
|
{0xC85, 0xC8C},
|
||||||
|
{0xC8E, 0xC90},
|
||||||
|
{0xC92, 0xCA8},
|
||||||
|
{0xCAA, 0xCB3},
|
||||||
|
{0xCB5, 0xCB9},
|
||||||
|
{0xCBD, 0xCBD},
|
||||||
|
{0xCDC, 0xCDE},
|
||||||
|
{0xCE0, 0xCE1},
|
||||||
|
{0xCF1, 0xCF2},
|
||||||
|
{0xD04, 0xD0C},
|
||||||
|
{0xD0E, 0xD10},
|
||||||
|
{0xD12, 0xD3A},
|
||||||
|
{0xD3D, 0xD3D},
|
||||||
|
{0xD4E, 0xD4E},
|
||||||
|
{0xD54, 0xD56},
|
||||||
|
{0xD5F, 0xD61},
|
||||||
|
{0xD7A, 0xD7F},
|
||||||
|
{0xD85, 0xD96},
|
||||||
|
{0xD9A, 0xDB1},
|
||||||
|
{0xDB3, 0xDBB},
|
||||||
|
{0xDBD, 0xDBD},
|
||||||
|
{0xDC0, 0xDC6},
|
||||||
|
{0xE01, 0xE30},
|
||||||
|
{0xE32, 0xE33},
|
||||||
|
{0xE40, 0xE46},
|
||||||
|
{0xE81, 0xE82},
|
||||||
|
{0xE84, 0xE84},
|
||||||
|
{0xE86, 0xE8A},
|
||||||
|
{0xE8C, 0xEA3},
|
||||||
|
{0xEA5, 0xEA5},
|
||||||
|
{0xEA7, 0xEB0},
|
||||||
|
{0xEB2, 0xEB3},
|
||||||
|
{0xEBD, 0xEBD},
|
||||||
|
{0xEC0, 0xEC4},
|
||||||
|
{0xEC6, 0xEC6},
|
||||||
|
{0xEDC, 0xEDF},
|
||||||
|
{0xF00, 0xF00},
|
||||||
|
{0xF40, 0xF47},
|
||||||
|
{0xF49, 0xF6C},
|
||||||
|
{0xF88, 0xF8C},
|
||||||
|
{0x1000, 0x102A},
|
||||||
|
{0x103F, 0x103F},
|
||||||
|
{0x1050, 0x1055},
|
||||||
|
{0x105A, 0x105D},
|
||||||
|
{0x1061, 0x1061},
|
||||||
|
{0x1065, 0x1066},
|
||||||
|
{0x106E, 0x1070},
|
||||||
|
{0x1075, 0x1081},
|
||||||
|
{0x108E, 0x108E},
|
||||||
|
{0x10A0, 0x10C5},
|
||||||
|
{0x10C7, 0x10C7},
|
||||||
|
{0x10CD, 0x10CD},
|
||||||
|
{0x10D0, 0x10FA},
|
||||||
|
{0x10FC, 0x1248},
|
||||||
|
{0x124A, 0x124D},
|
||||||
|
{0x1250, 0x1256},
|
||||||
|
{0x1258, 0x1258},
|
||||||
|
{0x125A, 0x125D},
|
||||||
|
{0x1260, 0x1288},
|
||||||
|
{0x128A, 0x128D},
|
||||||
|
{0x1290, 0x12B0},
|
||||||
|
{0x12B2, 0x12B5},
|
||||||
|
{0x12B8, 0x12BE},
|
||||||
|
{0x12C0, 0x12C0},
|
||||||
|
{0x12C2, 0x12C5},
|
||||||
|
{0x12C8, 0x12D6},
|
||||||
|
{0x12D8, 0x1310},
|
||||||
|
{0x1312, 0x1315},
|
||||||
|
{0x1318, 0x135A},
|
||||||
|
{0x1380, 0x138F},
|
||||||
|
{0x13A0, 0x13F5},
|
||||||
|
{0x13F8, 0x13FD},
|
||||||
|
{0x1401, 0x166C},
|
||||||
|
{0x166F, 0x167F},
|
||||||
|
{0x1681, 0x169A},
|
||||||
|
{0x16A0, 0x16EA},
|
||||||
|
{0x16F1, 0x16F8},
|
||||||
|
{0x1700, 0x1711},
|
||||||
|
{0x171F, 0x1731},
|
||||||
|
{0x1740, 0x1751},
|
||||||
|
{0x1760, 0x176C},
|
||||||
|
{0x176E, 0x1770},
|
||||||
|
{0x1780, 0x17B3},
|
||||||
|
{0x17D7, 0x17D7},
|
||||||
|
{0x17DC, 0x17DC},
|
||||||
|
{0x1820, 0x1878},
|
||||||
|
{0x1880, 0x1884},
|
||||||
|
{0x1887, 0x18A8},
|
||||||
|
{0x18AA, 0x18AA},
|
||||||
|
{0x18B0, 0x18F5},
|
||||||
|
{0x1900, 0x191E},
|
||||||
|
{0x1950, 0x196D},
|
||||||
|
{0x1970, 0x1974},
|
||||||
|
{0x1980, 0x19AB},
|
||||||
|
{0x19B0, 0x19C9},
|
||||||
|
{0x1A00, 0x1A16},
|
||||||
|
{0x1A20, 0x1A54},
|
||||||
|
{0x1AA7, 0x1AA7},
|
||||||
|
{0x1B05, 0x1B33},
|
||||||
|
{0x1B45, 0x1B4C},
|
||||||
|
{0x1B83, 0x1BA0},
|
||||||
|
{0x1BAE, 0x1BAF},
|
||||||
|
{0x1BBA, 0x1BE5},
|
||||||
|
{0x1C00, 0x1C23},
|
||||||
|
{0x1C4D, 0x1C4F},
|
||||||
|
{0x1C5A, 0x1C7D},
|
||||||
|
{0x1C80, 0x1C8A},
|
||||||
|
{0x1C90, 0x1CBA},
|
||||||
|
{0x1CBD, 0x1CBF},
|
||||||
|
{0x1CE9, 0x1CEC},
|
||||||
|
{0x1CEE, 0x1CF3},
|
||||||
|
{0x1CF5, 0x1CF6},
|
||||||
|
{0x1CFA, 0x1CFA},
|
||||||
|
{0x1D00, 0x1DBF},
|
||||||
|
{0x1E00, 0x1F15},
|
||||||
|
{0x1F18, 0x1F1D},
|
||||||
|
{0x1F20, 0x1F45},
|
||||||
|
{0x1F48, 0x1F4D},
|
||||||
|
{0x1F50, 0x1F57},
|
||||||
|
{0x1F59, 0x1F59},
|
||||||
|
{0x1F5B, 0x1F5B},
|
||||||
|
{0x1F5D, 0x1F5D},
|
||||||
|
{0x1F5F, 0x1F7D},
|
||||||
|
{0x1F80, 0x1FB4},
|
||||||
|
{0x1FB6, 0x1FBC},
|
||||||
|
{0x1FBE, 0x1FBE},
|
||||||
|
{0x1FC2, 0x1FC4},
|
||||||
|
{0x1FC6, 0x1FCC},
|
||||||
|
{0x1FD0, 0x1FD3},
|
||||||
|
{0x1FD6, 0x1FDB},
|
||||||
|
{0x1FE0, 0x1FEC},
|
||||||
|
{0x1FF2, 0x1FF4},
|
||||||
|
{0x1FF6, 0x1FFC},
|
||||||
|
{0x2071, 0x2071},
|
||||||
|
{0x207F, 0x207F},
|
||||||
|
{0x2090, 0x209C},
|
||||||
|
{0x2102, 0x2102},
|
||||||
|
{0x2107, 0x2107},
|
||||||
|
{0x210A, 0x2113},
|
||||||
|
{0x2115, 0x2115},
|
||||||
|
{0x2119, 0x211D},
|
||||||
|
{0x2124, 0x2124},
|
||||||
|
{0x2126, 0x2126},
|
||||||
|
{0x2128, 0x2128},
|
||||||
|
{0x212A, 0x212D},
|
||||||
|
{0x212F, 0x2139},
|
||||||
|
{0x213C, 0x213F},
|
||||||
|
{0x2145, 0x2149},
|
||||||
|
{0x214E, 0x214E},
|
||||||
|
{0x2183, 0x2184},
|
||||||
|
{0x2C00, 0x2CE4},
|
||||||
|
{0x2CEB, 0x2CEE},
|
||||||
|
{0x2CF2, 0x2CF3},
|
||||||
|
{0x2D00, 0x2D25},
|
||||||
|
{0x2D27, 0x2D27},
|
||||||
|
{0x2D2D, 0x2D2D},
|
||||||
|
{0x2D30, 0x2D67},
|
||||||
|
{0x2D6F, 0x2D6F},
|
||||||
|
{0x2D80, 0x2D96},
|
||||||
|
{0x2DA0, 0x2DA6},
|
||||||
|
{0x2DA8, 0x2DAE},
|
||||||
|
{0x2DB0, 0x2DB6},
|
||||||
|
{0x2DB8, 0x2DBE},
|
||||||
|
{0x2DC0, 0x2DC6},
|
||||||
|
{0x2DC8, 0x2DCE},
|
||||||
|
{0x2DD0, 0x2DD6},
|
||||||
|
{0x2DD8, 0x2DDE},
|
||||||
|
{0x2E2F, 0x2E2F},
|
||||||
|
{0x3005, 0x3006},
|
||||||
|
{0x3031, 0x3035},
|
||||||
|
{0x303B, 0x303C},
|
||||||
|
{0x3041, 0x3096},
|
||||||
|
{0x309D, 0x309F},
|
||||||
|
{0x30A1, 0x30FA},
|
||||||
|
{0x30FC, 0x30FF},
|
||||||
|
{0x3105, 0x312F},
|
||||||
|
{0x3131, 0x318E},
|
||||||
|
{0x31A0, 0x31BF},
|
||||||
|
{0x31F0, 0x31FF},
|
||||||
|
{0x3400, 0x4DBF},
|
||||||
|
{0x4E00, 0xA48C},
|
||||||
|
{0xA4D0, 0xA4FD},
|
||||||
|
{0xA500, 0xA60C},
|
||||||
|
{0xA610, 0xA61F},
|
||||||
|
{0xA62A, 0xA62B},
|
||||||
|
{0xA640, 0xA66E},
|
||||||
|
{0xA67F, 0xA69D},
|
||||||
|
{0xA6A0, 0xA6E5},
|
||||||
|
{0xA717, 0xA71F},
|
||||||
|
{0xA722, 0xA788},
|
||||||
|
{0xA78B, 0xA7DC},
|
||||||
|
{0xA7F1, 0xA801},
|
||||||
|
{0xA803, 0xA805},
|
||||||
|
{0xA807, 0xA80A},
|
||||||
|
{0xA80C, 0xA822},
|
||||||
|
{0xA840, 0xA873},
|
||||||
|
{0xA882, 0xA8B3},
|
||||||
|
{0xA8F2, 0xA8F7},
|
||||||
|
{0xA8FB, 0xA8FB},
|
||||||
|
{0xA8FD, 0xA8FE},
|
||||||
|
{0xA90A, 0xA925},
|
||||||
|
{0xA930, 0xA946},
|
||||||
|
{0xA960, 0xA97C},
|
||||||
|
{0xA984, 0xA9B2},
|
||||||
|
{0xA9CF, 0xA9CF},
|
||||||
|
{0xA9E0, 0xA9E4},
|
||||||
|
{0xA9E6, 0xA9EF},
|
||||||
|
{0xA9FA, 0xA9FE},
|
||||||
|
{0xAA00, 0xAA28},
|
||||||
|
{0xAA40, 0xAA42},
|
||||||
|
{0xAA44, 0xAA4B},
|
||||||
|
{0xAA60, 0xAA76},
|
||||||
|
{0xAA7A, 0xAA7A},
|
||||||
|
{0xAA7E, 0xAAAF},
|
||||||
|
{0xAAB1, 0xAAB1},
|
||||||
|
{0xAAB5, 0xAAB6},
|
||||||
|
{0xAAB9, 0xAABD},
|
||||||
|
{0xAAC0, 0xAAC0},
|
||||||
|
{0xAAC2, 0xAAC2},
|
||||||
|
{0xAADB, 0xAADD},
|
||||||
|
{0xAAE0, 0xAAEA},
|
||||||
|
{0xAAF2, 0xAAF4},
|
||||||
|
{0xAB01, 0xAB06},
|
||||||
|
{0xAB09, 0xAB0E},
|
||||||
|
{0xAB11, 0xAB16},
|
||||||
|
{0xAB20, 0xAB26},
|
||||||
|
{0xAB28, 0xAB2E},
|
||||||
|
{0xAB30, 0xAB5A},
|
||||||
|
{0xAB5C, 0xAB69},
|
||||||
|
{0xAB70, 0xABE2},
|
||||||
|
{0xAC00, 0xD7A3},
|
||||||
|
{0xD7B0, 0xD7C6},
|
||||||
|
{0xD7CB, 0xD7FB},
|
||||||
|
{0xF900, 0xFA6D},
|
||||||
|
{0xFA70, 0xFAD9},
|
||||||
|
{0xFB00, 0xFB06},
|
||||||
|
{0xFB13, 0xFB17},
|
||||||
|
{0xFB1D, 0xFB1D},
|
||||||
|
{0xFB1F, 0xFB28},
|
||||||
|
{0xFB2A, 0xFB36},
|
||||||
|
{0xFB38, 0xFB3C},
|
||||||
|
{0xFB3E, 0xFB3E},
|
||||||
|
{0xFB40, 0xFB41},
|
||||||
|
{0xFB43, 0xFB44},
|
||||||
|
{0xFB46, 0xFBB1},
|
||||||
|
{0xFBD3, 0xFD3D},
|
||||||
|
{0xFD50, 0xFD8F},
|
||||||
|
{0xFD92, 0xFDC7},
|
||||||
|
{0xFDF0, 0xFDFB},
|
||||||
|
{0xFE70, 0xFE74},
|
||||||
|
{0xFE76, 0xFEFC},
|
||||||
|
{0xFF21, 0xFF3A},
|
||||||
|
{0xFF41, 0xFF5A},
|
||||||
|
{0xFF66, 0xFFBE},
|
||||||
|
{0xFFC2, 0xFFC7},
|
||||||
|
{0xFFCA, 0xFFCF},
|
||||||
|
{0xFFD2, 0xFFD7},
|
||||||
|
{0xFFDA, 0xFFDC},
|
||||||
|
{0x10000, 0x1000B},
|
||||||
|
{0x1000D, 0x10026},
|
||||||
|
{0x10028, 0x1003A},
|
||||||
|
{0x1003C, 0x1003D},
|
||||||
|
{0x1003F, 0x1004D},
|
||||||
|
{0x10050, 0x1005D},
|
||||||
|
{0x10080, 0x100FA},
|
||||||
|
{0x10280, 0x1029C},
|
||||||
|
{0x102A0, 0x102D0},
|
||||||
|
{0x10300, 0x1031F},
|
||||||
|
{0x1032D, 0x10340},
|
||||||
|
{0x10342, 0x10349},
|
||||||
|
{0x10350, 0x10375},
|
||||||
|
{0x10380, 0x1039D},
|
||||||
|
{0x103A0, 0x103C3},
|
||||||
|
{0x103C8, 0x103CF},
|
||||||
|
{0x10400, 0x1049D},
|
||||||
|
{0x104B0, 0x104D3},
|
||||||
|
{0x104D8, 0x104FB},
|
||||||
|
{0x10500, 0x10527},
|
||||||
|
{0x10530, 0x10563},
|
||||||
|
{0x10570, 0x1057A},
|
||||||
|
{0x1057C, 0x1058A},
|
||||||
|
{0x1058C, 0x10592},
|
||||||
|
{0x10594, 0x10595},
|
||||||
|
{0x10597, 0x105A1},
|
||||||
|
{0x105A3, 0x105B1},
|
||||||
|
{0x105B3, 0x105B9},
|
||||||
|
{0x105BB, 0x105BC},
|
||||||
|
{0x105C0, 0x105F3},
|
||||||
|
{0x10600, 0x10736},
|
||||||
|
{0x10740, 0x10755},
|
||||||
|
{0x10760, 0x10767},
|
||||||
|
{0x10780, 0x10785},
|
||||||
|
{0x10787, 0x107B0},
|
||||||
|
{0x107B2, 0x107BA},
|
||||||
|
{0x10800, 0x10805},
|
||||||
|
{0x10808, 0x10808},
|
||||||
|
{0x1080A, 0x10835},
|
||||||
|
{0x10837, 0x10838},
|
||||||
|
{0x1083C, 0x1083C},
|
||||||
|
{0x1083F, 0x10855},
|
||||||
|
{0x10860, 0x10876},
|
||||||
|
{0x10880, 0x1089E},
|
||||||
|
{0x108E0, 0x108F2},
|
||||||
|
{0x108F4, 0x108F5},
|
||||||
|
{0x10900, 0x10915},
|
||||||
|
{0x10920, 0x10939},
|
||||||
|
{0x10940, 0x10959},
|
||||||
|
{0x10980, 0x109B7},
|
||||||
|
{0x109BE, 0x109BF},
|
||||||
|
{0x10A00, 0x10A00},
|
||||||
|
{0x10A10, 0x10A13},
|
||||||
|
{0x10A15, 0x10A17},
|
||||||
|
{0x10A19, 0x10A35},
|
||||||
|
{0x10A60, 0x10A7C},
|
||||||
|
{0x10A80, 0x10A9C},
|
||||||
|
{0x10AC0, 0x10AC7},
|
||||||
|
{0x10AC9, 0x10AE4},
|
||||||
|
{0x10B00, 0x10B35},
|
||||||
|
{0x10B40, 0x10B55},
|
||||||
|
{0x10B60, 0x10B72},
|
||||||
|
{0x10B80, 0x10B91},
|
||||||
|
{0x10C00, 0x10C48},
|
||||||
|
{0x10C80, 0x10CB2},
|
||||||
|
{0x10CC0, 0x10CF2},
|
||||||
|
{0x10D00, 0x10D23},
|
||||||
|
{0x10D4A, 0x10D65},
|
||||||
|
{0x10D6F, 0x10D85},
|
||||||
|
{0x10E80, 0x10EA9},
|
||||||
|
{0x10EB0, 0x10EB1},
|
||||||
|
{0x10EC2, 0x10EC7},
|
||||||
|
{0x10F00, 0x10F1C},
|
||||||
|
{0x10F27, 0x10F27},
|
||||||
|
{0x10F30, 0x10F45},
|
||||||
|
{0x10F70, 0x10F81},
|
||||||
|
{0x10FB0, 0x10FC4},
|
||||||
|
{0x10FE0, 0x10FF6},
|
||||||
|
{0x11003, 0x11037},
|
||||||
|
{0x11071, 0x11072},
|
||||||
|
{0x11075, 0x11075},
|
||||||
|
{0x11083, 0x110AF},
|
||||||
|
{0x110D0, 0x110E8},
|
||||||
|
{0x11103, 0x11126},
|
||||||
|
{0x11144, 0x11144},
|
||||||
|
{0x11147, 0x11147},
|
||||||
|
{0x11150, 0x11172},
|
||||||
|
{0x11176, 0x11176},
|
||||||
|
{0x11183, 0x111B2},
|
||||||
|
{0x111C1, 0x111C4},
|
||||||
|
{0x111DA, 0x111DA},
|
||||||
|
{0x111DC, 0x111DC},
|
||||||
|
{0x11200, 0x11211},
|
||||||
|
{0x11213, 0x1122B},
|
||||||
|
{0x1123F, 0x11240},
|
||||||
|
{0x11280, 0x11286},
|
||||||
|
{0x11288, 0x11288},
|
||||||
|
{0x1128A, 0x1128D},
|
||||||
|
{0x1128F, 0x1129D},
|
||||||
|
{0x1129F, 0x112A8},
|
||||||
|
{0x112B0, 0x112DE},
|
||||||
|
{0x11305, 0x1130C},
|
||||||
|
{0x1130F, 0x11310},
|
||||||
|
{0x11313, 0x11328},
|
||||||
|
{0x1132A, 0x11330},
|
||||||
|
{0x11332, 0x11333},
|
||||||
|
{0x11335, 0x11339},
|
||||||
|
{0x1133D, 0x1133D},
|
||||||
|
{0x11350, 0x11350},
|
||||||
|
{0x1135D, 0x11361},
|
||||||
|
{0x11380, 0x11389},
|
||||||
|
{0x1138B, 0x1138B},
|
||||||
|
{0x1138E, 0x1138E},
|
||||||
|
{0x11390, 0x113B5},
|
||||||
|
{0x113B7, 0x113B7},
|
||||||
|
{0x113D1, 0x113D1},
|
||||||
|
{0x113D3, 0x113D3},
|
||||||
|
{0x11400, 0x11434},
|
||||||
|
{0x11447, 0x1144A},
|
||||||
|
{0x1145F, 0x11461},
|
||||||
|
{0x11480, 0x114AF},
|
||||||
|
{0x114C4, 0x114C5},
|
||||||
|
{0x114C7, 0x114C7},
|
||||||
|
{0x11580, 0x115AE},
|
||||||
|
{0x115D8, 0x115DB},
|
||||||
|
{0x11600, 0x1162F},
|
||||||
|
{0x11644, 0x11644},
|
||||||
|
{0x11680, 0x116AA},
|
||||||
|
{0x116B8, 0x116B8},
|
||||||
|
{0x11700, 0x1171A},
|
||||||
|
{0x11740, 0x11746},
|
||||||
|
{0x11800, 0x1182B},
|
||||||
|
{0x118A0, 0x118DF},
|
||||||
|
{0x118FF, 0x11906},
|
||||||
|
{0x11909, 0x11909},
|
||||||
|
{0x1190C, 0x11913},
|
||||||
|
{0x11915, 0x11916},
|
||||||
|
{0x11918, 0x1192F},
|
||||||
|
{0x1193F, 0x1193F},
|
||||||
|
{0x11941, 0x11941},
|
||||||
|
{0x119A0, 0x119A7},
|
||||||
|
{0x119AA, 0x119D0},
|
||||||
|
{0x119E1, 0x119E1},
|
||||||
|
{0x119E3, 0x119E3},
|
||||||
|
{0x11A00, 0x11A00},
|
||||||
|
{0x11A0B, 0x11A32},
|
||||||
|
{0x11A3A, 0x11A3A},
|
||||||
|
{0x11A50, 0x11A50},
|
||||||
|
{0x11A5C, 0x11A89},
|
||||||
|
{0x11A9D, 0x11A9D},
|
||||||
|
{0x11AB0, 0x11AF8},
|
||||||
|
{0x11BC0, 0x11BE0},
|
||||||
|
{0x11C00, 0x11C08},
|
||||||
|
{0x11C0A, 0x11C2E},
|
||||||
|
{0x11C40, 0x11C40},
|
||||||
|
{0x11C72, 0x11C8F},
|
||||||
|
{0x11D00, 0x11D06},
|
||||||
|
{0x11D08, 0x11D09},
|
||||||
|
{0x11D0B, 0x11D30},
|
||||||
|
{0x11D46, 0x11D46},
|
||||||
|
{0x11D60, 0x11D65},
|
||||||
|
{0x11D67, 0x11D68},
|
||||||
|
{0x11D6A, 0x11D89},
|
||||||
|
{0x11D98, 0x11D98},
|
||||||
|
{0x11DB0, 0x11DDB},
|
||||||
|
{0x11EE0, 0x11EF2},
|
||||||
|
{0x11F02, 0x11F02},
|
||||||
|
{0x11F04, 0x11F10},
|
||||||
|
{0x11F12, 0x11F33},
|
||||||
|
{0x11FB0, 0x11FB0},
|
||||||
|
{0x12000, 0x12399},
|
||||||
|
{0x12480, 0x12543},
|
||||||
|
{0x12F90, 0x12FF0},
|
||||||
|
{0x13000, 0x1342F},
|
||||||
|
{0x13441, 0x13446},
|
||||||
|
{0x13460, 0x143FA},
|
||||||
|
{0x14400, 0x14646},
|
||||||
|
{0x16100, 0x1611D},
|
||||||
|
{0x16800, 0x16A38},
|
||||||
|
{0x16A40, 0x16A5E},
|
||||||
|
{0x16A70, 0x16ABE},
|
||||||
|
{0x16AD0, 0x16AED},
|
||||||
|
{0x16B00, 0x16B2F},
|
||||||
|
{0x16B40, 0x16B43},
|
||||||
|
{0x16B63, 0x16B77},
|
||||||
|
{0x16B7D, 0x16B8F},
|
||||||
|
{0x16D40, 0x16D6C},
|
||||||
|
{0x16E40, 0x16E7F},
|
||||||
|
{0x16EA0, 0x16EB8},
|
||||||
|
{0x16EBB, 0x16ED3},
|
||||||
|
{0x16F00, 0x16F4A},
|
||||||
|
{0x16F50, 0x16F50},
|
||||||
|
{0x16F93, 0x16F9F},
|
||||||
|
{0x16FE0, 0x16FE1},
|
||||||
|
{0x16FE3, 0x16FE3},
|
||||||
|
{0x16FF2, 0x16FF3},
|
||||||
|
{0x17000, 0x18CD5},
|
||||||
|
{0x18CFF, 0x18D1E},
|
||||||
|
{0x18D80, 0x18DF2},
|
||||||
|
{0x1AFF0, 0x1AFF3},
|
||||||
|
{0x1AFF5, 0x1AFFB},
|
||||||
|
{0x1AFFD, 0x1AFFE},
|
||||||
|
{0x1B000, 0x1B122},
|
||||||
|
{0x1B132, 0x1B132},
|
||||||
|
{0x1B150, 0x1B152},
|
||||||
|
{0x1B155, 0x1B155},
|
||||||
|
{0x1B164, 0x1B167},
|
||||||
|
{0x1B170, 0x1B2FB},
|
||||||
|
{0x1BC00, 0x1BC6A},
|
||||||
|
{0x1BC70, 0x1BC7C},
|
||||||
|
{0x1BC80, 0x1BC88},
|
||||||
|
{0x1BC90, 0x1BC99},
|
||||||
|
{0x1D400, 0x1D454},
|
||||||
|
{0x1D456, 0x1D49C},
|
||||||
|
{0x1D49E, 0x1D49F},
|
||||||
|
{0x1D4A2, 0x1D4A2},
|
||||||
|
{0x1D4A5, 0x1D4A6},
|
||||||
|
{0x1D4A9, 0x1D4AC},
|
||||||
|
{0x1D4AE, 0x1D4B9},
|
||||||
|
{0x1D4BB, 0x1D4BB},
|
||||||
|
{0x1D4BD, 0x1D4C3},
|
||||||
|
{0x1D4C5, 0x1D505},
|
||||||
|
{0x1D507, 0x1D50A},
|
||||||
|
{0x1D50D, 0x1D514},
|
||||||
|
{0x1D516, 0x1D51C},
|
||||||
|
{0x1D51E, 0x1D539},
|
||||||
|
{0x1D53B, 0x1D53E},
|
||||||
|
{0x1D540, 0x1D544},
|
||||||
|
{0x1D546, 0x1D546},
|
||||||
|
{0x1D54A, 0x1D550},
|
||||||
|
{0x1D552, 0x1D6A5},
|
||||||
|
{0x1D6A8, 0x1D6C0},
|
||||||
|
{0x1D6C2, 0x1D6DA},
|
||||||
|
{0x1D6DC, 0x1D6FA},
|
||||||
|
{0x1D6FC, 0x1D714},
|
||||||
|
{0x1D716, 0x1D734},
|
||||||
|
{0x1D736, 0x1D74E},
|
||||||
|
{0x1D750, 0x1D76E},
|
||||||
|
{0x1D770, 0x1D788},
|
||||||
|
{0x1D78A, 0x1D7A8},
|
||||||
|
{0x1D7AA, 0x1D7C2},
|
||||||
|
{0x1D7C4, 0x1D7CB},
|
||||||
|
{0x1DF00, 0x1DF1E},
|
||||||
|
{0x1DF25, 0x1DF2A},
|
||||||
|
{0x1E030, 0x1E06D},
|
||||||
|
{0x1E100, 0x1E12C},
|
||||||
|
{0x1E137, 0x1E13D},
|
||||||
|
{0x1E14E, 0x1E14E},
|
||||||
|
{0x1E290, 0x1E2AD},
|
||||||
|
{0x1E2C0, 0x1E2EB},
|
||||||
|
{0x1E4D0, 0x1E4EB},
|
||||||
|
{0x1E5D0, 0x1E5ED},
|
||||||
|
{0x1E5F0, 0x1E5F0},
|
||||||
|
{0x1E6C0, 0x1E6DE},
|
||||||
|
{0x1E6E0, 0x1E6E2},
|
||||||
|
{0x1E6E4, 0x1E6E5},
|
||||||
|
{0x1E6E7, 0x1E6ED},
|
||||||
|
{0x1E6F0, 0x1E6F4},
|
||||||
|
{0x1E6FE, 0x1E6FF},
|
||||||
|
{0x1E7E0, 0x1E7E6},
|
||||||
|
{0x1E7E8, 0x1E7EB},
|
||||||
|
{0x1E7ED, 0x1E7EE},
|
||||||
|
{0x1E7F0, 0x1E7FE},
|
||||||
|
{0x1E800, 0x1E8C4},
|
||||||
|
{0x1E900, 0x1E943},
|
||||||
|
{0x1E94B, 0x1E94B},
|
||||||
|
{0x1EE00, 0x1EE03},
|
||||||
|
{0x1EE05, 0x1EE1F},
|
||||||
|
{0x1EE21, 0x1EE22},
|
||||||
|
{0x1EE24, 0x1EE24},
|
||||||
|
{0x1EE27, 0x1EE27},
|
||||||
|
{0x1EE29, 0x1EE32},
|
||||||
|
{0x1EE34, 0x1EE37},
|
||||||
|
{0x1EE39, 0x1EE39},
|
||||||
|
{0x1EE3B, 0x1EE3B},
|
||||||
|
{0x1EE42, 0x1EE42},
|
||||||
|
{0x1EE47, 0x1EE47},
|
||||||
|
{0x1EE49, 0x1EE49},
|
||||||
|
{0x1EE4B, 0x1EE4B},
|
||||||
|
{0x1EE4D, 0x1EE4F},
|
||||||
|
{0x1EE51, 0x1EE52},
|
||||||
|
{0x1EE54, 0x1EE54},
|
||||||
|
{0x1EE57, 0x1EE57},
|
||||||
|
{0x1EE59, 0x1EE59},
|
||||||
|
{0x1EE5B, 0x1EE5B},
|
||||||
|
{0x1EE5D, 0x1EE5D},
|
||||||
|
{0x1EE5F, 0x1EE5F},
|
||||||
|
{0x1EE61, 0x1EE62},
|
||||||
|
{0x1EE64, 0x1EE64},
|
||||||
|
{0x1EE67, 0x1EE6A},
|
||||||
|
{0x1EE6C, 0x1EE72},
|
||||||
|
{0x1EE74, 0x1EE77},
|
||||||
|
{0x1EE79, 0x1EE7C},
|
||||||
|
{0x1EE7E, 0x1EE7E},
|
||||||
|
{0x1EE80, 0x1EE89},
|
||||||
|
{0x1EE8B, 0x1EE9B},
|
||||||
|
{0x1EEA1, 0x1EEA3},
|
||||||
|
{0x1EEA5, 0x1EEA9},
|
||||||
|
{0x1EEAB, 0x1EEBB},
|
||||||
|
{0x20000, 0x2A6DF},
|
||||||
|
{0x2A700, 0x2B81D},
|
||||||
|
{0x2B820, 0x2CEAD},
|
||||||
|
{0x2CEB0, 0x2EBE0},
|
||||||
|
{0x2EBF0, 0x2EE5D},
|
||||||
|
{0x2F800, 0x2FA1D},
|
||||||
|
{0x30000, 0x3134A},
|
||||||
|
{0x31350, 0x33479},
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto& r : ranges) {
|
||||||
|
if (ch >= r.start && ch <= r.end)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_space(char32_t cp) {
|
||||||
|
switch (cp) {
|
||||||
|
case 0x0009: // TAB \t
|
||||||
|
case 0x000A: // LF \n
|
||||||
|
case 0x000B: // VT
|
||||||
|
case 0x000C: // FF
|
||||||
|
case 0x000D: // CR \r
|
||||||
|
case 0x0020: // Space
|
||||||
|
case 0x00A0: // No-Break Space
|
||||||
|
case 0x1680: // Ogham Space Mark
|
||||||
|
case 0x2000: // En Quad
|
||||||
|
case 0x2001: // Em Quad
|
||||||
|
case 0x2002: // En Space
|
||||||
|
case 0x2003: // Em Space
|
||||||
|
case 0x2004: // Three-Per-Em Space
|
||||||
|
case 0x2005: // Four-Per-Em Space
|
||||||
|
case 0x2006: // Six-Per-Em Space
|
||||||
|
case 0x2007: // Figure Space
|
||||||
|
case 0x2008: // Punctuation Space
|
||||||
|
case 0x2009: // Thin Space
|
||||||
|
case 0x200A: // Hair Space
|
||||||
|
case 0x202F: // Narrow No-Break Space
|
||||||
|
case 0x205F: // Medium Mathematical Space
|
||||||
|
case 0x3000: // Ideographic Space
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string str_to_lower(const std::string& input) {
|
||||||
|
std::string result = input;
|
||||||
|
std::transform(result.begin(), result.end(), result.begin(),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// UTF-8 -> Unicode code points
|
||||||
|
std::vector<char32_t> utf8_to_codepoints(const std::string& str) {
|
||||||
|
std::vector<char32_t> codepoints;
|
||||||
|
size_t i = 0;
|
||||||
|
while (i < str.size()) {
|
||||||
|
unsigned char c = str[i];
|
||||||
|
char32_t cp = 0;
|
||||||
|
size_t extra_bytes = 0;
|
||||||
|
|
||||||
|
if ((c & 0x80) == 0)
|
||||||
|
cp = c;
|
||||||
|
else if ((c & 0xE0) == 0xC0) {
|
||||||
|
cp = c & 0x1F;
|
||||||
|
extra_bytes = 1;
|
||||||
|
} else if ((c & 0xF0) == 0xE0) {
|
||||||
|
cp = c & 0x0F;
|
||||||
|
extra_bytes = 2;
|
||||||
|
} else if ((c & 0xF8) == 0xF0) {
|
||||||
|
cp = c & 0x07;
|
||||||
|
extra_bytes = 3;
|
||||||
|
} else {
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
} // Invalid UTF-8
|
||||||
|
|
||||||
|
if (i + extra_bytes >= str.size())
|
||||||
|
break;
|
||||||
|
|
||||||
|
for (size_t j = 1; j <= extra_bytes; ++j)
|
||||||
|
cp = (cp << 6) | (str[i + j] & 0x3F);
|
||||||
|
|
||||||
|
codepoints.push_back(cp);
|
||||||
|
i += 1 + extra_bytes;
|
||||||
|
}
|
||||||
|
return codepoints;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unicode code point -> UTF-8
|
||||||
|
std::string codepoint_to_utf8(char32_t cp) {
|
||||||
|
std::string out;
|
||||||
|
if (cp <= 0x7F)
|
||||||
|
out.push_back(static_cast<char>(cp));
|
||||||
|
else if (cp <= 0x7FF) {
|
||||||
|
out.push_back(static_cast<char>(0xC0 | (cp >> 6)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | (cp & 0x3F)));
|
||||||
|
} else if (cp <= 0xFFFF) {
|
||||||
|
out.push_back(static_cast<char>(0xE0 | (cp >> 12)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | ((cp >> 6) & 0x3F)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | (cp & 0x3F)));
|
||||||
|
} else {
|
||||||
|
out.push_back(static_cast<char>(0xF0 | (cp >> 18)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | ((cp >> 12) & 0x3F)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | ((cp >> 6) & 0x3F)));
|
||||||
|
out.push_back(static_cast<char>(0x80 | (cp & 0x3F)));
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool starts_with(const std::vector<char32_t>& text,
|
||||||
|
const std::vector<char32_t>& prefix,
|
||||||
|
std::size_t index) {
|
||||||
|
if (index > text.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (prefix.size() > text.size() - index) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::equal(prefix.begin(), prefix.end(), text.begin() + index);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> token_split(const std::string& text) {
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
auto cps = utf8_to_codepoints(text);
|
||||||
|
size_t i = 0;
|
||||||
|
|
||||||
|
while (i < cps.size()) {
|
||||||
|
char32_t cp = cps[i];
|
||||||
|
|
||||||
|
// `(?i:'s|'t|'re|'ve|'m|'ll|'d)`
|
||||||
|
if (cp == U'\'' && i + 1 < cps.size()) {
|
||||||
|
std::string next = str_to_lower(codepoint_to_utf8(cps[i + 1]));
|
||||||
|
if (next == "s" || next == "t" || next == "m") {
|
||||||
|
tokens.push_back("'" + next);
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (i + 2 < cps.size()) {
|
||||||
|
next += str_to_lower(codepoint_to_utf8(cps[i + 2]));
|
||||||
|
if (next == "re" || next == "ve" || next == "ll" || next == "d") {
|
||||||
|
tokens.push_back("'" + next);
|
||||||
|
i += 3;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `\p{N}`
|
||||||
|
if (is_number(cp)) {
|
||||||
|
tokens.push_back(codepoint_to_utf8(cp));
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// `[^\r\n\p{L}\p{N}]?\p{L}+`
|
||||||
|
{
|
||||||
|
// `[^\r\n\p{L}\p{N}]\p{L}+`
|
||||||
|
if (!is_letter(cp) && cp != U'\r' && cp != U'\n' && i + 1 < cps.size() && is_letter(cps[i + 1])) {
|
||||||
|
std::string token = codepoint_to_utf8(cp);
|
||||||
|
++i;
|
||||||
|
|
||||||
|
while (i < cps.size() && is_letter(cps[i])) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
tokens.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// `\p{L}+`
|
||||||
|
if (is_letter(cp)) {
|
||||||
|
std::string token = codepoint_to_utf8(cp);
|
||||||
|
++i;
|
||||||
|
while (i < cps.size() && is_letter(cps[i])) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
tokens.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ` ?[^\s\p{L}\p{N}]+[\r\n]*`
|
||||||
|
{
|
||||||
|
// ` [^\s\p{L}\p{N}]+[\r\n]*`
|
||||||
|
if (cp == U' ' && i + 1 < cps.size() && !isspace(cps[i + 1]) && !is_letter(cps[i + 1]) && !is_number(cps[i + 1])) {
|
||||||
|
std::string token = codepoint_to_utf8(cp);
|
||||||
|
token += codepoint_to_utf8(cps[i + 1]);
|
||||||
|
i += 2;
|
||||||
|
|
||||||
|
while (i < cps.size() && !is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (i < cps.size() && (cps[i] == U'\r' || cps[i] == U'\n')) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// `[^\s\p{L}\p{N}]+[\r\n]*`
|
||||||
|
std::string token;
|
||||||
|
if (!is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) {
|
||||||
|
std::string token = codepoint_to_utf8(cp);
|
||||||
|
++i;
|
||||||
|
|
||||||
|
while (i < cps.size() && !is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (i < cps.size() && (cps[i] == U'\r' || cps[i] == U'\n')) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `\s*[\r\n]+|\s+(?!\S)|\s+`
|
||||||
|
if (is_space(cp)) {
|
||||||
|
std::string token = codepoint_to_utf8(cp);
|
||||||
|
++i;
|
||||||
|
|
||||||
|
while (i < cps.size() && is_space(cps[i])) {
|
||||||
|
token += codepoint_to_utf8(cps[i]);
|
||||||
|
++i;
|
||||||
|
if (cps[i] == U'\r' || cps[i] == U'\n') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(token);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> split_with_special_tokens(
|
||||||
|
const std::string& text,
|
||||||
|
const std::vector<std::string>& special_tokens) {
|
||||||
|
std::vector<std::string> result;
|
||||||
|
size_t pos = 0;
|
||||||
|
size_t text_len = text.size();
|
||||||
|
|
||||||
|
while (pos < text_len) {
|
||||||
|
size_t next_pos = text_len;
|
||||||
|
std::string matched_token;
|
||||||
|
|
||||||
|
for (const auto& token : special_tokens) {
|
||||||
|
size_t token_pos = text.find(token, pos);
|
||||||
|
if (token_pos != std::string::npos && token_pos < next_pos) {
|
||||||
|
next_pos = token_pos;
|
||||||
|
matched_token = token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (next_pos > pos) {
|
||||||
|
result.push_back(text.substr(pos, next_pos - pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matched_token.empty()) {
|
||||||
|
result.push_back(matched_token);
|
||||||
|
pos = next_pos + matched_token.size();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// int main() {
|
||||||
|
// std::string text = "I'm testing C++ token_split function. 你好,世界! 123";
|
||||||
|
// auto tokens = token_split(text);
|
||||||
|
|
||||||
|
// for (const auto& t : tokens) {
|
||||||
|
// std::cout << "[" << t << "] ";
|
||||||
|
// }
|
||||||
|
// std::cout << "\n";
|
||||||
|
// return 0;
|
||||||
|
// }
|
||||||
10
tokenize_util.h
Normal file
10
tokenize_util.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#ifndef __TOKENIZE_UTIL__
|
||||||
|
#define __TOKENIZE_UTIL__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
std::vector<std::string> token_split(const std::string& text);
|
||||||
|
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::vector<std::string>& special_tokens);
|
||||||
|
|
||||||
|
#endif // __TOKENIZE_UTIL__
|
||||||
139322
vocab_qwen.hpp
Normal file
139322
vocab_qwen.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2
wan.hpp
2
wan.hpp
@ -1833,7 +1833,7 @@ namespace WAN {
|
|||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t T = x->ne[1];
|
int64_t T = x->ne[2];
|
||||||
|
|
||||||
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
|
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
|
||||||
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
|
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user