diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index f050f88e..afe00c24 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -77,7 +77,7 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=int(s), **kwargs) for s in seed] @staticmethod def sort(a, b):