Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 43 additions & 44 deletions src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
This module defines the MetaL2CClient and MetaL2CServer classes for a Meta-learning
framework with collaborative weights in federated learning. It also includes the
This module defines the MetaL2CClient and MetaL2CServer classes for a Meta-learning
framework with collaborative weights in federated learning. It also includes the
ModelEncoder class for encoding model weights.
"""

from typing import Any, Dict
from typing import Any, Dict, List
from utils.communication.comm_utils import CommunicationManager
import math
import torch
Expand All @@ -21,24 +21,24 @@

class ModelEncoder(nn.Module):
"""
A neural network-based model encoder that encodes the weights of ResNet18
convolutional and batch normalization layers into a vector, while ignoring
A neural network-based model encoder that encodes the weights of ResNet18
convolutional and batch normalization layers into a vector, while ignoring
the final linear layer.

Args:
model_dict (dict): Dictionary of model weights to be encoded.
"""
def __init__(self, model_dict):
def __init__(self, model_dict: Dict[str, torch.Tensor]) -> None:
super(ModelEncoder, self).__init__()
self.init_encoder_weights(model_dict)

def weight_key_converter(self, key: str) -> str:
"""
Converts model weight keys by replacing periods with "_dot_".

Args:
key (str): The original weight key.

Returns:
str: The converted weight key.
"""
Expand All @@ -47,7 +47,7 @@ def weight_key_converter(self, key: str) -> str:
def init_encoder_weights(self, model_dict: Dict[str, torch.Tensor]) -> None:
"""
Initializes the encoder's weights for convolutional and batch normalization layers.

Args:
model_dict (Dict[str, torch.Tensor]): Dictionary containing the model's layers and weights.
"""
Expand Down Expand Up @@ -83,14 +83,14 @@ def init_encoder_weights(self, model_dict: Dict[str, torch.Tensor]) -> None:
def forward(self, model_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Forward pass of the encoder, converting model weights into a vector.

Args:
model_dict (Dict[str, torch.Tensor]): Dictionary containing model weights.

Returns:
torch.Tensor: A concatenated tensor representation of the model's weights.
"""
encoder_outs = []
encoder_outs: list[torch.Tensor] = []
for key in self.ordered_keys:
wts = model_dict[key]
if "conv" in key:
Expand All @@ -109,12 +109,12 @@ def forward(self, model_dict: Dict[str, torch.Tensor]) -> torch.Tensor:

