diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 7e25710469..a4ede52c6b 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -2,7 +2,7 @@ This example illustrates two features: * How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE). -* How to use the Flare ModelController API to contract a workflow to facilitate HE under simulator mode. +* How to use the Recipe API with Flare ModelController for job configuration and execution in both simulation and production environments. ## Basics of Kaplan-Meier Analysis Kaplan-Meier survival analysis is a non-parametric statistic used to estimate the survival function from lifetime data. It is used to analyze the time it takes for an event of interest to occur. For example, during a clinical trial, the Kaplan-Meier estimator can be used to estimate the proportion of patients who survive a certain amount of time after treatment. @@ -62,7 +62,7 @@ To run the baseline script, simply execute: ```commandline python utils/baseline_kaplan_meier.py ``` -By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resulting KM curve is shown below: +By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp/nvflare/baseline` directory. The resulting KM curve is shown below: ![KM survival baseline](figs/km_curve_baseline.png) Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution. @@ -72,9 +72,9 @@ We make use of FLARE ModelController API to implement the federated Kaplan-Meier The Flare ModelController API (`ModelController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. -Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models). -- different HE schemes (BFV) rather than CKKS -- different content at different rounds of federated learning, and only specific payload needs to be encrypted +Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models): +- Different content at different rounds of federated learning, where only specific payloads need to be encrypted +- Flexibility in choosing what to encrypt (histograms) versus what to send in plain text (metadata) With the ModelController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 2 rounds without HE, or 3 rounds with HE. @@ -82,31 +82,222 @@ For the federated analysis without HE, the detailed steps are as follows: 1. Server sends the simple start message without any payload. 2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients. -For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows: +For the federated analysis with HE, we need to ensure proper HE aggregation using CKKS, and the detailed steps are as follows: 1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. +### HE Context and Data Management + +- **Simulation Mode**: + - Uses **CKKS scheme** (approximate arithmetic, compatible with production) + - HE context files are manually created via `prepare_he_context.py`: + - Client context: `/tmp/nvflare/he_context/he_context_client.txt` + - Server context: `/tmp/nvflare/he_context/he_context_server.txt` + - Data prepared at `/tmp/nvflare/dataset/km_data` + - Paths can be customized via `--he_context_path` (for client context) and `--data_root` +- **Production Mode**: + - Uses **CKKS scheme** + - HE context is automatically provisioned into startup kits via `nvflare provision` + - Context files are resolved by NVFlare's SecurityContentService: + - Clients automatically use: `client_context.tenseal` (from their startup kit) + - Server automatically uses: `server_context.tenseal` (from its startup kit) + - The `--he_context_path` parameter is ignored in production mode + - **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default + +**Note:** CKKS scheme provides strong encryption with approximate arithmetic, which works well for this Kaplan-Meier analysis. The histogram counts are encrypted as floating-point numbers and rounded back to integers after decryption. Both simulation and production modes use the same CKKS scheme for consistency and compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation. + ## Run the job -First, we prepared data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days. + +This example supports both **Simulation Mode** (for local testing) and **Production Mode** (for real-world deployment). + +| Feature | Simulation Mode | Production Mode | +|---------|----------------|-----------------| +| **Use Case** | Testing & Development | Real-world Deployment / Production Testing | +| **HE Context** | Manual preparation via script | Auto-provisioned via startup kits | +| **Security** | Single machine, no encryption between processes | Secure startup kits with certificates | +| **Setup** | Quick & simple | Requires provisioning & starting all parties | +| **Startup** | Single command | `start_all.sh` (local) or manual (distributed) | +| **Participants** | All run locally in one process | Distributed servers/clients running separately | +| **Data** | Prepared once, shared by all | Same data reused from simulation | + +### Simulation Mode + +For simulation mode (testing and development), we manually prepare the data and HE context: + +**Step 1: Prepare Data** + +Split and generate data files for each client with binning interval of 7 days: ```commandline python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/nvflare/dataset/km_data" ``` -Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context. +**Step 2: Prepare HE Context (Simulation Only)** + +For simulation mode, manually prepare the HE context with CKKS scheme: ```commandline +# Remove old HE context if it exists +rm -rf /tmp/nvflare/he_context +# Generate new CKKS HE context python utils/prepare_he_context.py --out_path "/tmp/nvflare/he_context" ``` -Next, we run the federated training using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html), both without and with HE: +This generates the HE context with CKKS scheme (poly_modulus_degree=8192, global_scale=2^40) compatible with production mode. + +**Step 3: Run the Job** + +Run the job without and with HE: +```commandline +python job.py +python job.py --encryption +``` + +The script will execute the job in simulation mode and display the job status. Results (KM curves and analysis details) will be saved to each simulated client's workspace directory under `/tmp/nvflare/workspaces/`. + +### Production Mode + +For production deployments, the HE context is automatically provisioned through secure startup kits. + +**Quick Start for Local Testing:** +If you want to quickly test production mode on a single machine: +1. Run provisioning: `nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces` +2. Start all parties: `./start_all.sh` +3. Start admin console: `cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com && ./startup/fl_admin.sh` (use username `admin@nvidia.com`) +4. Submit job: `python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com` +5. Monitor job via admin console: `list_jobs`, `check_status client`, `download_job ` +6. Shutdown: `shutdown all` in admin console + +For detailed steps and distributed deployment, continue below: + +**Step 1: Install NVFlare with HE Support** + +```commandline +pip install nvflare[HE] +``` + +**Step 2: Provision Startup Kits with HE Context** + +The `project.yml` file in this directory is pre-configured with `HEBuilder` using the CKKS scheme. Run provisioning to output to `/tmp/nvflare/prod_workspaces`: + +```commandline +nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces +``` + +This generates startup kits in `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`: +- `localhost/` - Server startup kit with `server_context.tenseal` +- `site-1/`, `site-2/`, etc. - Client startup kits, each with `client_context.tenseal` +- `admin@nvidia.com/` - Admin console + +The HE context files are automatically included in each startup kit and do not need to be manually distributed. + +**Step 3: Distribute Startup Kits** + +Securely distribute the startup kits to each participant from `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`: +- `localhost/` directory is the server (for local testing, no need to send) +- Send `site-1/`, `site-2/`, etc. directories to each client host (for distributed deployment) +- Keep `admin@nvidia.com/` directory for the admin user + +**Step 4: Start All Parties** + +**Option A: Quick Start (Local Testing)** + +For local testing where all parties run on the same machine, use the convenience script: + +```commandline +./start_all.sh +``` + +This will start the server and all 5 clients in the background. Logs are saved to `/tmp/nvflare/logs/`. + +Then start the admin console: +```commandline +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +./startup/fl_admin.sh +``` + +**Important:** When prompted for "User Name:", enter `admin@nvidia.com` (this matches the admin defined in project.yml). + +Once connected, check the status of all participants: +``` +> check_status server +> check_status client +``` + +**Option B: Manual Start (Distributed Deployment)** + +For distributed deployment where parties run on different machines: + +**On the Server Host:** +```commandline +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/localhost +./startup/start.sh +``` + +Wait for the server to be ready (you should see "Server started" in the logs). + +**On Each Client Host:** +```commandline +# On site-1 +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1 +./startup/start.sh + +# On site-2 +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-2 +./startup/start.sh + +# Repeat for site-3, site-4, and site-5 +``` + +**On the Admin Machine:** ```commandline -python km_job.py -python km_job.py --encryption +cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +./startup/fl_admin.sh +# When prompted, use username: admin@nvidia.com ``` -By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory. +**Step 5: Submit and Run the Job** + +With all parties running, submit the job using the Recipe API. The job will automatically use: +- The provisioned HE context from each participant's startup kit +- The data already prepared in simulation mode at `/tmp/nvflare/dataset/km_data` + +```commandline +python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com +``` + +The script will output the job status. Note the job ID from the output. + +**Monitoring Job Progress:** + +The job runs asynchronously on the FL system. Use the admin console to monitor progress: + +```commandline +# In the admin console +> list_jobs # View all jobs +> check_status server # Check server status +> check_status client # Check all clients status +> download_job # Download results after completion +``` + +Results will be saved to each client's workspace directory after the job completes: +- `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1/{JOB_ID}/` +- Look for `km_curve_fl_he.png` and `km_global.json` in each client's job directory + +**Note:** In production mode with HE, the HE context paths are automatically configured to use the provisioned context files from each participant's startup kit: +- Clients use: `client_context.tenseal` +- Server uses: `server_context.tenseal` + +The `--he_context_path` parameter is only used for simulation mode and is ignored in production mode. No manual HE context distribution is needed in production. + +**Step 6: Shutdown All Parties** + +After the job completes, shut down all parties gracefully via admin console: + +``` +> shutdown all +``` ## Display Result diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/client.py similarity index 83% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py rename to examples/advanced/kaplan-meier-he/client.py index d8d7e55d28..437ab11ae0 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/client.py @@ -46,7 +46,15 @@ def details_save(kmf): "event_count": event_count.tolist(), "survival_rate": survival_rate.tolist(), } - file_path = os.path.join(os.getcwd(), "km_global.json") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod) + # We need to navigate up to the {JOB_DIR} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_DIR} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_global.json") print(f"save the details of KM analysis result to {file_path} \n") with open(file_path, "w") as json_file: json.dump(results, json_file, indent=4) @@ -62,7 +70,15 @@ def plot_and_save(kmf): plt.xlabel("time") plt.legend("", frameon=False) plt.tight_layout() - file_path = os.path.join(os.getcwd(), "km_curve_fl.png") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod) + # We need to navigate up to the {JOB_DIR} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_DIR} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_curve_fl.png") print(f"save the curve plot to {file_path} \n") plt.savefig(file_path) @@ -94,10 +110,10 @@ def main(): # Empty payload from server, send local histogram # Convert local data to histogram event_table = survival_table_from_events(time_local, event_local) - hist_idx = event_table.index.values.astype(int) hist_obs = {} hist_cen = {} - for idx in range(max(hist_idx)): + max_hist_idx = max(event_table.index.values.astype(int)) + for idx in range(max_hist_idx): hist_obs[idx] = 0 hist_cen[idx] = 0 # Assign values diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py b/examples/advanced/kaplan-meier-he/client_he.py similarity index 66% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py rename to examples/advanced/kaplan-meier-he/client_he.py index 1ff9c69dbb..0dbb30e160 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py +++ b/examples/advanced/kaplan-meier-he/client_he.py @@ -31,9 +31,30 @@ # Client code def read_data(file_name: str): + # Handle both absolute and relative paths + # In production mode, HE context files are in the startup directory + if not os.path.isabs(file_name) and not os.path.exists(file_name): + # Try CWD/startup/ (production deployment location) + cwd = os.getcwd() + startup_path = os.path.join(cwd, "startup", file_name) + if os.path.exists(startup_path): + file_name = startup_path + print(f"Using HE context file from startup directory: {file_name}") + with open(file_name, "rb") as f: data = f.read() - return base64.b64decode(data) + + # Handle both base64-encoded (simulation mode) and raw binary (production mode) formats + # Production mode (HEBuilder): files are raw binary (.tenseal) + # Simulation mode (prepare_he_context.py): files are base64-encoded (.txt) + if file_name.endswith(".tenseal"): + # Production mode: raw binary format + print("Using raw binary HE context (production mode)") + return data + else: + # Simulation mode: base64-encoded format (.txt files) + print("Using base64-encoded HE context (simulation mode)") + return base64.b64decode(data) def details_save(kmf): @@ -54,8 +75,16 @@ def details_save(kmf): "event_count": event_count.tolist(), "survival_rate": survival_rate.tolist(), } - file_path = os.path.join(os.getcwd(), "km_global.json") - print(f"save the details of KM analysis result to {file_path} \n") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_ID}/app_site-X/custom/client_he.py + # We need to navigate up to the {JOB_ID} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_ID} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_global.json") + print(f"save the details of KM analysis result (cleartext) to {file_path} \n") with open(file_path, "w") as json_file: json.dump(results, json_file, indent=4) @@ -70,7 +99,15 @@ def plot_and_save(kmf): plt.xlabel("time") plt.legend("", frameon=False) plt.tight_layout() - file_path = os.path.join(os.getcwd(), "km_curve_fl_he.png") + + # Save to job-specific directory + # The script is located at: site-X/{JOB_ID}/app_site-X/custom/client_he.py + # We need to navigate up to the {JOB_ID} directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up 2 levels: custom -> app_site-X -> {JOB_ID} + job_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) + + file_path = os.path.join(job_dir, "km_curve_fl_he.png") print(f"save the curve plot to {file_path} \n") plt.savefig(file_path) @@ -108,12 +145,11 @@ def main(): # Empty payload from server, send max index back # Condense local data to histogram event_table = survival_table_from_events(time_local, event_local) - hist_idx = event_table.index.values.astype(int) # Get the max index to be synced globally - max_hist_idx = max(hist_idx) + max_hist_idx = max(event_table.index.values.astype(int)) # Send max to server - print(f"send max hist index for site = {flare.get_site_name()}") + print(f"send max hist index (cleartext) for site = {site_name}") model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) flare.send(model) @@ -121,8 +157,7 @@ def main(): # Second round, get global max index # Organize local histogram and encrypt max_idx_global = global_msg.params["max_idx_global"] - print("Global Max Idx") - print(max_idx_global) + print(f"Received global max idx (cleartext): {max_idx_global}") # Convert local table to uniform histogram hist_obs = {} hist_cen = {} @@ -136,13 +171,14 @@ def main(): for i in range(len(idx)): hist_obs[idx[i]] = observed[i] hist_cen[idx[i]] = censored[i] - # Encrypt with tenseal using BFV scheme since observations are integers - hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) - hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Encrypt with tenseal using CKKS scheme + hist_obs_he = ts.ckks_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.ckks_vector(he_context, list(hist_cen.values())) # Serialize for transmission hist_obs_he_serial = hist_obs_he.serialize() hist_cen_he_serial = hist_cen_he.serialize() # Send encrypted histograms to server + print("Send encrypted histograms (ciphertext) to server") response = FLModel( params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL ) @@ -152,20 +188,23 @@ def main(): # Get global histograms hist_obs_global_serial = global_msg.params["hist_obs_global"] hist_cen_global_serial = global_msg.params["hist_cen_global"] + print("Received global accumulated histograms (ciphertext)") # Deserialize - hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) - hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + hist_obs_global = ts.ckks_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.ckks_vector_from(he_context, hist_cen_global_serial) # Decrypt - hist_obs_global = hist_obs_global.decrypt() - hist_cen_global = hist_cen_global.decrypt() + print("Decrypting histograms to cleartext") + hist_obs_global = [int(round(x)) for x in hist_obs_global.decrypt()] + hist_cen_global = [int(round(x)) for x in hist_cen_global.decrypt()] # Unfold histogram to event list + # CKKS returns floats, so we round to nearest integer time_unfold = [] event_unfold = [] for i in range(max_idx_global): - for j in range(hist_obs_global[i]): + for j in range(int(hist_obs_global[i])): time_unfold.append(i) event_unfold.append(True) - for k in range(hist_cen_global[i]): + for k in range(int(hist_cen_global[i])): time_unfold.append(i) event_unfold.append(False) time_unfold = np.array(time_unfold) @@ -188,7 +227,7 @@ def main(): response = FLModel(params={}, params_type=ParamsType.FULL) flare.send(response) - print(f"finish send for {site_name}, complete") + print(f"Finish send for {site_name}, complete") if __name__ == "__main__": diff --git a/examples/advanced/kaplan-meier-he/job.py b/examples/advanced/kaplan-meier-he/job.py new file mode 100644 index 0000000000..b581be6119 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from server import KM +from server_he import KM_HE + +from nvflare import FedJob +from nvflare.job_config.script_runner import ScriptRunner +from nvflare.recipe import ProdEnv, SimEnv +from nvflare.recipe.spec import Recipe + + +class KMRecipe(Recipe): + """Recipe wrapper around the Kaplan-Meier job configuration. + + This provides a recipe-style API for easy job configuration and execution + in both simulation and production environments. + """ + + def __init__( + self, + *, + num_clients: int, + encryption: bool = False, + data_root: str = "/tmp/nvflare/dataset/km_data", + he_context_path_client: str = "/tmp/nvflare/he_context/he_context_client.txt", + he_context_path_server: str = "/tmp/nvflare/he_context/he_context_server.txt", + ): + self.num_clients = num_clients + self.encryption = encryption + self.data_root = data_root + self.he_context_path_client = he_context_path_client + self.he_context_path_server = he_context_path_server + + # Set job name and script based on encryption mode + if self.encryption: + job_name = "KM_HE" + train_script = "client_he.py" + script_args = f"--data_root {data_root} --he_context_path {he_context_path_client}" + controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path_server) + else: + job_name = "KM" + train_script = "client.py" + script_args = f"--data_root {data_root}" + controller = KM(min_clients=num_clients) + + # Create the FedJob + job = FedJob(name=job_name, min_clients=num_clients) + + # Add controller workflow to server + job.to_server(controller) + + # Add ScriptRunner to all clients + runner = ScriptRunner( + script=train_script, + script_args=script_args, + framework="raw", + launch_external_process=False, + ) + job.to_clients(runner, tasks=["train"]) + + super().__init__(job) + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--workspace_dir", + type=str, + default="/tmp/nvflare/workspaces", + help="Work directory for simulator runs, default to '/tmp/nvflare/workspaces'", + ) + parser.add_argument( + "--job_dir", + type=str, + default="/tmp/nvflare/jobs", + help="Directory for job export, default to '/tmp/nvflare/jobs'", + ) + parser.add_argument( + "--encryption", + action=argparse.BooleanOptionalAction, + help="Whether to enable encryption, default to False", + ) + parser.add_argument( + "--num_clients", + type=int, + default=5, + help="Number of clients to simulate, default to 5", + ) + parser.add_argument( + "--num_threads", + type=int, + help="Number of threads to use for FL simulation, default to the number of clients if not specified", + ) + parser.add_argument( + "--data_root", + type=str, + default="/tmp/nvflare/dataset/km_data", + help="Root directory for KM data, default to '/tmp/nvflare/dataset/km_data'", + ) + parser.add_argument( + "--he_context_path", + type=str, + default="/tmp/nvflare/he_context/he_context_client.txt", + help="Path to HE context file for simulation mode (client context), default to '/tmp/nvflare/he_context/he_context_client.txt'. In production mode, context files are auto-provisioned.", + ) + parser.add_argument( + "--startup_kit_location", + type=str, + default=None, + help="Startup kit location for production mode, default to None (simulation mode)", + ) + parser.add_argument( + "--username", + type=str, + default="admin@nvidia.com", + help="Username for production mode, default to 'admin@nvidia.com'", + ) + return parser.parse_args() + + +def main(): + args = define_parser() + + num_clients = args.num_clients + num_threads = args.num_threads if args.num_threads else num_clients + + # Use workspace directory directly (SimEnv will create job-specific subdirectories) + workspace_dir = args.workspace_dir + + # Adjust HE context paths based on environment + if args.startup_kit_location and args.encryption: + # In production mode, use just the filename - NVFlare's SecurityContentService + # will resolve the path relative to each participant's workspace + he_context_path_client = "client_context.tenseal" + he_context_path_server = "server_context.tenseal" + elif args.encryption: + # In simulation mode, use the manually prepared context files + he_context_path_client = args.he_context_path + # Derive server path from client path + he_context_path_server = he_context_path_client.replace("he_context_client.txt", "he_context_server.txt") + else: + # No encryption - values won't be used + he_context_path_client = None + he_context_path_server = None + + # Create the recipe + recipe = KMRecipe( + num_clients=num_clients, + encryption=args.encryption, + data_root=args.data_root, + he_context_path_client=he_context_path_client, + he_context_path_server=he_context_path_server, + ) + + # Export job + recipe.job.export_job(args.job_dir) + + # Run recipe + if args.startup_kit_location: + print("\n=== Production Mode ===") + env = ProdEnv(startup_kit_location=args.startup_kit_location, username=args.username) + else: + print("\n=== Simulation Mode ===") + env = SimEnv(num_clients=num_clients, num_threads=num_threads, workspace_root=workspace_dir) + + run = recipe.execute(env) + print(f"Job Status: {run.get_status()}") + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/km_job.py b/examples/advanced/kaplan-meier-he/km_job.py deleted file mode 100644 index 1f59ac6a13..0000000000 --- a/examples/advanced/kaplan-meier-he/km_job.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -from src.kaplan_meier_wf import KM -from src.kaplan_meier_wf_he import KM_HE - -from nvflare import FedJob -from nvflare.job_config.script_runner import ScriptRunner - - -def main(): - args = define_parser() - # Default paths - data_root = "/tmp/nvflare/dataset/km_data" - he_context_path = "/tmp/nvflare/he_context/he_context_client.txt" - - # Set the script and config - if args.encryption: - job_name = "KM_HE" - train_script = "src/kaplan_meier_train_he.py" - script_args = f"--data_root {data_root} --he_context_path {he_context_path}" - else: - job_name = "KM" - train_script = "src/kaplan_meier_train.py" - script_args = f"--data_root {data_root}" - - # Set the number of clients and threads - num_clients = args.num_clients - if args.num_threads: - num_threads = args.num_threads - else: - num_threads = num_clients - - # Set the output workspace and job directories - workspace_dir = os.path.join(args.workspace_dir, job_name) - job_dir = args.job_dir - - # Create the FedJob - job = FedJob(name=job_name, min_clients=num_clients) - - # Define the KM controller workflow and send to server - if args.encryption: - controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) - else: - controller = KM(min_clients=num_clients) - job.to_server(controller) - - # Define the ScriptRunner and send to all clients - runner = ScriptRunner( - script=train_script, - script_args=script_args, - framework="raw", - params_exchange_format="raw", - launch_external_process=False, - ) - job.to_clients(runner, tasks=["train"]) - - # Export the job - print("job_dir=", job_dir) - job.export_job(job_dir) - - # Run the job - print("workspace_dir=", workspace_dir) - print("num_threads=", num_threads) - job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads) - - -def define_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--workspace_dir", - type=str, - default="/tmp/nvflare/jobs/km/workdir", - help="work directory, default to '/tmp/nvflare/jobs/km/workdir'", - ) - parser.add_argument( - "--job_dir", - type=str, - default="/tmp/nvflare/jobs/km/jobdir", - help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'", - ) - parser.add_argument( - "--encryption", - action=argparse.BooleanOptionalAction, - help="whether to enable encryption, default to False", - ) - parser.add_argument( - "--num_clients", - type=int, - default=5, - help="number of clients to simulate, default to 5", - ) - parser.add_argument( - "--num_threads", - type=int, - help="number of threads to use for FL simulation, default to the number of clients if not specified", - ) - return parser.parse_args() - - -if __name__ == "__main__": - main() diff --git a/examples/advanced/kaplan-meier-he/project.yml b/examples/advanced/kaplan-meier-he/project.yml new file mode 100644 index 0000000000..2a6cc01f22 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/project.yml @@ -0,0 +1,52 @@ +# Sample project.yml for Kaplan-Meier production deployment with HE (CKKS) +# This file is used for provisioning secure startup kits with HE context +# Note: Both simulation and production modes use CKKS scheme for consistency + +api_version: 3 +name: km_he_project +description: Kaplan-Meier Survival Analysis with Homomorphic Encryption + +participants: + - name: localhost + type: server + org: example + fed_learn_port: 8002 + admin_port: 8003 + - name: site-1 + type: client + org: example + - name: site-2 + type: client + org: example + - name: site-3 + type: client + org: example + - name: site-4 + type: client + org: example + - name: site-5 + type: client + org: example + - name: admin@nvidia.com + type: admin + org: nvidia + role: project_admin + +# The same methods in all builders are called in their order defined in builders section +builders: + - path: nvflare.lighter.impl.workspace.WorkspaceBuilder + args: + template_file: master_template.yml + - path: nvflare.lighter.impl.static_file.StaticFileBuilder + args: + config_folder: config + - path: nvflare.lighter.impl.cert.CertBuilder + # HEBuilder provisions HE context for CKKS scheme + - path: nvflare.lighter.impl.he.HEBuilder + args: + poly_modulus_degree: 8192 + coeff_mod_bit_sizes: [60, 40, 40] + scale_bits: 40 + scheme: CKKS # Using CKKS scheme for production deployment + - path: nvflare.lighter.impl.signature.SignatureBuilder + diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/server.py similarity index 94% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py rename to examples/advanced/kaplan-meier-he/server.py index 436778fa57..7afe73d06a 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/server.py @@ -21,7 +21,7 @@ # Controller Workflow class KM(ModelController): def __init__(self, min_clients: int): - super(KM, self).__init__() + super(KM, self).__init__(persistor_id="") self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients self.num_rounds = 2 @@ -47,8 +47,9 @@ def aggr_hist(self, sag_result: dict[str, dict[str, FLModel]]): hist_idx_max = 0 for fl_model in sag_result: hist = fl_model.params["hist_obs"] - if hist_idx_max < max(hist.keys()): - hist_idx_max = max(hist.keys()) + max_idx = max(hist.keys()) + if hist_idx_max < max_idx: + hist_idx_max = max_idx hist_idx_max += 1 hist_obs_global = {} diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py b/examples/advanced/kaplan-meier-he/server_he.py similarity index 66% rename from examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py rename to examples/advanced/kaplan-meier-he/server_he.py index 820848ea6f..4b3ac40393 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py +++ b/examples/advanced/kaplan-meier-he/server_he.py @@ -14,6 +14,7 @@ import base64 import logging +import os import tenseal as ts @@ -25,7 +26,7 @@ class KM_HE(ModelController): def __init__(self, min_clients: int, he_context_path: str): - super(KM_HE, self).__init__() + super(KM_HE, self).__init__(persistor_id="") self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients self.he_context_path = he_context_path @@ -39,9 +40,30 @@ def run(self): _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) def read_data(self, file_name: str): + # Handle both absolute and relative paths + # In production mode, HE context files are in the startup directory + if not os.path.isabs(file_name) and not os.path.exists(file_name): + # Try CWD/startup/ (production deployment location) + cwd = os.getcwd() + startup_path = os.path.join(cwd, "startup", file_name) + if os.path.exists(startup_path): + file_name = startup_path + self.logger.info(f"Using HE context file from startup directory: {file_name}") + with open(file_name, "rb") as f: data = f.read() - return base64.b64decode(data) + + # Handle both base64-encoded (simulation mode) and raw binary (production mode) formats + # Production mode (HEBuilder): files are raw binary (.tenseal) + # Simulation mode (prepare_he_context.py): files are base64-encoded (.txt) + if file_name.endswith(".tenseal"): + # Production mode: raw binary format + self.logger.info("Using raw binary HE context (production mode)") + return data + else: + # Simulation mode: base64-encoded format (.txt files) + self.logger.info("Using base64-encoded HE context (simulation mode)") + return base64.b64decode(data) def start_fl_collect_max_idx(self): self.logger.info("send initial message to all sites to start FL \n") @@ -51,7 +73,7 @@ def start_fl_collect_max_idx(self): return results def aggr_max_idx(self, sag_result: dict[str, dict[str, FLModel]]): - self.logger.info("aggregate max histogram index \n") + self.logger.info("aggregate max histogram index (cleartext) \n") if not sag_result: raise RuntimeError("input is None or empty") @@ -64,7 +86,7 @@ def aggr_max_idx(self, sag_result: dict[str, dict[str, FLModel]]): return max(max_idx_global) + 1 def distribute_max_idx_collect_enc_stats(self, result: int): - self.logger.info("send global max_index to all sites \n") + self.logger.info("send global max_index (cleartext) to all sites \n") model = FLModel( params={"max_idx_global": result}, @@ -78,7 +100,7 @@ def distribute_max_idx_collect_enc_stats(self, result: int): return results def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): - self.logger.info("aggregate histogram within HE \n") + self.logger.info("aggregate histogram (ciphertext) within HE \n") # Load HE context he_context_serial = self.read_data(self.he_context_path) @@ -92,22 +114,22 @@ def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): for fl_model in sag_result: site = fl_model.meta.get("client_name", None) hist_obs_he_serial = fl_model.params["hist_obs"] - hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_obs_he = ts.ckks_vector_from(he_context, hist_obs_he_serial) hist_cen_he_serial = fl_model.params["hist_cen"] - hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + hist_cen_he = ts.ckks_vector_from(he_context, hist_cen_he_serial) if not hist_obs_global: - print(f"assign global hist with result from {site}") + self.logger.info(f"assign global hist (ciphertext) with result from {site}") hist_obs_global = hist_obs_he else: - print(f"add to global hist with result from {site}") + self.logger.info(f"add ciphertext to global hist with result from {site}") hist_obs_global += hist_obs_he if not hist_cen_global: - print(f"assign global hist with result from {site}") + self.logger.info(f"assign global censored hist (ciphertext) with result from {site}") hist_cen_global = hist_cen_he else: - print(f"add to global hist with result from {site}") + self.logger.info(f"add ciphertext to global censored hist with result from {site}") hist_cen_global += hist_cen_he # return the two accumulated vectors, serialized for transmission @@ -116,7 +138,7 @@ def aggr_he_hist(self, sag_result: dict[str, dict[str, FLModel]]): return hist_obs_global_serial, hist_cen_global_serial def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): - self.logger.info("send global accumulated histograms within HE to all sites \n") + self.logger.info("send global accumulated histograms (ciphertext) to all sites \n") model = FLModel( params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, diff --git a/examples/advanced/kaplan-meier-he/start_all.sh b/examples/advanced/kaplan-meier-he/start_all.sh new file mode 100755 index 0000000000..578e579a9a --- /dev/null +++ b/examples/advanced/kaplan-meier-he/start_all.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Convenience script to start all parties for local production environment testing +# This starts the server and all 5 clients in the background + +set -e + +WORKSPACE_ROOT="/tmp/nvflare/prod_workspaces/km_he_project/prod_00" + +echo "======================================" +echo "Starting NVFlare Production Environment" +echo "======================================" +echo "" + +# Check if workspace exists +if [ ! -d "$WORKSPACE_ROOT" ]; then + echo "Error: Workspace not found at $WORKSPACE_ROOT" + echo "Please run provisioning first:" + echo " nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces" + exit 1 +fi + +# Function to start a component +start_component() { + local name=$1 + local path=$2 + local log_file="/tmp/nvflare/logs/${name}.log" + + mkdir -p /tmp/nvflare/logs + + echo "Starting ${name}..." + cd "${path}" + ./startup/start.sh > "${log_file}" 2>&1 & + local pid=$! + echo " PID: ${pid}" + echo " Log: ${log_file}" + echo "${pid}" > "/tmp/nvflare/logs/${name}.pid" +} + +# Start server +echo "" +echo "1. Starting Server..." +start_component "localhost" "${WORKSPACE_ROOT}/localhost" +echo " Waiting for server to be ready..." +sleep 10 + +# Start clients +echo "" +echo "2. Starting Clients..." +for i in {1..5}; do + start_component "site-${i}" "${WORKSPACE_ROOT}/site-${i}" + sleep 2 +done + +echo "" +echo "======================================" +echo "All parties started successfully!" +echo "======================================" +echo "" +echo "Server and clients are running in the background." +echo "Logs are available in /tmp/nvflare/logs/" +echo "" +echo "To check status:" +echo " tail -f /tmp/nvflare/logs/localhost.log" +echo " tail -f /tmp/nvflare/logs/site-1.log" +echo "" +echo "To start admin console:" +echo " cd ${WORKSPACE_ROOT}/admin@nvidia.com" +echo " ./startup/fl_admin.sh" +echo "" +echo "To stop all parties, use Admin Console" +echo "" + diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py index ceedf4c9a4..d9104ad77a 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py @@ -21,8 +21,8 @@ def data_split_args_parser(): parser = argparse.ArgumentParser(description="Generate HE context") - parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV") - parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096") + parser.add_argument("--scheme", type=str, default="CKKS", help="HE scheme, default is CKKS") + parser.add_argument("--poly_modulus_degree", type=int, default=8192, help="Poly modulus degree, default is 8192") parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server") return parser @@ -42,8 +42,14 @@ def main(): context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193) elif args.scheme == "CKKS": scheme = ts.SCHEME_TYPE.CKKS - # Generate HE context, CKKS does not need plain_modulus - context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree) + # Generate HE context for CKKS + coeff_mod_bit_sizes = [60, 40, 40] + context = ts.context( + scheme, poly_modulus_degree=args.poly_modulus_degree, coeff_mod_bit_sizes=coeff_mod_bit_sizes + ) + # CKKS requires global scale to be set + context.generate_relin_keys() + context.global_scale = 2**40 else: raise ValueError("HE scheme not supported")