diff --git a/domainlab/algos/seg/builder_dann_seg.py b/domainlab/algos/seg/builder_dann_seg.py new file mode 100644 index 000000000..6c4e580bc --- /dev/null +++ b/domainlab/algos/seg/builder_dann_seg.py @@ -0,0 +1,89 @@ +""" +builder for Domain Adversarial Neural Network: accept different training scheme +""" +from domainlab.algos.a_algo_builder import NodeAlgoBuilder +from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.observers.b_obvisitor import ObVisitor +from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp +from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential +from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter +from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear +from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim +from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter +from domainlab.models.model_dann import mk_dann +from domainlab.utils.utils_cuda import get_device +from domainlab.models.a_model_seg import AModelSeg + + +class NodeAlgoBuilderDANN(NodeAlgoBuilder): + """ + NodeAlgoBuilderDANN + """ + + def init_business(self, exp): + """ + return trainer, model, observer + """ + task = exp.task + self._task = task + args = exp.args + task.get_list_domains_tr_te(args.tr_d, args.te_d) + device = get_device(args) + msel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) + observer = ObVisitor(msel) + observer = ObVisitorCleanUp(observer) + + builder = FeatExtractNNBuilderChainNodeGetter( + args, arg_name_of_net="nname", arg_path_of_net="npath" + )() # request, @FIXME, constant string + + unet, net_encoder = builder.init_business( + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h), + ) + + dim_feat = get_flat_dim( + net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w + ) + + net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y) + + net_discriminator = self.reset_aux_net(net_encoder) + model = mk_dann(parent_class=AModelSeg, net_seg=unet)( + list_d_tr=task.list_domain_tr, + alpha=args.gamma_reg, + net_encoder=net_encoder, + net_discriminator=net_discriminator, + builder=self) + + model = self.init_next_model(model, exp) + trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") + trainer.init_business(model, task, observer, device, args) + if trainer.name == "hyperscheduler": + trainer.set_scheduler( + HyperSchedulerWarmupExponential, + total_steps=trainer.num_batches * args.warmup, + flag_update_epoch=False, + flag_update_batch=True, + ) + return trainer, model, observer, device + + def reset_aux_net(self, net_encoder): + """ + reset auxilliary neural network from task + note that net_encoder can also be a method like extract_semantic_feat + """ + dim_feat = get_flat_dim( + net_encoder, + self._task.isize.i_c, + self._task.isize.i_h, + self._task.isize.i_w, + ) + net_discriminator = ClassifDropoutReluLinear( + dim_feat, len(self._task.list_domain_tr) + ) + return net_discriminator diff --git a/domainlab/algos/seg/builder_erm_seg.py b/domainlab/algos/seg/builder_erm_seg.py new file mode 100644 index 000000000..7c0a8c10f --- /dev/null +++ b/domainlab/algos/seg/builder_erm_seg.py @@ -0,0 +1,48 @@ + +# CHANGE IT + +""" +builder for erm +""" +from domainlab.algos.a_algo_builder import NodeAlgoBuilder +from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.observers.b_obvisitor import ObVisitor +from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter +from domainlab.algos.utils import split_net_feat_last +from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter +from domainlab.models.model_erm import mk_erm +from domainlab.utils.utils_cuda import get_device +from domainlab.models.a_model_seg import CModelSeg + + +class NodeAlgoBuilderERM(NodeAlgoBuilder): + """ + builder for erm + """ + + def init_business(self, exp): + """ + return trainer, model, observer + """ + task = exp.task + args = exp.args + device = get_device(args) + model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) + observer = ObVisitor(model_sel) + + builder = FeatExtractNNBuilderChainNodeGetter( + args, arg_name_of_net="nname", arg_path_of_net="npath" + )() # request, # @FIXME, constant string + + net, net_feature = builder.init_business( + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w), + ) + + model = CModelSeg(net=net) + trainer = TrainerChainNodeGetter(args.trainer)(default="basic") + return trainer, model, observer, device diff --git a/domainlab/models/a_model_seg.py b/domainlab/models/a_model_seg.py new file mode 100644 index 000000000..e160f086d --- /dev/null +++ b/domainlab/models/a_model_seg.py @@ -0,0 +1,127 @@ +""" +operations that all claasification model should have +""" + +import abc +import math + +import numpy as np +import pandas as pd +import torch +from torch import nn +from torch.nn import functional as F + +from domainlab.models.a_model import AModel +from domainlab.utils.logger import Logger +from domainlab.utils.perf import PerfClassif +from domainlab.utils.perf_metrics import PerfMetricClassif +from domainlab.utils.utils_class import store_args +from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic +from domainlab.utils.utils_seg import DiceLoss + + +class ModelSeg(AModelSeg): + def __init__(self, unet): + self.unet = unet + + def cal_task_loss(self, tensor_x, tensor_y): + y_hat = self.unet(tensor_x) + return DiceLoss(y_hat, tensor_y) + + +class AModelSeg(AModel, metaclass=abc.ABCMeta): + """ + operations that all classification model should have + """ + + match_feat_fun_na = "cal_logit_y" + + def extend(self, model): + super().extend(model) + self._net_classifier = model.net_classifier + + @property + def metric4msel(self): + return "acc" + + @net_classifier.setter + def net_classifier(self, net_classifier): + self._net_classifier = net_classifier + + def create_perf_obj(self, task): + """ + for classification, dimension of target can be quieried from task + """ + self.perf_metric = PerfMetricClassif(task.dim_y) + return self.perf_metric + + def cal_perf_metric(self, loader, device): + """ + classification performance metric + """ + metric = None + with torch.no_grad(): + if loader is not None: + metric = self.perf_metric.cal_metrics(self, loader, device) + confmat = metric.pop("confmat") + logger = Logger.get_logger() + logger.info("scalar performance:") + logger.info(str(metric)) + logger.debug("confusion matrix:") + logger.debug(pd.DataFrame(confmat)) + metric["confmat"] = confmat + return metric + + def evaluate(self, loader_te, device): + """ + for classification task, use the current model to cal acc + """ + acc = PerfClassif.cal_acc(self, loader_te, device) + logger = Logger.get_logger() + logger.info(f"before training, model accuracy: {acc}") + + def extract_semantic_feat(self, tensor_x): + """ + flatten the shape of feature tensor from super() + """ + feat_tensor = super().extract_semantic_feat(tensor_x) + feat = feat_tensor.reshape(feat_tensor.shape[0], -1) + return feat + + @store_args + def __init__(self, **kwargs): + """ + :param list_str_y: list of fixed order, each element is a class label + """ + super().__init__() + for key, value in kwargs.items(): + if key == "net_seg": + net_seg = value + + self.seg_cross_entropy_loss = nn.CrossEntropyLoss() + self.net_seg = net_seg + + def cal_task_loss(self, tensor_x, tensor_y): + """ + Calculate the task loss. Used within the `cal_loss` methods of models + that are subclasses of `AModelClassif`. Cross entropy loss for + classification is used here by default but could be modified by + subclasses + as necessary. + + :param tensor_x: input + :param tensor_y: label + :return: task loss + """ + masks_pred = self.net_seg(tensor_x) + loss = self.seg_cross_entropy_loss(masks_pred.squeeze(1), tensor_y.float()) + loss += DiceLoss(masks_pred, tensor_y) + return loss + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + """ + for ERM to adapt to the interface of other regularized learners + """ + device = tensor_x.device + bsize = tensor_x.shape[0] + return [torch.zeros(bsize).to(device)], [0.0] diff --git a/domainlab/utils/utils_seg.py b/domainlab/utils/utils_seg.py new file mode 100644 index 000000000..b0433901c --- /dev/null +++ b/domainlab/utils/utils_seg.py @@ -0,0 +1,69 @@ +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# implementation of Dice Loss +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() + + def forward(self, inputs, targets, smooth=1): + + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) + + return 1 - dice + + +class DiceBCELoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceBCELoss, self).__init__() + + def forward(self, inputs, targets, smooth=1): + + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) + BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') + Dice_BCE = BCE + dice_loss + + return Dice_BCE + + +class IoULoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(IoULoss, self).__init__() + + def forward(self, inputs, targets, smooth=1): + + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + #intersection is equivalent to True Positive count + #union is the mutually inclusive area of all labels & predictions + intersection = (inputs * targets).sum() + total = (inputs + targets).sum() + union = total - intersection + + IoU = (intersection + smooth)/(union + smooth) + + return 1 - IoU diff --git a/examples/conf/toy_seg_erm.yaml b/examples/conf/toy_seg_erm.yaml new file mode 100644 index 000000000..1737c4a1e --- /dev/null +++ b/examples/conf/toy_seg_erm.yaml @@ -0,0 +1,6 @@ +te_d: domain1 # domain name of test domain +tpath: examples/tasks/task_seg_toy.py # python file path to specify the task +bs: 8 # batch size +model: seg # +epos: 1 # number of epochs +npath: examples/nets/seg/unet.py # neural network for class classification diff --git a/examples/nets/seg/unet.py b/examples/nets/seg/unet.py new file mode 100644 index 000000000..96702ea94 --- /dev/null +++ b/examples/nets/seg/unet.py @@ -0,0 +1,84 @@ + +from .unet_parts import * + + +class UNET_Feature_Extractor(nn.Module): + # The network is shared between unet feature extractor and the unet + def __init__(self, net): + self.net = net + + def forward(self, x): + return self.net.go_down(x) + + +class UNet(nn.Module): + def __init__(self, n_classes, n_channels=3, bilinear=False): # 3: RGB channels + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes # in binary seg. the value is 1 + self.bilinear = bilinear + + self.inc = (DoubleConv(self.n_channels, 64)) + self.down1 = (Down(64, 128)) + self.down2 = (Down(128, 256)) + self.down3 = (Down(256, 512)) + factor = 2 if self.bilinear else 1 + self.down4 = (Down(512, 1024 // factor)) + self.up1 = (Up(1024, 512 // factor, self.bilinear)) + self.up2 = (Up(512, 256 // factor, self.bilinear)) + self.up3 = (Up(256, 128 // factor, self.bilinear)) + self.up4 = (Up(128, 64, self.bilinear)) + self.outc = (OutConv(64, self.n_classes)) + + self.x1 = None + self.x2 = None + self.x3 = None + self.x4 = None + + def go_down(self, x): + self.x1 = self.inc(x) + self.x2 = self.down1(self.x1) + self.x3 = self.down2(self.x2) + self.x4 = self.down3(self.x3) + self.x5 = self.down4(self.x4) + + def forward(self, x): + x5 = self.go_down(x) + x = self.up1(x5, self.x4) + x = self.up2(x, self.x3) + x = self.up3(x, self.x2) + x = self.up4(x, self.x1) + logits = self.outc(x) + return logits + + def use_checkpointing(self): + self.inc = torch.utils.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint(self.down2) + self.down3 = torch.utils.checkpoint(self.down3) + self.down4 = torch.utils.checkpoint(self.down4) + self.up1 = torch.utils.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint(self.up2) + self.up3 = torch.utils.checkpoint(self.up3) + self.up4 = torch.utils.checkpoint(self.up4) + self.outc = torch.utils.checkpoint(self.outc) + +# CHANGEME: user is required to implement the following function +# with **exact** signature to return a neural network architecture for +# classification of dim_y number of classes if remove_last_layer=False +# or return the same neural network without the last layer if +# remove_last_layer=False. +def build_feat_extract_net(dim_y, remove_last_layer): + """ + This function is compulsory to return a neural network feature extractor. + :param dim_y: number of classes to be classify can be None + if remove_last_layer = True + :param remove_last_layer: for resnet for example, whether + remove the last layer or not. + """ + + unet = UNet(dim_y) # dim_y: in binary seg. the value is 1 + # the encoder part of the unet is provided as feature extractor + unet_encoder_part = UNET_Feature_Extractor(unet) + + return unet, unet_encoder_part diff --git a/examples/nets/seg/unet_parts.py b/examples/nets/seg/unet_parts.py new file mode 100644 index 000000000..986ba251f --- /dev/null +++ b/examples/nets/seg/unet_parts.py @@ -0,0 +1,77 @@ +""" Parts of the U-Net model """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/examples/tasks/task_seg_toy.py b/examples/tasks/task_seg_toy.py new file mode 100644 index 000000000..495832033 --- /dev/null +++ b/examples/tasks/task_seg_toy.py @@ -0,0 +1,59 @@ +import torch +from torch.utils.data import Dataset +import numpy as np + +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.tasks.utils_task import ImSize + + +class ToySegmentationDataset(Dataset): + def __init__(self, num_samples=1000, image_size=(64, 64), noise=0.1): + self.num_samples = num_samples + self.image_size = image_size + self.noise = noise + self.generate_data() + + def generate_data(self): + # Generate random images + self.images = torch.rand(self.num_samples, 1, *self.image_size) # FIXME: we have to change this to 3 channels + # Generate random segmentation masks + self.masks = torch.zeros(self.num_samples, 1, *self.image_size) + for i in range(self.num_samples): + # Randomly select a segment size + segment_size = np.random.randint(10, 20) + # Randomly select a position for the segment + x = np.random.randint(0, self.image_size[0] - segment_size) + y = np.random.randint(0, self.image_size[1] - segment_size) + # Add the segment to the mask + self.masks[i, 0, x:x+segment_size, y:y+segment_size] = 1 + # Add noise to the masks + self.masks += self.noise * torch.randn_like(self.masks) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + return self.images[idx], self.masks[idx] + + +task = mk_task_dset(isize=ImSize(1, 64, 64), dim_y=2, taskna="custom_task") # FIXME: remove dim_y +task.add_domain( + name="domain1", + dset_tr=ToySegmentationDataset(10), + dset_val=ToySegmentationDataset(5), +) + +task.add_domain( + name="domain2", + dset_tr=ToySegmentationDataset(15), + dset_val=ToySegmentationDataset(8), +) + +task.add_domain( + name="domain3", + dset_tr=ToySegmentationDataset(20), + dset_val=ToySegmentationDataset(10), +) + +def get_task(na=None): + return task