From aa9e21e71220cbdb4bd77d2f8c6cd23a356c1b26 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Tue, 16 Sep 2025 22:59:11 -0400 Subject: [PATCH 1/5] add cross site eval to fed avg recipe. --- examples/hello-world/hello-pt/README.md | 43 ++++++++----------------- examples/hello-world/hello-pt/job.py | 5 ++- nvflare/app_opt/pt/recipes/fedavg.py | 19 ++++++++++- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/examples/hello-world/hello-pt/README.md b/examples/hello-world/hello-pt/README.md index bc769b19e0..1ddba14218 100644 --- a/examples/hello-world/hello-pt/README.md +++ b/examples/hello-world/hello-pt/README.md @@ -123,33 +123,7 @@ The FedAvg controller implements these main steps: - FL server sends the global model to clients and waits for their updates using `self.send_model_and_wait()`. - FL server aggregates all the `results` and produces a new global model using `self.update_model()`. -```python -class FedAvg(BaseFedAvg): - def run(self) -> None: - self.info("Start FedAvg.") - - model = self.load_model() - model.start_round = self.start_round - model.total_rounds = self.num_rounds - - for self.current_round in range(self.start_round, self.start_round + self.num_rounds): - self.info(f"Round {self.current_round} started.") - model.current_round = self.current_round - - clients = self.sample_clients(self.num_clients) - - results = self.send_model_and_wait(targets=clients, data=model) - - aggregate_results = self.aggregate(results) - - model = self.update_model(model, aggregate_results) - - self.save_model(model) - - self.info("Finished FedAvg.") -``` - -In this example, we will directly use the default federated averaging algorithm provided by NVFlare. The FedAvg class is defined in nvflare.app_common.workflows.fedavg.FedAvg +In this example, we will directly use the default federated averaging algorithm provided by NVFlare implemented using the `ScatterAndGather` controller. There is no need to defined a customized server code for this example. ## Job Recipe Code @@ -168,22 +142,31 @@ Job Recipe contains the client.py and built-in Fed average algorithm. recipe.execute(env=env) ``` -To include both training and evaluation, you can change the recipe's training script +To include both training and evaluation using the `CrossSiteModelEval` controller, we use the recipe's training script ```python train_script="client_with_eval.py", ``` -or simply overwrite client.py with client_with_eval.py +and set `cross_site_eval=True` in [job.py](job.py) using below commandline arguments. ## Run Job from terminal try to run the code - ``` python job.py ``` + +To run with cross-site evaluation, execute +``` + python job.py --train_script="client_with_eval.py" --cross_site_eval +``` +The cross-site evaluation results can be shown via +``` +cat /tmp/nvflare/simulation/hello-pt/server/simulate_job/cross_site_val/cross_val_results.json +``` + > Note: >> depends on the number of clients, you might run into error due to several client try to download the data at the same time. >> suggest to pre-download the data to avoid such errors. diff --git a/examples/hello-world/hello-pt/job.py b/examples/hello-world/hello-pt/job.py index 0c05e093ae..47a13e9a01 100644 --- a/examples/hello-world/hello-pt/job.py +++ b/examples/hello-world/hello-pt/job.py @@ -28,6 +28,8 @@ def define_parser(): parser.add_argument("--n_clients", type=int, default=2) parser.add_argument("--num_rounds", type=int, default=2) parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--train_script", type=str, default="client.py") + parser.add_argument("--cross_site_eval", action="store_true") return parser.parse_args() @@ -44,8 +46,9 @@ def main(): min_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork(), - train_script="client.py", + train_script=args.train_script, train_args=f"--batch_size {batch_size}", + cross_site_eval=args.cross_site_eval, ) add_experiment_tracking(recipe, tracking_type="tensorboard") diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py index fb3b40b581..4a718fe3ab 100644 --- a/nvflare/app_opt/pt/recipes/fedavg.py +++ b/nvflare/app_opt/pt/recipes/fedavg.py @@ -20,7 +20,10 @@ from nvflare.app_common.abstract.aggregator import Aggregator from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator from nvflare.app_common.shareablegenerators import FullModelShareableGenerator +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather +from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob from nvflare.client.config import ExchangeFormat, TransferType from nvflare.job_config.script_runner import FrameworkType, ScriptRunner @@ -43,6 +46,7 @@ class _FedAvgValidator(BaseModel): command: str = "python3 -u" server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY params_transfer_type: TransferType = TransferType.FULL + cross_site_eval: bool = False class FedAvgRecipe(Recipe): @@ -75,6 +79,7 @@ class FedAvgRecipe(Recipe): server_expected_format (str): What format to exchange the parameters between server and client. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. + cross_site_eval (bool): Whether to enable cross site model evaluation. Defaults to False. Example: ```python @@ -112,6 +117,7 @@ def __init__( command: str = "python3 -u", server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, + cross_site_eval: bool = False, ): # Validate inputs internally v = _FedAvgValidator( @@ -127,6 +133,7 @@ def __init__( command=command, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, + cross_site_eval=cross_site_eval, ) self.name = v.name @@ -141,7 +148,7 @@ def __init__( self.command = v.command self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type - + self.cross_site_eval = v.cross_site_eval # Create BaseFedJob with initial model job = BaseFedJob( initial_model=self.initial_model, @@ -172,6 +179,16 @@ def __init__( # Send the controller to the server job.to_server(controller) + if self.cross_site_eval: + model_locator_id = job.to_server(PTFileModelLocator(pt_persistor_id=job.comp_ids["persistor_id"])) + eval_controller = CrossSiteModelEval( + model_locator_id=model_locator_id, + submit_model_timeout=600, + validation_timeout=6000, + ) + job.to_server(eval_controller) + job.to_server(ValidationJsonGenerator()) + # Add clients executor = ScriptRunner( script=self.train_script, From 201bd00f30f863f2c9293ed208888f9c380e34a4 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Tue, 16 Sep 2025 23:03:06 -0400 Subject: [PATCH 2/5] formatting --- nvflare/app_opt/pt/recipes/fedavg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py index 4a718fe3ab..f785d95924 100644 --- a/nvflare/app_opt/pt/recipes/fedavg.py +++ b/nvflare/app_opt/pt/recipes/fedavg.py @@ -149,6 +149,7 @@ def __init__( self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type self.cross_site_eval = v.cross_site_eval + # Create BaseFedJob with initial model job = BaseFedJob( initial_model=self.initial_model, From ae04c0c51a9ac3478348290d4c1a79001ed2f31b Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Wed, 17 Sep 2025 10:44:45 -0400 Subject: [PATCH 3/5] move to utils --- examples/hello-world/hello-pt/job.py | 6 +++++- nvflare/app_opt/pt/recipes/fedavg.py | 18 ----------------- nvflare/recipe/utils.py | 29 ++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/examples/hello-world/hello-pt/job.py b/examples/hello-world/hello-pt/job.py index 47a13e9a01..4d011167a2 100644 --- a/examples/hello-world/hello-pt/job.py +++ b/examples/hello-world/hello-pt/job.py @@ -21,6 +21,7 @@ from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe from nvflare.recipe import SimEnv, add_experiment_tracking +from nvflare.recipe.utils import add_cross_site_evaluation def define_parser(): @@ -48,10 +49,13 @@ def main(): initial_model=SimpleNetwork(), train_script=args.train_script, train_args=f"--batch_size {batch_size}", - cross_site_eval=args.cross_site_eval, ) add_experiment_tracking(recipe, tracking_type="tensorboard") + if args.cross_site_eval: + add_cross_site_evaluation(recipe) + + # Run FL simulation env = SimEnv(num_clients=n_clients) run = recipe.execute(env) print() diff --git a/nvflare/app_opt/pt/recipes/fedavg.py b/nvflare/app_opt/pt/recipes/fedavg.py index f785d95924..fb3b40b581 100644 --- a/nvflare/app_opt/pt/recipes/fedavg.py +++ b/nvflare/app_opt/pt/recipes/fedavg.py @@ -20,10 +20,7 @@ from nvflare.app_common.abstract.aggregator import Aggregator from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator from nvflare.app_common.shareablegenerators import FullModelShareableGenerator -from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator -from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather -from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob from nvflare.client.config import ExchangeFormat, TransferType from nvflare.job_config.script_runner import FrameworkType, ScriptRunner @@ -46,7 +43,6 @@ class _FedAvgValidator(BaseModel): command: str = "python3 -u" server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY params_transfer_type: TransferType = TransferType.FULL - cross_site_eval: bool = False class FedAvgRecipe(Recipe): @@ -79,7 +75,6 @@ class FedAvgRecipe(Recipe): server_expected_format (str): What format to exchange the parameters between server and client. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. - cross_site_eval (bool): Whether to enable cross site model evaluation. Defaults to False. Example: ```python @@ -117,7 +112,6 @@ def __init__( command: str = "python3 -u", server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, - cross_site_eval: bool = False, ): # Validate inputs internally v = _FedAvgValidator( @@ -133,7 +127,6 @@ def __init__( command=command, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, - cross_site_eval=cross_site_eval, ) self.name = v.name @@ -148,7 +141,6 @@ def __init__( self.command = v.command self.server_expected_format: ExchangeFormat = v.server_expected_format self.params_transfer_type: TransferType = v.params_transfer_type - self.cross_site_eval = v.cross_site_eval # Create BaseFedJob with initial model job = BaseFedJob( @@ -180,16 +172,6 @@ def __init__( # Send the controller to the server job.to_server(controller) - if self.cross_site_eval: - model_locator_id = job.to_server(PTFileModelLocator(pt_persistor_id=job.comp_ids["persistor_id"])) - eval_controller = CrossSiteModelEval( - model_locator_id=model_locator_id, - submit_model_timeout=600, - validation_timeout=6000, - ) - job.to_server(eval_controller) - job.to_server(ValidationJsonGenerator()) - # Add clients executor = ScriptRunner( script=self.train_script, diff --git a/nvflare/recipe/utils.py b/nvflare/recipe/utils.py index e0eba4da68..89d38a00b4 100644 --- a/nvflare/recipe/utils.py +++ b/nvflare/recipe/utils.py @@ -57,3 +57,32 @@ def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config: receiver_class = getattr(module, TRACKING_REGISTRY[tracking_type]["receiver_class"]) receiver = receiver_class(**tracking_config) recipe.job.to_server(receiver, "receiver") + + +def add_cross_site_evaluation(recipe: Recipe, persistor_id: str = None, submit_model_timeout: int = 600, validation_timeout: int = 6000): + """Enable cross-site model evaluation. + + Args: + recipe_or_job: Recipe object to add cross-site evaluation to + persistor_id: The persistor ID to use for model location. If None, uses the default persistor_id from job.comp_ids + submit_model_timeout: Timeout for model submission in seconds + validation_timeout: Timeout for validation in seconds + """ + from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval + from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator + + # Use provided persistor_id or default from job.comp_ids + if persistor_id is None: + persistor_id = recipe.job.comp_ids["persistor_id"] + + # Create and add model locator + model_locator_id = recipe.job.to_server(PTFileModelLocator(pt_persistor_id=persistor_id)) + + # Create and add cross-site evaluation controller + eval_controller = CrossSiteModelEval( + model_locator_id=model_locator_id, + submit_model_timeout=submit_model_timeout, + validation_timeout=validation_timeout, + ) + recipe.job.to_server(eval_controller) + From c2cb9ad97ebf561dcee3899c9fda6105c043d03c Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Wed, 17 Sep 2025 10:48:48 -0400 Subject: [PATCH 4/5] formatting --- nvflare/recipe/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/nvflare/recipe/utils.py b/nvflare/recipe/utils.py index 89d38a00b4..ba6c115464 100644 --- a/nvflare/recipe/utils.py +++ b/nvflare/recipe/utils.py @@ -59,7 +59,9 @@ def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config: recipe.job.to_server(receiver, "receiver") -def add_cross_site_evaluation(recipe: Recipe, persistor_id: str = None, submit_model_timeout: int = 600, validation_timeout: int = 6000): +def add_cross_site_evaluation( + recipe: Recipe, persistor_id: str = None, submit_model_timeout: int = 600, validation_timeout: int = 6000 +): """Enable cross-site model evaluation. Args: @@ -70,14 +72,14 @@ def add_cross_site_evaluation(recipe: Recipe, persistor_id: str = None, submit_m """ from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator - + # Use provided persistor_id or default from job.comp_ids if persistor_id is None: persistor_id = recipe.job.comp_ids["persistor_id"] - + # Create and add model locator model_locator_id = recipe.job.to_server(PTFileModelLocator(pt_persistor_id=persistor_id)) - + # Create and add cross-site evaluation controller eval_controller = CrossSiteModelEval( model_locator_id=model_locator_id, @@ -85,4 +87,3 @@ def add_cross_site_evaluation(recipe: Recipe, persistor_id: str = None, submit_m validation_timeout=validation_timeout, ) recipe.job.to_server(eval_controller) - From 6e05c390e3b63830bd0bdd9b8d65ef21e5681ae8 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Fri, 19 Sep 2025 16:26:16 -0400 Subject: [PATCH 5/5] enable numpy model locator --- examples/hello-world/hello-pt/job.py | 2 +- nvflare/recipe/utils.py | 47 +++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/examples/hello-world/hello-pt/job.py b/examples/hello-world/hello-pt/job.py index 4d011167a2..f97440548c 100644 --- a/examples/hello-world/hello-pt/job.py +++ b/examples/hello-world/hello-pt/job.py @@ -53,7 +53,7 @@ def main(): add_experiment_tracking(recipe, tracking_type="tensorboard") if args.cross_site_eval: - add_cross_site_evaluation(recipe) + add_cross_site_evaluation(recipe, model_locator_type="pytorch") # Run FL simulation env = SimEnv(num_clients=n_clients) diff --git a/nvflare/recipe/utils.py b/nvflare/recipe/utils.py index ba6c115464..a2e5ba8f14 100644 --- a/nvflare/recipe/utils.py +++ b/nvflare/recipe/utils.py @@ -35,6 +35,19 @@ }, } +MODEL_LOCATOR_REGISTRY = { + "pytorch": { + "locator_module": "nvflare.app_opt.pt.file_model_locator", + "locator_class": "PTFileModelLocator", + "persistor_param": "pt_persistor_id", + }, + "numpy": { + "locator_module": "nvflare.app_common.np.np_model_locator", + "locator_class": "NPModelLocator", + "persistor_param": None, # NPModelLocator doesn't use persistor_id + }, +} + def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config: dict = None): """Enable experiment tracking. @@ -60,25 +73,49 @@ def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config: def add_cross_site_evaluation( - recipe: Recipe, persistor_id: str = None, submit_model_timeout: int = 600, validation_timeout: int = 6000 + recipe: Recipe, + model_locator_type: str = "pytorch", + persistor_id: str = None, + submit_model_timeout: int = 600, + validation_timeout: int = 6000, ): """Enable cross-site model evaluation. Args: - recipe_or_job: Recipe object to add cross-site evaluation to + recipe: Recipe object to add cross-site evaluation to + model_locator_type: The type of model locator to use ("pytorch" or "numpy") persistor_id: The persistor ID to use for model location. If None, uses the default persistor_id from job.comp_ids submit_model_timeout: Timeout for model submission in seconds validation_timeout: Timeout for validation in seconds """ from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval - from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator + + if model_locator_type not in MODEL_LOCATOR_REGISTRY: + raise ValueError( + f"Invalid model locator type: {model_locator_type}. Available types: {list(MODEL_LOCATOR_REGISTRY.keys())}" + ) # Use provided persistor_id or default from job.comp_ids if persistor_id is None: persistor_id = recipe.job.comp_ids["persistor_id"] - # Create and add model locator - model_locator_id = recipe.job.to_server(PTFileModelLocator(pt_persistor_id=persistor_id)) + # Get model locator configuration from registry + locator_config = MODEL_LOCATOR_REGISTRY[model_locator_type] + + # Import and create model locator + module = importlib.import_module(locator_config["locator_module"]) + locator_class = getattr(module, locator_config["locator_class"]) + + # Create model locator with appropriate parameters + if locator_config["persistor_param"] is not None: + # For PyTorch locator, use persistor_id + locator_kwargs = {locator_config["persistor_param"]: persistor_id} + model_locator = locator_class(**locator_kwargs) + else: + # For Numpy locator, use default parameters (no persistor_id needed) + model_locator = locator_class() + + model_locator_id = recipe.job.to_server(model_locator) # Create and add cross-site evaluation controller eval_controller = CrossSiteModelEval(