Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 29, 2025

📄 8% (0.08x) speedup for PTModelPersistenceFormatManager.to_model_learnable in nvflare/app_opt/pt/model_persistence_format_manager.py

⏱️ Runtime : 906 microseconds 840 microseconds (best of 527 runs)

📝 Explanation and details

The optimization achieves a 7% speedup through several targeted performance improvements focused on the hot path in to_model_learnable:

Key Optimizations:

  1. Cached Attribute Lookups in Hot Loop: The main performance gain comes from caching frequently accessed attributes as local variables:

    • allow_numpy = self._allow_numpy_conversion
    • var_items = self.var_dict.items()
    • get_processed = processed_vars.get

    This eliminates repeated attribute lookups during the loop iteration, which is particularly beneficial when processing large numbers of model weights.

  2. Reduced Dict Method Lookups: By storing processed_vars.get as a local variable get_processed, the optimization avoids the overhead of method resolution on each loop iteration. The line profiler shows this reduces time spent on the is_processed = processed_vars.get(k, False) line from 1.09M to 1.02M nanoseconds.

  3. Class-Level Constant Declaration: Moving the PERSISTENCE_KEY_* constants to class attributes reduces attribute lookup overhead during initialization, though this has minimal impact on the hot path.

  4. Dictionary Comprehension for other_props: Replaced the loop-based construction with a more efficient dictionary comprehension, though this only affects initialization time.

Performance Impact by Test Case:
The optimizations are most effective for large-scale scenarios:

  • Large weight dictionaries (1000+ keys): 9-12% speedup
  • Complex metadata processing: 7-12% speedup
  • Simple cases show minimal or slight regression due to the overhead of local variable setup

The optimizations primarily benefit workloads with many model parameters where the loop overhead becomes significant, making this particularly valuable for deep learning model persistence scenarios.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 60 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest
from nvflare.app_opt.pt.model_persistence_format_manager import \
    PTModelPersistenceFormatManager

# --- Mocks and minimal stand-ins for external dependencies ---

# Simulate MetaKey with a simple Enum-like class
class MetaKey:
    PROCESSED_KEYS = "processed_keys"

# Simulate ModelLearnableKey with a simple Enum-like class
class ModelLearnableKey:
    WEIGHTS = "weights"
    META = "meta"

# Simulate a PyTorch tensor for testing
class FakeTensor:
    def __init__(self, data, device="cpu"):
        self.data = data
        self.device = device
        self.numpy_called = False

    def cpu(self):
        # Simulate moving to CPU (no-op if already on CPU)
        return self

    def numpy(self):
        # Simulate conversion to numpy array
        self.numpy_called = True
        return self.data

    def __eq__(self, other):
        # For test comparison
        if isinstance(other, FakeTensor):
            return self.data == other.data
        return self.data == other

    def __repr__(self):
        return f"FakeTensor(data={self.data!r}, device={self.device!r})"

# Simulate an exclude_vars object with a .search() method
class ExcludeVars:
    def __init__(self, exclude_keys):
        self.exclude_keys = set(exclude_keys)

    def search(self, key):
        return key in self.exclude_keys

# --- Unit tests ---

# ========== BASIC TEST CASES ==========

def test_simple_weight_dict_no_exclude():
    # Test with a simple weight dict, no excludes, no meta, all tensors unprocessed
    weights = {"w1": FakeTensor(1), "w2": FakeTensor(2)}
    manager = PTModelPersistenceFormatManager(weights)
    exclude_vars = None
    codeflash_output = manager.to_model_learnable(exclude_vars); ml = codeflash_output # 2.75μs -> 2.92μs (5.69% slower)

def test_dict_of_dicts_with_meta_and_exclude():
    # Test with dict of dicts, meta present, some variables excluded
    weights = {"w1": FakeTensor(10), "w2": FakeTensor(20), "w3": FakeTensor(30)}
    meta = {MetaKey.PROCESSED_KEYS: {"w2": True}}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_META_PROPS: meta,
    }
    manager = PTModelPersistenceFormatManager(data)
    exclude_vars = ExcludeVars(["w3"])
    codeflash_output = manager.to_model_learnable(exclude_vars); ml = codeflash_output # 3.64μs -> 3.27μs (11.3% faster)

def test_allow_numpy_conversion_false():
    # Test with allow_numpy_conversion=False, so no conversion should happen
    weights = {"a": FakeTensor(5), "b": FakeTensor(6)}
    manager = PTModelPersistenceFormatManager(weights, allow_numpy_conversion=False)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 1.96μs -> 2.00μs (1.95% slower)

