From 127b2ae79c8f5fd64a3ebc70423198de8d647382 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 12 Oct 2023 23:13:54 +0200 Subject: [PATCH 1/3] Update train_mu_controller.py --- domainlab/algos/trainers/train_mu_controller.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/train_mu_controller.py b/domainlab/algos/trainers/train_mu_controller.py index f92632619..1b2447763 100644 --- a/domainlab/algos/trainers/train_mu_controller.py +++ b/domainlab/algos/trainers/train_mu_controller.py @@ -73,8 +73,14 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: - # NOTE: if not update per_batch, then not updated - self.model.hyper_update(epoch*self.num_batches + ind_batch, self.hyper_scheduler) + self.hyper_scheduler.search_mu( + self.batch_reg_loss_tr, + self.batch_task_loss_tr, + self.batch_loss_tr, + self.list_str_multiplier_na, + dict(self.model.named_parameters()), + miter=epoch) + self.set_model_with_mu() return super().after_batch(epoch, ind_batch) def before_tr(self): From 6de2f35e04376ee8502b4a7f5a535176e3a86004 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 12 Oct 2023 23:18:36 +0200 Subject: [PATCH 2/3] Update a_trainer.py --- domainlab/algos/trainers/a_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 4d6f2c93f..9ee28b569 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -45,6 +45,9 @@ def __init__(self, successor_node=None): self.epo_loss_tr = None self.epo_reg_loss_tr = None self.epo_task_loss_tr = None + self.batch_loss_tr = None + self.batch_reg_loss_tr = None + self.batch_task_loss_tr = None self.counter_batch = None self.hyper_scheduler = None self.optimizer = None From cb351f7b4911e2befc14b5d02d5a61700bd084ae Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 12 Oct 2023 23:20:55 +0200 Subject: [PATCH 3/3] Update train_mu_controller.py --- domainlab/algos/trainers/train_mu_controller.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/domainlab/algos/trainers/train_mu_controller.py b/domainlab/algos/trainers/train_mu_controller.py index 1b2447763..762f34290 100644 --- a/domainlab/algos/trainers/train_mu_controller.py +++ b/domainlab/algos/trainers/train_mu_controller.py @@ -105,6 +105,8 @@ def tr_epoch(self, epoch): """ update multipliers only per epoch """ + if not self.flag_update_hyper_per_batch: + return super().tr_epoch(epoch) self.hyper_scheduler.search_mu( self.epo_reg_loss_tr, self.epo_task_loss_tr,