Compare commits

..

No commits in common. "b0179181069254389ccad604e44f17a2c25b4094" and "abb115cd021fc2beed826604ed1a479b6a77671c" have entirely different histories.

View File

@ -344,6 +344,9 @@ public:
LOG_INFO("Using flash attention in the diffusion model"); LOG_INFO("Using flash attention in the diffusion model");
} }
if (sd_version_is_sd3(version)) { if (sd_version_is_sd3(version)) {
if (sd_ctx_params->diffusion_flash_attn) {
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
}
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
model_loader.tensor_storages_types); model_loader.tensor_storages_types);
@ -359,15 +362,6 @@ public:
} }
} }
if (is_chroma) { if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
"If you find that the generated images are broken, "
"try either disabling flash attention or specifying "
"--chroma-disable-dit-mask as a workaround.");
}
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
@ -1552,7 +1546,7 @@ enum scheduler_t str_to_schedule(const char* str) {
} }
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {}; memset((void*)sd_ctx_params, 0, sizeof(sd_ctx_params_t));
sd_ctx_params->vae_decode_only = true; sd_ctx_params->vae_decode_only = true;
sd_ctx_params->vae_tiling = false; sd_ctx_params->vae_tiling = false;
sd_ctx_params->free_params_immediately = true; sd_ctx_params->free_params_immediately = true;
@ -1636,7 +1630,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
} }
void sd_sample_params_init(sd_sample_params_t* sample_params) { void sd_sample_params_init(sd_sample_params_t* sample_params) {
*sample_params = {};
sample_params->guidance.txt_cfg = 7.0f; sample_params->guidance.txt_cfg = 7.0f;
sample_params->guidance.img_cfg = INFINITY; sample_params->guidance.img_cfg = INFINITY;
sample_params->guidance.distilled_guidance = 3.5f; sample_params->guidance.distilled_guidance = 3.5f;
@ -1683,9 +1676,9 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
} }
void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
*sd_img_gen_params = {}; memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t));
sd_img_gen_params->clip_skip = -1;
sd_sample_params_init(&sd_img_gen_params->sample_params); sd_sample_params_init(&sd_img_gen_params->sample_params);
sd_img_gen_params->clip_skip = -1;
sd_img_gen_params->ref_images_count = 0; sd_img_gen_params->ref_images_count = 0;
sd_img_gen_params->width = 512; sd_img_gen_params->width = 512;
sd_img_gen_params->height = 512; sd_img_gen_params->height = 512;
@ -1742,7 +1735,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
} }
void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
*sd_vid_gen_params = {}; memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
sd_sample_params_init(&sd_vid_gen_params->sample_params); sd_sample_params_init(&sd_vid_gen_params->sample_params);
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params); sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1; sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
@ -1766,7 +1759,6 @@ sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params) {
sd_ctx->sd = new StableDiffusionGGML(); sd_ctx->sd = new StableDiffusionGGML();
if (sd_ctx->sd == NULL) { if (sd_ctx->sd == NULL) {
free(sd_ctx);
return NULL; return NULL;
} }
@ -2369,7 +2361,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->control_strength, sd_img_gen_params->control_strength,
sd_img_gen_params->style_strength, sd_img_gen_params->style_strength,
sd_img_gen_params->normalize_input, sd_img_gen_params->normalize_input,
SAFE_STR(sd_img_gen_params->input_id_images_path), sd_img_gen_params->input_id_images_path,
ref_latents, ref_latents,
sd_img_gen_params->increase_ref_index, sd_img_gen_params->increase_ref_index,
concat_latent, concat_latent,