diff --git a/src/main_exp.py b/src/main_exp.py index a326024d..e4aaf504 100644 --- a/src/main_exp.py +++ b/src/main_exp.py @@ -19,8 +19,8 @@ post_hoc_plot: bool = True # for each experiment key, write the modifications to the config file -gpu_ids = [2, 3, 5, 6] -exp_dict = { +gpu_ids : List[int] = [2, 3, 5, 6] +exp_dict : Dict[str, Dict[str, Any]] = { "experiment_1": { "algo_config": traditional_fl, "sys_config": grpc_system_config, @@ -46,7 +46,7 @@ } # parse the arguments -parser = argparse.ArgumentParser(description="host address of the nodes") +parser: argparse.ArgumentParser = argparse.ArgumentParser(description="host address of the nodes") parser.add_argument( "-host", nargs="?", @@ -58,11 +58,11 @@ for exp_id, exp_config in exp_dict.items(): # update the algo config with config settings - base_algo_config = exp_config["algo_config"].copy() + base_algo_config: Dict[str, Any] = exp_config["algo_config"].copy() base_algo_config.update(exp_config["algo"]) # update the sys config with config settings - base_sys_config = exp_config["sys_config"].copy() + base_sys_config: Dict[str, Any] = exp_config["sys_config"].copy() base_sys_config.update(exp_config["sys"]) # set up the full config file by combining the algo and sys config @@ -71,7 +71,7 @@ base_sys_config["algos"] = get_algo_configs(num_users=n, algo_configs=[base_algo_config], seed=seed) base_sys_config["device_ids"] = get_device_ids(n, gpu_ids) - full_config = base_sys_config.copy() + full_config: Dict[str, Any] = base_sys_config.copy() full_config["exp_id"] = exp_id # write the config file as python file configs/temp_config.py @@ -97,7 +97,7 @@ # run the post-hoc analysis if post_hoc_plot: full_config = process_config(full_config) # this populates the results path - logs_dir = full_config["results_path"] + '/logs/' + logs_dir:str = full_config["results_path"] + '/logs/' # aggregate metrics across all users aggregate_metrics_across_users(logs_dir)