def test_processed_keys_handling():
    # Test that processed keys are not converted even if allow_numpy_conversion is True
    weights = {"x": FakeTensor(99), "y": FakeTensor(100)}
    meta = {MetaKey.PROCESSED_KEYS: {"x": True, "y": False}}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_META_PROPS: meta,
    }
    manager = PTModelPersistenceFormatManager(data, allow_numpy_conversion=True)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 2.70μs -> 2.62μs (2.86% faster)

def test_default_train_conf():
    # Test that default_train_conf is set if not present in data
    weights = {"a": FakeTensor(1)}
    manager = PTModelPersistenceFormatManager(weights, default_train_conf={"lr": 0.1})

# ========== EDGE TEST CASES ==========

def test_empty_weights_dict():
    # Test with empty weight dict
    manager = PTModelPersistenceFormatManager({})
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 1.57μs -> 1.75μs (10.4% slower)

def test_none_data_raises_typeerror():
    # Test that passing None as data raises TypeError
    with pytest.raises(TypeError):
        PTModelPersistenceFormatManager(None)

def test_exclude_vars_excludes_all():
    # Test where exclude_vars excludes all keys
    weights = {"a": FakeTensor(1), "b": FakeTensor(2)}
    manager = PTModelPersistenceFormatManager(weights)
    exclude_vars = ExcludeVars(["a", "b"])
    codeflash_output = manager.to_model_learnable(exclude_vars); ml = codeflash_output # 2.14μs -> 2.35μs (9.13% slower)

def test_processed_keys_empty_dict():
    # Test meta with empty processed_keys dict
    weights = {"a": FakeTensor(1)}
    meta = {MetaKey.PROCESSED_KEYS: {}}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_META_PROPS: meta,
    }
    manager = PTModelPersistenceFormatManager(data)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 2.57μs -> 2.54μs (1.06% faster)

def test_exclude_vars_is_none():
    # Test with exclude_vars=None
    weights = {"a": FakeTensor(1)}
    manager = PTModelPersistenceFormatManager(weights)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 2.15μs -> 2.24μs (4.32% slower)

def test_other_props_are_preserved():
    # Test that other_props are correctly set
    weights = {"a": FakeTensor(1)}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        "custom_prop": 42,
    }
    manager = PTModelPersistenceFormatManager(data)

def test_train_conf_from_data():
    # Test that train_conf is set from data if present
    weights = {"a": FakeTensor(1)}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_TRAIN_CONF: {"epochs": 5},
    }
    manager = PTModelPersistenceFormatManager(data, default_train_conf={"epochs": 10})

# ========== LARGE SCALE TEST CASES ==========

def test_large_number_of_weights():
    # Test with a large number of weights (up to 1000)
    weights = {f"k{i}": FakeTensor(i) for i in range(1000)}
    manager = PTModelPersistenceFormatManager(weights)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 156μs -> 143μs (9.10% faster)
    # All weights should be converted to their .data (ints)
    for i in range(1000):
        pass

def test_large_number_of_processed_keys():
    # Test with a large number of processed keys, half processed, half not
    weights = {f"k{i}": FakeTensor(i) for i in range(1000)}
    processed_keys = {f"k{i}": (i % 2 == 0) for i in range(1000)}
    meta = {MetaKey.PROCESSED_KEYS: processed_keys}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_META_PROPS: meta,
    }
    manager = PTModelPersistenceFormatManager(data)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 156μs -> 142μs (9.72% faster)
    # Even keys should remain as FakeTensor, odd keys should be ints
    for i in range(1000):
        val = ml[ModelLearnableKey.WEIGHTS][f"k{i}"]
        if i % 2 == 0:
            pass
        else:
            pass

def test_large_exclude_vars():
    # Test with a large number of weights and exclude half
    weights = {f"k{i}": FakeTensor(i) for i in range(1000)}
    exclude_keys = [f"k{i}" for i in range(0, 1000, 2)]  # Exclude even keys
    manager = PTModelPersistenceFormatManager(weights)
    exclude_vars = ExcludeVars(exclude_keys)
    codeflash_output = manager.to_model_learnable(exclude_vars); ml = codeflash_output # 144μs -> 136μs (5.79% faster)
    # Only odd keys should remain
    for i in range(1, 1000, 2):
        pass
    for i in range(0, 1000, 2):
        pass

