From b95784fd4227fd3c9c86b5ed33ba342786031cf1 Mon Sep 17 00:00:00 2001 From: Matt Parnell Date: Fri, 19 May 2023 23:41:25 -0500 Subject: [PATCH 1/2] add missing check for empty tensors --- k_diffusion/sampling.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index f050f88e..1cd9634a 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -506,9 +506,16 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" - extra_args = {} if extra_args is None else extra_args +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None) + """DPM-Solver++ (stochastic).""" + positive_sigmas = sigmas[sigmas > 0] + + if positive_sigmas.numel() > 0: + sigma_min = positive_sigmas.min(dim=0)[0] + else: + sigma_min = 0 + + sigma_max = sigmas.max() noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() @@ -542,7 +549,14 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + positive_sigmas = sigmas[sigmas > 0] + + if positive_sigmas.numel() > 0: + sigma_min = positive_sigmas.min(dim=0)[0] + else: + sigma_min = 0 + + sigma_max = sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) From d4cf787f40dfeff8832aeca2527592778a081590 Mon Sep 17 00:00:00 2001 From: Matt Parnell Date: Fri, 19 May 2023 23:48:54 -0500 Subject: [PATCH 2/2] oops --- k_diffusion/sampling.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 1cd9634a..35669d1e 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -506,16 +506,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None) - """DPM-Solver++ (stochastic).""" - positive_sigmas = sigmas[sigmas > 0] - - if positive_sigmas.numel() > 0: - sigma_min = positive_sigmas.min(dim=0)[0] - else: - sigma_min = 0 - - sigma_max = sigmas.max() +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp()