mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add CosXL support (#683)
This commit is contained in:
parent
ecf5db97ae
commit
9251756086
58
denoiser.hpp
58
denoiser.hpp
@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule {
|
||||
std::vector<float> inputs;
|
||||
std::vector<float> results(n + 1);
|
||||
|
||||
switch (version) {
|
||||
case VERSION_SD2: /* fallthrough */
|
||||
LOG_WARN("AYS not designed for SD2.X models");
|
||||
case VERSION_SD1:
|
||||
LOG_INFO("AYS using SD1.5 noise levels");
|
||||
inputs = noise_levels[0];
|
||||
break;
|
||||
case VERSION_SDXL:
|
||||
LOG_INFO("AYS using SDXL noise levels");
|
||||
inputs = noise_levels[1];
|
||||
break;
|
||||
case VERSION_SVD:
|
||||
LOG_INFO("AYS using SVD noise levels");
|
||||
inputs = noise_levels[2];
|
||||
break;
|
||||
default:
|
||||
LOG_ERROR("Version not compatable with AYS scheduler");
|
||||
return results;
|
||||
if (sd_version_is_sd2((SDVersion)version)) {
|
||||
LOG_WARN("AYS not designed for SD2.X models");
|
||||
} /* fallthrough */
|
||||
else if (sd_version_is_sd1((SDVersion)version)) {
|
||||
LOG_INFO("AYS using SD1.5 noise levels");
|
||||
inputs = noise_levels[0];
|
||||
} else if (sd_version_is_sdxl((SDVersion)version)) {
|
||||
LOG_INFO("AYS using SDXL noise levels");
|
||||
inputs = noise_levels[1];
|
||||
} else if (version == VERSION_SVD) {
|
||||
LOG_INFO("AYS using SVD noise levels");
|
||||
inputs = noise_levels[2];
|
||||
} else {
|
||||
LOG_ERROR("Version not compatable with AYS scheduler");
|
||||
return results;
|
||||
}
|
||||
|
||||
/* Stretches those pre-calculated reference levels out to the desired
|
||||
@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser {
|
||||
}
|
||||
};
|
||||
|
||||
struct EDMVDenoiser : public CompVisVDenoiser {
|
||||
float min_sigma = 0.002;
|
||||
float max_sigma = 120.0;
|
||||
|
||||
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) {
|
||||
schedule = std::make_shared<ExponentialSchedule>();
|
||||
}
|
||||
|
||||
float t_to_sigma(float t) {
|
||||
return std::exp(t * 4/(float)TIMESTEPS);
|
||||
}
|
||||
|
||||
float sigma_to_t(float s) {
|
||||
return 0.25 * std::log(s);
|
||||
}
|
||||
|
||||
float sigma_min() {
|
||||
return min_sigma;
|
||||
}
|
||||
|
||||
float sigma_max() {
|
||||
return max_sigma;
|
||||
}
|
||||
};
|
||||
|
||||
float time_snr_shift(float alpha, float t) {
|
||||
if (alpha == 1.0f) {
|
||||
return t;
|
||||
|
||||
@ -103,6 +103,9 @@ public:
|
||||
bool vae_tiling = false;
|
||||
bool stacked_id = false;
|
||||
|
||||
bool is_using_v_parameterization = false;
|
||||
bool is_using_edm_v_parameterization = false;
|
||||
|
||||
std::map<std::string, struct ggml_tensor*> tensors;
|
||||
|
||||
std::string lora_model_dir;
|
||||
@ -543,12 +546,17 @@ public:
|
||||
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
|
||||
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
bool is_using_v_parameterization = false;
|
||||
|
||||
if (sd_version_is_sd2(version)) {
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
@ -573,6 +581,9 @@ public:
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
} else if (is_using_edm_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
} else {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
}
|
||||
@ -1396,7 +1407,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
||||
SDCondition uncond;
|
||||
if (cfg_scale != 1.0) {
|
||||
bool force_zero_embeddings = false;
|
||||
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
|
||||
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
|
||||
force_zero_embeddings = true;
|
||||
}
|
||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user