From 43ba198134f328ca15276e0124201789550d996f Mon Sep 17 00:00:00 2001 From: tremblerz Date: Thu, 6 Mar 2025 20:41:33 +0000 Subject: [PATCH 1/2] little bit of typehinting --- src/algos/base_class.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 8bf9bc42..899fcd47 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -527,8 +527,8 @@ def set_data_parameters(self, config: ConfigType) -> None: self.classes_of_interest = classes self.train_indices = train_indices self.train_dset = train_dset - self.dloader = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False) - self._test_loader = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False) + self.dloader: DataLoader[Any] = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False) + self._test_loader: DataLoader[Any] = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False) print("Using GIA data setup") print(self.labels) else: From 27b993a2dfe5c9a4c63f54346b90ae25e182de19 Mon Sep 17 00:00:00 2001 From: tremblerz Date: Thu, 6 Mar 2025 20:42:57 +0000 Subject: [PATCH 2/2] dynamic algo and config bug fix and type hint in comm_utils --- src/algos/fl_dynamic.py | 23 ++++++++++------------- src/configs/sys_config.py | 13 ++++++------- src/utils/communication/comm_utils.py | 2 +- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/algos/fl_dynamic.py b/src/algos/fl_dynamic.py index b96b38e0..8cb2c666 100644 --- a/src/algos/fl_dynamic.py +++ b/src/algos/fl_dynamic.py @@ -74,7 +74,7 @@ def get_neighbor_model_wts(self) -> List[Dict[str, TorchModelType]]: from all the neighbors because that's how most dynamic topologies work. """ - neighbor_models = self.comm_utils.all_gather(ignore_super_node=True) + neighbor_models: List[Dict[str, TorchModelType]] = self.comm_utils.all_gather(ignore_super_node=True) return neighbor_models def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -> List[float]: @@ -95,11 +95,12 @@ def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) - raise ValueError("Similarity metric {} not implemented".format(self.similarity)) return similarity_wts - def sample_neighbours(self, k: int) -> List[int]: + def sample_neighbours(self, k: int, mode: str|None = None) -> List[int]: """ We perform neighbor sampling after we have the similarity weights of all the neighbors. """ + assert mode is None or mode == "pull", "Only pull mode is supported for dynamic topology" if self.sampling == "closest": return select_smallest_k(self.similarity_wts, k) else: @@ -154,7 +155,7 @@ def __init__( self.topology = DynamicTopology(config, comm_utils, self) self.topology.initialize() - def get_representation(self, **kwargs: Any) -> TorchModelType: + def get_representation(self, **kwargs: Any) -> Dict[str, int|Dict[str, Any]]: """ Returns the model weights as representation. """ @@ -172,6 +173,8 @@ def run_protocol(self) -> None: epochs_per_round = self.config.get("epochs_per_round", 1) for it in range(start_round, total_rounds): + self.round_init() + # Train locally and send the representation to the server stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train( it, epochs_per_round @@ -179,17 +182,11 @@ def run_protocol(self) -> None: self.local_round_done() # Collect the representations from all other nodes from the server - neighbors = self.topology.recv_and_agg(self.num_collaborators) - # TODO: Log the neighbors - stats["neighbors"] = neighbors - - stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost() + collabs = self.topology.recv_and_agg(self.num_collaborators) - # evaluate the model on the test data - # Inside FedStaticNode.run_protocol() - stats["test_loss"], stats["test_acc"] = self.local_test() - stats.update(self.get_memory_metrics()) - self.log_metrics(stats=stats, iteration=it) + self.stats["neighbors"] = collabs + self.local_test() + self.round_finalize() diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index cbcc255a..60cc7104 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -159,10 +159,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIFAR10_DSET = "cifar10" CIAR10_DPATH = "./datasets/imgs/cifar10/" -NUM_COLLABORATORS = 1 -DUMP_DIR = "/tmp/" +NUM_COLLABORATORS = 3 +DUMP_DIR = "/tmp/new_sonar/" -num_users = 3 +num_users = 9 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, @@ -318,8 +318,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "exp_keys": [], } -num_users = 4 - dropout_dict: Any = { "distribution_dict": { # leave dict empty to disable dropout "method": "uniform", # "uniform", "normal" @@ -346,9 +344,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dpath": CIAR10_DPATH, "seed": 2, "device_ids": get_device_ids(num_users, gpu_ids), + "assign_based_on_host": True, # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore - "algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_loss]), # type: ignore - "samples_per_user": 10000 // num_users, # distributed equally + "algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_weights]), # type: ignore + "samples_per_user": 500, # distributed equally "train_label_distribution": "non_iid", "alpha_data": 0.1, "test_label_distribution": "iid", diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index 287c0026..5b2ed993 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -72,7 +72,7 @@ def receive(self, node_ids: List[int]) -> Any: def broadcast(self, data: Any, tag: int = 0): self.comm.broadcast(data) - def all_gather(self, tag: int = 0, ignore_super_node: bool = False): + def all_gather(self, tag: int = 0, ignore_super_node: bool = False) -> List[Dict[str, Any]]: return self.comm.all_gather(ignore_super_node=ignore_super_node) def send_quorum(self):