Skip to content

BUG Entra auth with synchronous token providers #1252

@romanlutz

Description

@romanlutz

Is your feature request related to a problem? Please describe.

Since we now support auth token providers via the api_key argument just like the openai SDK it is somewhat unintuitive that synchronous ones won't work. There's a cryptic error being shown if you try that.

E.g., people may try the following with any openai target:

api_key=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")

which for me yields

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File /workspace/pyrit/prompt_normalizer/prompt_normalizer.py:95, in PromptNormalizer.send_prompt_async(self, seed_group, target, conversation_id, request_converter_configurations, response_converter_configurations, labels, attack_identifier)
     94 try:
---> 95     responses = await target.send_prompt_async(message=request)
     96     self._memory.add_message_to_memory(request=request)

File /workspace/pyrit/prompt_target/common/utils.py:56, in limit_requests_per_minute.<locals>.set_max_rpm(*args, **kwargs)
     54     await asyncio.sleep(60 / rpm)
---> 56 return await func(*args, **kwargs)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:189, in AsyncRetrying.wraps.<locals>.async_wrapped(*args, **kwargs)
    188 async_wrapped.statistics = copy.statistics  # type: ignore[attr-defined]
--> 189 return await copy(fn, *args, **kwargs)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:111, in AsyncRetrying.__call__(self, fn, *args, **kwargs)
    110 while True:
--> 111     do = await self.iter(retry_state=retry_state)
    112     if isinstance(do, DoAttempt):

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:153, in AsyncRetrying.iter(self, retry_state)
    152 for action in self.iter_state.actions:
