-
Notifications
You must be signed in to change notification settings - Fork 633
MAINT: Updating AttackExecutor to more generically call attacks #1270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT: Updating AttackExecutor to more generically call attacks #1270
Conversation
Co-authored-by: hannahwestra25 <hannahwestra@microsoft.com>
Co-authored-by: Roman Lutz <romanlutz13@gmail.com>
bashirpartovi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rlundeen2 for tackling this and fixing the scenarios bug. The overall direction makes sense to me to standardize attack invocation and making the seed group flow more consistent. I do have a few design concerns though, mostly around how parameters/context are modeled, that I think we should address now, otherwise this will be painful to debug later.
The current design introduces accepted_context_parameters and _excluded_context_parameters. I get why they were added (different attacks accept different parameters), but the approach relies on runtime introspection and implicit contracts. In my opinion, that makes the code more error-prone and the API harder to understand/maintain.
Looking back at how AttackContext is designed, it is currently mixing two different things:
- Inputs from the caller (
objective,next_message,prepended_conversation,memory_labels) - Execution state used internally during the attack (
start_time,related_conversations)
Because these are mixed, the executor has to introspect dataclass fields to figure out which ones are valid inputs. That is probably why accepted_context_parameters exists.
I think these should be separated into different types:
AttackParametershas immutable inputs from the callerAttackContexthas mutable execution state that holds a reference to the parameters
(see below)
The second issue is _excluded_context_parameters is pretty fragile. Some attacks (e.g. RolePlay) generate next_message and prepended_conversation internally. To prevent a caller from overwriting those, the attack has to override _excluded_context_parameters, which is easy to forget. New attacks have to remember to add exclusions, and those exclusions are not visible to the caller, e.g. they can pass next_message and it just silently gets dropped. It is also kind of inverted because attacks should declare what they accept not what they reject.
What I am proposing below solves this more cleanly, attacks explicitly define the params they accept, and if a field is not in the param class, the caller cannot pass it.
The 3rd issue is silent param dropping in the executor. With strict_param_matching=False (the default), unsupported parameters are silently ignored. That makes debugging a bit painful because the caller thinks they passed something, but it gets filtered out. If we use an explicit parameters type, this becomes much harder to mess up.
The 4th issue is the executor having knowledge of attack types. _build_per_attack_params_from_seed_groups hardcodes field names that are specific to certain attacks. If we add more attack types, this method grows with more hardcoded mappings. The executor should not need to know attack-specific field names. The approach I proposed below fixes this by adding from_seed_group() to AttackParameters, so subclasses handle their own extraction. The executor stays generic and just calls params_type.from_seed_group(sg, **overrides)
Proposed design
AttackParamsT = TypeVar("AttackParamsT", bound=AttackParameters)
AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext")
AttackStrategyResultT = TypeVar("AttackStrategyResultT", bound="AttackResult")
@dataclass(frozen=True)
class AttackParameters(ABC):
objective: str
memory_labels: Optional[Dict[str, str]] = None
@classmethod
@abstractmethod
def from_seed_group(cls, seed_group: SeedGroup, **overrides) -> AttackParameters:
...Then AttackContext becomes:
@dataclass
class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]):
params: AttackParamsT
...If we want to minimize the changes to implement this, we can add properties so existing call sites don't have to change immediately.
@dataclass
class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]):
params: AttackParamsT
@property
def objective(self) -> str:
return self.params.objective
@property
def memory_labels(self) -> Dict[str, str]:
return self.params.memory_labels or {}Then we change AttackStrategy to expose the attack params type:
class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC):
def __init__(
self,
*,
objective_target: PromptTarget,
context_type: type[AttackStrategyContextT],
params_type: type[AttackParamsT],
logger: logging.Logger = logger,
):
super().__init__(
context_type=context_type,
event_handler=_DefaultAttackStrategyEventHandler[AttackStrategyContextT, AttackStrategyResultT](
logger=logger
),
logger=logger,
)
self._objective_target = objective_target
self._params_type = params_type
If an attack needs extra params, it just extends AttackParameters, e.g.:
@dataclass(frozen=True)
class SingleTurnAttackParameters(AttackParameters):
@classmethod
def from_seed_group(cls, seed_group: SeedGroup, **overrides) -> SingleTurnAttackParameters:
# extraction logic specific to this attack
return ...
@dataclass
class SingleTurnAttackContext(AttackContext[SingleTurnAttackParameters]):
...
class SingleTurnAttackStrategy(AttackStrategy[SingleTurnAttackContext, AttackResult], ABC):
def __init__(
self,
*,
objective_target: PromptTarget,
context_type: type[SingleTurnAttackContext],
params_type: type[SingleTurnAttackParameters] = SingleTurnAttackParameters,
logger: logging.Logger = logger,
):
super().__init__(
objective_target=objective_target,
context_type=context_type,
params_type=params_type,
logger=logger,
)This makes the executor code much more intuitive and decoupled:
class AttackExecutor:
async def execute_attack_from_seed_groups_async(
self,
*,
attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT],
seed_groups: Sequence[SeedGroup],
field_overrides: Optional[Sequence[Dict[str, Any]]] = None,
return_partial_on_failure: bool = False,
) -> AttackExecutorResult[AttackStrategyResultT]:
if not seed_groups:
raise ValueError("At least one seed_group must be provided")
if field_overrides and len(field_overrides) != len(seed_groups):
raise ValueError("field_overrides length must match seed_groups length")
params_type = attack.params_type
params_list: List[AttackParams] = []
for i, sg in enumerate(seed_groups):
overrides = field_overrides[i] if field_overrides else {}
params = params_type.from_seed_group(sg, **overrides)
params_list.append(params)
return await self._execute_with_params_list_async(
attack=attack,
params_list=params_list,
return_partial_on_failure=return_partial_on_failure,
)
async def _execute_with_params_list_async(
self,
*,
attack: AttackStrategy,
params_list: Sequence[AttackParams],
return_partial_on_failure: bool = False,
) -> AttackExecutorResult:
semaphore = asyncio.Semaphore(self._max_concurrency)
async def run_one(params: AttackParams):
async with semaphore:
context = attack._context_type(params=params)
return await attack.execute_with_context_async(context=context)
tasks = [run_one(p) for p in params_list]
results = await asyncio.gather(*tasks, return_exceptions=True)
return self._process_execution_results(
objectives=[p.objective for p in params_list],
results_or_exceptions=results,
return_partial_on_failure=return_partial_on_failure,
)…ithub.com/rlundeen2/PyRIT into users/rlundeen/2025_12_19_executor_update
This PR updates AttackExecutor to call attacks more generally. It also standardizes ways to translate from SeedGroups to attack parameters.
AttackExecutor.execute_attack_asyncandAttackExecutor.execute_attack_from_seed_groups_asyncwhich take attack parameters and call attack strategies generally, passing the parameters that work for the attack used.AttackParametersto strategies so the AttackExecutor could know which parameters the strategy expects and how to deserialize them from SeedGroupsSeedGroupto add convenience properties to reference the pieces needed for an attackAtomicAttackto use the new paradigm, and fixed bugs where pieces were called incorrectlyTests:
One of my recent PRs broke two scenarios but didn't catch until e2e tests. Added applicable unit tests and ran e2e.
On commit 8366277: all e2e tests pass
There are integration test failures related to ScorerIdentity. But I think all relevant tests are working