diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 432da9703..d4ebc75d3 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -117,17 +117,7 @@ def cal_oracle_perf(self): """ calculate oracle performance """ - try: - model_or = self.host_trainer.model.load("oracle") - # @FIXME: name "oracle is a strong dependency - model_or = model_or.to(self.device) - model_or.eval() - except FileNotFoundError: - return {"acc_oracle": -1} - logger = Logger.get_logger() - logger.info("oracle model performance metric: \n") - metric_te = model_or.cal_perf_metric(self.loader_te, self.device) - return {"acc_oracle": metric_te["acc"]} + return self.exp.cal_oracle_perf() def dump_prediction(self, model_ld, metric_te): """ diff --git a/domainlab/exp/exp_main.py b/domainlab/exp/exp_main.py index 12ecea5f5..29dc17be6 100755 --- a/domainlab/exp/exp_main.py +++ b/domainlab/exp/exp_main.py @@ -132,3 +132,20 @@ def clean_up(self): self.visitor.remove("ctr") # pylint: disable=E1101 except FileNotFoundError: pass + + def cal_oracle_perf(self): + """ + calculate oracle performance + """ + try: + model_or = self.visitor.load("oracle") + # @FIXME: name "oracle is a strong dependency + model_or = model_or.to(self.device) + model_or.eval() + except FileNotFoundError: + return {"acc_oracle": -1} + logger = Logger.get_logger() + logger.info("oracle model performance metric: \n") + metric_te = model_or.cal_perf_metric(self.observer.loader_te, + self.device) + return {"acc_oracle": metric_te["acc"]}