Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions nvflare/app_opt/pt/model_persistence_format_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,47 @@ def __init__(self, data: dict, default_train_conf=None, allow_numpy_conversion=T
self.train_conf = None
self.other_props = {} # other props from the original data that need to be kept

if self.PERSISTENCE_KEY_MODEL not in data:
# Cache attr accesses for performance
pk_model = self.PERSISTENCE_KEY_MODEL
pk_meta = self.PERSISTENCE_KEY_META_PROPS
pk_conf = self.PERSISTENCE_KEY_TRAIN_CONF

if pk_model not in data:
# this is a simple weight dict
self.var_dict = data
else:
# dict of dicts
self.var_dict = data[self.PERSISTENCE_KEY_MODEL]
self.meta = data.get(self.PERSISTENCE_KEY_META_PROPS, None)
self.train_conf = data.get(self.PERSISTENCE_KEY_TRAIN_CONF, None)

# we need to keep other props, if any, so they can be kept when persisted
for k, v in data.items():
if k not in [
self.PERSISTENCE_KEY_MODEL,
self.PERSISTENCE_KEY_META_PROPS,
self.PERSISTENCE_KEY_TRAIN_CONF,
]:
self.other_props[k] = v
self.var_dict = data[pk_model]
self.meta = data.get(pk_meta, None)
self.train_conf = data.get(pk_conf, None)
# use dict comprehension instead of a for loop for better performance
self.other_props = {k: v for k, v in data.items() if k not in (pk_model, pk_meta, pk_conf)}

if not self.train_conf:
self.train_conf = default_train_conf

self._allow_numpy_conversion = allow_numpy_conversion

def _get_processed_vars(self) -> dict:
# No change - this is already efficient
if self.meta:
return self.meta.get(MetaKey.PROCESSED_KEYS, {})
else:
return {}
return {}

def to_model_learnable(self, exclude_vars) -> ModelLearnable:
processed_vars = self._get_processed_vars()

weights = {}
for k, v in self.var_dict.items():
allow_numpy = self._allow_numpy_conversion
var_items = self.var_dict.items()
get_processed = processed_vars.get

for k, v in var_items:
if exclude_vars and exclude_vars.search(k):
continue

is_processed = processed_vars.get(k, False)
if not is_processed and self._allow_numpy_conversion:
is_processed = get_processed(k, False)
if not is_processed and allow_numpy:
# convert to numpy
weights[k] = v.cpu().numpy()
else:
Expand Down