From 1468e593f2aff1ce82f581a0be64e8de7fb052cf Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 15 Dec 2025 13:08:12 -0500 Subject: [PATCH 1/7] convert KM jobapi to recipe --- examples/advanced/kaplan-meier-he/README.md | 224 +++++++++++++++++- .../{src/kaplan_meier_train.py => client.py} | 0 .../kaplan_meier_train_he.py => client_he.py} | 0 examples/advanced/kaplan-meier-he/job.py | 189 +++++++++++++++ examples/advanced/kaplan-meier-he/km_job.py | 116 --------- examples/advanced/kaplan-meier-he/project.yml | 52 ++++ .../{src/kaplan_meier_wf.py => server.py} | 2 +- .../kaplan_meier_wf_he.py => server_he.py} | 2 +- .../advanced/kaplan-meier-he/start_all.sh | 73 ++++++ 9 files changed, 532 insertions(+), 126 deletions(-) rename examples/advanced/kaplan-meier-he/{src/kaplan_meier_train.py => client.py} (100%) rename examples/advanced/kaplan-meier-he/{src/kaplan_meier_train_he.py => client_he.py} (100%) create mode 100644 examples/advanced/kaplan-meier-he/job.py delete mode 100644 examples/advanced/kaplan-meier-he/km_job.py create mode 100644 examples/advanced/kaplan-meier-he/project.yml rename examples/advanced/kaplan-meier-he/{src/kaplan_meier_wf.py => server.py} (98%) rename examples/advanced/kaplan-meier-he/{src/kaplan_meier_wf_he.py => server_he.py} (98%) create mode 100755 examples/advanced/kaplan-meier-he/start_all.sh diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 7e25710469..90bc31fe67 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -2,7 +2,7 @@ This example illustrates two features: * How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE). -* How to use the Flare ModelController API to contract a workflow to facilitate HE under simulator mode. +* How to use the Recipe API with Flare ModelController for job configuration and execution in both simulation and production environments. ## Basics of Kaplan-Meier Analysis Kaplan-Meier survival analysis is a non-parametric statistic used to estimate the survival function from lifetime data. It is used to analyze the time it takes for an event of interest to occur. For example, during a clinical trial, the Kaplan-Meier estimator can be used to estimate the proportion of patients who survive a certain amount of time after treatment. @@ -62,7 +62,7 @@ To run the baseline script, simply execute: ```commandline python utils/baseline_kaplan_meier.py ``` -By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resulting KM curve is shown below: +By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp/nvflare/baseline` directory. The resulting KM curve is shown below: ![KM survival baseline](figs/km_curve_baseline.png) Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution. @@ -90,23 +90,231 @@ For the federated analysis with HE, we need to ensure proper HE aggregation usin After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. ## Run the job -First, we prepared data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days. + +This example supports both **Simulation Mode** (for local testing) and **Production Mode** (for real-world deployment). + +| Feature | Simulation Mode | Production Mode | +|---------|----------------|-----------------| +| **Use Case** | Testing & Development | Real-world Deployment / Production Testing | +| **HE Context** | Manual preparation via script | Auto-provisioned via startup kits | +| **Security** | Single machine, no encryption between processes | Secure startup kits with certificates | +| **Setup** | Quick & simple | Requires provisioning & starting all parties | +| **Startup** | Single command | `start_all.sh` (local) or manual (distributed) | +| **Participants** | All run locally in one process | Distributed servers/clients running separately | +| **Data** | Prepared once, shared by all | Same data reused from simulation | + +### Simulation Mode + +For simulation mode (testing and development), we manually prepare the data and HE context: + +**Step 1: Prepare Data** + +Split and generate data files for each client with binning interval of 7 days: ```commandline python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/nvflare/dataset/km_data" ``` -Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context. +**Step 2: Prepare HE Context (Simulation Only)** + +For simulation mode, manually prepare the HE context with BFV scheme: ```commandline python utils/prepare_he_context.py --out_path "/tmp/nvflare/he_context" ``` -Next, we run the federated training using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html), both without and with HE: +**Step 3: Run the Job** + +Run the job without and with HE: +```commandline +python job.py +python job.py --encryption +``` + +### Production Mode + +For production deployments, the HE context is automatically provisioned through secure startup kits. + +**Quick Start for Local Testing:** +If you want to quickly test production mode on a single machine, use the convenience scripts: +1. Run provisioning: `nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces` +2. Start all parties: `./start_all.sh` +3. Start admin console and use username `admin@nvidia.com` +4. Submit job: `python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com` +5. Monitor job via admin console (use `list_jobs`, `check_status client`) +6. Stop all parties with admin console + +**Note:** You may see SSL handshake warnings in local testing - these are harmless and can be ignored. + +For detailed steps and distributed deployment, continue below: + +**Step 1: Install NVFlare with HE Support** + +```commandline +pip install nvflare[HE] +``` + +**Step 2: Provision Startup Kits with HE Context** + +The `project.yml` file in this directory is pre-configured with `HEBuilder` using the CKKS scheme. Run provisioning to output to `/tmp/nvflare/prod_workspaces`: + +```commandline +nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces +``` + +This generates startup kits in `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`: +- `localhost/` - Server startup kit with `server_context.tenseal` +- `site-1/`, `site-2/`, etc. - Client startup kits, each with `client_context.tenseal` +- `admin@nvidia.com/` - Admin console + +The HE context files are automatically included in each startup kit and do not need to be manually distributed. + +**Step 3: Distribute Startup Kits** + +Securely distribute the startup kits to each participant from `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`: +- `localhost/` directory is the server (for local testing, no need to send) +- Send `site-1/`, `site-2/`, etc. directories to each client host (for distributed deployment) +- Keep `admin@nvidia.com/` directory for the admin user + +**Step 4: Start All Parties** + +**Option A: Quick Start (Local Testing)** + +For local testing where all parties run on the same machine, use the convenience script: + +```commandline +./start_all.sh +``` + +This will start the server and all 5 clients in the background. Logs are saved to `/tmp/nvflare/logs/`. + +Then start the admin console: ```commandline -python km_job.py -python km_job.py --encryption +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +./startup/fl_admin.sh +``` + +**Important:** When prompted for "User Name:", enter `admin@nvidia.com` (this matches the admin defined in project.yml). + +Once connected, check the status of all participants: ``` +> check_status server +> check_status client +``` + +**Option B: Manual Start (Distributed Deployment)** + +For distributed deployment where parties run on different machines: + +**On the Server Host:** +```commandline +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/localhost +./startup/start.sh +``` + +Wait for the server to be ready (you should see "Server started" in the logs). + +**On Each Client Host:** +```commandline +# On site-1 +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1 +./startup/start.sh + +# On site-2 +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-2 +./startup/start.sh + +# Repeat for site-3, site-4, and site-5 +``` + +**On the Admin Machine:** +```commandline +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +./startup/fl_admin.sh +# When prompted, use username: admin@nvidia.com +``` + +**Step 5: Submit and Run the Job** + +With all parties running, submit the job using the Recipe API. The job will automatically use: +- The provisioned HE context from each participant's startup kit +- The data already prepared in simulation mode at `/tmp/nvflare/dataset/km_data` + +```commandline +python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +``` + +The job will be submitted to the FL system and executed across all connected clients. + +**Monitoring Job Progress:** + +The job runs asynchronously. To monitor progress, use the admin console: +``` +> list_jobs +> check_status server +> check_status client +``` + +To download job results after completion: +``` +> download_job +``` + +Results will be saved to each client's workspace directory: +- `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1/` +- Check for `km_curve_fl_he.png` and `km_global.json` in each client's directory + +**Note:** In production mode with HE, the `--he_context_path` parameter is automatically set to use the provisioned `client_context.tenseal` or `server_context.tenseal` from each participant's startup kit. No manual HE context distribution is needed. + +**Step 6: Shutdown All Parties** + +After the job completes, shut down all parties gracefully Via Admin Console + +``` +> shutdown all +``` + +### Customization Options + +**Simulation Mode:** +```commandline +# Customize number of clients and threads +python job.py --num_clients 3 --num_threads 2 + +# Customize data paths (simulation only) +python job.py --data_root /custom/data/path --he_context_path /custom/he/path + +# Customize output directories +python job.py --workspace_dir /custom/workspace --job_dir /custom/jobdir + +# Combine options +python job.py --encryption --num_clients 10 --num_threads 5 +``` + +**Production Mode:** +```commandline +# Make sure server and all clients are started first (see Step 4) +# The he_context_path is automatically managed by startup kits +# By default, uses the same data from simulation mode at /tmp/nvflare/dataset/km_data +# Only customize data_root if you've moved the data elsewhere +python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com --data_root /custom/data/path +``` + +By default, this will generate a KM curve image `km_curve_fl.png` (or `km_curve_fl_he.png` with encryption) under each client's directory. + +### HE Context and Data Management + +- **Simulation Mode**: + - Uses **BFV scheme** (integer arithmetic, suitable for histogram aggregation) + - HE context files are manually created via `prepare_he_context.py` + - Data prepared at `/tmp/nvflare/dataset/km_data` + - Paths specified via `--he_context_path` and `--data_root` +- **Production Mode**: + - Uses **CKKS scheme** (approximate arithmetic, easier provisioning) + - HE context is automatically provisioned into startup kits via `nvflare provision` + - Clients use: `/client_context.tenseal` + - Server uses: `/server_context.tenseal` + - **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default -By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory. +**Note:** Both BFV and CKKS schemes provide strong encryption and work well for this Kaplan-Meier analysis. BFV is used in simulation for exact integer operations, while CKKS is used in production for simpler provisioning and broader compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. ## Display Result diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/client.py similarity index 100% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py rename to examples/advanced/kaplan-meier-he/client.py diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py b/examples/advanced/kaplan-meier-he/client_he.py similarity index 100% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py rename to examples/advanced/kaplan-meier-he/client_he.py diff --git a/examples/advanced/kaplan-meier-he/job.py b/examples/advanced/kaplan-meier-he/job.py new file mode 100644 index 0000000000..5db31410f8 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from server import KM +from server_he import KM_HE + +from nvflare import FedJob +from nvflare.job_config.script_runner import ScriptRunner +from nvflare.recipe import ProdEnv, SimEnv +from nvflare.recipe.spec import Recipe + + +class KMRecipe(Recipe): + """Recipe wrapper around the Kaplan-Meier job configuration. + + This provides a recipe-style API for easy job configuration and execution + in both simulation and production environments. + """ + + def __init__( + self, + *, + num_clients: int, + encryption: bool = False, + data_root: str = "/tmp/nvflare/dataset/km_data", + he_context_path: str = "/tmp/nvflare/he_context/he_context_client.txt", + ): + self.num_clients = num_clients + self.encryption = encryption + self.data_root = data_root + self.he_context_path = he_context_path + + # Set job name and script based on encryption mode + if self.encryption: + job_name = "KM_HE" + train_script = "client_he.py" + script_args = f"--data_root {data_root} --he_context_path {he_context_path}" + controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) + else: + job_name = "KM" + train_script = "client.py" + script_args = f"--data_root {data_root}" + controller = KM(min_clients=num_clients) + + # Create the FedJob + job = FedJob(name=job_name, min_clients=num_clients) + + # Add controller workflow to server + job.to_server(controller) + + # Add ScriptRunner to all clients + runner = ScriptRunner( + script=train_script, + script_args=script_args, + framework="raw", + launch_external_process=False, + ) + job.to_clients(runner, tasks=["train"]) + + super().__init__(job) + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--workspace_dir", + type=str, + default="/tmp/nvflare/workspaces/km", + help="Work directory for simulator runs, default to '/tmp/nvflare/workspaces/km'", + ) + parser.add_argument( + "--job_dir", + type=str, + default="/tmp/nvflare/jobs/km", + help="Directory for job export, default to '/tmp/nvflare/jobs/km'", + ) + parser.add_argument( + "--encryption", + action=argparse.BooleanOptionalAction, + help="Whether to enable encryption, default to False", + ) + parser.add_argument( + "--num_clients", + type=int, + default=5, + help="Number of clients to simulate, default to 5", + ) + parser.add_argument( + "--num_threads", + type=int, + help="Number of threads to use for FL simulation, default to the number of clients if not specified", + ) + parser.add_argument( + "--data_root", + type=str, + default="/tmp/nvflare/dataset/km_data", + help="Root directory for KM data, default to '/tmp/nvflare/dataset/km_data'", + ) + parser.add_argument( + "--he_context_path", + type=str, + default="/tmp/nvflare/he_context/he_context_client.txt", + help="Path to HE context file, default to '/tmp/nvflare/he_context/he_context_client.txt'", + ) + parser.add_argument( + "--startup_kit_location", + type=str, + default=None, + help="Startup kit location for production mode, default to None (simulation mode)", + ) + parser.add_argument( + "--username", + type=str, + default="admin@nvidia.com", + help="Username for production mode, default to 'admin@nvidia.com'", + ) + return parser.parse_args() + + +def main(): + print("Starting Kaplan-Meier job...") + args = define_parser() + print("args:", args) + + num_clients = args.num_clients + num_threads = args.num_threads if args.num_threads else num_clients + + # Determine job name for workspace directory + job_name = "KM_HE" if args.encryption else "KM" + workspace_dir = args.workspace_dir.replace("/km/", f"/{job_name}/") + + # Create the recipe + recipe = KMRecipe( + num_clients=num_clients, + encryption=args.encryption, + data_root=args.data_root, + he_context_path=args.he_context_path, + ) + + # Export job + print("Exporting job to", args.job_dir) + recipe.job.export_job(args.job_dir) + + # Run recipe + if args.startup_kit_location: + print("Running job in production mode...") + print("startup_kit_location=", args.startup_kit_location) + print("username=", args.username) + env = ProdEnv(startup_kit_location=args.startup_kit_location, username=args.username) + else: + print("Running job in simulation mode...") + print("workspace_dir=", workspace_dir) + print("num_clients=", num_clients) + print("num_threads=", num_threads) + env = SimEnv(num_clients=num_clients, num_threads=num_threads, workspace_root=workspace_dir) + + run = recipe.execute(env) + print("Job Status is:", run.get_status()) + + # In production mode, job runs asynchronously on the FL system + # Check status via admin console instead of waiting for result here + if args.startup_kit_location: + print("\nJob submitted successfully to the FL system!") + print("To monitor job status, use the admin console:") + print(f" cd {args.startup_kit_location}") + print(" ./startup/fl_admin.sh") + print(" > check_status server") + print(" > list_jobs") + print(f" > download_job {run.job_id}") + else: + # In simulation mode, we can get the result synchronously + print("Job Result is:", run.get_result()) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/km_job.py b/examples/advanced/kaplan-meier-he/km_job.py deleted file mode 100644 index 1f59ac6a13..0000000000 --- a/examples/advanced/kaplan-meier-he/km_job.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -from src.kaplan_meier_wf import KM -from src.kaplan_meier_wf_he import KM_HE - -from nvflare import FedJob -from nvflare.job_config.script_runner import ScriptRunner - - -def main(): - args = define_parser() - # Default paths - data_root = "/tmp/nvflare/dataset/km_data" - he_context_path = "/tmp/nvflare/he_context/he_context_client.txt" - - # Set the script and config - if args.encryption: - job_name = "KM_HE" - train_script = "src/kaplan_meier_train_he.py" - script_args = f"--data_root {data_root} --he_context_path {he_context_path}" - else: - job_name = "KM" - train_script = "src/kaplan_meier_train.py" - script_args = f"--data_root {data_root}" - - # Set the number of clients and threads - num_clients = args.num_clients - if args.num_threads: - num_threads = args.num_threads - else: - num_threads = num_clients - - # Set the output workspace and job directories - workspace_dir = os.path.join(args.workspace_dir, job_name) - job_dir = args.job_dir - - # Create the FedJob - job = FedJob(name=job_name, min_clients=num_clients) - - # Define the KM controller workflow and send to server - if args.encryption: - controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) - else: - controller = KM(min_clients=num_clients) - job.to_server(controller) - - # Define the ScriptRunner and send to all clients - runner = ScriptRunner( - script=train_script, - script_args=script_args, - framework="raw", - params_exchange_format="raw", - launch_external_process=False, - ) - job.to_clients(runner, tasks=["train"]) - - # Export the job - print("job_dir=", job_dir) - job.export_job(job_dir) - - # Run the job - print("workspace_dir=", workspace_dir) - print("num_threads=", num_threads) - job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads) - - -def define_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--workspace_dir", - type=str, - default="/tmp/nvflare/jobs/km/workdir", - help="work directory, default to '/tmp/nvflare/jobs/km/workdir'", - ) - parser.add_argument( - "--job_dir", - type=str, - default="/tmp/nvflare/jobs/km/jobdir", - help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'", - ) - parser.add_argument( - "--encryption", - action=argparse.BooleanOptionalAction, - help="whether to enable encryption, default to False", - ) - parser.add_argument( - "--num_clients", - type=int, - default=5, - help="number of clients to simulate, default to 5", - ) - parser.add_argument( - "--num_threads", - type=int, - help="number of threads to use for FL simulation, default to the number of clients if not specified", - ) - return parser.parse_args() - - -if __name__ == "__main__": - main() diff --git a/examples/advanced/kaplan-meier-he/project.yml b/examples/advanced/kaplan-meier-he/project.yml new file mode 100644 index 0000000000..368b07338a --- /dev/null +++ b/examples/advanced/kaplan-meier-he/project.yml @@ -0,0 +1,52 @@ +# Sample project.yml for Kaplan-Meier production deployment with HE (CKKS) +# This file is used for provisioning secure startup kits with HE context +# Note: Simulation mode uses BFV scheme, production mode uses CKKS scheme + +api_version: 3 +name: km_he_project +description: Kaplan-Meier Survival Analysis with Homomorphic Encryption + +participants: + - name: localhost + type: server + org: example + fed_learn_port: 8002 + admin_port: 8003 + - name: site-1 + type: client + org: example + - name: site-2 + type: client + org: example + - name: site-3 + type: client + org: example + - name: site-4 + type: client + org: example + - name: site-5 + type: client + org: example + - name: admin@nvidia.com + type: admin + org: nvidia + role: project_admin + +# The same methods in all builders are called in their order defined in builders section +builders: + - path: nvflare.lighter.impl.workspace.WorkspaceBuilder + args: + template_file: master_template.yml + - path: nvflare.lighter.impl.static_file.StaticFileBuilder + args: + config_folder: config + - path: nvflare.lighter.impl.cert.CertBuilder + # HEBuilder provisions HE context for CKKS scheme + - path: nvflare.lighter.impl.he.HEBuilder + args: + poly_modulus_degree: 8192 + coeff_mod_bit_sizes: [60, 40, 40] + scale_bits: 40 + scheme: CKKS # Using CKKS scheme for production deployment + - path: nvflare.lighter.impl.signature.SignatureBuilder + diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/server.py similarity index 98% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py rename to examples/advanced/kaplan-meier-he/server.py index 436778fa57..d6e3c3b401 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/server.py @@ -21,7 +21,7 @@ # Controller Workflow class KM(ModelController): def __init__(self, min_clients: int): - super(KM, self).__init__() + super(KM, self).__init__(persistor_id="") self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients self.num_rounds = 2 diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py b/examples/advanced/kaplan-meier-he/server_he.py similarity index 98% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py rename to examples/advanced/kaplan-meier-he/server_he.py index 820848ea6f..289edc3e6b 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py +++ b/examples/advanced/kaplan-meier-he/server_he.py @@ -25,7 +25,7 @@ class KM_HE(ModelController): def __init__(self, min_clients: int, he_context_path: str): - super(KM_HE, self).__init__() + super(KM_HE, self).__init__(persistor_id="") self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients self.he_context_path = he_context_path diff --git a/examples/advanced/kaplan-meier-he/start_all.sh b/examples/advanced/kaplan-meier-he/start_all.sh new file mode 100755 index 0000000000..1b5856266a --- /dev/null +++ b/examples/advanced/kaplan-meier-he/start_all.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Convenience script to start all parties for local production environment testing +# This starts the server and all 5 clients in the background + +set -e + +WORKSPACE_ROOT="/tmp/nvflare/prod_workspaces/km_he_project/prod_00" + +echo "======================================" +echo "Starting NVFlare Production Environment" +echo "======================================" +echo "" + +# Check if workspace exists +if [ ! -d "$WORKSPACE_ROOT" ]; then + echo "Error: Workspace not found at $WORKSPACE_ROOT" + echo "Please run provisioning first:" + echo " nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces" + exit 1 +fi + +# Function to start a component +start_component() { + local name=$1 + local path=$2 + local log_file="/tmp/nvflare/logs/${name}.log" + + mkdir -p /tmp/nvflare/logs + + echo "Starting ${name}..." + cd "${path}" + ./startup/start.sh > "${log_file}" 2>&1 & + local pid=$! + echo " PID: ${pid}" + echo " Log: ${log_file}" + echo "${pid}" > "/tmp/nvflare/logs/${name}.pid" +} + +# Start server +echo "" +echo "1. Starting Server..." +start_component "localhost" "${WORKSPACE_ROOT}/localhost" +echo " Waiting for server to be ready..." +sleep 10 + +# Start clients +echo "" +echo "2. Starting Clients..." +for i in {1..5}; do + start_component "site-${i}" "${WORKSPACE_ROOT}/site-${i}" + sleep 2 +done + +echo "" +echo "======================================" +echo "All parties started successfully!" +echo "======================================" +echo "" +echo "Server and clients are running in the background." +echo "Logs are available in /tmp/nvflare/logs/" +echo "" +echo "To check status:" +echo " tail -f /tmp/nvflare/logs/server.log" +echo " tail -f /tmp/nvflare/logs/site-1.log" +echo "" +echo "To start admin console:" +echo " cd ${WORKSPACE_ROOT}/admin@nvidia.com" +echo " ./startup/fl_admin.sh" +echo "" +echo "To stop all parties, run:" +echo " ./stop_all.sh" +echo "" + From 0f712b5c68fdf9ede8cdc19a98cfa3f323108d6c Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 15 Dec 2025 13:16:52 -0500 Subject: [PATCH 2/7] readme update --- examples/advanced/kaplan-meier-he/README.md | 2 -- examples/advanced/kaplan-meier-he/start_all.sh | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 90bc31fe67..fd771aa322 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -142,8 +142,6 @@ If you want to quickly test production mode on a single machine, use the conveni 5. Monitor job via admin console (use `list_jobs`, `check_status client`) 6. Stop all parties with admin console -**Note:** You may see SSL handshake warnings in local testing - these are harmless and can be ignored. - For detailed steps and distributed deployment, continue below: **Step 1: Install NVFlare with HE Support** diff --git a/examples/advanced/kaplan-meier-he/start_all.sh b/examples/advanced/kaplan-meier-he/start_all.sh index 1b5856266a..abf70be009 100755 --- a/examples/advanced/kaplan-meier-he/start_all.sh +++ b/examples/advanced/kaplan-meier-he/start_all.sh @@ -60,7 +60,7 @@ echo "Server and clients are running in the background." echo "Logs are available in /tmp/nvflare/logs/" echo "" echo "To check status:" -echo " tail -f /tmp/nvflare/logs/server.log" +echo " tail -f /tmp/nvflare/logs/localhost.log" echo " tail -f /tmp/nvflare/logs/site-1.log" echo "" echo "To start admin console:" From f55f3704013c2ce82ac690b4556990260c4fb0ed Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 15 Dec 2025 13:18:55 -0500 Subject: [PATCH 3/7] readme update --- examples/advanced/kaplan-meier-he/start_all.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/start_all.sh b/examples/advanced/kaplan-meier-he/start_all.sh index abf70be009..578e579a9a 100755 --- a/examples/advanced/kaplan-meier-he/start_all.sh +++ b/examples/advanced/kaplan-meier-he/start_all.sh @@ -67,7 +67,6 @@ echo "To start admin console:" echo " cd ${WORKSPACE_ROOT}/admin@nvidia.com" echo " ./startup/fl_admin.sh" echo "" -echo "To stop all parties, run:" -echo " ./stop_all.sh" +echo "To stop all parties, use Admin Console" echo "" From b4836f07abe1241d1cbe88fbb7a5466cfcd14439 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 15 Dec 2025 13:41:04 -0500 Subject: [PATCH 4/7] use ckks to ensure consistency --- examples/advanced/kaplan-meier-he/README.md | 10 ++++++---- .../advanced/kaplan-meier-he/client_he.py | 19 ++++++++++--------- .../advanced/kaplan-meier-he/server_he.py | 4 ++-- .../utils/prepare_he_context.py | 4 ++-- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index fd771aa322..d076ea0a6b 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -116,11 +116,13 @@ python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/nvflare/ **Step 2: Prepare HE Context (Simulation Only)** -For simulation mode, manually prepare the HE context with BFV scheme: +For simulation mode, manually prepare the HE context with CKKS scheme: ```commandline python utils/prepare_he_context.py --out_path "/tmp/nvflare/he_context" ``` +(By default uses CKKS scheme. To use BFV, add `--scheme BFV`) + **Step 3: Run the Job** Run the job without and with HE: @@ -301,18 +303,18 @@ By default, this will generate a KM curve image `km_curve_fl.png` (or `km_curve_ ### HE Context and Data Management - **Simulation Mode**: - - Uses **BFV scheme** (integer arithmetic, suitable for histogram aggregation) + - Uses **CKKS scheme** (approximate arithmetic, compatible with production) - HE context files are manually created via `prepare_he_context.py` - Data prepared at `/tmp/nvflare/dataset/km_data` - Paths specified via `--he_context_path` and `--data_root` - **Production Mode**: - - Uses **CKKS scheme** (approximate arithmetic, easier provisioning) + - Uses **CKKS scheme** (same as simulation for consistency) - HE context is automatically provisioned into startup kits via `nvflare provision` - Clients use: `/client_context.tenseal` - Server uses: `/server_context.tenseal` - **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default -**Note:** Both BFV and CKKS schemes provide strong encryption and work well for this Kaplan-Meier analysis. BFV is used in simulation for exact integer operations, while CKKS is used in production for simpler provisioning and broader compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. +**Note:** CKKS scheme provides strong encryption with approximate arithmetic, which works well for this Kaplan-Meier analysis. The histogram counts are encrypted as floating-point numbers and rounded back to integers after decryption. Both simulation and production modes use the same CKKS scheme for consistency and compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. ## Display Result diff --git a/examples/advanced/kaplan-meier-he/client_he.py b/examples/advanced/kaplan-meier-he/client_he.py index 1ff9c69dbb..2b280e5584 100644 --- a/examples/advanced/kaplan-meier-he/client_he.py +++ b/examples/advanced/kaplan-meier-he/client_he.py @@ -136,9 +136,9 @@ def main(): for i in range(len(idx)): hist_obs[idx[i]] = observed[i] hist_cen[idx[i]] = censored[i] - # Encrypt with tenseal using BFV scheme since observations are integers - hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) - hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Encrypt with tenseal using CKKS scheme + hist_obs_he = ts.ckks_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.ckks_vector(he_context, list(hist_cen.values())) # Serialize for transmission hist_obs_he_serial = hist_obs_he.serialize() hist_cen_he_serial = hist_cen_he.serialize() @@ -153,19 +153,20 @@ def main(): hist_obs_global_serial = global_msg.params["hist_obs_global"] hist_cen_global_serial = global_msg.params["hist_cen_global"] # Deserialize - hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) - hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + hist_obs_global = ts.ckks_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.ckks_vector_from(he_context, hist_cen_global_serial) # Decrypt - hist_obs_global = hist_obs_global.decrypt() - hist_cen_global = hist_cen_global.decrypt() + hist_obs_global = [int(round(x)) for x in hist_obs_global.decrypt()] + hist_cen_global = [int(round(x)) for x in hist_cen_global.decrypt()] # Unfold histogram to event list + # CKKS returns floats, so we round to nearest integer time_unfold = [] event_unfold = [] for i in range(max_idx_global): - for j in range(hist_obs_global[i]): + for j in range(int(hist_obs_global[i])): time_unfold.append(i) event_unfold.append(True) - for k in range(hist_cen_global[i]): + for k in range(int(hist_cen_global[i])): time_unfold.append(i) event_unfold.append(False) time_unfold = np.array(time_unfold) diff --git a/examples/advanced/kaplan-meier-he/server_he.py b/examples/advanced/kaplan-meier-he/server_he.py index 289edc3e6b..4d3e46f2eb 100644 --- a/examples/advanced/kaplan-meier-he/server_he.py +++ b/examples/advanced/kaplan-meier-he/server_he.py @@ -92,9 +92,9 @@ def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): for fl_model in sag_result: site = fl_model.meta.get("client_name", None) hist_obs_he_serial = fl_model.params["hist_obs"] - hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_obs_he = ts.ckks_vector_from(he_context, hist_obs_he_serial) hist_cen_he_serial = fl_model.params["hist_cen"] - hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + hist_cen_he = ts.ckks_vector_from(he_context, hist_cen_he_serial) if not hist_obs_global: print(f"assign global hist with result from {site}") diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py index ceedf4c9a4..2ed894e357 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py @@ -21,8 +21,8 @@ def data_split_args_parser(): parser = argparse.ArgumentParser(description="Generate HE context") - parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV") - parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096") + parser.add_argument("--scheme", type=str, default="CKKS", help="HE scheme, default is CKKS") + parser.add_argument("--poly_modulus_degree", type=int, default=8192, help="Poly modulus degree, default is 8192") parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server") return parser From 0ad3b350e357e34c41e81b500bd4ac8787049d2c Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 15 Dec 2025 14:00:06 -0500 Subject: [PATCH 5/7] use ckks only --- examples/advanced/kaplan-meier-he/README.md | 13 ++++++++----- examples/advanced/kaplan-meier-he/job.py | 10 +++++----- examples/advanced/kaplan-meier-he/project.yml | 2 +- .../kaplan-meier-he/utils/prepare_he_context.py | 10 ++++++++-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index d076ea0a6b..cbd908be55 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -72,9 +72,9 @@ We make use of FLARE ModelController API to implement the federated Kaplan-Meier The Flare ModelController API (`ModelController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. -Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models). -- different HE schemes (BFV) rather than CKKS -- different content at different rounds of federated learning, and only specific payload needs to be encrypted +Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models): +- Different content at different rounds of federated learning, where only specific payloads need to be encrypted +- Flexibility in choosing what to encrypt (histograms) versus what to send in plain text (metadata) With the ModelController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 2 rounds without HE, or 3 rounds with HE. @@ -82,7 +82,7 @@ For the federated analysis without HE, the detailed steps are as follows: 1. Server sends the simple start message without any payload. 2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients. -For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows: +For the federated analysis with HE, we need to ensure proper HE aggregation using CKKS, and the detailed steps are as follows: 1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. @@ -118,10 +118,13 @@ python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/nvflare/ For simulation mode, manually prepare the HE context with CKKS scheme: ```commandline +# Remove old HE context if it exists +rm -rf /tmp/nvflare/he_context +# Generate new CKKS HE context python utils/prepare_he_context.py --out_path "/tmp/nvflare/he_context" ``` -(By default uses CKKS scheme. To use BFV, add `--scheme BFV`) +This generates the HE context with CKKS scheme (poly_modulus_degree=8192, global_scale=2^40) compatible with production mode. **Step 3: Run the Job** diff --git a/examples/advanced/kaplan-meier-he/job.py b/examples/advanced/kaplan-meier-he/job.py index 5db31410f8..f007a1d6d9 100644 --- a/examples/advanced/kaplan-meier-he/job.py +++ b/examples/advanced/kaplan-meier-he/job.py @@ -78,14 +78,14 @@ def define_parser(): parser.add_argument( "--workspace_dir", type=str, - default="/tmp/nvflare/workspaces/km", - help="Work directory for simulator runs, default to '/tmp/nvflare/workspaces/km'", + default="/tmp/nvflare/workspaces", + help="Work directory for simulator runs, default to '/tmp/nvflare/workspaces'", ) parser.add_argument( "--job_dir", type=str, - default="/tmp/nvflare/jobs/km", - help="Directory for job export, default to '/tmp/nvflare/jobs/km'", + default="/tmp/nvflare/jobs", + help="Directory for job export, default to '/tmp/nvflare/jobs'", ) parser.add_argument( "--encryption", @@ -140,7 +140,7 @@ def main(): # Determine job name for workspace directory job_name = "KM_HE" if args.encryption else "KM" - workspace_dir = args.workspace_dir.replace("/km/", f"/{job_name}/") + workspace_dir = f"{args.workspace_dir}/{job_name}" # Create the recipe recipe = KMRecipe( diff --git a/examples/advanced/kaplan-meier-he/project.yml b/examples/advanced/kaplan-meier-he/project.yml index 368b07338a..2a6cc01f22 100644 --- a/examples/advanced/kaplan-meier-he/project.yml +++ b/examples/advanced/kaplan-meier-he/project.yml @@ -1,6 +1,6 @@ # Sample project.yml for Kaplan-Meier production deployment with HE (CKKS) # This file is used for provisioning secure startup kits with HE context -# Note: Simulation mode uses BFV scheme, production mode uses CKKS scheme +# Note: Both simulation and production modes use CKKS scheme for consistency api_version: 3 name: km_he_project diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py index 2ed894e357..d9104ad77a 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py @@ -42,8 +42,14 @@ def main(): context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193) elif args.scheme == "CKKS": scheme = ts.SCHEME_TYPE.CKKS - # Generate HE context, CKKS does not need plain_modulus - context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree) + # Generate HE context for CKKS + coeff_mod_bit_sizes = [60, 40, 40] + context = ts.context( + scheme, poly_modulus_degree=args.poly_modulus_degree, coeff_mod_bit_sizes=coeff_mod_bit_sizes + ) + # CKKS requires global scale to be set + context.generate_relin_keys() + context.global_scale = 2**40 else: raise ValueError("HE scheme not supported") From a7aee4d228a159d198274eed76d3fe91734a44c7 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 5 Jan 2026 13:48:59 -0500 Subject: [PATCH 6/7] fix tenseal context locations --- examples/advanced/kaplan-meier-he/README.md | 106 +++++++----------- examples/advanced/kaplan-meier-he/client.py | 28 ++++- .../advanced/kaplan-meier-he/client_he.py | 63 +++++++++-- examples/advanced/kaplan-meier-he/job.py | 64 +++++------ .../advanced/kaplan-meier-he/server_he.py | 40 +++++-- 5 files changed, 177 insertions(+), 124 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index cbd908be55..a4ede52c6b 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -89,6 +89,26 @@ For the federated analysis with HE, we need to ensure proper HE aggregation usin After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. +### HE Context and Data Management + +- **Simulation Mode**: + - Uses **CKKS scheme** (approximate arithmetic, compatible with production) + - HE context files are manually created via `prepare_he_context.py`: + - Client context: `/tmp/nvflare/he_context/he_context_client.txt` + - Server context: `/tmp/nvflare/he_context/he_context_server.txt` + - Data prepared at `/tmp/nvflare/dataset/km_data` + - Paths can be customized via `--he_context_path` (for client context) and `--data_root` +- **Production Mode**: + - Uses **CKKS scheme** + - HE context is automatically provisioned into startup kits via `nvflare provision` + - Context files are resolved by NVFlare's SecurityContentService: + - Clients automatically use: `client_context.tenseal` (from their startup kit) + - Server automatically uses: `server_context.tenseal` (from its startup kit) + - The `--he_context_path` parameter is ignored in production mode + - **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default + +**Note:** CKKS scheme provides strong encryption with approximate arithmetic, which works well for this Kaplan-Meier analysis. The histogram counts are encrypted as floating-point numbers and rounded back to integers after decryption. Both simulation and production modes use the same CKKS scheme for consistency and compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. + ## Run the job This example supports both **Simulation Mode** (for local testing) and **Production Mode** (for real-world deployment). @@ -134,18 +154,20 @@ python job.py python job.py --encryption ``` +The script will execute the job in simulation mode and display the job status. Results (KM curves and analysis details) will be saved to each simulated client's workspace directory under `/tmp/nvflare/workspaces/`. + ### Production Mode For production deployments, the HE context is automatically provisioned through secure startup kits. **Quick Start for Local Testing:** -If you want to quickly test production mode on a single machine, use the convenience scripts: +If you want to quickly test production mode on a single machine: 1. Run provisioning: `nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces` 2. Start all parties: `./start_all.sh` -3. Start admin console and use username `admin@nvidia.com` +3. Start admin console: `cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com && ./startup/fl_admin.sh` (use username `admin@nvidia.com`) 4. Submit job: `python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com` -5. Monitor job via admin console (use `list_jobs`, `check_status client`) -6. Stop all parties with admin console +5. Monitor job via admin console: `list_jobs`, `check_status client`, `download_job ` +6. Shutdown: `shutdown all` in admin console For detailed steps and distributed deployment, continue below: @@ -245,80 +267,38 @@ With all parties running, submit the job using the Recipe API. The job will auto python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com ``` -The job will be submitted to the FL system and executed across all connected clients. +The script will output the job status. Note the job ID from the output. **Monitoring Job Progress:** -The job runs asynchronously. To monitor progress, use the admin console: -``` -> list_jobs -> check_status server -> check_status client -``` +The job runs asynchronously on the FL system. Use the admin console to monitor progress: -To download job results after completion: -``` -> download_job +```commandline +# In the admin console +> list_jobs # View all jobs +> check_status server # Check server status +> check_status client # Check all clients status +> download_job # Download results after completion ``` -Results will be saved to each client's workspace directory: -- `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1/` -- Check for `km_curve_fl_he.png` and `km_global.json` in each client's directory +Results will be saved to each client's workspace directory after the job completes: +- `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1/{JOB_ID}/` +- Look for `km_curve_fl_he.png` and `km_global.json` in each client's job directory + +**Note:** In production mode with HE, the HE context paths are automatically configured to use the provisioned context files from each participant's startup kit: +- Clients use: `client_context.tenseal` +- Server uses: `server_context.tenseal` -**Note:** In production mode with HE, the `--he_context_path` parameter is automatically set to use the provisioned `client_context.tenseal` or `server_context.tenseal` from each participant's startup kit. No manual HE context distribution is needed. +The `--he_context_path` parameter is only used for simulation mode and is ignored in production mode. No manual HE context distribution is needed in production. **Step 6: Shutdown All Parties** -After the job completes, shut down all parties gracefully Via Admin Console +After the job completes, shut down all parties gracefully via admin console: ``` > shutdown all ``` -### Customization Options - -**Simulation Mode:** -```commandline -# Customize number of clients and threads -python job.py --num_clients 3 --num_threads 2 - -# Customize data paths (simulation only) -python job.py --data_root /custom/data/path --he_context_path /custom/he/path - -# Customize output directories -python job.py --workspace_dir /custom/workspace --job_dir /custom/jobdir - -# Combine options -python job.py --encryption --num_clients 10 --num_threads 5 -``` - -**Production Mode:** -```commandline -# Make sure server and all clients are started first (see Step 4) -# The he_context_path is automatically managed by startup kits -# By default, uses the same data from simulation mode at /tmp/nvflare/dataset/km_data -# Only customize data_root if you've moved the data elsewhere -python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com --data_root /custom/data/path -``` - -By default, this will generate a KM curve image `km_curve_fl.png` (or `km_curve_fl_he.png` with encryption) under each client's directory. - -### HE Context and Data Management - -- **Simulation Mode**: - - Uses **CKKS scheme** (approximate arithmetic, compatible with production) - - HE context files are manually created via `prepare_he_context.py` - - Data prepared at `/tmp/nvflare/dataset/km_data` - - Paths specified via `--he_context_path` and `--data_root` -- **Production Mode**: - - Uses **CKKS scheme** (same as simulation for consistency) - - HE context is automatically provisioned into startup kits via `nvflare provision` - - Clients use: `/client_context.tenseal` - - Server uses: `/server_context.tenseal` - - **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default - -**Note:** CKKS scheme provides strong encryption with approximate arithmetic, which works well for this Kaplan-Meier analysis. The histogram counts are encrypted as floating-point numbers and rounded back to integers after decryption. Both simulation and production modes use the same CKKS scheme for consistency and compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. - ## Display Result By comparing the two curves, we can observe that all curves are identical: diff --git a/examples/advanced/kaplan-meier-he/client.py b/examples/advanced/kaplan-meier-he/client.py index d8d7e55d28..a4a6e1cb73 100644 --- a/examples/advanced/kaplan-meier-he/client.py +++ b/examples/advanced/kaplan-meier-he/client.py @@ -28,7 +28,7 @@ # Client code -def details_save(kmf): +def details_save(kmf, site_name): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ # Get the timeline (time points) @@ -46,13 +46,21 @@ def details_save(kmf): "event_count": event_count.tolist(), "survival_rate": survival_rate.tolist(), } - file_path = os.path.join(os.getcwd(), "km_global.json") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod) + # We need to navigate up to the {JOB_DIR} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_DIR} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_global.json") print(f"save the details of KM analysis result to {file_path} \n") with open(file_path, "w") as json_file: json.dump(results, json_file, indent=4) -def plot_and_save(kmf): +def plot_and_save(kmf, site_name): # Plot and save the Kaplan-Meier survival curve plt.figure() plt.title("Federated") @@ -62,7 +70,15 @@ def plot_and_save(kmf): plt.xlabel("time") plt.legend("", frameon=False) plt.tight_layout() - file_path = os.path.join(os.getcwd(), "km_curve_fl.png") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod) + # We need to navigate up to the {JOB_DIR} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_DIR} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_curve_fl.png") print(f"save the curve plot to {file_path} \n") plt.savefig(file_path) @@ -136,10 +152,10 @@ def main(): kmf.fit(durations=time_unfold, event_observed=event_unfold) # Plot and save the KM curve - plot_and_save(kmf) + plot_and_save(kmf, site_name) # Save details of the KM result to a json file - details_save(kmf) + details_save(kmf, site_name) # Send a simple response to server response = FLModel(params={}, params_type=ParamsType.FULL) diff --git a/examples/advanced/kaplan-meier-he/client_he.py b/examples/advanced/kaplan-meier-he/client_he.py index 2b280e5584..600710d32e 100644 --- a/examples/advanced/kaplan-meier-he/client_he.py +++ b/examples/advanced/kaplan-meier-he/client_he.py @@ -31,12 +31,33 @@ # Client code def read_data(file_name: str): + # Handle both absolute and relative paths + # In production mode, HE context files are in the startup directory + if not os.path.isabs(file_name) and not os.path.exists(file_name): + # Try CWD/startup/ (production deployment location) + cwd = os.getcwd() + startup_path = os.path.join(cwd, "startup", file_name) + if os.path.exists(startup_path): + file_name = startup_path + print(f"Using HE context file from startup directory: {file_name}") + with open(file_name, "rb") as f: data = f.read() - return base64.b64decode(data) + + # Handle both base64-encoded (simulation mode) and raw binary (production mode) formats + # Production mode (HEBuilder): files are raw binary (.tenseal) + # Simulation mode (prepare_he_context.py): files are base64-encoded (.txt) + if file_name.endswith(".tenseal"): + # Production mode: raw binary format + print("Using raw binary HE context (production mode)") + return data + else: + # Simulation mode: base64-encoded format (.txt files) + print("Using base64-encoded HE context (simulation mode)") + return base64.b64decode(data) -def details_save(kmf): +def details_save(kmf, site_name): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ # Get the timeline (time points) @@ -54,13 +75,21 @@ def details_save(kmf): "event_count": event_count.tolist(), "survival_rate": survival_rate.tolist(), } - file_path = os.path.join(os.getcwd(), "km_global.json") - print(f"save the details of KM analysis result to {file_path} \n") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_ID}/app_site-X/custom/client_he.py + # We need to navigate up to the {JOB_ID} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_ID} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_global.json") + print(f"save the details of KM analysis result (cleartext) to {file_path} \n") with open(file_path, "w") as json_file: json.dump(results, json_file, indent=4) -def plot_and_save(kmf): +def plot_and_save(kmf, site_name): # Plot and save the Kaplan-Meier survival curve plt.figure() plt.title("Federated HE") @@ -70,7 +99,15 @@ def plot_and_save(kmf): plt.xlabel("time") plt.legend("", frameon=False) plt.tight_layout() - file_path = os.path.join(os.getcwd(), "km_curve_fl_he.png") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_ID}/app_site-X/custom/client_he.py + # We need to navigate up to the {JOB_ID} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_ID} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_curve_fl_he.png") print(f"save the curve plot to {file_path} \n") plt.savefig(file_path) @@ -113,7 +150,7 @@ def main(): max_hist_idx = max(hist_idx) # Send max to server - print(f"send max hist index for site = {flare.get_site_name()}") + print(f"send max hist index (cleartext) for site = {flare.get_site_name()}") model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) flare.send(model) @@ -121,8 +158,7 @@ def main(): # Second round, get global max index # Organize local histogram and encrypt max_idx_global = global_msg.params["max_idx_global"] - print("Global Max Idx") - print(max_idx_global) + print(f"Received global max idx (cleartext): {max_idx_global}") # Convert local table to uniform histogram hist_obs = {} hist_cen = {} @@ -143,6 +179,7 @@ def main(): hist_obs_he_serial = hist_obs_he.serialize() hist_cen_he_serial = hist_cen_he.serialize() # Send encrypted histograms to server + print("Send encrypted histograms (ciphertext) to server") response = FLModel( params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL ) @@ -152,10 +189,12 @@ def main(): # Get global histograms hist_obs_global_serial = global_msg.params["hist_obs_global"] hist_cen_global_serial = global_msg.params["hist_cen_global"] + print("Received global accumulated histograms (ciphertext)") # Deserialize hist_obs_global = ts.ckks_vector_from(he_context, hist_obs_global_serial) hist_cen_global = ts.ckks_vector_from(he_context, hist_cen_global_serial) # Decrypt + print("Decrypting histograms to cleartext") hist_obs_global = [int(round(x)) for x in hist_obs_global.decrypt()] hist_cen_global = [int(round(x)) for x in hist_cen_global.decrypt()] # Unfold histogram to event list @@ -180,16 +219,16 @@ def main(): kmf.fit(durations=time_unfold, event_observed=event_unfold) # Plot and save the KM curve - plot_and_save(kmf) + plot_and_save(kmf, site_name) # Save details of the KM result to a json file - details_save(kmf) + details_save(kmf, site_name) # Send a simple response to server response = FLModel(params={}, params_type=ParamsType.FULL) flare.send(response) - print(f"finish send for {site_name}, complete") + print(f"Finish send for {site_name}, complete") if __name__ == "__main__": diff --git a/examples/advanced/kaplan-meier-he/job.py b/examples/advanced/kaplan-meier-he/job.py index f007a1d6d9..b581be6119 100644 --- a/examples/advanced/kaplan-meier-he/job.py +++ b/examples/advanced/kaplan-meier-he/job.py @@ -36,19 +36,21 @@ def __init__( num_clients: int, encryption: bool = False, data_root: str = "/tmp/nvflare/dataset/km_data", - he_context_path: str = "/tmp/nvflare/he_context/he_context_client.txt", + he_context_path_client: str = "/tmp/nvflare/he_context/he_context_client.txt", + he_context_path_server: str = "/tmp/nvflare/he_context/he_context_server.txt", ): self.num_clients = num_clients self.encryption = encryption self.data_root = data_root - self.he_context_path = he_context_path + self.he_context_path_client = he_context_path_client + self.he_context_path_server = he_context_path_server # Set job name and script based on encryption mode if self.encryption: job_name = "KM_HE" train_script = "client_he.py" - script_args = f"--data_root {data_root} --he_context_path {he_context_path}" - controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) + script_args = f"--data_root {data_root} --he_context_path {he_context_path_client}" + controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path_server) else: job_name = "KM" train_script = "client.py" @@ -113,7 +115,7 @@ def define_parser(): "--he_context_path", type=str, default="/tmp/nvflare/he_context/he_context_client.txt", - help="Path to HE context file, default to '/tmp/nvflare/he_context/he_context_client.txt'", + help="Path to HE context file for simulation mode (client context), default to '/tmp/nvflare/he_context/he_context_client.txt'. In production mode, context files are auto-provisioned.", ) parser.add_argument( "--startup_kit_location", @@ -131,58 +133,52 @@ def define_parser(): def main(): - print("Starting Kaplan-Meier job...") args = define_parser() - print("args:", args) num_clients = args.num_clients num_threads = args.num_threads if args.num_threads else num_clients - # Determine job name for workspace directory - job_name = "KM_HE" if args.encryption else "KM" - workspace_dir = f"{args.workspace_dir}/{job_name}" + # Use workspace directory directly (SimEnv will create job-specific subdirectories) + workspace_dir = args.workspace_dir + + # Adjust HE context paths based on environment + if args.startup_kit_location and args.encryption: + # In production mode, use just the filename - NVFlare's SecurityContentService + # will resolve the path relative to each participant's workspace + he_context_path_client = "client_context.tenseal" + he_context_path_server = "server_context.tenseal" + elif args.encryption: + # In simulation mode, use the manually prepared context files + he_context_path_client = args.he_context_path + # Derive server path from client path + he_context_path_server = he_context_path_client.replace("he_context_client.txt", "he_context_server.txt") + else: + # No encryption - values won't be used + he_context_path_client = None + he_context_path_server = None # Create the recipe recipe = KMRecipe( num_clients=num_clients, encryption=args.encryption, data_root=args.data_root, - he_context_path=args.he_context_path, + he_context_path_client=he_context_path_client, + he_context_path_server=he_context_path_server, ) # Export job - print("Exporting job to", args.job_dir) recipe.job.export_job(args.job_dir) # Run recipe if args.startup_kit_location: - print("Running job in production mode...") - print("startup_kit_location=", args.startup_kit_location) - print("username=", args.username) + print("\n=== Production Mode ===") env = ProdEnv(startup_kit_location=args.startup_kit_location, username=args.username) else: - print("Running job in simulation mode...") - print("workspace_dir=", workspace_dir) - print("num_clients=", num_clients) - print("num_threads=", num_threads) + print("\n=== Simulation Mode ===") env = SimEnv(num_clients=num_clients, num_threads=num_threads, workspace_root=workspace_dir) run = recipe.execute(env) - print("Job Status is:", run.get_status()) - - # In production mode, job runs asynchronously on the FL system - # Check status via admin console instead of waiting for result here - if args.startup_kit_location: - print("\nJob submitted successfully to the FL system!") - print("To monitor job status, use the admin console:") - print(f" cd {args.startup_kit_location}") - print(" ./startup/fl_admin.sh") - print(" > check_status server") - print(" > list_jobs") - print(f" > download_job {run.job_id}") - else: - # In simulation mode, we can get the result synchronously - print("Job Result is:", run.get_result()) + print(f"Job Status: {run.get_status()}") if __name__ == "__main__": diff --git a/examples/advanced/kaplan-meier-he/server_he.py b/examples/advanced/kaplan-meier-he/server_he.py index 4d3e46f2eb..4b3ac40393 100644 --- a/examples/advanced/kaplan-meier-he/server_he.py +++ b/examples/advanced/kaplan-meier-he/server_he.py @@ -14,6 +14,7 @@ import base64 import logging +import os import tenseal as ts @@ -39,9 +40,30 @@ def run(self): _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) def read_data(self, file_name: str): + # Handle both absolute and relative paths + # In production mode, HE context files are in the startup directory + if not os.path.isabs(file_name) and not os.path.exists(file_name): + # Try CWD/startup/ (production deployment location) + cwd = os.getcwd() + startup_path = os.path.join(cwd, "startup", file_name) + if os.path.exists(startup_path): + file_name = startup_path + self.logger.info(f"Using HE context file from startup directory: {file_name}") + with open(file_name, "rb") as f: data = f.read() - return base64.b64decode(data) + + # Handle both base64-encoded (simulation mode) and raw binary (production mode) formats + # Production mode (HEBuilder): files are raw binary (.tenseal) + # Simulation mode (prepare_he_context.py): files are base64-encoded (.txt) + if file_name.endswith(".tenseal"): + # Production mode: raw binary format + self.logger.info("Using raw binary HE context (production mode)") + return data + else: + # Simulation mode: base64-encoded format (.txt files) + self.logger.info("Using base64-encoded HE context (simulation mode)") + return base64.b64decode(data) def start_fl_collect_max_idx(self): self.logger.info("send initial message to all sites to start FL \n") @@ -51,7 +73,7 @@ def start_fl_collect_max_idx(self): return results def aggr_max_idx(self, sag_result: dict[str, dict[str, FLModel]]): - self.logger.info("aggregate max histogram index \n") + self.logger.info("aggregate max histogram index (cleartext) \n") if not sag_result: raise RuntimeError("input is None or empty") @@ -64,7 +86,7 @@ def aggr_max_idx(self, sag_result: dict[str, dict[str, FLModel]]): return max(max_idx_global) + 1 def distribute_max_idx_collect_enc_stats(self, result: int): - self.logger.info("send global max_index to all sites \n") + self.logger.info("send global max_index (cleartext) to all sites \n") model = FLModel( params={"max_idx_global": result}, @@ -78,7 +100,7 @@ def distribute_max_idx_collect_enc_stats(self, result: int): return results def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): - self.logger.info("aggregate histogram within HE \n") + self.logger.info("aggregate histogram (ciphertext) within HE \n") # Load HE context he_context_serial = self.read_data(self.he_context_path) @@ -97,17 +119,17 @@ def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): hist_cen_he = ts.ckks_vector_from(he_context, hist_cen_he_serial) if not hist_obs_global: - print(f"assign global hist with result from {site}") + self.logger.info(f"assign global hist (ciphertext) with result from {site}") hist_obs_global = hist_obs_he else: - print(f"add to global hist with result from {site}") + self.logger.info(f"add ciphertext to global hist with result from {site}") hist_obs_global += hist_obs_he if not hist_cen_global: - print(f"assign global hist with result from {site}") + self.logger.info(f"assign global censored hist (ciphertext) with result from {site}") hist_cen_global = hist_cen_he else: - print(f"add to global hist with result from {site}") + self.logger.info(f"add ciphertext to global censored hist with result from {site}") hist_cen_global += hist_cen_he # return the two accumulated vectors, serialized for transmission @@ -116,7 +138,7 @@ def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): return hist_obs_global_serial, hist_cen_global_serial def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): - self.logger.info("send global accumulated histograms within HE to all sites \n") + self.logger.info("send global accumulated histograms (ciphertext) to all sites \n") model = FLModel( params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, From 690a327769cf8a5d294aada6c989ceefa2307305 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 5 Jan 2026 13:52:57 -0500 Subject: [PATCH 7/7] further polish --- examples/advanced/kaplan-meier-he/client.py | 12 ++++++------ examples/advanced/kaplan-meier-he/client_he.py | 13 ++++++------- examples/advanced/kaplan-meier-he/server.py | 5 +++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/client.py b/examples/advanced/kaplan-meier-he/client.py index a4a6e1cb73..437ab11ae0 100644 --- a/examples/advanced/kaplan-meier-he/client.py +++ b/examples/advanced/kaplan-meier-he/client.py @@ -28,7 +28,7 @@ # Client code -def details_save(kmf, site_name): +def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ # Get the timeline (time points) @@ -60,7 +60,7 @@ def details_save(kmf, site_name): json.dump(results, json_file, indent=4) -def plot_and_save(kmf, site_name): +def plot_and_save(kmf): # Plot and save the Kaplan-Meier survival curve plt.figure() plt.title("Federated") @@ -110,10 +110,10 @@ def main(): # Empty payload from server, send local histogram # Convert local data to histogram event_table = survival_table_from_events(time_local, event_local) - hist_idx = event_table.index.values.astype(int) hist_obs = {} hist_cen = {} - for idx in range(max(hist_idx)): + max_hist_idx = max(event_table.index.values.astype(int)) + for idx in range(max_hist_idx): hist_obs[idx] = 0 hist_cen[idx] = 0 # Assign values @@ -152,10 +152,10 @@ def main(): kmf.fit(durations=time_unfold, event_observed=event_unfold) # Plot and save the KM curve - plot_and_save(kmf, site_name) + plot_and_save(kmf) # Save details of the KM result to a json file - details_save(kmf, site_name) + details_save(kmf) # Send a simple response to server response = FLModel(params={}, params_type=ParamsType.FULL) diff --git a/examples/advanced/kaplan-meier-he/client_he.py b/examples/advanced/kaplan-meier-he/client_he.py index 600710d32e..0dbb30e160 100644 --- a/examples/advanced/kaplan-meier-he/client_he.py +++ b/examples/advanced/kaplan-meier-he/client_he.py @@ -57,7 +57,7 @@ def read_data(file_name: str): return base64.b64decode(data) -def details_save(kmf, site_name): +def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ # Get the timeline (time points) @@ -89,7 +89,7 @@ def details_save(kmf, site_name): json.dump(results, json_file, indent=4) -def plot_and_save(kmf, site_name): +def plot_and_save(kmf): # Plot and save the Kaplan-Meier survival curve plt.figure() plt.title("Federated HE") @@ -145,12 +145,11 @@ def main(): # Empty payload from server, send max index back # Condense local data to histogram event_table = survival_table_from_events(time_local, event_local) - hist_idx = event_table.index.values.astype(int) # Get the max index to be synced globally - max_hist_idx = max(hist_idx) + max_hist_idx = max(event_table.index.values.astype(int)) # Send max to server - print(f"send max hist index (cleartext) for site = {flare.get_site_name()}") + print(f"send max hist index (cleartext) for site = {site_name}") model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) flare.send(model) @@ -219,10 +218,10 @@ def main(): kmf.fit(durations=time_unfold, event_observed=event_unfold) # Plot and save the KM curve - plot_and_save(kmf, site_name) + plot_and_save(kmf) # Save details of the KM result to a json file - details_save(kmf, site_name) + details_save(kmf) # Send a simple response to server response = FLModel(params={}, params_type=ParamsType.FULL) diff --git a/examples/advanced/kaplan-meier-he/server.py b/examples/advanced/kaplan-meier-he/server.py index d6e3c3b401..7afe73d06a 100644 --- a/examples/advanced/kaplan-meier-he/server.py +++ b/examples/advanced/kaplan-meier-he/server.py @@ -47,8 +47,9 @@ def aggr_hist(self, sag_result: dict[str, dict[str, FLModel]]): hist_idx_max = 0 for fl_model in sag_result: hist = fl_model.params["hist_obs"] - if hist_idx_max < max(hist.keys()): - hist_idx_max = max(hist.keys()) + max_idx = max(hist.keys()) + if hist_idx_max < max_idx: + hist_idx_max = max_idx hist_idx_max += 1 hist_obs_global = {}