diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index f050f88e..35669d1e 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -542,7 +542,14 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + positive_sigmas = sigmas[sigmas > 0] + + if positive_sigmas.numel() > 0: + sigma_min = positive_sigmas.min(dim=0)[0] + else: + sigma_min = 0 + + sigma_max = sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]])