-
Notifications
You must be signed in to change notification settings - Fork 633
Description
Is your feature request related to a problem? Please describe.
Since we now support auth token providers via the api_key argument just like the openai SDK it is somewhat unintuitive that synchronous ones won't work. There's a cryptic error being shown if you try that.
E.g., people may try the following with any openai target:
api_key=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
which for me yields
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File /workspace/pyrit/prompt_normalizer/prompt_normalizer.py:95, in PromptNormalizer.send_prompt_async(self, seed_group, target, conversation_id, request_converter_configurations, response_converter_configurations, labels, attack_identifier)
94 try:
---> 95 responses = await target.send_prompt_async(message=request)
96 self._memory.add_message_to_memory(request=request)
File /workspace/pyrit/prompt_target/common/utils.py:56, in limit_requests_per_minute.<locals>.set_max_rpm(*args, **kwargs)
54 await asyncio.sleep(60 / rpm)
---> 56 return await func(*args, **kwargs)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:189, in AsyncRetrying.wraps.<locals>.async_wrapped(*args, **kwargs)
188 async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
--> 189 return await copy(fn, *args, **kwargs)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:111, in AsyncRetrying.__call__(self, fn, *args, **kwargs)
110 while True:
--> 111 do = await self.iter(retry_state=retry_state)
112 if isinstance(do, DoAttempt):
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:153, in AsyncRetrying.iter(self, retry_state)
152 for action in self.iter_state.actions:
--> 153 result = await action(retry_state)
154 return result
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/_utils.py:99, in wrap_to_async_func.<locals>.inner(*args, **kwargs)
98 async def inner(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
---> 99 return call(*args, **kwargs)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/__init__.py:400, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
399 if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 400 self._add_action_func(lambda rs: rs.outcome.result())
401 return
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/concurrent/futures/_base.py:449, in Future.result(self, timeout)
448 elif self._state == FINISHED:
--> 449 return self.__get_result()
451 self._condition.wait(timeout)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/concurrent/futures/_base.py:401, in Future.__get_result(self)
400 try:
--> 401 raise self._exception
402 finally:
403 # Break a reference cycle with the exception in self._exception
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:114, in AsyncRetrying.__call__(self, fn, *args, **kwargs)
113 try:
--> 114 result = await fn(*args, **kwargs)
115 except BaseException: # noqa: B902
File /workspace/pyrit/prompt_target/openai/openai_chat_target.py:197, in OpenAIChatTarget.send_prompt_async(self, message)
196 # Use unified error handling - automatically detects ChatCompletion and validates
--> 197 response = await self._handle_openai_request(
198 api_call=lambda: self._async_client.chat.completions.create(**body),
199 request=message,
200 )
201 return [response]
File /workspace/pyrit/prompt_target/openai/openai_target.py:372, in OpenAITarget._handle_openai_request(self, api_call, request)
370 try:
371 # Execute the API call
--> 372 response = await api_call()
374 # Extract MessagePiece for validation and construction (most targets use single piece)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/resources/chat/completions/completions.py:2678, in AsyncCompletions.create(self, messages, model, audio, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, metadata, modalities, n, parallel_tool_calls, prediction, presence_penalty, prompt_cache_key, prompt_cache_retention, reasoning_effort, response_format, safety_identifier, seed, service_tier, stop, store, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, verbosity, web_search_options, extra_headers, extra_query, extra_body, timeout)
2677 validate_response_format(response_format)
-> 2678 return await self._post(
2679 "/chat/completions",
2680 body=await async_maybe_transform(
2681 {
2682 "messages": messages,
2683 "model": model,
2684 "audio": audio,
2685 "frequency_penalty": frequency_penalty,
2686 "function_call": function_call,
2687 "functions": functions,
2688 "logit_bias": logit_bias,
2689 "logprobs": logprobs,
2690 "max_completion_tokens": max_completion_tokens,
2691 "max_tokens": max_tokens,
2692 "metadata": metadata,
2693 "modalities": modalities,
2694 "n": n,
2695 "parallel_tool_calls": parallel_tool_calls,
2696 "prediction": prediction,
2697 "presence_penalty": presence_penalty,
2698 "prompt_cache_key": prompt_cache_key,
2699 "prompt_cache_retention": prompt_cache_retention,
2700 "reasoning_effort": reasoning_effort,
2701 "response_format": response_format,
2702 "safety_identifier": safety_identifier,
2703 "seed": seed,
2704 "service_tier": service_tier,
2705 "stop": stop,
2706 "store": store,
2707 "stream": stream,
2708 "stream_options": stream_options,
2709 "temperature": temperature,
2710 "tool_choice": tool_choice,
2711 "tools": tools,
2712 "top_logprobs": top_logprobs,
2713 "top_p": top_p,
2714 "user": user,
2715 "verbosity": verbosity,
2716 "web_search_options": web_search_options,
2717 },
2718 completion_create_params.CompletionCreateParamsStreaming
2719 if stream
2720 else completion_create_params.CompletionCreateParamsNonStreaming,
2721 ),
2722 options=make_request_options(
2723 extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
2724 ),
2725 cast_to=ChatCompletion,
2726 stream=stream or False,
2727 stream_cls=AsyncStream[ChatCompletionChunk],
2728 )
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_base_client.py:1794, in AsyncAPIClient.post(self, path, cast_to, body, files, options, stream, stream_cls)
1791 opts = FinalRequestOptions.construct(
1792 method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options
1793 )
-> 1794 return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_base_client.py:1512, in AsyncAPIClient.request(self, cast_to, options, stream, stream_cls)
1511 options = model_copy(input_options)
-> 1512 options = await self._prepare_options(options)
1514 remaining_retries = max_retries - retries_taken
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_client.py:669, in AsyncOpenAI._prepare_options(self, options)
667 @override
668 async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
--> 669 await self._refresh_api_key()
670 return await super()._prepare_options(options)
File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_client.py:665, in AsyncOpenAI._refresh_api_key(self)
664 if self._api_key_provider:
--> 665 self.api_key = await self._api_key_provider()
TypeError: object str can't be used in 'await' expression
The above exception was the direct cause of the following exception:
Exception Traceback (most recent call last)
File /workspace/pyrit/executor/core/strategy.py:345, in Strategy.execute_with_context_async(self, context)
344 await self._handle_event(event=StrategyEvent.ON_PRE_EXECUTE, context=context)
--> 345 result = await self._perform_async(context=context)
346 await self._handle_event(event=StrategyEvent.ON_POST_EXECUTE, context=context, result=result)
File /workspace/pyrit/executor/attack/single_turn/prompt_sending.py:187, in PromptSendingAttack._perform_async(self, context)
186 # Send the prompt
--> 187 response = await self._send_prompt_to_objective_target_async(prompt_group=prompt_group, context=context)
188 if not response:
File /workspace/pyrit/executor/attack/single_turn/prompt_sending.py:300, in PromptSendingAttack._send_prompt_to_objective_target_async(self, prompt_group, context)
289 """
290 Send the prompt to the target and return the response.
291
(...) 298 the request was filtered, blocked, or encountered an error.
299 """
--> 300 return await self._prompt_normalizer.send_prompt_async(
301 seed_group=prompt_group,
302 target=self._objective_target,
303 conversation_id=context.conversation_id,
304 request_converter_configurations=self._request_converters,
305 response_converter_configurations=self._response_converters,
306 labels=context.memory_labels, # combined with strategy labels at _setup()
307 attack_identifier=self.get_identifier(),
308 )
File /workspace/pyrit/prompt_normalizer/prompt_normalizer.py:124, in PromptNormalizer.send_prompt_async(self, seed_group, target, conversation_id, request_converter_configurations, response_converter_configurations, labels, attack_identifier)
123 cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None
--> 124 raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex
126 # handling empty responses message list and None responses
Exception: Error sending prompt with conversation ID: 64477327-d2e6-4d1c-905b-b02627012c5e
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[4], line 26
18 target = OpenAIChatTarget(
19 endpoint=endpoint,
20 api_key=api_key,
21 model_name="<omitted>"
22 )
24 attack = PromptSendingAttack(objective_target=target)
---> 26 result = await attack.execute_async(objective=jailbreak_prompt) # type: ignore
27 await ConsoleAttackResultPrinter().print_conversation_async(result=result) # type: ignore
File /workspace/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py:112, in SingleTurnAttackStrategy.execute_async(self, **kwargs)
104 raise ValueError(
105 "Attack can only specify one objective per turn. Objective parameter '%s' and seed"
106 " prompt group objective '%s' are both defined",
107 objective,
108 seed_group.objective.value,
109 )
111 system_prompt = get_kwarg_param(kwargs=kwargs, param_name="system_prompt", expected_type=str, required=False)
--> 112 return await super().execute_async(
113 **kwargs, seed_group=seed_group, system_prompt=system_prompt, objective=objective
114 )
File /workspace/pyrit/executor/attack/core/attack_strategy.py:262, in AttackStrategy.execute_async(self, **kwargs)
257 prepended_conversation = get_kwarg_param(
258 kwargs=kwargs, param_name="prepended_conversation", expected_type=list, required=False
259 )
260 kwargs["prepended_conversation"] = prepended_conversation
--> 262 return await super().execute_async(**kwargs, objective=objective, memory_labels=memory_labels)
File /workspace/pyrit/executor/core/strategy.py:362, in Strategy.execute_async(self, **kwargs)
355 """
356 Execute the strategy asynchronously with the given keyword arguments.
357
358 Returns:
359 StrategyResultT: The result of the strategy execution.
360 """
361 context = self._context_type(**kwargs)
--> 362 return await self.execute_with_context_async(context=context)
File /workspace/pyrit/executor/core/strategy.py:352, in Strategy.execute_with_context_async(self, context)
350 await self._handle_event(event=StrategyEvent.ON_ERROR, context=context, error=e)
351 # Raise a specific execution error
--> 352 raise RuntimeError(f"Strategy execution failed for {self.__class__.__name__}: {str(e)}") from e
RuntimeError: Strategy execution failed for PromptSendingAttack: Error sending prompt with conversation ID: 64477327-d2e6-4d1c-905b-b02627012c5e
The correct way would be to use get_azure_openai_auth(endpoint_url) which calls get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope) under the hood, i.e., the async version of the token provider.
Describe the solution you'd like
It would be nice if we could check if the token provider is async and warn users if it isn't. Alternatively, if that's not possible or too complicated we can also try and catch the exception from the openai SDK and add some context for users. In any case, all the docstrings should say that we can only accept the async token provider. Taking a step back, perhaps it's possible to detect synchronous ones, warn users that it's sync and we expect async, but then wrap it in an async token provider class?