mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add wan2.1/2.2 support (#778)
* add wan vae suppport * add wan model support * add umt5 support * add wan2.1 t2i support * make flash attn work with wan * make wan a little faster * add wan2.1 t2v support * add wan gguf support * add offload params to cpu support * add wan2.1 i2v support * crop image before resize * set default fps to 16 * add diff lora support * fix wan2.1 i2v * introduce sd_sample_params_t * add wan2.2 t2v support * add wan2.2 14B i2v support * add wan2.2 ti2v support * add high noise lora support * sync: update ggml submodule url * avoid build failure on linux * avoid build failure * update ggml * update ggml * fix sd_version_is_wan * update ggml, fix cpu im2col_3d * fix ggml_nn_attention_ext mask * add cache support to ggml runner * fix the issue of illegal memory access * unify image loading processing * add wan2.1/2.2 FLF2V support * fix end_image mask * update to latest ggml * add GGUFReader * update docs
This commit is contained in:
parent
2eb3845df5
commit
cb1d975e96
2
.gitmodules
vendored
2
.gitmodules
vendored
@ -1,3 +1,3 @@
|
||||
[submodule "ggml"]
|
||||
path = ggml
|
||||
url = https://github.com/ggerganov/ggml.git
|
||||
url = https://github.com/ggml-org/ggml.git
|
||||
|
||||
62
README.md
62
README.md
@ -4,19 +4,33 @@
|
||||
|
||||
# stable-diffusion.cpp
|
||||
|
||||
Inference of Stable Diffusion and Flux in pure C/C++
|
||||
Diffusion model(SD,Flux,Wan,...) inference in pure C/C++
|
||||
|
||||
***Note that this project is under active development. \
|
||||
API and command-line parameters may change frequently.***
|
||||
|
||||
## Features
|
||||
|
||||
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
- Super lightweight and without external dependencies
|
||||
- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support
|
||||
- Supported models
|
||||
- Image Models
|
||||
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
|
||||
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
|
||||
- [Flux-dev/Flux-schnell Support](./docs/flux.md)
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- [SD3/SD3.5](./docs/sd3.md)
|
||||
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
||||
- [Chroma](./docs/chroma.md)
|
||||
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
|
||||
- Image Edit Models
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- Video Models
|
||||
- [Wan2.1/Wan2.2](./docs/wan.md)
|
||||
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
|
||||
- Control Net support with SD 1.5
|
||||
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
|
||||
- Latent Consistency Models support (LCM/LCM-LoRA)
|
||||
- Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd)
|
||||
- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN)
|
||||
- 16-bit, 32-bit float support
|
||||
- 2-bit, 3-bit, 4-bit, 5-bit and 8-bit integer quantization support
|
||||
- Accelerated memory-efficient CPU inference
|
||||
@ -26,15 +40,9 @@ Inference of Stable Diffusion and Flux in pure C/C++
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
|
||||
- No need to convert to `.ggml` or `.gguf` anymore!
|
||||
- Flash Attention for memory usage optimization
|
||||
- Original `txt2img` and `img2img` mode
|
||||
- Negative prompt
|
||||
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
|
||||
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
|
||||
- Latent Consistency Models support (LCM/LCM-LoRA)
|
||||
- Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd)
|
||||
- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN)
|
||||
- VAE tiling processing for reduce memory usage
|
||||
- Control Net support with SD 1.5
|
||||
- Sampling method
|
||||
- `Euler A`
|
||||
- `Euler`
|
||||
@ -287,8 +295,10 @@ arguments:
|
||||
If threads <= 0, then threads will be set to the number of CPU physical cores
|
||||
-m, --model [MODEL] path to full model
|
||||
--diffusion-model path to the standalone diffusion model
|
||||
--high-noise-diffusion-model path to the standalone high noise diffusion model
|
||||
--clip_l path to the clip-l text encoder
|
||||
--clip_g path to the clip-g text encoder
|
||||
--clip_vision path to the clip-vision encoder
|
||||
--t5xxl path to the t5xxl text encoder
|
||||
--vae [VAE] path to vae
|
||||
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
|
||||
@ -303,8 +313,9 @@ arguments:
|
||||
If not specified, the default is the type of the weight file
|
||||
--tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0")
|
||||
--lora-model-dir [DIR] lora model directory
|
||||
-i, --init-img [IMAGE] path to the input image, required by img2img
|
||||
-i, --init-img [IMAGE] path to the init image, required by img2img
|
||||
--mask [MASK] path to the mask image, required by img2img with mask
|
||||
-i, --end-img [IMAGE] path to the end image, required by flf2v
|
||||
--control-image [IMAGE] path to image condition, control net
|
||||
-r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times)
|
||||
-o, --output OUTPUT path to write result image to (default: ./output.png)
|
||||
@ -319,6 +330,23 @@ arguments:
|
||||
--skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])
|
||||
--skip-layer-start START SLG enabling point: (default: 0.01)
|
||||
--skip-layer-end END SLG disabling point: (default: 0.2)
|
||||
--scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)
|
||||
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}
|
||||
sampling method (default: "euler_a")
|
||||
--steps STEPS number of sample steps (default: 20)
|
||||
--high-noise-cfg-scale SCALE (high noise) unconditional guidance scale: (default: 7.0)
|
||||
--high-noise-img-cfg-scale SCALE (high noise) image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
|
||||
--high-noise-guidance SCALE (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
||||
--high-noise-slg-scale SCALE (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
||||
0 means disabled, a value of 2.5 is nice for sd3.5 medium
|
||||
--high-noise-eta SCALE (high noise) eta in DDIM, only for DDIM and TCD: (default: 0)
|
||||
--high-noise-skip-layers LAYERS (high noise) Layers to skip for SLG steps: (default: [7,8,9])
|
||||
--high-noise-skip-layer-start (high noise) SLG enabling point: (default: 0.01)
|
||||
--high-noise-skip-layer-end END (high noise) SLG disabling point: (default: 0.2)
|
||||
--high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)
|
||||
--high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}
|
||||
(high noise) sampling method (default: "euler_a")
|
||||
--high-noise-steps STEPS (high noise) number of sample steps (default: 20)
|
||||
SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])
|
||||
--strength STRENGTH strength for noising/unnoising (default: 0.75)
|
||||
--style-ratio STYLE-RATIO strength for keeping input identity (default: 20)
|
||||
@ -326,14 +354,10 @@ arguments:
|
||||
1.0 corresponds to full destruction of information in init image
|
||||
-H, --height H image height, in pixel space (default: 512)
|
||||
-W, --width W image width, in pixel space (default: 512)
|
||||
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}
|
||||
sampling method (default: "euler_a")
|
||||
--steps STEPS number of sample steps (default: 20)
|
||||
--rng {std_default, cuda} RNG (default: cuda)
|
||||
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
|
||||
-b, --batch-count COUNT number of images to generate
|
||||
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
|
||||
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
|
||||
--clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
|
||||
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
@ -351,6 +375,8 @@ arguments:
|
||||
--chroma-disable-dit-mask disable dit mask for chroma
|
||||
--chroma-enable-t5-mask enable t5 mask for chroma
|
||||
--chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma
|
||||
--video-frames video frames (default: 1)
|
||||
--fps fps (default: 24)
|
||||
-v, --verbose print extra info
|
||||
```
|
||||
|
||||
@ -438,3 +464,5 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
|
||||
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
|
||||
- [generative-models](https://github.com/Stability-AI/generative-models/)
|
||||
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker)
|
||||
- [Wan2.1](https://github.com/Wan-Video/Wan2.1)
|
||||
- [Wan2.2](https://github.com/Wan-Video/Wan2.2)
|
||||
BIN
assets/wan/Wan2.1_1.3B_t2v.mp4
Normal file
BIN
assets/wan/Wan2.1_1.3B_t2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.1_14B_flf2v.mp4
Normal file
BIN
assets/wan/Wan2.1_14B_flf2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.1_14B_i2v.mp4
Normal file
BIN
assets/wan/Wan2.1_14B_i2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.1_14B_t2v.mp4
Normal file
BIN
assets/wan/Wan2.1_14B_t2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_14B_flf2v.mp4
Normal file
BIN
assets/wan/Wan2.2_14B_flf2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_14B_i2v.mp4
Normal file
BIN
assets/wan/Wan2.2_14B_i2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_14B_t2i.png
Normal file
BIN
assets/wan/Wan2.2_14B_t2i.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 594 KiB |
BIN
assets/wan/Wan2.2_14B_t2v.mp4
Normal file
BIN
assets/wan/Wan2.2_14B_t2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_14B_t2v_lora.mp4
Normal file
BIN
assets/wan/Wan2.2_14B_t2v_lora.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_5B_i2v.mp4
Normal file
BIN
assets/wan/Wan2.2_5B_i2v.mp4
Normal file
Binary file not shown.
BIN
assets/wan/Wan2.2_5B_t2v.mp4
Normal file
BIN
assets/wan/Wan2.2_5B_t2v.mp4
Normal file
Binary file not shown.
28
clip.hpp
28
clip.hpp
@ -733,7 +733,7 @@ public:
|
||||
if (text_projection != NULL) {
|
||||
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
|
||||
} else {
|
||||
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
|
||||
LOG_DEBUG("identity projection");
|
||||
}
|
||||
return pooled; // [hidden_size, 1, 1]
|
||||
}
|
||||
@ -774,7 +774,10 @@ public:
|
||||
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* pixel_values,
|
||||
bool return_pooled = true,
|
||||
int clip_skip = -1) {
|
||||
// pixel_values: [N, num_channels, image_size, image_size]
|
||||
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
|
||||
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
|
||||
@ -783,7 +786,7 @@ public:
|
||||
|
||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||
x = pre_layernorm->forward(ctx, x);
|
||||
x = encoder->forward(ctx, x, -1, false);
|
||||
x = encoder->forward(ctx, x, clip_skip, false);
|
||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||
auto last_hidden_state = x;
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
@ -851,16 +854,22 @@ public:
|
||||
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* pixel_values,
|
||||
bool return_pooled = true,
|
||||
int clip_skip = -1) {
|
||||
// pixel_values: [N, num_channels, image_size, image_size]
|
||||
// return: [N, projection_dim]
|
||||
// return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size]
|
||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
|
||||
|
||||
auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size]
|
||||
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
||||
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
|
||||
|
||||
return x; // [N, projection_dim]
|
||||
if (return_pooled) {
|
||||
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
@ -868,12 +877,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
CLIPTextModel model;
|
||||
|
||||
CLIPTextModelRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
|
||||
bool with_final_ln = true,
|
||||
int clip_skip_value = -1)
|
||||
: GGMLRunner(backend), model(version, with_final_ln, clip_skip_value) {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, clip_skip_value) {
|
||||
model.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
|
||||
118
conditioner.hpp
118
conditioner.hpp
@ -22,7 +22,7 @@ struct Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) = 0;
|
||||
bool zero_out_masked = false) = 0;
|
||||
virtual void alloc_params_buffer() = 0;
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||
@ -35,7 +35,7 @@ struct Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) = 0;
|
||||
bool zero_out_masked = false) = 0;
|
||||
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||
const std::string& prompt) = 0;
|
||||
};
|
||||
@ -57,6 +57,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
std::vector<std::string> readed_embeddings;
|
||||
|
||||
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string& embd_dir,
|
||||
SDVersion version = VERSION_SD1,
|
||||
@ -64,12 +65,12 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
int clip_skip = -1)
|
||||
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
|
||||
if (sd_version_is_sd1(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14);
|
||||
} else if (sd_version_is_sd2(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14);
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
}
|
||||
set_clip_skip(clip_skip);
|
||||
}
|
||||
@ -154,7 +155,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
model_loader.load_tensors(on_load, NULL);
|
||||
model_loader.load_tensors(on_load);
|
||||
readed_embeddings.push_back(embd_name);
|
||||
if (embd) {
|
||||
int64_t hidden_size = text_model->model.hidden_size;
|
||||
@ -410,7 +411,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
set_clip_skip(clip_skip);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
|
||||
@ -499,7 +500,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
float new_mean = ggml_tensor_mean(result);
|
||||
ggml_tensor_scale(result, (original_mean / new_mean));
|
||||
}
|
||||
if (force_zero_embeddings) {
|
||||
if (zero_out_masked) {
|
||||
float* vec = (float*)result->data;
|
||||
for (int i = 0; i < ggml_nelements(result); i++) {
|
||||
vec[i] = 0;
|
||||
@ -563,7 +564,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto image_tokens = convert_token_to_id(trigger_word);
|
||||
// if(image_tokens.size() == 1){
|
||||
// printf(" image token id is: %d \n", image_tokens[0]);
|
||||
@ -584,7 +585,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
// for(int i = 0; i < clsm.size(); ++i)
|
||||
// printf("%d ", clsm[i]?1:0);
|
||||
// printf("\n");
|
||||
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
|
||||
auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
||||
return std::make_tuple(cond, clsm);
|
||||
}
|
||||
|
||||
@ -607,19 +608,21 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto tokens_and_weights = tokenize(text, true);
|
||||
std::vector<int>& tokens = tokens_and_weights.first;
|
||||
std::vector<float>& weights = tokens_and_weights.second;
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked);
|
||||
}
|
||||
};
|
||||
|
||||
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
CLIPVisionModelProjection vision_model;
|
||||
|
||||
FrozenCLIPVisionEmbedder(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
|
||||
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) {
|
||||
FrozenCLIPVisionEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) {
|
||||
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
|
||||
}
|
||||
|
||||
@ -631,12 +634,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
vision_model.get_param_tensors(tensors, "cond_stage_model.transformer");
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values, bool return_pooled, int clip_skip) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
pixel_values = to_backend(pixel_values);
|
||||
|
||||
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values);
|
||||
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip);
|
||||
|
||||
ggml_build_forward_expand(gf, hidden_states);
|
||||
|
||||
@ -645,10 +648,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
|
||||
void compute(const int n_threads,
|
||||
ggml_tensor* pixel_values,
|
||||
bool return_pooled,
|
||||
int clip_skip,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(pixel_values);
|
||||
return build_graph(pixel_values, return_pooled, clip_skip);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
@ -663,12 +668,13 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
std::shared_ptr<T5Runner> t5;
|
||||
|
||||
SD3CLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
int clip_skip = -1)
|
||||
: clip_g_tokenizer(0) {
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
set_clip_skip(clip_skip);
|
||||
}
|
||||
|
||||
@ -773,7 +779,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
int n_threads,
|
||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
||||
int clip_skip,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
set_clip_skip(clip_skip);
|
||||
auto& clip_l_tokens = token_and_weights[0].first;
|
||||
auto& clip_l_weights = token_and_weights[0].second;
|
||||
@ -952,7 +958,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
if (force_zero_embeddings) {
|
||||
if (zero_out_masked) {
|
||||
float* vec = (float*)chunk_hidden_states->data;
|
||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
||||
vec[i] = 0;
|
||||
@ -979,9 +985,9 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto tokens_and_weights = tokenize(text, 77, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
@ -992,7 +998,7 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
@ -1010,10 +1016,11 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
size_t chunk_len = 256;
|
||||
|
||||
FluxCLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
int clip_skip = -1) {
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
|
||||
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
set_clip_skip(clip_skip);
|
||||
}
|
||||
|
||||
@ -1101,7 +1108,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
int n_threads,
|
||||
std::vector<std::pair<std::vector<int>, std::vector<float>>> token_and_weights,
|
||||
int clip_skip,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
set_clip_skip(clip_skip);
|
||||
auto& clip_l_tokens = token_and_weights[0].first;
|
||||
auto& clip_l_weights = token_and_weights[0].second;
|
||||
@ -1173,7 +1180,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
if (force_zero_embeddings) {
|
||||
if (zero_out_masked) {
|
||||
float* vec = (float*)chunk_hidden_states->data;
|
||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
||||
vec[i] = 0;
|
||||
@ -1200,9 +1207,9 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
@ -1213,7 +1220,7 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
@ -1223,20 +1230,23 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
}
|
||||
};
|
||||
|
||||
struct PixArtCLIPEmbedder : public Conditioner {
|
||||
struct T5CLIPEmbedder : public Conditioner {
|
||||
T5UniGramTokenizer t5_tokenizer;
|
||||
std::shared_ptr<T5Runner> t5;
|
||||
size_t chunk_len = 512;
|
||||
bool use_mask = false;
|
||||
int mask_pad = 1;
|
||||
bool is_umt5 = false;
|
||||
|
||||
PixArtCLIPEmbedder(ggml_backend_t backend,
|
||||
T5CLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
int clip_skip = -1,
|
||||
bool use_mask = false,
|
||||
int mask_pad = 1)
|
||||
: use_mask(use_mask), mask_pad(mask_pad) {
|
||||
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
int mask_pad = 1,
|
||||
bool is_umt5 = false)
|
||||
: use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) {
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5);
|
||||
}
|
||||
|
||||
void set_clip_skip(int clip_skip) {
|
||||
@ -1317,7 +1327,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
int n_threads,
|
||||
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
|
||||
int clip_skip,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto& t5_tokens = std::get<0>(token_and_weights);
|
||||
auto& t5_weights = std::get<1>(token_and_weights);
|
||||
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
|
||||
@ -1325,8 +1335,8 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
|
||||
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
|
||||
struct ggml_tensor* pooled = NULL; // [768,]
|
||||
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
|
||||
struct ggml_tensor* pooled = NULL;
|
||||
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [n_token]
|
||||
|
||||
std::vector<float> hidden_states_vec;
|
||||
|
||||
@ -1367,10 +1377,16 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
if (force_zero_embeddings) {
|
||||
float* vec = (float*)chunk_hidden_states->data;
|
||||
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
|
||||
vec[i] = 0;
|
||||
if (zero_out_masked) {
|
||||
auto tensor = chunk_hidden_states;
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
if (chunk_mask[i1] < 0.f) {
|
||||
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1379,16 +1395,12 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
|
||||
}
|
||||
|
||||
if (hidden_states_vec.size() > 0) {
|
||||
GGML_ASSERT(hidden_states_vec.size() > 0);
|
||||
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
|
||||
hidden_states = ggml_reshape_2d(work_ctx,
|
||||
hidden_states,
|
||||
chunk_hidden_states->ne[0],
|
||||
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
|
||||
} else {
|
||||
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
|
||||
ggml_set_f32(hidden_states, 0.f);
|
||||
}
|
||||
|
||||
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
|
||||
|
||||
@ -1402,9 +1414,9 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
@ -1415,7 +1427,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool force_zero_embeddings = false) {
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
|
||||
@ -317,9 +317,10 @@ struct ControlNet : public GGMLRunner {
|
||||
bool guided_hint_cached = false;
|
||||
|
||||
ControlNet(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend), control_net(version) {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), control_net(version) {
|
||||
control_net.init(params_ctx, tensor_types, "");
|
||||
}
|
||||
|
||||
@ -357,7 +358,7 @@ struct ControlNet : public GGMLRunner {
|
||||
control_buffer_size += ggml_nbytes(controls[i]);
|
||||
}
|
||||
|
||||
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, backend);
|
||||
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);
|
||||
|
||||
LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f);
|
||||
}
|
||||
@ -454,7 +455,7 @@ struct ControlNet : public GGMLRunner {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(tensors, backend, ignore_tensors);
|
||||
bool success = model_loader.load_tensors(tensors, ignore_tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load control net tensors from model loader failed");
|
||||
|
||||
@ -252,7 +252,7 @@ struct KarrasSchedule : SigmaSchedule {
|
||||
};
|
||||
|
||||
struct Denoiser {
|
||||
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
|
||||
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
|
||||
virtual float sigma_min() = 0;
|
||||
virtual float sigma_max() = 0;
|
||||
virtual float sigma_to_t(float sigma) = 0;
|
||||
@ -263,7 +263,7 @@ struct Denoiser {
|
||||
|
||||
virtual std::vector<float> get_sigmas(uint32_t n) {
|
||||
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
|
||||
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
|
||||
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
|
||||
}
|
||||
};
|
||||
|
||||
@ -349,7 +349,7 @@ struct EDMVDenoiser : public CompVisVDenoiser {
|
||||
|
||||
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
|
||||
: min_sigma(min_sigma), max_sigma(max_sigma) {
|
||||
schedule = std::make_shared<ExponentialSchedule>();
|
||||
scheduler = std::make_shared<ExponentialSchedule>();
|
||||
}
|
||||
|
||||
float t_to_sigma(float t) {
|
||||
|
||||
@ -4,8 +4,10 @@
|
||||
#include "flux.hpp"
|
||||
#include "mmdit.hpp"
|
||||
#include "unet.hpp"
|
||||
#include "wan.hpp"
|
||||
|
||||
struct DiffusionModel {
|
||||
virtual std::string get_desc() = 0;
|
||||
virtual void compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
@ -32,10 +34,15 @@ struct UNetModel : public DiffusionModel {
|
||||
UNetModelRunner unet;
|
||||
|
||||
UNetModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return unet.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -85,8 +92,13 @@ struct MMDiTModel : public DiffusionModel {
|
||||
MMDiTRunner mmdit;
|
||||
|
||||
MMDiTModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: mmdit(backend, tensor_types, "model.diffusion_model") {
|
||||
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return mmdit.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -135,11 +147,16 @@ struct FluxModel : public DiffusionModel {
|
||||
Flux::FluxRunner flux;
|
||||
|
||||
FluxModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool flash_attn = false,
|
||||
bool use_mask = false)
|
||||
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
|
||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return flux.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
@ -184,4 +201,63 @@ struct FluxModel : public DiffusionModel {
|
||||
}
|
||||
};
|
||||
|
||||
struct WanModel : public DiffusionModel {
|
||||
std::string prefix;
|
||||
WAN::WanRunner wan;
|
||||
|
||||
WanModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_WAN2,
|
||||
bool flash_attn = false)
|
||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return wan.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
wan.alloc_params_buffer();
|
||||
}
|
||||
|
||||
void free_params_buffer() {
|
||||
wan.free_params_buffer();
|
||||
}
|
||||
|
||||
void free_compute_buffer() {
|
||||
wan.free_compute_buffer();
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||
wan.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
size_t get_params_buffer_size() {
|
||||
return wan.get_params_buffer_size();
|
||||
}
|
||||
|
||||
int64_t get_adm_in_channels() {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* c_concat,
|
||||
struct ggml_tensor* y,
|
||||
struct ggml_tensor* guidance,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
int num_video_frames = -1,
|
||||
std::vector<struct ggml_tensor*> controls = {},
|
||||
float control_strength = 0.f,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
141
docs/wan.md
Normal file
141
docs/wan.md
Normal file
@ -0,0 +1,141 @@
|
||||
# How to Use
|
||||
|
||||
## Download weights
|
||||
|
||||
- Download Wan
|
||||
- Wan2.1
|
||||
- Wan2.1 T2V 1.3B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
||||
- Wan2.1 T2V 14B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/tree/main
|
||||
- Wan2.1 I2V 14B 480P
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/tree/main
|
||||
- Wan2.1 I2V 14B 720P
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/city96/Wan2.1-I2V-14B-720P-gguf/tree/main
|
||||
- Wan2.1 FLF2V 14B 720P
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/city96/Wan2.1-FLF2V-14B-720P-gguf/tree/main
|
||||
- Wan2.2
|
||||
- Wan2.2 TI2V 5B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/QuantStack/Wan2.2-TI2V-5B-GGUF/tree/main
|
||||
- Wan2.2 T2V A14B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/QuantStack/Wan2.2-T2V-A14B-GGUF/tree/main
|
||||
- Wan2.2 I2V A14B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/tree/main
|
||||
- Download vae
|
||||
- wan_2.1_vae (for all the wan model except Wan2.2 TI2V 5B)
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors
|
||||
- wan_2.2_vae (for Wan2.2 TI2V 5B only)
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/blob/main/split_files/vae/wan2.2_vae.safetensors
|
||||
- Download umt5_xxl
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/text_encoders/umt5_xxl_fp16.safetensors
|
||||
- gguf: https://huggingface.co/city96/umt5-xxl-encoder-gguf/tree/main
|
||||
|
||||
- Download clip_vison_h (for Wan2.1 I2V/FLF2V only)
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/clip_vision/clip_vision_h.safetensors
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
Since GitHub does not support AVI files, the file I uploaded was converted from AVI to MP4.
|
||||
|
||||
### Wan2.1 T2V 1.3B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1_t2v_1.3B_fp16.safetensors --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --video-frames 33
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.1_1.3B_t2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.1 T2V 14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-t2v-14b-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --offload-to-cpu --video-frames 33
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.1_14B_t2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
|
||||
|
||||
### Wan2.1 I2V 14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-i2v-14b-480p-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf --clip_vision ..\..\ComfyUI\models\clip_vision\clip_vision_h.safetensors -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa --video-frames 33 --offload-to-cpu -i ..\assets\cat_with_sd_cpp_42.png
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.1_14B_i2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.2 T2V A14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf --high-noise-diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 3.5 --sampling-method euler --steps 10 --high-noise-cfg-scale 3.5 --high-noise-sampling-method euler --high-noise-steps 8 -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --offload-to-cpu --video-frames 33
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_14B_t2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.2 I2V A14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-I2V-A14B-LowNoise-Q8_0.gguf --high-noise-diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-I2V-A14B-HighNoise-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 3.5 --sampling-method euler --steps 10 --high-noise-cfg-scale 3.5 --high-noise-sampling-method euler --high-noise-steps 8 -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --offload-to-cpu --video-frames 33 --offload-to-cpu -i ..\assets\cat_with_sd_cpp_42.png
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_14B_i2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.2 T2V A14B T2I
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf --high-noise-diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 3.5 --sampling-method euler --steps 10 --high-noise-cfg-scale 3.5 --high-noise-sampling-method euler --high-noise-steps 8 -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --offload-to-cpu
|
||||
```
|
||||
|
||||
<img width="832" height="480" alt="Wan2 2_14B_t2i" src="../assets/wan/Wan2.2_14B_t2i.png" />
|
||||
|
||||
### Wan2.2 T2V 14B with Lora
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf --high-noise-diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat<lora:wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise:1><lora:|high_noise|wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise:1>" --cfg-scale 3.5 --sampling-method euler --steps 4 --high-noise-cfg-scale 3.5 --high-noise-sampling-method euler --high-noise-steps 4 -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --offload-to-cpu --lora-model-dir ..\..\ComfyUI\models\loras --video-frames 33
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_14B_t2v_lora.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
|
||||
|
||||
### Wan2.2 TI2V 5B
|
||||
|
||||
#### T2V
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.2_ti2v_5B_fp16.safetensors --vae ..\..\ComfyUI\models\vae\wan2.2_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa --offload-to-cpu --video-frames 33
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_5B_t2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
#### I2V
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.2_ti2v_5B_fp16.safetensors --vae ..\..\ComfyUI\models\vae\wan2.2_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa --offload-to-cpu --video-frames 33 -i ..\assets\cat_with_sd_cpp_42.png
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_5B_i2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.1 FLF2V 14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-flf2v-14b-720p-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf --clip_vision ..\..\ComfyUI\models\clip_vision\clip_vision_h.safetensors -p "glass flower blossom" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa --video-frames 33 --offload-to-cpu --init-img ..\..\ComfyUI\input\start_image.png --end-img ..\..\ComfyUI\input\end_image.png
|
||||
```
|
||||
|
||||
|
||||
<video src=../assets/wan/Wan2.1_14B_flf2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
|
||||
### Wan2.2 FLF2V 14B
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-I2V-A14B-LowNoise-Q8_0.gguf --high-noise-diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.2-I2V-A14B-HighNoise-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf --cfg-scale 3.5 --sampling-method euler --steps 10 --high-noise-cfg-scale 3.5 --high-noise-sampling-method euler --high-noise-steps 8 -v -p "glass flower blossom" -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa --video-frames 33 --offload-to-cpu --init-img ..\..\ComfyUI\input\start_image.png --end-img ..\..\ComfyUI\input\end_image.png
|
||||
```
|
||||
|
||||
<video src=../assets/wan/Wan2.2_14B_flf2v.mp4 controls="controls" muted="muted" type="video/mp4"></video>
|
||||
@ -142,8 +142,10 @@ struct ESRGAN : public GGMLRunner {
|
||||
int scale = 4;
|
||||
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
|
||||
|
||||
ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
|
||||
: GGMLRunner(backend) {
|
||||
ESRGAN(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
rrdb_net.init(params_ctx, tensor_types, "");
|
||||
}
|
||||
|
||||
@ -175,7 +177,7 @@ struct ESRGAN : public GGMLRunner {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(esrgan_tensors, backend);
|
||||
bool success = model_loader.load_tensors(esrgan_tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load esrgan tensors from model loader failed");
|
||||
|
||||
217
examples/cli/avi_writer.h
Normal file
217
examples/cli/avi_writer.h
Normal file
@ -0,0 +1,217 @@
|
||||
#ifndef __AVI_WRITER_H__
|
||||
#define __AVI_WRITER_H__
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#ifndef INCLUDE_STB_IMAGE_WRITE_H
|
||||
#include "stb_image_write.h"
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
uint32_t offset;
|
||||
uint32_t size;
|
||||
} avi_index_entry;
|
||||
|
||||
// Write 32-bit little-endian integer
|
||||
void write_u32_le(FILE* f, uint32_t val) {
|
||||
fwrite(&val, 4, 1, f);
|
||||
}
|
||||
|
||||
// Write 16-bit little-endian integer
|
||||
void write_u16_le(FILE* f, uint16_t val) {
|
||||
fwrite(&val, 2, 1, f);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an MJPG AVI file from an array of sd_image_t images.
|
||||
* Images are encoded to JPEG using stb_image_write.
|
||||
*
|
||||
* @param filename Output AVI file name.
|
||||
* @param images Array of input images.
|
||||
* @param num_images Number of images in the array.
|
||||
* @param fps Frames per second for the video.
|
||||
* @param quality JPEG quality (0-100).
|
||||
* @return 0 on success, -1 on failure.
|
||||
*/
|
||||
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality = 90) {
|
||||
if (num_images == 0) {
|
||||
fprintf(stderr, "Error: Image array is empty.\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
FILE* f = fopen(filename, "wb");
|
||||
if (!f) {
|
||||
perror("Error opening file for writing");
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint32_t width = images[0].width;
|
||||
uint32_t height = images[0].height;
|
||||
uint32_t channels = images[0].channel;
|
||||
if (channels != 3 && channels != 4) {
|
||||
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
|
||||
fclose(f);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// --- RIFF AVI Header ---
|
||||
fwrite("RIFF", 4, 1, f);
|
||||
long riff_size_pos = ftell(f);
|
||||
write_u32_le(f, 0); // Placeholder for file size
|
||||
fwrite("AVI ", 4, 1, f);
|
||||
|
||||
// 'hdrl' LIST (header list)
|
||||
fwrite("LIST", 4, 1, f);
|
||||
write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
|
||||
fwrite("hdrl", 4, 1, f);
|
||||
|
||||
// 'avih' chunk (AVI main header)
|
||||
fwrite("avih", 4, 1, f);
|
||||
write_u32_le(f, 56);
|
||||
write_u32_le(f, 1000000 / fps); // Microseconds per frame
|
||||
write_u32_le(f, 0); // Max bytes per second
|
||||
write_u32_le(f, 0); // Padding granularity
|
||||
write_u32_le(f, 0x110); // Flags (HASINDEX | ISINTERLEAVED)
|
||||
write_u32_le(f, num_images); // Total frames
|
||||
write_u32_le(f, 0); // Initial frames
|
||||
write_u32_le(f, 1); // Number of streams
|
||||
write_u32_le(f, width * height * 3); // Suggested buffer size
|
||||
write_u32_le(f, width);
|
||||
write_u32_le(f, height);
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
|
||||
// 'strl' LIST (stream list)
|
||||
fwrite("LIST", 4, 1, f);
|
||||
write_u32_le(f, 4 + 8 + 56 + 8 + 40);
|
||||
fwrite("strl", 4, 1, f);
|
||||
|
||||
// 'strh' chunk (stream header)
|
||||
fwrite("strh", 4, 1, f);
|
||||
write_u32_le(f, 56);
|
||||
fwrite("vids", 4, 1, f); // Stream type: video
|
||||
fwrite("MJPG", 4, 1, f); // Codec: Motion JPEG
|
||||
write_u32_le(f, 0); // Flags
|
||||
write_u16_le(f, 0); // Priority
|
||||
write_u16_le(f, 0); // Language
|
||||
write_u32_le(f, 0); // Initial frames
|
||||
write_u32_le(f, 1); // Scale
|
||||
write_u32_le(f, fps); // Rate
|
||||
write_u32_le(f, 0); // Start
|
||||
write_u32_le(f, num_images); // Length
|
||||
write_u32_le(f, width * height * 3); // Suggested buffer size
|
||||
write_u32_le(f, (uint32_t)-1); // Quality
|
||||
write_u32_le(f, 0); // Sample size
|
||||
write_u16_le(f, 0); // rcFrame.left
|
||||
write_u16_le(f, 0); // rcFrame.top
|
||||
write_u16_le(f, 0); // rcFrame.right
|
||||
write_u16_le(f, 0); // rcFrame.bottom
|
||||
|
||||
// 'strf' chunk (stream format: BITMAPINFOHEADER)
|
||||
fwrite("strf", 4, 1, f);
|
||||
write_u32_le(f, 40);
|
||||
write_u32_le(f, 40); // biSize
|
||||
write_u32_le(f, width);
|
||||
write_u32_le(f, height);
|
||||
write_u16_le(f, 1); // biPlanes
|
||||
write_u16_le(f, 24); // biBitCount
|
||||
fwrite("MJPG", 4, 1, f); // biCompression (FOURCC)
|
||||
write_u32_le(f, width * height * 3); // biSizeImage
|
||||
write_u32_le(f, 0); // XPelsPerMeter
|
||||
write_u32_le(f, 0); // YPelsPerMeter
|
||||
write_u32_le(f, 0); // Colors used
|
||||
write_u32_le(f, 0); // Colors important
|
||||
|
||||
// 'movi' LIST (video frames)
|
||||
long movi_list_pos = ftell(f);
|
||||
fwrite("LIST", 4, 1, f);
|
||||
long movi_size_pos = ftell(f);
|
||||
write_u32_le(f, 0); // Placeholder for movi size
|
||||
fwrite("movi", 4, 1, f);
|
||||
|
||||
avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images);
|
||||
if (!index) {
|
||||
fclose(f);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Encode and write each frame as JPEG
|
||||
struct {
|
||||
uint8_t* buf;
|
||||
size_t size;
|
||||
} jpeg_data;
|
||||
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
jpeg_data.buf = NULL;
|
||||
jpeg_data.size = 0;
|
||||
|
||||
// Callback function to collect JPEG data into memory
|
||||
auto write_to_buf = [](void* context, void* data, int size) {
|
||||
auto jd = (decltype(jpeg_data)*)context;
|
||||
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
|
||||
memcpy(jd->buf + jd->size, data, size);
|
||||
jd->size += size;
|
||||
};
|
||||
|
||||
// Encode to JPEG in memory
|
||||
stbi_write_jpg_to_func(
|
||||
write_to_buf,
|
||||
&jpeg_data,
|
||||
images[i].width,
|
||||
images[i].height,
|
||||
channels,
|
||||
images[i].data,
|
||||
quality);
|
||||
|
||||
// Write '00dc' chunk (video frame)
|
||||
fwrite("00dc", 4, 1, f);
|
||||
write_u32_le(f, jpeg_data.size);
|
||||
index[i].offset = ftell(f) - 8;
|
||||
index[i].size = jpeg_data.size;
|
||||
fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
|
||||
|
||||
// Align to even byte size
|
||||
if (jpeg_data.size % 2)
|
||||
fputc(0, f);
|
||||
|
||||
free(jpeg_data.buf);
|
||||
}
|
||||
|
||||
// Finalize 'movi' size
|
||||
long cur_pos = ftell(f);
|
||||
long movi_size = cur_pos - movi_size_pos - 4;
|
||||
fseek(f, movi_size_pos, SEEK_SET);
|
||||
write_u32_le(f, movi_size);
|
||||
fseek(f, cur_pos, SEEK_SET);
|
||||
|
||||
// Write 'idx1' index
|
||||
fwrite("idx1", 4, 1, f);
|
||||
write_u32_le(f, num_images * 16);
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
fwrite("00dc", 4, 1, f);
|
||||
write_u32_le(f, 0x10);
|
||||
write_u32_le(f, index[i].offset);
|
||||
write_u32_le(f, index[i].size);
|
||||
}
|
||||
|
||||
// Finalize RIFF size
|
||||
cur_pos = ftell(f);
|
||||
long file_size = cur_pos - riff_size_pos - 4;
|
||||
fseek(f, riff_size_pos, SEEK_SET);
|
||||
write_u32_le(f, file_size);
|
||||
fseek(f, cur_pos, SEEK_SET);
|
||||
|
||||
fclose(f);
|
||||
free(index);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // __AVI_WRITER_H__
|
||||
File diff suppressed because it is too large
Load Diff
187
flux.hpp
187
flux.hpp
@ -5,6 +5,7 @@
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "rope.hpp"
|
||||
|
||||
#define FLUX_GRAPH_SIZE 10240
|
||||
|
||||
@ -610,179 +611,11 @@ namespace Flux {
|
||||
};
|
||||
|
||||
struct Flux : public GGMLBlock {
|
||||
public:
|
||||
std::vector<float> linspace(float start, float end, int num) {
|
||||
std::vector<float> result(num);
|
||||
float step = (end - start) / (num - 1);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
result[i] = start + i * step;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
|
||||
int rows = mat.size();
|
||||
int cols = mat[0].size();
|
||||
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
for (int j = 0; j < cols; ++j) {
|
||||
transposed[j][i] = mat[i][j];
|
||||
}
|
||||
}
|
||||
return transposed;
|
||||
}
|
||||
|
||||
std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
|
||||
std::vector<float> flat_vec;
|
||||
for (const auto& sub_vec : vec) {
|
||||
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
|
||||
}
|
||||
return flat_vec;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
|
||||
assert(dim % 2 == 0);
|
||||
int half_dim = dim / 2;
|
||||
|
||||
std::vector<float> scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim);
|
||||
|
||||
std::vector<float> omega(half_dim);
|
||||
for (int i = 0; i < half_dim; ++i) {
|
||||
omega[i] = 1.0 / std::pow(theta, scale[i]);
|
||||
}
|
||||
|
||||
int pos_size = pos.size();
|
||||
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
||||
for (int i = 0; i < pos_size; ++i) {
|
||||
for (int j = 0; j < half_dim; ++j) {
|
||||
out[i][j] = pos[i] * omega[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
|
||||
for (int i = 0; i < pos_size; ++i) {
|
||||
for (int j = 0; j < half_dim; ++j) {
|
||||
result[i][4 * j] = std::cos(out[i][j]);
|
||||
result[i][4 * j + 1] = -std::sin(out[i][j]);
|
||||
result[i][4 * j + 2] = std::sin(out[i][j]);
|
||||
result[i][4 * j + 3] = std::cos(out[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Generate IDs for image patches and text
|
||||
std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
|
||||
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
|
||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||
|
||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
|
||||
|
||||
std::vector<float> row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len);
|
||||
std::vector<float> col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len);
|
||||
|
||||
for (int i = 0; i < h_len; ++i) {
|
||||
for (int j = 0; j < w_len; ++j) {
|
||||
img_ids[i * w_len + j][0] = index;
|
||||
img_ids[i * w_len + j][1] = row_ids[i];
|
||||
img_ids[i * w_len + j][2] = col_ids[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < img_ids.size(); ++j) {
|
||||
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
|
||||
}
|
||||
}
|
||||
return img_ids_repeated;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
|
||||
const std::vector<std::vector<float>>& b,
|
||||
int bs) {
|
||||
size_t a_len = a.size() / bs;
|
||||
size_t b_len = b.size() / bs;
|
||||
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < a_len; ++j) {
|
||||
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
|
||||
}
|
||||
for (int j = 0; j < b_len; ++j) {
|
||||
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents) {
|
||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
uint64_t curr_h_offset = 0;
|
||||
uint64_t curr_w_offset = 0;
|
||||
for (ggml_tensor* ref : ref_latents) {
|
||||
uint64_t h_offset = 0;
|
||||
uint64_t w_offset = 0;
|
||||
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
|
||||
w_offset = curr_w_offset;
|
||||
} else {
|
||||
h_offset = curr_h_offset;
|
||||
}
|
||||
|
||||
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
|
||||
ids = concat_ids(ids, ref_ids, bs);
|
||||
|
||||
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
|
||||
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
// Generate positional embeddings
|
||||
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents, int theta, const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents);
|
||||
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
||||
size_t pos_len = ids.size();
|
||||
int num_axes = axes_dim.size();
|
||||
for (int i = 0; i < pos_len; i++) {
|
||||
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
|
||||
}
|
||||
|
||||
int emb_dim = 0;
|
||||
for (int d : axes_dim)
|
||||
emb_dim += d / 2;
|
||||
|
||||
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
||||
int offset = 0;
|
||||
for (int i = 0; i < num_axes; ++i) {
|
||||
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
||||
for (int b = 0; b < bs; ++b) {
|
||||
for (int j = 0; j < pos_len; ++j) {
|
||||
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
||||
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += rope_emb[0].size();
|
||||
}
|
||||
|
||||
return flatten(emb);
|
||||
}
|
||||
|
||||
public:
|
||||
FluxParams params;
|
||||
Flux() {}
|
||||
Flux(FluxParams params)
|
||||
: params(params) {
|
||||
int64_t pe_dim = params.hidden_size / params.num_heads;
|
||||
|
||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
||||
if (params.is_chroma) {
|
||||
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
|
||||
@ -1048,12 +881,13 @@ namespace Flux {
|
||||
bool use_mask = false;
|
||||
|
||||
FluxRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool flash_attn = false,
|
||||
bool use_mask = false)
|
||||
: GGMLRunner(backend), use_mask(use_mask) {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) {
|
||||
flux_params.flash_attn = flash_attn;
|
||||
flux_params.guidance_embed = false;
|
||||
flux_params.depth = 0;
|
||||
@ -1063,7 +897,7 @@ namespace Flux {
|
||||
}
|
||||
for (auto pair : tensor_types) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
|
||||
if (!starts_with(tensor_name, prefix))
|
||||
continue;
|
||||
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
|
||||
// not schnell
|
||||
@ -1150,7 +984,14 @@ namespace Flux {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
}
|
||||
|
||||
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim);
|
||||
pe_vec = Rope::gen_flux_pe(x->ne[1],
|
||||
x->ne[0],
|
||||
2,
|
||||
x->ne[3],
|
||||
context->ne[1],
|
||||
ref_latents,
|
||||
flux_params.theta,
|
||||
flux_params.axes_dim);
|
||||
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
|
||||
// LOG_DEBUG("pos_len %d", pos_len);
|
||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
|
||||
@ -1245,7 +1086,7 @@ namespace Flux {
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend));
|
||||
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend, false));
|
||||
{
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
@ -1259,7 +1100,7 @@ namespace Flux {
|
||||
return;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(tensors, backend);
|
||||
bool success = model_loader.load_tensors(tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from model loader failed");
|
||||
|
||||
@ -1,2 +1,5 @@
|
||||
clang-format -style=file -i *.cpp *.h *.hpp
|
||||
clang-format -style=file -i examples/cli/*.cpp
|
||||
for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/cli/*.h; do
|
||||
[[ "$f" == vocab* ]] && continue
|
||||
echo "formatting '$f'"
|
||||
clang-format -style=file -i "$f"
|
||||
done
|
||||
2
ggml
2
ggml
@ -1 +1 @@
|
||||
Subproject commit 7dee1d6a1e7611f238d09be96738388da97c88ed
|
||||
Subproject commit 5fdc78fff274094e2a1b155928131983362d8a71
|
||||
652
ggml_extend.hpp
652
ggml_extend.hpp
@ -212,7 +212,7 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
|
||||
if (tensor->type == GGML_TYPE_F32) {
|
||||
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_tensor_get_f32(tensor, l, k, j, i));
|
||||
} else if (tensor->type == GGML_TYPE_F16) {
|
||||
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_f16(tensor, l, k, j, i));
|
||||
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_fp16_to_fp32(ggml_tensor_get_f16(tensor, l, k, j, i)));
|
||||
} else if (tensor->type == GGML_TYPE_I32) {
|
||||
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_i32(tensor, l, k, j, i));
|
||||
}
|
||||
@ -237,6 +237,8 @@ __STATIC_INLINE__ ggml_tensor* load_tensor_from_file(ggml_context* ctx, const st
|
||||
file.read(reinterpret_cast<char*>(&length), sizeof(length));
|
||||
file.read(reinterpret_cast<char*>(&ttype), sizeof(ttype));
|
||||
|
||||
LOG_DEBUG("load_tensor_from_file %d %d %d", n_dims, length, ttype);
|
||||
|
||||
if (file.eof()) {
|
||||
LOG_ERROR("incomplete file '%s'", file_path.c_str());
|
||||
return NULL;
|
||||
@ -325,17 +327,27 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
|
||||
return image_data;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ uint8_t* sd_tensor_to_mul_image(struct ggml_tensor* input, int idx) {
|
||||
__STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx, bool video = false) {
|
||||
int64_t width = input->ne[0];
|
||||
int64_t height = input->ne[1];
|
||||
int64_t channels = input->ne[2];
|
||||
int64_t channels;
|
||||
if (video) {
|
||||
channels = input->ne[3];
|
||||
} else {
|
||||
channels = input->ne[2];
|
||||
}
|
||||
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
|
||||
uint8_t* image_data = (uint8_t*)malloc(width * height * channels);
|
||||
for (int iy = 0; iy < height; iy++) {
|
||||
for (int ix = 0; ix < width; ix++) {
|
||||
for (int k = 0; k < channels; k++) {
|
||||
float value = ggml_tensor_get_f32(input, ix, iy, k, idx);
|
||||
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f);
|
||||
for (int ih = 0; ih < height; ih++) {
|
||||
for (int iw = 0; iw < width; iw++) {
|
||||
for (int ic = 0; ic < channels; ic++) {
|
||||
float value;
|
||||
if (video) {
|
||||
value = ggml_tensor_get_f32(input, iw, ih, idx, ic);
|
||||
} else {
|
||||
value = ggml_tensor_get_f32(input, iw, ih, ic, idx);
|
||||
}
|
||||
*(image_data + ih * width * channels + iw * channels + ic) = (uint8_t)(value * 255.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -581,7 +593,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ct
|
||||
}
|
||||
|
||||
// convert values from [0, 1] to [-1, 1]
|
||||
__STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
@ -591,7 +603,7 @@ __STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
|
||||
}
|
||||
|
||||
// convert values from [-1, 1] to [0, 1]
|
||||
__STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
@ -600,6 +612,125 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
if (ggml_is_contiguous(x)) {
|
||||
return x;
|
||||
}
|
||||
return ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
// torch like permute
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int axis0,
|
||||
int axis1,
|
||||
int axis2,
|
||||
int axis3) {
|
||||
int torch_axes[4] = {axis0, axis1, axis2, axis3};
|
||||
|
||||
int ggml_axes[4] = {0};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
int found = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
if (torch_axes[j] == i) {
|
||||
ggml_axes[i] = j;
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(found && "Invalid permute input: must be a permutation of 0-3");
|
||||
}
|
||||
|
||||
return ggml_permute(ctx, x, ggml_axes[0], ggml_axes[1], ggml_axes[2], ggml_axes[3]);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t dim,
|
||||
int64_t start,
|
||||
int64_t end) {
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
if (x->ne[dim] == 1) {
|
||||
return x;
|
||||
}
|
||||
while (start < 0) {
|
||||
start = x->ne[dim] + start;
|
||||
}
|
||||
while (end < 0) {
|
||||
end = x->ne[dim] + end;
|
||||
}
|
||||
GGML_ASSERT(end > start);
|
||||
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
||||
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
||||
|
||||
int perm[4] = {0, 1, 2, 3};
|
||||
for (int i = dim; i < 3; ++i)
|
||||
perm[i] = perm[i + 1];
|
||||
perm[3] = dim;
|
||||
|
||||
int inv_perm[4];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
inv_perm[perm[i]] = i;
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_view_4d(
|
||||
ctx, x,
|
||||
x->ne[0], x->ne[1], x->ne[2], end - start,
|
||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
// example: [N, 3*C, H, W] => ([N, C, H, W], [N, C, H, W], [N, C, H, W])
|
||||
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int num,
|
||||
int64_t dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
GGML_ASSERT(x->ne[dim] % num == 0);
|
||||
|
||||
int perm[4] = {0, 1, 2, 3};
|
||||
for (int i = dim; i < 3; ++i)
|
||||
perm[i] = perm[i + 1];
|
||||
perm[3] = dim;
|
||||
|
||||
int inv_perm[4];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
inv_perm[perm[i]] = i;
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> chunks;
|
||||
int64_t chunk_size = x->ne[3] / num;
|
||||
for (int i = 0; i < num; i++) {
|
||||
auto chunk = ggml_view_4d(
|
||||
ctx, x,
|
||||
x->ne[0], x->ne[1], x->ne[2], chunk_size,
|
||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size);
|
||||
|
||||
if (dim != 3) {
|
||||
chunk = ggml_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||
chunk = ggml_cont(ctx, chunk);
|
||||
}
|
||||
chunks.push_back(chunk);
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
|
||||
|
||||
// Tiling
|
||||
@ -680,7 +811,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
||||
struct ggml_tensor* b) {
|
||||
x = ggml_mul_mat(ctx, w, x);
|
||||
if (b != NULL) {
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -703,11 +834,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
|
||||
if (b != NULL) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||
// b = ggml_repeat(ctx, b, x);
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
// w: [OC*IC, KD, KH, KW]
|
||||
// x: [N*IC, ID, IH, IW]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* w,
|
||||
@ -730,35 +863,30 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context
|
||||
// w: [OC,IC, KD, 1 * 1]
|
||||
// x: [N, IC, IH, IW]
|
||||
// b: [OC,]
|
||||
// result: [N, OC, OH, OW]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx,
|
||||
// result: [N*OC, OD, OH, OW]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* w,
|
||||
struct ggml_tensor* b,
|
||||
int64_t IC,
|
||||
int s0 = 1,
|
||||
int s1 = 1,
|
||||
int s2 = 1,
|
||||
int p2 = 1,
|
||||
int p0 = 0,
|
||||
int p1 = 0,
|
||||
int p2 = 0,
|
||||
int d0 = 1,
|
||||
int d1 = 1,
|
||||
int d2 = 1) {
|
||||
GGML_ASSERT(w->ne[0] == 1);
|
||||
// timesteps = x.shape[0]
|
||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
// x = conv3d(x)
|
||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
int64_t OC = w->ne[3] / IC;
|
||||
int64_t N = x->ne[3] / IC;
|
||||
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW]
|
||||
if (b != NULL) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||
x = ggml_add(ctx, x, b);
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
return x; // [B*T, OC, OH, OW]
|
||||
return x;
|
||||
}
|
||||
|
||||
// w: [OC,IC, KD, 1 * 1]
|
||||
@ -794,6 +922,54 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
// qkv: [N, 3*C, H, W]
|
||||
// return: ([N, C, H, W], [N, C, H, W], [N, C, H, W])
|
||||
__STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_context* ctx,
|
||||
struct ggml_tensor* qkv) {
|
||||
int64_t W = qkv->ne[0];
|
||||
int64_t H = qkv->ne[1];
|
||||
int64_t C = qkv->ne[2] / 3;
|
||||
int64_t N = qkv->ne[3];
|
||||
int64_t nb1 = qkv->nb[1];
|
||||
int64_t nb2 = qkv->nb[2];
|
||||
qkv = ggml_reshape_4d(ctx, qkv, W * H, C, 3, N); // [N, 3, C, H*W]
|
||||
qkv = ggml_cont(ctx, ggml_torch_permute(ctx, qkv, 0, 1, 3, 2)); // [3, N, C, H*W]
|
||||
|
||||
int64_t offset = qkv->nb[2] * qkv->ne[2];
|
||||
auto q = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 0); // [N, C, H, W]
|
||||
auto k = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 1); // [N, C, H, W]
|
||||
auto v = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 2); // [N, C, H, W]
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_full(struct ggml_context* ctx,
|
||||
float value,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
|
||||
auto t = ggml_scale(ctx, one, value); // [1,]
|
||||
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
|
||||
return t;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_zeros(struct ggml_context* ctx,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
return ggml_full(ctx, 0.f, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ones(struct ggml_context* ctx,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
return ggml_full(ctx, 1.f, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
// q: [N * n_head, n_token, d_head]
|
||||
// k: [N * n_head, n_k, d_head]
|
||||
// v: [N * n_head, d_head, n_k]
|
||||
@ -821,6 +997,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
|
||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
||||
// mask: [N, L_q, L_k]
|
||||
// return: [N, L_q, C]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
||||
struct ggml_tensor* q,
|
||||
@ -843,11 +1020,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
N = q->ne[2];
|
||||
d_head = C / n_head;
|
||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||
|
||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
|
||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||
@ -862,43 +1039,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
int kv_pad = 0;
|
||||
// if (flash_attn) {
|
||||
if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
// }
|
||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
||||
d_head == 80 ||
|
||||
d_head == 96 ||
|
||||
d_head == 112 ||
|
||||
d_head == 128 ||
|
||||
d_head == 256);
|
||||
#if 0
|
||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
||||
#else
|
||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||
// TODO(Green-Sky): might be worth just padding by default
|
||||
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
|
||||
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
||||
} else {
|
||||
can_use_flash_attn = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
}
|
||||
|
||||
// TODO(Green-Sky): more pad or disable for funny tensor shapes
|
||||
if (!can_use_flash_attn) {
|
||||
flash_attn = false;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* kqv = nullptr;
|
||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
||||
if (can_use_flash_attn && flash_attn) {
|
||||
if (flash_attn) {
|
||||
// LOG_DEBUG(" uses flash attention");
|
||||
if (kv_pad != 0) {
|
||||
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||
@ -906,7 +1065,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
}
|
||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
if (kv_pad != 0) {
|
||||
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
|
||||
@ -915,14 +1074,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
|
||||
if (mask != nullptr) {
|
||||
mask = ggml_transpose(ctx, mask);
|
||||
|
||||
if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) {
|
||||
LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]);
|
||||
LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD));
|
||||
mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0);
|
||||
} else {
|
||||
if (kv_pad > 0) {
|
||||
mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k]
|
||||
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad]
|
||||
mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad]
|
||||
}
|
||||
}
|
||||
|
||||
// mask pad
|
||||
if (mask != nullptr) {
|
||||
int mask_pad = 0;
|
||||
if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) {
|
||||
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1];
|
||||
}
|
||||
if (mask_pad > 0) {
|
||||
mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad]
|
||||
}
|
||||
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
|
||||
// LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad);
|
||||
}
|
||||
|
||||
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
|
||||
@ -931,7 +1101,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
||||
} else {
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||
|
||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||
@ -950,7 +1120,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
||||
}
|
||||
|
||||
kqv = ggml_cont(ctx, kqv);
|
||||
kqv = ggml_nn_cont(ctx, kqv);
|
||||
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
||||
|
||||
return kqv;
|
||||
@ -963,9 +1133,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct
|
||||
float eps = EPS) {
|
||||
x = ggml_norm(ctx, x, eps);
|
||||
if (w != NULL) {
|
||||
x = ggml_mul(ctx, x, w);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
if (b != NULL) {
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
}
|
||||
return x;
|
||||
@ -984,9 +1154,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
|
||||
const float eps = 1e-6f; // default eps parameter
|
||||
x = ggml_group_norm(ctx, x, num_groups, eps);
|
||||
if (w != NULL && b != NULL) {
|
||||
x = ggml_mul(ctx, x, w);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
// b = ggml_repeat(ctx, b, x);
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -1005,14 +1175,18 @@ __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend,
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_I32);
|
||||
float value;
|
||||
if (tensor->type == GGML_TYPE_F32) {
|
||||
ggml_backend_tensor_get(tensor, &value, 0, sizeof(value));
|
||||
} else { // GGML_TYPE_F16
|
||||
} else if (tensor->type == GGML_TYPE_F16) {
|
||||
ggml_fp16_t f16_value;
|
||||
ggml_backend_tensor_get(tensor, &f16_value, 0, sizeof(f16_value));
|
||||
value = ggml_fp16_to_fp32(f16_value);
|
||||
} else { // GGML_TYPE_I32
|
||||
int int32_value;
|
||||
ggml_backend_tensor_get(tensor, &int32_value, 0, sizeof(int32_value));
|
||||
value = (float)int32_value;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -1116,7 +1290,7 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||
|
||||
/* SDXL with LoRA requires more space */
|
||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||
#define MAX_GRAPH_SIZE 32768
|
||||
#define MAX_GRAPH_SIZE 327680
|
||||
|
||||
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
||||
|
||||
@ -1124,15 +1298,27 @@ struct GGMLRunner {
|
||||
protected:
|
||||
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
|
||||
|
||||
ggml_backend_t params_backend = NULL;
|
||||
ggml_backend_t runtime_backend = NULL;
|
||||
|
||||
struct ggml_context* params_ctx = NULL;
|
||||
ggml_backend_buffer_t params_buffer = NULL;
|
||||
struct ggml_context* offload_ctx = NULL;
|
||||
ggml_backend_buffer_t runtime_params_buffer = NULL;
|
||||
bool params_on_runtime_backend = false;
|
||||
|
||||
struct ggml_context* cache_ctx = NULL;
|
||||
ggml_backend_buffer_t cache_buffer = NULL;
|
||||
|
||||
struct ggml_context* compute_ctx = NULL;
|
||||
struct ggml_gallocr* compute_allocr = NULL;
|
||||
|
||||
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
||||
std::vector<float> one_vec = {1.f};
|
||||
ggml_tensor* one_tensor = NULL;
|
||||
|
||||
ggml_backend_t backend = NULL;
|
||||
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
||||
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
|
||||
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
||||
|
||||
void alloc_params_ctx() {
|
||||
struct ggml_init_params params;
|
||||
@ -1142,6 +1328,10 @@ protected:
|
||||
|
||||
params_ctx = ggml_init(params);
|
||||
GGML_ASSERT(params_ctx != NULL);
|
||||
if (params_backend != runtime_backend) {
|
||||
offload_ctx = ggml_init(params);
|
||||
GGML_ASSERT(offload_ctx != NULL);
|
||||
}
|
||||
}
|
||||
|
||||
void free_params_ctx() {
|
||||
@ -1149,6 +1339,27 @@ protected:
|
||||
ggml_free(params_ctx);
|
||||
params_ctx = NULL;
|
||||
}
|
||||
if (offload_ctx != NULL) {
|
||||
ggml_free(offload_ctx);
|
||||
offload_ctx = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_cache_ctx() {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = true;
|
||||
|
||||
cache_ctx = ggml_init(params);
|
||||
GGML_ASSERT(cache_ctx != NULL);
|
||||
}
|
||||
|
||||
void free_cache_ctx() {
|
||||
if (cache_ctx != NULL) {
|
||||
ggml_free(cache_ctx);
|
||||
cache_ctx = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_compute_ctx() {
|
||||
@ -1168,14 +1379,33 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
void prepare_build_in_tensor_before() {
|
||||
one_tensor = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1);
|
||||
ggml_set_name(one_tensor, "ggml_runner_build_in_tensor:one");
|
||||
set_backend_tensor_data(one_tensor, one_vec.data());
|
||||
}
|
||||
|
||||
void prepare_build_in_tensor_after(struct ggml_cgraph* gf) {
|
||||
ggml_build_forward_expand(gf, one_tensor);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
|
||||
prepare_build_in_tensor_before();
|
||||
struct ggml_cgraph* gf = get_graph();
|
||||
auto result = ggml_graph_node(gf, -1);
|
||||
ggml_set_name(result, final_result_name.c_str());
|
||||
prepare_build_in_tensor_after(gf);
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
|
||||
if (compute_allocr != NULL) {
|
||||
return true;
|
||||
}
|
||||
reset_compute_ctx();
|
||||
struct ggml_cgraph* gf = get_graph();
|
||||
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||
backend_tensor_data_map.clear();
|
||||
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
|
||||
|
||||
if (!ggml_gallocr_reserve(compute_allocr, gf)) {
|
||||
// failed to allocate the compute buffer
|
||||
@ -1189,11 +1419,47 @@ protected:
|
||||
LOG_DEBUG("%s compute buffer size: %.2f MB(%s)",
|
||||
get_desc().c_str(),
|
||||
compute_buffer_size / 1024.0 / 1024.0,
|
||||
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM");
|
||||
ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM");
|
||||
return true;
|
||||
}
|
||||
|
||||
void cpy_data_to_backend_tensor() {
|
||||
void free_cache_buffer() {
|
||||
if (cache_buffer != NULL) {
|
||||
ggml_backend_buffer_free(cache_buffer);
|
||||
cache_buffer = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void copy_cache_tensors_to_cache_buffer() {
|
||||
if (cache_tensor_map.size() == 0) {
|
||||
return;
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
alloc_cache_ctx();
|
||||
GGML_ASSERT(cache_buffer == NULL);
|
||||
std::map<ggml_tensor*, ggml_tensor*> runtime_tensor_to_cache_tensor;
|
||||
for (auto kv : cache_tensor_map) {
|
||||
auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second);
|
||||
ggml_set_name(cache_tensor, kv.first.c_str());
|
||||
runtime_tensor_to_cache_tensor[kv.second] = cache_tensor;
|
||||
}
|
||||
size_t num_tensors = ggml_tensor_num(cache_ctx);
|
||||
cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend);
|
||||
GGML_ASSERT(cache_buffer != NULL);
|
||||
for (auto kv : runtime_tensor_to_cache_tensor) {
|
||||
ggml_backend_tensor_copy(kv.first, kv.second);
|
||||
}
|
||||
ggml_backend_synchronize(runtime_backend);
|
||||
cache_tensor_map.clear();
|
||||
size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer);
|
||||
LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)",
|
||||
get_desc().c_str(),
|
||||
cache_buffer_size / (1024.f * 1024.f),
|
||||
ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM",
|
||||
num_tensors);
|
||||
}
|
||||
|
||||
void copy_data_to_backend_tensor() {
|
||||
for (auto& kv : backend_tensor_data_map) {
|
||||
auto tensor = kv.first;
|
||||
auto data = kv.second;
|
||||
@ -1204,12 +1470,96 @@ protected:
|
||||
backend_tensor_data_map.clear();
|
||||
}
|
||||
|
||||
bool offload_params_to_runtime_backend() {
|
||||
if (params_backend == runtime_backend) {
|
||||
return true;
|
||||
}
|
||||
if (params_on_runtime_backend) {
|
||||
return true;
|
||||
}
|
||||
GGML_ASSERT(runtime_params_buffer == NULL);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
size_t num_tensors = ggml_tensor_num(offload_ctx);
|
||||
if (num_tensors == 0) {
|
||||
for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
|
||||
GGML_ASSERT(t->view_src == NULL);
|
||||
ggml_dup_tensor(offload_ctx, t);
|
||||
}
|
||||
}
|
||||
num_tensors = ggml_tensor_num(offload_ctx);
|
||||
GGML_ASSERT(num_tensors == ggml_tensor_num(params_ctx));
|
||||
|
||||
runtime_params_buffer = ggml_backend_alloc_ctx_tensors(offload_ctx, runtime_backend);
|
||||
|
||||
if (runtime_params_buffer == NULL) {
|
||||
LOG_ERROR("%s alloc runtime params backend buffer failed, num_tensors = %i",
|
||||
get_desc().c_str(),
|
||||
num_tensors);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_tensor* t = ggml_get_first_tensor(params_ctx);
|
||||
ggml_tensor* offload_t = ggml_get_first_tensor(offload_ctx);
|
||||
|
||||
while (t != NULL && offload_t != NULL) {
|
||||
ggml_backend_tensor_copy(t, offload_t);
|
||||
std::swap(t->buffer, offload_t->buffer);
|
||||
std::swap(t->data, offload_t->data);
|
||||
|
||||
t = ggml_get_next_tensor(params_ctx, t);
|
||||
offload_t = ggml_get_next_tensor(offload_ctx, offload_t);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
size_t params_buffer_size = ggml_backend_buffer_get_size(runtime_params_buffer);
|
||||
LOG_INFO("%s offload params (%6.2f MB, %i tensors) to runtime backend (%s), taking %.2fs",
|
||||
get_desc().c_str(),
|
||||
params_buffer_size / (1024.f * 1024.f),
|
||||
num_tensors,
|
||||
ggml_backend_name(runtime_backend),
|
||||
(t1 - t0) * 1.0f / 1000);
|
||||
|
||||
params_on_runtime_backend = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void offload_params_to_params_backend() {
|
||||
if (!params_on_runtime_backend) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor* t = ggml_get_first_tensor(params_ctx);
|
||||
ggml_tensor* offload_t = ggml_get_first_tensor(offload_ctx);
|
||||
|
||||
while (t != NULL && offload_t != NULL) {
|
||||
t->buffer = offload_t->buffer;
|
||||
t->data = offload_t->data;
|
||||
offload_t->buffer = NULL;
|
||||
offload_t->data = NULL;
|
||||
|
||||
t = ggml_get_next_tensor(params_ctx, t);
|
||||
offload_t = ggml_get_next_tensor(offload_ctx, offload_t);
|
||||
}
|
||||
|
||||
if (runtime_params_buffer != NULL) {
|
||||
ggml_backend_buffer_free(runtime_params_buffer);
|
||||
runtime_params_buffer = NULL;
|
||||
}
|
||||
params_on_runtime_backend = false;
|
||||
}
|
||||
|
||||
public:
|
||||
virtual std::string get_desc() = 0;
|
||||
|
||||
GGMLRunner(ggml_backend_t backend)
|
||||
: backend(backend) {
|
||||
GGMLRunner(ggml_backend_t backend, bool offload_params_to_cpu = false)
|
||||
: runtime_backend(backend) {
|
||||
alloc_params_ctx();
|
||||
if (!ggml_backend_is_cpu(runtime_backend) && offload_params_to_cpu) {
|
||||
params_backend = ggml_backend_cpu_init();
|
||||
} else {
|
||||
params_backend = runtime_backend;
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~GGMLRunner() {
|
||||
@ -1217,6 +1567,10 @@ public:
|
||||
free_compute_buffer();
|
||||
free_params_ctx();
|
||||
free_compute_ctx();
|
||||
if (params_backend != runtime_backend) {
|
||||
ggml_backend_free(params_backend);
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
}
|
||||
|
||||
void reset_compute_ctx() {
|
||||
@ -1226,7 +1580,7 @@ public:
|
||||
|
||||
bool alloc_params_buffer() {
|
||||
size_t num_tensors = ggml_tensor_num(params_ctx);
|
||||
params_buffer = ggml_backend_alloc_ctx_tensors(params_ctx, backend);
|
||||
params_buffer = ggml_backend_alloc_ctx_tensors(params_ctx, params_backend);
|
||||
if (params_buffer == NULL) {
|
||||
LOG_ERROR("%s alloc params backend buffer failed, num_tensors = %i",
|
||||
get_desc().c_str(),
|
||||
@ -1236,14 +1590,9 @@ public:
|
||||
size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer);
|
||||
LOG_DEBUG("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)",
|
||||
get_desc().c_str(),
|
||||
params_buffer_size / (1024.0 * 1024.0),
|
||||
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
|
||||
params_buffer_size / (1024.f * 1024.f),
|
||||
ggml_backend_is_cpu(params_backend) ? "RAM" : "VRAM",
|
||||
num_tensors);
|
||||
// printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n",
|
||||
// get_desc().c_str(),
|
||||
// params_buffer_size / (1024.0 * 1024.0),
|
||||
// ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
|
||||
// num_tensors);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1261,11 +1610,17 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
void free_cache_ctx_and_buffer() {
|
||||
free_cache_buffer();
|
||||
free_cache_ctx();
|
||||
}
|
||||
|
||||
void free_compute_buffer() {
|
||||
if (compute_allocr != NULL) {
|
||||
ggml_gallocr_free(compute_allocr);
|
||||
compute_allocr = NULL;
|
||||
}
|
||||
offload_params_to_params_backend();
|
||||
}
|
||||
|
||||
// do copy after alloc graph
|
||||
@ -1279,7 +1634,7 @@ public:
|
||||
return NULL;
|
||||
}
|
||||
// it's performing a compute, check if backend isn't cpu
|
||||
if (!ggml_backend_is_cpu(backend) && (tensor->buffer == NULL || ggml_backend_buffer_is_host(tensor->buffer))) {
|
||||
if (!ggml_backend_is_cpu(runtime_backend) && (tensor->buffer == NULL || ggml_backend_buffer_is_host(tensor->buffer))) {
|
||||
// pass input tensors to gpu memory
|
||||
auto backend_tensor = ggml_dup_tensor(compute_ctx, tensor);
|
||||
|
||||
@ -1290,31 +1645,47 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void cache(const std::string name, struct ggml_tensor* tensor) {
|
||||
cache_tensor_map[name] = tensor;
|
||||
}
|
||||
|
||||
struct ggml_tensor* get_cache_tensor_by_name(const std::string& name) {
|
||||
if (cache_ctx == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
return ggml_get_tensor(cache_ctx, name.c_str());
|
||||
}
|
||||
|
||||
void compute(get_graph_cb_t get_graph,
|
||||
int n_threads,
|
||||
bool free_compute_buffer_immediately = true,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
if (!offload_params_to_runtime_backend()) {
|
||||
LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str());
|
||||
return;
|
||||
}
|
||||
alloc_compute_buffer(get_graph);
|
||||
reset_compute_ctx();
|
||||
struct ggml_cgraph* gf = get_graph();
|
||||
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
||||
cpy_data_to_backend_tensor();
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
copy_data_to_backend_tensor();
|
||||
if (ggml_backend_is_cpu(runtime_backend)) {
|
||||
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
|
||||
}
|
||||
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
ggml_backend_graph_compute(runtime_backend, gf);
|
||||
#ifdef GGML_PERF
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
copy_cache_tensors_to_cache_buffer();
|
||||
if (output != NULL) {
|
||||
auto result = ggml_graph_node(gf, -1);
|
||||
auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str());
|
||||
if (*output == NULL && output_ctx != NULL) {
|
||||
*output = ggml_dup_tensor(output_ctx, result);
|
||||
}
|
||||
if (*output != NULL) {
|
||||
ggml_backend_tensor_get_and_sync(backend, result, (*output)->data, 0, ggml_nbytes(*output));
|
||||
ggml_backend_tensor_get_and_sync(runtime_backend, result, (*output)->data, 0, ggml_nbytes(*output));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1416,6 +1787,13 @@ public:
|
||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
|
||||
};
|
||||
|
||||
class Identity : public UnaryBlock {
|
||||
public:
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class Linear : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_features;
|
||||
@ -1430,7 +1808,7 @@ protected:
|
||||
}
|
||||
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
|
||||
if (bias) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
|
||||
}
|
||||
}
|
||||
@ -1594,6 +1972,58 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class Conv3d : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
std::tuple<int, int, int> kernel_size;
|
||||
std::tuple<int, int, int> stride;
|
||||
std::tuple<int, int, int> padding;
|
||||
std::tuple<int, int, int> dilation;
|
||||
bool bias;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||
wtype,
|
||||
std::get<2>(kernel_size),
|
||||
std::get<1>(kernel_size),
|
||||
std::get<0>(kernel_size),
|
||||
in_channels * out_channels);
|
||||
if (bias) {
|
||||
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Conv3d(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
std::tuple<int, int, int> kernel_size,
|
||||
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||
std::tuple<int, int, int> padding = {0, 0, 0},
|
||||
std::tuple<int, int, int> dilation = {1, 1, 1},
|
||||
bool bias = true)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels),
|
||||
kernel_size(kernel_size),
|
||||
stride(stride),
|
||||
padding(padding),
|
||||
dilation(dilation),
|
||||
bias(bias) {}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
struct ggml_tensor* b = NULL;
|
||||
if (bias) {
|
||||
b = params["bias"];
|
||||
}
|
||||
return ggml_nn_conv_3d(ctx, x, w, b, in_channels,
|
||||
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
||||
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||
}
|
||||
};
|
||||
|
||||
class LayerNorm : public UnaryBlock {
|
||||
protected:
|
||||
int64_t normalized_shape;
|
||||
@ -1679,6 +2109,30 @@ public:
|
||||
: GroupNorm(32, num_channels, 1e-06f) {}
|
||||
};
|
||||
|
||||
class RMSNorm : public UnaryBlock {
|
||||
protected:
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||
}
|
||||
|
||||
public:
|
||||
RMSNorm(int64_t hidden_size,
|
||||
float eps = 1e-06f)
|
||||
: hidden_size(hidden_size),
|
||||
eps(eps) {}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
x = ggml_rms_norm(ctx, x, eps);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class MultiheadAttention : public GGMLBlock {
|
||||
protected:
|
||||
int64_t embed_dim;
|
||||
|
||||
231
gguf_reader.hpp
Normal file
231
gguf_reader.hpp
Normal file
@ -0,0 +1,231 @@
|
||||
#ifndef __GGUF_READER_HPP__
|
||||
#define __GGUF_READER_HPP__
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "util.h"
|
||||
|
||||
struct GGUFTensorInfo {
|
||||
std::string name;
|
||||
ggml_type type;
|
||||
std::vector<int64_t> shape;
|
||||
size_t offset;
|
||||
};
|
||||
|
||||
enum class GGUFMetadataType : uint32_t {
|
||||
UINT8 = 0,
|
||||
INT8 = 1,
|
||||
UINT16 = 2,
|
||||
INT16 = 3,
|
||||
UINT32 = 4,
|
||||
INT32 = 5,
|
||||
FLOAT32 = 6,
|
||||
BOOL = 7,
|
||||
STRING = 8,
|
||||
ARRAY = 9,
|
||||
UINT64 = 10,
|
||||
INT64 = 11,
|
||||
FLOAT64 = 12,
|
||||
};
|
||||
|
||||
class GGUFReader {
|
||||
private:
|
||||
std::vector<GGUFTensorInfo> tensors_;
|
||||
size_t data_offset_;
|
||||
size_t alignment_ = 32; // default alignment is 32
|
||||
|
||||
template <typename T>
|
||||
bool safe_read(std::ifstream& fin, T& value) {
|
||||
fin.read(reinterpret_cast<char*>(&value), sizeof(T));
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool safe_read(std::ifstream& fin, char* buffer, size_t size) {
|
||||
fin.read(buffer, size);
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool safe_seek(std::ifstream& fin, std::streamoff offset, std::ios::seekdir dir) {
|
||||
fin.seekg(offset, dir);
|
||||
return fin.good();
|
||||
}
|
||||
|
||||
bool read_metadata(std::ifstream& fin) {
|
||||
uint64_t key_len = 0;
|
||||
if (!safe_read(fin, key_len))
|
||||
return false;
|
||||
|
||||
std::string key(key_len, '\0');
|
||||
if (!safe_read(fin, (char*)key.data(), key_len))
|
||||
return false;
|
||||
|
||||
uint32_t type = 0;
|
||||
if (!safe_read(fin, type))
|
||||
return false;
|
||||
|
||||
if (key == "general.alignment") {
|
||||
uint32_t align_val = 0;
|
||||
if (!safe_read(fin, align_val))
|
||||
return false;
|
||||
|
||||
if (align_val != 0 && (align_val & (align_val - 1)) == 0) {
|
||||
alignment_ = align_val;
|
||||
LOG_DEBUG("Found alignment: %zu", alignment_);
|
||||
} else {
|
||||
LOG_ERROR("Invalid alignment value %u, fallback to default %zu", align_val, alignment_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (static_cast<GGUFMetadataType>(type)) {
|
||||
case GGUFMetadataType::UINT8:
|
||||
case GGUFMetadataType::INT8:
|
||||
case GGUFMetadataType::BOOL:
|
||||
return safe_seek(fin, 1, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT16:
|
||||
case GGUFMetadataType::INT16:
|
||||
return safe_seek(fin, 2, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT32:
|
||||
case GGUFMetadataType::INT32:
|
||||
case GGUFMetadataType::FLOAT32:
|
||||
return safe_seek(fin, 4, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::UINT64:
|
||||
case GGUFMetadataType::INT64:
|
||||
case GGUFMetadataType::FLOAT64:
|
||||
return safe_seek(fin, 8, std::ios::cur);
|
||||
|
||||
case GGUFMetadataType::STRING: {
|
||||
uint64_t len = 0;
|
||||
if (!safe_read(fin, len))
|
||||
return false;
|
||||
return safe_seek(fin, len, std::ios::cur);
|
||||
}
|
||||
|
||||
case GGUFMetadataType::ARRAY: {
|
||||
uint32_t elem_type = 0;
|
||||
uint64_t len = 0;
|
||||
if (!safe_read(fin, elem_type))
|
||||
return false;
|
||||
if (!safe_read(fin, len))
|
||||
return false;
|
||||
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
if (!read_metadata(fin))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
default:
|
||||
LOG_ERROR("Unknown metadata type=%u", type);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
GGUFTensorInfo read_tensor_info(std::ifstream& fin) {
|
||||
GGUFTensorInfo info;
|
||||
|
||||
uint64_t name_len;
|
||||
if (!safe_read(fin, name_len))
|
||||
throw std::runtime_error("read tensor name length failed");
|
||||
|
||||
info.name.resize(name_len);
|
||||
if (!safe_read(fin, (char*)info.name.data(), name_len))
|
||||
throw std::runtime_error("read tensor name failed");
|
||||
|
||||
uint32_t n_dims;
|
||||
if (!safe_read(fin, n_dims))
|
||||
throw std::runtime_error("read tensor dims failed");
|
||||
|
||||
info.shape.resize(n_dims);
|
||||
for (uint32_t i = 0; i < n_dims; i++) {
|
||||
if (!safe_read(fin, info.shape[i]))
|
||||
throw std::runtime_error("read tensor shape failed");
|
||||
}
|
||||
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
for (int i = GGML_MAX_DIMS; i < n_dims; i++) {
|
||||
info.shape[GGML_MAX_DIMS - 1] *= info.shape[i]; // stack to last dim;
|
||||
}
|
||||
info.shape.resize(GGML_MAX_DIMS);
|
||||
n_dims = GGML_MAX_DIMS;
|
||||
}
|
||||
|
||||
uint32_t type;
|
||||
if (!safe_read(fin, type))
|
||||
throw std::runtime_error("read tensor type failed");
|
||||
info.type = static_cast<ggml_type>(type);
|
||||
|
||||
if (!safe_read(fin, info.offset))
|
||||
throw std::runtime_error("read tensor offset failed");
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
public:
|
||||
bool load(const std::string& file_path) {
|
||||
std::ifstream fin(file_path, std::ios::binary);
|
||||
if (!fin) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Header ---
|
||||
char magic[4];
|
||||
if (!safe_read(fin, magic, 4) || strncmp(magic, "GGUF", 4) != 0) {
|
||||
LOG_ERROR("not a valid GGUF file");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t version;
|
||||
if (!safe_read(fin, version))
|
||||
return false;
|
||||
|
||||
uint64_t tensor_count, metadata_kv_count;
|
||||
if (!safe_read(fin, tensor_count))
|
||||
return false;
|
||||
if (!safe_read(fin, metadata_kv_count))
|
||||
return false;
|
||||
|
||||
LOG_DEBUG("GGUF v%u, tensor_count=%llu, metadata_kv_count=%llu",
|
||||
version, (unsigned long long)tensor_count, (unsigned long long)metadata_kv_count);
|
||||
|
||||
// --- Read Metadata ---
|
||||
for (uint64_t i = 0; i < metadata_kv_count; i++) {
|
||||
if (!read_metadata(fin)) {
|
||||
LOG_ERROR("read meta data failed");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tensor Infos ---
|
||||
tensors_.clear();
|
||||
try {
|
||||
for (uint64_t i = 0; i < tensor_count; i++) {
|
||||
tensors_.push_back(read_tensor_info(fin));
|
||||
}
|
||||
} catch (const std::runtime_error& e) {
|
||||
LOG_ERROR("%s", e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
data_offset_ = static_cast<size_t>(fin.tellg());
|
||||
if ((data_offset_ % alignment_) != 0) {
|
||||
data_offset_ = ((data_offset_ + alignment_ - 1) / alignment_) * alignment_;
|
||||
}
|
||||
fin.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<GGUFTensorInfo>& tensors() const { return tensors_; }
|
||||
size_t data_offset() const { return data_offset_; }
|
||||
};
|
||||
|
||||
#endif // __GGUF_READER_HPP__
|
||||
191
lora.hpp
191
lora.hpp
@ -92,6 +92,7 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
float multiplier = 1.0f;
|
||||
std::map<std::string, struct ggml_tensor*> lora_tensors;
|
||||
std::map<ggml_tensor*, ggml_tensor*> original_tensor_to_final_tensor;
|
||||
std::string file_path;
|
||||
ModelLoader model_loader;
|
||||
bool load_failed = false;
|
||||
@ -103,7 +104,7 @@ struct LoraModel : public GGMLRunner {
|
||||
LoraModel(ggml_backend_t backend,
|
||||
const std::string& file_path = "",
|
||||
const std::string prefix = "")
|
||||
: file_path(file_path), GGMLRunner(backend) {
|
||||
: file_path(file_path), GGMLRunner(backend, false) {
|
||||
if (!model_loader.init_from_file(file_path, prefix)) {
|
||||
load_failed = true;
|
||||
}
|
||||
@ -129,7 +130,7 @@ struct LoraModel : public GGMLRunner {
|
||||
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
|
||||
return true;
|
||||
}
|
||||
// LOG_INFO("%s", name.c_str());
|
||||
// LOG_INFO("lora_tensor %s", name.c_str());
|
||||
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
|
||||
if (name.find(type_fingerprints[i]) != std::string::npos) {
|
||||
type = (lora_t)i;
|
||||
@ -151,11 +152,11 @@ struct LoraModel : public GGMLRunner {
|
||||
return true;
|
||||
};
|
||||
|
||||
model_loader.load_tensors(on_new_tensor_cb, backend);
|
||||
model_loader.load_tensors(on_new_tensor_cb);
|
||||
alloc_params_buffer();
|
||||
// exit(0);
|
||||
dry_run = false;
|
||||
model_loader.load_tensors(on_new_tensor_cb, backend);
|
||||
model_loader.load_tensors(on_new_tensor_cb);
|
||||
|
||||
LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());
|
||||
|
||||
@ -167,6 +168,7 @@ struct LoraModel : public GGMLRunner {
|
||||
auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a));
|
||||
out = ggml_get_rows(ctx, out, zero_index);
|
||||
out = ggml_reshape(ctx, out, a);
|
||||
// auto out = ggml_cast(ctx, a, GGML_TYPE_F32);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -245,14 +247,22 @@ struct LoraModel : public GGMLRunner {
|
||||
set_backend_tensor_data(zero_index, zero_index_vec.data());
|
||||
ggml_build_forward_expand(gf, zero_index);
|
||||
|
||||
original_tensor_to_final_tensor.clear();
|
||||
|
||||
std::set<std::string> applied_lora_tensors;
|
||||
for (auto it : model_tensors) {
|
||||
std::string k_tensor = it.first;
|
||||
struct ggml_tensor* weight = model_tensors[it.first];
|
||||
std::string model_tensor_name = it.first;
|
||||
struct ggml_tensor* model_tensor = model_tensors[it.first];
|
||||
|
||||
std::vector<std::string> keys = to_lora_keys(k_tensor, version);
|
||||
if (keys.size() == 0)
|
||||
std::vector<std::string> keys = to_lora_keys(model_tensor_name, version);
|
||||
bool is_bias = ends_with(model_tensor_name, ".bias");
|
||||
if (keys.size() == 0) {
|
||||
if (is_bias) {
|
||||
keys.push_back(model_tensor_name.substr(0, model_tensor_name.size() - 5)); // remove .bias
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& key : keys) {
|
||||
bool is_qkv_split = starts_with(key, "SPLIT|");
|
||||
@ -265,8 +275,22 @@ struct LoraModel : public GGMLRunner {
|
||||
}
|
||||
struct ggml_tensor* updown = NULL;
|
||||
float scale_value = 1.0f;
|
||||
std::string fk = lora_pre[type] + key;
|
||||
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
|
||||
std::string full_key = lora_pre[type] + key;
|
||||
if (is_bias) {
|
||||
if (lora_tensors.find(full_key + ".diff_b") != lora_tensors.end()) {
|
||||
std::string diff_name = full_key + ".diff_b";
|
||||
ggml_tensor* diff = lora_tensors[diff_name];
|
||||
updown = to_f32(compute_ctx, diff);
|
||||
applied_lora_tensors.insert(diff_name);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else if (lora_tensors.find(full_key + ".diff") != lora_tensors.end()) {
|
||||
std::string diff_name = full_key + ".diff";
|
||||
ggml_tensor* diff = lora_tensors[diff_name];
|
||||
updown = to_f32(compute_ctx, diff);
|
||||
applied_lora_tensors.insert(diff_name);
|
||||
} else if (lora_tensors.find(full_key + ".hada_w1_a") != lora_tensors.end()) {
|
||||
// LoHa mode
|
||||
|
||||
// TODO: split qkv convention for LoHas (is it ever used?)
|
||||
@ -292,9 +316,9 @@ struct LoraModel : public GGMLRunner {
|
||||
std::string hada_2_down_name = "";
|
||||
std::string hada_2_up_name = "";
|
||||
|
||||
hada_1_down_name = fk + ".hada_w1_b";
|
||||
hada_1_up_name = fk + ".hada_w1_a";
|
||||
hada_1_mid_name = fk + ".hada_t1";
|
||||
hada_1_down_name = full_key + ".hada_w1_b";
|
||||
hada_1_up_name = full_key + ".hada_w1_a";
|
||||
hada_1_mid_name = full_key + ".hada_t1";
|
||||
if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) {
|
||||
hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]);
|
||||
}
|
||||
@ -307,9 +331,9 @@ struct LoraModel : public GGMLRunner {
|
||||
hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up));
|
||||
}
|
||||
|
||||
hada_2_down_name = fk + ".hada_w2_b";
|
||||
hada_2_up_name = fk + ".hada_w2_a";
|
||||
hada_2_mid_name = fk + ".hada_t2";
|
||||
hada_2_down_name = full_key + ".hada_w2_b";
|
||||
hada_2_up_name = full_key + ".hada_w2_a";
|
||||
hada_2_mid_name = full_key + ".hada_t2";
|
||||
if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) {
|
||||
hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]);
|
||||
}
|
||||
@ -322,7 +346,7 @@ struct LoraModel : public GGMLRunner {
|
||||
hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up));
|
||||
}
|
||||
|
||||
alpha_name = fk + ".alpha";
|
||||
alpha_name = full_key + ".alpha";
|
||||
|
||||
applied_lora_tensors.insert(hada_1_down_name);
|
||||
applied_lora_tensors.insert(hada_1_up_name);
|
||||
@ -345,7 +369,7 @@ struct LoraModel : public GGMLRunner {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
}
|
||||
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
|
||||
} else if (lora_tensors.find(full_key + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(full_key + ".lokr_w1_a") != lora_tensors.end()) {
|
||||
// LoKr mode
|
||||
|
||||
// TODO: split qkv convention for LoKrs (is it ever used?)
|
||||
@ -354,7 +378,7 @@ struct LoraModel : public GGMLRunner {
|
||||
break;
|
||||
}
|
||||
|
||||
std::string alpha_name = fk + ".alpha";
|
||||
std::string alpha_name = full_key + ".alpha";
|
||||
|
||||
ggml_tensor* lokr_w1 = NULL;
|
||||
ggml_tensor* lokr_w2 = NULL;
|
||||
@ -362,8 +386,8 @@ struct LoraModel : public GGMLRunner {
|
||||
std::string lokr_w1_name = "";
|
||||
std::string lokr_w2_name = "";
|
||||
|
||||
lokr_w1_name = fk + ".lokr_w1";
|
||||
lokr_w2_name = fk + ".lokr_w2";
|
||||
lokr_w1_name = full_key + ".lokr_w1";
|
||||
lokr_w2_name = full_key + ".lokr_w2";
|
||||
|
||||
if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) {
|
||||
lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]);
|
||||
@ -435,29 +459,29 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
if (is_qkv_split) {
|
||||
std::string suffix = "";
|
||||
auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_q_d_name = full_key + "q" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
|
||||
suffix = "_proj";
|
||||
split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
|
||||
split_q_d_name = full_key + "q" + suffix + lora_downs[type] + ".weight";
|
||||
}
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_k_d_name = full_key + "k" + suffix + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = full_key + "v" + suffix + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_q_u_name = full_key + "q" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = full_key + "k" + suffix + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = full_key + "v" + suffix + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = fk + "q" + suffix + ".scale";
|
||||
auto split_k_scale_name = fk + "k" + suffix + ".scale";
|
||||
auto split_v_scale_name = fk + "v" + suffix + ".scale";
|
||||
auto split_q_scale_name = full_key + "q" + suffix + ".scale";
|
||||
auto split_k_scale_name = full_key + "k" + suffix + ".scale";
|
||||
auto split_v_scale_name = full_key + "v" + suffix + ".scale";
|
||||
|
||||
auto split_q_alpha_name = fk + "q" + suffix + ".alpha";
|
||||
auto split_k_alpha_name = fk + "k" + suffix + ".alpha";
|
||||
auto split_v_alpha_name = fk + "v" + suffix + ".alpha";
|
||||
auto split_q_alpha_name = full_key + "q" + suffix + ".alpha";
|
||||
auto split_k_alpha_name = full_key + "k" + suffix + ".alpha";
|
||||
auto split_v_alpha_name = full_key + "v" + suffix + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
@ -571,29 +595,29 @@ struct LoraModel : public GGMLRunner {
|
||||
applied_lora_tensors.insert(split_v_d_name);
|
||||
}
|
||||
} else if (is_qkvm_split) {
|
||||
auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight";
|
||||
auto split_q_d_name = full_key + "attn.to_q" + lora_downs[type] + ".weight";
|
||||
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
|
||||
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
|
||||
// find qkv and mlp up parts in LoRA model
|
||||
auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight";
|
||||
auto split_k_d_name = full_key + "attn.to_k" + lora_downs[type] + ".weight";
|
||||
auto split_v_d_name = full_key + "attn.to_v" + lora_downs[type] + ".weight";
|
||||
|
||||
auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight";
|
||||
auto split_q_u_name = full_key + "attn.to_q" + lora_ups[type] + ".weight";
|
||||
auto split_k_u_name = full_key + "attn.to_k" + lora_ups[type] + ".weight";
|
||||
auto split_v_u_name = full_key + "attn.to_v" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight";
|
||||
auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight";
|
||||
auto split_m_d_name = full_key + "proj_mlp" + lora_downs[type] + ".weight";
|
||||
auto split_m_u_name = full_key + "proj_mlp" + lora_ups[type] + ".weight";
|
||||
|
||||
auto split_q_scale_name = fk + "attn.to_q" + ".scale";
|
||||
auto split_k_scale_name = fk + "attn.to_k" + ".scale";
|
||||
auto split_v_scale_name = fk + "attn.to_v" + ".scale";
|
||||
auto split_m_scale_name = fk + "proj_mlp" + ".scale";
|
||||
auto split_q_scale_name = full_key + "attn.to_q" + ".scale";
|
||||
auto split_k_scale_name = full_key + "attn.to_k" + ".scale";
|
||||
auto split_v_scale_name = full_key + "attn.to_v" + ".scale";
|
||||
auto split_m_scale_name = full_key + "proj_mlp" + ".scale";
|
||||
|
||||
auto split_q_alpha_name = fk + "attn.to_q" + ".alpha";
|
||||
auto split_k_alpha_name = fk + "attn.to_k" + ".alpha";
|
||||
auto split_v_alpha_name = fk + "attn.to_v" + ".alpha";
|
||||
auto split_m_alpha_name = fk + "proj_mlp" + ".alpha";
|
||||
auto split_q_alpha_name = full_key + "attn.to_q" + ".alpha";
|
||||
auto split_k_alpha_name = full_key + "attn.to_k" + ".alpha";
|
||||
auto split_v_alpha_name = full_key + "attn.to_v" + ".alpha";
|
||||
auto split_m_alpha_name = full_key + "proj_mlp" + ".alpha";
|
||||
|
||||
ggml_tensor* lora_q_down = NULL;
|
||||
ggml_tensor* lora_q_up = NULL;
|
||||
@ -748,30 +772,27 @@ struct LoraModel : public GGMLRunner {
|
||||
applied_lora_tensors.insert(split_m_d_name);
|
||||
}
|
||||
} else {
|
||||
lora_up_name = fk + lora_ups[type] + ".weight";
|
||||
lora_down_name = fk + lora_downs[type] + ".weight";
|
||||
lora_mid_name = fk + ".lora_mid.weight";
|
||||
lora_up_name = full_key + lora_ups[type] + ".weight";
|
||||
lora_down_name = full_key + lora_downs[type] + ".weight";
|
||||
lora_mid_name = full_key + ".lora_mid.weight";
|
||||
|
||||
alpha_name = fk + ".alpha";
|
||||
scale_name = fk + ".scale";
|
||||
alpha_name = full_key + ".alpha";
|
||||
scale_name = full_key + ".scale";
|
||||
|
||||
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
|
||||
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
|
||||
applied_lora_tensors.insert(lora_up_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
|
||||
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
|
||||
applied_lora_tensors.insert(lora_down_name);
|
||||
}
|
||||
|
||||
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
|
||||
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
|
||||
applied_lora_tensors.insert(lora_mid_name);
|
||||
}
|
||||
|
||||
applied_lora_tensors.insert(lora_up_name);
|
||||
applied_lora_tensors.insert(lora_down_name);
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
applied_lora_tensors.insert(scale_name);
|
||||
}
|
||||
|
||||
if (lora_up == NULL || lora_down == NULL) {
|
||||
@ -782,29 +803,37 @@ struct LoraModel : public GGMLRunner {
|
||||
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
|
||||
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
|
||||
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
|
||||
applied_lora_tensors.insert(scale_name);
|
||||
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
|
||||
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
|
||||
scale_value = alpha / rank;
|
||||
// LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value);
|
||||
applied_lora_tensors.insert(alpha_name);
|
||||
}
|
||||
|
||||
updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
|
||||
}
|
||||
scale_value *= multiplier;
|
||||
updown = ggml_reshape(compute_ctx, updown, weight);
|
||||
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
|
||||
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
||||
ggml_tensor* final_weight;
|
||||
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
|
||||
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
|
||||
// final_weight = ggml_cpy(compute_ctx, weight, final_weight);
|
||||
final_weight = to_f32(compute_ctx, weight);
|
||||
final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
|
||||
final_weight = ggml_cpy(compute_ctx, final_weight, weight);
|
||||
} else {
|
||||
final_weight = ggml_add_inplace(compute_ctx, weight, updown);
|
||||
ggml_tensor* original_tensor = model_tensor;
|
||||
if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) {
|
||||
model_tensor = ggml_dup_tensor(compute_ctx, model_tensor);
|
||||
set_backend_tensor_data(model_tensor, original_tensor->data);
|
||||
}
|
||||
updown = ggml_reshape(compute_ctx, updown, model_tensor);
|
||||
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(model_tensor));
|
||||
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
|
||||
ggml_tensor* final_tensor;
|
||||
if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) {
|
||||
final_tensor = to_f32(compute_ctx, model_tensor);
|
||||
final_tensor = ggml_add_inplace(compute_ctx, final_tensor, updown);
|
||||
final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor);
|
||||
} else {
|
||||
final_tensor = ggml_add_inplace(compute_ctx, model_tensor, updown);
|
||||
}
|
||||
ggml_build_forward_expand(gf, final_tensor);
|
||||
if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) {
|
||||
original_tensor_to_final_tensor[original_tensor] = final_tensor;
|
||||
}
|
||||
// final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
|
||||
ggml_build_forward_expand(gf, final_weight);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -825,10 +854,10 @@ struct LoraModel : public GGMLRunner {
|
||||
* this function is called once to calculate the required buffer size
|
||||
* and then again to actually generate a graph to be used */
|
||||
if (applied_lora_tensors_count != total_lora_tensors_count) {
|
||||
LOG_WARN("Only (%lu / %lu) LoRA tensors have been applied",
|
||||
LOG_WARN("Only (%lu / %lu) LoRA tensors will be applied",
|
||||
applied_lora_tensors_count, total_lora_tensors_count);
|
||||
} else {
|
||||
LOG_DEBUG("(%lu / %lu) LoRA tensors applied successfully",
|
||||
LOG_DEBUG("(%lu / %lu) LoRA tensors will be applied",
|
||||
applied_lora_tensors_count, total_lora_tensors_count);
|
||||
}
|
||||
|
||||
@ -839,7 +868,15 @@ struct LoraModel : public GGMLRunner {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_lora_graph(model_tensors, version);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true);
|
||||
GGMLRunner::compute(get_graph, n_threads, false);
|
||||
for (auto item : original_tensor_to_final_tensor) {
|
||||
ggml_tensor* original_tensor = item.first;
|
||||
ggml_tensor* final_tensor = item.second;
|
||||
|
||||
ggml_backend_tensor_copy(final_tensor, original_tensor);
|
||||
}
|
||||
original_tensor_to_final_tensor.clear();
|
||||
GGMLRunner::free_compute_buffer();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
74
ltxv.hpp
Normal file
74
ltxv.hpp
Normal file
@ -0,0 +1,74 @@
|
||||
#ifndef __LTXV_HPP__
|
||||
#define __LTXV_HPP__
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
namespace LTXV {
|
||||
|
||||
class CausalConv3d : public GGMLBlock {
|
||||
protected:
|
||||
int time_kernel_size;
|
||||
|
||||
public:
|
||||
CausalConv3d(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
int kernel_size = 3,
|
||||
std::tuple<int> stride = {1, 1, 1},
|
||||
int dilation = 1,
|
||||
bool bias = true) {
|
||||
time_kernel_size = kernel_size / 2;
|
||||
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
|
||||
out_channels,
|
||||
{kernel_size, kernel_size, kernel_size},
|
||||
stride,
|
||||
{0, kernel_size / 2, kernel_size / 2},
|
||||
{dilation, 1, 1},
|
||||
bias));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
bool causal = true) {
|
||||
// x: [N*IC, ID, IH, IW]
|
||||
// result: [N*OC, OD, OH, OW]
|
||||
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
||||
if (causal) {
|
||||
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW]
|
||||
auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW]
|
||||
first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||
auto first_frame_pad = first_frame;
|
||||
for (int i = 1; i < time_kernel_size - 1; i++) {
|
||||
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
|
||||
}
|
||||
x = ggml_concat(ctx, first_frame_pad, x, 2);
|
||||
} else {
|
||||
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW]
|
||||
int64_t offset = h->nb[2] * h->ne[2];
|
||||
|
||||
auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW]
|
||||
first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||
auto first_frame_pad = first_frame;
|
||||
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
|
||||
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
|
||||
}
|
||||
|
||||
auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW]
|
||||
last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||
auto last_frame_pad = last_frame;
|
||||
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
|
||||
last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2);
|
||||
}
|
||||
|
||||
x = ggml_concat(ctx, first_frame_pad, x, 2);
|
||||
x = ggml_concat(ctx, x, last_frame_pad, 2);
|
||||
}
|
||||
|
||||
x = conv->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
31
mmdit.hpp
31
mmdit.hpp
@ -142,30 +142,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class RMSNorm : public UnaryBlock {
|
||||
protected:
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||
}
|
||||
|
||||
public:
|
||||
RMSNorm(int64_t hidden_size,
|
||||
float eps = 1e-06f)
|
||||
: hidden_size(hidden_size),
|
||||
eps(eps) {}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
x = ggml_rms_norm(ctx, x, eps);
|
||||
x = ggml_mul(ctx, x, w);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class SelfAttention : public GGMLBlock {
|
||||
public:
|
||||
int64_t num_heads;
|
||||
@ -870,9 +846,10 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
MMDiT mmdit;
|
||||
|
||||
MMDiTRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "")
|
||||
: GGMLRunner(backend), mmdit(tensor_types) {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
|
||||
mmdit.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
@ -970,7 +947,7 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_F16;
|
||||
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend));
|
||||
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
|
||||
{
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
@ -984,7 +961,7 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
return;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(tensors, backend);
|
||||
bool success = model_loader.load_tensors(tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from model loader failed");
|
||||
|
||||
135
model.cpp
135
model.cpp
@ -6,10 +6,12 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "gguf_reader.hpp"
|
||||
#include "model.h"
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
#include "vocab.hpp"
|
||||
#include "vocab_umt5.hpp"
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
@ -88,6 +90,7 @@ const char* unused_tensors[] = {
|
||||
"posterior_mean_coef1",
|
||||
"posterior_mean_coef2",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_ids",
|
||||
"cond_stage_model.transformer.vision_model.embeddings.position_ids",
|
||||
"cond_stage_model.model.logit_scale",
|
||||
"cond_stage_model.model.text_projection",
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
|
||||
@ -141,6 +144,11 @@ std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
|
||||
{"mlp.c_proj.weight", "mlp.fc2.weight"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> cond_model_name_map = {
|
||||
{"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"},
|
||||
{"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> vae_decoder_name_map = {
|
||||
{"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"},
|
||||
{"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"},
|
||||
@ -179,7 +187,7 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
||||
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
||||
};
|
||||
|
||||
std::string convert_open_clip_to_hf_clip(const std::string& name) {
|
||||
std::string convert_cond_model_name(const std::string& name) {
|
||||
std::string new_name = name;
|
||||
std::string prefix;
|
||||
if (contains(new_name, ".enc.")) {
|
||||
@ -268,6 +276,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
|
||||
new_name = open_clip_to_hf_clip_model[new_name];
|
||||
}
|
||||
|
||||
if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) {
|
||||
new_name = cond_model_name_map[new_name];
|
||||
}
|
||||
|
||||
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
|
||||
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
|
||||
|
||||
@ -563,7 +575,7 @@ std::string convert_tensor_name(std::string name) {
|
||||
// }
|
||||
std::string new_name = name;
|
||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
||||
new_name = convert_open_clip_to_hf_clip(name);
|
||||
new_name = convert_cond_model_name(name);
|
||||
} else if (starts_with(name, "first_stage_model.decoder")) {
|
||||
new_name = convert_vae_decoder_name(name);
|
||||
} else if (starts_with(name, "pmid.qformer_perceiver")) {
|
||||
@ -592,9 +604,11 @@ std::string convert_tensor_name(std::string name) {
|
||||
} else {
|
||||
new_name = name;
|
||||
}
|
||||
} else if (ends_with(name, ".diff") || ends_with(name, ".diff_b")) {
|
||||
new_name = "lora." + name;
|
||||
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
|
||||
contains(name, "lora.up") || contains(name, "lora.down") ||
|
||||
contains(name, "lora_linear")) {
|
||||
contains(name, "lora_linear") || ends_with(name, ".alpha")) {
|
||||
size_t pos = new_name.find(".processor");
|
||||
if (pos != std::string::npos) {
|
||||
new_name.replace(pos, strlen(".processor"), "");
|
||||
@ -602,7 +616,11 @@ std::string convert_tensor_name(std::string name) {
|
||||
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
|
||||
// new_name = "model.diffusion_model." + new_name;
|
||||
// }
|
||||
if (ends_with(name, ".alpha")) {
|
||||
pos = new_name.rfind("alpha");
|
||||
} else {
|
||||
pos = new_name.rfind("lora");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
||||
std::string network_part = new_name.substr(pos);
|
||||
@ -684,6 +702,13 @@ void preprocess_tensor(TensorStorage tensor_storage,
|
||||
tensor_storage.unsqueeze();
|
||||
}
|
||||
|
||||
// wan vae
|
||||
if (ends_with(new_name, "gamma")) {
|
||||
tensor_storage.reverse_ne();
|
||||
tensor_storage.n_dims = 1;
|
||||
tensor_storage.reverse_ne();
|
||||
}
|
||||
|
||||
tensor_storage.name = new_name;
|
||||
|
||||
if (new_name.find("cond_stage_model") != std::string::npos &&
|
||||
@ -1030,12 +1055,40 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
gguf_context* ctx_gguf_ = NULL;
|
||||
ggml_context* ctx_meta_ = NULL;
|
||||
|
||||
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
|
||||
if (!ctx_gguf_) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str());
|
||||
GGUFReader gguf_reader;
|
||||
if (!gguf_reader.load(file_path)) {
|
||||
LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_offset = gguf_reader.data_offset();
|
||||
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
|
||||
std::string name = gguf_tensor_info.name;
|
||||
if (!starts_with(name, prefix)) {
|
||||
name = prefix + name;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(
|
||||
name,
|
||||
gguf_tensor_info.type,
|
||||
gguf_tensor_info.shape.data(),
|
||||
gguf_tensor_info.shape.size(),
|
||||
file_index,
|
||||
data_offset + gguf_tensor_info.offset);
|
||||
|
||||
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int n_tensors = gguf_get_n_tensors(ctx_gguf_);
|
||||
|
||||
size_t total_size = 0;
|
||||
@ -1047,7 +1100,11 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
|
||||
TensorStorage tensor_storage(prefix + name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
|
||||
if (!starts_with(name, prefix)) {
|
||||
name = prefix + name;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
||||
|
||||
@ -1085,7 +1142,7 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
|
||||
// https://huggingface.co/docs/safetensors/index
|
||||
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
||||
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
|
||||
file_paths_.push_back(file_path);
|
||||
size_t file_index = file_paths_.size() - 1;
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
@ -1150,6 +1207,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
std::string dtype = tensor_info["dtype"];
|
||||
nlohmann::json shape = tensor_info["shape"];
|
||||
|
||||
if (dtype == "U8") {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
||||
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
||||
|
||||
@ -1171,12 +1232,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
}
|
||||
|
||||
if (n_dims == 5) {
|
||||
if (ne[3] == 1 && ne[4] == 1) {
|
||||
n_dims = 4;
|
||||
} else {
|
||||
LOG_ERROR("invalid tensor '%s'", name.c_str());
|
||||
return false;
|
||||
}
|
||||
ne[0] = ne[0] * ne[1];
|
||||
ne[1] = ne[2];
|
||||
ne[2] = ne[3];
|
||||
ne[3] = ne[4];
|
||||
}
|
||||
|
||||
// ggml_n_dims returns 1 for scalars
|
||||
@ -1184,7 +1244,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
n_dims = 1;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||
if (!starts_with(name, prefix)) {
|
||||
name = prefix + name;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||
tensor_storage.reverse_ne();
|
||||
|
||||
size_t tensor_data_size = end - begin;
|
||||
@ -1569,7 +1633,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
reader.tensor_storage.file_index = file_index;
|
||||
// if(strcmp(prefix.c_str(), "scarlett") == 0)
|
||||
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
|
||||
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
|
||||
std::string name = reader.tensor_storage.name;
|
||||
if (!starts_with(name, prefix)) {
|
||||
name = prefix + name;
|
||||
}
|
||||
reader.tensor_storage.name = name;
|
||||
tensor_storages.push_back(reader.tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
|
||||
|
||||
@ -1643,10 +1711,12 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
|
||||
bool is_xl = false;
|
||||
bool is_flux = false;
|
||||
bool is_wan = false;
|
||||
int64_t patch_embedding_channels = 0;
|
||||
bool has_img_emb = false;
|
||||
|
||||
#define found_family (is_xl || is_flux)
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (!found_family) {
|
||||
if (!(is_xl || is_flux)) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
if (input_block_checked) {
|
||||
@ -1656,6 +1726,15 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||
return VERSION_SD3;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||
is_wan = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) {
|
||||
patch_embedding_channels = tensor_storage.ne[3];
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
|
||||
has_img_emb = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
||||
is_unet = true;
|
||||
if (has_multiple_encoders) {
|
||||
@ -1690,11 +1769,21 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
||||
input_block_weight = tensor_storage;
|
||||
input_block_checked = true;
|
||||
if (found_family) {
|
||||
if (is_xl || is_flux) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_wan) {
|
||||
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
|
||||
if (patch_embedding_channels == 184320 && !has_img_emb) {
|
||||
return VERSION_WAN2_2_I2V;
|
||||
}
|
||||
if (patch_embedding_channels == 147456 && !has_img_emb) {
|
||||
return VERSION_WAN2_2_TI2V;
|
||||
}
|
||||
return VERSION_WAN2;
|
||||
}
|
||||
bool is_inpaint = input_block_weight.ne[2] == 9;
|
||||
bool is_ip2p = input_block_weight.ne[2] == 8;
|
||||
if (is_xl) {
|
||||
@ -1850,6 +1939,11 @@ std::string ModelLoader::load_t5_tokenizer_json() {
|
||||
return json_str;
|
||||
}
|
||||
|
||||
std::string ModelLoader::load_umt5_tokenizer_json() {
|
||||
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
|
||||
return json_str;
|
||||
}
|
||||
|
||||
std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& vec) {
|
||||
std::vector<TensorStorage> res;
|
||||
std::unordered_map<std::string, size_t> name_to_index_map;
|
||||
@ -1871,7 +1965,7 @@ std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& v
|
||||
return res;
|
||||
}
|
||||
|
||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) {
|
||||
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
|
||||
std::vector<TensorStorage> processed_tensor_storages;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
// LOG_DEBUG("%s", name.c_str());
|
||||
@ -2080,7 +2174,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
}
|
||||
|
||||
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
|
||||
ggml_backend_t backend,
|
||||
std::set<std::string> ignore_tensors) {
|
||||
std::set<std::string> tensor_names_in_file;
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
@ -2120,7 +2213,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
||||
return true;
|
||||
};
|
||||
|
||||
bool success = load_tensors(on_new_tensor_cb, backend);
|
||||
bool success = load_tensors(on_new_tensor_cb);
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from file failed");
|
||||
return false;
|
||||
@ -2151,7 +2244,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
||||
|
||||
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
std::vector<std::pair<std::string, ggml_type>> result;
|
||||
for (const auto& item : splitString(tensor_type_rules, ',')) {
|
||||
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||
if (item.size() == 0)
|
||||
continue;
|
||||
std::string::size_type pos = item.find('=');
|
||||
@ -2264,7 +2357,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
||||
return true;
|
||||
};
|
||||
|
||||
bool success = load_tensors(on_new_tensor_cb, backend);
|
||||
bool success = load_tensors(on_new_tensor_cb);
|
||||
ggml_backend_free(backend);
|
||||
LOG_INFO("load tensors done");
|
||||
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||
|
||||
46
model.h
46
model.h
@ -31,23 +31,12 @@ enum SDVersion {
|
||||
VERSION_SD3,
|
||||
VERSION_FLUX,
|
||||
VERSION_FLUX_FILL,
|
||||
VERSION_WAN2,
|
||||
VERSION_WAN2_2_I2V,
|
||||
VERSION_WAN2_2_TI2V,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
static inline bool sd_version_is_flux(SDVersion version) {
|
||||
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||
if (version == VERSION_SD3) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
|
||||
return true;
|
||||
@ -69,6 +58,27 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||
if (version == VERSION_SD3) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_flux(SDVersion version) {
|
||||
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_wan(SDVersion version) {
|
||||
if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
|
||||
return true;
|
||||
@ -77,7 +87,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_dit(SDVersion version) {
|
||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -113,7 +123,7 @@ struct TensorStorage {
|
||||
|
||||
TensorStorage() = default;
|
||||
|
||||
TensorStorage(const std::string& name, ggml_type type, int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
TensorStorage(const std::string& name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
: name(name), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
this->ne[i] = ne[i];
|
||||
@ -237,9 +247,8 @@ public:
|
||||
ggml_type get_diffusion_model_wtype();
|
||||
ggml_type get_vae_wtype();
|
||||
void set_wtype_override(ggml_type wtype, std::string prefix = "");
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend);
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
|
||||
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
|
||||
ggml_backend_t backend,
|
||||
std::set<std::string> ignore_tensors = {});
|
||||
|
||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
|
||||
@ -249,6 +258,7 @@ public:
|
||||
|
||||
static std::string load_merges();
|
||||
static std::string load_t5_tokenizer_json();
|
||||
static std::string load_umt5_tokenizer_json();
|
||||
};
|
||||
|
||||
#endif // __MODEL_H__
|
||||
|
||||
10
pmid.hpp
10
pmid.hpp
@ -624,12 +624,13 @@ public:
|
||||
|
||||
public:
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
SDVersion version = VERSION_SDXL,
|
||||
PMVersion pm_v = PM_VERSION_1,
|
||||
float sty = 20.f)
|
||||
: GGMLRunner(backend),
|
||||
: GGMLRunner(backend, offload_params_to_cpu),
|
||||
version(version),
|
||||
pm_version(pm_v),
|
||||
style_strength(sty) {
|
||||
@ -785,10 +786,11 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
bool applied = false;
|
||||
|
||||
PhotoMakerIDEmbed(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
ModelLoader* ml,
|
||||
const std::string& file_path = "",
|
||||
const std::string& prefix = "")
|
||||
: file_path(file_path), GGMLRunner(backend), model_loader(ml) {
|
||||
: file_path(file_path), GGMLRunner(backend, offload_params_to_cpu), model_loader(ml) {
|
||||
if (!model_loader->init_from_file(file_path, prefix)) {
|
||||
load_failed = true;
|
||||
}
|
||||
@ -828,11 +830,11 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
|
||||
return true;
|
||||
};
|
||||
|
||||
model_loader->load_tensors(on_new_tensor_cb, backend);
|
||||
model_loader->load_tensors(on_new_tensor_cb);
|
||||
alloc_params_buffer();
|
||||
|
||||
dry_run = false;
|
||||
model_loader->load_tensors(on_new_tensor_cb, backend);
|
||||
model_loader->load_tensors(on_new_tensor_cb);
|
||||
|
||||
LOG_DEBUG("finished loading PhotoMaker ID Embeds ");
|
||||
return true;
|
||||
|
||||
252
rope.hpp
Normal file
252
rope.hpp
Normal file
@ -0,0 +1,252 @@
|
||||
#ifndef __ROPE_HPP__
|
||||
#define __ROPE_HPP__
|
||||
|
||||
#include <vector>
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
struct Rope {
|
||||
template <class T>
|
||||
static std::vector<T> linspace(T start, T end, int num) {
|
||||
std::vector<T> result(num);
|
||||
if (num == 1) {
|
||||
result[0] = start;
|
||||
return result;
|
||||
}
|
||||
T step = (end - start) / (num - 1);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
result[i] = start + i * step;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
|
||||
int rows = mat.size();
|
||||
int cols = mat[0].size();
|
||||
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
for (int j = 0; j < cols; ++j) {
|
||||
transposed[j][i] = mat[i][j];
|
||||
}
|
||||
}
|
||||
return transposed;
|
||||
}
|
||||
|
||||
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
|
||||
std::vector<float> flat_vec;
|
||||
for (const auto& sub_vec : vec) {
|
||||
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
|
||||
}
|
||||
return flat_vec;
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
|
||||
assert(dim % 2 == 0);
|
||||
int half_dim = dim / 2;
|
||||
|
||||
std::vector<float> scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim);
|
||||
|
||||
std::vector<float> omega(half_dim);
|
||||
for (int i = 0; i < half_dim; ++i) {
|
||||
omega[i] = 1.0 / std::pow(theta, scale[i]);
|
||||
}
|
||||
|
||||
int pos_size = pos.size();
|
||||
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
||||
for (int i = 0; i < pos_size; ++i) {
|
||||
for (int j = 0; j < half_dim; ++j) {
|
||||
out[i][j] = pos[i] * omega[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
|
||||
for (int i = 0; i < pos_size; ++i) {
|
||||
for (int j = 0; j < half_dim; ++j) {
|
||||
result[i][4 * j] = std::cos(out[i][j]);
|
||||
result[i][4 * j + 1] = -std::sin(out[i][j]);
|
||||
result[i][4 * j + 2] = std::sin(out[i][j]);
|
||||
result[i][4 * j + 3] = std::cos(out[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Generate IDs for image patches and text
|
||||
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
|
||||
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
|
||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||
|
||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
|
||||
|
||||
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
||||
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
||||
|
||||
for (int i = 0; i < h_len; ++i) {
|
||||
for (int j = 0; j < w_len; ++j) {
|
||||
img_ids[i * w_len + j][0] = index;
|
||||
img_ids[i * w_len + j][1] = row_ids[i];
|
||||
img_ids[i * w_len + j][2] = col_ids[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < img_ids.size(); ++j) {
|
||||
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
|
||||
}
|
||||
}
|
||||
return img_ids_repeated;
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
|
||||
const std::vector<std::vector<float>>& b,
|
||||
int bs) {
|
||||
size_t a_len = a.size() / bs;
|
||||
size_t b_len = b.size() / bs;
|
||||
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < a_len; ++j) {
|
||||
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
|
||||
}
|
||||
for (int j = 0; j < b_len; ++j) {
|
||||
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
|
||||
int bs,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
||||
size_t pos_len = ids.size() / bs;
|
||||
int num_axes = axes_dim.size();
|
||||
// for (int i = 0; i < pos_len; i++) {
|
||||
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
|
||||
// }
|
||||
|
||||
int emb_dim = 0;
|
||||
for (int d : axes_dim)
|
||||
emb_dim += d / 2;
|
||||
|
||||
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
||||
int offset = 0;
|
||||
for (int i = 0; i < num_axes; ++i) {
|
||||
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
||||
for (int b = 0; b < bs; ++b) {
|
||||
for (int j = 0; j < pos_len; ++j) {
|
||||
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
||||
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += rope_emb[0].size();
|
||||
}
|
||||
|
||||
return flatten(emb);
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
std::vector<ggml_tensor*> ref_latents) {
|
||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
uint64_t curr_h_offset = 0;
|
||||
uint64_t curr_w_offset = 0;
|
||||
for (ggml_tensor* ref : ref_latents) {
|
||||
uint64_t h_offset = 0;
|
||||
uint64_t w_offset = 0;
|
||||
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
|
||||
w_offset = curr_w_offset;
|
||||
} else {
|
||||
h_offset = curr_h_offset;
|
||||
}
|
||||
|
||||
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
|
||||
ids = concat_ids(ids, ref_ids, bs);
|
||||
|
||||
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
|
||||
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
// Generate flux positional embeddings
|
||||
static std::vector<float> gen_flux_pe(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents);
|
||||
return embed_nd(ids, bs, theta, axes_dim);
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> gen_vid_ids(int t,
|
||||
int h,
|
||||
int w,
|
||||
int pt,
|
||||
int ph,
|
||||
int pw,
|
||||
int bs,
|
||||
int t_offset = 0,
|
||||
int h_offset = 0,
|
||||
int w_offset = 0) {
|
||||
int t_len = (t + (pt / 2)) / pt;
|
||||
int h_len = (h + (ph / 2)) / ph;
|
||||
int w_len = (w + (pw / 2)) / pw;
|
||||
|
||||
std::vector<std::vector<float>> vid_ids(t_len * h_len * w_len, std::vector<float>(3, 0.0));
|
||||
|
||||
std::vector<float> t_ids = linspace<float>(t_offset, t_len - 1 + t_offset, t_len);
|
||||
std::vector<float> h_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
||||
std::vector<float> w_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
||||
|
||||
for (int i = 0; i < t_len; ++i) {
|
||||
for (int j = 0; j < h_len; ++j) {
|
||||
for (int k = 0; k < w_len; ++k) {
|
||||
int idx = i * h_len * w_len + j * w_len + k;
|
||||
vid_ids[idx][0] = t_ids[i];
|
||||
vid_ids[idx][1] = h_ids[j];
|
||||
vid_ids[idx][2] = w_ids[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> vid_ids_repeated(bs * vid_ids.size(), std::vector<float>(3));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < vid_ids.size(); ++j) {
|
||||
vid_ids_repeated[i * vid_ids.size() + j] = vid_ids[j];
|
||||
}
|
||||
}
|
||||
return vid_ids_repeated;
|
||||
}
|
||||
|
||||
// Generate wan positional embeddings
|
||||
static std::vector<float> gen_wan_pe(int t,
|
||||
int h,
|
||||
int w,
|
||||
int pt,
|
||||
int ph,
|
||||
int pw,
|
||||
int bs,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
|
||||
return embed_nd(ids, bs, theta, axes_dim);
|
||||
}
|
||||
}; // struct Rope
|
||||
|
||||
#endif // __ROPE_HPP__
|
||||
1112
stable-diffusion.cpp
1112
stable-diffusion.cpp
File diff suppressed because it is too large
Load Diff
@ -50,7 +50,7 @@ enum sample_method_t {
|
||||
SAMPLE_METHOD_COUNT
|
||||
};
|
||||
|
||||
enum schedule_t {
|
||||
enum scheduler_t {
|
||||
DEFAULT,
|
||||
DISCRETE,
|
||||
KARRAS,
|
||||
@ -101,7 +101,8 @@ enum sd_type_t {
|
||||
// SD_TYPE_IQ4_NL_4_4 = 36,
|
||||
// SD_TYPE_IQ4_NL_4_8 = 37,
|
||||
// SD_TYPE_IQ4_NL_8_8 = 38,
|
||||
SD_TYPE_COUNT = 39,
|
||||
SD_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
||||
SD_TYPE_COUNT = 40,
|
||||
};
|
||||
|
||||
enum sd_log_level_t {
|
||||
@ -115,8 +116,10 @@ typedef struct {
|
||||
const char* model_path;
|
||||
const char* clip_l_path;
|
||||
const char* clip_g_path;
|
||||
const char* clip_vision_path;
|
||||
const char* t5xxl_path;
|
||||
const char* diffusion_model_path;
|
||||
const char* high_noise_diffusion_model_path;
|
||||
const char* vae_path;
|
||||
const char* taesd_path;
|
||||
const char* control_net_path;
|
||||
@ -129,7 +132,7 @@ typedef struct {
|
||||
int n_threads;
|
||||
enum sd_type_t wtype;
|
||||
enum rng_type_t rng_type;
|
||||
enum schedule_t schedule;
|
||||
bool offload_params_to_cpu;
|
||||
bool keep_clip_on_cpu;
|
||||
bool keep_control_net_on_cpu;
|
||||
bool keep_vae_on_cpu;
|
||||
@ -159,29 +162,33 @@ typedef struct {
|
||||
typedef struct {
|
||||
float txt_cfg;
|
||||
float img_cfg;
|
||||
float min_cfg;
|
||||
float distilled_guidance;
|
||||
sd_slg_params_t slg;
|
||||
} sd_guidance_params_t;
|
||||
|
||||
typedef struct {
|
||||
sd_guidance_params_t guidance;
|
||||
enum scheduler_t scheduler;
|
||||
enum sample_method_t sample_method;
|
||||
int sample_steps;
|
||||
float eta;
|
||||
} sd_sample_params_t;
|
||||
|
||||
typedef struct {
|
||||
const char* prompt;
|
||||
const char* negative_prompt;
|
||||
int clip_skip;
|
||||
sd_guidance_params_t guidance;
|
||||
sd_image_t init_image;
|
||||
sd_image_t* ref_images;
|
||||
int ref_images_count;
|
||||
sd_image_t mask_image;
|
||||
int width;
|
||||
int height;
|
||||
enum sample_method_t sample_method;
|
||||
int sample_steps;
|
||||
float eta;
|
||||
sd_sample_params_t sample_params;
|
||||
float strength;
|
||||
int64_t seed;
|
||||
int batch_count;
|
||||
const sd_image_t* control_cond;
|
||||
sd_image_t control_image;
|
||||
float control_strength;
|
||||
float style_strength;
|
||||
bool normalize_input;
|
||||
@ -189,18 +196,18 @@ typedef struct {
|
||||
} sd_img_gen_params_t;
|
||||
|
||||
typedef struct {
|
||||
const char* prompt;
|
||||
const char* negative_prompt;
|
||||
int clip_skip;
|
||||
sd_image_t init_image;
|
||||
sd_image_t end_image;
|
||||
int width;
|
||||
int height;
|
||||
sd_guidance_params_t guidance;
|
||||
enum sample_method_t sample_method;
|
||||
int sample_steps;
|
||||
sd_sample_params_t sample_params;
|
||||
sd_sample_params_t high_noise_sample_params;
|
||||
float strength;
|
||||
int64_t seed;
|
||||
int video_frames;
|
||||
int motion_bucket_id;
|
||||
int fps;
|
||||
float augmentation_level;
|
||||
} sd_vid_gen_params_t;
|
||||
|
||||
typedef struct sd_ctx_t sd_ctx_t;
|
||||
@ -219,8 +226,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
|
||||
SD_API enum rng_type_t str_to_rng_type(const char* str);
|
||||
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
|
||||
SD_API enum sample_method_t str_to_sample_method(const char* str);
|
||||
SD_API const char* sd_schedule_name(enum schedule_t schedule);
|
||||
SD_API enum schedule_t str_to_schedule(const char* str);
|
||||
SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
|
||||
SD_API enum scheduler_t str_to_schedule(const char* str);
|
||||
|
||||
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
||||
@ -228,21 +235,27 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
||||
|
||||
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
||||
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
||||
|
||||
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
|
||||
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
|
||||
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
|
||||
|
||||
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
|
||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken
|
||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
|
||||
|
||||
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
||||
|
||||
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
|
||||
int n_threads,
|
||||
bool direct);
|
||||
bool offload_params_to_cpu,
|
||||
bool direct,
|
||||
int n_threads);
|
||||
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
|
||||
|
||||
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
|
||||
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
|
||||
sd_image_t input_image,
|
||||
uint32_t upscale_factor);
|
||||
|
||||
SD_API bool convert(const char* input_path,
|
||||
const char* vae_path,
|
||||
|
||||
113
t5.hpp
113
t5.hpp
@ -124,6 +124,9 @@ protected:
|
||||
return;
|
||||
}
|
||||
std::string piece = item[0];
|
||||
if (piece.empty()) {
|
||||
piece = "<empty_token>";
|
||||
}
|
||||
float score = item[1];
|
||||
piece_score_pairs.emplace_back(piece, score);
|
||||
}
|
||||
@ -147,6 +150,7 @@ protected:
|
||||
std::vector<const char*> key(pieces->size());
|
||||
std::vector<int> value(pieces->size());
|
||||
for (size_t i = 0; i < pieces->size(); ++i) {
|
||||
// LOG_DEBUG("%s %d", (*pieces)[i].first.c_str(), (*pieces)[i].second);
|
||||
key[i] = (*pieces)[i].first.data(); // sorted piece.
|
||||
value[i] = (*pieces)[i].second; // vocab_id
|
||||
}
|
||||
@ -335,9 +339,9 @@ protected:
|
||||
}
|
||||
|
||||
public:
|
||||
explicit T5UniGramTokenizer(const std::string& json_str = "") {
|
||||
if (json_str.size() != 0) {
|
||||
InitializePieces(json_str);
|
||||
explicit T5UniGramTokenizer(bool is_umt5 = false) {
|
||||
if (is_umt5) {
|
||||
InitializePieces(ModelLoader::load_umt5_tokenizer_json());
|
||||
} else {
|
||||
InitializePieces(ModelLoader::load_t5_tokenizer_json());
|
||||
}
|
||||
@ -673,10 +677,11 @@ public:
|
||||
int64_t model_dim,
|
||||
int64_t inner_dim,
|
||||
int64_t ff_dim,
|
||||
int64_t num_heads)
|
||||
int64_t num_heads,
|
||||
bool relative_attention = true)
|
||||
: num_layers(num_layers) {
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
blocks["block." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new T5Block(model_dim, inner_dim, ff_dim, num_heads, i == 0));
|
||||
blocks["block." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new T5Block(model_dim, inner_dim, ff_dim, num_heads, (!relative_attention || i == 0)));
|
||||
}
|
||||
|
||||
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new T5LayerNorm(model_dim));
|
||||
@ -703,15 +708,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct T5Params {
|
||||
int64_t num_layers = 24;
|
||||
int64_t model_dim = 4096;
|
||||
int64_t ff_dim = 10240;
|
||||
int64_t num_heads = 64;
|
||||
int64_t vocab_size = 32128;
|
||||
bool relative_attention = true;
|
||||
};
|
||||
|
||||
struct T5 : public GGMLBlock {
|
||||
T5Params params;
|
||||
|
||||
public:
|
||||
T5(int64_t num_layers,
|
||||
int64_t model_dim,
|
||||
int64_t ff_dim,
|
||||
int64_t num_heads,
|
||||
int64_t vocab_size) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(num_layers, model_dim, model_dim, ff_dim, num_heads));
|
||||
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(vocab_size, model_dim));
|
||||
T5() {}
|
||||
T5(T5Params params)
|
||||
: params(params) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(params.num_layers,
|
||||
params.model_dim,
|
||||
params.model_dim,
|
||||
params.ff_dim,
|
||||
params.num_heads,
|
||||
params.relative_attention));
|
||||
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size,
|
||||
params.model_dim));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
@ -731,18 +751,21 @@ public:
|
||||
};
|
||||
|
||||
struct T5Runner : public GGMLRunner {
|
||||
T5Params params;
|
||||
T5 model;
|
||||
std::vector<int> relative_position_bucket_vec;
|
||||
|
||||
T5Runner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
int64_t num_layers = 24,
|
||||
int64_t model_dim = 4096,
|
||||
int64_t ff_dim = 10240,
|
||||
int64_t num_heads = 64,
|
||||
int64_t vocab_size = 32128)
|
||||
: GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) {
|
||||
bool is_umt5 = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
if (is_umt5) {
|
||||
params.vocab_size = 256384;
|
||||
params.relative_attention = false;
|
||||
}
|
||||
model = T5(params);
|
||||
model.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
@ -770,6 +793,7 @@ struct T5Runner : public GGMLRunner {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
input_ids = to_backend(input_ids);
|
||||
attention_mask = to_backend(attention_mask);
|
||||
|
||||
relative_position_bucket_vec = compute_relative_position_bucket(input_ids->ne[0], input_ids->ne[0]);
|
||||
|
||||
@ -877,14 +901,11 @@ struct T5Embedder {
|
||||
T5Runner model;
|
||||
|
||||
T5Embedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
int64_t num_layers = 24,
|
||||
int64_t model_dim = 4096,
|
||||
int64_t ff_dim = 10240,
|
||||
int64_t num_heads = 64,
|
||||
int64_t vocab_size = 32128)
|
||||
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) {
|
||||
bool is_umt5 = false)
|
||||
: model(backend, offload_params_to_cpu, tensor_types, prefix, is_umt5), tokenizer(is_umt5) {
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
@ -946,25 +967,22 @@ struct T5Embedder {
|
||||
GGML_ASSERT(work_ctx != NULL);
|
||||
|
||||
{
|
||||
// cpu f16: pass
|
||||
// cpu f32: pass
|
||||
// cuda f16: nan
|
||||
// cuda f32: pass
|
||||
// cuda q8_0: nan
|
||||
// TODO: fix cuda nan
|
||||
std::string text("a lovely cat");
|
||||
auto tokens_and_weights = tokenize(text, 77, true);
|
||||
// std::string text("一只可爱的猫"); // umt5 chinease test
|
||||
auto tokens_and_weights = tokenize(text, 512, true);
|
||||
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
|
||||
std::vector<float>& weights = std::get<1>(tokens_and_weights);
|
||||
std::vector<float>& masks = std::get<2>(tokens_and_weights);
|
||||
for (auto token : tokens) {
|
||||
printf("%d ", token);
|
||||
}
|
||||
printf("\n");
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
auto attention_mask = vector_to_ggml_tensor(work_ctx, masks);
|
||||
struct ggml_tensor* out = NULL;
|
||||
|
||||
int t0 = ggml_time_ms();
|
||||
model.compute(8, input_ids, NULL, &out, work_ctx);
|
||||
model.compute(8, input_ids, attention_mask, &out, work_ctx);
|
||||
int t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
@ -973,16 +991,14 @@ struct T5Embedder {
|
||||
}
|
||||
|
||||
static void load_from_file_and_test(const std::string& file_path) {
|
||||
// cpu f16: pass
|
||||
// cpu f32: pass
|
||||
// cuda f16: pass
|
||||
// cuda f32: pass
|
||||
// cuda q8_0: pass
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_F32;
|
||||
std::shared_ptr<T5Embedder> t5 = std::shared_ptr<T5Embedder>(new T5Embedder(backend));
|
||||
{
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
t5->alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> tensors;
|
||||
t5->get_param_tensors(tensors, "");
|
||||
ggml_type model_data_type = GGML_TYPE_F16;
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
@ -990,7 +1006,21 @@ struct T5Embedder {
|
||||
return;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(tensors, backend);
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
item.second = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<T5Embedder> t5 = std::shared_ptr<T5Embedder>(new T5Embedder(backend, false, tensor_types, "", true));
|
||||
|
||||
t5->alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> tensors;
|
||||
t5->get_param_tensors(tensors, "");
|
||||
|
||||
bool success = model_loader.load_tensors(tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from model loader failed");
|
||||
@ -998,7 +1028,6 @@ struct T5Embedder {
|
||||
}
|
||||
|
||||
LOG_INFO("t5 model loaded");
|
||||
}
|
||||
t5->test();
|
||||
}
|
||||
};
|
||||
|
||||
5
tae.hpp
5
tae.hpp
@ -196,13 +196,14 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||
bool decode_only = false;
|
||||
|
||||
TinyAutoEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decoder_only),
|
||||
taesd(decoder_only, version),
|
||||
GGMLRunner(backend) {
|
||||
GGMLRunner(backend, offload_params_to_cpu) {
|
||||
taesd.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
@ -237,7 +238,7 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(taesd_tensors, backend, ignore_tensors);
|
||||
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tae tensors from model loader failed");
|
||||
|
||||
3
thirdparty/darts.h
vendored
3
thirdparty/darts.h
vendored
@ -4,6 +4,7 @@
|
||||
#include <cstdio>
|
||||
#include <exception>
|
||||
#include <new>
|
||||
#include <iostream>
|
||||
|
||||
#define DARTS_VERSION "0.32"
|
||||
|
||||
@ -1140,9 +1141,11 @@ inline void DawgBuilder::insert(const char *key, std::size_t length,
|
||||
if (value < 0) {
|
||||
DARTS_THROW("failed to insert key: negative value");
|
||||
} else if (length == 0) {
|
||||
std::cout << value << std::endl;
|
||||
DARTS_THROW("failed to insert key: zero-length key");
|
||||
}
|
||||
|
||||
|
||||
id_type id = 0;
|
||||
std::size_t key_pos = 0;
|
||||
|
||||
|
||||
3
unet.hpp
3
unet.hpp
@ -538,11 +538,12 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
UnetModelBlock unet;
|
||||
|
||||
UNetModelRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend), unet(version, tensor_types, flash_attn) {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types, flash_attn) {
|
||||
unet.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
|
||||
12
upscaler.cpp
12
upscaler.cpp
@ -17,7 +17,8 @@ struct UpscalerGGML {
|
||||
direct(direct) {
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& esrgan_path) {
|
||||
bool load_from_file(const std::string& esrgan_path,
|
||||
bool offload_params_to_cpu) {
|
||||
#ifdef SD_USE_CUDA
|
||||
LOG_DEBUG("Using CUDA backend");
|
||||
backend = ggml_backend_cuda_init(0);
|
||||
@ -49,7 +50,7 @@ struct UpscalerGGML {
|
||||
backend = ggml_backend_cpu_init();
|
||||
}
|
||||
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
|
||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
|
||||
if (direct) {
|
||||
esrgan_upscaler->enable_conv2d_direct();
|
||||
}
|
||||
@ -110,8 +111,9 @@ struct upscaler_ctx_t {
|
||||
};
|
||||
|
||||
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
|
||||
int n_threads,
|
||||
bool direct = false) {
|
||||
bool offload_params_to_cpu,
|
||||
bool direct,
|
||||
int n_threads) {
|
||||
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
|
||||
if (upscaler_ctx == NULL) {
|
||||
return NULL;
|
||||
@ -123,7 +125,7 @@ upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path)) {
|
||||
if (!upscaler_ctx->upscaler->load_from_file(esrgan_path, offload_params_to_cpu)) {
|
||||
delete upscaler_ctx->upscaler;
|
||||
upscaler_ctx->upscaler = NULL;
|
||||
free(upscaler_ctx);
|
||||
|
||||
13
util.cpp
13
util.cpp
@ -72,6 +72,17 @@ std::string format(const char* fmt, ...) {
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
int round_up_to(int value, int base) {
|
||||
if (base <= 0) {
|
||||
return value;
|
||||
}
|
||||
if (value % base == 0) {
|
||||
return value;
|
||||
} else {
|
||||
return ((value / base) + 1) * base;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32 // code for windows
|
||||
#include <windows.h>
|
||||
|
||||
@ -290,7 +301,7 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
||||
return p1 + "/" + p2;
|
||||
}
|
||||
|
||||
std::vector<std::string> splitString(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> split_string(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> result;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
4
util.h
4
util.h
@ -18,6 +18,8 @@ std::string format(const char* fmt, ...);
|
||||
|
||||
void replace_all_chars(std::string& str, char target, char replacement);
|
||||
|
||||
int round_up_to(int value, int base);
|
||||
|
||||
bool file_exists(const std::string& filename);
|
||||
bool is_directory(const std::string& path);
|
||||
std::string get_full_path(const std::string& dir, const std::string& filename);
|
||||
@ -48,7 +50,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
|
||||
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
|
||||
|
||||
std::string path_join(const std::string& p1, const std::string& p2);
|
||||
std::vector<std::string> splitString(const std::string& str, char delimiter);
|
||||
std::vector<std::string> split_string(const std::string& str, char delimiter);
|
||||
void pretty_progress(int step, int steps, float time);
|
||||
|
||||
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
||||
|
||||
17
vae.hpp
17
vae.hpp
@ -520,17 +520,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct AutoEncoderKL : public GGMLRunner {
|
||||
struct VAE : public GGMLRunner {
|
||||
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
virtual void compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx) = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||
virtual void enable_conv2d_direct(){};
|
||||
};
|
||||
|
||||
struct AutoEncoderKL : public VAE {
|
||||
bool decode_only = true;
|
||||
AutoencodingEngine ae;
|
||||
|
||||
AutoEncoderKL(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const std::string prefix,
|
||||
bool decode_only = false,
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) {
|
||||
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
|
||||
762304
vocab_umt5.hpp
Normal file
762304
vocab_umt5.hpp
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user