diff --git a/src/utils/community_utils.py b/src/utils/community_utils.py index 2c6f8873..7398f605 100644 --- a/src/utils/community_utils.py +++ b/src/utils/community_utils.py @@ -8,7 +8,7 @@ def get_random_communities( assert num_clients % num_communities == 0 clients_per_community = num_clients // num_communities - indices = np.random.permutation(range(1, num_clients + 1)) + indices: np.ndarray = np.random.permutation(range(1, num_clients + 1)) support: Dict[int, List[int]] = {} for i, c_id in enumerate(indices): @@ -40,7 +40,7 @@ def get_dset_balanced_communities( num_assigned = 0 communities: Dict[int, List[int]] = {k: [] for k in range(num_communities)} while num_assigned < num_clients: - communities_random_order = { + communities_random_order: Dict[int, int] = { k: v for k, v in enumerate(np.random.permutation(range(num_communities))) } communities_id = sorted( diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py index d18107b6..211cda3f 100644 --- a/src/utils/config_utils.py +++ b/src/utils/config_utils.py @@ -43,8 +43,8 @@ def process_config(config: Dict[str, Any]) -> Dict[str, Any]: # experiment_name += "_{}_{}".format(key, item) experiment_name += "_{}".format(item) - experiments_folder = config["dump_dir"] - results_path = experiments_folder + experiment_name + f"_seed{config['seed']}" + experiments_folder: str = config["dump_dir"] + results_path: str = experiments_folder + experiment_name + f"_seed{config['seed']}" log_path = results_path + "/logs/" images_path = results_path + "/images/" @@ -79,7 +79,7 @@ def get_device_ids(num_users: int, num_client_per_gpu: int, available_gpus: list assert num_users <= len(available_gpus) * num_client_per_gpu device_ids: Dict[str, List[int]] = {} - gpu_id = 0 + gpu_id: int = 0 for i in range(1, num_users + 1): device_ids[f"node_{i}"] = [available_gpus[gpu_id]] gpu_id = (gpu_id + 1) % len(available_gpus) # Alternate GPU assignment diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 28c3040d..14070590 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,5 +1,5 @@ import importlib -from typing import Any, List, Sequence, Tuple, Optional +from typing import Any, List, Sequence, Tuple, Optional, Dict import numpy as np import torch import torchvision.transforms as T @@ -41,7 +41,7 @@ def __getitem__(self, index: int): img = self.transform(img) return img, label - def __len__(self): + def __len__(self) -> int: return len(self.dset) # Custom dataset wrapper to apply corruption @@ -111,7 +111,7 @@ def get_dataset(dname: str, dpath: str): return dataset_class(dpath) -def filter_by_class(dataset: Subset[Any], classes: List[str]): +def filter_by_class(dataset: Subset[Any], classes: List[str]) -> Tuple[Subset[Any], np.ndarray]: """ Filters the dataset by specified classes. """ @@ -129,7 +129,7 @@ def random_samples( return Subset(dataset, indices), indices -def extr_noniid(train_dataset: Any, samples_per_user: int, classes: Sequence[int]): +def extr_noniid(train_dataset: Any, samples_per_user: int, classes: Sequence[int]) -> Tuple[Subset[Any]]: """ Extracts non-IID data from the training dataset. """ @@ -147,7 +147,7 @@ def cifar_extr_noniid( n_class: int, num_samples: int, rate_unbalance: float, -): +) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]: """ Extracts non-IID data for CIFAR-10 dataset. """ diff --git a/src/utils/dropout_utils.py b/src/utils/dropout_utils.py index fa70db1a..b64e723e 100644 --- a/src/utils/dropout_utils.py +++ b/src/utils/dropout_utils.py @@ -32,7 +32,7 @@ def __init__(self, node_id: int, dropout_dict: dict, rng: random.Random) -> None dropout_method_params = dropout_distribution_dict.get("parameters", {}) dropout_method = dropout_distribution_dict.get("method", "uniform") self.rng = rng - self.dropped_recently = False + self.dropped_recently: bool = False if dropout_method == "uniform": diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 97c639eb..0fc0f354 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -541,14 +541,14 @@ def substract_model_weights(self, model1, model2): res[key] = param - model2[key] return res - def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]): + def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]) -> List[str]: # Assume one layer is composed of multiple weights named as "layer_name.weight_name" reversed_model_wts = reversed(model_wts) last_key = next(reversed_model_wts) last_layer = last_key.split(".")[0] - last_layer_keys = [] + last_layer_keys: List[str] = [] while last_key is not None and last_key.startswith(last_layer + "."): last_layer_keys.append(last_key) @@ -557,16 +557,16 @@ def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]): def filter_model_weights( self, model_wts: OrderedDict[str, Tensor], key_to_ignore: List[str] - ): + ) -> OrderedDict[str, Tensor]: # Assume one layer is composed of multiple weights named as "layer_name.weight_name" - filtered_model_wts = OrderedDict() + filtered_model_wts: OrderedDict[str, Tensor] = OrderedDict() for key, param in model_wts.items(): if key not in key_to_ignore: filtered_model_wts[key] = param return filtered_model_wts - def get_memory_usage(self): + def get_memory_usage(self) -> int: """ Get the memory usage """