diff --git a/nvflare/app_opt/pt/model_persistence_format_manager.py b/nvflare/app_opt/pt/model_persistence_format_manager.py index 61f6856732..2f1f6a2fe4 100644 --- a/nvflare/app_opt/pt/model_persistence_format_manager.py +++ b/nvflare/app_opt/pt/model_persistence_format_manager.py @@ -53,23 +53,21 @@ def __init__(self, data: dict, default_train_conf=None, allow_numpy_conversion=T self.train_conf = None self.other_props = {} # other props from the original data that need to be kept - if self.PERSISTENCE_KEY_MODEL not in data: + # Cache attr accesses for performance + pk_model = self.PERSISTENCE_KEY_MODEL + pk_meta = self.PERSISTENCE_KEY_META_PROPS + pk_conf = self.PERSISTENCE_KEY_TRAIN_CONF + + if pk_model not in data: # this is a simple weight dict self.var_dict = data else: # dict of dicts - self.var_dict = data[self.PERSISTENCE_KEY_MODEL] - self.meta = data.get(self.PERSISTENCE_KEY_META_PROPS, None) - self.train_conf = data.get(self.PERSISTENCE_KEY_TRAIN_CONF, None) - - # we need to keep other props, if any, so they can be kept when persisted - for k, v in data.items(): - if k not in [ - self.PERSISTENCE_KEY_MODEL, - self.PERSISTENCE_KEY_META_PROPS, - self.PERSISTENCE_KEY_TRAIN_CONF, - ]: - self.other_props[k] = v + self.var_dict = data[pk_model] + self.meta = data.get(pk_meta, None) + self.train_conf = data.get(pk_conf, None) + # use dict comprehension instead of a for loop for better performance + self.other_props = {k: v for k, v in data.items() if k not in (pk_model, pk_meta, pk_conf)} if not self.train_conf: self.train_conf = default_train_conf @@ -77,21 +75,25 @@ def __init__(self, data: dict, default_train_conf=None, allow_numpy_conversion=T self._allow_numpy_conversion = allow_numpy_conversion def _get_processed_vars(self) -> dict: + # No change - this is already efficient if self.meta: return self.meta.get(MetaKey.PROCESSED_KEYS, {}) - else: - return {} + return {} def to_model_learnable(self, exclude_vars) -> ModelLearnable: processed_vars = self._get_processed_vars() weights = {} - for k, v in self.var_dict.items(): + allow_numpy = self._allow_numpy_conversion + var_items = self.var_dict.items() + get_processed = processed_vars.get + + for k, v in var_items: if exclude_vars and exclude_vars.search(k): continue - is_processed = processed_vars.get(k, False) - if not is_processed and self._allow_numpy_conversion: + is_processed = get_processed(k, False) + if not is_processed and allow_numpy: # convert to numpy weights[k] = v.cpu().numpy() else: