Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ expt_dump_old/
di_test/
imgs/
pascal/
<<<<<<< HEAD
data/
!src/inversefed/data/
=======
# data/
>>>>>>> b73ed37 (mia so far)
notes.txt
removeme*.png

Expand Down
310 changes: 250 additions & 60 deletions src/algos/base_class.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions src/algos/fl_static.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Module for FedStaticClient and FedStaticServer in Federated Learning.
"""
from typing import Any, Dict, OrderedDict, List
from typing import Any, Dict, OrderedDict, List, Tuple
from collections import OrderedDict, defaultdict

from utils.communication.comm_utils import CommunicationManager
Expand Down Expand Up @@ -38,7 +38,8 @@ def get_neighbors(self) -> List[int]:
self.stats["neighbors"] = neighbors # type: ignore, where the hell self.stats is coming from

return neighbors



def run_protocol(self) -> None:
"""
Runs the federated learning protocol for the client.
Expand All @@ -64,12 +65,12 @@ def run_protocol(self) -> None:

neighbors = self.get_neighbors()
# TODO: Log the neighbors
self.receive_and_aggregate(neighbors)
self.receive_and_aggregate(neighbors, it)
# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
self.local_test()

self.round_finalize()
# self.local_round_done()

class FedStaticServer(BaseFedAvgClient):
"""
Expand Down
32 changes: 32 additions & 0 deletions src/algos/topologies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def initialize(self) -> None:
self.generate_graph()
self._convert_labels_to_int()

# if graph is not fully conncted, print warning in red
if not nx.is_connected(self.graph):
print("\033[91m" + "Warning: Graph is not fully connected" + "\033[0m")

def get_all_neighbours(self) -> List[int]:
"""
Returns the list of neighbours of the current node
"""
# get all neighbours of the current node using the self.graph attribute
# NOTE: graph is 1-indexed, but our node IDs are 0-indexed
if self.graph is None:
raise ValueError("Graph not initialized")
return list(self.graph.neighbors(self.rank)) # type: ignore
Expand Down Expand Up @@ -104,3 +109,30 @@ def get_neighbourhood_size(self) -> int:
if self.graph is None:
raise ValueError("Graph not initialized")
return len(self.get_all_neighbours())

def calculate_graph_metrics(self, target_node:int) -> dict[str, int | dict[int, int]]:
G = self.graph
# Calculate graph density
density = nx.density(G)
print(f"Graph Density: {density}")

# Calculate shortest path lengths from every node to the target node
shortest_paths = nx.shortest_path_length(G, target=target_node)
print(f"Shortest Path Lengths to Target Node {target_node}: {shortest_paths}")

# Calculate centrality metrics
eigenvector_centrality = nx.eigenvector_centrality(G)
closeness_centrality = nx.closeness_centrality(G)
degree_centrality = nx.degree_centrality(G)

print("Eigenvector Centrality:", eigenvector_centrality)
print("Closeness Centrality:", closeness_centrality)
print("Degree Centrality:", degree_centrality)

return {
"density": density,
"shortest_paths": shortest_paths,
"eigenvector_centrality": eigenvector_centrality,
"closeness_centrality": closeness_centrality,
"degree_centrality": degree_centrality
}
25 changes: 22 additions & 3 deletions src/algos/topologies/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ def __init__(self, config: ConfigType, rank: int):
self.p = p
self.seed = config["seed"]

def generate_graph(self) -> None:
self.graph = nx.erdos_renyi_graph(self.num_users, self.p, self.seed)

def generate_graph(self, full:bool=False) -> None:
""" if full is True, generate a fully connected graph """
if full:
while True:
self.graph = nx.erdos_renyi_graph(self.num_users, self.p, seed=self.seed)
if nx.is_connected(self.graph):
break
else:
self.graph = nx.erdos_renyi_graph(self.num_users, self.p, seed=self.seed)

class WattsStrogatzTopology(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
Expand All @@ -145,6 +151,17 @@ def __init__(self, config: ConfigType, rank: int):
def generate_graph(self) -> None:
self.graph = nx.random_regular_graph(self.d, self.num_users, self.seed) # type: ignore

class LineTopology(BaseTopology):
""" test topology for debugging gradient disambiguation attack """
def __init__(self, config: ConfigType, rank: int):
super().__init__(config, rank)

def generate_graph(self) -> None:
self.graph = nx.Graph()
self.graph.add_node(0)
for i in range(1, self.num_users):
self.graph.add_node(i)
self.graph.add_edge(i-1, i)

class DynamicGraph(BaseTopology):
def __init__(self, config: ConfigType, rank: int):
Expand Down Expand Up @@ -277,6 +294,8 @@ def select_topology(config: ConfigType, rank: int) -> BaseTopology:
return RandomRegularTopology(config, rank)
if topology_name == "barbell":
return BarbellTopology(config, rank)
if topology_name == "line":
return LineTopology(config, rank)
if topology_name == "one_peer_exponential":
return OnePeerExponentialTopology(config, rank)
if topology_name == "hyper_hypercube":
Expand Down
9 changes: 4 additions & 5 deletions src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
traditional_fl: ConfigType = {
# Collaboration setup
"algo": "fedavg",
"rounds": 2,

"rounds": 5,
# Model parameters
"model": "resnet10",
"model_lr": 3e-4,
Expand Down Expand Up @@ -197,7 +196,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
"average_last_layer": True,
"mask_finetune_last_layer": False,
# params for model
"position": 0,
"position": 0,
}

fedstatic: ConfigType = {
Expand All @@ -207,9 +206,9 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
# "topology": {"name": "base_graph", "max_degree": 2}, # type: ignore
"rounds": 3,
# Model parameters
"optimizer": "sgd", # TODO comment out for real training
"optimizer": "sgd",
"model": "resnet10",
"model_lr": 3e-4,
"model_lr": 0.1, # lr for sgd
"batch_size": 256,
}

Expand Down
60 changes: 52 additions & 8 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):

CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"
MNIST_DEST = "mnist"
MNIST_DPATH = "./datasets/imgs/mnist/"

NUM_COLLABORATORS = 3
DUMP_DIR = "/tmp/new_sonar/"
Expand Down Expand Up @@ -340,7 +342,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"num_collaborators": NUM_COLLABORATORS,
"comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:32048"]}, # The super-node
"dset": CIFAR10_DSET,
"dump_dir": DUMP_DIR,
"dump_dir": f"{DUMP_DIR}test/ring_",
"dpath": CIAR10_DPATH,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
Expand All @@ -361,21 +363,63 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"num_users": num_users,
"num_collaborators": NUM_COLLABORATORS,
"comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node
"dset": CIFAR10_DSET,
"dump_dir": DUMP_DIR,
"dpath": CIAR10_DPATH,
"dset": MNIST_DEST,
"dump_dir": f"{DUMP_DIR}disambiguate/torus_",
# "dpath": CIAR10_DPATH,
"dpath": MNIST_DPATH,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "iid",
"test_label_distribution": "iid",
"exp_keys": [],
"assign_based_on_host": True,
"hostname_to_device_ids": {
"matlaber1": [2, 3, 4, 5, 6, 7],
"matlaber12": [0, 1, 2, 3],
"matlaber3": [0, 1, 2, 3],
"matlaber4": [0, 2, 3, 4, 5, 6, 7],
},
"dropout_dicts": dropout_dicts,
"gia":True,
"gia_attackers":[1]
"gia_attackers":[1],
"disaggregate": True
}

grpc_system_config_mia: ConfigType = {
"exp_id": "static",
"num_users": num_users,
"num_collaborators": NUM_COLLABORATORS,
"comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node
"dset": CIFAR10_DSET,
"dump_dir": f"{DUMP_DIR}mia/SGD/1.0/er_",
"dpath": CIFAR10_DSET,
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "non_iid",
"test_label_distribution": "iid",
"alpha_data": 1.0,
"test_samples_per_user": 500,
"exp_keys": [],
"assign_based_on_host": True,
"hostname_to_device_ids": {
"matlaber1": [2, 3, 4, 5, 6, 7],
"matlaber12": [0, 1, 2, 3],
"matlaber3": [0, 1, 2, 3],
"matlaber4": [0, 2, 3, 4, 5, 6, 7],
"matlaberp4": [0,1],
"matlaberp1": [0,1]

},
"dropout_dicts": dropout_dicts,
"mia":True,
}

current_config = grpc_system_config
# current_config = mpi_system_config

# current_config = grpc_system_config
current_config = mpi_system_config
# current_config = grpc_system_config_gia
# current_config = grpc_system_config_mia
5 changes: 3 additions & 2 deletions src/data_loaders/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class MNISTDataset:
"""

