diff --git a/configs/base_config.py b/configs/base_config.py index cd747f7..4897c0f 100644 --- a/configs/base_config.py +++ b/configs/base_config.py @@ -23,7 +23,7 @@ def __init__(self): parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd") parser.add_argument("--set", help="only Flower102Pytorch is currently supported", type=str, default="Flower102Pytorch", - choices=['Flower102Pytorch','Aircraft100Pytorch']) + choices=['Flower102Pytorch','Aircraft100Pytorch','Aircrafts', 'Dog120']) parser.add_argument( "-a", "--arch", metavar="ARCH", default="Split_ResNet18", help="model architecture", @@ -276,7 +276,7 @@ def parse(self,args): elif self.cfg.set in ['MIT67']: self.cfg.num_cls = 67 self.cfg.eval_tst = False - elif self.cfg.set == 'Aircraft100' or self.cfg.set == 'Aircraft100Pytorch': + elif self.cfg.set == 'Aircraft100' or self.cfg.set == 'Aircraft100Pytorch' or self.cfg.set == 'Aircrafts': self.cfg.num_cls = 100 self.cfg.eval_tst = True else: diff --git a/data/__init__.py b/data/__init__.py index 7a86ecc..e93990e 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,2 +1,4 @@ from data.flower import Flower102Pytorch +from data.aircrafts import Aircrafts +from data.dog import Dog120 from data.aircraft import Aircraft100Pytorch diff --git a/data/aircrafts.py b/data/aircrafts.py new file mode 100644 index 0000000..e41f95a --- /dev/null +++ b/data/aircrafts.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +from data.datasets import load_dataset + + +class Aircrafts: + def __init__(self, cfg): + + if cfg.cs_kd: + sampler_type = 'pair' + else: + sampler_type = 'default' + + trainloader, valloader, testloader = load_dataset('Aircrafts', + cfg.data, + sampler_type, batch_size=cfg.batch_size) + self.num_classes = trainloader.dataset.num_classes + + self.train_loader = trainloader + self.tst_loader = testloader + self.val_loader = valloader \ No newline at end of file diff --git a/data/datasets.py b/data/datasets.py new file mode 100644 index 0000000..bbe885d --- /dev/null +++ b/data/datasets.py @@ -0,0 +1,283 @@ +#This script is borrowed from https://github.com/alinlab/cs-kd/blob/master/datasets.py +import csv, torchvision, numpy as np, random, os +from PIL import Image + +from torch.utils.data import Sampler, Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler, Subset +from torchvision import transforms, datasets +from collections import defaultdict +import math + +class PairBatchSampler(Sampler): + def __init__(self, dataset, batch_size, num_iterations=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_iterations = num_iterations + + def __iter__(self): + indices = list(range(len(self.dataset))) + random.shuffle(indices) + for k in range(len(self)): + if self.num_iterations is None: + offset = k*self.batch_size + batch_indices = indices[offset:offset+self.batch_size] + else: + batch_indices = random.sample(range(len(self.dataset)), + self.batch_size) + + pair_indices = [] + for idx in batch_indices: + y = self.dataset.get_class(idx) + pair_indices.append(random.choice(self.dataset.classwise_indices[y])) + + yield batch_indices + pair_indices + + def __len__(self): + if self.num_iterations is None: + return (len(self.dataset)+self.batch_size-1) // self.batch_size + else: + return self.num_iterations + + +class DatasetWrapper(Dataset): + # Additinoal attributes + # - indices + # - classwise_indices + # - num_classes + # - get_class + + def __init__(self, dataset, indices=None): + self.base_dataset = dataset + if indices is None: + self.indices = list(range(len(dataset))) + else: + self.indices = indices + + # torchvision 0.2.0 compatibility + if torchvision.__version__.startswith('0.2'): + if isinstance(self.base_dataset, datasets.ImageFolder): + self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs] + else: + if self.base_dataset.train: + self.base_dataset.targets = self.base_dataset.train_labels + else: + self.base_dataset.targets = self.base_dataset.test_labels + + self.classwise_indices = defaultdict(list) + for i in range(len(self)): + y = self.base_dataset.targets[self.indices[i]] + self.classwise_indices[y].append(i) + self.num_classes = max(self.classwise_indices.keys())+1 + + def __getitem__(self, i): + return self.base_dataset[self.indices[i]] + + def __len__(self): + return len(self.indices) + + def get_class(self, i): + return self.base_dataset.targets[self.indices[i]] + + +class ConcatWrapper(Dataset): # TODO: Naming + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + @staticmethod + def numcls(sequence): + s = 0 + for e in sequence: + l = e.num_classes + s += l + return s + + @staticmethod + def clsidx(sequence): + r, s, n = defaultdict(list), 0, 0 + for e in sequence: + l = e.classwise_indices + for c in range(s, s + e.num_classes): + t = np.asarray(l[c-s]) + n + r[c] = t.tolist() + s += e.num_classes + n += len(e) + return r + + def __init__(self, datasets): + super(ConcatWrapper, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + # for d in self.datasets: + # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + self.num_classes = self.numcls(self.datasets) + self.classwise_indices = self.clsidx(self.datasets) + + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + def get_class(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + true_class = self.datasets[dataset_idx].base_dataset.targets[self.datasets[dataset_idx].indices[sample_idx]] + return self.datasets[dataset_idx].base_dataset.target_transform(true_class) + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + + + +def load_dataset(name, root, sample='default', **kwargs): + # Dataset + if name in ['imagenet','tinyimagenet', 'CUB200', 'STANFORD120', 'MIT67', 'Aircrafts', 'Dog120']: + # TODO + if name == 'tinyimagenet': + transform_train = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + transform_test = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + + train_val_dataset_dir = os.path.join(root, "train") + test_dataset_dir = os.path.join(root, "val") + + trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) + valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) + + elif name == 'imagenet': + transform_train = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + transform_test = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + train_val_dataset_dir = os.path.join(root, "train") + test_dataset_dir = os.path.join(root, "val") + + trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) + valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) + + else: + transform_train = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + transform_test = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + + train_val_dataset_dir = os.path.join(root, name, "train") + trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) + + if name in ['Aircrafts']: + val_dataset_dir = os.path.join(root, name, "val") + test_dataset_dir = os.path.join(root, name, "test") + valset = DatasetWrapper(datasets.ImageFolder(root=val_dataset_dir, transform=transform_test)) + testset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) + + else: + test_dataset_dir = os.path.join(root, name, "test") + valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) + + elif name.startswith('cifar'): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + if name == 'cifar10': + CIFAR = datasets.CIFAR10 + else: + CIFAR = datasets.CIFAR100 + + trainset = DatasetWrapper(CIFAR(root, train=True, download=True, transform=transform_train)) + valset = DatasetWrapper(CIFAR(root, train=False, download=True, transform=transform_test)) + else: + raise Exception('Unknown dataset: {}'.format(name)) + + # Sampler + if sample == 'default': + get_train_sampler = lambda d: BatchSampler(RandomSampler(d), kwargs['batch_size'], False) + get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) + + elif sample == 'pair': + get_train_sampler = lambda d: PairBatchSampler(d, kwargs['batch_size']) + get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) + + else: + raise Exception('Unknown sampling: {}'.format(sampling)) + + trainloader = DataLoader(trainset, batch_sampler=get_train_sampler(trainset), num_workers=10, pin_memory=True) + valloader = DataLoader(valset, batch_sampler=get_test_sampler(valset), num_workers=10, pin_memory=True) + + epoch_size = len(trainset) + trainloader.num_batches = math.ceil(epoch_size / kwargs['batch_size']) + trainloader.num_files = epoch_size + + epoch_size = len(valset) + valloader.num_batches = math.ceil(epoch_size / kwargs['batch_size']) + valloader.num_files = epoch_size + + if name in ['Aircrafts']: + testloader = DataLoader(testset, batch_sampler=get_test_sampler(testset), num_workers=10, pin_memory=True) + epoch_size = len(testset) + testloader.num_batches = math.ceil(epoch_size / kwargs['batch_size']) + testloader.num_files = epoch_size + return trainloader, valloader, testloader + else: + return trainloader, valloader + diff --git a/data/dog.py b/data/dog.py new file mode 100644 index 0000000..bd22f38 --- /dev/null +++ b/data/dog.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +from data.datasets import load_dataset + + +class Dog120: + def __init__(self, cfg): + + if cfg.cs_kd: + sampler_type = 'pair' + else: + sampler_type = 'default' + + trainloader, valloader = load_dataset('Dog120', + cfg.data, + sampler_type, batch_size=cfg.batch_size) + self.num_classes = trainloader.dataset.num_classes + + self.train_loader = trainloader + self.tst_loader = valloader + self.val_loader = valloader \ No newline at end of file