diff --git a/src/algos/attack_add_noise.py b/src/algos/attack_add_noise.py index 2835269c..cf6c54d6 100644 --- a/src/algos/attack_add_noise.py +++ b/src/algos/attack_add_noise.py @@ -15,20 +15,21 @@ noisy_weights = attack.get_representation() """ -import random from collections import OrderedDict +from algos.base_attack import BaseAttack from typing import Dict from torch import Tensor from utils.types import ConfigType -class AddNoiseAttack: +class AddNoiseAttack(BaseAttack): """ A class that adds Gaussian noise to model weights. This can be used as a form of attack, often referred to as a 'backdoor attack', where the noise is introduced without causing significant noticeable changes in model performance. Attributes: + node_id (int): The unique identifier of the node used to set the seed. state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). noise_rate (float): The probability that noise will be added to each weight. noise_mean (float): The mean of the Gaussian noise. @@ -36,7 +37,7 @@ class AddNoiseAttack: """ def __init__( - self, config: ConfigType, state_dict: Dict[str, Tensor] + self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor] ) -> None: """ Initializes the AddNoiseAttack class with the provided configuration and model state. @@ -48,6 +49,7 @@ def __init__( state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). """ + super().__init__(node_id, config) self.state_dict = state_dict self.noise_rate = float(config.get("noise_rate", 1)) # type: ignore self.noise_mean = float(config.get("noise_mean", 0)) # type: ignore @@ -67,8 +69,8 @@ def get_representation(self) -> Dict[str, Tensor]: return OrderedDict( { key: ( - val + self.noise_std * random.gauss(self.noise_mean, self.noise_std) - if random.random() < self.noise_rate + val + self.noise_std * self.rng.gauss(self.noise_mean, self.noise_std) + if self.rng.random() < self.noise_rate else val ) for key, val in self.state_dict.items() diff --git a/src/algos/attack_bad_weights.py b/src/algos/attack_bad_weights.py index 83e4dcb1..ad4bf822 100644 --- a/src/algos/attack_bad_weights.py +++ b/src/algos/attack_bad_weights.py @@ -13,31 +13,33 @@ corrupted_weights = attack.get_representation() """ -import random from collections import OrderedDict +from algos.base_attack import BaseAttack from typing import Dict from torch import Tensor from utils.types import ConfigType -class BadWeightsAttack: +class BadWeightsAttack(BaseAttack): """ A class that applies corruption to a portion of the model's weights by scaling them with a predefined factor ('weight'). This can be used to simulate malicious attacks on the model. Attributes: + node_id (int): The unique identifier of the node used to set the seed. state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). weight (float): A factor by which corrupted weights are scaled. Default is 0. corrupt_portion (float): The proportion of weights to corrupt. A float between 0 and 1. """ def __init__( - self, config: ConfigType, state_dict: Dict[str, Tensor] + self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor] ) -> None: """ Initializes the BadWeightsAttack class with the provided configuration and model state. Args: + node_id (int): The unique identifier of the node used to set the seed. config (ConfigType): A configuration dictionary containing 'weight' and 'corrupt_portion'. 'weight' specifies the factor to scale corrupted weights, and 'corrupt_portion' defines @@ -45,6 +47,7 @@ def __init__( state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). """ + super().__init__(node_id, config) self.state_dict = state_dict self.weight = config.get("weight", 0) self.corrupt_portion = float(config.get("corrupt_portion", 1)) # type: ignore @@ -62,7 +65,7 @@ def get_representation(self) -> Dict[str, Tensor]: return OrderedDict( { key: ( - val * self.weight if random.random() < self.corrupt_portion else val + val * self.weight if self.rng.random() < self.corrupt_portion else val ) for key, val in self.state_dict.items() } diff --git a/src/algos/attack_sign_flip.py b/src/algos/attack_sign_flip.py index bb9ef0d4..baa82a99 100644 --- a/src/algos/attack_sign_flip.py +++ b/src/algos/attack_sign_flip.py @@ -13,38 +13,41 @@ flipped_weights = attack.get_representation() """ -import random from collections import OrderedDict +from algos.base_attack import BaseAttack from typing import Dict from torch import Tensor from utils.types import ConfigType -class SignFlipAttack: +class SignFlipAttack(BaseAttack): """ A class that flips the sign of a portion of model weights based on a configured flip rate. This can simulate an adversarial attack that introduces significant changes to the model's weights. Attributes: + node_id (int): The unique identifier of the node used to set the seed. state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). flip_rate (float): The probability that the sign of a weight will be flipped. A float between 0 and 1, where 1 means all weights are flipped. """ def __init__( - self, config: ConfigType, state_dict: Dict[str, Tensor] + self, node_id: int, config: ConfigType, state_dict: Dict[str, Tensor] ) -> None: """ Initializes the SignFlipAttack class with the provided configuration and model state. Args: + node_id (int): The unique identifier of the node used to set the seed. config (ConfigType): A configuration dictionary that contains 'flip_rate', which determines the probability of flipping the sign of a weight. state_dict (OrderedDict[str, Tensor]): A dictionary containing the model's state (weights). """ + super().__init__(node_id, config) self.state_dict = state_dict self.flip_rate = float(config.get("flip_rate", 1)) # type: ignore # TODO: Add conditions such as target label, source label, start/end epochs, or rounds for the attack. @@ -60,7 +63,7 @@ def get_representation(self) -> Dict[str, Tensor]: """ return OrderedDict( { - key: -1 * val if random.random() < self.flip_rate else val + key: -1 * val if self.rng.random() < self.flip_rate else val for key, val in self.state_dict.items() } ) diff --git a/src/algos/base_attack.py b/src/algos/base_attack.py new file mode 100644 index 00000000..4a60bd9f --- /dev/null +++ b/src/algos/base_attack.py @@ -0,0 +1,34 @@ +""" +This module implements the base class for malicious attacks BaseAttack. + +Usage: + node_id = self.node_id # base node's id +""" + +import random +from typing import Any +from utils.types import ConfigType + + + +class BaseAttack: + """ + A base class for attacks. + + Attributes: + node_id (int): The unique identifier of the node used to set the seed. + config (ConfigType): A configuration dictionary containing attack parameters and random seed. + """ + + def __init__( + self, node_id: int, config: ConfigType, *args: Any, **kwargs: Any + ) -> None: + """ + Initializes the AddNoiseAttack class with the provided configuration and model state. + + Args: + config (ConfigType): A configuration dictionary containing noise parameters like + 'seed', 'noise_rate', 'noise_mean', and 'noise_std'. Default values + are used if keys are missing. + """ + self.rng = random.Random(int(config.get("seed", 20)) * int(config.get("num_users", 9)) + node_id) # type: ignore diff --git a/src/algos/fl.py b/src/algos/fl.py index 625a61ae..484919b3 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -65,20 +65,18 @@ def get_model_weights(self, **kwargs: Any) -> Dict[str, Tensor]: elif malicious_type == "bad_weights": # Corrupt the weights return BadWeightsAttack( - self.config, self.model.state_dict() + self.node_id, self.config, self.model.state_dict() ).get_representation() elif malicious_type == "sign_flip": # Flip the sign of the weights, also TODO: consider label flipping return SignFlipAttack( - self.config, self.model.state_dict() + self.node_id, self.config, self.model.state_dict() ).get_representation() elif malicious_type == "add_noise": # Add noise to the weights return AddNoiseAttack( - self.config, self.model.state_dict() + self.node_id, self.config, self.model.state_dict() ).get_representation() - else: - return self.model.state_dict() # type: ignore return self.model.state_dict() # type: ignore def set_representation(self, representation: OrderedDict[str, Tensor]):