Skip to content
Open
89 changes: 89 additions & 0 deletions domainlab/algos/seg/builder_dann_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
builder for Domain Adversarial Neural Network: accept different training scheme
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear
from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.models.a_model_seg import AModelSeg


class NodeAlgoBuilderDANN(NodeAlgoBuilder):
"""
NodeAlgoBuilderDANN
"""

def init_business(self, exp):
"""
return trainer, model, observer
"""
task = exp.task
self._task = task
args = exp.args
task.get_list_domains_tr_te(args.tr_d, args.te_d)
device = get_device(args)
msel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
observer = ObVisitor(msel)
observer = ObVisitorCleanUp(observer)

builder = FeatExtractNNBuilderChainNodeGetter(
args, arg_name_of_net="nname", arg_path_of_net="npath"
)() # request, @FIXME, constant string

unet, net_encoder = builder.init_business(
flag_pretrain=True,
dim_out=task.dim_y,
remove_last_layer=False,
args=args,
isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h),
)

dim_feat = get_flat_dim(
net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w
)

net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y)

net_discriminator = self.reset_aux_net(net_encoder)
model = mk_dann(parent_class=AModelSeg, net_seg=unet)(
list_d_tr=task.list_domain_tr,
alpha=args.gamma_reg,
net_encoder=net_encoder,
net_discriminator=net_discriminator,
builder=self)

model = self.init_next_model(model, exp)
trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler")
trainer.init_business(model, task, observer, device, args)
if trainer.name == "hyperscheduler":
trainer.set_scheduler(
HyperSchedulerWarmupExponential,
total_steps=trainer.num_batches * args.warmup,
flag_update_epoch=False,
flag_update_batch=True,
)
return trainer, model, observer, device

def reset_aux_net(self, net_encoder):
"""
reset auxilliary neural network from task
note that net_encoder can also be a method like extract_semantic_feat
"""
dim_feat = get_flat_dim(
net_encoder,
self._task.isize.i_c,
self._task.isize.i_h,
self._task.isize.i_w,
)
net_discriminator = ClassifDropoutReluLinear(
dim_feat, len(self._task.list_domain_tr)
)
return net_discriminator
48 changes: 48 additions & 0 deletions domainlab/algos/seg/builder_erm_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

# CHANGE IT

"""
builder for erm
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.algos.utils import split_net_feat_last
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_erm import mk_erm
from domainlab.utils.utils_cuda import get_device
from domainlab.models.a_model_seg import CModelSeg


class NodeAlgoBuilderERM(NodeAlgoBuilder):
"""
builder for erm
"""

def init_business(self, exp):
"""
return trainer, model, observer
"""
task = exp.task
args = exp.args
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
observer = ObVisitor(model_sel)

builder = FeatExtractNNBuilderChainNodeGetter(
args, arg_name_of_net="nname", arg_path_of_net="npath"
)() # request, # @FIXME, constant string

net, net_feature = builder.init_business(
flag_pretrain=True,
dim_out=task.dim_y,
remove_last_layer=False,
args=args,
isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w),
)

model = CModelSeg(net=net)
trainer = TrainerChainNodeGetter(args.trainer)(default="basic")
return trainer, model, observer, device
127 changes: 127 additions & 0 deletions domainlab/models/a_model_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
operations that all claasification model should have
"""

import abc
import math

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F

from domainlab.models.a_model import AModel
from domainlab.utils.logger import Logger
from domainlab.utils.perf import PerfClassif
from domainlab.utils.perf_metrics import PerfMetricClassif
from domainlab.utils.utils_class import store_args
from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic
from domainlab.utils.utils_seg import DiceLoss


class ModelSeg(AModelSeg):
def __init__(self, unet):
self.unet = unet

def cal_task_loss(self, tensor_x, tensor_y):
y_hat = self.unet(tensor_x)
return DiceLoss(y_hat, tensor_y)


class AModelSeg(AModel, metaclass=abc.ABCMeta):
"""
operations that all classification model should have
"""

match_feat_fun_na = "cal_logit_y"

def extend(self, model):
super().extend(model)
self._net_classifier = model.net_classifier

@property
def metric4msel(self):
return "acc"

@net_classifier.setter
def net_classifier(self, net_classifier):
self._net_classifier = net_classifier

def create_perf_obj(self, task):
"""
for classification, dimension of target can be quieried from task
"""
self.perf_metric = PerfMetricClassif(task.dim_y)
return self.perf_metric

def cal_perf_metric(self, loader, device):
"""
classification performance metric
"""
metric = None
with torch.no_grad():
if loader is not None:
metric = self.perf_metric.cal_metrics(self, loader, device)
confmat = metric.pop("confmat")
logger = Logger.get_logger()
logger.info("scalar performance:")
logger.info(str(metric))
logger.debug("confusion matrix:")
logger.debug(pd.DataFrame(confmat))
metric["confmat"] = confmat
return metric

def evaluate(self, loader_te, device):
"""
for classification task, use the current model to cal acc
"""
acc = PerfClassif.cal_acc(self, loader_te, device)
logger = Logger.get_logger()
logger.info(f"before training, model accuracy: {acc}")

def extract_semantic_feat(self, tensor_x):
"""
flatten the shape of feature tensor from super()
"""
feat_tensor = super().extract_semantic_feat(tensor_x)
feat = feat_tensor.reshape(feat_tensor.shape[0], -1)
return feat

@store_args
def __init__(self, **kwargs):
"""
:param list_str_y: list of fixed order, each element is a class label
"""
super().__init__()
for key, value in kwargs.items():
if key == "net_seg":
net_seg = value

self.seg_cross_entropy_loss = nn.CrossEntropyLoss()
self.net_seg = net_seg

def cal_task_loss(self, tensor_x, tensor_y):
"""
Calculate the task loss. Used within the `cal_loss` methods of models
that are subclasses of `AModelClassif`. Cross entropy loss for
classification is used here by default but could be modified by
subclasses
as necessary.

:param tensor_x: input
:param tensor_y: label
:return: task loss
"""
masks_pred = self.net_seg(tensor_x)
loss = self.seg_cross_entropy_loss(masks_pred.squeeze(1), tensor_y.float())
loss += DiceLoss(masks_pred, tensor_y)
return loss

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
for ERM to adapt to the interface of other regularized learners
"""
device = tensor_x.device
bsize = tensor_x.shape[0]
return [torch.zeros(bsize).to(device)], [0.0]
69 changes: 69 additions & 0 deletions domainlab/utils/utils_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F


# implementation of Dice Loss
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)

#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

return 1 - dice


class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)

#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

intersection = (inputs * targets).sum()
dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
Dice_BCE = BCE + dice_loss

return Dice_BCE


class IoULoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(IoULoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)

#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

#intersection is equivalent to True Positive count
#union is the mutually inclusive area of all labels & predictions
intersection = (inputs * targets).sum()
total = (inputs + targets).sum()
union = total - intersection

IoU = (intersection + smooth)/(union + smooth)

return 1 - IoU
6 changes: 6 additions & 0 deletions examples/conf/toy_seg_erm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
te_d: domain1 # domain name of test domain
tpath: examples/tasks/task_seg_toy.py # python file path to specify the task
bs: 8 # batch size
model: seg #
epos: 1 # number of epochs
npath: examples/nets/seg/unet.py # neural network for class classification
Loading