def test_large_scale_with_meta_and_other_props():
    # Large scale with meta and other_props
    weights = {f"w{i}": FakeTensor(i) for i in range(500)}
    meta = {MetaKey.PROCESSED_KEYS: {f"w{i}": False for i in range(500)}}
    data = {
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_MODEL: weights,
        PTModelPersistenceFormatManager.PERSISTENCE_KEY_META_PROPS: meta,
        "extra": "value"
    }
    manager = PTModelPersistenceFormatManager(data)
    codeflash_output = manager.to_model_learnable(None); ml = codeflash_output # 78.6μs -> 73.1μs (7.55% faster)
    # All weights should be converted to ints
    for i in range(500):
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
from nvflare.app_opt.pt.model_persistence_format_manager import \
    PTModelPersistenceFormatManager

# ---- Mocks and minimal implementations for dependencies ----

# Simulate MetaKey with just the needed attribute
class MetaKey:
    PROCESSED_KEYS = "processed_keys"

# Simulate ModelLearnableKey with needed keys
class ModelLearnableKey:
    WEIGHTS = "weights"
    META = "meta"

# ---- Helper classes for tests ----

# Simulate a minimal tensor-like object with .cpu().numpy() chain
class DummyTensor:
    def __init__(self, value):
        self.value = value
    def cpu(self):
        return self
    def numpy(self):
        # Return a tuple to simulate numpy array for test purposes
        return ('numpy', self.value)
    # For identity checks
    def __eq__(self, other):
        if isinstance(other, DummyTensor):
            return self.value == other.value
        return False
    def __repr__(self):
        return f"DummyTensor({self.value!r})"

# Simulate an exclude_vars object with a .search() method
class DummyExcludeVars:
    def __init__(self, exclude_keys):
        self.exclude_keys = set(exclude_keys)
    def search(self, key):
        return key in self.exclude_keys

# ---- Unit Tests ----

# ----------- Basic Test Cases -----------

def test_basic_simple_weight_dict_conversion():
    # Test with a simple weight dict, no meta, no exclude, all tensors
    data = {
        "w1": DummyTensor(1),
        "w2": DummyTensor(2),
    }
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 3.34μs -> 3.37μs (1.01% slower)

def test_basic_dict_of_dicts_with_meta_and_train_conf():
    # Test with dict of dicts, with meta and train_conf present
    meta = {MetaKey.PROCESSED_KEYS: {"w1": True}}
    data = {
        "model": {"w1": DummyTensor(1), "w2": DummyTensor(2)},
        "meta_props": meta,
        "train_conf": {"lr": 0.1},
    }
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 3.17μs -> 2.95μs (7.64% faster)

def test_basic_exclude_vars():
    # Test that exclude_vars works to exclude keys
    data = {"w1": DummyTensor(1), "w2": DummyTensor(2)}
    mgr = PTModelPersistenceFormatManager(data)
    exclude = DummyExcludeVars(["w1"])
    codeflash_output = mgr.to_model_learnable(exclude_vars=exclude); ml = codeflash_output # 2.84μs -> 2.98μs (4.92% slower)

def test_basic_allow_numpy_conversion_false():
    # Test that allow_numpy_conversion disables conversion
    data = {"w1": DummyTensor(1)}
    mgr = PTModelPersistenceFormatManager(data, allow_numpy_conversion=False)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 2.01μs -> 2.01μs (0.099% slower)

# ----------- Edge Test Cases -----------

def test_edge_empty_data_dict():
    # Empty dict should result in empty weights, no meta
    data = {}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 1.60μs -> 1.69μs (5.04% slower)

def test_edge_single_key_processed_true():
    # Only one key, processed True, should not convert
    meta = {MetaKey.PROCESSED_KEYS: {"w": True}}
    data = {"model": {"w": DummyTensor(7)}, "meta_props": meta}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 2.67μs -> 2.53μs (5.33% faster)

def test_edge_single_key_processed_false():
    # Only one key, processed False, should convert
    meta = {MetaKey.PROCESSED_KEYS: {"w": False}}
    data = {"model": {"w": DummyTensor(8)}, "meta_props": meta}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 2.49μs -> 2.43μs (2.47% faster)

def test_edge_exclude_vars_excludes_all():
    # All keys are excluded
    data = {"a": DummyTensor(1), "b": DummyTensor(2)}
    exclude = DummyExcludeVars(["a", "b"])
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=exclude); ml = codeflash_output # 2.18μs -> 2.28μs (4.34% slower)

def test_edge_exclude_vars_is_none():
    # exclude_vars is None, should include all keys
    data = {"x": DummyTensor(4)}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 2.21μs -> 2.27μs (2.73% slower)

def test_edge_non_dict_data_raises_typeerror():
    # Non-dict data should raise TypeError
    with pytest.raises(TypeError):
        PTModelPersistenceFormatManager(["not", "a", "dict"])

