diff --git a/src/algos/MetaL2C.py b/src/algos/MetaL2C.py index beb267de..01064b8b 100644 --- a/src/algos/MetaL2C.py +++ b/src/algos/MetaL2C.py @@ -1,10 +1,10 @@ """ -This module defines the MetaL2CClient and MetaL2CServer classes for a Meta-learning -framework with collaborative weights in federated learning. It also includes the +This module defines the MetaL2CClient and MetaL2CServer classes for a Meta-learning +framework with collaborative weights in federated learning. It also includes the ModelEncoder class for encoding model weights. """ -from typing import Any, Dict +from typing import Any, Dict, List from utils.communication.comm_utils import CommunicationManager import math import torch @@ -21,24 +21,24 @@ class ModelEncoder(nn.Module): """ - A neural network-based model encoder that encodes the weights of ResNet18 - convolutional and batch normalization layers into a vector, while ignoring + A neural network-based model encoder that encodes the weights of ResNet18 + convolutional and batch normalization layers into a vector, while ignoring the final linear layer. - + Args: model_dict (dict): Dictionary of model weights to be encoded. """ - def __init__(self, model_dict): + def __init__(self, model_dict: Dict[str, torch.Tensor]) -> None: super(ModelEncoder, self).__init__() self.init_encoder_weights(model_dict) def weight_key_converter(self, key: str) -> str: """ Converts model weight keys by replacing periods with "_dot_". - + Args: key (str): The original weight key. - + Returns: str: The converted weight key. """ @@ -47,7 +47,7 @@ def weight_key_converter(self, key: str) -> str: def init_encoder_weights(self, model_dict: Dict[str, torch.Tensor]) -> None: """ Initializes the encoder's weights for convolutional and batch normalization layers. - + Args: model_dict (Dict[str, torch.Tensor]): Dictionary containing the model's layers and weights. """ @@ -83,14 +83,14 @@ def init_encoder_weights(self, model_dict: Dict[str, torch.Tensor]) -> None: def forward(self, model_dict: Dict[str, torch.Tensor]) -> torch.Tensor: """ Forward pass of the encoder, converting model weights into a vector. - + Args: model_dict (Dict[str, torch.Tensor]): Dictionary containing model weights. - + Returns: torch.Tensor: A concatenated tensor representation of the model's weights. """ - encoder_outs = [] + encoder_outs: list[torch.Tensor] = [] for key in self.ordered_keys: wts = model_dict[key] if "conv" in key: @@ -109,12 +109,12 @@ def forward(self, model_dict: Dict[str, torch.Tensor]) -> torch.Tensor: class MetaL2CClient(BaseFedAvgClient): """ - A federated learning client that uses model weight encoding and knowledge + A federated learning client that uses model weight encoding and knowledge sharing with collaborative learning based on Meta-L2C. - + Args: config (Dict[str, Any]): Configuration parameters for the client. - comm_utils (CommunicationManager): A communication manager for sending + comm_utils (CommunicationManager): A communication manager for sending and receiving data. """ def __init__( @@ -127,20 +127,20 @@ def __init__( self.encoder.parameters(), lr=self.config["alpha_lr"] ) - self.model_keys_to_ignore = [] + self.model_keys_to_ignore: list[str] = [] if not self.config.get("average_last_layer", True): # By default include last layer - keys = self.model_utils.get_last_layer_keys(self.get_model_weights()) + keys: List[str] = self.model_utils.get_last_layer_keys(self.get_model_weights()) self.model_keys_to_ignore.extend(keys) - self.sharing_mode = self.config["sharing"] - self.neighbors_ids = list(range(1, self.config["num_users"] + 1)) + self.sharing_mode: str = self.config["sharing"] + self.neighbors_ids: list[int] = list(range(1, self.config["num_users"] + 1)) def get_representation(self) -> Dict[str, torch.Tensor]: """ Retrieves the representation of the client's current model weights. - + Returns: - Dict[str, torch.Tensor]: A dictionary containing the model's weight + Dict[str, torch.Tensor]: A dictionary containing the model's weight representations. """ return self.model_utils.substract_model_weights( @@ -150,9 +150,9 @@ def get_representation(self) -> Dict[str, torch.Tensor]: def get_knowledge_sharing_artifact(self) -> Dict[str, torch.Tensor]: """ Retrieves the knowledge sharing artifact based on the sharing mode. - + Returns: - Dict[str, torch.Tensor]: Knowledge sharing artifact (either model + Dict[str, torch.Tensor]: Knowledge sharing artifact (either model updates or model weights). """ if self.sharing_mode == "updates": @@ -166,13 +166,13 @@ def get_knowledge_sharing_artifact(self) -> Dict[str, torch.Tensor]: def get_collaborator_weights(self, reprs_dict: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: """ - Computes the collaboration weights based on the representations and + Computes the collaboration weights based on the representations and the encoder. - + Args: - reprs_dict (Dict[int, torch.Tensor]): A dictionary containing representations + reprs_dict (Dict[int, torch.Tensor]): A dictionary containing representations from neighboring clients. - + Returns: Dict[int, torch.Tensor]: Collaboration weights for each client. """ @@ -189,19 +189,19 @@ def get_collaborator_weights(self, reprs_dict: Dict[int, torch.Tensor]) -> Dict[ return collab_weights_dict def learn_collab_weights(self, models_update_wts: Dict[int, Dict[str, torch.Tensor]], - collab_weights_tensor_dict: Dict[int, torch.Tensor]) -> tuple: + collab_weights_tensor_dict: Dict[int, torch.Tensor]) -> tuple[float, float]: """ Learns the collaboration weights and updates them during validation. - + Args: models_update_wts (Dict[int, Dict[str, torch.Tensor]]): Weights updates from models. collab_weights_tensor_dict (Dict[int, torch.Tensor]): Collaboration weights. - + Returns: tuple: Validation loss and accuracy after learning the collaborative weights. """ self.model.eval() - val_loss, correct = 0, 0 + val_loss, correct = 0.0, 0 for data, target in self.val_dloader: data, target = data.to(self.device), target.to(self.device) @@ -242,12 +242,12 @@ def learn_collab_weights(self, models_update_wts: Dict[int, Dict[str, torch.Tens def filter_out_worse_neighbors(self, num_neighbors_to_keep: int, collab_weights_dict: Dict[int, torch.Tensor]) -> None: """ - Filters out neighbors with lower collaboration weights, keeping only + Filters out neighbors with lower collaboration weights, keeping only the top k neighbors. - + Args: num_neighbors_to_keep (int): Number of neighbors to keep. - collab_weights_dict (Dict[int, torch.Tensor]): Dictionary of + collab_weights_dict (Dict[int, torch.Tensor]): Dictionary of collaboration weights. """ if num_neighbors_to_keep <= 0: @@ -258,7 +258,7 @@ def filter_out_worse_neighbors(self, num_neighbors_to_keep: int, collab_weights_ def run_protocol(self) -> None: """ - Runs the federated learning protocol for this client, which includes local + Runs the federated learning protocol for this client, which includes local training, knowledge sharing, and weight aggregation. """ self.model_init = self.get_model_weights() @@ -320,12 +320,12 @@ def run_protocol(self) -> None: class MetaL2CServer(BaseFedAvgServer): """ - A federated learning server that coordinates training across clients, averages + A federated learning server that coordinates training across clients, averages model updates, and aggregates collaboration weights. Args: config (Dict[str, Any]): Configuration parameters for the server. - comm_utils (CommunicationManager): Communication manager for coordinating + comm_utils (CommunicationManager): Communication manager for coordinating the protocol with clients. """ def __init__( @@ -335,7 +335,7 @@ def __init__( self.config = config self.set_model_parameters(config) self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" - + def average_state_dicts(self, state_dicts: list) -> Dict[str, torch.Tensor]: """ Averages a list of model state dictionaries. @@ -357,11 +357,11 @@ def average_state_dicts(self, state_dicts: list) -> Dict[str, torch.Tensor]: def single_round(self, avg_alpha: Dict[str, torch.Tensor]) -> tuple: """ - Runs a single round of federated learning by coordinating client training, + Runs a single round of federated learning by coordinating client training, gathering model representations, and aggregating collaboration weights. Args: - avg_alpha (Dict[str, torch.Tensor]): The averaged model weights to send + avg_alpha (Dict[str, torch.Tensor]): The averaged model weights to send to clients for the next round. Returns: @@ -390,7 +390,7 @@ def single_round(self, avg_alpha: Dict[str, torch.Tensor]) -> tuple: def run_protocol(self) -> None: """ - Runs the federated learning protocol for the server, coordinating training + Runs the federated learning protocol for the server, coordinating training across multiple rounds and clients. """ self.log_utils.log_console("Starting Meta L2C") @@ -402,7 +402,7 @@ def run_protocol(self) -> None: for cur_round in range(start_round, total_rounds): self.round = cur_round self.log_utils.log_console(f"Starting round {cur_round}") - + round_stats, avg_alpha = self.single_round(avg_alpha) stats.append(round_stats) @@ -410,4 +410,3 @@ def run_protocol(self) -> None: stats_dict["round_step"] = 1 self.log_utils.log_experiments_stats(stats_dict) self.plot_utils.plot_experiments_stats(stats_dict) - diff --git a/src/algos/topologies/collections.py b/src/algos/topologies/collections.py index 632f42b4..358394b4 100644 --- a/src/algos/topologies/collections.py +++ b/src/algos/topologies/collections.py @@ -6,7 +6,7 @@ class RingTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) def generate_graph(self) -> None: @@ -14,7 +14,7 @@ def generate_graph(self) -> None: class StarTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) def generate_graph(self) -> None: @@ -22,7 +22,7 @@ def generate_graph(self) -> None: class FullyConnectedTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) def generate_graph(self) -> None: @@ -30,7 +30,7 @@ def generate_graph(self) -> None: class GridTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) if self.num_users**0.5 != int(self.num_users**0.5): raise ValueError("Number of users should be a perfect square for grid topology") @@ -40,7 +40,7 @@ def generate_graph(self) -> None: class TorusTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) if self.num_users**0.5 != int(self.num_users**0.5): raise ValueError("Number of users should be a perfect square for grid topology") @@ -49,14 +49,14 @@ def generate_graph(self) -> None: self.graph = nx.grid_2d_graph(ceil(self.num_users**0.5), ceil(self.num_users**0.5), periodic=True) # type: ignore class CircleLadderTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: super().__init__(config, rank) def generate_graph(self) -> None: self.graph = nx.circular_ladder_graph(self.num_users) # type: ignore class TreeTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int, children: int = 2): + def __init__(self, config: ConfigType, rank: int, children: int = 2) -> None: super().__init__(config, rank) self.children = children @@ -67,7 +67,7 @@ def generate_graph(self) -> None: ######### Random Graphs ######### class ErdosRenyiTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: p: float = config["topology"]["p"] # type: ignore super().__init__(config, rank) self.p = p @@ -78,7 +78,7 @@ def generate_graph(self) -> None: class WattsStrogatzTopology(BaseTopology): - def __init__(self, config: ConfigType, rank: int): + def __init__(self, config: ConfigType, rank: int) -> None: k: int = config["topology"]["k"] # type: ignore p: float = config["topology"]["p"] # type: ignore super().__init__(config, rank) @@ -113,4 +113,4 @@ def select_topology(config: ConfigType, rank: int) -> BaseTopology: return WattsStrogatzTopology(config, rank) if topology_name == "tree": return TreeTopology(config, rank) - raise ValueError(f"Topology {topology_name} not implemented") \ No newline at end of file + raise ValueError(f"Topology {topology_name} not implemented") diff --git a/src/config.yaml b/src/config.yaml new file mode 100644 index 00000000..1020b2c4 --- /dev/null +++ b/src/config.yaml @@ -0,0 +1,3 @@ +num_users: 5 +host: localhost +port: 50048 diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index a2242c62..b017840e 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -1,3 +1,4 @@ + from typing import Dict, List from .malicious_config import malicious_config_list import random diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 738f4c93..43d63ae0 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -1,3 +1,4 @@ + # System Configuration # TODO: Set up multiple non-iid configurations here. The goal of a separate system config # is to simulate different real-world scenarios without changing the algorithm configuration. diff --git a/src/main_exp.py b/src/main_exp.py index a326024d..95fc9d60 100644 --- a/src/main_exp.py +++ b/src/main_exp.py @@ -6,7 +6,7 @@ import argparse import subprocess -from typing import List +from typing import List, Dict, Any from utils.types import ConfigType from utils.config_utils import process_config @@ -19,8 +19,8 @@ post_hoc_plot: bool = True # for each experiment key, write the modifications to the config file -gpu_ids = [2, 3, 5, 6] -exp_dict = { +gpu_ids: List[int] = [2, 3, 5, 6] +exp_dict: Dict[str, Dict[str, Any]] = { "experiment_1": { "algo_config": traditional_fl, "sys_config": grpc_system_config, @@ -54,15 +54,15 @@ help=f"host address of the nodes", ) -args = parser.parse_args() +args: argparse.Namespace = parser.parse_args() for exp_id, exp_config in exp_dict.items(): # update the algo config with config settings - base_algo_config = exp_config["algo_config"].copy() + base_algo_config: Dict[str, Any] = exp_config["algo_config"].copy() base_algo_config.update(exp_config["algo"]) # update the sys config with config settings - base_sys_config = exp_config["sys_config"].copy() + base_sys_config: Dict[str, Any] = exp_config["sys_config"].copy() base_sys_config.update(exp_config["sys"]) # set up the full config file by combining the algo and sys config @@ -71,7 +71,7 @@ base_sys_config["algos"] = get_algo_configs(num_users=n, algo_configs=[base_algo_config], seed=seed) base_sys_config["device_ids"] = get_device_ids(n, gpu_ids) - full_config = base_sys_config.copy() + full_config: Dict[str, Any] = base_sys_config.copy() full_config["exp_id"] = exp_id # write the config file as python file configs/temp_config.py @@ -97,7 +97,7 @@ # run the post-hoc analysis if post_hoc_plot: full_config = process_config(full_config) # this populates the results path - logs_dir = full_config["results_path"] + '/logs/' + logs_dir: str = full_config["results_path"] + '/logs/' # aggregate metrics across all users aggregate_metrics_across_users(logs_dir) @@ -105,4 +105,4 @@ plot_all_metrics(logs_dir) # Continue with the next set of commands after supernode finishes - print(f"Supernode process {exp_id} finished. Proceeding to next set of commands.") \ No newline at end of file + print(f"Supernode process {exp_id} finished. Proceeding to next set of commands.")