diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86ca0b445..b6d9e7cde 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,11 +2,10 @@ name: CI on: push: - branches: master + branches: mhof_dev pull_request: - branches: master + branches: mhof_dev workflow_dispatch: - jobs: test: name: Run tests diff --git a/.gitignore b/.gitignore index a351b31fe..b001bad9e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,8 +6,6 @@ tests/__pycache__/ .vscode/ domainlab/zdata/pacs /data/ -/.snakemake/ /dist /domainlab.egg-info /runs -/slurm_errors.txt \ No newline at end of file diff --git a/README.md b/README.md index 6cecda105..e16d02681 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ For example, the following result (without any augmentation like flip) is for P Source: https://arxiv.org/pdf/2403.14356.pdf +Citation: ```bibtex @misc{sun2024domainlab, title={DomainLab: A modular Python package for domain generalization in deep learning}, @@ -132,3 +133,71 @@ Source: https://arxiv.org/pdf/2403.14356.pdf year={2024} } ``` + +# M-HOF-Opt: Multi-Objective Hierarchical Output Feedback Optimization via Multiplier Induced Loss Landscape Scheduling +Source: https://arxiv.org/pdf/2403.13728.pdf + +M-HOF-Opt is implemented in [DomainLab](https://github.com/marrlab/DomainLab). If you meet any problems, feel free to report them at https://github.com/marrlab/DomainLab/issues + +## Dependencies and Data Preparation +#### Example dependencies installation +``` +git checkout mhof # switch to mhof branch +conda create --name domainlab_py39 python=3.9 # create a virtual environment +conda activate domainlab_py39 # activate virtual environment +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge +conda install torchmetrics==0.10.3 +pip install -r requirements_notorch.txt +conda install tensorboard # install tensorboard +``` + +#### Data preparation: download the domain generalization dataset PACS + +step 1: + +use the following script to download PACS to your local laptop and upload it to your cluster + +https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py + +step 2: +make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh + +where `mkdir -p data/pacs` is executed under the repository directory, + +`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS` +will create a symbolic link under the repository directory + +### M-HOF experiments reproduction + +#### Run the experiment + +To execute a single run of the M-HOF method, from the root folder run the command: + +``` +python main_out.py -c a_reproduce_pacs_diva.yaml +``` + +which uses the configuration file [a_reproduce_pacs_diva.yaml](https://github.com/marrlab/DomainLab/blob/mhof/a_reproduce_pacs_diva.yaml). + +#### Visualization of the results + +The results of the experiment are stored in the `runs` directory generated by Tensorboard. +The various loss curves with the corresponding setpoint change curves, as well as phase-portrait-like figures showing the loss dynamics between the task loss and the various regularization losses, can be obtained by running the script [script_generate_all_figures_diva.sh](https://github.com/marrlab/DomainLab/blob/mhof/script_generate_all_figures_diva.sh): + +``` +bash script_generate_all_figures_diva.sh +``` + +The resulting figures will be stored in the directory `figures_diva`, which can be changed by editing the top of the [script_generate_all_figures_diva.sh](https://github.com/marrlab/DomainLab/blob/mhof/script_generate_all_figures_diva.sh) file if needed. + +Citation: +```bibtex +@misc{sun2024m, + title={M-HOF-Opt: Multi-Objective Hierarchical Output Feedback Optimization via Multiplier Induced Loss Landscape Scheduling}, + author={Sun, Xudong and Chen, Nutan and Gossmann, Alexej and Xing, Yu and Dorigatt, Emilio and Drost, Felix and Feistner, Carla and Scarcella, Daniele and Beer, Lisa and Marr, Carsten}, + journal={https://arxiv.org/pdf/2403.13728.pdf}, + number={2403.13728}, + year={2024}, + publisher={https://arxiv.org/pdf/2403.13728.pdf} +} +``` \ No newline at end of file diff --git a/a_reproduce_pacs_diva.yaml b/a_reproduce_pacs_diva.yaml new file mode 100644 index 000000000..db3c234eb --- /dev/null +++ b/a_reproduce_pacs_diva.yaml @@ -0,0 +1,24 @@ +te_d: sketch +tpath: examples/tasks/task_pacs_aug.py +bs: 32 +model: diva +trainer: fbopt +gamma_y: 1.0 +ini_setpoint_ratio: 0.99 +str_diva_multiplier_type: gammad_recon +coeff_ma_output_state: 0.1 +coeff_ma_setpoint: 0.9 +exp_shoulder_clip: 5 +mu_init: 0.000001 +k_i_gain_ratio: 0.5 +mu_clip: 10 +epos: 1000 +epos_min: 200 +npath: examples/nets/resnet50domainbed.py +npath_dom: examples/nets/resnet50domainbed.py +es: 2 +lr: 0.00005 +zx_dim: 0 +zy_dim: 64 +zd_dim: 64 +force_setpoint_change_once: True diff --git a/a_test_feedforward_irm.sh b/a_test_feedforward_irm.sh new file mode 100644 index 000000000..007f3e9ac --- /dev/null +++ b/a_test_feedforward_irm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=hyperscheduler_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=10 --epos_min=4 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.99999999 diff --git a/a_test_mhof_irm.sh b/a_test_mhof_irm.sh new file mode 100644 index 000000000..3cf1ae7ed --- /dev/null +++ b/a_test_mhof_irm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --model=erm --nname=conv_bn_pool_2 --trainer=fbopt_irm_dial --k_i_gain_ratio=0.5 --force_setpoint_change_once --epos=500 --epos_min=400 --exp_shoulder_clip=1 --mu_clip=100 --ini_setpoint_ratio=0.9 --nb4reg_over_task_ratio=0 --tr_with_init_mu --coeff_ma_setpoint=0.0 --str_setpoint_ada="SliderAnyComponent()" diff --git a/domainlab/algos/builder_diva.py b/domainlab/algos/builder_diva.py index 5e33826ec..968840d2c 100644 --- a/domainlab/algos/builder_diva.py +++ b/domainlab/algos/builder_diva.py @@ -3,7 +3,9 @@ """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen @@ -35,7 +37,8 @@ def init_business(self, exp): request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args) node = VAEChainNodeGetter(request)() task.get_list_domains_tr_te(args.tr_d, args.te_d) - model = mk_diva(list_str_y=task.list_str_y)( + model = mk_diva(str_diva_multiplier_type=args.str_diva_multiplier_type, list_str_y=task.list_str_y)( + node, zd_dim=args.zd_dim, zy_dim=args.zy_dim, @@ -48,7 +51,9 @@ def init_business(self, exp): beta_d=args.beta_d, ) device = get_device(args) - model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold) + model_sel = MSelSetpointDelay( + MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)), val_threshold=args.val_threshold + ) if not args.gen: observer = ObVisitor(model_sel) else: diff --git a/domainlab/algos/builder_fbopt_dial.py b/domainlab/algos/builder_fbopt_dial.py new file mode 100644 index 000000000..f1faad96b --- /dev/null +++ b/domainlab/algos/builder_fbopt_dial.py @@ -0,0 +1,21 @@ +""" +builder for feedback optimization of dial +""" +from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA +from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt + + +class NodeAlgoBuilderFbOptDial(NodeAlgoBuilderDIVA): + """ + builder for feedback optimization for dial + """ + + def init_business(self, exp): + """ + return trainer, model, observer + """ + trainer_in, model, observer, device = super().init_business(exp) + trainer_in.init_business(model, exp.task, observer, device, exp.args) + trainer = TrainerFbOpt() + trainer.init_business(trainer_in, exp.task, observer, device, exp.args) + return trainer, model, observer, device diff --git a/domainlab/algos/builder_jigen1.py b/domainlab/algos/builder_jigen1.py index de671affe..fb1e78214 100644 --- a/domainlab/algos/builder_jigen1.py +++ b/domainlab/algos/builder_jigen1.py @@ -3,7 +3,9 @@ """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_setpoint_delay import MSelSetpointDelay from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.msels.c_msel_val_top_k import MSelValPerfTopK from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential @@ -30,7 +32,7 @@ def init_business(self, exp): task = exp.task args = exp.args device = get_device(args) - msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es), val_threshold=args.val_threshold) + msel = MSelSetpointDelay(MSelOracleVisitor(MSelValPerfTopK(max_es=args.es)), val_threshold=args.val_threshold) observer = ObVisitor(msel) observer = ObVisitorCleanUp(observer) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 6a593c9b3..09b703377 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -124,6 +124,12 @@ def sel_model_te_acc(self): return self.msel.sel_model_te_acc return -1 + @property + def oracle_last_setpoint_sel_te_acc(self): + if self.msel is not None: + return self.msel.oracle_last_setpoint_sel_te_acc + return -1 + @property def model_selection_epoch(self): """ diff --git a/domainlab/algos/msels/c_msel_setpoint_delay.py b/domainlab/algos/msels/c_msel_setpoint_delay.py new file mode 100644 index 000000000..d62983e89 --- /dev/null +++ b/domainlab/algos/msels/c_msel_setpoint_delay.py @@ -0,0 +1,54 @@ +""" +logs the best up-to-event selected model at each event when setpoint shrinks +""" +from domainlab.algos.msels.a_model_sel import AMSel +from domainlab.utils.logger import Logger + + +class MSelSetpointDelay(AMSel): + """ + This class decorate another model selection object, it logs the current + selected performance from the decoratee each time the setpoint shrinks + """ + + def __init__(self, msel, val_threshold = None): + super().__init__(val_threshold) + # NOTE: super() has to come first always otherwise self.msel will be overwritten to be None + self.msel = msel + self._oracle_last_setpoint_sel_te_acc = 0.0 + + @property + def oracle_last_setpoint_sel_te_acc(self): + """ + return the last setpoint best acc + """ + return self._oracle_last_setpoint_sel_te_acc + + def base_update(self, clear_counter=False): + """ + if the best model should be updated + currently, clear_counter is set via + flag = super().tr_epoch(epoch, self.flag_setpoint_updated) + """ + logger = Logger.get_logger() + logger.info( + f"setpoint selected current acc {self._oracle_last_setpoint_sel_te_acc}" + ) + if clear_counter: + # for the current version of code, clear_counter = flag_setpoint_updated + log_message = ( + f"setpoint msel te acc updated from " + # self._oracle_last_setpoint_sel_te_acc start from 0.0, and always saves + # the test acc when last setpoint decrease occurs + f"{self._oracle_last_setpoint_sel_te_acc} to " + # self.sel_model_te_acc defined as a property + # in a_msel, which returns self.msel.sel_model_te_acc + # is the validation acc based model selection, which + # does not take setpoint into account + f"{self.sel_model_te_acc}" + ) + logger.info(log_message) + self._oracle_last_setpoint_sel_te_acc = self.sel_model_te_acc + # let decoratee decide if model should be selected or not + flag = self.msel.update(clear_counter) + return flag diff --git a/domainlab/algos/msels/c_msel_val_top_k.py b/domainlab/algos/msels/c_msel_val_top_k.py new file mode 100644 index 000000000..f557c7dc1 --- /dev/null +++ b/domainlab/algos/msels/c_msel_val_top_k.py @@ -0,0 +1,61 @@ +""" +Model Selection should be decoupled from +""" +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.utils.logger import Logger + + +class MSelValPerfTopK(MSelValPerf): + """ + 1. Model selection using validation performance + 2. Visitor pattern to trainer + """ + + def __init__(self, max_es, top_k=2): + super().__init__(max_es) # construct self.tr_obs (observer) + self.top_k = top_k + self.list_top_k_acc = [0.0 for _ in range(top_k)] + + def update(self, clear_counter=False): + """ + if the best model should be updated + """ + flag_super = super().update(clear_counter) + metric_val_current = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] + acc_min = min(self.list_top_k_acc) + if metric_val_current > acc_min: + # overwrite + logger = Logger.get_logger() + logger.info( + f"top k validation acc: {self.list_top_k_acc} \ + overwriting/reset counter" + ) + self.es_c = 0 # restore counter + ind = self.list_top_k_acc.index(acc_min) + # avoid having identical values + if metric_val_current not in self.list_top_k_acc: + self.list_top_k_acc[ind] = metric_val_current + logger.info( + f"top k validation acc updated: \ + {self.list_top_k_acc}" + ) + # overwrite to ensure consistency + # issue #569: initially self.list_top_k_acc will be [xx, 0] and it does not matter since 0 will be overwriten by second epoch validation acc. + # actually, after epoch 1, most often, sefl._best_val_acc will be the higher value of self.list_top_k_acc will overwriten by min(self.list_top_k_acc) + logger.info( + f"top-2 val sel: overwriting best val acc from {self._best_val_acc} to " + f"minimum of {self.list_top_k_acc} which is {min(self.list_top_k_acc)} " + f"to ensure consistency" + ) + self._best_val_acc = min(self.list_top_k_acc) + # overwrite test acc, this does not depend on if val top-k acc has been overwritten or not + metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + if self._sel_model_te_acc != metric_te_current: + # this can only happen if the validation acc has decreased and current val acc is only bigger than min(self.list_top_k_acc} but lower than max(self.list_top_k_acc) + logger.info( + f"top-2 val sel: overwriting selected model test acc from " + f"{self._sel_model_te_acc} to {metric_te_current} to ensure consistency" + ) + self._sel_model_te_acc = metric_te_current + return True # if metric_val_current > acc_min: + return flag_super diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 432da9703..4c3d1917c 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -28,6 +28,8 @@ def __init__(self, model_sel): self.metric_val = None self.perf_metric = None + self.flag_setpoint_changed_once = False + @property def str_metric4msel(self): """ @@ -35,7 +37,13 @@ def str_metric4msel(self): """ return self.host_trainer.str_metric4msel - def update(self, epoch): + def reset(self): + """ + reset observer via reset model selector + """ + self.model_sel.reset() + + def update(self, epoch, flag_info=False): logger = Logger.get_logger() logger.info(f"epoch: {epoch}") self.epo = epoch @@ -53,13 +61,18 @@ def update(self, epoch): self.loader_te, self.device ) self.metric_te = metric_te - if self.model_sel.update(epoch): + if self.model_sel.update(epoch, flag_info): logger.info("better model found") self.host_trainer.model.save() logger.info("persisted") acc = self.metric_te.get("acc") flag_stop = self.model_sel.if_stop(acc) flag_enough = epoch >= self.host_trainer.aconf.epos_min + + self.flag_setpoint_changed_once |= flag_info + if self.host_trainer.aconf.force_setpoint_change_once: + return flag_stop & flag_enough & self.flag_setpoint_changed_once + return flag_stop & flag_enough def accept(self, trainer): @@ -106,7 +119,15 @@ def after_all(self): metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch}) else: metric_te.update({"acc_val": -1}) - metric_te.update({"model_selection_epoch": -1}) + + if hasattr(self, "model_sel") and hasattr( + self.model_sel, "oracle_last_setpoint_sel_te_acc" + ): + metric_te.update( + {"acc_setpoint": self.model_sel.oracle_last_setpoint_sel_te_acc} + ) + else: + metric_te.update({"acc_setpoint": -1}) self.dump_prediction(model_ld, metric_te) # save metric to one line in csv result file self.host_trainer.model.visitor(metric_te) diff --git a/domainlab/algos/observers/c_obvisitor_cleanup.py b/domainlab/algos/observers/c_obvisitor_cleanup.py index 91ac53216..4de3ef6b4 100644 --- a/domainlab/algos/observers/c_obvisitor_cleanup.py +++ b/domainlab/algos/observers/c_obvisitor_cleanup.py @@ -12,13 +12,13 @@ def __init__(self, observer): def after_all(self): self.observer.after_all() - self.observer.clean_up() + self.observer.clean_up() # FIXME should be self.clean_up??? def accept(self, trainer): self.observer.accept(trainer) - def update(self, epoch): - return self.observer.update(epoch) + def update(self, epoch, flag_info=False): + return self.observer.update(epoch, flag_info) def clean_up(self): self.observer.clean_up() @@ -26,3 +26,15 @@ def clean_up(self): @property def model_sel(self): return self.observer.model_sel + + @model_sel.setter + def model_sel(self, model_sel): + self.observer.model_sel = model_sel + + @property + def metric_te(self): + return self.observer.metric_te + + @property + def metric_val(self): + return self.observer.metric_val diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 051cc1e6f..e825fff16 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -60,14 +60,17 @@ def __init__(self, successor_node=None, extend=None): """ super().__init__(successor_node) self._model = None + # decoratee can be both model or trainer self._decoratee = extend self.task = None self.observer = None self.device = None self.aconf = None + self.gamma_reg = None # self.dict_loader_tr = None self.loader_tr = None + self.loader_tr_no_drop = None self.loader_te = None self.num_batches = None self.flag_update_hyper_per_epoch = None @@ -92,12 +95,17 @@ def __init__(self, successor_node=None, extend=None): self.inner_trainer = None self.loader_tr_source_target = None self.flag_initialized = False + # fbopt + self.mu_iter_start = 0 + self.flag_setpoint_updated = False # moving average self.ma_weight_previous_model_params = None self._dict_previous_para_persist = {} self._ma_iter = 0 # self.list_reg_over_task_ratio = None + # mhof + self.dict_multiplier = {} # MIRO self.input_tensor_shape = None # LR-scheduler @@ -154,9 +162,11 @@ def init_business(self, model, task, observer, device, aconf, flag_accept=True): self.observer = observer self.device = device self.aconf = aconf + self.gamma_reg = self.aconf.gamma_reg # self.dict_loader_tr = task.dict_loader_tr self.loader_tr = task.loader_tr + self.loader_tr_no_drop = task._loader_tr_no_drop self.loader_te = task.loader_te if flag_accept: @@ -202,7 +212,13 @@ def before_tr(self): """ before training, probe model performance """ - self.cal_reg_loss_over_task_loss_ratio() + list_mu = self.cal_reg_loss_over_task_loss_ratio() + self.dict_multiplier = {'mu4regloss'+str(i): value for i, value in enumerate(list_mu)} + + @property + def list_str_multiplier_na(self): + list_str = list(self.dict_multiplier.keys()) + return list_str def cal_reg_loss_over_task_loss_ratio(self): """ @@ -211,19 +227,23 @@ def cal_reg_loss_over_task_loss_ratio(self): """ list_accum_reg_loss = [] loss_task_agg = 0 + list_mu = None for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): self.input_tensor_shape = tensor_x.shape - if ind_batch >= self.aconf.nb4reg_over_task_ratio: - return + tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), tensor_d.to(self.device), ) - list_reg_loss_tensor, _ = \ + list_reg_loss_tensor, list_mu = \ self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + + if ind_batch >= self.aconf.nb4reg_over_task_ratio: + return list_mu + list_reg_loss_tensor = [torch.sum(tensor).detach().item() for tensor in list_reg_loss_tensor] if ind_batch == 0: @@ -239,6 +259,7 @@ def cal_reg_loss_over_task_loss_ratio(self): loss_task_agg += tensor_loss_task self.list_reg_over_task_ratio = [reg_loss / loss_task_agg for reg_loss in list_accum_reg_loss] + return list_mu def post_tr(self): """ @@ -277,6 +298,15 @@ def get_model(self): return self._model return self._model.get_model() + def as_model(self): + """ + used for decorator pattern + + It is not necessary to write any function that just copies the pattern + self.get_model().do_something() + """ + return self.get_model() + def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ decorate trainer regularization loss @@ -296,7 +326,25 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): list_loss_tensor = list_reg_loss_model_tensor + \ list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer - return list_loss_tensor, list_mu + # ERM return a tensor of all zeros, delete here + if len(list_mu) > 1 and "ModelERM" == type(self.get_model()).__name__: + list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item() + for i in range(len(list_mu))] + list_loss_tensor = [list_loss_tensor[i] for (i, flag) in + enumerate(list_boolean_zero) if not flag] + list_mu = [list_mu[i] for (i, flag) in enumerate(list_boolean_zero) if not flag] + if self.dict_multiplier: + list_mu = list(self.dict_multiplier.values()) + + list_loss_tensor_normalized = list_loss_tensor + if self.list_reg_over_task_ratio: + assert len(list_mu) == len(self.list_reg_over_task_ratio) + list_loss_tensor_normalized = \ + [reg_loss / reg_over_task_ratio if reg_over_task_ratio != 0 + else reg_loss for (reg_loss, reg_over_task_ratio) + in zip(list_loss_tensor, self.list_reg_over_task_ratio)] + + return list_loss_tensor_normalized, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ @@ -321,3 +369,23 @@ def print_parameters(self): """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}") + + def hyper_init(self, functor_scheduler, trainer): + """ + initialize both trainer's multiplier and model's multiplier + """ + if not self.dict_multiplier: + raise RuntimeError("self.dict_multiplier empty!") + return functor_scheduler( + trainer=trainer, **self.dict_multiplier + ) + + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + dict_rst = fun_scheduler(epoch) + for key in self.dict_multiplier: + self.dict_multiplier[key] = dict_rst[key] diff --git a/domainlab/algos/trainers/args_fbopt.py b/domainlab/algos/trainers/args_fbopt.py new file mode 100644 index 000000000..e1fa536ac --- /dev/null +++ b/domainlab/algos/trainers/args_fbopt.py @@ -0,0 +1,156 @@ +""" +feedback opt +""" + + +def add_args2parser_fbopt(parser): + """ + append hyper-parameters to the main argparser + """ + + parser.add_argument( + "--k_i_gain", type=float, default=0.001, + help="PID control gain for integrator, if k_i_gain_ratio is not None, \ + then this value will be overwriten, see doc for k_i_gain_ratio" + ) + + parser.add_argument( + "--k_i_gain_ratio", + type=float, + default=None, + help="set k_i_gain to be ratio of initial saturation k_i_gain \ + which K_I * delta = exp_shoulder_clip(saturation value), solve \ + for K_I, where delta = reg loss - setpoint. \ + Now independent of the scale of delta, the K_I gain will be set so \ + that the multiplier will be increased at a rate defined by \ + exp_shoulder_clip", + ) + + parser.add_argument( + "--mu_clip", type=float, default=1e4, + help="maximum value of mu: mu_clip should be large enough so that the \ + regularization loss as penalty can be weighed superior enough to \ + decrease." + ) + + parser.add_argument( + "--mu_min", type=float, default=1e-6, help="minimum value of mu, mu \ + can not be negative" + ) + + parser.add_argument( + "--mu_init", type=float, default=0.001, + help="initial value for each component of the multiplier vector" + ) + + parser.add_argument( + "--coeff_ma", type=float, default=0.5, + help="exponential moving average of delta \ + (reg minus setpoint as control error): \ + move_ave=move_ave + coeff*delta(current value)" + ) + + parser.add_argument( + "--coeff_ma_output_state", + type=float, + default=0.1, + help="output (reguarization loss) exponential moving average \ + move_ave=move_ave*coeef + reg(current value)", + ) + + parser.add_argument( + "--coeff_ma_setpoint", + type=float, + default=0.9, + help="setpoint average (coeff for previous setpoint)", + ) + + parser.add_argument( + "--exp_shoulder_clip", + type=float, + default=5, + help="clip delta(control error): \ + R(reg-loss)-b(setpoint) before exponential operation: \ + exp[clip(R-b, exp_shoulder_clip)].\ + exponential magnifies control error, so this argument \ + defines the maximum rate of change of multipliers \ + exp(5)=148, exp_shoulder_clip should not be too big, \ + if exp_shoulder_clip is small, then more like exterior point method", + ) + + parser.add_argument( + "--ini_setpoint_ratio", + type=float, + default=0.99, + help="before training start, evaluate reg loss, \ + setpoint will be 0.9 of this loss", + ) + + parser.add_argument( + "--force_feedforward", + action="store_true", + default=False, + help="use feedforward scheduler", + ) + + parser.add_argument( + "--force_setpoint_change_once", + action="store_true", + default=False, + help="continue training until the setpoint changed at least once: \ + up to maximum epos specified", + ) + + parser.add_argument( + "--no_tensorboard", + action="store_true", + default=False, + help="disable tensorboard", + ) + + parser.add_argument( + "--no_setpoint_update", + action="store_true", + default=False, + help="disable setpoint update", + ) + + parser.add_argument( + "--tr_with_init_mu", + action="store_true", + default=False, + help="disable setpoint update", + ) + + parser.add_argument( + "--no_overshoot_rewind", + action="store_true", + default=False, + help="disable overshoot rewind: when reg loss satisfies setpoint \ + already, then set activation=K_I*delta = 0", + ) + + parser.add_argument( + "--setpoint_rewind", + action="store_true", + default=False, + help="rewind setpoint", + ) + + # this arg is only used when model is set to be "diva" + parser.add_argument( + "--str_setpoint_ada", + type=str, + default="DominateAllComponent()", + help="which setpoint adaptation criteria to use", + ) + + # this arg is only used when model is set to be "diva" + parser.add_argument( + "--str_diva_multiplier_type", + type=str, + default="gammad_recon", + help="which penalty to tune, only useful to DIVA model", + ) + + return parser diff --git a/domainlab/algos/trainers/compos/matchdg_match.py b/domainlab/algos/trainers/compos/matchdg_match.py index 78e67abde..8c6b46c90 100644 --- a/domainlab/algos/trainers/compos/matchdg_match.py +++ b/domainlab/algos/trainers/compos/matchdg_match.py @@ -16,6 +16,7 @@ class MatchPair: """ match different input """ + @store_args def __init__( self, diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py new file mode 100644 index 000000000..272d34908 --- /dev/null +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -0,0 +1,297 @@ +""" +update hyper-parameters during training +""" +import os +import warnings + +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from domainlab.algos.trainers.fbopt_setpoint_ada import ( + FbOptSetpointController, + if_list_sign_agree, +) +from domainlab.utils.logger import Logger + + +class StubSummaryWriter: + """ + # stub writer for tensorboard that ignores all messages + """ + + def add_scalar(self, *args, **kwargs): + """ + stub, pass do nothing + """ + + def add_scalars(self, *args, **kwargs): + """ + stub, pass, do nothing + """ + + +class HyperSchedulerFeedback: + # pylint: disable=too-many-instance-attributes + """ + design $\\mu$$ sequence based on state of penalized loss + """ + + def __init__(self, trainer, **kwargs): + """ + kwargs is a dictionary with key the hyper-parameter name and its value + """ + self.trainer = trainer + self.init_mu = trainer.aconf.mu_init + self.mu_min = trainer.aconf.mu_min + self.mu_clip = trainer.aconf.mu_clip + + if not kwargs: + raise RuntimeError("feedback scheduler requires **kwargs, the set \ + of multipliers non-empty") + self.mmu = kwargs + # force initial value of mu + self.mmu = {key: self.init_mu for key, val in self.mmu.items()} + self.set_point_controller = FbOptSetpointController(args=self.trainer.aconf) + + self.k_i_control = [trainer.aconf.k_i_gain for i in + range(len(self.mmu))] + self.k_i_gain_ratio = trainer.aconf.k_i_gain_ratio + self.overshoot_rewind = not trainer.aconf.no_overshoot_rewind + self.delta_epsilon_r = None + + # NOTE: this value will be set according to initial evaluation of + # neural network + self.activation_clip = trainer.aconf.exp_shoulder_clip + self.coeff4newval_ma_delta = trainer.aconf.coeff_ma + # NOTE: + # print(copy.deepcopy(self.model)) + # TypeError: cannot pickle '_thread.lock' object + if trainer.aconf.no_tensorboard: + self.writer = StubSummaryWriter() + else: + str_job_id = os.environ.get("SLURM_JOB_ID", "") + self.writer = SummaryWriter(comment=str_job_id) + + def set_k_i_gain(self, epo_reg_loss): + if self.k_i_gain_ratio is None: + if self.k_i_control: + return + raise RuntimeError("set either direct k_i_control value or \ + set k_i_gain_ratio, can not be both empty!") + # NOTE: do not use self.cal_delta4control!!!! which will change + # class member variables self.delta_epsilon_r! + list_setpoint = self.get_setpoint4r() + if_list_sign_agree(epo_reg_loss, list_setpoint) + delta_epsilon_r = [a - b for a, b in zip(epo_reg_loss, list_setpoint)] + + # to calculate self.delta_epsilon_r + list_active = [self.activation_clip for i in range(len(delta_epsilon_r))] + + k_i_gain_saturate = [ + a / b for a, b in zip(list_active, delta_epsilon_r) + ] + + # FIXME: add max K_I gain here if initial delta is too small + + k_i_gain_saturate_min = min(k_i_gain_saturate) + # NOTE: here we override the commandline arguments specification + # for k_i_control, so k_i_control is not a hyperparameter anymore + # self.k_i_control = [self.k_i_gain_ratio * ele for ele in k_i_gain_saturate] + # k_I should be the same for each component, the control error already + # make the multiplier magnification different + self.k_i_control = [self.k_i_gain_ratio * k_i_gain_saturate_min for i + in range(len(delta_epsilon_r))] + warnings.warn( + f"hyperparameter k_i_gain disabled! \ + replace with {self.k_i_control}" + ) + # FIXME: change this to 1-self.ini_setpoint_ratio, i.e. the more + # difficult the initial setpoint is, the bigger the k_i_gain should be + + def get_setpoint4r(self): + """ + get setpoint list + """ + return self.set_point_controller.setpoint4R + + def set_setpoint(self, list_setpoint4r, setpoint4ell): + """ + set the setpoint + """ + self.set_point_controller.setpoint4R = list_setpoint4r + self.set_point_controller.setpoint4ell = setpoint4ell + + def cal_delta4control(self, list1, list_setpoint): + """ + list difference + """ + if_list_sign_agree(list1, list_setpoint) + delta_epsilon_r = [a - b for a, b in zip(list1, list_setpoint)] + if self.delta_epsilon_r is None: + self.delta_epsilon_r = delta_epsilon_r + else: + # PI control. + # self.delta_epsilon_r is the previous time step. + # delta_epsilon_r is the current time step + self.delta_epsilon_r = self.cal_delta_integration( + self.delta_epsilon_r, delta_epsilon_r, + self.coeff4newval_ma_delta + ) + + def cal_delta_integration(self, list_old, list_new, coeff): + """ + ma of delta + """ + return [(1 - coeff) * a + coeff * b for a, b in zip(list_old, list_new)] + + def tackle_overshoot(self, activation, epo_reg_loss, list_str_multiplier_na): + """ + tackle overshoot + """ + list_overshoot = [ + i if (a - b) * (self.delta_epsilon_r[i]) < 0 else None + for i, (a, b) in enumerate( + zip(epo_reg_loss, self.set_point_controller.setpoint4R) + ) + ] + for ind in list_overshoot: + if ind is not None: + logger = Logger.get_logger( + logger_name="main_out_logger", loglevel="INFO" + ) + logger.info(f"delta integration: {self.delta_epsilon_r}") + logger.info( + f"overshooting at pos \ + {ind} of activation: {activation}" + ) + logger.info(f"name reg loss:{list_str_multiplier_na}") + if self.overshoot_rewind: + activation[ind] = 0.0 + logger.info( + f"PID controller set to zero now, \ + new activation: {activation}" + ) + return activation + + def cal_activation(self): + """ + calculate activation on exponential shoulder + """ + setpoint = self.get_setpoint4r() + activation = [ + self.k_i_control[i] * val if setpoint[i] > 0 else self.k_i_control[i] * (-val) + for i, val in enumerate(self.delta_epsilon_r) + ] + if self.activation_clip is not None: + activation = [ + np.clip( + val, a_min=-1 * self.activation_clip, a_max=self.activation_clip + ) + for val in activation + ] + return activation + + def search_mu( + self, epo_reg_loss, epo_task_loss, epo_loss_tr, list_str_multiplier_na, miter + ): + # pylint: disable=too-many-locals, too-many-arguments + """ + start from parameter dictionary dict_theta: {"layer":tensor}, + enlarge mu w.r.t. its current value + to see if the criteria is met + $$\\mu^{k+1}=mu^{k}exp(rate_mu*[R(\\theta^{k})-ref_R])$$ + """ + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info(f"before controller: current mu: {self.mmu}") + logger.info(f"epo reg loss: {epo_reg_loss}") + logger.info(f"name reg loss:{list_str_multiplier_na}") + self.cal_delta4control(epo_reg_loss, self.get_setpoint4r()) + activation = self.cal_activation() + # overshoot handling + activation = self.tackle_overshoot( + activation, epo_reg_loss, list_str_multiplier_na + ) + list_gain = np.exp(activation) + dict_gain = dict(zip(list_str_multiplier_na, list_gain)) + target = self.dict_multiply(self.mmu, dict_gain) + self.mmu = self.dict_clip(target) + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info(f"after contoller: current mu: {self.mmu}") + + for key, val in self.mmu.items(): + self.writer.add_scalar(f"dyn_mu/{key}", val, miter) + self.writer.add_scalar(f"controller_gain/{key}", dict_gain[key], miter) + ind = list_str_multiplier_na.index(key) + self.writer.add_scalar(f"delta/{key}", self.delta_epsilon_r[ind], miter) + + if list_str_multiplier_na: + for i, (reg_dyn, reg_set) in enumerate( + zip(epo_reg_loss, self.get_setpoint4r()) + ): + + self.writer.add_scalar( + f"lossrd/dyn_{list_str_multiplier_na[i]}", reg_dyn, miter + ) + self.writer.add_scalar( + f"lossrs/setpoint_{list_str_multiplier_na[i]}", reg_set, miter + ) + + self.writer.add_scalars( + f"loss_rds/loss_{list_str_multiplier_na[i]}_w_setpoint", + { + f"lossr/loss_{list_str_multiplier_na[i]}": reg_dyn, + f"lossr/setpoint_{list_str_multiplier_na[i]}": reg_set, + }, + miter, + ) + self.writer.add_scalar( + f"x_ell_y_r/loss_{list_str_multiplier_na[i]}", reg_dyn, epo_task_loss + ) + else: + logger.info("No multiplier provided") + self.writer.add_scalar("loss_task/penalized", epo_loss_tr, miter) + self.writer.add_scalar("loss_task/ell", epo_task_loss, miter) + acc_te = 0 + acc_val = 0 + acc_sel = 0 + acc_set = 0 + + if miter > 1: + acc_te = self.trainer.observer.metric_te["acc"] + acc_val = self.trainer.observer.metric_val["acc"] + acc_sel = self.trainer.observer.model_sel.sel_model_te_acc + acc_set = self.trainer.observer.model_sel.oracle_last_setpoint_sel_te_acc + self.writer.add_scalar("acc/te", acc_te, miter) + self.writer.add_scalar("acc/val", acc_val, miter) + self.writer.add_scalar("acc/sel", acc_sel, miter) + self.writer.add_scalar("acc/setpoint", acc_set, miter) + + def dict_clip(self, dict_base): + """ + clip each entry of the mu according to pre-set self.mu_clip + """ + return { + key: np.clip(val, a_min=self.mu_min, a_max=self.mu_clip) + for key, val in dict_base.items() + } + + def dict_is_zero(self, dict_mu): + """ + check if hyper-parameter start from zero + """ + for key in dict_mu.keys(): + if dict_mu[key] == 0.0: + return True + return False + + def dict_multiply(self, dict_base, dict_multiplier): + """ + multiply a float to each element of a dictionary + """ + return {key: val * dict_multiplier[key] for key, val in dict_base.items()} + + def update_setpoint(self, epo_reg_loss, epo_task_loss): + """ + update setpoint + """ + return self.set_point_controller.observe(epo_reg_loss, epo_task_loss) diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py new file mode 100644 index 000000000..e22e13a75 --- /dev/null +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -0,0 +1,324 @@ +""" +update hyper-parameters during training +""" +import numpy as np + +from domainlab.utils.logger import Logger + + +def list_true(list1): + """ + find out position of a list which has element True + """ + arr_pos = np.arange(len(list1))[list1] + return list(arr_pos) + + +def list_add(list1, list2): + """ + add two lists + """ + return [a + b for a, b in zip(list1, list2)] + + +def list_multiply(list1, coeff): + """ + multiply a scalar to a list + """ + return [ele * coeff for ele in list1] + + +def if_list_sign_agree(list1, list2): + """ + each pair must have the same sign + """ + list_agree = [a * b >= 0 for a, b in zip(list1, list2)] + if not all(list_agree): + raise RuntimeError(f"{list1} and {list2} can not be compared!") + + +def is_less_list_any(list1, list2): + """ + judge if one list is less than the other + """ + if_list_sign_agree(list1, list2) + list_comparison = [ + a < b if a >= 0 and b >= 0 else a > b for a, b in zip(list1, list2) + ] + return any(list_comparison), list_true(list_comparison) + + +def is_less_list_all(list1, list2, flag_eq=False): + """ + judge if one list is less than the other + """ + if_list_sign_agree(list1, list2) + list_comparison = [ + a < b if a >= 0 and b >= 0 else a > b for a, b in zip(list1, list2) + ] + if flag_eq: + list_comparison = [ + a <= b if a >= 0 and b >= 0 else a >= b for a, b in zip(list1, list2) + ] + return all(list_comparison) + + +def list_ma(list_state, list_input, coeff): + """ + moving average of list + """ + return [a * coeff + b * (1 - coeff) for a, b in zip(list_state, list_input)] + + +class SetpointRewinder: + """ + rewind setpoint if current loss exponential moving average is + bigger than setpoint + """ + + def __init__(self, host): + self.host = host + self.counter = None + self.epo_ma = None + self.ref = None + self.coeff_ma_setpoint_rewinder = 0.5 + self.setpoint_rewind = host.flag_setpoint_rewind + + def reset(self, epo_reg_loss): + """ + when setpoint is adjusted + """ + self.counter = 0 + self.epo_ma = [0.0 for _ in range(10)] # FIXME + self.ref = epo_reg_loss + + def observe(self, epo_reg_loss): + """ + update moving average + """ + if self.ref is None: + self.reset(epo_reg_loss) + self.epo_ma = list_ma(self.epo_ma, epo_reg_loss, self.coeff_ma_setpoint_rewinder) + list_comparison_increase = [a < b for a, b in zip(self.ref, self.epo_ma)] + list_comparison_above_setpoint = [ + a < b for a, b in zip(self.host.setpoint4R, self.epo_ma) + ] + flag_increase = any(list_comparison_increase) + flag_above_setpoint = any(list_comparison_above_setpoint) + if flag_increase and flag_above_setpoint: + self.counter += 1 + + else: + self.counter = 0 + self.reset(epo_reg_loss) + + if self.setpoint_rewind: + if self.counter > 2 and self.counter <= 3: + # only allow self.counter = 2, 3 to rewind setpoing twice + list_pos = list_true(list_comparison_above_setpoint) + print(f"\n\n\n!!!!!!!setpoint too low at {list_pos}!\n\n\n") + for pos in list_pos: + print( + f"\n\n\n!!!!!!!rewinding setpoint at pos {pos} \ + from {self.host.setpoint4R[pos]} to \ + {self.epo_ma[pos]}!\n\n\n" + ) + self.host.setpoint4R[pos] = self.epo_ma[pos] + + if self.counter > 3: + self.host.transition_to(FixedSetpoint()) + self.counter = np.inf # FIXME + + +class FbOptSetpointController: + # pylint: disable=too-many-instance-attributes + """ + update setpoint for mu + """ + + def __init__(self, state=None, args=None): + """ + kwargs is a dictionary with key the hyper-parameter name and its value + """ + if state is None: + if args is not None and args.no_setpoint_update: + state = FixedSetpoint() + else: + # state = eval('DominateAllComponent()') + # state = DominateAllComponent() + state = eval(args.str_setpoint_ada) + self.transition_to(state) + self.flag_setpoint_rewind = args.setpoint_rewind + self.setpoint_rewinder = SetpointRewinder(self) + self.state_task_loss = 0.0 + self.state_epo_reg_loss = [ + 0.0 for _ in range(10) + ] # FIXME: 10 is the maximum number losses here + self.coeff_ma_setpoint = args.coeff_ma_setpoint + self.coeff_ma_output = args.coeff_ma_output_state + # initial value will be set via trainer + self.setpoint4R = None + self.setpoint4ell = None + self.host = None + + def transition_to(self, state): + """ + change internal state + """ + self.state_updater = state + self.state_updater.accept(self) + + def update_setpoint_ma(self, list_target, list_pos): + """ + using moving average + """ + target_ma = [ + self.coeff_ma_setpoint * a + (1 - self.coeff_ma_setpoint) * b + for a, b in zip(self.setpoint4R, list_target) + ] + self.setpoint4R = [ + target_ma[i] if i in list_pos else self.setpoint4R[i] + for i in range(len(target_ma)) + ] + + def observe(self, epo_reg_loss, epo_task_loss): + """ + read current epo_reg_loss continuously + """ + self.state_epo_reg_loss = [ + self.coeff_ma_output * a + (1 - self.coeff_ma_output) * b if a != 0.0 else b + for a, b in zip(self.state_epo_reg_loss, epo_reg_loss) + ] + if self.state_task_loss == 0.0: + self.state_task_loss = epo_task_loss + self.state_task_loss = ( + self.coeff_ma_output * self.state_task_loss + + (1 - self.coeff_ma_output) * epo_task_loss + ) + self.setpoint_rewinder.observe(self.state_epo_reg_loss) + flag_update, list_pos = self.state_updater.update_setpoint() + if flag_update: + self.setpoint_rewinder.reset(self.state_epo_reg_loss) + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info(f"!!!!!set point old value {self.setpoint4R}!") + self.update_setpoint_ma(self.state_epo_reg_loss, list_pos) + logger.info(f"!!!!!set point updated to {self.setpoint4R}!") + return True + return False + + +class FbOptSetpointControllerState: + # pylint: disable=too-few-public-methods + """ + abstract state pattern + """ + + def __init__(self): + """ """ + self.host = None + + def accept(self, controller): + """ + set host for state + """ + self.host = controller + + +class FixedSetpoint(FbOptSetpointControllerState): + """ + do not update setpoint + """ + + def update_setpoint(self): + """ + always return False so setpoint no update + """ + return False, None + + +class SliderAllComponent(FbOptSetpointControllerState): + """ + concrete state pattern + """ + + def update_setpoint(self): + """ + all components of R descreases regardless if ell decreases or not + """ + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info( + f"comparing output vs setpoint: \n \ + {self.host.state_epo_reg_loss} \n \ + {self.host.setpoint4R}" + ) + if is_less_list_all( + self.host.state_epo_reg_loss, self.host.setpoint4R, flag_eq=True + ): + logger.info( + "!!!!!!!!!In SliderAllComponent: R current value better than current setpoint!" + ) + return True, list(range(len(self.host.setpoint4R))) + return False, None + + +class SliderAnyComponent(FbOptSetpointControllerState): + """ + concrete state pattern + """ + + def update_setpoint(self): + """ + if any component of R has decreased regardless if ell decreases + """ + flag, list_pos = is_less_list_any( + self.host.state_epo_reg_loss, self.host.setpoint4R + ) + return flag, list_pos + + def transit(self): + self.host.transition_to(SliderAllComponent()) + + +class DominateAnyComponent(SliderAnyComponent): + """ + concrete state pattern + """ + + def update_setpoint(self): + """ + if any of the component of R loss has decreased together with ell loss + """ + flag1, list_pos = super().update_setpoint() + flag2 = self.host.state_task_loss < self.host.setpoint4ell + if flag2: + self.host.setpoint4ell = self.host.state_task_loss + return flag1 & flag2, list_pos + + +class DominateAllComponent(SliderAllComponent): + """ + concrete state pattern + """ + + def update_setpoint(self): + """ + if each component of R loss has decreased and ell loss also decreased + """ + flag1, list_pos = super().update_setpoint() + flag2 = self.host.state_task_loss < self.host.setpoint4ell + if flag2: + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info( + f"best ell loss: from {self.host.setpoint4ell} to \ + {self.host.state_task_loss}" + ) + self.host.setpoint4ell = self.host.state_task_loss + + if flag1 & flag2: + logger = Logger.get_logger(logger_name="main_out_logger", loglevel="INFO") + logger.info( + "!!!!!!!!!In DominantAllComponent: \ + besides each components of reg loss shrinks \ + task loss also decreased which forms dominance!" + ) + return flag1 & flag2, list_pos diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 10ac3b06f..447b446f0 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -23,6 +23,7 @@ def before_tr(self): """ self.model.evaluate(self.loader_te, self.device) super().before_tr() + self.before_epoch() def before_epoch(self): """ @@ -35,15 +36,15 @@ def before_epoch(self): self.epo_reg_loss_tr = [0.0 for _ in range(10)] self.epo_task_loss_tr = 0 - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): self.before_epoch() for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): self.tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch) - return self.after_epoch(epoch) + return self.after_epoch(epoch, flag_info) - def after_epoch(self, epoch): + def after_epoch(self, epoch, flag_info): """ observer collect information """ @@ -52,7 +53,7 @@ def after_epoch(self, epoch): self.epo_reg_loss_tr = list_divide(self.epo_reg_loss_tr, self.counter_batch) assert self.epo_loss_tr is not None assert not math.isnan(self.epo_loss_tr) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer assert flag_stop is not None return flag_stop @@ -79,7 +80,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): tensor_d.to(self.device), ) self.optimizer.zero_grad() - loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() if self.lr_scheduler: @@ -96,17 +97,11 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - list_mu_tr_normalized = list_mu_tr - if self.list_reg_over_task_ratio: - assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) - list_mu_tr_normalized = \ - [mu / reg_over_task_ratio if reg_over_task_ratio != 0 - else mu for (mu, reg_over_task_ratio) - in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized ) + assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) loss_reg_penalized_agg = g_tensor_batch_agg(tensor_batch_reg_loss_penalized) @@ -114,4 +109,4 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): self.model.multiplier4task_loss * loss_erm_agg + loss_reg_penalized_agg ) self.log_loss(list_reg_tr_batch, loss_task, loss_penalized) - return loss_penalized + return loss_penalized, list_reg_tr_batch, loss_erm_agg diff --git a/domainlab/algos/trainers/train_causIRL.py b/domainlab/algos/trainers/train_causIRL.py index 085abcbfa..194386b43 100644 --- a/domainlab/algos/trainers/train_causIRL.py +++ b/domainlab/algos/trainers/train_causIRL.py @@ -69,7 +69,7 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): penalty = torch.nan_to_num(self.mmd(first, second)) else: penalty = torch.tensor(0) - loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss = loss + penalty loss.backward() self.optimizer.step() diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 4fe700f45..711a83008 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -47,7 +47,8 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ _ = tensor_d _ = others - tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) - tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) - loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) - return [loss_dial], [get_gamma_reg(self.aconf, self.name)] + with torch.enable_grad(): + tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) + tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) + loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) + return [loss_dial], [get_gamma_reg(self.aconf, self.name)] diff --git a/domainlab/algos/trainers/train_ema.py b/domainlab/algos/trainers/train_ema.py index 8b57368b0..b2ccede40 100644 --- a/domainlab/algos/trainers/train_ema.py +++ b/domainlab/algos/trainers/train_ema.py @@ -57,10 +57,10 @@ def move_average(self, dict_data, epoch): self._ma_iter += 1 return dict_return_ema_para_curr_iter - def after_epoch(self, epoch): + def after_epoch(self, epoch, flag_info=None): torch_model = self.get_model() dict_para = torch_model.state_dict() # only for trainable parameters new_dict_para = self.move_average(dict_para, epoch) # without deepcopy, this seems to work torch_model.load_state_dict(new_dict_para) - super().after_epoch(epoch) + super().after_epoch(epoch, flag_info) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py new file mode 100644 index 000000000..b47ffc69e --- /dev/null +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -0,0 +1,198 @@ +""" +update hyper-parameters during training +""" +from operator import add + +import torch + +from domainlab.algos.trainers.fbopt_mu_controller import HyperSchedulerFeedback +from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupLinear +from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.utils.logger import Logger + + +def list_divide(list_val, scalar): + """ + divide a list by a scalar + """ + return [ele / scalar for ele in list_val] + + +class HyperSetter: + # pylint: disable=too-few-public-methods + """ + mock object to force hyper-parameter in the model + """ + + def __init__(self, dict_hyper): + self.dict_hyper = dict_hyper + + def __call__(self, epoch=None): + return self.dict_hyper + + +class TrainerFbOpt(TrainerBasic): + """ + TrainerHyperScheduler + """ + + def set_scheduler(self, scheduler): + """ + Args: + scheduler: The class name of the scheduler, the object corresponding to + this class name will be created inside model + """ + # model.hyper_init will register the hyper-parameters of the model to scheduler + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) + + def eval_r_loss(self): + """ + evaluate the regularization loss and ERM loss with respect ot parameter dict_theta + ERM loss on all available training data + # TODO: normalize loss via batchsize + """ + self.model.eval() + # mock the model hyper-parameter to be from dict4mu + epo_reg_loss = [] + epo_task_loss = 0 + epo_p_loss = 0 + counter = 0.0 + with torch.no_grad(): + for _, (tensor_x, vec_y, vec_d, *others) in enumerate( + self.loader_tr_no_drop + ): + tensor_x, vec_y, vec_d = ( + tensor_x.to(self.device), + vec_y.to(self.device), + vec_d.to(self.device), + ) + tuple_reg_loss = self.decoratee.cal_reg_loss(tensor_x, vec_y, vec_d, others) + p_loss, *_ = self.decoratee.cal_loss(tensor_x, vec_y, vec_d, others) + if p_loss.dim() > 0: + p_loss = p_loss.sum() + # NOTE: first [0] extract the loss, second [0] get the list + list_b_reg_loss = tuple_reg_loss[0] + list_b_reg_loss_sumed = [ + ele.sum().detach().item() for ele in list_b_reg_loss + ] + if len(epo_reg_loss) == 0: + epo_reg_loss = list_b_reg_loss_sumed + else: + epo_reg_loss = list(map(add, epo_reg_loss, list_b_reg_loss_sumed)) + b_task_loss = ( + self.model.cal_task_loss(tensor_x, vec_y).sum().detach().item() + ) + # sum will kill the dimension of the mini batch + epo_task_loss += b_task_loss + epo_p_loss += p_loss.detach().item() + counter += 1.0 + return ( + list_divide(epo_reg_loss, counter), + epo_task_loss / counter, + epo_p_loss / counter, + ) + + def before_batch(self, epoch, ind_batch): + """ + if hyper-parameters should be updated per batch, then step + should be set to epoch*self.num_batches + ind_batch + """ + if self.flag_update_hyper_per_batch: + # NOTE: if not update per_batch, then not updated + self.model.hyper_update( + epoch * self.num_batches + ind_batch, self.hyper_scheduler + ) + return super().after_batch(epoch, ind_batch) + + def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + # initialize self.decoratee.dict_multiplier + self.decoratee.before_tr() + self.flag_setpoint_updated = False + if self.aconf.force_feedforward: + self.set_scheduler(scheduler=HyperSchedulerWarmupLinear) + else: + self.set_scheduler(scheduler=HyperSchedulerFeedback) + + self.set_model_with_mu() # very small value + + # evaluate regularization loss list + ( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + ) = self.eval_r_loss() + + if self.aconf.tr_with_init_mu: + self.tr_with_init_mu() + # evaluate regularization loss list + ( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + ) = self.eval_r_loss() + + self.hyper_scheduler.set_setpoint( + [ + ele * self.aconf.ini_setpoint_ratio + if ele > 0 + else ele / self.aconf.ini_setpoint_ratio + for ele in self.epo_reg_loss_tr + ], + self.epo_task_loss_tr, + ) # setpoint w.r.t. random initialization of neural network + self.hyper_scheduler.set_k_i_gain(self.epo_reg_loss_tr) + + @property + def list_str_multiplier_na(self): + """ + return the name of multipliers + """ + return self.decoratee.list_str_multiplier_na + + def tr_with_init_mu(self): + """ + erm step with very small mu + """ + super().tr_epoch(-1) + + def set_model_with_mu(self): + """ + set model multipliers + """ + self.decoratee.hyper_update( + epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu) + ) + + def tr_epoch(self, epoch, flag_info=False): + """ + update multipliers only per epoch + """ + self.hyper_scheduler.search_mu( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + self.list_str_multiplier_na, + miter=epoch, + ) + self.set_model_with_mu() + if hasattr(self.decoratee, "dict_multiplier"): + logger = Logger.get_logger() + logger.info(f"current multiplier: {self.decoratee.dict_multiplier}") + + if self._decoratee is not None: + flag = self._decoratee.tr_epoch(epoch, self.flag_setpoint_updated) + # self._decoratee.tr_epoch here will call + # self._decoratee.after_epoch to log the losses, but it only sotre + # the value into self._decoratee, + # so we have to mannually copy the value here + self.epo_loss_tr = self._decoratee.epo_loss_tr + self.epo_reg_loss_tr = self._decoratee.epo_reg_loss_tr + self.epo_task_loss_tr = self._decoratee.epo_task_loss_tr + else: + flag = super().tr_epoch(epoch, self.flag_setpoint_updated) + # is it good to update setpoint after we know the new value of each loss? + self.flag_setpoint_updated = self.hyper_scheduler.update_setpoint( + self.epo_reg_loss_tr, self.epo_task_loss_tr + ) + return flag diff --git a/domainlab/algos/trainers/train_fishr.py b/domainlab/algos/trainers/train_fishr.py index 3580a0721..250b4109d 100644 --- a/domainlab/algos/trainers/train_fishr.py +++ b/domainlab/algos/trainers/train_fishr.py @@ -26,7 +26,7 @@ class TrainerFishr(TrainerBasic): "Fishr: Invariant gradient variances for out-of-distribution generalization." International Conference on Machine Learning. PMLR, 2022. """ - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() @@ -46,7 +46,7 @@ def tr_epoch(self, epoch): self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def var_grads_and_loss(self, tuple_data_domains_batch): @@ -161,10 +161,6 @@ def cal_dict_variance_grads(self, tensor_x, vec_y): inputs=list(self.model.parameters()), retain_graph=True, create_graph=True ) - for name, param in self.model.named_parameters(): - print(name) - print(".grad.shape: ", param.variance.shape) - dict_variance = OrderedDict( [(name, weights.variance.clone()) for name, weights in self.model.named_parameters() diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index 0a89e7691..00b1cb1d3 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -25,7 +25,7 @@ def set_scheduler( flag_update_epoch: if hyper-parameters should be changed per epoch flag_update_batch: if hyper-parameters should be changed per batch """ - self.hyper_scheduler = self.model.hyper_init(scheduler) + self.hyper_scheduler = self.decoratee.hyper_init(scheduler, trainer=self) # let model register its hyper-parameters to the scheduler self.flag_update_hyper_per_epoch = flag_update_epoch self.flag_update_hyper_per_batch = flag_update_batch @@ -37,12 +37,14 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: - self.model.hyper_update( + self.decoratee.hyper_update( epoch * self.num_batches + ind_batch, self.hyper_scheduler ) return super().before_batch(epoch, ind_batch) def before_tr(self): + if hasattr(self.decoratee, "before_tr"): + self.decoratee.before_tr() if self.hyper_scheduler is None: logger = Logger.get_logger() logger.warning( @@ -54,12 +56,14 @@ def before_tr(self): total_steps=self.aconf.warmup, flag_update_epoch=True, ) - super().before_tr() - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): """ update hyper-parameters only per epoch """ if self.flag_update_hyper_per_epoch: - self.model.hyper_update(epoch, self.hyper_scheduler) + self.decoratee.hyper_update(epoch, self.hyper_scheduler) + if hasattr(self.decoratee, "dict_multiplier"): + logger = Logger.get_logger() + logger.info(f"---before epoch, current multiplier: {self.decoratee.dict_multiplier}") return super().tr_epoch(epoch) diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 90b16396e..09748ee1f 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -19,36 +19,6 @@ class TrainerIRM(TrainerBasic): For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """ - def tr_epoch(self, epoch): - list_loaders = list(self.dict_loader_tr.values()) - loaders_zip = zip(*list_loaders) - self.model.train() - self.epo_loss_tr = 0 - - for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): - self.optimizer.zero_grad() - list_domain_loss_erm = [] - list_domain_reg = [] - for batch_domain_e in tuple_data_domains_batch: - tensor_x, tensor_y, tensor_d, *others = batch_domain_e - tensor_x, tensor_y, tensor_d = \ - tensor_x.to(self.device), tensor_y.to(self.device), \ - tensor_d.to(self.device) - list_domain_loss_erm.append( - self.model.cal_task_loss(tensor_x, tensor_y)) - list_1ele_loss_irm, _ = \ - self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - list_domain_reg += list_1ele_loss_irm - loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ - self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) - loss.backward() - self.optimizer.step() - self.epo_loss_tr += loss.detach().item() - self.after_batch(epoch, ind_batch) - - flag_stop = self.observer.update(epoch) # notify observer - return flag_stop - def _cal_phi(self, tensor_x): logits = self.model.cal_logit_y(tensor_x) return logits @@ -60,12 +30,15 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _ = tensor_d _ = others y = tensor_y - phi = self._cal_phi(tensor_x) - dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() - loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) - loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) - grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] - grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] - loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar - loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) - return [loss_irm_tensor], [self.aconf.gamma_reg] + with torch.enable_grad(): + phi = self._cal_phi(tensor_x) + dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() + # interleave instances inside a minibatch + loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) + loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) + grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] + grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] + loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar + loss_irm_scalar = torch.square(loss_irm_scalar) + loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) + return [loss_irm_tensor], [self.aconf.gamma_reg] diff --git a/domainlab/algos/trainers/train_irm_sep_dom.py b/domainlab/algos/trainers/train_irm_sep_dom.py new file mode 100644 index 000000000..94d3bca79 --- /dev/null +++ b/domainlab/algos/trainers/train_irm_sep_dom.py @@ -0,0 +1,39 @@ +""" +use random start to generate adversarial images +""" +import torch +from torch import autograd +from torch.nn import functional as F +from domainlab.algos.trainers.train_irm import TrainerIRM + + +class TrainerIRMSepDom(TrainerIRM): + def tr_epoch(self, epoch, flag_info=False): + list_loaders = list(self.dict_loader_tr.values()) + loaders_zip = zip(*list_loaders) + self.model.train() + self.epo_loss_tr = 0 + + for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): + self.optimizer.zero_grad() + list_domain_loss_erm = [] + list_domain_reg = [] + for batch_domain_e in tuple_data_domains_batch: + tensor_x, tensor_y, tensor_d, *others = batch_domain_e + tensor_x, tensor_y, tensor_d = \ + tensor_x.to(self.device), tensor_y.to(self.device), \ + tensor_d.to(self.device) + list_domain_loss_erm.append( + self.model.cal_task_loss(tensor_x, tensor_y)) + list_1ele_loss_irm, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_domain_reg += list_1ele_loss_irm + loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ + self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) + loss.backward() + self.optimizer.step() + self.epo_loss_tr += loss.detach().item() + self.after_batch(epoch, ind_batch) + + flag_stop = self.observer.update(epoch, flag_info) # notify observer + return flag_stop diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index 72c14ab83..ef4b1c862 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -43,7 +43,7 @@ def init_business( self.tuple_tensor_ref_domain2each_y = None self.tuple_tensor_refdomain2each = None - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): """ # data in one batch comes from two sources: one part from loader, # the other part from match tensor @@ -95,7 +95,7 @@ def tr_epoch(self, epoch): logger.info("\n\nPhase erm+ctr \n\n") self.flag_erm = True - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index 2c7376636..050227999 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -52,7 +52,7 @@ def prepare_ziped_loader(self): ddset_mix = DsetZip(ddset_source, ddset_target) self.loader_tr_source_target = mk_loader(ddset_mix, self.aconf.bs) - def tr_epoch(self, epoch): + def tr_epoch(self, epoch, flag_info=False): self.model.train() self.epo_loss_tr = 0 self.prepare_ziped_loader() @@ -118,5 +118,5 @@ def tr_epoch(self, epoch): self.optimizer.step() self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) - flag_stop = self.observer.update(epoch) # notify observer + flag_stop = self.observer.update(epoch, flag_info) # notify observer return flag_stop diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index c11049375..802778ccf 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -4,12 +4,14 @@ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.algos.trainers.train_ema import TrainerMA from domainlab.algos.trainers.train_dial import TrainerDIAL +from domainlab.algos.trainers.train_fbopt_b import TrainerFbOpt from domainlab.algos.trainers.train_hyper_scheduler \ import TrainerHyperScheduler from domainlab.algos.trainers.train_matchdg import TrainerMatchDG from domainlab.algos.trainers.train_mldg import TrainerMLDG from domainlab.algos.trainers.train_fishr import TrainerFishr from domainlab.algos.trainers.train_irm import TrainerIRM +from domainlab.algos.trainers.train_irm_sep_dom import TrainerIRMSepDom from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL from domainlab.algos.trainers.train_coral import TrainerCoral from domainlab.algos.trainers.train_miro import TrainerMiro @@ -57,7 +59,9 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): chain = TrainerMLDG(chain) chain = TrainerFishr(chain) chain = TrainerIRM(chain) + chain = TrainerIRMSepDom(chain) chain = TrainerHyperScheduler(chain) + chain = TrainerFbOpt(chain) chain = TrainerCausalIRL(chain) chain = TrainerCoral(chain) chain = TrainerMiro(chain) diff --git a/domainlab/algos/zoo_algos.py b/domainlab/algos/zoo_algos.py index e067c1326..4b8240387 100644 --- a/domainlab/algos/zoo_algos.py +++ b/domainlab/algos/zoo_algos.py @@ -5,6 +5,7 @@ from domainlab.algos.builder_dann import NodeAlgoBuilderDANN from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA from domainlab.algos.builder_erm import NodeAlgoBuilderERM +from domainlab.algos.builder_fbopt_dial import NodeAlgoBuilderFbOptDial from domainlab.algos.builder_hduva import NodeAlgoBuilderHDUVA from domainlab.algos.builder_jigen1 import NodeAlgoBuilderJiGen from domainlab.utils.u_import import import_path @@ -45,6 +46,7 @@ def __call__(self): chain = NodeAlgoBuilderJiGen(chain) chain = NodeAlgoBuilderHDUVA(chain) chain = NodeAlgoBuilderAPIModel(chain) + chain = NodeAlgoBuilderFbOptDial(chain) chain = self.register_external_node(chain) node = chain.handle(self.model) head = node diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index bb7bda2b4..204469c43 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -7,6 +7,7 @@ import yaml from domainlab.algos.trainers.args_dial import add_args2parser_dial +from domainlab.algos.trainers.args_fbopt import add_args2parser_fbopt from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg from domainlab.algos.trainers.args_miro import add_args2parser_miro from domainlab.models.args_jigen import add_args2parser_jigen @@ -256,7 +257,6 @@ def mk_parser_main(): "using --gen_plots. " "Default is zoutput/benchmarks/shell_benchmark", ) - parser.add_argument( "--opt", type=str, @@ -284,13 +284,20 @@ def mk_parser_main(): parser.add_argument( "--msel", - choices=["val", "loss_tr"], + choices=["val", "loss_tr", "last"], default="val", help="model selection for early stop: val, loss_tr, recon, the \ elbo and recon only make sense for vae models,\ will be ignored by other methods", ) + parser.add_argument( + "--msel_tr_loss", + choices=["reg", "task"], + default="task", + help="model selection for tr loss", + ) + parser.add_argument( "--model", metavar="an", type=str, default=None, help="algorithm name" ) @@ -377,6 +384,8 @@ def mk_parser_main(): arg_group_jigen = add_args2parser_jigen(arg_group_jigen) args_group_dial = parser.add_argument_group("dial") args_group_dial = add_args2parser_dial(args_group_dial) + args_group_fbopt = parser.add_argument_group("fbopt") + args_group_fbopt = add_args2parser_fbopt(args_group_fbopt) return parser diff --git a/domainlab/exp_protocol/run_experiment.py b/domainlab/exp_protocol/run_experiment.py index 81cdd091f..f45bc9498 100644 --- a/domainlab/exp_protocol/run_experiment.py +++ b/domainlab/exp_protocol/run_experiment.py @@ -143,6 +143,11 @@ def run_experiment( gpu_ind = param_index % num_gpus args.device = str(gpu_ind) + logger.info("*** begin args") + for k, v in vars(args).items(): + logger.info(f"{k} : {v}") + logger.info("*** end args") + if torch.cuda.is_available(): torch.cuda.init() logger.info("before experiment loop: ") diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 113b8c456..299292017 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -14,6 +14,15 @@ class AModel(nn.Module, metaclass=abc.ABCMeta): operations that all models (classification, segmentation, seq2seq) """ + def set_params(self, dict_params): + """ + set + """ + # FIXME: net1.load_state_dict(net2.state_dict()) contains more information than model.named_parameters() like optimizer status + # but I dont know another method to set neural network weights without using load_state_dict + # FIXME: dict_params lack some keys compared to self.state_dict(), why? + self.load_state_dict(dict_params, strict=False) + def __init__(self): super().__init__() self._decoratee = None diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 1917f752e..ea9ef5f47 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -25,7 +25,6 @@ loss_cross_entropy_extended = extend(nn.CrossEntropyLoss(reduction="none")) - class AModelClassif(AModel, metaclass=abc.ABCMeta): """ operations that all classification model should have @@ -244,3 +243,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): device = tensor_x.device bsize = tensor_x.shape[0] return [torch.zeros(bsize).to(device)], [0.0] + # return [], [] diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 59a619e7e..b940c5d71 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -77,6 +77,13 @@ def reset_aux_net(self): self.extract_semantic_feat ) + @property + def list_str_multiplier_na(self): + """ + name of the multiplier + """ + return [self.name + "_alpha"] + def hyper_update(self, epoch, fun_scheduler): """hyper_update. :param epoch: @@ -87,13 +94,13 @@ def hyper_update(self, epoch, fun_scheduler): ) # the __call__ method of hyperparameter scheduler self.alpha = dict_rst[self.name + "_alpha"] - def hyper_init(self, functor_scheduler): + def hyper_init(self, functor_scheduler, trainer=None): """hyper_init. - :param functor_scheduler: + :param functor_scheduler: name of the scheduler class """ parameters = {} parameters[self.name + "_alpha"] = self.alpha - return functor_scheduler(trainer=None, **parameters) + return functor_scheduler(trainer=trainer, **parameters) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): _ = others diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index 362c14518..859fb6fa5 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -9,7 +9,8 @@ from domainlab.utils.utils_class import store_args -def mk_diva(parent_class=VAEXYDClassif, **kwargs): +def mk_diva( + parent_class=VAEXYDClassif, str_diva_multiplier_type="default", **kwargs): # FIXME: should not be default """ Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss. @@ -69,8 +70,9 @@ def __init__( beta_d, beta_x, beta_y, - multiplier_recon=1.0, + mu_recon=1.0, ): + # pylint: disable=too-many-arguments, unused-argument """ gamma: classification loss coefficient """ @@ -102,7 +104,7 @@ def hyper_update(self, epoch, fun_scheduler): self.beta_y = dict_rst[self.name + "_beta_x"] self.beta_x = dict_rst[self.name + "_beta_y"] - def hyper_init(self, functor_scheduler): + def hyper_init(self, functor_scheduler, trainer=None): """ initiate a scheduler object via class name and things inside this model @@ -113,9 +115,29 @@ def hyper_init(self, functor_scheduler): parameters[self.name + "_beta_y"] = self.beta_y parameters[self.name + "_beta_x"] = self.beta_x return functor_scheduler( - trainer=None, **parameters + trainer=trainer, **parameters ) + @property + def list_str_multiplier_na(self): + """ + list of multipliers name + """ + return [f"{self.name}_mu_recon", f"{self.name}_beta_d", f"{self.name}_beta_x", f"{self.name}_beta_y", f"{self.name}_gamma_d"] + + @property + def dict_multiplier(self): + """ + list of multipliers name, which correspond to cal_reg_loss + """ + return { + f"{self.name}_mu_recon": self.mu_recon, + f"{self.name}_beta_d": self.beta_d, + f"{self.name}_beta_x": self.beta_x, + f"{self.name}_beta_y": self.beta_y, + f"{self.name}_gamma_d": self.gamma_d, + } + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): q_zd, zd_q, q_zx, zx_q, q_zy, zy_q = self.encoder(tensor_x) logit_d = self.net_classif_d(zd_q) @@ -147,20 +169,111 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): ) _, d_target = tensor_d.max(dim=1) - lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) + lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) return [ loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d, - ], [ - self.multiplier_recon, - -self.beta_d, - -self.beta_x, - -self.beta_y, - self.gamma_d, - ] - - return ModelDIVA + ], [self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, self.gamma_d] + + class ModelDIVAGammadRecon(ModelDIVA): + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + dict_rst = fun_scheduler(epoch) + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_x"] + self.beta_x = dict_rst[self.name + "_beta_y"] + self.gamma_d = dict_rst[self.name + "_gamma_d"] + self.mu_recon = dict_rst[self.name + "_mu_recon"] + + def hyper_init(self, functor_scheduler, trainer=None): + """ + initiate a scheduler object via class name and things inside this model + + :param functor_scheduler: the class name of the scheduler + """ + parameters = {} + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x + parameters[self.name + "_gamma_d"] = self.gamma_d + parameters[self.name + "_mu_recon"] = self.mu_recon + return functor_scheduler( + trainer=trainer, **parameters + ) + + class ModelDIVAGammadReconPerPixel(ModelDIVAGammadRecon): + def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], [ + mu_recon, + minus_beta_d, + minus_beta_x, + minus_beta_y, + gamma_d, + ] = super().cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + + return [ + torch.div(loss_recon_x, tensor_x.shape[2] * tensor_x.shape[3]), + zd_p_minus_zd_q, + zx_p_minus_zx_q, + zy_p_minus_zy_q, + lc_d, + ], [mu_recon, minus_beta_d, minus_beta_x, minus_beta_y, gamma_d] + + class ModelDIVAGammad(ModelDIVA): + """ + only adjust gammad and beta + """ + + def hyper_update(self, epoch, fun_scheduler): + """hyper_update. + + :param epoch: + :param fun_scheduler: + """ + dict_rst = fun_scheduler(epoch) + self.beta_d = dict_rst[self.name + "_beta_d"] + self.beta_y = dict_rst[self.name + "_beta_x"] + self.beta_x = dict_rst[self.name + "_beta_y"] + self.gamma_d = dict_rst[self.name + "_gamma_d"] + + def hyper_init(self, functor_scheduler, trainer=None): + """ + initiate a scheduler object via class name and things inside this model + + :param functor_scheduler: the class name of the scheduler + """ + parameters = {} + parameters[self.name + "_beta_d"] = self.beta_d + parameters[self.name + "_beta_y"] = self.beta_y + parameters[self.name + "_beta_x"] = self.beta_x + parameters[self.name + "_gamma_d"] = self.gamma_d + parameters[self.name + "_mu_recon"] = self.mu_recon + return functor_scheduler( + trainer=trainer, **parameters + ) + + class ModelDIVADefault(ModelDIVA): + """ + mock + """ + + if str_diva_multiplier_type == "gammad_recon": + return ModelDIVAGammadRecon + if str_diva_multiplier_type == "gammad_recon_per_pixel": + return ModelDIVAGammadReconPerPixel + if str_diva_multiplier_type == "gammad": + return ModelDIVAGammad + if str_diva_multiplier_type == "default": + return ModelDIVADefault + raise RuntimeError( + "not support argument candiates for str_diva_multiplier_type: \ + allowed: default, gammad_recon, gammad_recon_per_pixel, gammad" + ) diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 5c7bb290d..f7ab1d48c 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -70,8 +70,9 @@ def hyper_update(self, epoch, fun_scheduler): self.beta_y = dict_rst[self.name + "_beta_y"] self.beta_x = dict_rst[self.name + "_beta_x"] self.beta_t = dict_rst[self.name + "_beta_t"] + self.mu_recon = dict_rst[self.name + "_mu_recon"] - def hyper_init(self, functor_scheduler): + def hyper_init(self, functor_scheduler, trainer=None): """hyper_init. :param functor_scheduler: """ @@ -83,8 +84,9 @@ def hyper_init(self, functor_scheduler): parameters[self.name + "_beta_y"] = self.beta_y parameters[self.name + "_beta_x"] = self.beta_x parameters[self.name + "_beta_t"] = self.beta_t + parameters[self.name + "_mu_recon"] = self.mu_recon return functor_scheduler( - trainer=None, **parameters + trainer=trainer, **parameters ) @store_args @@ -102,10 +104,11 @@ def __init__( device, zx_dim=0, topic_dim=3, - multiplier_recon=1.0): + mu_recon=1.0, + ): + # pylint: disable=too-many-arguments, unused-argument """ """ super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, **kwargs) - # topic to zd follows Gaussian distribution self.add_module( "net_p_zd", @@ -198,13 +201,27 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q, - ], [ - self.multiplier_recon, - -self.beta_x, - -self.beta_y, - -self.beta_d, - -self.beta_t, - ] + ], [self.mu_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] + + @property + def list_str_multiplier_na(self): + """ + list of multipliers name which matches the order from cal_reg_loss + """ + return [f"{self.name}_mu_recon", f"{self.name}_beta_d", f"{self.name}_beta_x", f"{self.name}_beta_y", f"{self.name}_beta_t"] + + @property + def dict_multiplier(self): + """ + dictionary of multipliers name + """ + return { + f"{self.name}_mu_recon": self.mu_recon, + f"{self.name}_beta_d": self.beta_d, + f"{self.name}_beta_x": self.beta_x, + f"{self.name}_beta_y": self.beta_y, + f"{self.name}_beta_t": self.beta_t, + } def extract_semantic_feat(self, tensor_x): """ diff --git a/domainlab/models/model_jigen.py b/domainlab/models/model_jigen.py index 8d08d9b1b..2d382293f 100644 --- a/domainlab/models/model_jigen.py +++ b/domainlab/models/model_jigen.py @@ -76,6 +76,20 @@ def __init__( self.n_perm = n_perm self.prob_perm = prob_permutation + @property + def list_str_multiplier_na(self): + """ + list of multipliers which match the order in cal_reg_loss + """ + return [self.name + "_alpha"] + + @property + def dict_multiplier(self): + """ + dictionary of multipliers + """ + return {f"{self.name}_alpha": self.alpha} + def dset_decoration_args_algo(self, args, ddset): """ JiGen need to shuffle the tiles of the original image diff --git a/domainlab/tasks/a_task.py b/domainlab/tasks/a_task.py index 6eab450d7..c4104e730 100644 --- a/domainlab/tasks/a_task.py +++ b/domainlab/tasks/a_task.py @@ -17,6 +17,7 @@ class NodeTaskDG(AbstractChainNodeHandler): def __init__(self, succ=None): super().__init__(succ) self._loader_tr = None + self._loader_tr_no_drop = None self._loader_te = None self._loader_val = None self._list_domains = None diff --git a/domainlab/tasks/b_task.py b/domainlab/tasks/b_task.py index 56ac4cd79..b02609039 100644 --- a/domainlab/tasks/b_task.py +++ b/domainlab/tasks/b_task.py @@ -54,9 +54,10 @@ def init_business(self, args, trainer=None): self.dict_loader_tr.update({na_domain: mk_loader(ddset_tr, args.bs)}) self.dict_dset_val.update({na_domain: ddset_val}) ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values())) - flag_shuffling = not args.shuffling_off - # args.shuffling_off default is False -> not False -> True - self._loader_tr = mk_loader(ddset_mix, args.bs, shuffle=flag_shuffling) + self._loader_tr = mk_loader(ddset_mix, args.bs) + self._loader_tr_no_drop = mk_loader( + ddset_mix, args.bs, drop_last=False, shuffle=False + ) ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values())) self._loader_val = mk_loader( diff --git a/domainlab/utils/generate_fbopt_phase_portrait.py b/domainlab/utils/generate_fbopt_phase_portrait.py new file mode 100644 index 000000000..ef4b28806 --- /dev/null +++ b/domainlab/utils/generate_fbopt_phase_portrait.py @@ -0,0 +1,402 @@ +""" +This file is used for generating phase portrait from tensorboard event files. +""" +import argparse +import glob +import os +import numpy as np +import re +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['text.usetex'] = True +plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' +font = {'size': 20} +matplotlib.rc('font', **font) + + +def sav2pdfpage(fig, fname): + pdf_page = PdfPages(fname) + pdf_page.savefig(fig, bbox_inches="tight") + pdf_page.close() + +def latex_to_nonlatex(latex_string): + nonlatex_string = re.sub(r'[{$}]', '', latex_string) + nonlatex_string = nonlatex_string.replace("\\", "") + return nonlatex_string + +class ListFileHandler: + def __init__(self, file_path): + self.file_path = file_path + + def write_lists_to_file(self, list1, list2=None): + with open(self.file_path, 'w') as file: + if list2 is None: + for val1 in list1: + file.write(f"{val1}\n") + else: + for val1, val2 in zip(list1, list2): + file.write(f"{val1} {val2}\n") + + def read_lists_from_file(self): + list1 = [] + list2 = [] + with open(self.file_path, 'r') as file: + for line in file: + values = list(map(float, line.strip().split())) + if len(values) == 1: + list1.append(values[0]) + elif len(values) == 2: + list1.append(values[0]) + list2.append(values[1]) + return list1, list2 + + +# pylint: disable=too-many-arguments +def get_xy_from_event_file( + event_file, + plot1, + plot2=None, + tf_size_guidance=None, + sanity_check=False, + verbose=True, +): + """ + extract x and y values from a tensorboard event file + """ + if tf_size_guidance is None: + # settings for which/how much data is loaded from the + # tensorboard event files + tf_size_guidance = { + "compressedHistograms": 0, + "images": 0, + "scalars": 1e10, # keep unlimited number + "histograms": 0, + } + # load event file + event = EventAccumulator(event_file, tf_size_guidance) + event.Reload() + # print names of available plots + if verbose: + print(f"Event file {event_file} -- available plots:") + print(event.Tags()["scalars"]) + if plot2: + # extract the plot2 values (e.g., reg/dyn0) + y_event = event.Scalars(plot2) + y = [s.value for s in y_event] + x_int = [s.step for s in y_event] + # the .step data are saved as ints in tensorboard, + # (so, in case of phase portrait, we re-extact from 'task') + else: + y = None + # extract the corresponding plot1 values (e.g., 'task') + x_event = event.Scalars(plot1) + x = [s.value for s in x_event] + # sanity check (originally added for the reg/dyn0 vs. task phase portrait; + # shouldn't be needed if plot1 and plot2 represent something else): + if sanity_check: + for i in range(len(x)): + assert int(x[i]) == x_int[i] + + return x, y + + +# pylint: disable=too-many-arguments, too-many-locals, redefined-outer-name, unused-argument +def phase_portrait_combined( + event_files, + colors, + plot1, + plot2, + legend1=None, + legend2=None, + plot_len=None, + skip_n_steps=1, + output_dir=".", +): + """ + combined phase portait for multiple (at least one) Tensorboard + event files in the same plot + """ + fig = plt.figure() + + for event_i in range(len(event_files)): + x, y = get_xy_from_event_file(event_files[event_i], plot1=plot1, plot2=plot2) + + assert len(x) == len(y) + if plot_len is None: + plot_len = len(x) + # truncate x and y to the desired length: + x = x[:plot_len] + y = y[:plot_len] + # skip every n steps + x = x[0::skip_n_steps] + y = y[0::skip_n_steps] + + x = [-ele if ele <0 else ele for ele in x ] + y = [-ele if ele <0 else ele for ele in y] + + head_w_glob = min((max(x) - min(x)) / plot_len, (max(y) - min(y)) / plot_len) + head_w_glob *= skip_n_steps + for i in range(len(x) - 1): + xy_dist = np.sqrt((x[i + 1] - x[i]) ** 2 + (y[i + 1] - y[i]) ** 2) + head_l = xy_dist / plot_len * skip_n_steps + # let width be one tenth of length + head_w = min(head_l/10.0, head_w_glob) + plt.arrow( + x[i], + y[i], + (x[i + 1] - x[i]), + (y[i + 1] - y[i]), + head_width=head_w, + head_length=head_l, + length_includes_head=True, + fc=colors[event_i], + ec=colors[event_i], + alpha=0.8, + ) + # the combination of head_width and head_length make the arrow + # more visible. + # length_includes_head=False makes the arrow stick too far out + # beyond of the point, which let; so, True is used. + + # use finite color + # colors = ["red", "green", "blue", "yellow", "purple"] + # list_color = [colors[i % len(colors)] for i, h in enumerate(x)] + # use numerical color + colors = np.arange(0, plot_len, skip_n_steps) + plt.plot(x[0], y[0], "ko") + # plt.scatter(x, y, s=1, c=np.array(list_color)) + # size + plt.scatter(x, y, s=10, c=colors, cmap='viridis') + plt.yscale("log") + plt.xscale("log") + plt.colorbar() + + if legend1 is None: + legend1 = plot1 + if legend2 is None: + legend2 = plot2 + plt.xlabel(legend1) + plt.ylabel(legend2) + plt.title("output portrait") + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + legend22 = legend2.split(os.sep)[-1] + + fname_legend = latex_to_nonlatex(legend22) + + # write x and y data to a text file: + txt_name = os.path.join(output_dir, f"phase_portrait_{fname_legend}.txt") + fh = ListFileHandler(txt_name) + fh.write_lists_to_file(x, y) + + # save figures + fname = os.path.join(output_dir, f"phase_portrait_{fname_legend}") + plt.savefig(fname+".png", dpi=300, bbox_inches="tight") + plt.savefig(fname+".pdf", format="pdf", bbox_inches="tight") + plt.savefig(fname+".svg", format="svg", bbox_inches="tight") + sav2pdfpage(fig, fname+"_pdfpage.pdf") + + +def two_curves_combined( + event_files, + colors, + plot1, + plot2, + legend1=None, + legend2=None, + output_dir=".", + title=None, + logscale=False, + neg=False, + prefix="output_r_", + plot_len=None): + """ + FIXME: colors parameter is not used + """ + fig = plt.figure() + for event_i in range(len(event_files)): + x, y = get_xy_from_event_file(event_files[event_i], plot1=plot1, plot2=plot2) + if plot_len is None: + plot_len = len(x) + # truncate x and y to the desired length: + x = x[:plot_len] + y = y[:plot_len] + + if neg: + plt.plot(-np.array(x), color="blue") + plt.plot(-np.array(y), color="red") + else: + plt.plot(x, color="blue") + plt.plot(y, color="red") + if logscale: + plt.yscale("log") + plt.xlabel("Epoch") + # plt.ylabel("loss") + if title is not None: + plt.title(title) + if legend1 is None: + legend1 = plot1 + if legend2 is None: + legend2 = plot2 + plt.legend([legend1, legend2]) + + legend11 = legend1.replace(os.sep, "_") + legend22 = legend2.replace(os.sep, "_") + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + fname_legend = latex_to_nonlatex(legend11) + fname_legend += latex_to_nonlatex(legend22) + # write x and y data to a text file: + txt_name = os.path.join(output_dir, prefix+f"{fname_legend}.txt") + fh = ListFileHandler(txt_name) + fh.write_lists_to_file(x, y) + + # save figures + fname_logscale = "_logscale" if logscale else "" + fname = os.path.join(output_dir, prefix+f"{fname_legend}") + plt.savefig(fname+fname_logscale+".png", dpi=300, bbox_inches="tight") + plt.savefig(fname+fname_logscale+".pdf", format="pdf", bbox_inches="tight") + plt.savefig(fname+fname_logscale+".svg", format="svg", bbox_inches="tight") + pdf_page = PdfPages(fname+fname_logscale+"_pdfpage.pdf") + pdf_page.savefig(fig, bbox_inches="tight") + pdf_page.close() + + + + +def plot_single_curve(event_files, colors, plot1, legend1=None, output_dir=".", plot_len=None): + """ + FIXME: colors parameter is not used + """ + fig = plt.figure() + for event_i in range(len(event_files)): + x, _ = get_xy_from_event_file(event_files[event_i], plot1=plot1) + if plot_len is None: + plot_len = len(x) + # truncate x and y to the desired length: + x = x[:plot_len] + plt.plot(x) + plt.yscale("log") + plt.xlabel("Epoch") + if legend1 is None: + legend1 = plot1 + plt.ylabel(legend1) + # plt.title("timecourse") + + legend11 = legend1.replace(os.sep, "_") + fname_legend = latex_to_nonlatex(legend11) + + # save figures + if not os.path.exists(output_dir): + os.makedirs(output_dir) + plt.savefig(os.path.join(output_dir, f"single_timecourse_{fname_legend}.png"), dpi=300, bbox_inches="tight") + plt.savefig(os.path.join(output_dir, f"single_timecourse_{fname_legend}.pdf"), format="pdf", bbox_inches="tight") + plt.savefig(os.path.join(output_dir, f"single_timecourse_{fname_legend}.svg"), format="svg", bbox_inches="tight") + pdf_page = PdfPages(os.path.join(output_dir, f"single_timecourse_{fname_legend}_pdfpage.pdf")) + pdf_page.savefig(fig, bbox_inches="tight") + pdf_page.close() + + # write x and y data to a text file: + txt_name = os.path.join(output_dir, f"single_timecourse_{fname_legend}.txt") + fh = ListFileHandler(txt_name) + fh.write_lists_to_file(list(range(len(x))), x) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="plot") + parser.add_argument("-plot1", "--plot1", default=None, type=str) + parser.add_argument("-plot2", "--plot2", default=None, type=str) + parser.add_argument("-legend1", "--legend1", default=None, type=str) + parser.add_argument("-legend2", "--legend2", default=None, type=str) + parser.add_argument("-plot_len", "--plot_len", default=None, type=int) + parser.add_argument("-skip_n_steps", "--skip_n_steps", default=None, type=int) + parser.add_argument("-title", "--title", default=None, type=str) + parser.add_argument("--output_dir", default=".", type=str) + parser.add_argument("--runs_dir", default="runs", type=str) + parser.add_argument( + "--neg", + action="store_true", + help="if true, plot negative of a list", + ) + parser.add_argument( + "--phase_portrait", + action="store_true", + help="if True plots a phase portrait,\ + otherwise a curve (default)", + ) + args = parser.parse_args() + + # get event files from all available runs + # Tensorboard: * could be the date information, this intermediate directory + # always exist + # events* means all the event folders + # this would combine plots from all subfolders in the runs directory (i.e., all graphs combined in each plot): + #event_files = glob.glob(f"{args.runs_dir}/*/events*") + # this needs the user to specify a specific run (subfolder in the runs directory): + event_files = glob.glob(f"{args.runs_dir}/events*") + if not os.path.isdir(args.runs_dir): raise RuntimeError("runs_dir should be a directory.") + print( + "Using the following tensorboard event files:\n{}".format( + "\n".join(event_files) + ) + ) + + # Different colors for the different runs + cmap = plt.get_cmap("tab10") # Choose a colormap + colors = [cmap(i) for i in range(len(event_files))] + + if args.phase_portrait: + phase_portrait_combined( + event_files, + colors, + plot1=args.plot1, + plot2=args.plot2, + legend1=args.legend1, + legend2=args.legend2, + plot_len=args.plot_len, + skip_n_steps=args.skip_n_steps, + output_dir=args.output_dir, + ) + else: + if args.plot2: + # two curves per plot + two_curves_combined( + event_files, + colors, + plot1=args.plot1, + plot2=args.plot2, + legend1=args.legend1, + legend2=args.legend2, + output_dir=args.output_dir, + title=args.title, + neg=args.neg + ) + two_curves_combined( + event_files, + colors, + plot1=args.plot1, + plot2=args.plot2, + legend1=args.legend1, + legend2=args.legend2, + output_dir=args.output_dir, + title=args.title, + neg=args.neg, + logscale=True + ) + + else: + # one curve per plot + plot_single_curve( + event_files, + colors, + plot1=args.plot1, + legend1=args.legend1, + output_dir=args.output_dir, + ) diff --git a/examples/benchmark/aistat_dial_erm_mhof.yaml b/examples/benchmark/aistat_dial_erm_mhof.yaml new file mode 100644 index 000000000..ab530d72f --- /dev/null +++ b/examples/benchmark/aistat_dial_erm_mhof.yaml @@ -0,0 +1,84 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_mhof_dial + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 4 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_dial_erm: + model: erm + trainer: fbopt_dial + nb4reg_over_task_ratio: 0 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - mu_init + - mu_clip + +dial_erm: + model: erm + nb4reg_over_task_ratio: 0 + trainer: dial + shared: + - gamma_reg diff --git a/examples/benchmark/aistat_irl_erm_mhof.yaml b/examples/benchmark/aistat_irl_erm_mhof.yaml new file mode 100644 index 000000000..9d13e0cf3 --- /dev/null +++ b/examples/benchmark/aistat_irl_erm_mhof.yaml @@ -0,0 +1,64 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_mhof_causal_irl_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 4 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_irm_erm: + model: erm + trainer: fbopt_causalirl + shared: + - ini_setpoint_ratio + - k_i_gain_ratio + +irm_erm: + model: erm + trainer: causalirl + shared: + - gamma_reg diff --git a/examples/benchmark/aistat_irm_erm_mhof.yaml b/examples/benchmark/aistat_irm_erm_mhof.yaml new file mode 100644 index 000000000..f4a508896 --- /dev/null +++ b/examples/benchmark/aistat_irm_erm_mhof.yaml @@ -0,0 +1,95 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_mhof_irm_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 1 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 2 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 4 + distribution: loguniform + # 1778 is the largest gamma_reg using this sampling + + +# Test fbopt with different hyperparameter configurations + +fbopt_irm_erm: + model: erm + trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_clip + - mu_init + +irm_erm: + model: erm + trainer: irm + shared: + - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_irm_erm_mhof2.yaml b/examples/benchmark/aistat_irm_erm_mhof2.yaml new file mode 100644 index 000000000..f6bc9467e --- /dev/null +++ b/examples/benchmark/aistat_irm_erm_mhof2.yaml @@ -0,0 +1,99 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_irm_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + # 1778 is the largest gamma_reg using this sampling + + +# Test fbopt with different hyperparameter configurations + +fbopt_irm_erm: + model: erm + trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + +irm_erm: + model: erm + trainer: irm + shared: + - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_irm_erm_mhof_adadmw.yaml b/examples/benchmark/aistat_irm_erm_mhof_adadmw.yaml new file mode 100644 index 000000000..a0479baec --- /dev/null +++ b/examples/benchmark/aistat_irm_erm_mhof_adadmw.yaml @@ -0,0 +1,100 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_irm_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + opt: AdamW + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + # 1778 is the largest gamma_reg using this sampling + + +# Test fbopt with different hyperparameter configurations + +fbopt_irm_erm: + model: erm + trainer: fbopt_irm + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - nb4reg_over_task_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + +irm_erm: + model: erm + trainer: irm + shared: + - gamma_reg + - nb4reg_over_task_ratio diff --git a/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml new file mode 100644 index 000000000..ba0753e53 --- /dev/null +++ b/examples/benchmark/aistat_pacs_mhof_dial_diva.yaml @@ -0,0 +1,85 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_dial_diva_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 20 + es: 5 + bs: 16 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + k_i_gain_ratio: + min: 0.01 + max: 0.90 + num: 3 + step: 0.0001 + distribution: uniform + + dial_lr: + min: 1e-5 + max: 1e-3 + num: 2 + step: 1e-5 + distribution: uniform + + dial_epsilon: + min: 1e-5 + max: 1e-3 + num: 2 + step: 1e-5 + distribution: uniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +dial_fbopt: + model: diva + trainer: fbopt_dial + gamma_y: 1.0 + shared: + - k_i_gain_ratio + - dial_lr + - dial_epsilon + +dial_diva: + model: diva + trainer: dial + shared: + - dial_lr + - dial_epsilon + - gamma_y + - gamma_d diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml new file mode 100644 index 000000000..8225eed2f --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof.yaml @@ -0,0 +1,117 @@ +mode: grid + +output_dir: zoutput/benchmarks/mhof_trainer_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + lr_scheduler: CosineAnnealingLR + + +Shared params: + opt: + distribution: categorical + datatype: str + values: + - AdamW + - Adam + + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.001 + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + +fixed_multplier_irm_dial: + model: erm + trainer: irm_dial + shared: + - gamma_reg + - opt + +feedforward_irm_dial: + model: erm + trainer: hyperscheduler_irm_dial + shared: + - gamma_reg + - opt + +mhof_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.9 + force_setpoint_change_once: True + coeff_ma_setpoint: 0.0 + tr_with_init_mu: True + nb4reg_over_task_ratio: 0 + str_setpoint_ada: "SliderAnyComponent()" + shared: + - opt + - k_i_gain_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + + erm: + model: erm + shared: + - opt diff --git a/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml new file mode 100644 index 000000000..d8964c6d9 --- /dev/null +++ b/examples/benchmark/aistat_trainer_combo_dial_irm_erm_mhof_only.yaml @@ -0,0 +1,99 @@ +mode: grid + +output_dir: zoutput/benchmarks/only_mhof_combo_irm_dial_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 100 + epos_min: 10 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + nb4reg_over_task_ratio: + distribution: categorical # name of the distribution + datatype: int + values: # concrete values to choose from + - 0 + - 100 + + mu_init: + distribution: categorical + datatype: float + values: + - 0.000001 + - 0.001 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 100 + - 1000 + + exp_shoulder_clip: + distribution: categorical + datatype: float + values: + - 1 + - 5 + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 2 + distribution: uniform + + gamma_reg: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1 + - 10 + - 100 + - 1000 + + str_setpoint_ada: + distribution: categorical + datatype: str + values: + - "SliderAnyComponent()" + - "SliderAllComponent()" + - "DominateAnyComponent()" + + +fbopt_irm_dial_erm: + model: erm + trainer: fbopt_irm_dial + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + coeff_ma_setpoint: 0.0 + tr_with_init_mu: True + nb4reg_over_task_ratio: 0 + shared: + - k_i_gain_ratio + - mu_init + - mu_clip + - exp_shoulder_clip + - str_setpoint_ada diff --git a/examples/benchmark/benchmark_pacs_resnet_grid_jigen.yaml b/examples/benchmark/benchmark_pacs_resnet_grid_jigen.yaml new file mode 100644 index 000000000..28c6705e3 --- /dev/null +++ b/examples/benchmark/benchmark_pacs_resnet_grid_jigen.yaml @@ -0,0 +1,52 @@ +# test benchmark config. + +mode: grid + +test_domains: + - sketch + +output_dir: zoutput/benchmarks/pacs_benchmark_grid + +startseed: 0 +endseed: 1 # currently included + + +domainlab_args: + tpath: examples/tasks/task_pacs_path_list.py + dmem: False + lr: 5e-5 + epos: 500 + es: 1 + bs: 32 + npath: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + san_check: False + + +Shared params: + gamma_reg: + min: 0.01 + max: 10 + step: 0.1 + distribution: loguniform + num: 3 + + +jigen: # name + model: jigen + shared: + - gamma_reg + + hyperparameters: + # probability of permutating the tiles of an image, pperm = 0 -> pure classification + pperm: + min: 0.7 + max: 1 + step: 0.1 + distribution: uniform + num: 3 + + +erm: + model: erm diff --git a/examples/benchmark/mnist_dann_fbopt.yaml b/examples/benchmark/mnist_dann_fbopt.yaml new file mode 100644 index 000000000..8bdbe444c --- /dev/null +++ b/examples/benchmark/mnist_dann_fbopt.yaml @@ -0,0 +1,62 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.0001 + epos: 500 + es: 100 + bs: 64 + nname: conv_bn_pool_2 + san_check: False + exp_shoulder_clip: 10 + mu_clip: 10_000 + coeff_ma: 0.5 + no_tensorboard: False + + + +Shared params: + ini_setpoint_ratio: + min: 0.9 + max: 0.99 + num: 3 + step: 0.05 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + num: 2 + distribution: uniform + +# Test fbopt with different hyperparameter configurations + +dann_fbopt: + model: dann + trainer: fbopt + ini_setpoint_ratio: 0.9 + shared: + - k_i_gain + - mu_init + +erm: + model: erm diff --git a/examples/benchmark/mnist_diva_fbopt_alone.yaml b/examples/benchmark/mnist_diva_fbopt_alone.yaml new file mode 100644 index 000000000..c483b0e68 --- /dev/null +++ b/examples/benchmark/mnist_diva_fbopt_alone.yaml @@ -0,0 +1,92 @@ +mode: grid + +output_dir: zoutput/benchmarks/mnist_diva_fbopt_alone + +sampling_seed: 0 +startseed: 0 +endseed: 10 + +test_domains: + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.001 + epos: 500 + epos_min: 20 + es: 5 + bs: 64 + zx_dim: 0 + zy_dim: 32 + zd_dim: 32 + nname: conv_bn_pool_2 + nname_dom: conv_bn_pool_2 + nname_encoder_x2topic_h: conv_bn_pool_2 + nname_encoder_sandwich_x2h4zd: conv_bn_pool_2 + san_check: False + coeff_ma: 0.5 + no_tensorboard: False + + + +Shared params: + ini_setpoint_ratio: + min: 0.9 + max: 0.99 + num: 3 + distribution: uniform + + k_i_gain: + min: 1e-4 + max: 1e-3 + num: 2 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + mu_clip: + distribution: categorical + datatype: float + values: + - 1000 + - 100 + - 10 + - 1 + +# Test fbopt with different hyperparameter configurations + +diva_fbopt_a: + model: diva + trainer: fbopt + str_diva_multiplier_type: gammad_recon + gamma_y: 1.0 + init_setpoint_ratio: 0.99 + exp_shoulder_clip: 1 + mu_init: 1e-6 + shared: + - k_i_gain + - mu_clip + +erm: + model: erm diff --git a/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml b/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml new file mode 100644 index 000000000..b687b69f4 --- /dev/null +++ b/examples/benchmark/mnist_diva_fbopt_and_baselines.yaml @@ -0,0 +1,122 @@ +mode: grid + +output_dir: zoutput/benchmarks/mnist_diva_fbopt_and_baselines + +sampling_seed: 0 +startseed: 0 +endseed: 10 + +test_domains: + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.001 + epos: 5000 + epos_min: 500 + es: 5 + bs: 64 + zx_dim: 0 + zy_dim: 32 + zd_dim: 32 + nname: conv_bn_pool_2 + nname_dom: conv_bn_pool_2 + nname_encoder_x2topic_h: conv_bn_pool_2 + nname_encoder_sandwich_x2h4zd: conv_bn_pool_2 + san_check: False + coeff_ma: 0.5 + no_tensorboard: False + + + +Shared params: + ini_setpoint_ratio: + min: 0.9 + max: 0.99 + num: 3 + distribution: uniform + + k_i_gain: + min: 1e-4 + max: 1e-3 + num: 2 + distribution: loguniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 5 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + mu_clip: + distribution: categorical + datatype: float + values: + - 1000 + - 100 + - 10 + - 1 + +# Test fbopt with different hyperparameter configurations + +diva_fbopt_a: + model: diva + trainer: fbopt + str_diva_multiplier_type: gammad_recon + gamma_y: 1.0 + init_setpoint_ratio: 0.99 + exp_shoulder_clip: 1 + mu_init: 1e-6 + shared: + - k_i_gain_ratio + - mu_clip + +diva_feedforward_a: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: gammad_recon + gamma_y: 1.0 + shared: + - gamma_d + +diva_default: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y + +diva_fixed_penalty: + model: diva + trainer: basic + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y + +erm: + model: erm diff --git a/examples/benchmark/mnist_jigen_fbopt_alone.yaml b/examples/benchmark/mnist_jigen_fbopt_alone.yaml new file mode 100644 index 000000000..388bdcf13 --- /dev/null +++ b/examples/benchmark/mnist_jigen_fbopt_alone.yaml @@ -0,0 +1,79 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt + +sampling_seed: 0 +startseed: 0 +endseed: 4 + +test_domains: + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.001 + epos: 1000 + es: 100 + bs: 64 + nname: conv_bn_pool_2 + san_check: False + exp_shoulder_clip: 10 + mu_clip: 10 + coeff_ma: 0.5 + no_tensorboard: False + pperm: 0.5 + + + +Shared params: + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + coeff_ma_setpoint: + distribution: uniform + min: 0.0 + max: 0.9 + num: 2 + setpoint_rewind: + distribution: categorical + datatype: str + values: + - "yes" + - "no" + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + num: 2 + distribution: loguniform + + gamma_reg: + min: 0.01 + max: 10_000 + num: 10 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + coeff_ma_output_state: 0.5 + coeff_ma_setpoint: 0.5 + shared: + - k_i_gain + - mu_init diff --git a/examples/benchmark/mnist_jigen_fbopt_and_others.yaml b/examples/benchmark/mnist_jigen_fbopt_and_others.yaml new file mode 100644 index 000000000..bd4857610 --- /dev/null +++ b/examples/benchmark/mnist_jigen_fbopt_and_others.yaml @@ -0,0 +1,80 @@ +mode: grid + +output_dir: zoutput/benchmarks/mnist_fbopt_and_others + +sampling_seed: 0 +startseed: 0 +endseed: 4 + +test_domains: + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.001 + epos: 2000 + epos_min: 100 + es: 1 + bs: 64 + nname: conv_bn_pool_2 + san_check: False + no_tensorboard: False + pperm: 0.5 + + + +Shared params: + k_i_gain: + min: 1e-4 + max: 1e-3 + num: 2 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + num: 3 + distribution: loguniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 3 + distribution: loguniform + + mu_clip: + min: 0.01 + max: 1e4 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + mu_init: 1e-6 + shared: + - k_i_gain + - mu_clip + +jigen_feedforward: + model: jigen + trainer: hyperscheduler + shared: + - gamma_reg + +jigen_fixed_penalty: + model: jigen + trainer: basic + shared: + - gamma_reg + +erm: + model: erm diff --git a/examples/benchmark/pacs_dann_fbopt.yaml b/examples/benchmark/pacs_dann_fbopt.yaml new file mode 100644 index 000000000..b5c743033 --- /dev/null +++ b/examples/benchmark/pacs_dann_fbopt.yaml @@ -0,0 +1,53 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_path_list.py + dmem: False + lr: 5e-5 + epos: 1 + es: 5 + bs: 64 + san_check: True + npath: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + exp_shoulder_clip: 10 + mu_clip: 10_000 + coeff_ma: 0.5 + no_tensorboard: False + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + +# Test fbopt with different hyperparameter configurations + +dann_fbopt: + model: dann + trainer: fbopt + shared: + - ini_setpoint_ratio + - k_i_gain + - es diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml new file mode 100644 index 000000000..35c93c236 --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki.yaml @@ -0,0 +1,106 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_alone_zx + +sampling_seed: 0 + +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 16 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: diva + trainer: fbopt + force_setpoint_change_once: True + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + str_diva_multiplier_type: gammad_recon + mu_init: 1e-6 + shared: + - k_i_gain_ratio + - mu_clip diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_1run.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_1run.yaml new file mode 100644 index 000000000..9ec43ac7b --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_1run.yaml @@ -0,0 +1,107 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_alone_single_run + +sampling_seed: 0 + +startseed: 0 +endseed: 0 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 16 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: diva + trainer: fbopt + force_setpoint_change_once: True + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + str_diva_multiplier_type: gammad_recon + mu_init: 1e-6 + k_i_gain_ratio: 0.5 + mu_clip: 10 + coeff_ma_output_state: 0.0 + coeff_ma_setpoint: 0.0 diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml new file mode 100644 index 000000000..36fd10554 --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_autoki_output_ma_9.yaml @@ -0,0 +1,108 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_autoki_aug + +sampling_seed: 0 + +startseed: 0 +endseed: 10 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 1 + bs: 64 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + - 0.8 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 10 + num: 10 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: diva + trainer: fbopt + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + str_diva_multiplier_type: gammad_recon + mu_init: 1e-6 + shared: + - k_i_gain_ratio + - coeff_ma_output_state diff --git a/examples/benchmark/pacs_diva_fbopt_alone_es1_random_ki.yaml b/examples/benchmark/pacs_diva_fbopt_alone_es1_random_ki.yaml new file mode 100644 index 000000000..24177c0bc --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_alone_es1_random_ki.yaml @@ -0,0 +1,102 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_alone + +sampling_seed: 0 + +startseed: 0 +endseed: 2 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 200 + epos_min: 20 + es: 1 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: diva + trainer: fbopt + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + str_diva_multiplier_type: gammad_recon + coeff_ma_output_state: 0.1 + mu_init: 0.000001 + shared: + - k_i_gain + - mu_clip diff --git a/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml b/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml new file mode 100644 index 000000000..e2a78230a --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_alone_fixed.yaml @@ -0,0 +1,97 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_alone + +sampling_seed: 0 + +startseed: 0 +endseed: 5 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_path_list.py + dmem: False + lr: 5e-5 + epos: 200 + epos_min: 20 + es: 5 + bs: 64 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + + mu_clip: + distribution: categorical + datatype: int + values: + - 10 + - 1000 + - 1000_000 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 3 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: diva + trainer: fbopt + exp_shoulder_clip: 5 + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + mu_init: 0.000001 diff --git a/examples/benchmark/pacs_diva_fbopt_and_baselines.yaml b/examples/benchmark/pacs_diva_fbopt_and_baselines.yaml new file mode 100644 index 000000000..7ea54939e --- /dev/null +++ b/examples/benchmark/pacs_diva_fbopt_and_baselines.yaml @@ -0,0 +1,119 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_fbopt_and_baselines_aug + +sampling_seed: 0 + +startseed: 0 +endseed: 6 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.9 + max: 0.99 + num: 3 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 3 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 3 + distribution: uniform + + + mu_init: + min: 0.000001 + max: 0.9 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + +# Test fbopt with different hyperparameter configurations +diva_fbopt_a: + model: diva + trainer: fbopt + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + force_setpoint_change_once: True + str_diva_multiplier_type: gammad_recon + coeff_ma_output_state: 0.1 + mu_init: 1e-6 + shared: + - k_i_gain_ratio + - mu_clip + +diva_feedforward_full: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: gammad_recon + gamma_y: 1.0 + shared: + - gamma_d + +diva_default: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y + +diva_fixed_penalty: + model: diva + trainer: basic + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y + +erm: + model: erm diff --git a/examples/benchmark/pacs_diva_others.yaml b/examples/benchmark/pacs_diva_others.yaml new file mode 100644 index 000000000..55d6a7f37 --- /dev/null +++ b/examples/benchmark/pacs_diva_others.yaml @@ -0,0 +1,68 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_diva_others + +sampling_seed: 0 + +startseed: 0 +endseed: 6 +test_domains: - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 5 + bs: 32 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + +diva_feedforward_full: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: gammad_recon + shared: + - gamma_d + - gamma_y + +diva_default: + model: diva + trainer: hyperscheduler + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y + +diva_fixed_penalty: + model: diva + trainer: basic + str_diva_multiplier_type: default + shared: + - gamma_d + - gamma_y diff --git a/examples/benchmark/pacs_fbopt_fishr_erm.yaml b/examples/benchmark/pacs_fbopt_fishr_erm.yaml new file mode 100644 index 000000000..781a2518e --- /dev/null +++ b/examples/benchmark/pacs_fbopt_fishr_erm.yaml @@ -0,0 +1,66 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_fishr_erm_pacs + +sampling_seed: 0 +startseed: 0 +endseed: 0 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 10 + epos_min: 2 + es: 5 + bs: 32 + san_check: False + nname: alexnet + nname_dom: alexnet + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.05 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 1e4 + num: 3 + distribution: loguniform + + +# Test fbopt with different hyperparameter configurations + +fbopt_fishr_erm: + model: erm + trainer: fbopt_fishr + shared: + - ini_setpoint_ratio + - k_i_gain + - gamma_reg + +fishr_erm: + model: erm + trainer: fishr + shared: + - gamma_reg diff --git a/examples/benchmark/pacs_hduva_baselines.yaml b/examples/benchmark/pacs_hduva_baselines.yaml new file mode 100644 index 000000000..cbdb704eb --- /dev/null +++ b/examples/benchmark/pacs_hduva_baselines.yaml @@ -0,0 +1,111 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_hduva_fbopt_and_baselines + +sampling_seed: 0 + +startseed: 0 +endseed: 10 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 1 + bs: 16 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + - 0.8 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 10 + num: 10 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +hduva_beta_warmup: + model: hduva + shared: + - gamma_y + +hduva_fbopt_full: + model: hduva + trainer: fbopt + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + mu_init: 1e-6 + shared: + - k_i_gain_ratio diff --git a/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml b/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml new file mode 100644 index 000000000..d773cb25b --- /dev/null +++ b/examples/benchmark/pacs_hduva_fbopt_alone_es1_autoki_aug.yaml @@ -0,0 +1,107 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_hduva_fbopt_alone_aug + +sampling_seed: 0 + +startseed: 0 +endseed: 10 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 100 + es: 10 + bs: 16 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + - 0.8 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 10 + num: 10 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +diva_fbopt_full: + model: hduva + trainer: fbopt + gamma_y: 1.0 + ini_setpoint_ratio: 0.99 + mu_init: 1e-6 + mu_clip: 10 + shared: + - k_i_gain_ratio diff --git a/examples/benchmark/pacs_hduva_matchdg.yaml b/examples/benchmark/pacs_hduva_matchdg.yaml new file mode 100644 index 000000000..f8c99d6d3 --- /dev/null +++ b/examples/benchmark/pacs_hduva_matchdg.yaml @@ -0,0 +1,112 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_hduva_fbopt_alone_aug + +sampling_seed: 0 + +startseed: 0 +endseed: 10 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 100 + es: 10 + bs: 16 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + ini_setpoint_ratio: + min: 0.990 + max: 0.999 + num: 2 + distribution: uniform + + str_diva_multiplier_type: + distribution: categorical + datatype: str + values: + - gammad_recon + - gammad_recon_per_pixel + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + - 0.8 + + mu_clip: + distribution: categorical + datatype: float + values: + - 10 + - 1000 + - 1 + - 100 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: uniform + + k_i_gain_ratio: + min: 0.1 + max: 10 + num: 10 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00001 + step: 0.000001 + num: 3 + distribution: loguniform + + gamma_y: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_d: + min: 1.0 + max: 1e6 + step: 100 + num: 3 + distribution: loguniform + + gamma_reg: + min: 0.01 + max: 10 + distribution: loguniform + num: 3 + + + + + +# Test fbopt with different hyperparameter configurations, no noeed to tune mu_clip since this is the job of KI gain when mu_init is small + +match_duva: + model: matchhduva + epochs_ctr: 10 + shared: + - gamma_y + - gamma_reg diff --git a/examples/benchmark/pacs_jigen_baslines4fbopt.yaml b/examples/benchmark/pacs_jigen_baslines4fbopt.yaml new file mode 100644 index 000000000..8c4d99d3d --- /dev/null +++ b/examples/benchmark/pacs_jigen_baslines4fbopt.yaml @@ -0,0 +1,74 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_jigen_fbopt_baselines + +sampling_seed: 0 + +startseed: 0 +endseed: 4 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug_noflip.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 100 + es: 1 + bs: 64 + san_check: True + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + k_i_gain: + min: 0.0001 + max: 0.01 + num: 3 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00005 + num: 3 + distribution: loguniform + + pperm: + min: 0.1 + max: 0.9 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 10 + num: 5 + distribution: loguniform + +# Test fbopt with different hyperparameter configurations + + +jigen_feedforward: + model: jigen + trainer: hyperscheduler + shared: + - gamma_reg + - pperm + +jigen_fixed_penalty: + model: jigen + trainer: basic + shared: + - gamma_reg + - pperm + +erm: + model: erm diff --git a/examples/benchmark/pacs_jigen_fbopt_alone.yaml b/examples/benchmark/pacs_jigen_fbopt_alone.yaml new file mode 100644 index 000000000..3107894ed --- /dev/null +++ b/examples/benchmark/pacs_jigen_fbopt_alone.yaml @@ -0,0 +1,85 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_jigen_fbopt_alone + +sampling_seed: 0 + +startseed: 0 +endseed: 3 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug_noflip.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 200 + es: 1 + bs: 64 + san_check: True + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + + +Shared params: + k_i_gain: + min: 0.0001 + max: 0.01 + num: 3 + distribution: loguniform + + k_i_gain_ratio: + min: 0.1 + max: 10 + num: 3 + distribution: loguniform + + + + mu_init: + min: 0.000001 + max: 0.00005 + num: 3 + distribution: loguniform + + pperm: + min: 0.1 + max: 0.7 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 10_000 + num: 10 + distribution: loguniform + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.1 + - 0.5 + - 0.9 + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + coeff_ma: 0.5 + mu_init: 1e-6 + shared: + - k_i_gain_ratio + - pperm + +erm: + model: erm diff --git a/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml b/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml new file mode 100644 index 000000000..3c70d07b6 --- /dev/null +++ b/examples/benchmark/pacs_jigen_fbopt_alone_autoki.yaml @@ -0,0 +1,92 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_jigen_fbopt_alone + +sampling_seed: 0 + +startseed: 0 +endseed: 3 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_path_list.py + dmem: False + lr: 5e-5 + epos: 500 + epos_min: 50 + force_setpoint_change_once: True + es: 1 + bs: 64 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + pperm: 0.1 + # pperm correspond to 1-bias_wholeimage in https://github.com/fmcarlucci/JigenDG + + +Shared params: + k_i_gain: + min: 0.0001 + max: 0.01 + num: 3 + distribution: loguniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 5 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00005 + num: 3 + distribution: loguniform + + pperm: + min: 0.1 + max: 0.9 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 10_000 + num: 10 + distribution: loguniform + + coeff_ma_setpoint: + distribution: categorical + datatype: float + values: + - 0.0 + - 0.5 + - 0.9 + + coeff_ma_output_state: + distribution: categorical + datatype: float + values: + - 0.0 + - 0.5 + - 0.9 + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + mu_init: 1e-6 + + shared: + - k_i_gain_ratio + - coeff_ma_output_state + - coeff_ma_setpoint diff --git a/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml b/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml new file mode 100644 index 000000000..1421913b3 --- /dev/null +++ b/examples/benchmark/pacs_jigen_fbopt_and_baselines.yaml @@ -0,0 +1,90 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt_pacs_full + +sampling_seed: 0 + +startseed: 0 +endseed: 5 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_path_list.py + dmem: False + lr: 5e-5 + epos: 200 + epos_min: 20 + es: 1 + bs: 64 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + pperm: 0.5 + + +Shared params: + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: loguniform + + mu_init: + min: 0.000001 + max: 0.00005 + num: 2 + distribution: loguniform + + pperm: + min: 0.1 + max: 0.9 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 10_000 + num: 4 + distribution: loguniform + + mu_clip: + distribution: categorical + datatype: float + values: + - 1 + - 10 + - 100 + - 1000 + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + mu_init: 0.000001 + shared: + - k_i_gain + - mu_clip + +jigen_feedforward: + model: jigen + trainer: hyperscheduler + shared: + - gamma_reg + +jigen_fixed_penalty: + model: jigen + trainer: basic + shared: + - gamma_reg + +erm: + model: erm diff --git a/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml b/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml new file mode 100644 index 000000000..3b0f8dba6 --- /dev/null +++ b/examples/benchmark/pacs_jigen_fbopt_and_baselines_aug.yaml @@ -0,0 +1,105 @@ +mode: grid + +output_dir: zoutput/benchmarks/pacs_aug_jigen + +sampling_seed: 0 + +startseed: 0 +endseed: 3 + +test_domains: + - sketch + +domainlab_args: + tpath: examples/tasks/task_pacs_aug_noflip.py + dmem: False + epos: 500 + epos_min: 200 + es: 1 + bs: 64 + san_check: False + npath: examples/nets/resnet50domainbed.py + npath_dom: examples/nets/resnet50domainbed.py + npath_encoder_x2topic_h: examples/nets/resnet50domainbed.py + npath_encoder_sandwich_x2h4zd: examples/nets/resnet50domainbed.py + zx_dim: 0 + zy_dim: 64 + zd_dim: 64 + pperm: 0.1 + + +Shared params: + lr: + distribution: categorical + values: + - 5e-5 + - 1e-3 + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + distribution: loguniform + + k_i_gain_ratio: + min: 0.1 + max: 1 + num: 4 + distribution: uniform + + mu_init: + min: 0.000001 + max: 0.00005 + num: 2 + distribution: loguniform + + pperm: + min: 0.1 + max: 0.9 + num: 3 + distribution: uniform + + gamma_reg: + min: 0.01 + max: 10 + num: 4 + distribution: loguniform + + mu_clip: + distribution: categorical + datatype: float + values: + - 0.01 + - 0.1 + - 1.0 + - 10 + +# Test fbopt with different hyperparameter configurations + +jigen_feedback: + model: jigen + trainer: fbopt + ini_setpoint_ratio: 0.99 + mu_init: 1e-6 + force_setpoint_change_once: True + shared: + - k_i_gain_ratio + - mu_clip + - lr + +jigen_feedforward: + model: jigen + trainer: hyperscheduler + shared: + - gamma_reg + - lr + +jigen_fixed_penalty: + model: jigen + trainer: basic + shared: + - gamma_reg + - lr + +erm: + model: erm diff --git a/examples/benchmark/test_benchmark_fbopt.yaml b/examples/benchmark/test_benchmark_fbopt.yaml new file mode 100644 index 000000000..87ce24a8d --- /dev/null +++ b/examples/benchmark/test_benchmark_fbopt.yaml @@ -0,0 +1,70 @@ +mode: grid + +output_dir: zoutput/benchmarks/benchmark_fbopt + +num_param_samples: 8 +sampling_seed: 0 +startseed: 0 +endseed: 2 + +test_domains: + - 3 + - 0 + + +domainlab_args: + task: mnistcolor10 + tr_d: [1, 2] + dmem: False + lr: 0.001 + epos: 3 + es: 5 + bs: 64 + nname: conv_bn_pool_2 + san_check: True + + +Shared params: + ini_setpoint_ratio: + min: 0.5 + max: 0.99 + num: 2 + step: 0.001 + distribution: uniform + + k_i_gain: + min: 0.0001 + max: 0.01 + num: 2 + step: 0.0001 + distribution: uniform + + exp_shoulder_clip: + min: 5 + max: 10 + num: 2 + step: 1 + distribution: uniform + + mu_clip: + min: 0.001 + max: 1e4 + num: 2 + step: 10 + distribution: loguniform + + coeff_ma: + min: 0.001 + max: 0.99 + num: 2 + step: 0.001 + distribution: uniform + +# Test fbopt with different hyperparameter configurations + +jigen_fbopt: + model: jigen + trainer: fbopt + + shared: + - ini_setpoint_raio diff --git a/examples/tasks/task_pacs_aug.py b/examples/tasks/task_pacs_aug.py index e971bea8c..0d334a45a 100644 --- a/examples/tasks/task_pacs_aug.py +++ b/examples/tasks/task_pacs_aug.py @@ -11,9 +11,10 @@ from domainlab.tasks.utils_task import ImSize # change this to absolute directory where you have the raw images from PACS, -G_PACS_RAW_PATH = "domainlab/zdata/pacs/PACS" +G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in -# domainlab/zdata/pacs_split folder of domainlab +# domainlab/zdata/pacs_split folder of domainlab, +# but PACS dataset is too big to put into domainlab folder def get_task(na=None): diff --git a/fbopt_mnist_diva_pixel.sh b/fbopt_mnist_diva_pixel.sh new file mode 100644 index 000000000..bac129db9 --- /dev/null +++ b/fbopt_mnist_diva_pixel.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --model=diva --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=2000 --mu_init=0.00001 --gamma_y=1.0 --mu_clip=10 --str_diva_multiplier_type=gammad_recon_per_pixel diff --git a/pyproject.toml b/pyproject.toml index 8b158da35..964ed0b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ pandas = "^1.5.1" scikit-learn = "^1.2.1" pyyaml = "^6.0" gdown = "^4.7.1" +tensorboard = "^2.15.1" backpack-for-pytorch = "^1.6.0" [tool.poetry.scripts] diff --git a/requirements_notorch.txt b/requirements_notorch.txt new file mode 100644 index 000000000..5aec65167 --- /dev/null +++ b/requirements_notorch.txt @@ -0,0 +1,79 @@ +appdirs==1.4.4 ; python_version >= "3.9" and python_version < "4.0" +attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0" +beautifulsoup4==4.12.2 ; python_version >= "3.9" and python_version < "4.0" +certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0" +charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" +configargparse==1.7 ; python_version >= "3.9" and python_version < "4.0" +connection-pool==0.0.3 ; python_version >= "3.9" and python_version < "4.0" +contourpy==1.1.0 ; python_version >= "3.9" and python_version < "4.0" +cycler==0.11.0 ; python_version >= "3.9" and python_version < "4.0" +datrie==0.8.2 ; python_version >= "3.9" and python_version < "4.0" +docutils==0.20.1 ; python_version >= "3.9" and python_version < "4.0" +dpath==2.1.6 ; python_version >= "3.9" and python_version < "4.0" +fastjsonschema==2.18.0 ; python_version >= "3.9" and python_version < "4.0" +filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" +fonttools==4.42.0 ; python_version >= "3.9" and python_version < "4.0" +gdown==4.7.1 ; python_version >= "3.9" and python_version < "4.0" +gitdb==4.0.10 ; python_version >= "3.9" and python_version < "4.0" +gitpython==3.1.32 ; python_version >= "3.9" and python_version < "4.0" +humanfriendly==10.0 ; python_version >= "3.9" and python_version < "4.0" +idna==3.4 ; python_version >= "3.9" and python_version < "4.0" +importlib-resources==6.0.1 ; python_version >= "3.9" and python_version < "3.10" +jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0" +joblib==1.3.2 ; python_version >= "3.9" and python_version < "4.0" +jsonschema-specifications==2023.7.1 ; python_version >= "3.9" and python_version < "4.0" +jsonschema==4.19.0 ; python_version >= "3.9" and python_version < "4.0" +jupyter-core==5.3.1 ; python_version >= "3.9" and python_version < "4.0" +kiwisolver==1.4.4 ; python_version >= "3.9" and python_version < "4.0" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0" +markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0" +matplotlib==3.7.2 ; python_version >= "3.9" and python_version < "4.0" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0" +nbformat==5.9.2 ; python_version >= "3.9" and python_version < "4.0" +numpy==1.25.2 ; python_version < "4.0" and python_version >= "3.9" +packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" +pandas==1.5.3 ; python_version >= "3.9" and python_version < "4.0" +pillow==9.5.0 ; python_version >= "3.9" and python_version < "4.0" +plac==1.3.5 ; python_version >= "3.9" and python_version < "4.0" +platformdirs==3.10.0 ; python_version >= "3.9" and python_version < "4.0" +psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0" +pulp==2.7.0 ; python_version >= "3.9" and python_version < "4.0" +pygments==2.16.1 ; python_version >= "3.9" and python_version < "4.0" +pyparsing==3.0.9 ; python_version >= "3.9" and python_version < "4.0" +pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "4.0" +pysocks==1.7.1 ; python_version >= "3.9" and python_version < "4.0" +python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0" +pytz==2023.3 ; python_version >= "3.9" and python_version < "4.0" +pywin32==306 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "4.0" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0" +referencing==0.30.2 ; python_version >= "3.9" and python_version < "4.0" +requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" +requests[socks]==2.31.0 ; python_version >= "3.9" and python_version < "4.0" +reretry==0.11.8 ; python_version >= "3.9" and python_version < "4.0" +rich==13.5.2 ; python_version >= "3.9" and python_version < "4.0" +rpds-py==0.9.2 ; python_version >= "3.9" and python_version < "4.0" +scikit-learn==1.3.0 ; python_version >= "3.9" and python_version < "4.0" +scipy==1.9.3 ; python_version >= "3.9" and python_version < "4.0" +seaborn==0.12.2 ; python_version >= "3.9" and python_version < "4.0" +setuptools-scm==7.1.0 ; python_version >= "3.9" and python_version < "4.0" +setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0" +six==1.16.0 ; python_version >= "3.9" and python_version < "4.0" +smart-open==6.3.0 ; python_version >= "3.9" and python_version < "4.0" +smmap==5.0.0 ; python_version >= "3.9" and python_version < "4.0" +snakemake==7.32.4 ; python_version >= "3.9" and python_version < "4.0" +soupsieve==2.4.1 ; python_version >= "3.9" and python_version < "4.0" +stopit==1.1.2 ; python_version >= "3.9" and python_version < "4.0" +tabulate==0.9.0 ; python_version >= "3.9" and python_version < "4.0" +threadpoolctl==3.2.0 ; python_version >= "3.9" and python_version < "4.0" +throttler==1.2.2 ; python_version >= "3.9" and python_version < "4.0" +tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" +toposort==1.10 ; python_version >= "3.9" and python_version < "4.0" +tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0" +tensorboard==2.14.0 ; python_version >= "3.9" and python_version < "4.0" +traitlets==5.9.0 ; python_version >= "3.9" and python_version < "4.0" +typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0" +urllib3==2.0.4 ; python_version >= "3.9" and python_version < "4.0" +wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" +yte==1.5.1 ; python_version >= "3.9" and python_version < "4.0" +zipp==3.16.2 ; python_version >= "3.9" and python_version < "3.10" diff --git a/run_fbopt_hduva b/run_fbopt_hduva new file mode 100644 index 000000000..c1add075f --- /dev/null +++ b/run_fbopt_hduva @@ -0,0 +1 @@ +python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --bs=8 --model=hduva --trainer=fbopt --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 --gamma_y=3 --epos=2 diff --git a/script_generate_all_figures_diva.sh b/script_generate_all_figures_diva.sh new file mode 100755 index 000000000..e7f9617d2 --- /dev/null +++ b/script_generate_all_figures_diva.sh @@ -0,0 +1,76 @@ +#!/bin/bash -x -v + +STR_LOSS_ELL="loss_task/ell" +OUT_DIR="./figures_diva" +# Number of points to plot: +phase_portrait_plot_len=120 + +LOSS_GAMMA_D="$\mathbb{E}_{q_{\phi_d}(z_d|x)}[\log q_{\omega_d}(d|z_d)]$" + + +# README: +# The following scripts will check event files from the 'runs' folder of the working directory. +# To generate example tensorboard 'runs' folder, one could execute e.g. `sh run_fbopt_mnist_diva_autoki.sh` such that there will be 'runs' folder. + +if [ -z "$1" ]; then + # Check if an argument is provided + runs_dir="runs/*" +else + # Use the provided argument + runs_dir=$1 +fi + + +# a command line argument can be passed to this script, in order to skip the first few large jumps on the phase plots; if no argument is provided then all points will be plotted: +if [ -z "$2" ]; then + # Check if an argument is provided + skip_n=0 +else + # Use the provided argument + skip_n=$2 +fi + + + + +# Phase portraits +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot2="lossrd/dyn_gamma_d" --plot1="loss_task/ell" --legend2="\$R_{\gamma_d}(\cdot)\$" --legend1="\$\ell(\cdot)\$" --plot_len $phase_portrait_plot_len --skip_n_steps $skip_n --output_dir=$OUT_DIR --phase_portrait + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot2="lossrd/dyn_mu_recon" --plot1="loss_task/ell" --legend2="\$R_{\mu_{recon}}(\cdot)\$" --legend1="\$\ell(\cdot)\$" --plot_len $phase_portrait_plot_len --skip_n_steps $skip_n --output_dir=$OUT_DIR --phase_portrait + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot2="lossrd/dyn_beta_d" --plot1="loss_task/ell" --legend2="\$R_{\beta_d}(\cdot)\$" --legend1="\$\ell(\cdot)\$" --plot_len $phase_portrait_plot_len --skip_n_steps $skip_n --output_dir=$OUT_DIR --phase_portrait + +# python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot2="lossrd/dyn_beta_x" --plot1="loss_task/ell" --legend2="KL (beta_x)" --legend1="ell" --plot_len $phase_portrait_plot_len --skip_n_steps $skip_n --output_dir=$OUT_DIR --phase_portrait + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot2="lossrd/dyn_beta_y" --plot1="loss_task/ell" --legend2="\$R_{beta_y}(\cdot)\$" --legend1="\$\ell(\cdot)\$" --plot_len $phase_portrait_plot_len --skip_n_steps $skip_n --output_dir=$OUT_DIR --phase_portrait + + + + +# Plot R and the corresponding set point curves (both in the same figure) +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="lossrd/dyn_gamma_d" --plot2="lossrs/setpoint_gamma_d" --legend1="\$R_{\gamma_d}\$" --legend2="setpoint" --output_dir=$OUT_DIR + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="lossrd/dyn_mu_recon" --plot2="lossrs/setpoint_mu_recon" --legend1="\$R_{\mu_{recon}}(\cdot)\$" --legend2="setpoint" --output_dir=$OUT_DIR + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="lossrd/dyn_beta_d" --plot2="lossrs/setpoint_beta_d" --legend1="\$R_{\beta_d}(\cdot)\$" --legend2="setpoint" --output_dir=$OUT_DIR --neg + +# python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="lossrd/dyn_beta_x" --plot2="lossrs/setpoint_beta_x" --legend1="KL (beta_x)" --legend2="setpoint" --output_dir=$OUT_DIR + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="lossrd/dyn_beta_y" --plot2="lossrs/setpoint_beta_y" --legend1="\$R_{\beta_y}(\cdot)\$" --legend2="setpoint" --output_dir=$OUT_DIR --neg + + + # One curve per figure + values=('controller_gain/beta_d' 'controller_gain/beta_y' 'controller_gain/beta_x' 'controller_gain/gamma_d' 'controller_gain/mu_recon' 'dyn_mu/beta_d' 'delta/beta_d' 'dyn_mu/beta_y' 'delta/beta_y' 'dyn_mu/beta_x' 'delta/beta_x' 'dyn_mu/gamma_d' 'delta/gamma_d' 'dyn_mu/mu_recon' 'delta/mu_recon' 'loss_task/penalized' 'loss_task/ell' 'acc/te' 'acc/val' 'acc/sel' 'acc/setpoint') + # Loop over the array + for val in "${values[@]}" + do + python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="$val" --legend1="$val" --output_dir=$OUT_DIR + done + + +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="dyn_mu/mu_recon" --legend1="\$\mu_{recon}\$" --output_dir=$OUT_DIR +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="dyn_mu/gamma_d" --legend1="\$\gamma_d\$" --output_dir=$OUT_DIR +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="dyn_mu/beta_y" --legend1="\$\beta_y\$" --output_dir=$OUT_DIR +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="loss_task/ell" --legend1="\$\ell(\cdot)\$" --output_dir=$OUT_DIR +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="loss_task/penalized" --legend1="\$\ell(\cdot)+\mu^TR(\cdot)\$" --output_dir=$OUT_DIR +python domainlab/utils/generate_fbopt_phase_portrait.py --runs_dir $runs_dir --plot1="controller_gain/beta_y" --legend1="controller gain for \$\beta_y\$" --output_dir=$OUT_DIR diff --git a/script_jigen_plot.sh b/script_jigen_plot.sh new file mode 100755 index 000000000..5c47a68f8 --- /dev/null +++ b/script_jigen_plot.sh @@ -0,0 +1,4 @@ +python domainlab/utils/generate_fbopt_phase_portrait.py --plot2="lossrd/dyn_alpha" --plot1="loss_task/ell" --legend2="regularization loss jigen" --legend1="classification loss" --output_dir="." --phase_portrait + + +python domainlab/utils/generate_fbopt_phase_portrait.py --plot1="lossrs/setpoint_alpha" --plot2="lossrd/dyn_alpha" --legend2="regularization loss jigen" --legend1="setpoint" --output_dir="." diff --git a/scripts/sh_link_pacs_dataset.sh b/scripts/sh_link_pacs_dataset.sh index f973db02b..60823522b 100644 --- a/scripts/sh_link_pacs_dataset.sh +++ b/scripts/sh_link_pacs_dataset.sh @@ -1,2 +1,5 @@ mkdir -p domainlab/zdata/pacs ln -s ~/Documents/datasets/pacs/raw domainlab/zdata/pacs/PACS + +mkdir -p data/pacs +ln -s ~/Documents/datasets/pacs/raw data/pacs/PACS diff --git a/scripts_fbopt/run_erm.sh b/scripts_fbopt/run_erm.sh new file mode 100644 index 000000000..f5285811f --- /dev/null +++ b/scripts_fbopt/run_erm.sh @@ -0,0 +1 @@ +python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --model=erm --nname=conv_bn_pool_2 --epos=10 diff --git a/scripts_fbopt/run_fbopt_dann.sh b/scripts_fbopt/run_fbopt_dann.sh new file mode 100644 index 000000000..c75fb071c --- /dev/null +++ b/scripts_fbopt/run_fbopt_dann.sh @@ -0,0 +1 @@ +python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=dann --trainer=fbopt --nname=alexnet --epos=200 --es=200 --no_setpoint_update diff --git a/scripts_fbopt/run_fbopt_diva.sh b/scripts_fbopt/run_fbopt_diva.sh new file mode 100644 index 000000000..dc48bce9b --- /dev/null +++ b/scripts_fbopt/run_fbopt_diva.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=caltech --task=mini_vlcs --bs=8 --model=diva --trainer=fbopt --nname=alexnet --nname_dom=alexnet --gamma_d=3 --gamma_y=3 --epos=200 diff --git a/scripts_fbopt/run_fbopt_diva_cpu.sh b/scripts_fbopt/run_fbopt_diva_cpu.sh new file mode 100644 index 000000000..59d0c592a --- /dev/null +++ b/scripts_fbopt/run_fbopt_diva_cpu.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=caltech --task=mini_vlcs --bs=8 --model=diva --trainer=fbopt --nname=alexnet --nname_dom=alexnet --gamma_d=3 --gamma_y=3 --epos=200 --es=100 diff --git a/scripts_fbopt/run_fbopt_hduva_cpu.sh b/scripts_fbopt/run_fbopt_hduva_cpu.sh new file mode 100644 index 000000000..54b7d5995 --- /dev/null +++ b/scripts_fbopt/run_fbopt_hduva_cpu.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES="" +python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --bs=8 --model=hduva --trainer=fbopt --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 --gamma_y=3 --epos=2 diff --git a/scripts_fbopt/run_fbopt_match_diva.sh b/scripts_fbopt/run_fbopt_match_diva.sh new file mode 100644 index 000000000..c1547567c --- /dev/null +++ b/scripts_fbopt/run_fbopt_match_diva.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=caltech --task=mini_vlcs --bs=8 --model=diva --trainer=fbopt_matchdg --nname=alexnet --nname_dom=alexnet --gamma_d=3 --gamma_y=3 --epos=200 --es=100 diff --git a/scripts_fbopt/run_fbopt_mnist.sh b/scripts_fbopt/run_fbopt_mnist.sh new file mode 100644 index 000000000..2e3edc424 --- /dev/null +++ b/scripts_fbopt/run_fbopt_mnist.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once diff --git a/scripts_fbopt/run_fbopt_mnist_diva.sh b/scripts_fbopt/run_fbopt_mnist_diva.sh new file mode 100644 index 000000000..fd5c2b8cf --- /dev/null +++ b/scripts_fbopt/run_fbopt_mnist_diva.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=diva --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=500 --mu_init=0.000001 --gamma_y=1.0 diff --git a/scripts_fbopt/run_fbopt_mnist_diva_autoki.sh b/scripts_fbopt/run_fbopt_mnist_diva_autoki.sh new file mode 100644 index 000000000..64c19e102 --- /dev/null +++ b/scripts_fbopt/run_fbopt_mnist_diva_autoki.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=diva --trainer=fbopt --nname=conv_bn_pool_2 --epos=5000 --es=5 --mu_init=1e-6 --gamma_y=1.0 --k_i_gain_ratio=0.9 --coeff_ma_output_state=0 --coeff_ma_setpoint=0 --epos_min=1000 --force_setpoint_change_once diff --git a/scripts_fbopt/run_fbopt_mnist_feedforward.sh b/scripts_fbopt/run_fbopt_mnist_feedforward.sh new file mode 100644 index 000000000..b04819c61 --- /dev/null +++ b/scripts_fbopt/run_fbopt_mnist_feedforward.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --epos_min=100 --es=1 --force_feedforward diff --git a/scripts_fbopt/run_fbopt_mnist_jigen_autoki.sh b/scripts_fbopt/run_fbopt_mnist_jigen_autoki.sh new file mode 100644 index 000000000..8b346e011 --- /dev/null +++ b/scripts_fbopt/run_fbopt_mnist_jigen_autoki.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=1 --epos_min=500 --mu_init=1e-6 --coeff_ma_output_state=0.99 --k_i_gain_ratio=0.99 diff --git a/scripts_fbopt/run_fbopt_small_pacs.sh b/scripts_fbopt/run_fbopt_small_pacs.sh new file mode 100644 index 000000000..fc3ab6bc7 --- /dev/null +++ b/scripts_fbopt/run_fbopt_small_pacs.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=sketch --tpath=examples/tasks/demo_task_path_list_small.py --bs=16 --model=jigen --trainer=fbopt --nname=alexnet --epos=200 --es=100 --init_mu=0.01 diff --git a/scripts_fbopt/run_mnist_jigen.sh b/scripts_fbopt/run_mnist_jigen.sh new file mode 100644 index 000000000..0bc854c5e --- /dev/null +++ b/scripts_fbopt/run_mnist_jigen.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py + +python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=jigen --trainer=fbopt --nname=alexnet --epos=200 --es=200 --mu_init=1.0 --coeff_ma_output=0 --coeff_ma_setpoint=0 --coeff_ma_output=0 diff --git a/scripts_fbopt/run_pacs_diva_fbopt.sh b/scripts_fbopt/run_pacs_diva_fbopt.sh new file mode 100644 index 000000000..74d1f0cd3 --- /dev/null +++ b/scripts_fbopt/run_pacs_diva_fbopt.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=sketch --bs=32 --model=diva --trainer=fbopt --epos=200 --es=200 --npath_dom=examples/nets/resnet50domainbed.py --tpath=examples/tasks/task_pacs_path_list.py --npath=examples/nets/resnet50domainbed.py --gamma_y=1.0 --mu_init=1e-6 --lr=5e-5 --zx_dim=0 diff --git a/scripts_fbopt/run_pacs_jigen_fbopt.sh b/scripts_fbopt/run_pacs_jigen_fbopt.sh new file mode 100644 index 000000000..99663ee61 --- /dev/null +++ b/scripts_fbopt/run_pacs_jigen_fbopt.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# export CUDA_VISIBLE_DEVICES="" +# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error +# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring +# pytest -s tests/test_fbopt.py +python main_out.py --te_d=sketch --tpath=examples/tasks/task_pacs_path_list.py --model=jigen --trainer=fbopt --bs=64 --epos=200 --es=200 --npath=examples/nets/resnet50domainbed.py --mu_init=1e-6 --lr=5e-5 --coeff_ma_output_state=0.1 diff --git a/test_fbopt_dial.sh b/test_fbopt_dial.sh new file mode 100644 index 000000000..4bf0c669b --- /dev/null +++ b/test_fbopt_dial.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES="" +python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --model=fboptdial --trainer=dial --nname=alexnet --nname_dom=alexnet --gamma_y=1e6 --gamma_d=1e6 diff --git a/test_match_duva.sh b/test_match_duva.sh new file mode 100644 index 000000000..9f3e9951e --- /dev/null +++ b/test_match_duva.sh @@ -0,0 +1,4 @@ +python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --debug --bs=2 --model=matchhduva \ + --epochs_ctr=3 --epos=6 --nname=conv_bn_pool_2 --gamma_y=7e5 \ + --nname_encoder_x2topic_h=conv_bn_pool_2 \ + --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 diff --git a/test_match_duva_vlcs.sh b/test_match_duva_vlcs.sh new file mode 100644 index 000000000..a47e76c36 --- /dev/null +++ b/test_match_duva_vlcs.sh @@ -0,0 +1,4 @@ +python main_out.py --te_d=caltech --task=mini_vlcs --debug --bs=2 --model=matchhduva \ + --epochs_ctr=3 --epos=6 --npath=examples/nets/resnet.py --gamma_y=7e5 \ + --npath_encoder_x2topic_h=examples/nets/resnet.py \ + --npath_encoder_sandwich_x2h4zd=examples/nets/resnet.py diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py new file mode 100644 index 000000000..15306ebf4 --- /dev/null +++ b/tests/test_fbopt.py @@ -0,0 +1,29 @@ +""" +unit and end-end test for deep all, mldg +""" +import pytest +from tests.utils_test import utils_test_algo + + +def test_dann_fbopt(): + """ + dann + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=dann --trainer=fbopt --nname=alexnet --epos=3 --no_dump" + utils_test_algo(args) + + +def test_jigen_fbopt(): + """ + jigen + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=3 --no_dump" + utils_test_algo(args) + + +def test_diva_fbopt(): + """ + diva + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3 --no_dump" + utils_test_algo(args) diff --git a/tests/test_fbopt2.py b/tests/test_fbopt2.py new file mode 100644 index 000000000..746924386 --- /dev/null +++ b/tests/test_fbopt2.py @@ -0,0 +1,37 @@ +""" +unit and end-end test for deep all, mldg +""" +import pytest +from tests.utils_test import utils_test_algo + + +def test_erm_fbopt(): + """ + erm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=fbopt --nname=alexnet --epos=3 \ + --no_dump" # pylint: disable=line-too-long + with pytest.raises(RuntimeError): + utils_test_algo(args) + + +def test_irm_fbopt(): + """ + irm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=fbopt_irm --nname=alexnet --epos=3 \ + --no_dump" # pylint: disable=line-too-long + utils_test_algo(args) + + +def test_forcesetpoint_fbopt(): + """ + diva + """ + args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen \ + --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 \ + --mu_init=0.00001 --coeff_ma_setpoint=0.5 \ + --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump" + utils_test_algo(args) diff --git a/tests/test_fbopt_irm.py b/tests/test_fbopt_irm.py new file mode 100644 index 000000000..10cfbcb53 --- /dev/null +++ b/tests/test_fbopt_irm.py @@ -0,0 +1,14 @@ +""" + end-end test +""" +from tests.utils_test import utils_test_algo + + +def test_mhof_irm(): + """ + mhof-irm + """ + args = "--te_d=0 --task=mnistcolor10 --model=erm \ + --trainer=fbopt_irm --nname=conv_bn_pool_2 \ + --k_i_gain_ratio=0.5 --no_dump" + utils_test_algo(args) diff --git a/tests/test_fbopt_setpoint_ada.py b/tests/test_fbopt_setpoint_ada.py new file mode 100644 index 000000000..4b8029056 --- /dev/null +++ b/tests/test_fbopt_setpoint_ada.py @@ -0,0 +1,9 @@ +from domainlab.algos.trainers.fbopt_setpoint_ada import is_less_list_all + + +def test_less_than(): + a = [3, 4, -9, -8] + b = [1, 0.5, -1, -0.5] + c = [0.5, 0.25, -0.5, -0.25] + assert not is_less_list_all(a, b) + assert is_less_list_all(c, b) diff --git a/tests/test_fbopt_setpoint_rewind.py b/tests/test_fbopt_setpoint_rewind.py new file mode 100644 index 000000000..3fcc8660e --- /dev/null +++ b/tests/test_fbopt_setpoint_rewind.py @@ -0,0 +1,12 @@ +""" +unit and end-end test for deep all, mldg +""" +from tests.utils_test import utils_test_algo + + +def test_jigen_fbopt(): + """ + jigen + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=jigen --trainer=fbopt --nname=alexnet --epos=300 --setpoint_rewind --no_dump" + utils_test_algo(args) diff --git a/tests/test_irm.py b/tests/test_irm.py index 235b9e4ce..2cc736374 100644 --- a/tests/test_irm.py +++ b/tests/test_irm.py @@ -9,16 +9,26 @@ def test_irm(): train with Invariant Risk Minimization """ args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ - --trainer=irm --nname=alexnet" + --trainer=irm --nname=alexnet --no_dump" utils_test_algo(args) +def test_irm_sepdom(): + """ + train with Invariant Risk Minimization + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=irmsepdom --nname=alexnet --no_dump" + utils_test_algo(args) + + + def test_irm_scheduler(): """ train with Invariant Risk Minimization """ args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ - --trainer=hyperscheduler_irm --nname=alexnet" + --trainer=hyperscheduler_irm --nname=alexnet --no_dump" utils_test_algo(args) @@ -29,5 +39,5 @@ def test_irm_mnist(): train with Invariant Risk Minimization """ args = "--te_d=0 --task=mnistcolor10 --keep_model --model=erm \ - --trainer=irm --nname=conv_bn_pool_2" + --trainer=irm --nname=conv_bn_pool_2 --no_dump" utils_test_algo(args) diff --git a/tests/test_mk_exp_jigen.py b/tests/test_mk_exp_jigen.py index 21c533489..a8d7c507d 100644 --- a/tests/test_mk_exp_jigen.py +++ b/tests/test_mk_exp_jigen.py @@ -19,7 +19,6 @@ def test_mk_exp_jigen(): """ mk_exp_jigen(trainer="mldg") - mk_exp_jigen(trainer="dial") def mk_exp_jigen(trainer="mldg"):