feat: add CosXL support (#683)

This commit is contained in:
stduhpf 2025-07-01 17:13:04 +02:00 committed by GitHub
parent ecf5db97ae
commit 9251756086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 20 deletions

View File

@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> inputs;
std::vector<float> results(n + 1);
switch (version) {
case VERSION_SD2: /* fallthrough */
LOG_WARN("AYS not designed for SD2.X models");
case VERSION_SD1:
LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0];
break;
case VERSION_SDXL:
LOG_INFO("AYS using SDXL noise levels");
inputs = noise_levels[1];
break;
case VERSION_SVD:
LOG_INFO("AYS using SVD noise levels");
inputs = noise_levels[2];
break;
default:
LOG_ERROR("Version not compatable with AYS scheduler");
return results;
if (sd_version_is_sd2((SDVersion)version)) {
LOG_WARN("AYS not designed for SD2.X models");
} /* fallthrough */
else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0];
} else if (sd_version_is_sdxl((SDVersion)version)) {
LOG_INFO("AYS using SDXL noise levels");
inputs = noise_levels[1];
} else if (version == VERSION_SVD) {
LOG_INFO("AYS using SVD noise levels");
inputs = noise_levels[2];
} else {
LOG_ERROR("Version not compatable with AYS scheduler");
return results;
}
/* Stretches those pre-calculated reference levels out to the desired
@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser {
}
};
struct EDMVDenoiser : public CompVisVDenoiser {
float min_sigma = 0.002;
float max_sigma = 120.0;
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) {
schedule = std::make_shared<ExponentialSchedule>();
}
float t_to_sigma(float t) {
return std::exp(t * 4/(float)TIMESTEPS);
}
float sigma_to_t(float s) {
return 0.25 * std::log(s);
}
float sigma_min() {
return min_sigma;
}
float sigma_max() {
return max_sigma;
}
};
float time_snr_shift(float alpha, float t) {
if (alpha == 1.0f) {
return t;

View File

@ -103,6 +103,9 @@ public:
bool vae_tiling = false;
bool stacked_id = false;
bool is_using_v_parameterization = false;
bool is_using_edm_v_parameterization = false;
std::map<std::string, struct ggml_tensor*> tensors;
std::string lora_model_dir;
@ -543,12 +546,17 @@ public:
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
// check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false;
if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
is_using_v_parameterization = true;
}
@ -573,6 +581,9 @@ public:
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
} else if (is_using_edm_v_parameterization) {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
} else {
LOG_INFO("running in eps-prediction mode");
}
@ -1396,7 +1407,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
SDCondition uncond;
if (cfg_scale != 1.0) {
bool force_zero_embeddings = false;
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
force_zero_embeddings = true;
}
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,