Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/algos/attack_add_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,29 @@
noisy_weights = attack.get_representation()
"""

import random
from collections import OrderedDict
from algos.base_attack import BaseAttack
from typing import Dict
from torch import Tensor
from utils.types import ConfigType


class AddNoiseAttack:
class AddNoiseAttack(BaseAttack):
"""
A class that adds Gaussian noise to model weights. This can be used as a form of attack,
often referred to as a 'backdoor attack', where the noise is introduced without causing
significant noticeable changes in model performance.

Attributes:
node_id (int): The unique identifier of the node used to set the seed.
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights).
noise_rate (float): The probability that noise will be added to each weight.
noise_mean (float): The mean of the Gaussian noise.
noise_std (float): The standard deviation of the Gaussian noise.
"""

def __init__(
self, config: ConfigType, state_dict: Dict[str, Tensor]
self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor]
) -> None:
"""
Initializes the AddNoiseAttack class with the provided configuration and model state.
Expand All @@ -48,6 +49,7 @@ def __init__(
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state
(weights).
"""
super().__init__(node_id, config)
self.state_dict = state_dict
self.noise_rate = float(config.get("noise_rate", 1)) # type: ignore
self.noise_mean = float(config.get("noise_mean", 0)) # type: ignore
Expand All @@ -67,8 +69,8 @@ def get_representation(self) -> Dict[str, Tensor]:
return OrderedDict(
{
key: (
val + self.noise_std * random.gauss(self.noise_mean, self.noise_std)
if random.random() < self.noise_rate
val + self.noise_std * self.rng.gauss(self.noise_mean, self.noise_std)
if self.rng.random() < self.noise_rate
else val
)
for key, val in self.state_dict.items()
Expand Down
11 changes: 7 additions & 4 deletions src/algos/attack_bad_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,41 @@
corrupted_weights = attack.get_representation()
"""

import random
from collections import OrderedDict
from algos.base_attack import BaseAttack
from typing import Dict
from torch import Tensor
from utils.types import ConfigType


class BadWeightsAttack:
class BadWeightsAttack(BaseAttack):
"""
A class that applies corruption to a portion of the model's weights by scaling them with
a predefined factor ('weight'). This can be used to simulate malicious attacks on the model.

Attributes:
node_id (int): The unique identifier of the node used to set the seed.
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights).
weight (float): A factor by which corrupted weights are scaled. Default is 0.
corrupt_portion (float): The proportion of weights to corrupt. A float between 0 and 1.
"""

def __init__(
self, config: ConfigType, state_dict: Dict[str, Tensor]
self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor]
) -> None:
"""
Initializes the BadWeightsAttack class with the provided configuration and model state.

Args:
node_id (int): The unique identifier of the node used to set the seed.
config (ConfigType): A configuration dictionary containing 'weight' and
'corrupt_portion'. 'weight' specifies the factor to
scale corrupted weights, and 'corrupt_portion' defines
the proportion of weights to corrupt.
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state
(weights).
"""
super().__init__(node_id, config)
self.state_dict = state_dict
self.weight = config.get("weight", 0)
self.corrupt_portion = float(config.get("corrupt_portion", 1)) # type: ignore
Expand All @@ -62,7 +65,7 @@ def get_representation(self) -> Dict[str, Tensor]:
return OrderedDict(
{
key: (
val * self.weight if random.random() < self.corrupt_portion else val
val * self.weight if self.rng.random() < self.corrupt_portion else val
)
for key, val in self.state_dict.items()
}
Expand Down
11 changes: 7 additions & 4 deletions src/algos/attack_sign_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,41 @@
flipped_weights = attack.get_representation()
"""

import random
from collections import OrderedDict
from algos.base_attack import BaseAttack
from typing import Dict
from torch import Tensor
from utils.types import ConfigType


class SignFlipAttack:
class SignFlipAttack(BaseAttack):
"""
A class that flips the sign of a portion of model weights based on a configured
flip rate. This can simulate an adversarial attack that introduces significant
changes to the model's weights.

Attributes:
node_id (int): The unique identifier of the node used to set the seed.
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights).
flip_rate (float): The probability that the sign of a weight will be flipped.
A float between 0 and 1, where 1 means all weights are flipped.
"""

def __init__(
self, config: ConfigType, state_dict: Dict[str, Tensor]
self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor]
) -> None:
"""
Initializes the SignFlipAttack class with the provided configuration and model state.

Args:
node_id (int): The unique identifier of the node used to set the seed.
config (ConfigType): A configuration dictionary that contains 'flip_rate',
which determines the probability of flipping the sign
of a weight.
state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state
(weights).
"""
super().__init__(node_id, config)
self.state_dict = state_dict
self.flip_rate = float(config.get("flip_rate", 1)) # type: ignore
# TODO: Add conditions such as target label, source label, start/end epochs, or rounds for the attack.
Expand All @@ -60,7 +63,7 @@ def get_representation(self) -> Dict[str, Tensor]:
"""
return OrderedDict(
{
key: -1 * val if random.random() < self.flip_rate else val
key: -1 * val if self.rng.random() < self.flip_rate else val
for key, val in self.state_dict.items()
}
)
34 changes: 34 additions & 0 deletions src/algos/base_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This module implements the base class for malicious attacks BaseAttack.

Usage:
node_id = self.node_id # base node's id
"""

import random
from typing import Any
from utils.types import ConfigType



class BaseAttack:
"""
A base class for attacks.

Attributes:
node_id (int): The unique identifier of the node used to set the seed.
config (ConfigType): A configuration dictionary containing attack parameters and random seed.
"""

def __init__(
self, node_id: int, config: ConfigType, *args: Any, **kwargs: Any
) -> None:
"""
Initializes the AddNoiseAttack class with the provided configuration and model state.

Args:
config (ConfigType): A configuration dictionary containing noise parameters like
'seed', 'noise_rate', 'noise_mean', and 'noise_std'. Default values
are used if keys are missing.
"""
self.rng = random.Random(int(config.get("seed", 20)) * int(config.get("num_users", 9)) + node_id) # type: ignore
8 changes: 3 additions & 5 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,18 @@ def get_model_weights(self, **kwargs: Any) -> Dict[str, Tensor]:
elif malicious_type == "bad_weights":
# Corrupt the weights
return BadWeightsAttack(
self.config, self.model.state_dict()
self.node_id, self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "sign_flip":
# Flip the sign of the weights, also TODO: consider label flipping
return SignFlipAttack(
self.config, self.model.state_dict()
self.node_id, self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "add_noise":
# Add noise to the weights
return AddNoiseAttack(
self.config, self.model.state_dict()
self.node_id, self.config, self.model.state_dict()
).get_representation()
else:
return self.model.state_dict() # type: ignore
return self.model.state_dict() # type: ignore

def set_representation(self, representation: OrderedDict[str, Tensor]):
Expand Down