diff --git a/README.md b/README.md index 76ad91c..9418e56 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,10 @@ TODO: write this section ## Installation -`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. +`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. + +To run the training and inference scripts, clone this repository and run `pip install -e [train]` +(to install with the `train` extra that includes additional libraries required for training). ## Training diff --git a/k_diffusion/augmentation.py b/k_diffusion/augmentation.py index 920ef86..581e86c 100644 --- a/k_diffusion/augmentation.py +++ b/k_diffusion/augmentation.py @@ -3,7 +3,10 @@ import operator import numpy as np -from skimage import transform +try: + import skimage.transform as skt +except ImportError: + skt = None import torch from torch import nn @@ -31,6 +34,8 @@ def rotate2d(theta): class KarrasAugmentationPipeline: def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8, disable_all=False): + if not skt: + raise ImportError('Please install scikit-image to use KarrasAugmentationPipeline') self.a_prob = a_prob self.a_scale = a_scale self.a_aniso = a_aniso @@ -78,9 +83,9 @@ def __call__(self, image): image_orig = np.array(image, dtype=np.float32) / 255 if image_orig.ndim == 2: image_orig = image_orig[..., None] - tf = transform.AffineTransform(mat.numpy()) + tf = skt.AffineTransform(mat.numpy()) if not self.disable_all: - image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) + image = skt.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) else: image = image_orig cond = torch.zeros_like(cond) diff --git a/k_diffusion/evaluation.py b/k_diffusion/evaluation.py index d7cf37c..da0c888 100644 --- a/k_diffusion/evaluation.py +++ b/k_diffusion/evaluation.py @@ -2,7 +2,6 @@ import os from pathlib import Path -from cleanfid.inception_torchscript import InceptionV3W import clip import torch from torch import nn @@ -16,6 +15,10 @@ class InceptionV3FeatureExtractor(nn.Module): def __init__(self, device='cpu'): super().__init__() + try: + from cleanfid.inception_torchscript import InceptionV3W + except ImportError as ie: + raise ImportError('Please install clean-fid to use InceptionV3FeatureExtractor') from ie path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index c62792c..0000000 --- a/requirements.txt +++ /dev/null @@ -1,17 +0,0 @@ -accelerate -clean-fid -clip-anytorch -dctorch -einops -jsonmerge -kornia -Pillow -safetensors -scikit-image -scipy -torch>=2.1 -torchdiffeq -torchsde -torchvision -tqdm -wandb diff --git a/setup.cfg b/setup.cfg index 1e73b1f..5c14bc3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,9 +11,8 @@ license = MIT [options] packages = find: -install_requires = +install_requires = accelerate - clean-fid clip-anytorch dctorch einops @@ -21,11 +20,15 @@ install_requires = kornia Pillow safetensors - scikit-image scipy torch >= 2.1 torchdiffeq torchsde torchvision tqdm + +[options.extras_require] +train = + clean-fid + scikit-image wandb