--> 153     result = await action(retry_state)
    154 return result

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/_utils.py:99, in wrap_to_async_func.<locals>.inner(*args, **kwargs)
     98 async def inner(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
---> 99     return call(*args, **kwargs)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/__init__.py:400, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
    399 if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 400     self._add_action_func(lambda rs: rs.outcome.result())
    401     return

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/concurrent/futures/_base.py:449, in Future.result(self, timeout)
    448 elif self._state == FINISHED:
--> 449     return self.__get_result()
    451 self._condition.wait(timeout)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/concurrent/futures/_base.py:401, in Future.__get_result(self)
    400 try:
--> 401     raise self._exception
    402 finally:
    403     # Break a reference cycle with the exception in self._exception

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/tenacity/asyncio/__init__.py:114, in AsyncRetrying.__call__(self, fn, *args, **kwargs)
    113 try:
--> 114     result = await fn(*args, **kwargs)
    115 except BaseException:  # noqa: B902

File /workspace/pyrit/prompt_target/openai/openai_chat_target.py:197, in OpenAIChatTarget.send_prompt_async(self, message)
    196 # Use unified error handling - automatically detects ChatCompletion and validates
--> 197 response = await self._handle_openai_request(
    198     api_call=lambda: self._async_client.chat.completions.create(**body),
    199     request=message,
    200 )
    201 return [response]

File /workspace/pyrit/prompt_target/openai/openai_target.py:372, in OpenAITarget._handle_openai_request(self, api_call, request)
    370 try:
    371     # Execute the API call
--> 372     response = await api_call()
    374     # Extract MessagePiece for validation and construction (most targets use single piece)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/resources/chat/completions/completions.py:2678, in AsyncCompletions.create(self, messages, model, audio, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, metadata, modalities, n, parallel_tool_calls, prediction, presence_penalty, prompt_cache_key, prompt_cache_retention, reasoning_effort, response_format, safety_identifier, seed, service_tier, stop, store, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, verbosity, web_search_options, extra_headers, extra_query, extra_body, timeout)
   2677 validate_response_format(response_format)
-> 2678 return await self._post(
   2679     "/chat/completions",
   2680     body=await async_maybe_transform(
   2681         {
   2682             "messages": messages,
   2683             "model": model,
   2684             "audio": audio,
   2685             "frequency_penalty": frequency_penalty,
   2686             "function_call": function_call,
   2687             "functions": functions,
   2688             "logit_bias": logit_bias,
   2689             "logprobs": logprobs,
   2690             "max_completion_tokens": max_completion_tokens,
   2691             "max_tokens": max_tokens,
   2692             "metadata": metadata,
   2693             "modalities": modalities,
   2694             "n": n,
   2695             "parallel_tool_calls": parallel_tool_calls,
   2696             "prediction": prediction,
   2697             "presence_penalty": presence_penalty,
   2698             "prompt_cache_key": prompt_cache_key,
   2699             "prompt_cache_retention": prompt_cache_retention,
   2700             "reasoning_effort": reasoning_effort,
   2701             "response_format": response_format,
   2702             "safety_identifier": safety_identifier,
   2703             "seed": seed,
   2704             "service_tier": service_tier,
   2705             "stop": stop,
   2706             "store": store,
   2707             "stream": stream,
   2708             "stream_options": stream_options,
   2709             "temperature": temperature,
   2710             "tool_choice": tool_choice,
   2711             "tools": tools,
   2712             "top_logprobs": top_logprobs,
   2713             "top_p": top_p,
   2714             "user": user,
   2715             "verbosity": verbosity,
   2716             "web_search_options": web_search_options,
   2717         },
   2718         completion_create_params.CompletionCreateParamsStreaming
   2719         if stream
   2720         else completion_create_params.CompletionCreateParamsNonStreaming,
   2721     ),
   2722     options=make_request_options(
   2723         extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
   2724     ),
   2725     cast_to=ChatCompletion,
   2726     stream=stream or False,
   2727     stream_cls=AsyncStream[ChatCompletionChunk],
   2728 )

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_base_client.py:1794, in AsyncAPIClient.post(self, path, cast_to, body, files, options, stream, stream_cls)
   1791 opts = FinalRequestOptions.construct(
   1792     method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options
   1793 )
-> 1794 return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_base_client.py:1512, in AsyncAPIClient.request(self, cast_to, options, stream, stream_cls)
   1511 options = model_copy(input_options)
-> 1512 options = await self._prepare_options(options)
   1514 remaining_retries = max_retries - retries_taken

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_client.py:669, in AsyncOpenAI._prepare_options(self, options)
    667 @override
    668 async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
--> 669     await self._refresh_api_key()
    670     return await super()._prepare_options(options)

File ~/.conda/envs/pyrit-release-test1/lib/python3.12/site-packages/openai/_client.py:665, in AsyncOpenAI._refresh_api_key(self)
    664 if self._api_key_provider:
--> 665     self.api_key = await self._api_key_provider()

TypeError: object str can't be used in 'await' expression

The above exception was the direct cause of the following exception:

Exception                                 Traceback (most recent call last)
File /workspace/pyrit/executor/core/strategy.py:345, in Strategy.execute_with_context_async(self, context)
    344 await self._handle_event(event=StrategyEvent.ON_PRE_EXECUTE, context=context)
--> 345 result = await self._perform_async(context=context)
    346 await self._handle_event(event=StrategyEvent.ON_POST_EXECUTE, context=context, result=result)

File /workspace/pyrit/executor/attack/single_turn/prompt_sending.py:187, in PromptSendingAttack._perform_async(self, context)
    186 # Send the prompt
--> 187 response = await self._send_prompt_to_objective_target_async(prompt_group=prompt_group, context=context)
    188 if not response:

File /workspace/pyrit/executor/attack/single_turn/prompt_sending.py:300, in PromptSendingAttack._send_prompt_to_objective_target_async(self, prompt_group, context)
    289 """
    290 Send the prompt to the target and return the response.
    291 
   (...)    298         the request was filtered, blocked, or encountered an error.
    299 """
--> 300 return await self._prompt_normalizer.send_prompt_async(
    301     seed_group=prompt_group,
    302     target=self._objective_target,
    303     conversation_id=context.conversation_id,
    304     request_converter_configurations=self._request_converters,
    305     response_converter_configurations=self._response_converters,
    306     labels=context.memory_labels,  # combined with strategy labels at _setup()
    307     attack_identifier=self.get_identifier(),
    308 )

File /workspace/pyrit/prompt_normalizer/prompt_normalizer.py:124, in PromptNormalizer.send_prompt_async(self, seed_group, target, conversation_id, request_converter_configurations, response_converter_configurations, labels, attack_identifier)
    123     cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None
--> 124     raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex
    126 # handling empty responses message list and None responses

Exception: Error sending prompt with conversation ID: 64477327-d2e6-4d1c-905b-b02627012c5e

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 26
     18 target = OpenAIChatTarget(
     19     endpoint=endpoint,
     20     api_key=api_key,
     21     model_name="<omitted>"
     22 )
     24 attack = PromptSendingAttack(objective_target=target)
---> 26 result = await attack.execute_async(objective=jailbreak_prompt)  # type: ignore
     27 await ConsoleAttackResultPrinter().print_conversation_async(result=result)  # type: ignore

File /workspace/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py:112, in SingleTurnAttackStrategy.execute_async(self, **kwargs)
    104     raise ValueError(
    105         "Attack can only specify one objective per turn. Objective parameter '%s' and seed"
    106         " prompt group objective '%s' are both defined",
    107         objective,
    108         seed_group.objective.value,
    109     )
    111 system_prompt = get_kwarg_param(kwargs=kwargs, param_name="system_prompt", expected_type=str, required=False)
--> 112 return await super().execute_async(
    113     **kwargs, seed_group=seed_group, system_prompt=system_prompt, objective=objective
    114 )

File /workspace/pyrit/executor/attack/core/attack_strategy.py:262, in AttackStrategy.execute_async(self, **kwargs)
    257     prepended_conversation = get_kwarg_param(
    258         kwargs=kwargs, param_name="prepended_conversation", expected_type=list, required=False
    259     )
    260     kwargs["prepended_conversation"] = prepended_conversation
--> 262 return await super().execute_async(**kwargs, objective=objective, memory_labels=memory_labels)

File /workspace/pyrit/executor/core/strategy.py:362, in Strategy.execute_async(self, **kwargs)
    355 """
    356 Execute the strategy asynchronously with the given keyword arguments.
    357 
    358 Returns:
    359     StrategyResultT: The result of the strategy execution.
    360 """
    361 context = self._context_type(**kwargs)
--> 362 return await self.execute_with_context_async(context=context)

File /workspace/pyrit/executor/core/strategy.py:352, in Strategy.execute_with_context_async(self, context)
    350 await self._handle_event(event=StrategyEvent.ON_ERROR, context=context, error=e)
    351 # Raise a specific execution error
--> 352 raise RuntimeError(f"Strategy execution failed for {self.__class__.__name__}: {str(e)}") from e

RuntimeError: Strategy execution failed for PromptSendingAttack: Error sending prompt with conversation ID: 64477327-d2e6-4d1c-905b-b02627012c5e

The correct way would be to use get_azure_openai_auth(endpoint_url) which calls get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope) under the hood, i.e., the async version of the token provider.

Describe the solution you'd like

It would be nice if we could check if the token provider is async and warn users if it isn't. Alternatively, if that's not possible or too complicated we can also try and catch the exception from the openai SDK and add some context for users. In any case, all the docstrings should say that we can only accept the async token provider. Taking a step back, perhaps it's possible to detect synchronous ones, warn users that it's sync and we expect async, but then wrap it in an async token provider class?

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions