Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 63 additions & 45 deletions examples/advanced/gnn/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Federated GNN on Graph Dataset using Inductive Learning
In this example, we will demonstrate how to train a classification model using Graph Neural Network (GNN).
In this example, we will demonstrate how to train a classification model using Graph Neural Network (GNN) using a **recipe-based approach**.

### Background of Graph Neural Network
Graph Neural Networks (GNNs) show a promising future in research and industry, with potential applications in various domains, including social networks, e-commerce, recommendation systems, and more.
Expand Down Expand Up @@ -34,62 +34,80 @@ To support functions of PyTorch Geometric necessary for this example, we need ex
python3 -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
```

#### Job Template
We reuse the job templates from [sag_gnn](../../../job_templates/sag_gnn), let's set the job template path with the following command.
```bash
nvflare config -jt ../../../job_templates/
```
Then we can check the available templates with the following command.
```bash
nvflare job list_templates
```
We can see the "sag_gnn" template is available
#### Recipe-based Approach

#### Protein Classification
The PPI dataset is directly available via torch_geometric library, we randomly split the dataset to 2 subsets, one for each client (`--client_id 1` and `--client_id 2`).
First, we run the local training on each client, as well as the whole dataset with `--client_id 0`.
```
python3 code/graphsage_protein_local.py --client_id 0
python3 code/graphsage_protein_local.py --client_id 1
python3 code/graphsage_protein_local.py --client_id 2
```
Then, we create NVFlare job based on GNN template.
This example uses **NVFlare's FedAvgRecipe** to create federated GNN training jobs, leveraging NVFlare's standard FedAvg recipe.

We provide two task-specific job creation functions:
- `create_protein_job()`: For PPI protein classification
- `create_finance_job()`: For Elliptic++ financial transaction classification

Both functions return a configured `FedAvgRecipe` instance that can be executed in simulation or production mode.

The recipe can be used in two ways:
1. **Command-line**: Run `job.py` directly with command-line arguments
2. **Programmatic**: Import and use `create_protein_job()` or `create_finance_job()` from `job.py` in your own Python code

#### Folder Structure
```
nvflare job create -force -j "/tmp/nvflare/jobs/gnn_protein" -w "sag_gnn" -sd "code" \
-f app_1/config_fed_client.conf app_script="graphsage_protein_fl.py" app_config="--client_id 1 --epochs 10" \
-f app_2/config_fed_client.conf app_script="graphsage_protein_fl.py" app_config="--client_id 2 --epochs 10" \
-f app_server/config_fed_server.conf num_rounds=7 key_metric="validation_f1" model_class_path="torch_geometric.nn.GraphSAGE" components[0].args.model.args.in_channels=50 components[0].args.model.args.hidden_channels=64 components[0].args.model.args.num_layers=2 components[0].args.model.args.out_channels=64
gnn/
├── job.py # Job creation functions using FedAvgRecipe
├── model.py # Custom SAGE model for finance task
├── client_protein.py # FL client for protein classification
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name should be "client.py".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two jobs should separate folders each job has its own.
Sharing the code might be convenient for us. its difficult for user.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Financial guys does need to know protein.