def __init__(self, dpath: str) -> None:
print(f"DEBUG dp: {dpath}")
self.image_size = 28
self.num_cls = 10
self.mean = 0.1307
Expand All @@ -27,8 +28,8 @@ def __init__(self, dpath: str) -> None:
]
)
self.train_dset = MNIST(
root=dpath, train=True, download=True, transform=self.train_transform
root=dpath, train=True, download=False, transform=self.train_transform
)
self.test_dset = MNIST(
root=dpath, train=False, download=True, transform=self.test_transform
root=dpath, train=False, download=False, transform=self.test_transform
)
2 changes: 1 addition & 1 deletion src/inversefed/data/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,4 @@ def _get_meanstd(dataset):
cc = torch.cat([trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1)
data_mean = torch.mean(cc, dim=1).tolist()
data_std = torch.std(cc, dim=1).tolist()
return data_mean, data_std
return data_mean, data_std
2 changes: 1 addition & 1 deletion src/inversefed/data/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ def metric(self, x=None, y=None):
return name, format
else:
value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0]
return value.detach(), name, format
return value.detach(), name, format
13 changes: 12 additions & 1 deletion src/inversefed/reconstruction_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def loss_fn(pred, labels):
return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
self.loss_fn = loss_fn
else:
assert labels.shape[0] == self.num_images
assert labels.shape[0] == self.num_images, f"Labels shape {labels.shape} does not match num_images {self.num_images}."
self.reconstruct_label = False

