Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ def set_data_parameters(self, config: ConfigType) -> None:
self.classes_of_interest = classes
self.train_indices = train_indices
self.train_dset = train_dset
self.dloader = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
self._test_loader = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
self.dloader: DataLoader[Any] = DataLoader(train_dset, batch_size=len(train_dset), shuffle=False)
self._test_loader: DataLoader[Any] = DataLoader(test_dset, batch_size=len(test_dset), shuffle=False)
print("Using GIA data setup")
print(self.labels)
else:
Expand Down
23 changes: 10 additions & 13 deletions src/algos/fl_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_neighbor_model_wts(self) -> List[Dict[str, TorchModelType]]:
from all the neighbors because that's how
most dynamic topologies work.
"""
neighbor_models = self.comm_utils.all_gather(ignore_super_node=True)
neighbor_models: List[Dict[str, TorchModelType]] = self.comm_utils.all_gather(ignore_super_node=True)
return neighbor_models

def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -> List[float]:
Expand All @@ -95,11 +95,12 @@ def get_neighbor_similarity(self, others_wts: List[Dict[str, TorchModelType]]) -
raise ValueError("Similarity metric {} not implemented".format(self.similarity))
return similarity_wts

def sample_neighbours(self, k: int) -> List[int]:
def sample_neighbours(self, k: int, mode: str|None = None) -> List[int]:
"""
We perform neighbor sampling after
we have the similarity weights of all the neighbors.
"""
assert mode is None or mode == "pull", "Only pull mode is supported for dynamic topology"
if self.sampling == "closest":
return select_smallest_k(self.similarity_wts, k)
else:
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
self.topology = DynamicTopology(config, comm_utils, self)
self.topology.initialize()

def get_representation(self, **kwargs: Any) -> TorchModelType:
def get_representation(self, **kwargs: Any) -> Dict[str, int|Dict[str, Any]]:
"""
Returns the model weights as representation.
"""
Expand All @@ -172,24 +173,20 @@ def run_protocol(self) -> None:
epochs_per_round = self.config.get("epochs_per_round", 1)

for it in range(start_round, total_rounds):
self.round_init()

# Train locally and send the representation to the server
stats["train_loss"], stats["train_acc"], stats["train_time"] = self.local_train(
it, epochs_per_round
)
self.local_round_done()

# Collect the representations from all other nodes from the server
neighbors = self.topology.recv_and_agg(self.num_collaborators)
# TODO: Log the neighbors
stats["neighbors"] = neighbors

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
collabs = self.topology.recv_and_agg(self.num_collaborators)

# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()
stats.update(self.get_memory_metrics())
self.log_metrics(stats=stats, iteration=it)
self.stats["neighbors"] = collabs
self.local_test()
self.round_finalize()



Expand Down
13 changes: 6 additions & 7 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"

NUM_COLLABORATORS = 1
DUMP_DIR = "/tmp/"
NUM_COLLABORATORS = 3
DUMP_DIR = "/tmp/new_sonar/"

num_users = 3
num_users = 9
mpi_system_config: ConfigType = {
"exp_id": "",
"comm": {"type": "MPI"},
Expand Down Expand Up @@ -318,8 +318,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"exp_keys": [],
}

num_users = 4

dropout_dict: Any = {
"distribution_dict": { # leave dict empty to disable dropout
"method": "uniform", # "uniform", "normal"
Expand All @@ -346,9 +344,10 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dpath": CIAR10_DPATH,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
"assign_based_on_host": True,
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_loss]), # type: ignore
"samples_per_user": 10000 // num_users, # distributed equally
"algos": get_algo_configs(num_users=num_users, algo_configs=[fed_dynamic_weights]), # type: ignore
"samples_per_user": 500, # distributed equally
"train_label_distribution": "non_iid",
"alpha_data": 0.1,
"test_label_distribution": "iid",
Expand Down
2 changes: 1 addition & 1 deletion src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def receive(self, node_ids: List[int]) -> Any:
def broadcast(self, data: Any, tag: int = 0):
self.comm.broadcast(data)

def all_gather(self, tag: int = 0, ignore_super_node: bool = False):
def all_gather(self, tag: int = 0, ignore_super_node: bool = False) -> List[Dict[str, Any]]:
return self.comm.all_gather(ignore_super_node=ignore_super_node)

def send_quorum(self):
Expand Down
Loading