fix: avoid crash when using taesd for preview only (#1141)

This commit is contained in:
leejet 2025-12-24 23:30:12 +08:00 committed by GitHub
parent a0adcfb148
commit 860a78e248
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 44 deletions

View File

@ -105,9 +105,9 @@ struct SDSvrParams {
std::string listen_ip = "127.0.0.1"; std::string listen_ip = "127.0.0.1";
int listen_port = 1234; int listen_port = 1234;
std::string serve_html_path; std::string serve_html_path;
bool normal_exit = false; bool normal_exit = false;
bool verbose = false; bool verbose = false;
bool color = false; bool color = false;
ArgOptions get_options() { ArgOptions get_options() {
ArgOptions options; ArgOptions options;

View File

@ -648,7 +648,7 @@ namespace Qwen {
modulate_index_vec.insert(modulate_index_vec.end(), num_ref_img_tokens, 1.f); modulate_index_vec.insert(modulate_index_vec.end(), num_ref_img_tokens, 1.f);
} }
modulate_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, modulate_index_vec.size()); modulate_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, modulate_index_vec.size());
set_backend_tensor_data(modulate_index, modulate_index_vec.data()); set_backend_tensor_data(modulate_index, modulate_index_vec.data());
} }

View File

@ -591,8 +591,8 @@ public:
vae_backend = backend; vae_backend = backend;
} }
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
if (!use_tiny_autoencoder) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend, first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
@ -601,57 +601,56 @@ public:
version); version);
first_stage_model->alloc_params_buffer(); first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model"); first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else if (version == VERSION_CHROMA_RADIANCE) {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
} else { } else {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN(
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f",
vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
}
}
if (use_tiny_autoencoder) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend, tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
"decoder", "decoder",
vae_decode_only, vae_decode_only,
version); version);
if (sd_ctx_params->vae_conv_direct) { } else {
LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
tae_first_stage->set_conv2d_direct_enabled(true); offload_params_to_cpu,
} tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
} }
} else if (version == VERSION_CHROMA_RADIANCE) {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
} else if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN(
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f",
vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else if (use_tiny_autoencoder) {
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
if (sd_ctx_params->vae_conv_direct) { if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model"); LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->set_conv2d_direct_enabled(true); tae_first_stage->set_conv2d_direct_enabled(true);
} }
} }
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
if (strlen(SAFE_STR(sd_ctx_params->control_net_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->control_net_path)) > 0) {
ggml_backend_t controlnet_backend = nullptr; ggml_backend_t controlnet_backend = nullptr;