From abb185ffc22e908094bec4ac4f2e8d664c650b66 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 25 Oct 2023 10:22:59 +0200 Subject: [PATCH] orcale model load move from observer to exp --- domainlab/algos/observers/b_obvisitor.py | 14 ++------------ domainlab/compos/exp/exp_main.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 7fec2eccc..9f23db9e8 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -112,17 +112,7 @@ def cal_oracle_perf(self): """ calculate oracle performance """ - try: - model_or = self.exp.visitor.load("oracle") - # @FIXME: name "oracle is a strong dependency - model_or = model_or.to(self.device) - model_or.eval() - except FileNotFoundError: - return {"acc_oracle": -1} - logger = Logger.get_logger() - logger.info("oracle model performance metric: \n") - metric_te = model_or.cal_perf_metric(self.loader_te, self.device) - return {"acc_oracle": metric_te["acc"]} + return self.exp.cal_oracle_perf() def dump_prediction(self, model_ld, metric_te): """ @@ -152,4 +142,4 @@ def clean_up(self): to be called by a decorator """ if not self.keep_model: - self.exp.clean_up() \ No newline at end of file + self.exp.clean_up() diff --git a/domainlab/compos/exp/exp_main.py b/domainlab/compos/exp/exp_main.py index 74ffe86da..aa5118219 100755 --- a/domainlab/compos/exp/exp_main.py +++ b/domainlab/compos/exp/exp_main.py @@ -40,6 +40,8 @@ def __init__(self, args, task=None, model=None, observer=None, visitor=AggWriter # jigen algorithm builder has method dset_decoration_args_algo, which could AOP # into the task intilization process self.trainer, self.model, observer_default, device = algo_builder.init_business(self) + self.device = device + self.observer = observer if model is not None: self.model = model self.visitor = visitor(self) # visitor depends on task initialization first @@ -118,3 +120,20 @@ def clean_up(self): self.visitor.remove("ctr") # pylint: disable=E1101 except FileNotFoundError: pass + + def cal_oracle_perf(self): + """ + calculate oracle performance + """ + try: + model_or = self.visitor.load("oracle") + # @FIXME: name "oracle is a strong dependency + model_or = model_or.to(self.device) + model_or.eval() + except FileNotFoundError: + return {"acc_oracle": -1} + logger = Logger.get_logger() + logger.info("oracle model performance metric: \n") + metric_te = model_or.cal_perf_metric(self.observer.loader_te, + self.device) + return {"acc_oracle": metric_te["acc"]}