class MetaL2CClient(BaseFedAvgClient):
"""
A federated learning client that uses model weight encoding and knowledge
A federated learning client that uses model weight encoding and knowledge
sharing with collaborative learning based on Meta-L2C.

Args:
config (Dict[str, Any]): Configuration parameters for the client.
comm_utils (CommunicationManager): A communication manager for sending
comm_utils (CommunicationManager): A communication manager for sending
and receiving data.
"""
def __init__(
Expand All @@ -127,20 +127,20 @@ def __init__(
self.encoder.parameters(), lr=self.config["alpha_lr"]
)

self.model_keys_to_ignore = []
self.model_keys_to_ignore: list[str] = []
if not self.config.get("average_last_layer", True): # By default include last layer
keys = self.model_utils.get_last_layer_keys(self.get_model_weights())
keys: List[str] = self.model_utils.get_last_layer_keys(self.get_model_weights())
self.model_keys_to_ignore.extend(keys)

self.sharing_mode = self.config["sharing"]
self.neighbors_ids = list(range(1, self.config["num_users"] + 1))
self.sharing_mode: str = self.config["sharing"]
self.neighbors_ids: list[int] = list(range(1, self.config["num_users"] + 1))

def get_representation(self) -> Dict[str, torch.Tensor]:
"""
Retrieves the representation of the client's current model weights.

Returns:
Dict[str, torch.Tensor]: A dictionary containing the model's weight
Dict[str, torch.Tensor]: A dictionary containing the model's weight
representations.
"""
return self.model_utils.substract_model_weights(
Expand All @@ -150,9 +150,9 @@ def get_representation(self) -> Dict[str, torch.Tensor]:
def get_knowledge_sharing_artifact(self) -> Dict[str, torch.Tensor]:
"""
Retrieves the knowledge sharing artifact based on the sharing mode.

Returns:
Dict[str, torch.Tensor]: Knowledge sharing artifact (either model
Dict[str, torch.Tensor]: Knowledge sharing artifact (either model
updates or model weights).
"""
if self.sharing_mode == "updates":
Expand All @@ -166,13 +166,13 @@ def get_knowledge_sharing_artifact(self) -> Dict[str, torch.Tensor]:

def get_collaborator_weights(self, reprs_dict: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]:
"""
Computes the collaboration weights based on the representations and
Computes the collaboration weights based on the representations and
the encoder.

Args:
reprs_dict (Dict[int, torch.Tensor]): A dictionary containing representations
reprs_dict (Dict[int, torch.Tensor]): A dictionary containing representations
from neighboring clients.

Returns:
Dict[int, torch.Tensor]: Collaboration weights for each client.
"""
Expand All @@ -189,19 +189,19 @@ def get_collaborator_weights(self, reprs_dict: Dict[int, torch.Tensor]) -> Dict[
return collab_weights_dict

def learn_collab_weights(self, models_update_wts: Dict[int, Dict[str, torch.Tensor]],
collab_weights_tensor_dict: Dict[int, torch.Tensor]) -> tuple:
collab_weights_tensor_dict: Dict[int, torch.Tensor]) -> tuple[float, float]:
"""
Learns the collaboration weights and updates them during validation.

Args:
models_update_wts (Dict[int, Dict[str, torch.Tensor]]): Weights updates from models.
collab_weights_tensor_dict (Dict[int, torch.Tensor]): Collaboration weights.

Returns:
tuple: Validation loss and accuracy after learning the collaborative weights.
"""
self.model.eval()
val_loss, correct = 0, 0
val_loss, correct = 0.0, 0

for data, target in self.val_dloader:
data, target = data.to(self.device), target.to(self.device)
Expand Down Expand Up @@ -242,12 +242,12 @@ def learn_collab_weights(self, models_update_wts: Dict[int, Dict[str, torch.Tens

def filter_out_worse_neighbors(self, num_neighbors_to_keep: int, collab_weights_dict: Dict[int, torch.Tensor]) -> None:
"""
Filters out neighbors with lower collaboration weights, keeping only
Filters out neighbors with lower collaboration weights, keeping only
the top k neighbors.

Args:
num_neighbors_to_keep (int): Number of neighbors to keep.
collab_weights_dict (Dict[int, torch.Tensor]): Dictionary of
collab_weights_dict (Dict[int, torch.Tensor]): Dictionary of
collaboration weights.
"""
if num_neighbors_to_keep <= 0:
Expand All @@ -258,7 +258,7 @@ def filter_out_worse_neighbors(self, num_neighbors_to_keep: int, collab_weights_

def run_protocol(self) -> None:
"""
Runs the federated learning protocol for this client, which includes local
Runs the federated learning protocol for this client, which includes local
training, knowledge sharing, and weight aggregation.
"""
self.model_init = self.get_model_weights()
Expand Down Expand Up @@ -320,12 +320,12 @@ def run_protocol(self) -> None:

class MetaL2CServer(BaseFedAvgServer):
"""
A federated learning server that coordinates training across clients, averages
A federated learning server that coordinates training across clients, averages
model updates, and aggregates collaboration weights.

Args:
config (Dict[str, Any]): Configuration parameters for the server.
comm_utils (CommunicationManager): Communication manager for coordinating
comm_utils (CommunicationManager): Communication manager for coordinating
the protocol with clients.
"""
def __init__(
Expand All @@ -335,7 +335,7 @@ def __init__(
self.config = config
self.set_model_parameters(config)
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"

def average_state_dicts(self, state_dicts: list) -> Dict[str, torch.Tensor]:
"""
Averages a list of model state dictionaries.
Expand All @@ -357,11 +357,11 @@ def average_state_dicts(self, state_dicts: list) -> Dict[str, torch.Tensor]:

def single_round(self, avg_alpha: Dict[str, torch.Tensor]) -> tuple:
"""
Runs a single round of federated learning by coordinating client training,
Runs a single round of federated learning by coordinating client training,
gathering model representations, and aggregating collaboration weights.

Args:
avg_alpha (Dict[str, torch.Tensor]): The averaged model weights to send
avg_alpha (Dict[str, torch.Tensor]): The averaged model weights to send
to clients for the next round.

Returns:
Expand Down Expand Up @@ -390,7 +390,7 @@ def single_round(self, avg_alpha: Dict[str, torch.Tensor]) -> tuple:

def run_protocol(self) -> None:
"""
Runs the federated learning protocol for the server, coordinating training
Runs the federated learning protocol for the server, coordinating training
across multiple rounds and clients.
"""
self.log_utils.log_console("Starting Meta L2C")
Expand All @@ -402,12 +402,11 @@ def run_protocol(self) -> None:
for cur_round in range(start_round, total_rounds):
self.round = cur_round
self.log_utils.log_console(f"Starting round {cur_round}")

round_stats, avg_alpha = self.single_round(avg_alpha)
stats.append(round_stats)

stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats)
stats_dict["round_step"] = 1
self.log_utils.log_experiments_stats(stats_dict)
self.plot_utils.plot_experiments_stats(stats_dict)

20 changes: 10 additions & 10 deletions src/algos/topologies/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@


class RingTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)

def generate_graph(self) -> None:
self.graph = nx.cycle_graph(self.num_users) # type: ignore


class StarTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)

def generate_graph(self) -> None:
self.graph = nx.star_graph(self.num_users - 1) # type: ignore


class FullyConnectedTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)

def generate_graph(self) -> None:
self.graph = nx.complete_graph(self.num_users) # type: ignore


class GridTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)
if self.num_users**0.5 != int(self.num_users**0.5):
raise ValueError("Number of users should be a perfect square for grid topology")
Expand All @@ -40,7 +40,7 @@ def generate_graph(self) -> None:


class TorusTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)
if self.num_users**0.5 != int(self.num_users**0.5):
raise ValueError("Number of users should be a perfect square for grid topology")
Expand All @@ -49,14 +49,14 @@ def generate_graph(self) -> None:
self.graph = nx.grid_2d_graph(ceil(self.num_users**0.5), ceil(self.num_users**0.5), periodic=True) # type: ignore

class CircleLadderTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
super().__init__(config, rank)

def generate_graph(self) -> None:
self.graph = nx.circular_ladder_graph(self.num_users) # type: ignore

class TreeTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int, children: int = 2):
def __init__(self, config: ConfigType, rank: int, children: int = 2) -> None:
super().__init__(config, rank)
self.children = children

Expand All @@ -67,7 +67,7 @@ def generate_graph(self) -> None:

######### Random Graphs #########
class ErdosRenyiTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
p: float = config["topology"]["p"] # type: ignore
super().__init__(config, rank)
self.p = p
Expand All @@ -78,7 +78,7 @@ def generate_graph(self) -> None:


class WattsStrogatzTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
def __init__(self, config: ConfigType, rank: int) -> None:
k: int = config["topology"]["k"] # type: ignore
p: float = config["topology"]["p"] # type: ignore
super().__init__(config, rank)
Expand Down Expand Up @@ -113,4 +113,4 @@ def select_topology(config: ConfigType, rank: int) -> BaseTopology:
return WattsStrogatzTopology(config, rank)
if topology_name == "tree":
return TreeTopology(config, rank)
raise ValueError(f"Topology {topology_name} not implemented")
raise ValueError(f"Topology {topology_name} not implemented")
3 changes: 3 additions & 0 deletions src/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
num_users: 5
host: localhost
port: 50048
1 change: 1 addition & 0 deletions src/configs/algo_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from typing import Dict, List
from .malicious_config import malicious_config_list
import random
Expand Down
1 change: 1 addition & 0 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# System Configuration
# TODO: Set up multiple non-iid configurations here. The goal of a separate system config
# is to simulate different real-world scenarios without changing the algorithm configuration.
Expand Down
Loading