feat: add CosXL support (#683)

This commit is contained in:
stduhpf 2025-07-01 17:13:04 +02:00 committed by GitHub
parent ecf5db97ae
commit 9251756086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 20 deletions

View File

@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> inputs; std::vector<float> inputs;
std::vector<float> results(n + 1); std::vector<float> results(n + 1);
switch (version) { if (sd_version_is_sd2((SDVersion)version)) {
case VERSION_SD2: /* fallthrough */ LOG_WARN("AYS not designed for SD2.X models");
LOG_WARN("AYS not designed for SD2.X models"); } /* fallthrough */
case VERSION_SD1: else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS using SD1.5 noise levels"); LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0]; inputs = noise_levels[0];
break; } else if (sd_version_is_sdxl((SDVersion)version)) {
case VERSION_SDXL: LOG_INFO("AYS using SDXL noise levels");
LOG_INFO("AYS using SDXL noise levels"); inputs = noise_levels[1];
inputs = noise_levels[1]; } else if (version == VERSION_SVD) {
break; LOG_INFO("AYS using SVD noise levels");
case VERSION_SVD: inputs = noise_levels[2];
LOG_INFO("AYS using SVD noise levels"); } else {
inputs = noise_levels[2]; LOG_ERROR("Version not compatable with AYS scheduler");
break; return results;
default:
LOG_ERROR("Version not compatable with AYS scheduler");
return results;
} }
/* Stretches those pre-calculated reference levels out to the desired /* Stretches those pre-calculated reference levels out to the desired
@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser {
} }
}; };
struct EDMVDenoiser : public CompVisVDenoiser {
float min_sigma = 0.002;
float max_sigma = 120.0;
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) {
schedule = std::make_shared<ExponentialSchedule>();
}
float t_to_sigma(float t) {
return std::exp(t * 4/(float)TIMESTEPS);
}
float sigma_to_t(float s) {
return 0.25 * std::log(s);
}
float sigma_min() {
return min_sigma;
}
float sigma_max() {
return max_sigma;
}
};
float time_snr_shift(float alpha, float t) { float time_snr_shift(float alpha, float t) {
if (alpha == 1.0f) { if (alpha == 1.0f) {
return t; return t;

View File

@ -103,6 +103,9 @@ public:
bool vae_tiling = false; bool vae_tiling = false;
bool stacked_id = false; bool stacked_id = false;
bool is_using_v_parameterization = false;
bool is_using_edm_v_parameterization = false;
std::map<std::string, struct ggml_tensor*> tensors; std::map<std::string, struct ggml_tensor*> tensors;
std::string lora_model_dir; std::string lora_model_dir;
@ -543,12 +546,17 @@ public:
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
// check is_using_v_parameterization_for_sd2 // check is_using_v_parameterization_for_sd2
bool is_using_v_parameterization = false;
if (sd_version_is_sd2(version)) { if (sd_version_is_sd2(version)) {
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
is_using_v_parameterization = true; is_using_v_parameterization = true;
} }
} else if (sd_version_is_sdxl(version)) { } else if (sd_version_is_sdxl(version)) {
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
is_using_v_parameterization = true; is_using_v_parameterization = true;
} }
@ -573,6 +581,9 @@ public:
} else if (is_using_v_parameterization) { } else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode"); LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>(); denoiser = std::make_shared<CompVisVDenoiser>();
} else if (is_using_edm_v_parameterization) {
LOG_INFO("running in v-prediction EDM mode");
denoiser = std::make_shared<EDMVDenoiser>();
} else { } else {
LOG_INFO("running in eps-prediction mode"); LOG_INFO("running in eps-prediction mode");
} }
@ -1396,7 +1407,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
SDCondition uncond; SDCondition uncond;
if (cfg_scale != 1.0) { if (cfg_scale != 1.0) {
bool force_zero_embeddings = false; bool force_zero_embeddings = false;
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
force_zero_embeddings = true; force_zero_embeddings = true;
} }
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,