├── client_finance.py # FL client for financial transaction classification
└── utils/ # Utilities and local training scripts
├── graphsage_protein_local.py # Local training for protein task
├── graphsage_finance_local.py # Local training for finance task
└── process_elliptic.py # Elliptic++ data preprocessing
```
For client configs, we set client_ids for each client, and the number of local epochs per round for each client's local training.

For server configs, we set the number of rounds for federated training, the key metric for model selection, and the model class path with model hyperparameters.
#### Protein Classification
The PPI dataset is directly available via torch_geometric library. We randomly split the dataset to 2 subsets, one for each client.

With the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator.
First, run local training to establish baselines:
```bash
python3 utils/graphsage_protein_local.py --client_id 0
python3 utils/graphsage_protein_local.py --client_id 1
python3 utils/graphsage_protein_local.py --client_id 2
```
nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_protein

Then, run federated training using the recipe:
```bash
python3 job.py \
--task_type protein \
--num_clients 2 \
--num_rounds 7 \
--epochs_per_round 10 \
--data_path /tmp/nvflare/datasets/ppi \
--workspace_dir /tmp/nvflare/gnn/protein_fl_workspace \
--job_dir /tmp/nvflare/jobs/gnn_protein \
--threads 2
```

#### Financial Transaction Classification
We first download the Elliptic++ dataset to `/tmp/nvflare/datasets/elliptic_pp` folder. In this example, we will use the following three files:
First, download the Elliptic++ dataset to `/tmp/nvflare/datasets/elliptic_pp` folder. In this example, we will use the following three files:
- `txs_classes.csv`: transaction id and its class (licit or illicit)
- `txs_edgelist.csv`: connections for transaction ids
- `txs_features.csv`: transaction id and its features
Then, we run the local training on each client, as well as the whole dataset. Again, `--client_id 0` uses all data.
```
python3 code/graphsage_finance_local.py --client_id 0
python3 code/graphsage_finance_local.py --client_id 1
python3 code/graphsage_finance_local.py --client_id 2
```
Similarly, we create NVFlare job based on GNN template.
```
nvflare job create -force -j "/tmp/nvflare/jobs/gnn_finance" -w "sag_gnn" -sd "code" \
-f app_1/config_fed_client.conf app_script="graphsage_finance_fl.py" app_config="--client_id 1 --epochs 10" \
-f app_2/config_fed_client.conf app_script="graphsage_finance_fl.py" app_config="--client_id 2 --epochs 10" \
-f app_server/config_fed_server.conf num_rounds=7 key_metric="validation_auc" model_class_path="pyg_sage.SAGE" components[0].args.model.args.in_channels=165 components[0].args.model.args.hidden_channels=256 components[0].args.model.args.num_layers=3 components[0].args.model.args.num_classes=2
```
And with the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator.

Run local training to establish baselines:
```bash
python3 utils/graphsage_finance_local.py --client_id 0
python3 utils/graphsage_finance_local.py --client_id 1
python3 utils/graphsage_finance_local.py --client_id 2
```
nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_finance

Then, run federated training using the recipe:
```bash
python3 job.py \
--task_type finance \
--num_clients 2 \
--num_rounds 7 \
--epochs_per_round 10 \
--data_path /tmp/nvflare/datasets/elliptic_pp \
--workspace_dir /tmp/nvflare/gnn/finance_fl_workspace \
--job_dir /tmp/nvflare/jobs/gnn_finance \
--threads 2
```

### Results
Expand Down Expand Up @@ -137,4 +155,4 @@ BibTeX
journal={arXiv preprint arXiv:2306.06108},
year={2023}
}
```
```
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import pandas as pd
import torch
import torch.nn.functional as F
from process_elliptic import process_ellipitc
from pyg_sage import SAGE
from model import SAGE
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Data
from utils.process_elliptic import process_elliptic

DEVICE = "cuda:0"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# (1) import nvflare client API
# Import nvflare client API
import nvflare.client as flare


Expand All @@ -45,33 +45,35 @@ def main():
default=70,
)
parser.add_argument(
"--total_clients",
"--num_clients",
type=int,
default=2,
)
parser.add_argument(
"--client_id",
type=int,
default=0,
help="0: use all data, 1-N: use data from client N",
)
parser.add_argument(
"--output_path",
type=str,
default="./output",
)
args = parser.parse_args()

# Initialize NVFlare client API first to get site name
flare.init()

# Derive client_id from site name (e.g., "site-1" -> 1)
site_name = flare.get_site_name()
client_id = int(site_name.split("-")[-1])
print(f"Site: {site_name}, Client ID: {client_id}")

# Set up tensorboard
writer = SummaryWriter(os.path.join(args.output_path, str(args.client_id)))
writer = SummaryWriter(os.path.join(args.output_path, str(client_id)))

# Create elliptic dataset for training.
df_classes = pd.read_csv(os.path.join(args.data_path, "txs_classes.csv"))
df_edges = pd.read_csv(os.path.join(args.data_path, "txs_edgelist.csv"))
df_features = pd.read_csv(os.path.join(args.data_path, "txs_features.csv"))

# Preprocess data
node_features, classified_idx, unclassified_idx, edge_index, weights, labels, y_train = process_ellipitc(
node_features, classified_idx, unclassified_idx, edge_index, weights, labels, y_train = process_elliptic(
df_features, df_edges, df_classes
)

Expand All @@ -84,45 +86,59 @@ def main():
_, _, y_train, _, train_idx, valid_idx = train_test_split(
node_features[classified_idx], y_train, classified_idx, test_size=0.1, random_state=77, stratify=y_train
)
# Futher split train data into two clients
_, _, _, _, train_1_idx, train_2_idx = train_test_split(
node_features[train_idx], y_train, train_idx, test_size=0.5, random_state=77, stratify=y_train
)

# Get the subgraph index for the client
# note that client 0 uses all data
# client 1 uses data from classified_1 and unclassified data
# client 2 uses data from classified_2 and unclassified data
if args.client_id == 0:
train_data_sub = train_data
if args.client_id == 1:
train_data_sub = train_data.subgraph(torch.tensor(train_1_idx.append(unclassified_idx)))
train_idx = np.arange(len(train_1_idx))
elif args.client_id == 2:
train_data_sub = train_data.subgraph(torch.tensor(train_2_idx.append(unclassified_idx)))
train_idx = np.arange(len(train_2_idx))
# Split train data among clients
np.random.seed(77)
shuffled_train_idx = train_idx.copy()
np.random.shuffle(shuffled_train_idx)
client_train_splits = np.array_split(shuffled_train_idx, args.num_clients)

# Get the subgraph index for the client (client_id is 1-indexed)
# Each client uses their subset of classified data plus all unclassified data
client_subset_idx = client_train_splits[client_id - 1]
combined_idx = np.concatenate([client_subset_idx, unclassified_idx])
train_data_sub = train_data.subgraph(torch.tensor(combined_idx))
train_idx = np.arange(len(client_subset_idx))
train_data = train_data.to(DEVICE)
train_data_sub = train_data_sub.to(DEVICE)

# Train model
model = SAGE(train_data_sub.num_node_features, hidden_channels=256, num_classes=2, num_layers=3)
model.double()

# (2) initializes NVFlare client API
flare.init()
# Define evaluation logic outside the training loop for efficiency
# This function is reused for evaluation on both trained and received model
def evaluate(input_weights, step):
model_eval = SAGE(train_data.num_node_features, hidden_channels=256, num_classes=2, num_layers=3)
model_eval.double()
model_eval.load_state_dict(input_weights)
# (optional) use GPU to speed things up
model_eval.to(DEVICE)

with torch.no_grad():
model_eval.eval()
out = model_eval(train_data)
# Use probability scores for AUC calculation, not binary predictions
y_prob = F.softmax(out, dim=1)[:, 1].detach().cpu().numpy()
y_true = train_data.y.detach().cpu().numpy()
val_auc = roc_auc_score(y_true[valid_idx], y_prob[valid_idx])
print(f"Validation AUC: {val_auc:.4f} ")
writer.add_scalar("val_auc", val_auc, step)
return val_auc

while flare.is_running():
# (3) receives FLModel from NVFlare
# Receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

# (3) loads model from NVFlare
# Loads model from NVFlare
model.load_state_dict(input_model.params)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# (optional) use GPU to speed things up
model.to(DEVICE)

# (optional) calculate total steps
steps = args.epochs * len(train_idx)
for epoch in range(1, args.epochs + 1):
Expand All @@ -136,34 +152,16 @@ def main():
print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}")
writer.add_scalar("train_loss", loss.item(), input_model.current_round * args.epochs + epoch)

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
model_eval = SAGE(train_data.num_node_features, hidden_channels=256, num_classes=2, num_layers=3)
model_eval.double()
model_eval.load_state_dict(input_weights)
# (optional) use GPU to speed things up
model_eval.to(DEVICE)

with torch.no_grad():
model_eval.eval()
out = model_eval(train_data)
y_pred = torch.argmax(out, dim=1).detach().cpu().numpy()
y_true = train_data.y.detach().cpu().numpy()
val_auc = roc_auc_score(y_true[valid_idx], y_pred[valid_idx])
print(f"Validation AUC: {val_auc:.4f} ")
writer.add_scalar("val_auc", val_auc, input_model.current_round * args.epochs + epoch)
return val_auc

# (6) evaluate on received model for model selection
global_auc = evaluate(input_model.params)
# (7) construct trained FL model
# Evaluate on received model for model selection
final_step = input_model.current_round * args.epochs + args.epochs
global_auc = evaluate(input_model.params, final_step)
# Construct trained FL model
output_model = flare.FLModel(
params=model.cpu().state_dict(),
metrics={"validation_auc": global_auc},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
# (8) send model back to NVFlare
# Send model back to NVFlare
flare.send(output_model)


Expand Down
Loading
Loading