Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/main_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
post_hoc_plot: bool = True

# for each experiment key, write the modifications to the config file
gpu_ids = [2, 3, 5, 6]
exp_dict = {
gpu_ids : List[int] = [2, 3, 5, 6]
exp_dict : Dict[str, Dict[str, Any]] = {
"experiment_1": {
"algo_config": traditional_fl,
"sys_config": grpc_system_config,
Expand All @@ -46,7 +46,7 @@
}

# parse the arguments
parser = argparse.ArgumentParser(description="host address of the nodes")
parser: argparse.ArgumentParser = argparse.ArgumentParser(description="host address of the nodes")
parser.add_argument(
"-host",
nargs="?",
Expand All @@ -58,11 +58,11 @@

for exp_id, exp_config in exp_dict.items():
# update the algo config with config settings
base_algo_config = exp_config["algo_config"].copy()
base_algo_config: Dict[str, Any] = exp_config["algo_config"].copy()
base_algo_config.update(exp_config["algo"])

# update the sys config with config settings
base_sys_config = exp_config["sys_config"].copy()
base_sys_config: Dict[str, Any] = exp_config["sys_config"].copy()
base_sys_config.update(exp_config["sys"])

# set up the full config file by combining the algo and sys config
Expand All @@ -71,7 +71,7 @@
base_sys_config["algos"] = get_algo_configs(num_users=n, algo_configs=[base_algo_config], seed=seed)
base_sys_config["device_ids"] = get_device_ids(n, gpu_ids)

full_config = base_sys_config.copy()
full_config: Dict[str, Any] = base_sys_config.copy()
full_config["exp_id"] = exp_id

# write the config file as python file configs/temp_config.py
Expand All @@ -97,7 +97,7 @@
# run the post-hoc analysis
if post_hoc_plot:
full_config = process_config(full_config) # this populates the results path
logs_dir = full_config["results_path"] + '/logs/'
logs_dir:str = full_config["results_path"] + '/logs/'

# aggregate metrics across all users
aggregate_metrics_across_users(logs_dir)
Expand Down