diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/filter.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/filter.py index d3815a467..e01c460bb 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/filter.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/filter.py @@ -57,7 +57,7 @@ class FilterException(AgentToolkitException): def _get_collection( - request: Union[RequestCollection, RequestRelationCollection] + request: Union[RequestCollection, RequestRelationCollection], ) -> Union[CollectionCustomizer, Collection]: collection = request.collection if isinstance(request, RequestRelationCollection): diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/security/exceptions.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/security/exceptions.py index 42963e1f8..1741a19e2 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/security/exceptions.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/security/exceptions.py @@ -6,8 +6,9 @@ class AuthenticationException(AgentToolkitException): class OpenIdException(AuthenticationException): - def __init__(self, message: str, error: str, error_description: str, state: str) -> None: + def __init__(self, message: str, error: str, error_description: str, state: str, status: int = 401) -> None: super().__init__(message) self.error = error self.error_description = error_description self.state = state + self.STATUS = status diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/http.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/http.py index 51a454f80..097a1212d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/http.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/http.py @@ -5,6 +5,7 @@ from aiohttp.web import HTTPException from forestadmin.agent_toolkit.exceptions import AgentToolkitException from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.resources.security.exceptions import OpenIdException from forestadmin.agent_toolkit.utils.forest_schema.type import ForestSchema @@ -94,6 +95,7 @@ async def post( def _parse_forest_response(error: HTTPException): status = error.status server_message = None + name = None response_content = {} if error.text is not None and len(error.text) > 0: try: @@ -101,11 +103,12 @@ def _parse_forest_response(error: HTTPException): errors = response_content.get("errors", []) if len(errors) > 0: status = errors[0].get("status", status) + name = errors[0].get("name") server_message = errors[0].get("detail") except Exception: pass - return status, response_content, server_message + return status, name, response_content, server_message @staticmethod async def _handle_server_error(endpoint: str, error: Exception) -> Exception: @@ -117,9 +120,13 @@ async def _handle_server_error(endpoint: str, error: Exception) -> Exception: ) elif isinstance(error, HTTPException): - status, response_content, server_message = ForestHttpApi._parse_forest_response(error) + status, name, response_content, server_message = ForestHttpApi._parse_forest_response(error) if status in [-1, 0, 502]: new_error = ForestHttpApiException("Failed to reach ForestAdmin server. Are you online?") + elif status == 403: + new_error = OpenIdException( + response_content, name, server_message, status, status=status #  type:ignore + ) elif status == 404: new_error = ForestHttpApiException( "ForestAdmin server failed to find the project related to the envSecret you configured." diff --git a/src/agent_toolkit/tests/utils/test_http.py b/src/agent_toolkit/tests/utils/test_http.py index 95c06ce9e..c92a1f083 100644 --- a/src/agent_toolkit/tests/utils/test_http.py +++ b/src/agent_toolkit/tests/utils/test_http.py @@ -6,6 +6,7 @@ import aiohttp from aiohttp import client_exceptions from aiohttp.web import HTTPException +from forestadmin.agent_toolkit.resources.security.exceptions import OpenIdException from forestadmin.agent_toolkit.utils.http import ForestHttpApi, ForestHttpApiException, HttpOptions @@ -350,3 +351,32 @@ def test_handle_error_should_handle_not_http_errors(self): error_mock, ), ) + + def test_handle_error_should_wrap_openid_error_for_2fa_error(self): + error_mock = Mock(HTTPException) + error_mock.status = 403 + error_mock.text = ( + '{"errors":[{"status":403,"detail":"Two factor authentication is required to access this ' + + 'project","name":"TwoFactorAuthenticationRequiredForbiddenError"}]}' + ) + try: + self.loop.run_until_complete( + ForestHttpApi._handle_server_error( + "http://endpoint.fr", + error_mock, + ) + ) + except OpenIdException as exc: + self.assertEqual(exc.STATUS, 403) + self.assertEqual(exc.state, 403) + self.assertEqual( + exc.args[0], + ( + "🌳🌳🌳{'errors': [{'status': 403, 'detail': 'Two factor authentication is required to access this " + + "project', 'name': 'TwoFactorAuthenticationRequiredForbiddenError'}]}" + ), + ) + self.assertEqual(exc.error, "TwoFactorAuthenticationRequiredForbiddenError") + self.assertEqual(exc.error_description, "Two factor authentication is required to access this project") + else: + raise Exception("should have been in except block")