def test_edge_meta_without_processed_keys():
    # meta present but no processed_keys
    meta = {"something_else": 123}
    data = {"model": {"w": DummyTensor(9)}, "meta_props": meta}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 2.68μs -> 2.61μs (2.61% faster)

def test_edge_var_dict_is_empty():
    # model dict present but empty
    data = {"model": {}}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 1.51μs -> 1.76μs (14.1% slower)

def test_edge_default_train_conf_applied():
    # If no train_conf, should use default_train_conf
    data = {"model": {"w": DummyTensor(1)}}
    mgr = PTModelPersistenceFormatManager(data, default_train_conf={"lr": 0.1})

# ----------- Large Scale Test Cases -----------

def test_large_scale_many_weights():
    # Test with 1000 keys, half processed, half not
    N = 1000
    model_dict = {}
    processed_keys = {}
    for i in range(N):
        model_dict[f"w{i}"] = DummyTensor(i)
        processed_keys[f"w{i}"] = (i % 2 == 0)
    meta = {MetaKey.PROCESSED_KEYS: processed_keys}
    data = {"model": model_dict, "meta_props": meta}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 150μs -> 135μs (11.8% faster)
    weights = ml[ModelLearnableKey.WEIGHTS]
    # Even keys should be DummyTensor, odd keys should be numpy
    for i in range(N):
        key = f"w{i}"
        if i % 2 == 0:
            pass
        else:
            pass

def test_large_scale_all_excluded():
    # All keys are excluded in a large dict
    N = 500
    data = {f"k{i}": DummyTensor(i) for i in range(N)}
    exclude = DummyExcludeVars([f"k{i}" for i in range(N)])
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=exclude); ml = codeflash_output # 35.2μs -> 35.1μs (0.327% faster)

def test_large_scale_no_conversion():
    # Large dict, allow_numpy_conversion=False, nothing is converted
    N = 500
    data = {f"k{i}": DummyTensor(i) for i in range(N)}
    mgr = PTModelPersistenceFormatManager(data, allow_numpy_conversion=False)
    codeflash_output = mgr.to_model_learnable(exclude_vars=None); ml = codeflash_output # 38.1μs -> 32.5μs (17.1% faster)
    for i in range(N):
        pass

def test_large_scale_exclude_some():
    # Exclude every 10th key in a large dict
    N = 500
    keys = [f"k{i}" for i in range(N)]
    exclude_keys = [k for i, k in enumerate(keys) if i % 10 == 0]
    exclude = DummyExcludeVars(exclude_keys)
    data = {k: DummyTensor(i) for i, k in enumerate(keys)}
    mgr = PTModelPersistenceFormatManager(data)
    codeflash_output = mgr.to_model_learnable(exclude_vars=exclude); ml = codeflash_output # 99.5μs -> 94.9μs (4.94% faster)
    weights = ml[ModelLearnableKey.WEIGHTS]
    for i, k in enumerate(keys):
        if k in exclude_keys:
            pass
        else:
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-PTModelPersistenceFormatManager.to_model_learnable-mhcd5awr and push.

Codeflash

The optimization achieves a 7% speedup through several targeted performance improvements focused on the hot path in `to_model_learnable`:

**Key Optimizations:**

1. **Cached Attribute Lookups in Hot Loop**: The main performance gain comes from caching frequently accessed attributes as local variables:
   - `allow_numpy = self._allow_numpy_conversion` 
   - `var_items = self.var_dict.items()`
   - `get_processed = processed_vars.get`
   
   This eliminates repeated attribute lookups during the loop iteration, which is particularly beneficial when processing large numbers of model weights.

2. **Reduced Dict Method Lookups**: By storing `processed_vars.get` as a local variable `get_processed`, the optimization avoids the overhead of method resolution on each loop iteration. The line profiler shows this reduces time spent on the `is_processed = processed_vars.get(k, False)` line from 1.09M to 1.02M nanoseconds.

3. **Class-Level Constant Declaration**: Moving the `PERSISTENCE_KEY_*` constants to class attributes reduces attribute lookup overhead during initialization, though this has minimal impact on the hot path.

4. **Dictionary Comprehension for `other_props`**: Replaced the loop-based construction with a more efficient dictionary comprehension, though this only affects initialization time.

**Performance Impact by Test Case:**
The optimizations are most effective for large-scale scenarios:
- Large weight dictionaries (1000+ keys): 9-12% speedup
- Complex metadata processing: 7-12% speedup  
- Simple cases show minimal or slight regression due to the overhead of local variable setup

The optimizations primarily benefit workloads with many model parameters where the loop overhead becomes significant, making this particularly valuable for deep learning model persistence scenarios.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 29, 2025 19:03
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant