Skip to content

Avoid loading data at every round for flower #3834

@conanhujinming

Description

@conanhujinming

Is your feature request related to a problem? Please describe.
I run the official hello-flower tutorial from this repo. I found that at every round, nvflare will re-load the training data and testing data for every clients, which is not necessary. It should be possible to load the data only at the first round and keep it in the memory throughout the training process.

Describe the solution you'd like
Only load the training and testing data for every clients at the first round.

Additional context
I added some logs in the loading data functions, and definitely nvflare will reload the data every round:

Image

And it takes about 130s to run this tutorial on my A800 GPU.

Also, I rewrite the same logic in pure flower, and it will only load the data at the first round. It might partly be this reason that flower is able to finish the training within 90s. Here is the flower code:

import os
from collections import OrderedDict
from typing import List

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Context, Metrics, ndarrays_to_parameters

import random

from task import load_data, Net, train, DEVICE

# Device Configuration
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        self.net.load_state_dict(
            OrderedDict({k: torch.tensor(v) for k, v in params_dict}), strict=True
        )

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        results = train(self.net, self.trainloader, self.testloader, epochs=1, device=DEVICE)
        return self.get_parameters(self.net), len(self.trainloader.dataset), results
        
       

def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# Define metric aggregation function
def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
    examples = [num_examples for num_examples, _ in metrics]

    # Multiply accuracy of each client by number of examples used
    train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
    train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics]
    val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
    val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

    # Aggregate and return custom metric (weighted average)
    return {
        "train_loss": sum(train_losses) / sum(examples),
        "train_accuracy": sum(train_accuracies) / sum(examples),
        "val_loss": sum(val_losses) / sum(examples),
        "val_accuracy": sum(val_accuracies) / sum(examples),
    }

def main() -> None:
    torch_fix_seed(seed=42)

    trainloader, testloader = load_data()

    def client_fn(cid: str) -> FlowerClient:
        net = Net().to(DEVICE)
        return FlowerClient(net, trainloader, testloader)

    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.0,  # Disable evaluation
        min_available_clients=2,
        fit_metrics_aggregation_fn=weighted_average,
    )

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=2,
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=strategy,
        client_resources=({"num_cpus": 1, "num_gpus": 0.1})
    )

    print("--- Experiment Finished ---")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions