Skip to content
12 changes: 1 addition & 11 deletions domainlab/algos/observers/b_obvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,7 @@ def cal_oracle_perf(self):
"""
calculate oracle performance
"""
try:
model_or = self.host_trainer.model.load("oracle")
# @FIXME: name "oracle is a strong dependency
model_or = model_or.to(self.device)
model_or.eval()
except FileNotFoundError:
return {"acc_oracle": -1}
logger = Logger.get_logger()
logger.info("oracle model performance metric: \n")
metric_te = model_or.cal_perf_metric(self.loader_te, self.device)
return {"acc_oracle": metric_te["acc"]}
return self.exp.cal_oracle_perf()

def dump_prediction(self, model_ld, metric_te):
"""
Expand Down
17 changes: 17 additions & 0 deletions domainlab/exp/exp_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,20 @@ def clean_up(self):
self.visitor.remove("ctr") # pylint: disable=E1101
except FileNotFoundError:
pass

def cal_oracle_perf(self):
"""
calculate oracle performance
"""
try:
model_or = self.visitor.load("oracle")
# @FIXME: name "oracle is a strong dependency
model_or = model_or.to(self.device)
model_or.eval()
except FileNotFoundError:
return {"acc_oracle": -1}
logger = Logger.get_logger()
logger.info("oracle model performance metric: \n")
metric_te = model_or.cal_perf_metric(self.observer.loader_te,
self.device)
return {"acc_oracle": metric_te["acc"]}