diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 42b909e4f..a1b92f77e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -579,7 +579,7 @@ int main(int argc, const char* argv[]) { } if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { - gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx); + gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method); } if (cli_params.mode == IMG_GEN) { @@ -752,4 +752,4 @@ int main(int argc, const char* argv[]) { release_all_resources(); return 0; -} \ No newline at end of file +} diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 44bd3ccac..b7947c801 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2773,13 +2773,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { return EULER_A_SAMPLE_METHOD; } -enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) { +enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { auto edm_v_denoiser = std::dynamic_pointer_cast(sd_ctx->sd->denoiser); if (edm_v_denoiser) { return EXPONENTIAL_SCHEDULER; } } + if (sample_method == LCM_SAMPLE_METHOD) { + return LCM_SCHEDULER; + } return DISCRETE_SCHEDULER; } @@ -3214,6 +3217,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps); } } else { + scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler; + if (scheduler == SCHEDULER_COUNT) { + scheduler = sd_get_default_scheduler(sd_ctx, sample_method); + } sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_ctx->sd->get_image_seq_len(height, width), sd_img_gen_params->sample_params.scheduler, diff --git a/stable-diffusion.h b/stable-diffusion.h index 9266ba437..10a170495 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -334,7 +334,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx); -SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx); +SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);