Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/utils/community_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_random_communities(
assert num_clients % num_communities == 0

clients_per_community = num_clients // num_communities
indices = np.random.permutation(range(1, num_clients + 1))
indices: np.ndarray = np.random.permutation(range(1, num_clients + 1))

support: Dict[int, List[int]] = {}
for i, c_id in enumerate(indices):
Expand Down Expand Up @@ -40,7 +40,7 @@ def get_dset_balanced_communities(
num_assigned = 0
communities: Dict[int, List[int]] = {k: [] for k in range(num_communities)}
while num_assigned < num_clients:
communities_random_order = {
communities_random_order: Dict[int, int] = {
k: v for k, v in enumerate(np.random.permutation(range(num_communities)))
}
communities_id = sorted(
Expand Down
6 changes: 3 additions & 3 deletions src/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def process_config(config: Dict[str, Any]) -> Dict[str, Any]:
# experiment_name += "_{}_{}".format(key, item)
experiment_name += "_{}".format(item)

experiments_folder = config["dump_dir"]
results_path = experiments_folder + experiment_name + f"_seed{config['seed']}"
experiments_folder: str = config["dump_dir"]
results_path: str = experiments_folder + experiment_name + f"_seed{config['seed']}"

log_path = results_path + "/logs/"
images_path = results_path + "/images/"
Expand Down Expand Up @@ -79,7 +79,7 @@ def get_device_ids(num_users: int, num_client_per_gpu: int, available_gpus: list
assert num_users <= len(available_gpus) * num_client_per_gpu
device_ids: Dict[str, List[int]] = {}

gpu_id = 0
gpu_id: int = 0
for i in range(1, num_users + 1):
device_ids[f"node_{i}"] = [available_gpus[gpu_id]]
gpu_id = (gpu_id + 1) % len(available_gpus) # Alternate GPU assignment
Expand Down
10 changes: 5 additions & 5 deletions src/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from typing import Any, List, Sequence, Tuple, Optional
from typing import Any, List, Sequence, Tuple, Optional, Dict
import numpy as np
import torch
import torchvision.transforms as T
Expand Down Expand Up @@ -41,7 +41,7 @@ def __getitem__(self, index: int):
img = self.transform(img)
return img, label

def __len__(self):
def __len__(self) -> int:
return len(self.dset)

# Custom dataset wrapper to apply corruption
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_dataset(dname: str, dpath: str):
return dataset_class(dpath)


def filter_by_class(dataset: Subset[Any], classes: List[str]):
def filter_by_class(dataset: Subset[Any], classes: List[str]) -> Tuple[Subset[Any], np.ndarray]:
"""
Filters the dataset by specified classes.
"""
Expand All @@ -129,7 +129,7 @@ def random_samples(
return Subset(dataset, indices), indices


def extr_noniid(train_dataset: Any, samples_per_user: int, classes: Sequence[int]):
def extr_noniid(train_dataset: Any, samples_per_user: int, classes: Sequence[int]) -> Tuple[Subset[Any]]:
"""
Extracts non-IID data from the training dataset.
"""
Expand All @@ -147,7 +147,7 @@ def cifar_extr_noniid(
n_class: int,
num_samples: int,
rate_unbalance: float,
):
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
"""
Extracts non-IID data for CIFAR-10 dataset.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/utils/dropout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, node_id: int, dropout_dict: dict, rng: random.Random) -> None
dropout_method_params = dropout_distribution_dict.get("parameters", {})
dropout_method = dropout_distribution_dict.get("method", "uniform")
self.rng = rng
self.dropped_recently = False
self.dropped_recently: bool = False


if dropout_method == "uniform":
Expand Down
10 changes: 5 additions & 5 deletions src/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,14 @@ def substract_model_weights(self, model1, model2):
res[key] = param - model2[key]
return res

def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]):
def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]) -> List[str]:
# Assume one layer is composed of multiple weights named as "layer_name.weight_name"

reversed_model_wts = reversed(model_wts)
last_key = next(reversed_model_wts)
last_layer = last_key.split(".")[0]

last_layer_keys = []
last_layer_keys: List[str] = []

while last_key is not None and last_key.startswith(last_layer + "."):
last_layer_keys.append(last_key)
Expand All @@ -557,16 +557,16 @@ def get_last_layer_keys(self, model_wts: OrderedDict[str, Tensor]):

def filter_model_weights(
self, model_wts: OrderedDict[str, Tensor], key_to_ignore: List[str]
):
) -> OrderedDict[str, Tensor]:
# Assume one layer is composed of multiple weights named as "layer_name.weight_name"

filtered_model_wts = OrderedDict()
filtered_model_wts: OrderedDict[str, Tensor] = OrderedDict()
for key, param in model_wts.items():
if key not in key_to_ignore:
filtered_model_wts[key] = param
return filtered_model_wts

def get_memory_usage(self):
def get_memory_usage(self) -> int:
"""
Get the memory usage
"""
Expand Down