Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions examples/hello-world/hello-pt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,7 @@ The FedAvg controller implements these main steps:
- FL server sends the global model to clients and waits for their updates using `self.send_model_and_wait()`.
- FL server aggregates all the `results` and produces a new global model using `self.update_model()`.

```python
class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")

model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round

clients = self.sample_clients(self.num_clients)

results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(results)

model = self.update_model(model, aggregate_results)

self.save_model(model)

self.info("Finished FedAvg.")
```

In this example, we will directly use the default federated averaging algorithm provided by NVFlare. The FedAvg class is defined in nvflare.app_common.workflows.fedavg.FedAvg
In this example, we will directly use the default federated averaging algorithm provided by NVFlare implemented using the `ScatterAndGather` controller.
There is no need to defined a customized server code for this example.

## Job Recipe Code
Expand All @@ -168,22 +142,31 @@ Job Recipe contains the client.py and built-in Fed average algorithm.
recipe.execute(env=env)
```

To include both training and evaluation, you can change the recipe's training script
To include both training and evaluation using the `CrossSiteModelEval` controller, we use the recipe's training script

```python

train_script="client_with_eval.py",

```
or simply overwrite client.py with client_with_eval.py
and set `cross_site_eval=True` in [job.py](job.py) using below commandline arguments.

## Run Job
from terminal try to run the code


```
python job.py
```

To run with cross-site evaluation, execute
```
python job.py --train_script="client_with_eval.py" --cross_site_eval
```
The cross-site evaluation results can be shown via
```
cat /tmp/nvflare/simulation/hello-pt/server/simulate_job/cross_site_val/cross_val_results.json
```

> Note:
>> depends on the number of clients, you might run into error due to several client try to download the data at the same time.
>> suggest to pre-download the data to avoid such errors.
Expand Down
9 changes: 8 additions & 1 deletion examples/hello-world/hello-pt/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@

from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe import SimEnv, add_experiment_tracking
from nvflare.recipe.utils import add_cross_site_evaluation


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--n_clients", type=int, default=2)
parser.add_argument("--num_rounds", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--train_script", type=str, default="client.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why make this an argument, are we keep changing this file ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes it easier to use. Previous instructions required the user to change or "overwrite" the code in client.py...

parser.add_argument("--cross_site_eval", action="store_true")

return parser.parse_args()

Expand All @@ -44,11 +47,15 @@ def main():
min_clients=n_clients,
num_rounds=num_rounds,
initial_model=SimpleNetwork(),
train_script="client.py",
train_script=args.train_script,
train_args=f"--batch_size {batch_size}",
)
add_experiment_tracking(recipe, tracking_type="tensorboard")

if args.cross_site_eval:
add_cross_site_evaluation(recipe, model_locator_type="pytorch")

# Run FL simulation
env = SimEnv(num_clients=n_clients)
run = recipe.execute(env)
print()
Expand Down
67 changes: 67 additions & 0 deletions nvflare/recipe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@
},
}

MODEL_LOCATOR_REGISTRY = {
"pytorch": {
"locator_module": "nvflare.app_opt.pt.file_model_locator",
"locator_class": "PTFileModelLocator",
"persistor_param": "pt_persistor_id",
},
"numpy": {
"locator_module": "nvflare.app_common.np.np_model_locator",
"locator_class": "NPModelLocator",
"persistor_param": None, # NPModelLocator doesn't use persistor_id
},
}


def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config: dict = None):
"""Enable experiment tracking.
Expand All @@ -57,3 +70,57 @@ def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config:
receiver_class = getattr(module, TRACKING_REGISTRY[tracking_type]["receiver_class"])
receiver = receiver_class(**tracking_config)
recipe.job.to_server(receiver, "receiver")


def add_cross_site_evaluation(
recipe: Recipe,
model_locator_type: str = "pytorch",
persistor_id: str = None,
submit_model_timeout: int = 600,
validation_timeout: int = 6000,
):
"""Enable cross-site model evaluation.

Args:
recipe: Recipe object to add cross-site evaluation to
model_locator_type: The type of model locator to use ("pytorch" or "numpy")
persistor_id: The persistor ID to use for model location. If None, uses the default persistor_id from job.comp_ids
submit_model_timeout: Timeout for model submission in seconds
validation_timeout: Timeout for validation in seconds
"""
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval

if model_locator_type not in MODEL_LOCATOR_REGISTRY:
raise ValueError(
f"Invalid model locator type: {model_locator_type}. Available types: {list(MODEL_LOCATOR_REGISTRY.keys())}"
)

# Use provided persistor_id or default from job.comp_ids
if persistor_id is None:
persistor_id = recipe.job.comp_ids["persistor_id"]

# Get model locator configuration from registry
locator_config = MODEL_LOCATOR_REGISTRY[model_locator_type]

# Import and create model locator
module = importlib.import_module(locator_config["locator_module"])
locator_class = getattr(module, locator_config["locator_class"])

# Create model locator with appropriate parameters
if locator_config["persistor_param"] is not None:
# For PyTorch locator, use persistor_id
locator_kwargs = {locator_config["persistor_param"]: persistor_id}
model_locator = locator_class(**locator_kwargs)
else:
# For Numpy locator, use default parameters (no persistor_id needed)
model_locator = locator_class()

model_locator_id = recipe.job.to_server(model_locator)

# Create and add cross-site evaluation controller
eval_controller = CrossSiteModelEval(
model_locator_id=model_locator_id,
submit_model_timeout=submit_model_timeout,
validation_timeout=validation_timeout,
)
recipe.job.to_server(eval_controller)
Loading