From 7baea73853334c0e3a0d0c7f52588f09664f9a65 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Mon, 15 Apr 2024 12:57:41 +0200 Subject: [PATCH 01/21] fixed issue where hyperparameters are colliding --- domainlab/models/a_model.py | 1 + domainlab/models/model_dann.py | 6 ++++-- domainlab/models/model_diva.py | 12 ++++++++---- domainlab/models/model_hduva.py | 20 +++++++++++--------- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index beb867167..fdc1c1299 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -20,6 +20,7 @@ def __init__(self): self.list_d_tr = None self.visitor = None self._net_invar_feat = None + self.id = str(id(self)) + "_model" def extend(self, model): """ diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 2abd3feda..8c74a16bc 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -85,13 +85,15 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ method of hyperparameter scheduler - self.alpha = dict_rst["alpha"] + self.alpha = dict_rst[self.id + "_alpha"] def hyper_init(self, functor_scheduler): """hyper_init. :param functor_scheduler: """ - return functor_scheduler(trainer=None, alpha=self.alpha) + parameters = {} + parameters[self.id + "_alpha"] = self.alpha + return functor_scheduler(trainer=None, **parameters) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): _ = others diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index b67272784..1be0d2613 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -98,9 +98,9 @@ def hyper_update(self, epoch, fun_scheduler): :param fun_scheduler: """ dict_rst = fun_scheduler(epoch) - self.beta_d = dict_rst["beta_d"] - self.beta_y = dict_rst["beta_y"] - self.beta_x = dict_rst["beta_x"] + self.beta_d = dict_rst[self.id + "_beta_d"] + self.beta_y = dict_rst[self.id + "_beta_x"] + self.beta_x = dict_rst[self.id + "_beta_y"] def hyper_init(self, functor_scheduler): """ @@ -108,8 +108,12 @@ def hyper_init(self, functor_scheduler): :param functor_scheduler: the class name of the scheduler """ + parameters = {} + parameters[self.id + "_beta_d"] = self.beta_d + parameters[self.id + "_beta_y"] = self.beta_y + parameters[self.id + "_beta_x"] = self.beta_x return functor_scheduler( - trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x + trainer=None, **parameters ) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 61411982e..2a3edd7cb 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ function of hyper-para-scheduler object - self.beta_d = dict_rst["beta_d"] - self.beta_y = dict_rst["beta_y"] - self.beta_x = dict_rst["beta_x"] - self.beta_t = dict_rst["beta_t"] + self.beta_d = dict_rst[self.id + "_beta_d"] + self.beta_y = dict_rst[self.id + "_beta_y"] + self.beta_x = dict_rst[self.id + "_beta_x"] + self.beta_t = dict_rst[self.id + "_beta_t"] def hyper_init(self, functor_scheduler): """hyper_init. @@ -78,12 +78,14 @@ def hyper_init(self, functor_scheduler): # calling the constructor of the hyper-parameter-scheduler class, so that this scheduler # class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y} # constructor signature is def __init__(self, **kwargs): + + parameters = {} + parameters[self.id + "_beta_d"] = self.beta_d + parameters[self.id + "_beta_y"] = self.beta_y + parameters[self.id + "_beta_x"] = self.beta_x + parameters[self.id + "_beta_t"] = self.beta_t return functor_scheduler( - trainer=None, - beta_d=self.beta_d, - beta_y=self.beta_y, - beta_x=self.beta_x, - beta_t=self.beta_t, + trainer=None, **parameters ) @store_args From 07086a2389ff6e90a14e13dbd25d4dd7921fa317 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 16 Apr 2024 17:01:53 +0200 Subject: [PATCH 02/21] changed id to name for model identification --- domainlab/models/a_model.py | 25 ++++++++++++++++++++++++- domainlab/models/model_dann.py | 4 ++-- domainlab/models/model_diva.py | 12 ++++++------ domainlab/models/model_hduva.py | 17 ++++++++--------- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index fdc1c1299..b53800b8e 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -20,7 +20,6 @@ def __init__(self): self.list_d_tr = None self.visitor = None self._net_invar_feat = None - self.id = str(id(self)) + "_model" def extend(self, model): """ @@ -179,3 +178,27 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset + + @property + def p_na_prefix(self): + """ + common prefix for Models + """ + return "Model" + + @property + def name(self): + """ + get the name of the algorithm + """ + na_prefix = self.p_na_prefix + len_prefix = len(na_prefix) + na_class = type(self).__name__ + if na_class[:len_prefix] != na_prefix: + raise RuntimeError( + "Trainer builder node class must start with ", + na_prefix, + "the current class is named: ", + na_class, + ) + return type(self).__name__[len_prefix:].lower() diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 8c74a16bc..59a619e7e 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -85,14 +85,14 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ method of hyperparameter scheduler - self.alpha = dict_rst[self.id + "_alpha"] + self.alpha = dict_rst[self.name + "_alpha"] def hyper_init(self, functor_scheduler): """hyper_init. :param functor_scheduler: """ parameters = {} - parameters[self.id + "_alpha"] = self.alpha + parameters[self.name + "_alpha"] = self.alpha return functor_scheduler(trainer=None, **parameters) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index 1be0d2613..362c14518 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -98,9 +98,9 @@ def hyper_update(self, epoch, fun_scheduler): :param fun_scheduler: """ dict_rst = fun_scheduler(epoch) - self.beta_d = dict_rst[self.id + "_beta_d"] - self.beta_y = dict_rst[self.id + "_beta_x"] - self.beta_x = dict_rst[self.id + "_beta_y"] + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_x"] + self.beta_x = dict_rst[self.name + "_beta_y"] def hyper_init(self, functor_scheduler): """ @@ -109,9 +109,9 @@ def hyper_init(self, functor_scheduler): :param functor_scheduler: the class name of the scheduler """ parameters = {} - parameters[self.id + "_beta_d"] = self.beta_d - parameters[self.id + "_beta_y"] = self.beta_y - parameters[self.id + "_beta_x"] = self.beta_x + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x return functor_scheduler( trainer=None, **parameters ) diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 2a3edd7cb..5c7bb290d 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -66,10 +66,10 @@ def hyper_update(self, epoch, fun_scheduler): dict_rst = fun_scheduler( epoch ) # the __call__ function of hyper-para-scheduler object - self.beta_d = dict_rst[self.id + "_beta_d"] - self.beta_y = dict_rst[self.id + "_beta_y"] - self.beta_x = dict_rst[self.id + "_beta_x"] - self.beta_t = dict_rst[self.id + "_beta_t"] + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_y"] + self.beta_x = dict_rst[self.name + "_beta_x"] + self.beta_t = dict_rst[self.name + "_beta_t"] def hyper_init(self, functor_scheduler): """hyper_init. @@ -78,12 +78,11 @@ def hyper_init(self, functor_scheduler): # calling the constructor of the hyper-parameter-scheduler class, so that this scheduler # class build a dictionary {"beta_d":self.beta_d, "beta_y":self.beta_y} # constructor signature is def __init__(self, **kwargs): - parameters = {} - parameters[self.id + "_beta_d"] = self.beta_d - parameters[self.id + "_beta_y"] = self.beta_y - parameters[self.id + "_beta_x"] = self.beta_x - parameters[self.id + "_beta_t"] = self.beta_t + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x + parameters[self.name + "_beta_t"] = self.beta_t return functor_scheduler( trainer=None, **parameters ) From 871374a407b0e75d5dba28478b1eab0a3cdf2b61 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Mon, 22 Apr 2024 11:34:25 +0200 Subject: [PATCH 03/21] Solved gamma_reg naming collision by introducing functionality to pass different gamma values for different trainers and models based on the naming of the class. Also added parameter printing --- docs/doc_usage_cmd.md | 20 +++++++++++++-- domainlab/algos/builder_dann.py | 3 ++- domainlab/algos/builder_jigen1.py | 3 ++- domainlab/algos/trainers/a_trainer.py | 7 ++++++ domainlab/algos/trainers/train_dial.py | 3 ++- domainlab/algos/trainers/train_fishr.py | 3 ++- domainlab/algos/trainers/train_matchdg.py | 3 ++- domainlab/algos/trainers/train_mldg.py | 3 ++- domainlab/arg_parser.py | 18 +++++++++++++- domainlab/models/a_model.py | 7 ++++++ domainlab/utils/hyperparameter_retrieval.py | 20 +++++++++++++++ examples/conf/vlcs_diva_mldg_dial.yaml | 5 +++- tests/test_hyperparameter_retrieval.py | 27 +++++++++++++++++++++ 13 files changed, 112 insertions(+), 10 deletions(-) create mode 100644 domainlab/utils/hyperparameter_retrieval.py create mode 100644 tests/test_hyperparameter_retrieval.py diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index 0dbb16a0f..e8e2f2480 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -21,8 +21,24 @@ To run DomainLab, the minimum necessary parameters are: ### Advanced Configuration - **Learning Rate (`--lr`):** Set the training learning rate. -- **Regularization (`--gamma_reg`):** Weight of regularization loss. -- **Early Stopping (`--es`):** Steps for early stopping. +- **Regularization (`--gamma_reg`):** Sets the weight of the regularization loss. This parameter can be configured either as a single value applied to individual classes, or using a dictionary to specify different weights for different models and trainers. + + - **Command Line Usage:** + - For a single value: `python script.py --gamma_reg=0.1` + - For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,diva=0.2'` + + - **YAML Configuration:** + - For a single value: + ```yaml + gamma_reg: 0.1 + ``` + - For different values: + ```yaml + gamma_reg: + default: 0.1 # every other instance that is not listed below will get this value assigned + dann: 0.05 + diva: 0.2 + ```- **Early Stopping (`--es`):** Steps for early stopping. - **Random Seed (`--seed`):** Seed for reproducibility. - **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings. - **Generated Images (`--gen`):** Option to save generated images. diff --git a/domainlab/algos/builder_dann.py b/domainlab/algos/builder_dann.py index 65b26a62a..73b373c15 100644 --- a/domainlab/algos/builder_dann.py +++ b/domainlab/algos/builder_dann.py @@ -13,6 +13,7 @@ from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter from domainlab.models.model_dann import mk_dann from domainlab.utils.utils_cuda import get_device +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class NodeAlgoBuilderDANN(NodeAlgoBuilder): @@ -55,7 +56,7 @@ def init_business(self, exp): model = mk_dann(list_str_y=task.list_str_y, net_classifier=net_classifier)( list_d_tr=task.list_domain_tr, - alpha=args.gamma_reg, + alpha=get_gamma_reg(args, 'dann'), net_encoder=net_encoder, net_discriminator=net_discriminator, builder=self) diff --git a/domainlab/algos/builder_jigen1.py b/domainlab/algos/builder_jigen1.py index e899e32f2..de671affe 100644 --- a/domainlab/algos/builder_jigen1.py +++ b/domainlab/algos/builder_jigen1.py @@ -15,6 +15,7 @@ from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches from domainlab.models.model_jigen import mk_jigen from domainlab.utils.utils_cuda import get_device +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class NodeAlgoBuilderJiGen(NodeAlgoBuilder): @@ -56,7 +57,7 @@ def init_business(self, exp): model = mk_jigen( list_str_y=task.list_str_y, net_classifier=net_classifier)( - coeff_reg=args.gamma_reg, + coeff_reg=get_gamma_reg(args, 'jigen'), net_encoder=net_encoder, net_classifier_permutation=net_classifier_perm, n_perm=args.nperm, diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 9bc126705..8d06a448e 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -256,3 +256,10 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset + + def print_parameters(self): + """ + Function to print all parameters of the object. Can also be used to print the parameters of the child classes + """ + params = vars(self) + print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 75a5e34f0..4fe700f45 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -5,6 +5,7 @@ from torch.autograd import Variable from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerDIAL(TrainerBasic): @@ -49,4 +50,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) - return [loss_dial], [self.aconf.gamma_reg] + return [loss_dial], [get_gamma_reg(self.aconf, self.name)] diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 1a11e3780..3580a0721 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -13,6 +13,7 @@ backpack = None from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerFishr(TrainerBasic): @@ -39,7 +40,7 @@ def tr_epoch(self, epoch): dict_layerwise_var_var_grads_sum = \ {key: val.sum() for key, val in dict_layerwise_var_var_grads.items()} loss_fishr = sum(dict_layerwise_var_var_grads_sum.values()) - loss = sum(list_loss_erm) + self.aconf.gamma_reg * loss_fishr + loss = sum(list_loss_erm) + get_gamma_reg(self.aconf, self.name) * loss_fishr loss.backward() self.optimizer.step() self.epo_loss_tr += loss.detach().item() diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index 6a3edd996..72c14ab83 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -13,6 +13,7 @@ ) from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD from domainlab.utils.logger import Logger +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerMatchDG(AbstractTrainer): @@ -36,7 +37,7 @@ def init_business( self.base_domain_size = get_base_domain_size4match_dg(self.task) self.epo_loss_tr = 0 self.flag_erm = flag_erm - self.lambda_ctr = self.aconf.gamma_reg + self.lambda_ctr = get_gamma_reg(aconf, self.name) self.mk_match_tensor(epoch=0) self.flag_match_tensor_sweep_over = False self.tuple_tensor_ref_domain2each_y = None diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index 90318286c..e91310adf 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -10,6 +10,7 @@ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.tasks.utils_task import mk_loader from domainlab.tasks.utils_task_dset import DsetZip +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg class TrainerMLDG(AbstractTrainer): @@ -108,7 +109,7 @@ def tr_epoch(self, epoch): loss = ( loss_source_task.sum() + source_reg_tr.sum() - + self.aconf.gamma_reg * loss_look_forward.sum() + + get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum() ) # loss.backward() diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 7c7e004e7..eb8c6836c 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -12,6 +12,19 @@ from domainlab.models.args_vae import add_args2parser_vae from domainlab.utils.logger import Logger +class StoreDictKeyPair(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + try: + if "=" in values: + my_dict = {} + for kv in values.split(","): + k, v = kv.split("=") + my_dict[k.strip()] = float(v.strip()) # Assuming values are floats + setattr(namespace, self.dest, my_dict) + else: + setattr(namespace, self.dest, float(values)) # Single float value + except ValueError: + raise argparse.ArgumentError(self, f"Invalid value for {self.dest}: {values}") def mk_parser_main(): """ @@ -31,7 +44,10 @@ def mk_parser_main(): parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") parser.add_argument( - "--gamma_reg", type=float, default=0.1, help="weight of regularization loss" + "--gamma_reg", + default=0.1, + help="weight of regularization loss, can specify per model as 'dann=1.0,diva=2.0'", + action=StoreDictKeyPair ) parser.add_argument("--es", type=int, default=1, help="early stop steps") diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index b53800b8e..6ea45747f 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -202,3 +202,10 @@ def name(self): na_class, ) return type(self).__name__[len_prefix:].lower() + + def print_parameters(self): + """ + Function to print all parameters of the object. Can also be used to print the parameters of the child classes + """ + params = vars(self) + print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py new file mode 100644 index 000000000..8e49c356c --- /dev/null +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -0,0 +1,20 @@ +""" +retrieval for hyperparameters +""" + +def get_gamma_reg(args, model_name): + """ + Retrieves either a shared gamma regularization, or individual ones for each specified object + """ + gamma_reg = args.gamma_reg + print(gamma_reg) + if isinstance(gamma_reg, dict): + print("is instance dict") + if model_name in gamma_reg: + return gamma_reg[model_name] + elif 'default' in gamma_reg: + return gamma_reg['default'] + else: + raise ValueError("If a gamma_reg dict is specified, but no value set for every model and trainer, a default value must be specified.") + else: + return gamma_reg # Return the single value if it's not a dictionary diff --git a/examples/conf/vlcs_diva_mldg_dial.yaml b/examples/conf/vlcs_diva_mldg_dial.yaml index ac2d83337..5cf1a8041 100644 --- a/examples/conf/vlcs_diva_mldg_dial.yaml +++ b/examples/conf/vlcs_diva_mldg_dial.yaml @@ -5,7 +5,10 @@ val_threshold: 0.8 model: dann_diva # combine model DANN with DIVA epos: 1 # number of epochs trainer: mldg_dial # combine trainer MLDG and DIAL -gamma_reg: 1.0 # hyperparameter of DANN +gamma_reg: + default: 1.0 + dann: 1.5 + diva: 2.0 gamma_y: 700000.0 # hyperparameter of diva gamma_d: 100000.0 # hyperparameter of diva npath: examples/nets/resnet.py # neural network for class classification diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py new file mode 100644 index 000000000..192343617 --- /dev/null +++ b/tests/test_hyperparameter_retrieval.py @@ -0,0 +1,27 @@ +# Assume imports and any necessary setup are already done +import argparse +from domainlab.arg_parser import StoreDictKeyPair, mk_parser_main +from domainlab.utils.hyperparameter_retrieval import get_gamma_reg + +def test_store_dict_key_pair_single_value(): + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', '0.5']) + assert args.gamma_reg == 0.5 + +def test_store_dict_key_pair_dict_value(): + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'dann=1.0,diva=2.0']) + assert args.gamma_reg == {'dann': 1.0, 'diva': 2.0} + +def test_get_gamma_reg_single_value(): + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', '0.5']) + assert get_gamma_reg(args, 'dann') == 0.5 + +def test_get_gamma_reg_dict_value(): + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,diva=2.0']) + print(args) + assert get_gamma_reg(args, 'dann') == 1.0 + assert get_gamma_reg(args, 'diva') == 2.0 + assert get_gamma_reg(args, 'nonexistent') == 5.0 From 4be4a021bdb35374372570b9f6d49a8a47c4b050 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 23 Apr 2024 17:43:30 +0200 Subject: [PATCH 04/21] fixed codacy --- domainlab/algos/trainers/a_trainer.py | 5 +++-- domainlab/arg_parser.py | 2 ++ domainlab/models/a_model.py | 5 +++-- domainlab/utils/hyperparameter_retrieval.py | 12 +++++++----- tests/test_hyperparameter_retrieval.py | 11 ++++++++--- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 8d06a448e..0e3d4a168 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -256,10 +256,11 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset - + def print_parameters(self): """ - Function to print all parameters of the object. Can also be used to print the parameters of the child classes + Function to print all parameters of the object. + Can be used to print the parameters of any child class """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index eb8c6836c..50ab707db 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -13,6 +13,8 @@ from domainlab.utils.logger import Logger class StoreDictKeyPair(argparse.Action): + """Class used for arg parsing where values are provided in a key value format""" + def __call__(self, parser, namespace, values, option_string=None): try: if "=" in values: diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 6ea45747f..60571248d 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -178,7 +178,7 @@ def dset_decoration_args_algo(self, args, ddset): if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset - + @property def p_na_prefix(self): """ @@ -205,7 +205,8 @@ def name(self): def print_parameters(self): """ - Function to print all parameters of the object. Can also be used to print the parameters of the child classes + Function to print all parameters of the object. + Can be used to print the parameters of every child class. """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}") diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py index 8e49c356c..c7882cb0b 100644 --- a/domainlab/utils/hyperparameter_retrieval.py +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -10,11 +10,13 @@ def get_gamma_reg(args, model_name): print(gamma_reg) if isinstance(gamma_reg, dict): print("is instance dict") - if model_name in gamma_reg: - return gamma_reg[model_name] - elif 'default' in gamma_reg: + if model_name in gamma_reg: + return gamma_reg[model_name] + if 'default' in gamma_reg: return gamma_reg['default'] - else: - raise ValueError("If a gamma_reg dict is specified, but no value set for every model and trainer, a default value must be specified.") + else: + raise ValueError("""If a gamma_reg dict is specified, + but no value set for every model and trainer, + a default value must be specified.""") else: return gamma_reg # Return the single value if it's not a dictionary diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index 192343617..e166e320b 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -1,24 +1,29 @@ -# Assume imports and any necessary setup are already done -import argparse -from domainlab.arg_parser import StoreDictKeyPair, mk_parser_main +""" +unit test for hyperparameter parsing +""" +from domainlab.arg_parser import mk_parser_main from domainlab.utils.hyperparameter_retrieval import get_gamma_reg def test_store_dict_key_pair_single_value(): + """Test to parse a single gamma_reg parameter""" parser = mk_parser_main() args = parser.parse_args(['--gamma_reg', '0.5']) assert args.gamma_reg == 0.5 def test_store_dict_key_pair_dict_value(): + """Test to parse a dict for the gamma_reg""" parser = mk_parser_main() args = parser.parse_args(['--gamma_reg', 'dann=1.0,diva=2.0']) assert args.gamma_reg == {'dann': 1.0, 'diva': 2.0} def test_get_gamma_reg_single_value(): + """Test to retrieve a single gamma_reg parameter which is applied to all objects""" parser = mk_parser_main() args = parser.parse_args(['--gamma_reg', '0.5']) assert get_gamma_reg(args, 'dann') == 0.5 def test_get_gamma_reg_dict_value(): + """Test to retrieve a dict of gamma_reg parameters for different objects""" parser = mk_parser_main() args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,diva=2.0']) print(args) From 394d23ab9cbfd4d4ab357d05722b4b4e8dc2296d Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 23 Apr 2024 18:00:58 +0200 Subject: [PATCH 05/21] fixed codacy --- domainlab/models/a_model.py | 2 +- domainlab/utils/hyperparameter_retrieval.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 60571248d..1b7e652f4 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -202,7 +202,7 @@ def name(self): na_class, ) return type(self).__name__[len_prefix:].lower() - + def print_parameters(self): """ Function to print all parameters of the object. diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py index c7882cb0b..d47135554 100644 --- a/domainlab/utils/hyperparameter_retrieval.py +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -14,9 +14,7 @@ def get_gamma_reg(args, model_name): return gamma_reg[model_name] if 'default' in gamma_reg: return gamma_reg['default'] - else: - raise ValueError("""If a gamma_reg dict is specified, - but no value set for every model and trainer, + raise ValueError("""If a gamma_reg dict is specified, + but no value set for every model and trainer, a default value must be specified.""") - else: - return gamma_reg # Return the single value if it's not a dictionary + return gamma_reg # Return the single value if it's not a dictionary From cc650addd27be3922bfcee46fdc2d3fe1dc86700 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 30 Apr 2024 10:48:10 +0200 Subject: [PATCH 06/21] Removed diva from tests and yaml for gamma hyperparam --- domainlab/models/a_model.py | 2 +- domainlab/utils/hyperparameter_retrieval.py | 2 -- examples/conf/vlcs_diva_mldg_dial.yaml | 3 +-- tests/test_hyperparameter_retrieval.py | 8 ++++---- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 1b7e652f4..71b0db334 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -196,7 +196,7 @@ def name(self): na_class = type(self).__name__ if na_class[:len_prefix] != na_prefix: raise RuntimeError( - "Trainer builder node class must start with ", + "Model builder node class must start with ", na_prefix, "the current class is named: ", na_class, diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py index d47135554..9b3c97e7e 100644 --- a/domainlab/utils/hyperparameter_retrieval.py +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -7,9 +7,7 @@ def get_gamma_reg(args, model_name): Retrieves either a shared gamma regularization, or individual ones for each specified object """ gamma_reg = args.gamma_reg - print(gamma_reg) if isinstance(gamma_reg, dict): - print("is instance dict") if model_name in gamma_reg: return gamma_reg[model_name] if 'default' in gamma_reg: diff --git a/examples/conf/vlcs_diva_mldg_dial.yaml b/examples/conf/vlcs_diva_mldg_dial.yaml index 5cf1a8041..71b376a03 100644 --- a/examples/conf/vlcs_diva_mldg_dial.yaml +++ b/examples/conf/vlcs_diva_mldg_dial.yaml @@ -7,8 +7,7 @@ epos: 1 # number of epochs trainer: mldg_dial # combine trainer MLDG and DIAL gamma_reg: default: 1.0 - dann: 1.5 - diva: 2.0 + dann: 1.5 gamma_y: 700000.0 # hyperparameter of diva gamma_d: 100000.0 # hyperparameter of diva npath: examples/nets/resnet.py # neural network for class classification diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index e166e320b..fa3119e98 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -13,8 +13,8 @@ def test_store_dict_key_pair_single_value(): def test_store_dict_key_pair_dict_value(): """Test to parse a dict for the gamma_reg""" parser = mk_parser_main() - args = parser.parse_args(['--gamma_reg', 'dann=1.0,diva=2.0']) - assert args.gamma_reg == {'dann': 1.0, 'diva': 2.0} + args = parser.parse_args(['--gamma_reg', 'dann=1.0,jigen=2.0']) + assert args.gamma_reg == {'dann': 1.0, 'jigen': 2.0} def test_get_gamma_reg_single_value(): """Test to retrieve a single gamma_reg parameter which is applied to all objects""" @@ -25,8 +25,8 @@ def test_get_gamma_reg_single_value(): def test_get_gamma_reg_dict_value(): """Test to retrieve a dict of gamma_reg parameters for different objects""" parser = mk_parser_main() - args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,diva=2.0']) + args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0']) print(args) assert get_gamma_reg(args, 'dann') == 1.0 - assert get_gamma_reg(args, 'diva') == 2.0 + assert get_gamma_reg(args, 'jigen') == 2.0 assert get_gamma_reg(args, 'nonexistent') == 5.0 From febe8760366adcd991f66b469284502177e55cd6 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 7 May 2024 11:16:44 +0200 Subject: [PATCH 07/21] Increased test coverage --- domainlab/arg_parser.py | 29 ++++++------ tests/test_a_model.py | 64 ++++++++++++++++++++++++++ tests/test_a_trainer.py | 34 ++++++++++++++ tests/test_argparse.py | 31 ++++++++++++- tests/test_hyperparameter_retrieval.py | 10 +++- 5 files changed, 153 insertions(+), 15 deletions(-) create mode 100644 tests/test_a_model.py create mode 100644 tests/test_a_trainer.py diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 50ab707db..32faa48be 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -12,21 +12,24 @@ from domainlab.models.args_vae import add_args2parser_vae from domainlab.utils.logger import Logger -class StoreDictKeyPair(argparse.Action): +class ParseValuesOrKeyValuePairs(argparse.Action): """Class used for arg parsing where values are provided in a key value format""" def __call__(self, parser, namespace, values, option_string=None): - try: - if "=" in values: - my_dict = {} - for kv in values.split(","): - k, v = kv.split("=") - my_dict[k.strip()] = float(v.strip()) # Assuming values are floats - setattr(namespace, self.dest, my_dict) - else: - setattr(namespace, self.dest, float(values)) # Single float value - except ValueError: - raise argparse.ArgumentError(self, f"Invalid value for {self.dest}: {values}") + if "=" in values: + my_dict = {} + for kv in values.split(","): + k, v = kv.split("=") + try: + my_dict[k.strip()] = float(v.strip()) + except ValueError: + raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float") + setattr(namespace, self.dest, my_dict) + else: + try: + setattr(namespace, self.dest, float(values)) + except ValueError: + raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float") def mk_parser_main(): """ @@ -49,7 +52,7 @@ def mk_parser_main(): "--gamma_reg", default=0.1, help="weight of regularization loss, can specify per model as 'dann=1.0,diva=2.0'", - action=StoreDictKeyPair + action=ParseValuesOrKeyValuePairs ) parser.add_argument("--es", type=int, default=1, help="early stop steps") diff --git a/tests/test_a_model.py b/tests/test_a_model.py new file mode 100644 index 000000000..c770f271d --- /dev/null +++ b/tests/test_a_model.py @@ -0,0 +1,64 @@ +""" +Test a model functionality +""" + +import pytest +from io import StringIO +import sys +from torch import nn +from domainlab.models.a_model import AModel + +class ModelTest(AModel): + """ + A test model class conforming to model naming + """ + + def __init__(self): + super().__init__() + self.test_param = 42 + + def cal_task_loss(self, tensor_x, tensor_y): + return 0 + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + return 0 + + +class InvalidTest(AModel): + """ + A test model class that does not conform to the "Model" prefix naming convention + """ + + def cal_task_loss(self, tensor_x, tensor_y): + return 0 + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + return 0 + + +def test_model_name_valid(): + """ + Test a valid model name + """ + model = ModelTest() + assert model.name == "test", f"Expected 'test' but got '{model.name}'" + + +def test_model_name_invalid(): + """ + Test an invalid model name + """ + model = InvalidTest() + with pytest.raises(RuntimeError, match="Model builder node class must start with"): + model.name + + +def test_print_parameters(capsys): + """ + Test the printing of parameters + """ + model = ModelTest() + model.print_parameters() + captured = capsys.readouterr() + assert "Parameters of ModelTest:" in captured.out + assert "'test_param': 42" in captured.out \ No newline at end of file diff --git a/tests/test_a_trainer.py b/tests/test_a_trainer.py new file mode 100644 index 000000000..4543f0d20 --- /dev/null +++ b/tests/test_a_trainer.py @@ -0,0 +1,34 @@ +""" +Test a trainer functionality +""" + +from domainlab.algos.trainers.a_trainer import AbstractTrainer + +class TrainerTest(AbstractTrainer): + """ + A test trainer class conforming to model naming + """ + + def __init__(self): + super().__init__() + self.test_param = 42 + + def tr_epoch(self, epoch): + """ + :param epoch: + """ + + def before_tr(self): + """ + before training, probe model performance + """ + +def test_print_parameters(capsys): + """ + Test the printing of parameters + """ + trainer = TrainerTest() + trainer.print_parameters() + captured = capsys.readouterr() + assert "Parameters of TrainerTest:" in captured.out + assert "'test_param': 42" in captured.out \ No newline at end of file diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 683f66911..199f6d56a 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -7,7 +7,8 @@ import pytest -from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args +from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args, ParseValuesOrKeyValuePairs +import argparse def test_parse_cmd_args_warning(): @@ -54,3 +55,31 @@ def test_apply_dict_to_args(): apply_dict_to_args(args, data, extend=True) assert args.a == 1 assert args.model == "diva" + +def test_store_dict_key_value_valid(): + """Testing to parse valid gamma_reg value""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + namespace = parser.parse_args(["--keypair", "1"]) + assert namespace.keypair == 1.0 + +def test_store_dict_key_value_pair_valid(): + """Testing to parse valid gamma_reg key value paris""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + namespace = parser.parse_args(["--keypair", "value1=1,value2=2"]) + assert namespace.keypair == {"value1": 1.0, "value2": 2.0} + +def test_store_dict_key_value_invalid(): + """Testing to parse invalid gamma_reg value""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + with pytest.raises(ValueError): + parser.parse_args(["--keypair", "invalid"]) + +def test_store_dict_key_value_pair_invalid(): + """Testing to parse invalid gamma_reg key value pairs""" + parser = mk_parser_main() + parser.add_argument("--keypair", action=ParseValuesOrKeyValuePairs) + with pytest.raises(ValueError): + parser.parse_args(["--keypair", "value1=1,value2=invalid"]) diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index fa3119e98..ec2f4d2b4 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -3,6 +3,7 @@ """ from domainlab.arg_parser import mk_parser_main from domainlab.utils.hyperparameter_retrieval import get_gamma_reg +import pytest def test_store_dict_key_pair_single_value(): """Test to parse a single gamma_reg parameter""" @@ -26,7 +27,14 @@ def test_get_gamma_reg_dict_value(): """Test to retrieve a dict of gamma_reg parameters for different objects""" parser = mk_parser_main() args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0']) - print(args) assert get_gamma_reg(args, 'dann') == 1.0 assert get_gamma_reg(args, 'jigen') == 2.0 assert get_gamma_reg(args, 'nonexistent') == 5.0 + +def test_exception(): + """Test to not specify a default value""" + parser = mk_parser_main() + args = parser.parse_args(['--gamma_reg', 'dann=1.0']) + + with pytest.raises(ValueError, match="If a gamma_reg dict is specified"): + get_gamma_reg(args, 'jigen') From fd360565d9aab1ca3cb60b90feac0493b6ced126 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 7 May 2024 11:28:09 +0200 Subject: [PATCH 08/21] fixed codacity --- tests/test_a_model.py | 18 +++++++++++------- tests/test_a_trainer.py | 2 +- tests/test_argparse.py | 8 ++------ tests/test_hyperparameter_retrieval.py | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_a_model.py b/tests/test_a_model.py index c770f271d..8480cfd55 100644 --- a/tests/test_a_model.py +++ b/tests/test_a_model.py @@ -1,11 +1,7 @@ """ Test a model functionality """ - import pytest -from io import StringIO -import sys -from torch import nn from domainlab.models.a_model import AModel class ModelTest(AModel): @@ -22,6 +18,10 @@ def cal_task_loss(self, tensor_x, tensor_y): def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): return 0 + + @property + def metric4msel(self): + return "" class InvalidTest(AModel): @@ -35,6 +35,10 @@ def cal_task_loss(self, tensor_x, tensor_y): def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): return 0 + @property + def metric4msel(self): + return "" + def test_model_name_valid(): """ @@ -49,8 +53,7 @@ def test_model_name_invalid(): Test an invalid model name """ model = InvalidTest() - with pytest.raises(RuntimeError, match="Model builder node class must start with"): - model.name + with pytest.raises(RuntimeError, match="Model builder node class must start with"): model.name def test_print_parameters(capsys): @@ -61,4 +64,5 @@ def test_print_parameters(capsys): model.print_parameters() captured = capsys.readouterr() assert "Parameters of ModelTest:" in captured.out - assert "'test_param': 42" in captured.out \ No newline at end of file + assert "'test_param': 42" in captured.out + \ No newline at end of file diff --git a/tests/test_a_trainer.py b/tests/test_a_trainer.py index 4543f0d20..d4e073ae3 100644 --- a/tests/test_a_trainer.py +++ b/tests/test_a_trainer.py @@ -31,4 +31,4 @@ def test_print_parameters(capsys): trainer.print_parameters() captured = capsys.readouterr() assert "Parameters of TrainerTest:" in captured.out - assert "'test_param': 42" in captured.out \ No newline at end of file + assert "'test_param': 42" in captured.out diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 199f6d56a..2eaaae10b 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -1,15 +1,11 @@ """ Test argparser functionality """ - import os import sys - import pytest - -from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args, ParseValuesOrKeyValuePairs -import argparse - +from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args +from domainlab.arg_parser import ParseValuesOrKeyValuePairs def test_parse_cmd_args_warning(): """Call argparser for command line""" diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index ec2f4d2b4..cfc7b28f9 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -1,9 +1,9 @@ """ unit test for hyperparameter parsing """ +import pytest from domainlab.arg_parser import mk_parser_main from domainlab.utils.hyperparameter_retrieval import get_gamma_reg -import pytest def test_store_dict_key_pair_single_value(): """Test to parse a single gamma_reg parameter""" From 6520834b1af455710b8b656eb8e7195ada50b84e Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 7 May 2024 11:33:18 +0200 Subject: [PATCH 09/21] fixed codacity --- tests/test_a_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_a_model.py b/tests/test_a_model.py index 8480cfd55..5b2d7bd0a 100644 --- a/tests/test_a_model.py +++ b/tests/test_a_model.py @@ -18,7 +18,7 @@ def cal_task_loss(self, tensor_x, tensor_y): def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): return 0 - + @property def metric4msel(self): return "" @@ -53,7 +53,8 @@ def test_model_name_invalid(): Test an invalid model name """ model = InvalidTest() - with pytest.raises(RuntimeError, match="Model builder node class must start with"): model.name + with pytest.raises(RuntimeError, match="Model builder node class must start with"): + _ = model.name def test_print_parameters(capsys): From c5e3e6af6bb3480497a9f74dfa9dde8c4b23c3f7 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 7 May 2024 11:36:31 +0200 Subject: [PATCH 10/21] fixed codacity --- tests/test_a_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_a_model.py b/tests/test_a_model.py index 5b2d7bd0a..fac060f06 100644 --- a/tests/test_a_model.py +++ b/tests/test_a_model.py @@ -53,7 +53,7 @@ def test_model_name_invalid(): Test an invalid model name """ model = InvalidTest() - with pytest.raises(RuntimeError, match="Model builder node class must start with"): + with pytest.raises(RuntimeError, match="Model builder node class must start with"): _ = model.name From 884682b54c1cf9c597498bd93559bf7d09ac2280 Mon Sep 17 00:00:00 2001 From: matteowohlrapp Date: Tue, 7 May 2024 14:12:20 +0200 Subject: [PATCH 11/21] Added quotes around variables --- run_benchmark_slurm.sh | 2 +- run_benchmark_standalone.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/run_benchmark_slurm.sh b/run_benchmark_slurm.sh index 61615198a..8db5b6965 100755 --- a/run_benchmark_slurm.sh +++ b/run_benchmark_slurm.sh @@ -32,4 +32,4 @@ echo "Configuration file: $CONFIGFILE" echo "starting seed is: $DOMAINLAB_CUDA_START_SEED" echo "verbose log: $logfile" # Helmholtz -snakemake --profile "examples/yaml/slurm" --config yaml_file=$CONFIGFILE --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" +snakemake --profile "examples/yaml/slurm" --config yaml_file="$CONFIGFILE" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" diff --git a/run_benchmark_standalone.sh b/run_benchmark_standalone.sh index 8f712d1d1..386b7dab5 100755 --- a/run_benchmark_standalone.sh +++ b/run_benchmark_standalone.sh @@ -61,7 +61,7 @@ fi snakemake --rerun-incomplete --cores 1 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" --keep-going --summary # this will give us a clue first what jobs will be run # second submit the jobs, make sure you have more than 4 cores on your laptop, otherwise adjust the cores -snakemake --config yaml_file=$CONFIGFILE --rerun-incomplete --resources nvidia_gpu=$NUMBER_GPUS --cores 4 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" +snakemake --config yaml_file="$CONFIGFILE" --rerun-incomplete --resources nvidia_gpu="$NUMBER_GPUS" --cores 4 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" # snakemake --rerun-incomplete --cores 1 -s "domainlab/exp_protocol/benchmark.smk" --configfile "examples/yaml/demo_benchmark.yaml" From 5d2fb8222cbe9e0e5081f8389acc07c2c0f4778a Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 11:25:20 +0200 Subject: [PATCH 12/21] Fixed codacity --- docs/doc_usage_cmd.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index dfb92b3ca..3c6e53176 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -21,7 +21,9 @@ To run DomainLab, the minimum necessary parameters are: ### Advanced Configuration - **Learning Rate (`--lr`):** Set the training learning rate. -- **Regularization (`--gamma_reg`):** Sets the weight of the regularization loss. This parameter can be configured either as a single value applied to individual classes, or using a dictionary to specify different weights for different models and trainers. +- **Regularization (`--gamma_reg`):** Sets the weight of the regularization loss. +This parameter can be configured either as a single value applied to individual classes, +or using a dictionary to specify different weights for different models and trainers. - **Command Line Usage:** - For a single value: `python script.py --gamma_reg=0.1` @@ -32,13 +34,16 @@ To run DomainLab, the minimum necessary parameters are: ```yaml gamma_reg: 0.1 ``` + - For different values: ```yaml gamma_reg: - default: 0.1 # every other instance that is not listed below will get this value assigned dann: 0.05 diva: 0.2 - ```- **Early Stopping (`--es`):** Steps for early stopping. + default: 0.1 # value for every other instance + ``` + +- **Early Stopping (`--es`):** Steps for early stopping. - **Random Seed (`--seed`):** Seed for reproducibility. - **CUDA Options (`--nocu`, `--device`):** Configure CUDA usage and device settings. - **Generated Images (`--gen`):** Option to save generated images. From 570e8a86826ec31353fe883d6a670e1f2d51583a Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 11 Jun 2024 13:03:37 +0200 Subject: [PATCH 13/21] fixed codacity --- docs/doc_usage_cmd.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index 3c6e53176..ee814fdd3 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -21,9 +21,11 @@ To run DomainLab, the minimum necessary parameters are: ### Advanced Configuration - **Learning Rate (`--lr`):** Set the training learning rate. -- **Regularization (`--gamma_reg`):** Sets the weight of the regularization loss. -This parameter can be configured either as a single value applied to individual classes, -or using a dictionary to specify different weights for different models and trainers. +- **Regularization (`--gamma_reg`):** Sets the weight of the regularization + loss. This parameter can be configured either as + a single value applied to individual classes, + or using a dictionary to specify different + weights for different models and trainers. - **Command Line Usage:** - For a single value: `python script.py --gamma_reg=0.1` @@ -31,11 +33,13 @@ or using a dictionary to specify different weights for different models and trai - **YAML Configuration:** - For a single value: + ```yaml gamma_reg: 0.1 ``` - For different values: + ```yaml gamma_reg: dann: 0.05 From 1198eeb452cfbec047161e3d9a8186fb638a6018 Mon Sep 17 00:00:00 2001 From: Matteo Wohlrapp Date: Tue, 2 Jul 2024 12:33:36 +0200 Subject: [PATCH 14/21] Corrected doc for gamma_reg, added docstring to call method in argparser and specified input types --- docs/doc_usage_cmd.md | 4 +++- domainlab/arg_parser.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index ee814fdd3..28325b5af 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -43,9 +43,11 @@ To run DomainLab, the minimum necessary parameters are: ```yaml gamma_reg: dann: 0.05 - diva: 0.2 + dial: 0.2 default: 0.1 # value for every other instance ``` +Gamma reg is available for the trainers, as well as the +dann and jigen model. - **Early Stopping (`--es`):** Steps for early stopping. - **Random Seed (`--seed`):** Seed for reproducibility. diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 32faa48be..1b386c698 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -15,8 +15,22 @@ class ParseValuesOrKeyValuePairs(argparse.Action): """Class used for arg parsing where values are provided in a key value format""" - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, parser: argparse.ArgumentParser, + namespace: argparse.Namespace, values: str, option_string: str = None): + """ + Handle parsing of key value pairs, or a single value instead + + Args: + parser (argparse.ArgumentParser): The ArgumentParser object. + namespace (argparse.Namespace): The namespace object to store parsed values. + values (str): The string containing key=value pairs or a single float value. + option_string (str, optional): The option string that triggered this action (unused). + + Raises: + ValueError: If the values cannot be parsed to float. + """ if "=" in values: + my_dict = {} if "=" in values: my_dict = {} for kv in values.split(","): k, v = kv.split("=") From 97d93209dd7f36124778f324cd4f8a29ea9b714f Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 13:48:03 +0200 Subject: [PATCH 15/21] Update arg_parser.py, update documentation , remove diva gamma_reg --- domainlab/arg_parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 32faa48be..ce981e9f0 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -51,7 +51,10 @@ def mk_parser_main(): parser.add_argument( "--gamma_reg", default=0.1, - help="weight of regularization loss, can specify per model as 'dann=1.0,diva=2.0'", + help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \ + can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \ + note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \ + so diva does not have gamma_reg", action=ParseValuesOrKeyValuePairs ) From 05d228b6bc200856e9d3c312a12f840336a57595 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 13:55:22 +0200 Subject: [PATCH 16/21] Update doc_usage_cmd.md, remove diva gamma_reg in doc --- docs/doc_usage_cmd.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index ee814fdd3..eba5720a1 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -29,7 +29,7 @@ To run DomainLab, the minimum necessary parameters are: - **Command Line Usage:** - For a single value: `python script.py --gamma_reg=0.1` - - For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,diva=0.2'` + - For multiple values: `python script.py --gamma_reg='default=0.1,dann=0.05,jigen=0.2'` - **YAML Configuration:** - For a single value: From 5c0f6b1ff091709d5cd1b2c43ca7d7e2e16ecd02 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 13:56:06 +0200 Subject: [PATCH 17/21] Update doc_usage_cmd.md --- docs/doc_usage_cmd.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index eba5720a1..e8b4f121b 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -43,7 +43,7 @@ To run DomainLab, the minimum necessary parameters are: ```yaml gamma_reg: dann: 0.05 - diva: 0.2 + jigen: 0.2 default: 0.1 # value for every other instance ``` From 31ce1ecb4f193a11052aa6798b57361d3c56046e Mon Sep 17 00:00:00 2001 From: MatteoWohlrapp Date: Tue, 2 Jul 2024 15:09:11 +0200 Subject: [PATCH 18/21] fixed argparser syntax, renamed argument --- domainlab/arg_parser.py | 1 - domainlab/utils/hyperparameter_retrieval.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 3abd471bb..1b8593e8f 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -30,7 +30,6 @@ def __call__(self, parser: argparse.ArgumentParser, ValueError: If the values cannot be parsed to float. """ if "=" in values: - my_dict = {} if "=" in values: my_dict = {} for kv in values.split(","): k, v = kv.split("=") diff --git a/domainlab/utils/hyperparameter_retrieval.py b/domainlab/utils/hyperparameter_retrieval.py index 9b3c97e7e..5a1960e2e 100644 --- a/domainlab/utils/hyperparameter_retrieval.py +++ b/domainlab/utils/hyperparameter_retrieval.py @@ -2,14 +2,14 @@ retrieval for hyperparameters """ -def get_gamma_reg(args, model_name): +def get_gamma_reg(args, component_name): """ Retrieves either a shared gamma regularization, or individual ones for each specified object """ gamma_reg = args.gamma_reg if isinstance(gamma_reg, dict): - if model_name in gamma_reg: - return gamma_reg[model_name] + if component_name in gamma_reg: + return gamma_reg[component_name] if 'default' in gamma_reg: return gamma_reg['default'] raise ValueError("""If a gamma_reg dict is specified, From 6392ac5c27b9e1df2fcb68da6c1fca55dd118e65 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 15:36:39 +0200 Subject: [PATCH 19/21] Update vlcs_diva_mldg_dial.yaml --- examples/conf/vlcs_diva_mldg_dial.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/conf/vlcs_diva_mldg_dial.yaml b/examples/conf/vlcs_diva_mldg_dial.yaml index 39286b66b..e24cfd22e 100644 --- a/examples/conf/vlcs_diva_mldg_dial.yaml +++ b/examples/conf/vlcs_diva_mldg_dial.yaml @@ -7,7 +7,8 @@ epos: 1 # number of epochs trainer: mldg_dial # combine trainer MLDG and DIAL gamma_reg: default: 1.0 - dann: 1.5 + dann: 1.5 +# in this case, mldg and dial get the default gamma_reg value 1.0 gamma_y: 700000.0 # hyperparameter of diva gamma_d: 100000.0 # hyperparameter of diva npath: examples/nets/resnet.py # neural network for class classification From 0d629194fa559e63abbc6c1eadbd0685d87f1c2a Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Tue, 2 Jul 2024 15:45:05 +0200 Subject: [PATCH 20/21] Update test_hyperparameter_retrieval.py, comments in unit test --- tests/test_hyperparameter_retrieval.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index cfc7b28f9..2295f0489 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -29,7 +29,9 @@ def test_get_gamma_reg_dict_value(): args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0']) assert get_gamma_reg(args, 'dann') == 1.0 assert get_gamma_reg(args, 'jigen') == 2.0 - assert get_gamma_reg(args, 'nonexistent') == 5.0 + assert get_gamma_reg(args, 'nonexistent') == 5.0 # if we implement other model/trainers, + # since not specified in command line arguments, the new model/trainer called "nonexistent" should + # get the default value 5.0. def test_exception(): """Test to not specify a default value""" From bdf84c281b58946be0a3eff5dde54230c486bb06 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 2 Jul 2024 16:38:46 +0200 Subject: [PATCH 21/21] style --- tests/test_hyperparameter_retrieval.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_hyperparameter_retrieval.py b/tests/test_hyperparameter_retrieval.py index 2295f0489..a9a141329 100644 --- a/tests/test_hyperparameter_retrieval.py +++ b/tests/test_hyperparameter_retrieval.py @@ -29,9 +29,11 @@ def test_get_gamma_reg_dict_value(): args = parser.parse_args(['--gamma_reg', 'default=5.0,dann=1.0,jigen=2.0']) assert get_gamma_reg(args, 'dann') == 1.0 assert get_gamma_reg(args, 'jigen') == 2.0 - assert get_gamma_reg(args, 'nonexistent') == 5.0 # if we implement other model/trainers, - # since not specified in command line arguments, the new model/trainer called "nonexistent" should - # get the default value 5.0. + assert get_gamma_reg(args, 'nonexistent') == 5.0 # if we implement other + # model/trainers, + # since not specified in command line arguments, the new model/trainer + # called "nonexistent" should + # get the default value 5.0. def test_exception(): """Test to not specify a default value"""