try:
Expand All @@ -109,6 +109,10 @@ def loss_fn(pred, labels):
else:
print('Choosing optimal result ...')
scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores?
if scores.numel() == 0:
print("All scores are NaN or -Inf. Cannot determine the optimal result.")
return None # or handle the situation as needed

optimal_index = torch.argmin(scores)
print(f'Optimal result score: {scores[optimal_index]:2.4f}')
stats['opt'] = scores[optimal_index].item()
Expand All @@ -128,6 +132,13 @@ def _init_images(self, img_shape):
raise ValueError()

def _run_trial(self, x_trial, input_data, labels, dryrun=False):
# Print data types
print(f"Input data type: {type(input_data)}")
if isinstance(input_data, torch.Tensor):
print(f"Input data tensor dtype: {input_data.dtype}, shape: {input_data.shape}")
if isinstance(labels, torch.Tensor):
print(f"Labels tensor dtype: {labels.dtype}, shape: {labels.shape}, labels value: {labels}")

x_trial.requires_grad = True
if self.reconstruct_label:
output_test = self.model(x_trial)
Expand Down
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
)

args: argparse.Namespace = parser.parse_args()

scheduler: Scheduler = Scheduler()

# Assign the configuration from the file paths provided via arguments
print(f"arg s {args.s}")
print(f"all args {args}")
scheduler.assign_config_by_path(args.s, args.b, args.super, args.host)
print("Config loaded")

Expand Down
4 changes: 2 additions & 2 deletions src/main_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
# Command for opening each process
command_list: List[str] = ["python", "main.py", "-host", args.host]
if args.dev == True:
command_list: List[str] = ["python", "main.py", "-b", "./configs/algo_config_test.py", "-s", "./configs/sys_config_test.py", "-host", args.host]
print("here!")
command_list: List[str] = ["python", "main.py", "-b", "./configs/algo_config_2.py", "-s", "./configs/sys_config_2.py", "-host", args.host]

# Start process for each user
for i in range(args.n):
print(f"Starting process for user {i}")
# start a Popen process
subprocess.Popen(command_list)
Loading