mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 21:38:58 +00:00
feat: add flow shift parameter (for SD3 and Wan) (#780)
* Add flow shift parameter (for SD3 and Wan) * unify code style and fix some issues --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
21ce9fe2cf
commit
141a4b4113
@ -382,7 +382,8 @@ struct DiscreteFlowDenoiser : public Denoiser {
|
|||||||
|
|
||||||
float sigma_data = 1.0f;
|
float sigma_data = 1.0f;
|
||||||
|
|
||||||
DiscreteFlowDenoiser() {
|
DiscreteFlowDenoiser(float shift = 3.0f)
|
||||||
|
: shift(shift) {
|
||||||
set_parameters();
|
set_parameters();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,7 @@ struct SDParams {
|
|||||||
bool chroma_use_dit_mask = true;
|
bool chroma_use_dit_mask = true;
|
||||||
bool chroma_use_t5_mask = false;
|
bool chroma_use_t5_mask = false;
|
||||||
int chroma_t5_mask_pad = 1;
|
int chroma_t5_mask_pad = 1;
|
||||||
|
float flow_shift = INFINITY;
|
||||||
|
|
||||||
SDParams() {
|
SDParams() {
|
||||||
sd_sample_params_init(&sample_params);
|
sd_sample_params_init(&sample_params);
|
||||||
@ -171,6 +172,7 @@ void print_params(SDParams params) {
|
|||||||
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
|
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
|
||||||
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
|
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
|
||||||
printf(" moe_boundary: %.3f\n", params.moe_boundary);
|
printf(" moe_boundary: %.3f\n", params.moe_boundary);
|
||||||
|
printf(" flow_shift: %.2f\n", params.flow_shift);
|
||||||
printf(" strength(img2img): %.2f\n", params.strength);
|
printf(" strength(img2img): %.2f\n", params.strength);
|
||||||
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
|
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
|
||||||
printf(" seed: %ld\n", params.seed);
|
printf(" seed: %ld\n", params.seed);
|
||||||
@ -278,8 +280,9 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
|
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
|
||||||
printf(" --video-frames video frames (default: 1)\n");
|
printf(" --video-frames video frames (default: 1)\n");
|
||||||
printf(" --fps fps (default: 24)\n");
|
printf(" --fps fps (default: 24)\n");
|
||||||
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
|
printf(" --moe-boundary BOUNDARY timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
|
||||||
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
|
printf(" only enabled if `--high-noise-steps` is set to -1\n");
|
||||||
|
printf(" --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto)\n");
|
||||||
printf(" -v, --verbose print extra info\n");
|
printf(" -v, --verbose print extra info\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -514,6 +517,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
|||||||
{"", "--style-ratio", "", ¶ms.style_ratio},
|
{"", "--style-ratio", "", ¶ms.style_ratio},
|
||||||
{"", "--control-strength", "", ¶ms.control_strength},
|
{"", "--control-strength", "", ¶ms.control_strength},
|
||||||
{"", "--moe-boundary", "", ¶ms.moe_boundary},
|
{"", "--moe-boundary", "", ¶ms.moe_boundary},
|
||||||
|
{"", "--flow-shift", "", ¶ms.flow_shift},
|
||||||
};
|
};
|
||||||
|
|
||||||
options.bool_options = {
|
options.bool_options = {
|
||||||
@ -1181,6 +1185,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.chroma_use_dit_mask,
|
params.chroma_use_dit_mask,
|
||||||
params.chroma_use_t5_mask,
|
params.chroma_use_t5_mask,
|
||||||
params.chroma_t5_mask_pad,
|
params.chroma_t5_mask_pad,
|
||||||
|
params.flow_shift,
|
||||||
};
|
};
|
||||||
|
|
||||||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||||
|
|||||||
@ -681,7 +681,11 @@ public:
|
|||||||
|
|
||||||
if (sd_version_is_sd3(version)) {
|
if (sd_version_is_sd3(version)) {
|
||||||
LOG_INFO("running in FLOW mode");
|
LOG_INFO("running in FLOW mode");
|
||||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
float shift = sd_ctx_params->flow_shift;
|
||||||
|
if (shift == INFINITY) {
|
||||||
|
shift = 3.0;
|
||||||
|
}
|
||||||
|
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
LOG_INFO("running in Flux FLOW mode");
|
LOG_INFO("running in Flux FLOW mode");
|
||||||
float shift = 1.0f; // TODO: validate
|
float shift = 1.0f; // TODO: validate
|
||||||
@ -694,7 +698,11 @@ public:
|
|||||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||||
} else if (sd_version_is_wan(version)) {
|
} else if (sd_version_is_wan(version)) {
|
||||||
LOG_INFO("running in FLOW mode");
|
LOG_INFO("running in FLOW mode");
|
||||||
denoiser = std::make_shared<DiscreteFlowDenoiser>();
|
float shift = sd_ctx_params->flow_shift;
|
||||||
|
if (shift == INFINITY) {
|
||||||
|
shift = 5.0;
|
||||||
|
}
|
||||||
|
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||||
} else if (is_using_v_parameterization) {
|
} else if (is_using_v_parameterization) {
|
||||||
LOG_INFO("running in v-prediction mode");
|
LOG_INFO("running in v-prediction mode");
|
||||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||||
@ -1553,6 +1561,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
|||||||
sd_ctx_params->chroma_use_dit_mask = true;
|
sd_ctx_params->chroma_use_dit_mask = true;
|
||||||
sd_ctx_params->chroma_use_t5_mask = false;
|
sd_ctx_params->chroma_use_t5_mask = false;
|
||||||
sd_ctx_params->chroma_t5_mask_pad = 1;
|
sd_ctx_params->chroma_t5_mask_pad = 1;
|
||||||
|
sd_ctx_params->flow_shift = INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||||
|
|||||||
@ -142,6 +142,7 @@ typedef struct {
|
|||||||
bool chroma_use_dit_mask;
|
bool chroma_use_dit_mask;
|
||||||
bool chroma_use_t5_mask;
|
bool chroma_use_t5_mask;
|
||||||
int chroma_t5_mask_pad;
|
int chroma_t5_mask_pad;
|
||||||
|
float flow_shift;
|
||||||
} sd_ctx_params_t;
|
} sd_ctx_params_t;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user