add support for ref_clip_feats

This commit is contained in:
leejet 2025-12-26 00:10:22 +08:00
parent 190c523cec
commit 3d565871d9
2 changed files with 18 additions and 6 deletions

View File

@ -24,6 +24,7 @@ struct DiffusionParams {
float vace_strength = 1.f; float vace_strength = 1.f;
std::vector<int> skip_layers = {}; std::vector<int> skip_layers = {};
std::vector<struct ggml_tensor*> extra_contexts; // for z-image-omni std::vector<struct ggml_tensor*> extra_contexts; // for z-image-omni
std::vector<struct ggml_tensor*> ref_clip_feats; // for z-image-omni
}; };
struct DiffusionModel { struct DiffusionModel {
@ -444,7 +445,7 @@ struct ZImageModel : public DiffusionModel {
diffusion_params.timesteps, diffusion_params.timesteps,
contexts, contexts,
diffusion_params.ref_latents, diffusion_params.ref_latents,
{}, diffusion_params.ref_clip_feats,
output, output,
output_ctx); output_ctx);
} }

View File

@ -1528,6 +1528,7 @@ public:
int start_merge_step, int start_merge_step,
SDCondition id_cond, SDCondition id_cond,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
std::vector<ggml_tensor*> ref_clip_feats = {},
bool increase_ref_index = false, bool increase_ref_index = false,
ggml_tensor* denoise_mask = nullptr, ggml_tensor* denoise_mask = nullptr,
ggml_tensor* vace_context = nullptr, ggml_tensor* vace_context = nullptr,
@ -1921,6 +1922,7 @@ public:
diffusion_params.timesteps = timesteps; diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor; diffusion_params.guidance = guidance_tensor;
diffusion_params.ref_latents = ref_latents; diffusion_params.ref_latents = ref_latents;
diffusion_params.ref_clip_feats = ref_clip_feats;
diffusion_params.increase_ref_index = increase_ref_index; diffusion_params.increase_ref_index = increase_ref_index;
diffusion_params.controls = controls; diffusion_params.controls = controls;
diffusion_params.control_strength = control_strength; diffusion_params.control_strength = control_strength;
@ -3103,6 +3105,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_pm_params_t pm_params, sd_pm_params_t pm_params,
std::vector<sd_image_t*> ref_images, std::vector<sd_image_t*> ref_images,
std::vector<ggml_tensor*> ref_latents, std::vector<ggml_tensor*> ref_latents,
std::vector<ggml_tensor*> ref_clip_feats,
bool increase_ref_index, bool increase_ref_index,
ggml_tensor* concat_latent = nullptr, ggml_tensor* concat_latent = nullptr,
ggml_tensor* denoise_mask = nullptr, ggml_tensor* denoise_mask = nullptr,
@ -3391,6 +3394,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
start_merge_step, start_merge_step,
id_cond, id_cond,
ref_latents, ref_latents,
ref_clip_feats,
increase_ref_index, increase_ref_index,
denoise_mask, denoise_mask,
nullptr, nullptr,
@ -3657,6 +3661,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
} }
std::vector<ggml_tensor*> ref_latents; std::vector<ggml_tensor*> ref_latents;
std::vector<ggml_tensor*> ref_clip_feats;
for (int i = 0; i < ref_images.size(); i++) { for (int i = 0; i < ref_images.size(); i++) {
ggml_tensor* img; ggml_tensor* img;
if (sd_img_gen_params->auto_resize_ref_image) { if (sd_img_gen_params->auto_resize_ref_image) {
@ -3703,6 +3708,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent); ref_latents.push_back(latent);
auto clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, *ref_images[i], false, -2);
ref_clip_feats.push_back(clip_vision_output);
} }
if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) { if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) {
@ -3730,6 +3738,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->pm_params, sd_img_gen_params->pm_params,
ref_images, ref_images,
ref_latents, ref_latents,
ref_clip_feats,
sd_img_gen_params->increase_ref_index, sd_img_gen_params->increase_ref_index,
concat_latent, concat_latent,
denoise_mask, denoise_mask,
@ -4098,8 +4107,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
high_noise_sample_method, high_noise_sample_method,
high_noise_sigmas, high_noise_sigmas,
-1, -1,
{}, {}, // id_cond
{}, {}, // ref_latents
{}, // ref_clip_feats
false, false,
denoise_mask, denoise_mask,
vace_context, vace_context,
@ -4135,8 +4145,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sample_method, sample_method,
sigmas, sigmas,
-1, -1,
{}, {}, // id_cond
{}, {}, // ref_latents
{}, // ref_clip_feats
false, false,
denoise_mask, denoise_mask,
vace_context, vace_context,