feat: add PuLID-Flux identity-injection support (#1595)

This commit is contained in:
RapidMark 2026-06-15 08:33:50 -07:00 committed by GitHub
parent 6e66a1a4a4
commit 93527fda74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 878 additions and 13 deletions

195
docs/pulid.md Normal file
View File

@ -0,0 +1,195 @@
# PuLID-Flux face-identity preservation
stable-diffusion.cpp supports the [PuLID-Flux](https://github.com/ToTheBeginning/PuLID)
identity-injection technique on top of Flux.1 (schnell or dev) models.
Given a single source portrait, PuLID-Flux produces new generations that
preserve the source person's face across arbitrary scenes, poses, and
prompts.
Unlike PhotoMaker (which extracts the identity inside the inference
process from a directory of images), PuLID-Flux's identity extractor is
a heavy stack (insightface ArcFace + EVA-CLIP-L + IDFormer encoder) that
is impractical to port to C++/ggml. To keep this implementation small and
cross-vendor, **stable-diffusion.cpp consumes a precomputed identity
embedding** produced by an external Python tool that runs once per source
portrait. Everything downstream of that one-shot extraction is C++ and
runs on any backend (Vulkan, CUDA, Metal, ROCm, CPU).
## Architecture summary
The PuLID-Flux contribution to the Flux denoise loop is a stack of 20
small cross-attention modules (`PerceiverAttentionCA`) inserted between
the Flux transformer blocks:
- After every 2nd of the 19 double-stream blocks (10 hook points)
- After every 4th of the 38 single-stream blocks (10 hook points)
Each cross-attention layer takes the current image tokens as query, the
32-token / 2048-dim identity embedding as key+value, and adds its output
(scaled by `id_weight`, typically 1.0) back to the image tokens.
## Required weights
Three files in addition to the standard Flux weight set:
1. **Flux base** (transformer + VAE + clip_l + t5xxl) -- exactly as
[docs/flux.md](flux.md) describes.
2. **PuLID weights** -- download from
[guozinan/PuLID](https://huggingface.co/guozinan/PuLID):
- `pulid_flux_v0.9.0.safetensors` or `pulid_flux_v0.9.1.safetensors`
(recommended; this implementation is verified against v0.9.1)
- **v1.1 (`pulid_v1.1.safetensors`) is NOT yet supported** -- it uses
renamed keys (`id_adapter_attn_layers.*` instead of `pulid_ca.*`)
and possibly different module structure. Future PR.
3. **Identity embedding (.pulidembd)** -- produced by the precompute
tool below.
## Precompute the identity embedding
The precompute tool runs the PyTorch identity-extraction stack on a
single portrait image and writes the resulting `(32, 2048)` embedding
to a `.pulidembd` binary file (about 131 KB). Run it once per source
person; the same file is reused for any number of generations.
A reference Python script is provided alongside this docs file at
[`scripts/pulid_extract_id.py`](../scripts/pulid_extract_id.py). It
requires:
- A working CUDA / CPU PyTorch + diffusers stack
- `insightface`, `facexlib`, `eva-clip`, `torchvision`
- The PuLID weights file (same one stable-diffusion.cpp will load below)
- The ToTheBeginning/PuLID repo's `pulid/pipeline_flux.py` (and its
dependencies under `pulid/` and `flux/`) -- recommended to vendor
rather than pip-install due to upstream packaging quirks
Run it as:
```
python pulid_extract_id.py \
--portrait /path/to/source-photo.jpg \
--pulid-weights /path/to/pulid_flux_v0.9.1.safetensors \
--out /path/to/source.pulidembd
```
## Format (gguf)
The embedding is a standard **gguf** container holding a single tensor:
```
tensor name : "pulid_id"
shape : [token_dim, num_tokens] (ggml order; typically [2048, 32])
type : F16 (also accepts F32 / BF16)
metadata : general.architecture = "pulid", pulid.version = 1
```
stable-diffusion.cpp loads it with the normal gguf reader
(`gguf_init_from_file`) and converts to fp32 at load time -- no bespoke
parser. Total file size for the typical (32, 2048, fp16) case is ~131 KB.
## Command-line usage
```
.\bin\Release\sd-cli.exe \
--diffusion-model models\flux1-schnell-Q4_K_S.gguf \
--vae models\ae.safetensors \
--clip_l models\clip_l.safetensors \
--t5xxl models\t5xxl_fp16.safetensors \
--pulid-weights models\pulid_flux_v0.9.1.safetensors \
--pulid-id-embedding source.pulidembd \
--pulid-id-weight 1.0 \
-p "candid photograph of a young woman on a beach at sunset" \
--cfg-scale 1.0 --sampling-method euler --steps 4 -W 512 -H 512 \
--seed 42 --clip-on-cpu \
-o out.png
```
For Flux Dev (instead of Schnell), add `--guidance 3.5` and `--steps 20`.
## Flags
| Flag | Purpose |
|----------------------------|-------------------------------------------------------------------|
| `--pulid-weights <path>` | Path to `pulid_flux_v0.9.x.safetensors`. Loaded with the model. |
| `--pulid-id-embedding <p>` | Path to a `.pulidembd` binary produced by the precompute tool. |
| `--pulid-id-weight <f>` | Identity-injection strength. Typical 0.7-1.2; default 1.0. |
All three flags must be set together to activate PuLID. Setting only
`--pulid-weights` (no embedding) loads the weights but disables injection
at runtime. Setting `--pulid-id-weight 0` zeros out the contribution
(useful for falsification testing: outputs should be byte-identical to
a no-PuLID run with the same seed).
## Memory budget
At 512x512, 4 steps (Schnell), the 20 cross-attention layers add roughly
10% to denoise time and almost nothing to peak VRAM. Tested on a 12 GB
consumer card alongside Flux Schnell Q4 GGUF + CPU-offloaded clip_l and
t5xxl + GPU-resident VAE.
At 1024x1024 with Flux Dev Q4 + 20 steps + PuLID, the VAE decode compute
buffer doesn't fit on a 12 GB card even with `--vae-on-cpu`. Workaround:
explicitly route VAE to the CPU backend instead of the offload flag:
```
--backend "diffusion=vulkan0,vae=cpu"
```
The `--vae-on-cpu` flag offloads VAE weights but leaves the compute graph
on the default backend; this is existing stable-diffusion.cpp behavior,
not a PuLID-specific issue. Documented here because anyone running PuLID
at 1024 will hit it.
## Backend selection
The standard `--backend` flag works as documented. Common patterns:
```
# AMD Vulkan
--backend "diffusion=vulkan0,vae=cpu"
# NVIDIA Vulkan
--backend "diffusion=vulkan1,vae=cpu"
# CUDA
--backend "diffusion=cuda0,vae=cpu"
```
The PuLID cross-attention layers run on the same backend as the main
diffusion model. They have not yet been independently profiled on every
backend; only Vulkan and CPU have been tested by the original contributor.
## Verification
A three-way SHA-256 check is the recommended sanity test when bringing up
a new combination of model + backend + hardware:
| Run | Expected hash relation |
|----------------------------------------------|------------------------------------|
| A: no `--pulid-*` flags | baseline |
| B: PuLID flags, `--pulid-id-weight 0.0` | **byte-identical to A** |
| C: PuLID flags, `--pulid-id-weight 1.0` | **different from A,B**, preserves source identity |
If A and C differ but A and B differ too, the injection is allocating
or computing something even at zero weight -- likely a bug.
## Limitations / not yet supported
- **`--skip-layers` (skip-layer-guidance / SLG) combined with PuLID** is not
supported. The `pulid_ca` index advances per non-skipped block, so a
skipped block silently misaligns the cross-attention weight assignment
vs. the trained intervals. The reference PyTorch implementation does
not have SLG either, so there is no well-defined behavior to emulate.
Use either feature alone.
- **PuLID v1.1 weights** (`pulid_v1.1.safetensors`, renamed key layout).
- **Multiple ID images.** The reference PyTorch implementation can fuse
several portraits into one embedding for stronger identity. This
implementation accepts a single embedding produced from one or more
images by the external precompute tool.
- **Negative-prompt branch of CFG.** PuLID only injects on the positive
conditioning path in the published reference, and the implementation
here follows that. Flux's distilled guidance doesn't run a separate
uncond branch in normal use, so this matters only for `--true-cfg`
workflows that aren't standard for Flux.
- **Backends other than Vulkan and CPU** are untested by the original
contributor. The implementation is pure-ggml and should work on CUDA,
ROCm, and Metal, but verification by users on those backends is
welcomed.

View File

@ -415,6 +415,10 @@ ArgOptions SDContextParams::get_options() {
"--photo-maker",
"path to PHOTOMAKER model",
&photo_maker_path},
{"",
"--pulid-weights",
"path to PuLID flux weights (e.g. pulid_flux_v0.9.1.safetensors). Identity is injected during the denoise loop when paired with --pulid-id-embedding.",
&pulid_weights_path},
{"",
"--upscale-model",
"path to esrgan model.",
@ -812,6 +816,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool taesd_preview) {
sd_ctx_params.embeddings = embedding_vec.data();
sd_ctx_params.embedding_count = static_cast<uint32_t>(embedding_vec.size());
sd_ctx_params.photo_maker_path = photo_maker_path.c_str();
sd_ctx_params.pulid_weights_path = pulid_weights_path.c_str();
sd_ctx_params.tensor_type_rules = tensor_type_rules.c_str();
sd_ctx_params.n_threads = n_threads;
sd_ctx_params.wtype = wtype;
@ -887,6 +892,10 @@ ArgOptions SDGenerationParams::get_options() {
"--pm-id-embed-path",
"path to PHOTOMAKER v2 id embed",
&pm_id_embed_path},
{"",
"--pulid-id-embedding",
"path to a .pulidembd binary produced by pulid_extract_id.py. Carries a (32, 2048) identity embedding extracted from a source portrait. Pair with --pulid-weights on the context.",
&pulid_id_embedding_path},
{"",
"--hires-upscaler",
"highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent (nearest-exact), "
@ -1037,6 +1046,10 @@ ArgOptions SDGenerationParams::get_options() {
"--pm-style-strength",
"",
&pm_style_strength},
{"",
"--pulid-id-weight",
"strength of PuLID identity injection (default: 1.0). 0.7-1.2 are typical; lower lets the prompt override the face more, higher tightens identity match.",
&pulid_id_weight},
{"",
"--control-strength",
"strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image",
@ -2269,6 +2282,11 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
pm_style_strength,
};
sd_pulid_params_t pulid_params = {
pulid_id_embedding_path.empty() ? nullptr : pulid_id_embedding_path.c_str(),
pulid_id_weight,
};
params.loras = lora_vec.empty() ? nullptr : lora_vec.data();
params.lora_count = static_cast<uint32_t>(lora_vec.size());
params.prompt = prompt.c_str();
@ -2289,6 +2307,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
params.control_image = control_image.get();
params.control_strength = control_strength;
params.pm_params = pm_params;
params.pulid_params = pulid_params;
params.vae_tiling_params = vae_tiling_params;
params.cache = cache_params;

View File

@ -133,6 +133,11 @@ struct SDContextParams {
std::string control_net_path;
std::string embedding_dir;
std::string photo_maker_path;
// PuLID-Flux identity-preservation context path: the safetensors blob
// carrying the PerceiverAttentionCA cross-attention weights. Loaded
// once with the model. Per-generation pulid_id_embedding_path lives in
// SDGenerationParams below.
std::string pulid_weights_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir = ".";
@ -234,6 +239,12 @@ struct SDGenerationParams {
std::string pm_id_embed_path;
float pm_style_strength = 20.f;
// PuLID-Flux: per-generation identity embedding (binary file produced by
// runtime-scripts/pulid_extract_id.py). Format documented in
// include/stable-diffusion.h sd_pulid_params_t.
std::string pulid_id_embedding_path;
float pulid_id_weight = 1.0f;
int upscale_repeats = 1;
int upscale_tile_size = 128;

View File

@ -195,6 +195,16 @@ typedef struct {
const sd_embedding_t* embeddings;
uint32_t embedding_count;
const char* photo_maker_path;
/**
* Path to pulid_flux_v0.9.1.safetensors (the PuLID identity-injection
* cross-attention weights). When set together with sd_img_gen_params_t.
* pulid_params.id_embedding_path, the Flux diffusion model performs PuLID
* cross-attention injection during the denoise loop. Loaded once with
* the model; the embedding is per-generation. Currently only meaningful
* for Flux (depth=19 double, 38 single blocks); silently ignored for
* other model versions.
*/
const char* pulid_weights_path;
const char* tensor_type_rules;
int n_threads;
enum sd_type_t wtype;
@ -272,6 +282,25 @@ typedef struct {
float style_strength;
} sd_pm_params_t; // photo maker
/**
* PuLID-Flux identity preservation params.
*
* Unlike PhotoMaker (which extracts the ID embedding inside the inference
* process from a directory of images), PuLID's ID extraction is a heavy
* Python-only stack (insightface ArcFace + EVA-CLIP-L + IDFormer). To stay
* cross-vendor in C++/Vulkan, sd.cpp consumes a precomputed binary file
* produced by an external tool (runtime-scripts/pulid_extract_id.py in the
* Cloudhands client tree).
*
* Format: a gguf container with a single tensor "pulid_id" of shape
* [token_dim, num_tokens] (ggml order; typically [2048, 32]) in F16/F32/BF16.
* Loaded with the standard gguf reader; see docs/pulid.md.
*/
typedef struct {
const char* id_embedding_path; // path to .pulidembd file produced by pulid_extract_id.py
float id_weight; // strength of the ID injection; typical 0.7-1.2, default 1.0
} sd_pulid_params_t;
enum sd_cache_mode_t {
SD_CACHE_DISABLED = 0,
SD_CACHE_EASYCACHE,
@ -364,6 +393,7 @@ typedef struct {
sd_image_t control_image;
float control_strength;
sd_pm_params_t pm_params;
sd_pulid_params_t pulid_params;
sd_tiling_params_t vae_tiling_params;
sd_cache_params_t cache;
sd_hires_params_t hires;

161
scripts/pulid_extract_id.py Normal file
View File

@ -0,0 +1,161 @@
"""
Precompute a PuLID-Flux identity embedding from a single source portrait.
Writes a gguf file (a single tensor `pulid_id`) that stable-diffusion.cpp's
`--pulid-id-embedding` flag consumes. See docs/pulid.md for the format and
overall PuLID-Flux flow.
This script intentionally lives outside the C++ build: identity extraction
needs insightface + EVA-CLIP-L + IDFormer, which are PyTorch-only stacks
that would be impractical to reimplement in ggml just to run once per
source person. The C++ side downstream of this file is cross-vendor and
backend-agnostic.
Dependencies (recommended: vendor rather than pip-install due to upstream
packaging quirks):
- torch + safetensors
- The ToTheBeginning/PuLID repository's `pulid/pipeline_flux.py` and
its sibling packages (`flux/`, `eva_clip/`, `models/`). Put them on
PYTHONPATH or sys.path before running this script.
- insightface, facexlib (PuLID pipeline pulls these in)
- numpy, Pillow
Usage:
python pulid_extract_id.py \\
--portrait /path/to/source-photo.jpg \\
--pulid-weights /path/to/pulid_flux_v0.9.1.safetensors \\
--out /path/to/source.pulidembd
The portrait must contain a clearly visible face. insightface's antelopev2
detector will be auto-downloaded on first run.
"""
from __future__ import annotations
import argparse
import os
import sys
def _make_minimal_flux_skeleton(device):
"""PuLIDPipeline expects a `dit` (Flux transformer) to attach its
PerceiverAttentionCA modules to during construction. We never run a
forward pass on it -- the encoders alone (which is what we actually
need) live on the pipeline object, not the dit. So we instantiate a
real Flux skeleton with default params and never load its weights."""
import torch
from flux.model import Flux
from flux.util import configs
with torch.device("cpu"):
model = Flux(configs["flux-dev"].params).to(torch.bfloat16)
return model
def extract(portrait_path: str, pulid_weights: str) -> "torch.Tensor":
import numpy as np
import torch
from PIL import Image
from pulid.pipeline_flux import PuLIDPipeline
if torch.cuda.is_available():
device, onnx_provider = "cuda", "gpu"
else:
device, onnx_provider = "cpu", "cpu"
print(f"device={device}", flush=True)
print("constructing minimal Flux skeleton (no weights loaded)", flush=True)
dit = _make_minimal_flux_skeleton(device)
print("instantiating PuLIDPipeline", flush=True)
pulid = PuLIDPipeline(dit=dit, device=device,
weight_dtype=torch.bfloat16,
onnx_provider=onnx_provider)
print(f"loading PuLID weights from {pulid_weights}", flush=True)
# PuLIDPipeline.load_pretrain expects a "version" string used to construct
# the default filename when pretrain_path is None. We pass the file
# directly so the version string is informational only.
pulid.load_pretrain(pretrain_path=pulid_weights, version="v0.9.1")
print(f"extracting ID embedding from {portrait_path}", flush=True)
face_img = np.array(Image.open(portrait_path).convert("RGB"))
id_embedding, _ = pulid.get_id_embedding(face_img)
print(f"id embedding shape={tuple(id_embedding.shape)} dtype={id_embedding.dtype}",
flush=True)
if id_embedding.ndim == 3 and id_embedding.shape[0] == 1:
id_embedding = id_embedding[0]
return id_embedding
def write_embd(tensor, out_path: str, dtype_choice: str) -> None:
import gguf
import torch
if tensor.ndim != 2:
raise ValueError(f"expected (num_tokens, token_dim); got {tuple(tensor.shape)}")
num_tokens, token_dim = tensor.shape
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
# The embedding ships as a standard gguf container holding a single tensor
# named "pulid_id". numpy is row-major (num_tokens, token_dim); gguf stores
# dims reversed, so stable-diffusion.cpp reads it back as
# ne[0]=token_dim, ne[1]=num_tokens (see load_pulid_id_embedding).
writer = gguf.GGUFWriter(out_path, arch="pulid")
writer.add_uint32("pulid.version", 1)
if dtype_choice == "fp16":
arr = tensor.to(torch.float16).contiguous().cpu().numpy()
writer.add_tensor("pulid_id", arr)
elif dtype_choice == "fp32":
arr = tensor.to(torch.float32).contiguous().cpu().numpy()
writer.add_tensor("pulid_id", arr)
elif dtype_choice == "bf16":
raw = tensor.to(torch.bfloat16).contiguous().view(torch.uint16).cpu().numpy()
writer.add_tensor("pulid_id", raw,
raw_shape=(int(num_tokens), int(token_dim)),
raw_dtype=gguf.GGMLQuantizationType.BF16)
else:
raise ValueError(f"unknown --dtype {dtype_choice}")
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
print(f"wrote {out_path}: gguf, tensor pulid_id [{token_dim}, {num_tokens}] {dtype_choice}",
flush=True)
def main() -> int:
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--portrait", required=True,
help="Path to the source portrait image (JPG/PNG).")
ap.add_argument("--pulid-weights", required=True,
help="Path to pulid_flux_v0.9.x.safetensors.")
ap.add_argument("--out", required=True,
help="Output path for the .pulidembd binary.")
ap.add_argument("--dtype", default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Storage dtype (default fp16; produces ~131 KB).")
args = ap.parse_args()
if not os.path.exists(args.portrait):
print(f"ERROR: portrait not found at {args.portrait}", file=sys.stderr)
return 2
if not os.path.exists(args.pulid_weights):
print(f"ERROR: PuLID weights not found at {args.pulid_weights}", file=sys.stderr)
return 3
embedding = extract(args.portrait, args.pulid_weights)
write_embd(embedding, args.out, args.dtype)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -10,6 +10,7 @@
#include "conditioning/conditioner.hpp"
#include "core/ggml_extend_backend.h"
#include "model/diffusion/model.hpp"
#include "model_loader.h"
#include "model_manager.h"
#include "stable-diffusion.h"
@ -30,6 +31,7 @@ struct GenerationExtensionConditionContext {
Conditioner* conditioner;
ConditionerParams& condition_params;
const sd_pm_params_t& pm_params;
const sd_pulid_params_t& pulid_params;
int n_threads;
int total_steps;
};
@ -56,8 +58,20 @@ struct GenerationExtension {
const SDCondition& condition) const {
return condition;
}
// Called in the denoise loop for each enabled extension, after the per-step
// DiffusionParams (including its version-specific `extra`) has been built,
// but before diffusion_model->compute(). Lets an extension feed data into
// the diffusion forward that the conditioning-side hooks can't reach -- it
// can set/override fields on `params` (typically the architecture-specific
// `params.extra`, e.g. a guidance tensor, control payload, or an identity
// embedding for an adapter that injects inside the model's blocks). The
// extension targets whichever `extra` variant matches the active model.
// Mutates `params` only, never the extension. Default no-op.
virtual void before_diffusion(DiffusionParams& /*params*/, int /*step*/) const {}
};
std::shared_ptr<GenerationExtension> create_photomaker_extension();
std::shared_ptr<GenerationExtension> create_pulid_extension();
#endif

View File

@ -0,0 +1,143 @@
#include "extensions/generation_extension.h"
#include <cstring>
#include <variant>
#include "core/tensor_ggml.hpp"
#include "core/util.h"
#include "gguf.h"
// Load the precomputed PuLID identity embedding produced by
// scripts/pulid_extract_id.py into a sd::Tensor<float> (always materialized as
// fp32 for the diffusion path). Returns an empty tensor on any failure (the
// caller treats empty as "PuLID off").
//
// The file is a standard gguf container holding a single tensor named
// "pulid_id" with shape [token_dim, num_tokens] (ggml order; typically
// [2048, 32]) in f16 / bf16 / f32. Using gguf rather than a bespoke header
// means the shape + dtype are self-describing and we reuse ggml's reader.
static sd::Tensor<float> load_pulid_id_embedding(const char* path) {
sd::Tensor<float> empty;
if (path == nullptr || strlen(path) == 0) {
return empty;
}
struct ggml_context* ctx_data = nullptr;
struct gguf_init_params gp = {/*.no_alloc =*/false, /*.ctx =*/&ctx_data};
struct gguf_context* gguf_ctx = gguf_init_from_file(path, gp);
if (gguf_ctx == nullptr || ctx_data == nullptr) {
LOG_WARN("PuLID id-embedding: cannot read gguf '%s'", path);
if (gguf_ctx != nullptr)
gguf_free(gguf_ctx);
if (ctx_data != nullptr)
ggml_free(ctx_data);
return empty;
}
struct ggml_tensor* t = ggml_get_tensor(ctx_data, "pulid_id");
if (t == nullptr) {
LOG_WARN("PuLID id-embedding: no 'pulid_id' tensor in '%s'", path);
gguf_free(gguf_ctx);
ggml_free(ctx_data);
return empty;
}
const int64_t token_dim = t->ne[0];
const int64_t num_tokens = t->ne[1];
if (token_dim <= 0 || num_tokens <= 0 || token_dim > 65536 || num_tokens > 1024 ||
t->ne[2] != 1 || t->ne[3] != 1) {
LOG_WARN("PuLID id-embedding: implausible shape [%lld, %lld] in '%s'",
(long long)token_dim, (long long)num_tokens, path);
gguf_free(gguf_ctx);
ggml_free(ctx_data);
return empty;
}
const size_t n_elem = (size_t)token_dim * (size_t)num_tokens;
sd::Tensor<float> out({token_dim, num_tokens, 1});
float* dst = out.data();
if (t->type == GGML_TYPE_F32) {
memcpy(dst, t->data, n_elem * sizeof(float));
} else if (t->type == GGML_TYPE_F16) {
const ggml_fp16_t* src = reinterpret_cast<const ggml_fp16_t*>(t->data);
for (size_t i = 0; i < n_elem; i++) {
dst[i] = ggml_fp16_to_fp32(src[i]);
}
} else if (t->type == GGML_TYPE_BF16) {
const ggml_bf16_t* src = reinterpret_cast<const ggml_bf16_t*>(t->data);
for (size_t i = 0; i < n_elem; i++) {
dst[i] = ggml_bf16_to_fp32(src[i]);
}
} else {
LOG_WARN("PuLID id-embedding: unsupported tensor type %s in '%s'",
ggml_type_name(t->type), path);
gguf_free(gguf_ctx);
ggml_free(ctx_data);
return empty;
}
LOG_INFO("PuLID id-embedding: loaded [%lld, %lld] type=%s from '%s'",
(long long)token_dim, (long long)num_tokens, ggml_type_name(t->type), path);
gguf_free(gguf_ctx);
ggml_free(ctx_data);
return out;
}
// PuLID-Flux identity injection as a generation extension.
//
// Unlike PhotoMaker, PuLID does NOT modify the conditioning -- it injects an
// identity embedding via cross-attention *inside* the Flux denoise forward (the
// pulid_ca.* blocks). Those cross-attention weights are part of the Flux
// diffusion model and are loaded into the model tensor map before the model is
// constructed (see SDImpl ctor, gated on sd_ctx_params.pulid_weights_path), so
// this extension does not own a separate model. Its job is purely runtime:
// - prepare_condition: load the per-generation id-embedding file.
// - before_diffusion: hand that embedding (+ weight) to FluxDiffusionExtra,
// which flux.hpp reads to drive the pulid_ca injection.
struct PuLIDExtension : public GenerationExtension {
bool enabled = false;
sd::Tensor<float> id_embedding; // per-generation; empty when PuLID is off for this request
float id_weight = 1.0f;
const char* name() const override {
return "pulid";
}
bool is_enabled() const override {
return enabled;
}
bool init(const GenerationExtensionInitContext& ctx) override {
enabled = strlen(SAFE_STR(ctx.params->pulid_weights_path)) > 0;
return true;
}
void reset_runtime_condition() override {
id_embedding = {};
id_weight = 1.0f;
}
bool prepare_condition(GenerationExtensionConditionContext& ctx) override {
reset_runtime_condition();
if (!enabled) {
return false;
}
id_embedding = load_pulid_id_embedding(ctx.pulid_params.id_embedding_path);
id_weight = ctx.pulid_params.id_weight;
return false; // PuLID does not modify the conditioning
}
void before_diffusion(DiffusionParams& params, int /*step*/) const override {
if (!enabled || id_embedding.empty()) {
return;
}
if (auto* flux_extra = std::get_if<FluxDiffusionExtra>(&params.extra)) {
flux_extra->pulid_id = &id_embedding;
flux_extra->pulid_id_weight = id_weight;
}
}
};
std::shared_ptr<GenerationExtension> create_pulid_extension() {
return std::make_shared<PuLIDExtension>();
}

130
src/model/adapter/pulid.hpp Normal file
View File

@ -0,0 +1,130 @@
#ifndef __PULID_HPP__
#define __PULID_HPP__
#include "core/ggml_extend.hpp"
#include "model/common/block.hpp"
/**
* PuLID-Flux identity injection for stable-diffusion.cpp.
*
* Mirrors the PerceiverAttentionCA module from
* https://github.com/ToTheBeginning/PuLID/blob/main/pulid/encoders_transformer.py
*
* Each instance is a cross-attention layer where:
* Q comes from image tokens (dim = 3072 = Flux hidden_size)
* K, V come from a precomputed ID embedding (kv_dim = 2048, num_tokens = 32)
*
* 14 instances are inserted into the Flux denoise loop at fixed intervals:
* - Every 2nd of the 19 double_blocks (10 hook points)
* - Every 4th of the 38 single_blocks (10 hook points... but the v0.9.1
* reference uses 4 single hooks, for 14 total)
*
* Weight key prefix in pulid_flux_v0.9.1.safetensors:
* pulid_ca.<i>.norm1.{weight,bias}
* pulid_ca.<i>.norm2.{weight,bias}
* pulid_ca.<i>.to_q.weight
* pulid_ca.<i>.to_kv.weight
* pulid_ca.<i>.to_out.weight
*
* Pure-ggml implementation: all ops have Vulkan / CUDA / Metal kernels in
* the upstream ggml backends, so this works cross-vendor by construction.
*/
class PuLIDPerceiverAttentionCA : public GGMLBlock {
public:
static constexpr int64_t DEFAULT_DIM = 3072; // Flux hidden size
static constexpr int64_t DEFAULT_DIM_HEAD = 128;
static constexpr int64_t DEFAULT_HEADS = 16;
static constexpr int64_t DEFAULT_KV_DIM = 2048; // PuLID ID-embedding dim
protected:
int64_t dim;
int64_t dim_head;
int64_t heads;
int64_t kv_dim;
int64_t inner_dim; // dim_head * heads = 2048
public:
PuLIDPerceiverAttentionCA(int64_t dim = DEFAULT_DIM,
int64_t dim_head = DEFAULT_DIM_HEAD,
int64_t heads = DEFAULT_HEADS,
int64_t kv_dim = DEFAULT_KV_DIM)
: dim(dim),
dim_head(dim_head),
heads(heads),
kv_dim(kv_dim),
inner_dim(dim_head * heads) {
// Note the PyTorch reference's surprising signature:
// norm1 operates on x (the id_embedding side, kv_dim wide)
// norm2 operates on latents (the image tokens, dim wide)
// to_q consumes latents (dim -> inner_dim)
// to_kv consumes x (kv_dim -> 2*inner_dim)
// to_out projects (inner_dim -> dim)
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(kv_dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, inner_dim, /*bias=*/false));
blocks["to_kv"] = std::shared_ptr<GGMLBlock>(new Linear(kv_dim, inner_dim * 2, /*bias=*/false));
blocks["to_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim, /*bias=*/false));
}
/**
* Compute: residual_to_image = PerceiverAttentionCA(id_embedding, image_tokens)
*
* Inputs:
* id_embedding [N, n_id_tokens=32, kv_dim=2048]
* image_tokens [N, n_img_tokens, dim=3072]
*
* Returns:
* [N, n_img_tokens, dim=3072] -- to be added to image_tokens by the caller,
* scaled by id_weight.
*/
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* id_embedding,
ggml_tensor* image_tokens) {
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_kv = std::dynamic_pointer_cast<Linear>(blocks["to_kv"]);
auto to_out = std::dynamic_pointer_cast<Linear>(blocks["to_out"]);
// Normalize each input on its own dim. The PyTorch reference normalizes
// x (id_embedding) and `latents` (image_tokens) separately, then uses
// latents for Q and x for K/V -- mind the unusual cross-attention shape.
ggml_tensor* x_normed = norm1->forward(ctx, id_embedding); // [N, 32, 2048]
ggml_tensor* lat_normed = norm2->forward(ctx, image_tokens); // [N, T_img, 3072]
// Projections. to_q : 3072 -> 2048 ; to_kv : 2048 -> 4096 (k concat v).
ggml_tensor* q = to_q->forward(ctx, lat_normed); // [N, T_img, 2048]
ggml_tensor* kv = to_kv->forward(ctx, x_normed); // [N, 32, 4096]
// Split KV into K (first inner_dim of last axis) and V (second
// inner_dim). ggml_view_3d gives strided views without copying;
// ggml_cont materializes them so ggml_ext_attention_ext sees
// contiguous tensors.
ggml_tensor* k = ggml_view_3d(ctx->ggml_ctx, kv,
inner_dim, kv->ne[1], kv->ne[2],
kv->nb[1], kv->nb[2],
/*offset=*/0); // [N, 32, 2048]
ggml_tensor* v = ggml_view_3d(ctx->ggml_ctx, kv,
inner_dim, kv->ne[1], kv->ne[2],
kv->nb[1], kv->nb[2],
/*offset=*/inner_dim * ggml_element_size(kv)); // [N, 32, 2048]
k = ggml_cont(ctx->ggml_ctx, k);
v = ggml_cont(ctx->ggml_ctx, v);
// Standard multi-head attention. ggml_ext_attention_ext expects
// [N, n_token, embed_dim] and reshapes into heads internally.
// n_head = heads (=16), per-head dim = inner_dim / heads (=128).
ggml_tensor* attn_out = ggml_ext_attention_ext(
ctx->ggml_ctx, ctx->backend,
q, k, v,
heads,
/*mask=*/nullptr,
/*diag_mask_inf=*/false); // [N, T_img, inner_dim=2048]
// Project back to image-token width (3072).
ggml_tensor* out = to_out->forward(ctx, attn_out); // [N, T_img, 3072]
return out;
}
};
#endif // __PULID_HPP__

View File

@ -4,6 +4,7 @@
#include <memory>
#include <vector>
#include "model/adapter/pulid.hpp"
#include "model/common/rope.hpp"
#include "model/diffusion/dit.hpp"
#include "model/diffusion/model.hpp"
@ -49,6 +50,13 @@ namespace Flux {
float ref_index_scale = 1.f;
ChromaRadianceConfig chroma_radiance_params;
// PuLID-Flux identity injection. Turned on by the runner when a
// --pulid-weights path is provided. The intervals are fixed by the
// PuLID v0.9.1 architecture (every 2nd double, every 4th single).
bool pulid_enabled = false;
int pulid_double_interval = 2;
int pulid_single_interval = 4;
static FluxConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
SDVersion version = VERSION_FLUX) {
@ -138,6 +146,13 @@ namespace Flux {
if (ends_with(name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = tensor_storage.ne[0];
}
// PuLID weights live alongside the diffusion model under the same
// prefix (pulid_ca.<i>.<sub>) when the pulid loader merges them in
// (see stable-diffusion.cpp). Spotting any pulid_ca.* key flips the
// flag so the Flux ctor builds the pulid_ca.<i> child blocks.
if (name.find("pulid_ca.") != std::string::npos) {
config.pulid_enabled = true;
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != config.patch_size) {
GGML_ASSERT(config.patch_size == 2 * actual_radiance_patch_size);
@ -957,6 +972,29 @@ namespace Flux {
blocks["double_stream_modulation_txt"] = std::make_shared<Modulation>(config.hidden_size, true, !config.disable_bias);
blocks["single_stream_modulation"] = std::make_shared<Modulation>(config.hidden_size, false, !config.disable_bias);
}
// PuLID-Flux identity-injection cross-attention modules. Only constructed
// when config.pulid_enabled is set (turned on by the runner after seeing a
// --pulid-weights path during model load). Counts come straight from PuLID
// v0.9.1's pipeline_flux.py: every `pulid_double_interval` double block
// (=2) and every `pulid_single_interval` single block (=4). For a stock
// Flux Dev (depth=19, depth_single_blocks=38), this means 10 + 10 = 20
// hook points... but the reference uses ceil-rounding so the actual count
// is `ceil(depth/2) + ceil(depth_single_blocks/4)` = 10 + 10 = 20. PuLID
// v0.9.1 trained weights have 20 entries.
if (config.pulid_enabled) {
int num_double_ca = (config.depth + config.pulid_double_interval - 1) / config.pulid_double_interval;
int num_single_ca = (config.depth_single_blocks + config.pulid_single_interval - 1) / config.pulid_single_interval;
int num_ca = num_double_ca + num_single_ca;
for (int i = 0; i < num_ca; i++) {
blocks["pulid_ca." + std::to_string(i)] =
std::shared_ptr<GGMLBlock>(new PuLIDPerceiverAttentionCA(
/*dim=*/ config.hidden_size,
/*dim_head=*/PuLIDPerceiverAttentionCA::DEFAULT_DIM_HEAD,
/*heads=*/ PuLIDPerceiverAttentionCA::DEFAULT_HEADS,
/*kv_dim=*/ PuLIDPerceiverAttentionCA::DEFAULT_KV_DIM));
}
}
}
ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
@ -967,7 +1005,9 @@ namespace Flux {
ggml_tensor* guidance,
ggml_tensor* pe,
ggml_tensor* mod_index_arange = nullptr,
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
ggml_tensor* pulid_id = nullptr,
float pulid_id_weight = 1.0f) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
@ -1044,6 +1084,23 @@ namespace Flux {
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt");
sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec");
// PuLID identity injection: mirrors ToTheBeginning/PuLID
// pulid/encoders_transformer.py + flux/model.py. The CA layers
// run *between* transformer blocks, with their output added to
// img (scaled by id_weight) at every `pulid_double_interval`-th
// double_block and every `pulid_single_interval`-th single_block.
//
// skip_layers + PuLID is NOT a supported combination -- skipping
// a block at a PuLID-aligned index would either misalign the
// ca_idx assignment (silent quality regression) or require us
// to invent a non-reference index policy. Refuse early instead.
const bool pulid_active = config.pulid_enabled && pulid_id != nullptr;
if (pulid_active && !skip_layers.empty()) {
LOG_WARN("PuLID + skip_layers is not supported; disabling PuLID for this generation.");
}
const bool pulid_run = pulid_active && skip_layers.empty();
int ca_idx = 0;
for (int i = 0; i < config.depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
@ -1056,9 +1113,19 @@ namespace Flux {
txt = img_txt.second; // [N, n_txt_token, hidden_size]
sd::ggml_graph_cut::mark_graph_cut(img, "flux.double_blocks." + std::to_string(i), "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.double_blocks." + std::to_string(i), "txt");
if (pulid_run && (i % config.pulid_double_interval == 0)) {
auto pulid_ca = std::dynamic_pointer_cast<PuLIDPerceiverAttentionCA>(
blocks["pulid_ca." + std::to_string(ca_idx)]);
ggml_tensor* ca_out = pulid_ca->forward(ctx, pulid_id, img); // [N, n_img_token, hidden_size]
img = ggml_add(ctx->ggml_ctx, img, ggml_scale(ctx->ggml_ctx, ca_out, pulid_id_weight));
sd::ggml_graph_cut::mark_graph_cut(img, "flux.pulid_ca." + std::to_string(ca_idx), "img");
ca_idx++;
}
}
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
const int64_t n_txt_tok = txt->ne[1]; // for splitting back into img portion below
for (int i = 0; i < config.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + config.depth) != skip_layers.end()) {
continue;
@ -1067,6 +1134,31 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
sd::ggml_graph_cut::mark_graph_cut(txt_img, "flux.single_blocks." + std::to_string(i), "txt_img");
if (pulid_run && (i % config.pulid_single_interval == 0)) {
auto pulid_ca = std::dynamic_pointer_cast<PuLIDPerceiverAttentionCA>(
blocks["pulid_ca." + std::to_string(ca_idx)]);
// Split txt_img into [txt | img], inject ID into the img portion
// only, then concatenate back. Matches the PyTorch reference.
ggml_tensor* txt_part = ggml_view_3d(ctx->ggml_ctx, txt_img,
txt_img->ne[0], n_txt_tok, txt_img->ne[2],
txt_img->nb[1], txt_img->nb[2],
0);
ggml_tensor* img_part = ggml_view_3d(ctx->ggml_ctx, txt_img,
txt_img->ne[0],
txt_img->ne[1] - n_txt_tok,
txt_img->ne[2],
txt_img->nb[1],
txt_img->nb[2],
n_txt_tok * txt_img->nb[1]);
txt_part = ggml_cont(ctx->ggml_ctx, txt_part);
img_part = ggml_cont(ctx->ggml_ctx, img_part);
ggml_tensor* ca_out = pulid_ca->forward(ctx, pulid_id, img_part);
img_part = ggml_add(ctx->ggml_ctx, img_part, ggml_scale(ctx->ggml_ctx, ca_out, pulid_id_weight));
txt_img = ggml_concat(ctx->ggml_ctx, txt_part, img_part, 1);
sd::ggml_graph_cut::mark_graph_cut(txt_img, "flux.pulid_ca." + std::to_string(ca_idx), "txt_img");
ca_idx++;
}
}
img = ggml_view_3d(ctx->ggml_ctx,
@ -1105,7 +1197,9 @@ namespace Flux {
ggml_tensor* mod_index_arange = nullptr,
ggml_tensor* dct = nullptr,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
ggml_tensor* pulid_id = nullptr,
float pulid_id_weight = 1.0f) {
GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0];
@ -1131,7 +1225,8 @@ namespace Flux {
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers,
pulid_id, pulid_id_weight); // [N, n_img_token, hidden_size]
// nerf decode
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
@ -1179,7 +1274,9 @@ namespace Flux {
ggml_tensor* mod_index_arange = nullptr,
ggml_tensor* dct = nullptr,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
ggml_tensor* pulid_id = nullptr,
float pulid_id_weight = 1.0f) {
GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0];
@ -1226,7 +1323,8 @@ namespace Flux {
}
}
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers,
pulid_id, pulid_id_weight); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
@ -1248,7 +1346,9 @@ namespace Flux {
ggml_tensor* mod_index_arange = nullptr,
ggml_tensor* dct = nullptr,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
ggml_tensor* pulid_id = nullptr,
float pulid_id_weight = 1.0f) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
@ -1271,7 +1371,9 @@ namespace Flux {
mod_index_arange,
dct,
ref_latents,
skip_layers);
skip_layers,
pulid_id,
pulid_id_weight);
} else {
return forward_flux_chroma(ctx,
x,
@ -1284,7 +1386,9 @@ namespace Flux {
mod_index_arange,
dct,
ref_latents,
skip_layers);
skip_layers,
pulid_id,
pulid_id_weight);
}
}
};
@ -1384,7 +1488,9 @@ namespace Flux {
const sd::Tensor<float>& guidance_tensor = {},
const std::vector<sd::Tensor<float>>& ref_latents_tensor = {},
bool increase_ref_index = false,
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
const sd::Tensor<float>& pulid_id_tensor = {},
float pulid_id_weight = 1.0f) {
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
ggml_tensor* context = make_optional_input(context_tensor);
@ -1461,6 +1567,13 @@ namespace Flux {
set_backend_tensor_data(dct, dct_vec.data());
}
// Materialize the PuLID id embedding into the compute graph when
// pulid_id_tensor is non-empty. forward() accepts nullptr for the
// no-injection case.
ggml_tensor* pulid_id = pulid_id_tensor.empty()
? nullptr
: make_input(pulid_id_tensor);
auto runner_ctx = get_context();
ggml_tensor* out = flux.forward(&runner_ctx,
@ -1474,7 +1587,9 @@ namespace Flux {
mod_index_arange,
dct,
ref_latents,
skip_layers);
skip_layers,
pulid_id,
pulid_id_weight);
ggml_build_forward_expand(gf, out);
@ -1490,14 +1605,17 @@ namespace Flux {
const sd::Tensor<float>& guidance = {},
const std::vector<sd::Tensor<float>>& ref_latents = {},
bool increase_ref_index = false,
std::vector<int> skip_layers = std::vector<int>()) {
std::vector<int> skip_layers = std::vector<int>(),
const sd::Tensor<float>& pulid_id = {},
float pulid_id_weight = 1.0f) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
// pulid_id: empty (no injection) or [N, num_id_tokens=32, kv_dim=2048]
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers, pulid_id, pulid_id_weight);
};
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
@ -1520,7 +1638,9 @@ namespace Flux {
tensor_or_empty(extra->guidance),
diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents,
diffusion_params.increase_ref_index,
extra->skip_layers ? *extra->skip_layers : empty_skip_layers);
extra->skip_layers ? *extra->skip_layers : empty_skip_layers,
tensor_or_empty(extra->pulid_id),
extra->pulid_id_weight);
}
void test() {

View File

@ -22,6 +22,11 @@ struct SkipLayerDiffusionExtra {
struct FluxDiffusionExtra {
const sd::Tensor<float>* guidance = nullptr;
const std::vector<int>* skip_layers = nullptr;
// PuLID-Flux: precomputed (N=1, num_tokens=32, kv_dim=2048) identity embedding
// produced by runtime-scripts/pulid_extract_id.py. nullptr when PuLID is
// disabled. id_weight is per-job (typical 0.7-1.2; default 1.0).
const sd::Tensor<float>* pulid_id = nullptr;
float pulid_id_weight = 1.0f;
};
struct AnimaDiffusionExtra {

View File

@ -428,6 +428,22 @@ public:
}
}
if (strlen(SAFE_STR(sd_ctx_params->pulid_weights_path)) > 0) {
LOG_INFO("loading PuLID weights from '%s'", sd_ctx_params->pulid_weights_path);
// PuLID's cross-attention (pulid_ca.*) weights are part of the Flux
// diffusion model -- its blocks are constructed inside FluxModel when
// the tensor map contains pulid_ca.* keys. So they must be merged into
// the model loader here, BEFORE the diffusion model is built; that is
// why this stays in the ctor rather than in the pulid generation
// extension (whose init runs after model construction). The runtime
// side -- per-generation id-embedding + per-step injection -- lives in
// src/extensions/pulid_extension.cpp.
if (!model_loader.init_from_file(sd_ctx_params->pulid_weights_path,
"model.diffusion_model.")) {
LOG_WARN("loading PuLID weights from '%s' failed", sd_ctx_params->pulid_weights_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->llm_path)) > 0) {
LOG_INFO("loading llm from '%s'", sd_ctx_params->llm_path);
if (!model_loader.init_from_file(sd_ctx_params->llm_path, "text_encoders.llm.")) {
@ -1012,6 +1028,14 @@ public:
if (photomaker_extension->is_enabled()) {
generation_extensions.push_back(photomaker_extension);
}
auto pulid_extension = create_pulid_extension();
if (!pulid_extension->init(extension_ctx)) {
return false;
}
if (pulid_extension->is_enabled()) {
generation_extensions.push_back(pulid_extension);
}
}
for (auto& extension : generation_extensions) {
if (!register_runner_params(extension->name(),
@ -1522,6 +1546,7 @@ public:
}
void prepare_generation_extensions(const sd_pm_params_t& pm_params,
const sd_pulid_params_t& pulid_params,
ConditionerParams& condition_params,
int total_steps) {
reset_generation_extensions();
@ -1529,6 +1554,7 @@ public:
cond_stage_model.get(),
condition_params,
pm_params,
pulid_params,
n_threads,
total_steps,
};
@ -2043,6 +2069,10 @@ public:
return std::move(cached_output);
}
for (const auto& extension : generation_extensions) {
extension->before_diffusion(diffusion_params, step);
}
auto output_opt = work_diffusion_model->compute(n_threads, diffusion_params);
if (output_opt.empty()) {
LOG_ERROR("diffusion model compute failed");
@ -2642,6 +2672,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->backend = nullptr;
sd_ctx_params->params_backend = nullptr;
sd_ctx_params->rpc_servers = nullptr;
sd_ctx_params->pulid_weights_path = nullptr;
}
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
@ -2667,6 +2698,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"taesd_path: %s\n"
"control_net_path: %s\n"
"photo_maker_path: %s\n"
"pulid_weights_path: %s\n"
"tensor_type_rules: %s\n"
"n_threads: %d\n"
"wtype: %s\n"
@ -2701,6 +2733,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->pulid_weights_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
sd_ctx_params->n_threads,
sd_type_name(sd_ctx_params->wtype),
@ -2795,6 +2828,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->pulid_params = {nullptr, 1.0f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires);
@ -3096,6 +3130,7 @@ struct GenerationRequest {
sd_guidance_params_t guidance = {};
sd_guidance_params_t high_noise_guidance = {};
sd_pm_params_t pm_params = {};
sd_pulid_params_t pulid_params = {};
sd_hires_params_t hires = {};
int frames = -1;
int requested_frames = -1;
@ -3121,6 +3156,7 @@ struct GenerationRequest {
has_ref_images = sd_img_gen_params->ref_images_count > 0;
guidance = sd_img_gen_params->sample_params.guidance;
pm_params = sd_img_gen_params->pm_params;
pulid_params = sd_img_gen_params->pulid_params;
hires = sd_img_gen_params->hires;
cache_params = &sd_img_gen_params->cache;
resolve(sd_ctx);
@ -4047,6 +4083,7 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
condition_params.ref_images = &latents->ref_images;
sd_ctx->sd->prepare_generation_extensions(request->pm_params,
request->pulid_params,
condition_params,
plan->total_steps);
int64_t prepare_start_ms = ggml_time_ms();