diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 670f5e3fa..cbc57cc77 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -45,6 +45,9 @@ def __init__(self, successor_node=None): self.epo_loss_tr = None self.epo_reg_loss_tr = None self.epo_task_loss_tr = None + self.batch_loss_tr = None + self.batch_reg_loss_tr = None + self.batch_task_loss_tr = None self.counter_batch = None self.hyper_scheduler = None self.optimizer = None diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index 9bc64785e..2050cb8a6 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -76,8 +76,15 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: + self.hyper_scheduler.search_mu( + self.epo_reg_loss_tr, + self.epo_task_loss_tr, + self.epo_loss_tr, + self.list_str_multiplier_na, + miter=epoch) + self.set_model_with_mu() # NOTE: if not update per_batch, then not updated - self.model.hyper_update(epoch * self.num_batches + ind_batch, self.hyper_scheduler) + # self.model.hyper_update(epoch * self.num_batches + ind_batch, self.hyper_scheduler) return super().after_batch(epoch, ind_batch) def before_tr(self): @@ -117,6 +124,8 @@ def tr_epoch(self, epoch, flag_info=False): """ update multipliers only per epoch """ + if not self.flag_update_hyper_per_batch: + return super().tr_epoch(epoch) self.hyper_scheduler.search_mu( self.epo_reg_loss_tr, self.epo_task_loss_tr,