diff --git a/datajunction-server/datajunction_server/api/cubes.py b/datajunction-server/datajunction_server/api/cubes.py index c565c6c37..339b82f71 100644 --- a/datajunction-server/datajunction_server/api/cubes.py +++ b/datajunction-server/datajunction_server/api/cubes.py @@ -14,13 +14,15 @@ from datajunction_server.database.user import User from datajunction_server.errors import DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.internal.materializations import build_cube_materialization from datajunction_server.internal.nodes import ( get_all_cube_revisions_metadata, get_single_cube_revision_metadata, ) -from datajunction_server.models import access from datajunction_server.models.cube import ( CubeRevisionMetadata, DimensionValue, @@ -208,9 +210,7 @@ async def get_cube_dimension_sql( include_counts: bool = False, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> TranslatedSQL: """ Generates SQL to retrieve all unique values of a dimension for the cube @@ -222,7 +222,7 @@ async def get_cube_dimension_sql( node_revision, dimensions, current_user, - validate_access, + access_checker, filters, limit, include_counts, @@ -251,9 +251,7 @@ async def get_cube_dimension_values( request: Request, query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DimensionValues: """ All unique values of a dimension from the cube @@ -266,7 +264,7 @@ async def get_cube_dimension_values( cube, dimensions, current_user, - validate_access, + access_checker, filters, limit, include_counts, diff --git a/datajunction-server/datajunction_server/api/data.py b/datajunction-server/datajunction_server/api/data.py index abe5c0629..089374ec3 100644 --- a/datajunction-server/datajunction_server/api/data.py +++ b/datajunction-server/datajunction_server/api/data.py @@ -29,8 +29,9 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, - validate_access_requests, + AccessChecker, + AccessDenialMode, + get_access_checker, ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.models import access @@ -66,9 +67,7 @@ async def add_availability_state( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> JSONResponse: """ @@ -90,17 +89,13 @@ async def add_availability_state( # Source nodes require that any availability states set are for one of the defined tables node_revision = node.current # type: ignore - validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.WRITE, - access_object=access.Resource.from_node(node_revision), - ), - ], - True, + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource.from_node(node_revision), + ), ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) if node.current.type == NodeType.SOURCE: # type: ignore if ( @@ -190,9 +185,7 @@ async def remove_availability_state( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> JSONResponse: """ @@ -215,17 +208,13 @@ async def remove_availability_state( ), ) - validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.WRITE, - access_object=access.Resource.from_node(node), - ), - ], - True, + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource.from_node(node.current), # type: ignore + ), ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) # Save the old availability state for history record old_availability = ( @@ -281,10 +270,6 @@ async def get_data( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, cache: Cache = Depends(get_cache), ) -> QueryWithResults: @@ -310,8 +295,6 @@ async def get_data( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) @@ -362,10 +345,6 @@ async def get_data_stream_for_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, cache: Cache = Depends(get_cache), ) -> QueryWithResults: @@ -401,8 +380,6 @@ async def get_data_stream_for_node( engine_version=engine_version, use_materialized=True, ignore_errors=False, - current_user=current_user, - validate_access=validate_access, ), ) query_create = QueryCreate( @@ -468,10 +445,6 @@ async def get_data_for_metrics( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, ) -> QueryWithResults: @@ -497,8 +470,6 @@ async def get_data_for_metrics( engine_version=engine_version, use_materialized=True, ignore_errors=False, - current_user=current_user, - validate_access=validate_access, ), ) node = cast( @@ -544,20 +515,12 @@ async def get_data_stream_for_metrics( engine_name: Optional[str] = None, engine_version: Optional[str] = None, current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: """ Return data for a set of metrics with dimensions and filters using server sent events """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.READ, - ) - translated_sql, engine, catalog = await build_sql_for_multiple_metrics( session, metrics, @@ -567,7 +530,7 @@ async def get_data_stream_for_metrics( limit, engine_name, engine_version, - access_control, + access_checker, ) query_create = QueryCreate( diff --git a/datajunction-server/datajunction_server/api/deployments.py b/datajunction-server/datajunction_server/api/deployments.py index 4ec7cc945..03b05bb89 100644 --- a/datajunction-server/datajunction_server/api/deployments.py +++ b/datajunction-server/datajunction_server/api/deployments.py @@ -24,7 +24,9 @@ from datajunction_server.internal.deployment.utils import DeploymentContext from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, + AccessChecker, + AccessDenialMode, + get_access_checker, ) from datajunction_server.models import access from datajunction_server.models.deployment import DeploymentStatus @@ -158,22 +160,30 @@ async def create_deployment( current_user: User = Depends(get_current_user), query_service_client: QueryServiceClient = Depends(get_query_service_client), cache: Cache = Depends(get_cache), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DeploymentInfo: """ This endpoint takes a deployment specification (namespace, nodes, tags), topologically sorts and validates the deployable objects, and deploys the nodes in parallel where possible. It returns a summary of the deployment. """ + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource( + resource_type=access.ResourceType.NAMESPACE, + name=deployment_spec.namespace, + ), + ), + ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + deployment_id = await executor.submit( spec=deployment_spec, context=DeploymentContext( current_user=current_user, request=request, query_service_client=query_service_client, - validate_access=validate_access, background_tasks=background_tasks, cache=cache, ), diff --git a/datajunction-server/datajunction_server/api/dimensions.py b/datajunction-server/datajunction_server/api/dimensions.py index 2ebb7b5ca..2e8b72152 100644 --- a/datajunction-server/datajunction_server/api/dimensions.py +++ b/datajunction-server/datajunction_server/api/dimensions.py @@ -3,7 +3,7 @@ """ import logging -from typing import List, Optional +from typing import List, Optional, cast from fastapi import Depends, Query from sqlalchemy.ext.asyncio import AsyncSession @@ -11,11 +11,11 @@ from datajunction_server.models.node import NodeNameOutput from datajunction_server.api.helpers import get_node_by_name from datajunction_server.api.nodes import list_nodes -from datajunction_server.database.user import User +from datajunction_server.database.node import Node from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, - validate_access_requests, + AccessChecker, + get_access_checker, ) from datajunction_server.models import access from datajunction_server.models.node import NodeIndegreeOutput @@ -25,7 +25,6 @@ get_nodes_with_common_dimensions, ) from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -40,10 +39,7 @@ async def list_dimensions( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeIndegreeOutput]: """ List all available dimensions. @@ -52,8 +48,7 @@ async def list_dimensions( node_type=NodeType.DIMENSION, prefix=prefix, session=session, - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) node_indegrees = await get_dimension_dag_indegree(session, node_names) return sorted( @@ -71,39 +66,25 @@ async def find_nodes_with_dimension( *, node_type: List[NodeType] = Query([]), session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeNameOutput]: """ List all nodes that have the specified dimension """ - dimension_node = await get_node_by_name(session, name) + dimension_node = cast( + Node, + await Node.get_by_name(session, name, raise_if_not_exists=True), + ) + access_checker.add_node(dimension_node, access.ResourceAction.READ) + nodes = await get_nodes_with_common_dimensions( session, [dimension_node], node_type if node_type else None, ) - approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=node.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for node in nodes - ], - ) - ] - return [NodeNameOutput(name=node.name) for node in nodes if node.name in approvals] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + approved_nodes = await access_checker.approved_resource_names() + return [node for node in nodes if node.name in approved_nodes] @router.get("/dimensions/common/", response_model=List[NodeNameOutput]) @@ -112,10 +93,7 @@ async def find_nodes_with_common_dimensions( node_type: List[NodeType] = Query([]), *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeNameOutput]: """ Find all nodes that have the list of common dimensions @@ -125,22 +103,6 @@ async def find_nodes_with_common_dimensions( [await get_node_by_name(session, dim) for dim in dimension], # type: ignore node_type, ) - approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=node.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for node in nodes - ], - ) - ] - return [NodeNameOutput(name=node.name) for node in nodes if node.name in approvals] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + approved_resource_names = await access_checker.approved_resource_names() + return [node for node in nodes if node.name in approved_resource_names] diff --git a/datajunction-server/datajunction_server/api/djsql.py b/datajunction-server/datajunction_server/api/djsql.py index 6f01edb84..3de60edf2 100644 --- a/datajunction-server/datajunction_server/api/djsql.py +++ b/datajunction-server/datajunction_server/api/djsql.py @@ -9,14 +9,14 @@ from sse_starlette.sse import EventSourceResponse from datajunction_server.api.helpers import build_sql_for_dj_query, query_event_stream -from datajunction_server.database.user import User from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.models.query import QueryCreate, QueryWithResults from datajunction_server.service_clients import QueryServiceClient from datajunction_server.utils import ( - get_current_user, get_query_service_client, get_session, get_settings, @@ -36,24 +36,16 @@ async def get_data_for_djsql( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: """ Return data for a DJ SQL query """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.EXECUTE, - ) translated_sql, engine, catalog = await build_sql_for_dj_query( session, query, - access_control, + access_checker, engine_name, engine_version, ) @@ -86,24 +78,16 @@ async def get_data_stream_for_djsql( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: # pragma: no cover """ Return data for a DJ SQL query using server side events """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.EXECUTE, - ) translated_sql, engine, catalog = await build_sql_for_dj_query( session, query, - access_control, + access_checker, engine_name, engine_version, ) diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index 088ba548b..fa3b0f345 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -9,6 +9,7 @@ from strawberry.types import Info from datajunction_server.internal.caching.cachelib_cache import get_cache +from datajunction_server.internal.access.authentication.http import DJHTTPBearer from datajunction_server.api.graphql.queries.catalogs import list_catalogs from datajunction_server.api.graphql.queries.dag import ( common_dimensions, @@ -82,6 +83,7 @@ async def get_context( background_tasks: BackgroundTasks, db_session=Depends(get_session), cache=Depends(get_cache), + _auth=Depends(DJHTTPBearer(auto_error=False)), ): """ Provides the context for graphql requests diff --git a/datajunction-server/datajunction_server/api/helpers.py b/datajunction-server/datajunction_server/api/helpers.py index 5305d97e6..96d4ab41f 100644 --- a/datajunction-server/datajunction_server/api/helpers.py +++ b/datajunction-server/datajunction_server/api/helpers.py @@ -18,6 +18,10 @@ from sqlalchemy.orm import defer, joinedload, selectinload from sqlalchemy.sql.operators import and_, is_ +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.api.notifications import get_notifier from datajunction_server.construction.build import ( get_default_criteria, @@ -207,7 +211,8 @@ async def get_query( orderby: List[str], limit: Optional[int] = None, engine: Optional[Engine] = None, - access_control: Optional[access.AccessControlStore] = None, + *, + access_checker: AccessChecker, use_materialized: bool = True, query_parameters: Optional[Dict[str, str]] = None, ignore_errors: bool = True, @@ -227,7 +232,7 @@ async def get_query( if ignore_errors: query_builder.ignore_errors() query_ast = await ( - query_builder.with_access_control(access_control) + query_builder.with_access_control(access_checker) .with_build_criteria(build_criteria) .add_dimensions(dimensions) .add_filters(filters) @@ -776,7 +781,7 @@ async def query_event_stream( async def build_sql_for_dj_query( # pragma: no cover session: AsyncSession, query: str, - access_control: access.AccessControl, + access_checker: AccessChecker, engine_name: Optional[str] = None, engine_version: Optional[str] = None, ) -> Tuple[TranslatedSQL, Engine, Catalog]: @@ -787,11 +792,12 @@ async def build_sql_for_dj_query( # pragma: no cover query_ast, dj_nodes = await build_dj_query(session, query) for node in dj_nodes: # pragma: no cover - access_control.add_request_by_node( # pragma: no cover + access_checker.add_node( # pragma: no cover node.current, + access.ResourceAction.READ, ) - access_control.validate_and_raise() # pragma: no cover + await access_checker.check(on_denied=AccessDenialMode.RAISE) # pragma: no cover leading_metric_node = dj_nodes[0] # pragma: no cover available_engines = ( # pragma: no cover diff --git a/datajunction-server/datajunction_server/api/materializations.py b/datajunction-server/datajunction_server/api/materializations.py index d2399729d..6bfbc0cc7 100644 --- a/datajunction-server/datajunction_server/api/materializations.py +++ b/datajunction-server/datajunction_server/api/materializations.py @@ -22,14 +22,16 @@ from datajunction_server.database.user import User from datajunction_server.errors import DJDoesNotExistException, DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.materializations import ( create_new_materialization, schedule_materialization_jobs, ) from datajunction_server.materialization.jobs import MaterializationJob -from datajunction_server.models import access from datajunction_server.models.base import labelize from datajunction_server.models.cube_materialization import UpsertCubeMaterialization from datajunction_server.models.node import AvailabilityStateInfo @@ -97,9 +99,7 @@ async def upsert_materialization( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add or update a materialization of the specified node. If a node_name is specified @@ -136,7 +136,7 @@ async def upsert_materialization( session, current_revision, materialization, - validate_access, # type: ignore + access_checker, # type: ignore current_user=current_user, ) diff --git a/datajunction-server/datajunction_server/api/metrics.py b/datajunction-server/datajunction_server/api/metrics.py index 417d4c344..206fb54b2 100644 --- a/datajunction-server/datajunction_server/api/metrics.py +++ b/datajunction-server/datajunction_server/api/metrics.py @@ -13,13 +13,14 @@ from datajunction_server.api.nodes import list_nodes from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.database.user import User from datajunction_server.errors import DJError, DJInvalidInputException, ErrorCode from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.models.metric import Metric from datajunction_server.models.node import ( DimensionAttributeOutput, @@ -30,7 +31,6 @@ from datajunction_server.models.node_type import NodeType from datajunction_server.sql.dag import get_dimensions, get_shared_dimensions from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -67,10 +67,7 @@ async def list_metrics( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, ) -> List[str]: @@ -83,8 +80,7 @@ async def list_metrics( node_type=NodeType.METRIC, prefix=prefix, session=session, - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) background_tasks.add_task(cache.set, "metrics", metrics) return metrics diff --git a/datajunction-server/datajunction_server/api/namespaces.py b/datajunction-server/datajunction_server/api/namespaces.py index 42a072bcf..d2011af76 100644 --- a/datajunction-server/datajunction_server/api/namespaces.py +++ b/datajunction-server/datajunction_server/api/namespaces.py @@ -15,12 +15,14 @@ from datajunction_server.database.node import Node from datajunction_server.database.user import User from datajunction_server.errors import DJAlreadyExistsException +from datajunction_server.models.access import ResourceAction from datajunction_server.models.deployment import CubeSpec, DeploymentSpec from datajunction_server.models.dimensionlink import LinkType from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, - validate_access_requests, + AccessChecker, + get_access_checker, + AccessDenialMode, ) from datajunction_server.internal.namespaces import ( create_namespace, @@ -109,30 +111,17 @@ async def create_node_namespace( ) async def list_namespaces( session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NamespaceOutput]: """ List namespaces with the number of nodes contained in them """ results = await NodeNamespace.get_all_with_node_count(session) - resource_requests = [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_namespace(record.namespace), - ) - for record in results - ] - approvals = validate_access_requests( - validate_access, - current_user, - resource_requests=resource_requests, + access_checker.add_namespaces( + [record.namespace for record in results], + access.ResourceAction.READ, ) - approved_namespaces: List[str] = [ - request.access_object.name for request in approvals - ] + approved_namespaces = await access_checker.approved_resource_names() return [ NamespaceOutput(namespace=record.namespace, num_nodes=record.num_nodes) for record in results @@ -156,17 +145,38 @@ async def list_nodes_in_namespace( description="Whether to include a list of users who edited each node", ), session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeMinimumDetail]: """ List node names in namespace, filterable to a given type if desired. """ - return await NodeNamespace.list_nodes( + # Check that the user has namespace-level READ access + access_checker.add_namespace(namespace, access.ResourceAction.READ) + namespace_decisions = await access_checker.check( + on_denied=AccessDenialMode.FILTER, + ) + if not namespace_decisions: + # User has no access to this namespace at all + return [] + + # Get all nodes in namespace + nodes = await NodeNamespace.list_nodes( session, namespace, type_, with_edited_by=with_edited_by, ) + # Filter to nodes the user has READ access to + access_checker.add_nodes(nodes=nodes, action=access.ResourceAction.READ) + node_decisions = await access_checker.check(on_denied=AccessDenialMode.RETURN) + approved_names = { + decision.request.access_object.name + for decision in node_decisions + if decision.approved + } + return [node for node in nodes if node.name in approved_names] + @router.delete("/namespaces/{namespace}/", status_code=HTTPStatus.OK) async def deactivate_a_namespace( @@ -182,10 +192,14 @@ async def deactivate_a_namespace( query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, request: Request, + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Deactivates a node namespace """ + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_namespace = await NodeNamespace.get( session, namespace, @@ -266,10 +280,14 @@ async def restore_a_namespace( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Restores a node namespace """ + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_namespace = await get_node_namespace( session=session, namespace=namespace, @@ -340,6 +358,7 @@ async def hard_delete_node_namespace( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Hard delete a namespace, which will completely remove the namespace. Additionally, @@ -347,6 +366,9 @@ async def hard_delete_node_namespace( is set to true. If cascade is set to false, we'll raise an error. This should be used with caution, as the impact may be large. """ + access_checker.add_namespace(namespace, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + impacts = await hard_delete_namespace( session=session, namespace=namespace, @@ -371,11 +393,15 @@ async def export_a_namespace( namespace: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[Dict]: """ Generates a zip of YAML files for the contents of the given namespace as well as a project definition file. """ + access_checker.add_namespace(namespace, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await get_project_config( session=session, nodes=await get_nodes_in_namespace_detailed(session, namespace), @@ -400,10 +426,14 @@ async def export_namespace_spec( namespace: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DeploymentSpec: """ Generates a deployment spec for a namespace """ + access_checker.add_namespace(namespace, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + nodes = await NodeNamespace.list_all_nodes( session, namespace, diff --git a/datajunction-server/datajunction_server/api/nodes.py b/datajunction-server/datajunction_server/api/nodes.py index 0ead3ca85..343867416 100644 --- a/datajunction-server/datajunction_server/api/nodes.py +++ b/datajunction-server/datajunction_server/api/nodes.py @@ -16,6 +16,7 @@ from sqlalchemy.sql.operators import is_ from starlette.requests import Request + from datajunction_server.api.helpers import ( get_catalog_by_name, get_column, @@ -42,9 +43,11 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, - validate_access_requests, + AccessChecker, + get_access_checker, + AccessDenialMode, ) +from datajunction_server.models.access import ResourceAction from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.nodes import ( activate_node, @@ -159,10 +162,14 @@ async def revalidate( save_history: Callable = Depends(get_save_history), *, background_tasks: BackgroundTasks, + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeStatusDetails: """ Revalidate a single existing node and update its status appropriately """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_validator = await revalidate_node( name=name, session=session, @@ -209,10 +216,14 @@ async def set_column_attributes( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[ColumnOutput]: """ Set column attributes for the node. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -237,29 +248,14 @@ async def list_nodes( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[str]: """ List the available nodes. """ nodes = await Node.find(session, prefix, node_type) # type: ignore - return [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(node), - ) - for node in nodes - ], - ) - ] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + return await access_checker.approved_resource_names() @router.get("/nodes/details/", response_model=List[NodeIndexItem]) @@ -268,10 +264,7 @@ async def list_all_nodes_with_details( node_type: Optional[NodeType] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeIndexItem]: """ List the available nodes. @@ -301,24 +294,19 @@ async def list_all_nodes_with_details( "%s limit reached when returning all nodes, all nodes may not be captured in results", settings.node_list_max, ) - approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=row.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for row in results - ], - ) - ] + access_checker.add_requests( + [ + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource( + name=row.name, + resource_type=access.ResourceType.NODE, + ), + ) + for row in results + ], + ) + approvals = await access_checker.approved_resource_names() return [row for row in results if row.name in approvals] @@ -327,10 +315,14 @@ async def get_node( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Show the active version of the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -350,10 +342,14 @@ async def delete_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, request: Request, + access_checker: AccessChecker = Depends(get_access_checker), ): """ Delete (aka deactivate) the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await deactivate_node( session=session, name=name, @@ -375,11 +371,15 @@ async def hard_delete( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Hard delete a node, destroying all links and invalidating all downstream nodes. This should be used with caution, deactivating a node is preferred. """ + access_checker.add_request_by_node_name(name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + impact = await hard_delete_node( name=name, session=session, @@ -402,10 +402,14 @@ async def restore_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ): """ Restore (aka re-activate) the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await activate_node( session=session, name=name, @@ -423,10 +427,14 @@ async def list_node_revisions( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeRevisionOutput]: """ List all revisions for the node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -446,9 +454,7 @@ async def create_source( current_user: User = Depends(get_current_user), request: Request, query_service_client: QueryServiceClient = Depends(get_query_service_client), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), ) -> NodeOutput: @@ -456,13 +462,17 @@ async def create_source( Create a source node. If columns are not provided, the source node's schema will be inferred using the configured query service. """ + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await create_a_source_node( data=data, request=request, session=session, current_user=current_user, query_service_client=query_service_client, - validate_access=validate_access, + access_checker=access_checker, background_tasks=background_tasks, save_history=save_history, ) @@ -494,9 +504,7 @@ async def create_node( current_user: User = Depends(get_current_user), query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), cache: Cache = Depends(get_cache), ) -> NodeOutput: @@ -504,6 +512,11 @@ async def create_node( Create a node. """ node_type = NodeType(os.path.basename(os.path.normpath(request.url.path))) + + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await create_a_node( data=data, request=request, @@ -512,7 +525,7 @@ async def create_node( current_user=current_user, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, cache=cache, ) @@ -532,14 +545,28 @@ async def create_cube( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> NodeOutput: """ Create a cube node. """ + # Check WRITE access on the namespace for creating the cube + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + + # Check READ access on all metrics and dimensions being included in the cube + if data.metrics: + for metric_name in data.metrics: + access_checker.add_request_by_node_name(metric_name, ResourceAction.READ) + if data.dimensions: + for dim_attr in data.dimensions: + # Dimension attributes are in format "node_name.column_name" + dim_node_name = dim_attr.rsplit(".", 1)[0] + access_checker.add_request_by_node_name(dim_node_name, ResourceAction.READ) + + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await create_a_cube( data=data, request=request, @@ -547,7 +574,7 @@ async def create_cube( current_user=current_user, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, ) @@ -575,6 +602,7 @@ async def register_table( current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Register a table. This creates a source node in the SOURCE_NODE_NAMESPACE and @@ -602,6 +630,8 @@ async def register_table( current_user=current_user, save_history=save_history, ) + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) # Use reflection to get column names and types _catalog = await get_catalog_by_name(session=session, name=catalog) @@ -628,6 +658,7 @@ async def register_table( background_tasks=background_tasks, save_history=save_history, request=request, + access_checker=access_checker, ) @@ -649,6 +680,7 @@ async def register_view( current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Register a view by creating the view in the database and adding a source node for it. @@ -666,6 +698,9 @@ async def register_view( view_name = f"{schema_}.{view}" await raise_if_node_exists(session, node_name) + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + # Re-create the view in the database _catalog = await get_catalog_by_name(session=session, name=catalog) or_replace = "OR REPLACE" if replace else "" @@ -717,6 +752,7 @@ async def register_view( background_tasks=background_tasks, save_history=save_history, request=request, + access_checker=access_checker, ) @@ -729,6 +765,7 @@ async def link_dimension( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add a simple dimension link from a node column to a dimension node. @@ -736,6 +773,10 @@ async def link_dimension( 2. If no `dimension_column` is provided, the primary key column of the dimension node will be used as the join column for the link. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + activity_type = await upsert_simple_dimension_link( session, name, @@ -771,10 +812,15 @@ async def add_reference_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add reference dimension link to a node column """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension_node, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await upsert_reference_dimension_link( session=session, node_name=node_name, @@ -803,10 +849,14 @@ async def remove_reference_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Remove reference dimension link from a node column """ + access_checker.add_request_by_node_name(node_name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name(session, node_name, raise_if_not_exists=True) target_column = await get_column(session, node.current, node_column) # type: ignore if target_column.dimension_id or target_column.dimension_column: @@ -853,11 +903,19 @@ async def add_complex_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Links a source, dimension, or transform node to a dimension with a custom join query. If a link already exists, updates the link definition. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name( + link_input.dimension_node, + ResourceAction.READ, + ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + activity_type = await upsert_complex_dimension_link( session, node_name, @@ -888,10 +946,16 @@ async def remove_complex_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Removes a complex dimension link based on the dimension node and its role (if any). """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name( + link_identifier.dimension_node, + ResourceAction.READ, + ) return await remove_dimension_link( session, node_name, @@ -910,10 +974,15 @@ async def delete_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Remove the link between a node column and a dimension node """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await remove_dimension_link( session, name, @@ -936,10 +1005,14 @@ async def tags_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add a tag to a node """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name(session=session, name=name) existing_tags = {tag.name for tag in node.tags} # type: ignore if not tag_names: @@ -990,10 +1063,14 @@ async def refresh_source_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Refresh a source node with the latest columns from the query service. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await refresh_source( # type: ignore name=name, session=session, @@ -1015,15 +1092,29 @@ async def update_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), cache: Cache = Depends(get_cache), ) -> NodeOutput: """ Update a node. """ + # Check WRITE access on the node being updated + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + + # For cube updates: check READ access on any metrics/dimensions being added + # (user must have access to read nodes they're including in the cube) + if data.metrics: + for metric_name in data.metrics: + access_checker.add_request_by_node_name(metric_name, ResourceAction.READ) + if data.dimensions: + for dim_attr in data.dimensions: + # Dimension attributes are in format "node_name.column_name" + dim_node_name = dim_attr.rsplit(".", 1)[0] + access_checker.add_request_by_node_name(dim_node_name, ResourceAction.READ) + + await access_checker.check(on_denied=AccessDenialMode.RAISE) + request_headers = dict(request.headers) await update_any_node( name, @@ -1032,7 +1123,7 @@ async def update_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, request_headers=request_headers, save_history=save_history, refresh_materialization=refresh_materialization, @@ -1053,10 +1144,15 @@ async def calculate_node_similarity( node2_name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Compare two nodes by how similar their queries are """ + access_checker.add_request_by_node_name(node1_name, ResourceAction.READ) + access_checker.add_request_by_node_name(node2_name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node1 = await Node.get_by_name( session, node1_name, @@ -1091,12 +1187,16 @@ async def list_downstream_nodes( node_type: NodeType = None, depth: int = -1, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are downstream from the given node, filterable by type and max depth. Setting a max depth of -1 will include all downstream nodes. """ - return await get_downstream_nodes( + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + + downstreams = await get_downstream_nodes( session=session, node_name=name, node_type=node_type, @@ -1104,6 +1204,11 @@ async def list_downstream_nodes( depth=depth, ) + for node in downstreams: + access_checker.add_request_by_node_name(node.name, ResourceAction.READ) + accessible = await access_checker.approved_resource_names() + return [node for node in downstreams if node.name in accessible] + @router.get( "/nodes/{name}/upstream/", @@ -1117,10 +1222,14 @@ async def list_upstream_nodes( cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are upstream from the given node, filterable by type. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = cast(Node, await Node.get_by_name(session, name, raise_if_not_exists=True)) upstream_cache_key = node.upstream_cache_key() results = cache.get(upstream_cache_key) @@ -1132,7 +1241,11 @@ async def list_upstream_nodes( results, timeout=settings.query_cache_timeout, ) - return results + + for node in results: + access_checker.add_request_by_node_name(node.name, ResourceAction.READ) + accessible = await access_checker.approved_resource_names() + return [node for node in results if node.name in accessible] @router.get( @@ -1143,11 +1256,15 @@ async def list_node_dag( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are part of the DAG of the given node. This means getting all upstreams, downstreams, and linked dimension nodes. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -1183,10 +1300,14 @@ async def list_all_dimension_attributes( *, depth: int = 30, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> list[DimensionAttributeOutput]: """ List all available dimension attributes for the given node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + dimensions = await get_dimension_attributes(session, name) filter_only_dimensions = await get_filter_only_dimensions(session, name) return dimensions + filter_only_dimensions @@ -1201,10 +1322,13 @@ async def column_lineage( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[LineageColumn]: """ List column-level lineage of a node in a graph """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) node = await Node.get_by_name( session, @@ -1239,10 +1363,14 @@ async def set_column_display_name( save_history: Callable = Depends(get_save_history), *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Set column name for the node """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1281,10 +1409,14 @@ async def set_column_description( save_history: Callable = Depends(get_save_history), *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Set column description for the node """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1324,10 +1456,14 @@ async def set_column_partition( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Add or update partition columns for the specified node. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1394,6 +1530,7 @@ async def copy_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DAGNodeOutput: """ Copy this node to a new name. @@ -1402,6 +1539,11 @@ async def copy_node( new_node_namespace = ".".join(new_name.split(".")[:-1]) await get_node_namespace(session, new_node_namespace, raise_if_not_exists=True) + # Check that the user has access to read the existing node and write to the new namespace + access_checker.add_request_by_node_name(node_name, ResourceAction.READ) + access_checker.add_namespace(new_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + # Check if there is already a node with the new name existing_new_node = await get_node_by_name( session, diff --git a/datajunction-server/datajunction_server/api/sql.py b/datajunction-server/datajunction_server/api/sql.py index 12699c190..3c49a6fb7 100644 --- a/datajunction-server/datajunction_server/api/sql.py +++ b/datajunction-server/datajunction_server/api/sql.py @@ -9,6 +9,7 @@ from fastapi import BackgroundTasks, Depends, Query, Request from sqlalchemy.ext.asyncio import AsyncSession +from datajunction_server.utils import get_current_user from datajunction_server.construction.build_v3 import ( build_metrics_sql, build_measures_sql, @@ -25,12 +26,10 @@ from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache from datajunction_server.database import Node -from datajunction_server.database.queryrequest import QueryBuildType from datajunction_server.database.user import User +from datajunction_server.database.queryrequest import QueryBuildType from datajunction_server.errors import DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access from datajunction_server.models.metric import TranslatedSQL, V3TranslatedSQL from datajunction_server.models.node_type import NodeType from datajunction_server.models.query import V3ColumnMetadata @@ -42,7 +41,6 @@ ) from datajunction_server.models.sql import GeneratedSQL from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -79,13 +77,8 @@ async def get_measures_sql_for_cube_v2( ), ), cache: Cache = Depends(get_cache), - session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: Optional[User] = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), use_materialized: bool = True, background_tasks: BackgroundTasks, request: Request, @@ -119,8 +112,6 @@ async def get_measures_sql_for_cube_v2( include_all_columns=include_all_columns, preaggregate=preaggregate, use_materialized=use_materialized, - current_user=current_user, - validate_access=validate_access, ), ) @@ -138,13 +129,8 @@ async def get_sql( limit: Optional[int] = None, query_params: str = Query("{}", description="Query parameters"), *, - session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, ignore_errors: Optional[bool] = True, use_materialized: Optional[bool] = True, @@ -172,8 +158,6 @@ async def get_sql( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) @@ -404,10 +388,6 @@ async def get_sql_for_metrics( session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), ignore_errors: Optional[bool] = True, use_materialized: Optional[bool] = True, background_tasks: BackgroundTasks, @@ -449,7 +429,5 @@ async def get_sql_for_metrics( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) diff --git a/datajunction-server/datajunction_server/api/system.py b/datajunction-server/datajunction_server/api/system.py index c8c3407ea..835cb1d85 100644 --- a/datajunction-server/datajunction_server/api/system.py +++ b/datajunction-server/datajunction_server/api/system.py @@ -7,10 +7,6 @@ from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession -from datajunction_server.internal.access.authorization import ( - validate_access, -) -from datajunction_server.models import access from datajunction_server.models.system import DimensionStats, RowOutput from datajunction_server.sql.dag import ( get_cubes_using_dimensions, @@ -18,12 +14,10 @@ ) from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache -from datajunction_server.database.user import User from datajunction_server.database.node import Node from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.models.node_type import NodeType from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -63,11 +57,7 @@ async def get_data_for_system_metric( limit: int | None = None, session: AsyncSession = Depends(get_session), *, - current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), cache: Cache = Depends(get_cache), request: Request, ) -> list[list[RowOutput]]: @@ -94,8 +84,6 @@ async def get_data_for_system_metric( filters=filters, orderby=orderby, limit=limit, - current_user=current_user, - validate_access=validate_access, ), ) results = await session.execute(text(translated_sql.sql)) diff --git a/datajunction-server/datajunction_server/config.py b/datajunction-server/datajunction_server/config.py index ed6941ca6..9e0e2bd73 100644 --- a/datajunction-server/datajunction_server/config.py +++ b/datajunction-server/datajunction_server/config.py @@ -157,6 +157,18 @@ class Settings(BaseSettings): # pragma: no cover # or a custom implementation of the GroupMembershipProvider interface group_membership_provider: str = "postgres" + # Authorization configuration + # Provider for authorization checks: + # - "rbac": Role-based access control (default) + # - "passthrough": Always approve (testing/development) + # - Custom implementations can be plugged in + authorization_provider: str = "rbac" + + # Default access policy when no explicit RBAC rule exists: + # - "permissive": Allow by default + # - "restrictive": Deny by default + default_access_policy: str = "permissive" # or "restrictive" + # Interval in seconds with which to expire caching of any indexes index_cache_expire: int = 60 diff --git a/datajunction-server/datajunction_server/construction/build.py b/datajunction-server/datajunction_server/construction/build.py index 7588f61a0..6344214f5 100755 --- a/datajunction-server/datajunction_server/construction/build.py +++ b/datajunction-server/datajunction_server/construction/build.py @@ -12,7 +12,7 @@ from datajunction_server.database.node import Node, NodeRevision from datajunction_server.errors import DJError, DJInvalidInputException, ErrorCode from datajunction_server.internal.engines import get_engine -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.cube_materialization import MetricComponent from datajunction_server.models.engine import Dialect from datajunction_server.models.materialization import GenericCubeConfig @@ -176,7 +176,7 @@ async def build_metric_nodes( engine_name: Optional[str] = None, engine_version: Optional[str] = None, build_criteria: Optional[BuildCriteria] = None, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker | None = None, ignore_errors: bool = True, query_parameters: Optional[dict[str, Any]] = None, ): @@ -214,7 +214,7 @@ async def build_metric_nodes( .order_by(orderby) .limit(limit) .with_build_criteria(build_criteria) - .with_access_control(access_control) + .with_access_control(access_checker) ) if ignore_errors: builder = builder.ignore_errors() diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index 7ab01224c..3ac4a1d39 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -18,6 +18,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.construction.utils import to_namespaced_name from datajunction_server.database import Engine from datajunction_server.database.attributetype import ColumnAttribute @@ -274,7 +278,7 @@ def __init__( self._orderby: list[str] = [] self._limit: Optional[int] = None self._build_criteria: Optional[BuildCriteria] = self.get_default_criteria() - self._access_control: Optional[access.AccessControlStore] = None + self._access_checker: Optional[AccessChecker] = None self._ignore_errors: bool = False # The following attributes will be modified as the query gets built. @@ -384,14 +388,14 @@ def with_build_criteria(self, build_criteria: Optional[BuildCriteria] = None): def with_access_control( self, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker, ): """ Set access control for the query builder. """ - if access_control: # pragma: no cover - access_control.add_request_by_node(self.node_revision) - self._access_control = access_control + if access_checker: # pragma: no cover + access_checker.add_node(self.node_revision, access.ResourceAction.READ) + self._access_checker = access_checker return self @property @@ -572,7 +576,7 @@ async def build(self) -> ast.Query: ) # Error validation - self.validate_access() + await self.validate_access() if self.errors and not self._ignore_errors: raise DJQueryBuildException(errors=self.errors) return self.final_ast # type: ignore @@ -991,23 +995,18 @@ def set_dimension_aliases(self): node_col.set_semantic_entity(dim_name) node_col.set_semantic_type(SemanticType.DIMENSION) - async def add_request_by_node_name(self, node_name): + def add_request_by_node_name(self, node_name: str): """Add a node request to the access control validator.""" - if self._access_control: # pragma: no cover - # Use cached node if available to avoid DB lookup - cached_node = self.dependencies_cache.get(node_name) - if cached_node: # pragma: no cover - self._access_control.add_request_by_node(cached_node) - else: # pragma: no cover - await self._access_control.add_request_by_node_name( - self.session, - node_name, - ) + if self._access_checker: # pragma: no cover + self._access_checker.add_request_by_node_name( + node_name, + access.ResourceAction.READ, + ) - def validate_access(self): + async def validate_access(self): """Validates access""" - if self._access_control: - self._access_control.validate_and_raise() + if self._access_checker: + await self._access_checker.check(on_denied=AccessDenialMode.RAISE) async def find_dimension_node_joins( self, @@ -1059,7 +1058,7 @@ async def find_dimension_node_joins( # Build DimensionJoin objects using preloaded paths for attr in non_local_dimensions: - await self.add_request_by_node_name(attr.node_name) + self.add_request_by_node_name(attr.node_name) if attr.join_key not in dimension_node_joins: # Find matching path - try exact role match first, then no-role match @@ -1177,8 +1176,8 @@ def __init__( self._orderby: list[str] = [] self._limit: Optional[int] = None self._parameters: dict[str, ast.Value] = {} - self._build_criteria: Optional[BuildCriteria] = self.get_default_criteria() - self._access_control: Optional[access.AccessControlStore] = None + self._build_criteria: BuildCriteria | None = self.get_default_criteria() + self._access_checker: AccessChecker | None = None self._ignore_errors: bool = False # The following attributes will be modified as the query gets built. @@ -1293,14 +1292,13 @@ def with_build_criteria(self, build_criteria: Optional[BuildCriteria] = None): def with_access_control( self, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker, ): """ Set access control for the query builder. """ - if access_control: # pragma: no cover - access_control.add_request_by_nodes(self.metric_nodes) - self._access_control = access_control + access_checker.add_nodes(self.metric_nodes, access.ResourceAction.READ) + self._access_checker = access_checker return self @property @@ -1398,15 +1396,15 @@ async def build(self) -> ast.Query: self.final_ast.select.limit = ast.Number(value=self._limit) # Error validation - self.validate_access() + await self.validate_access() if self.errors and not self._ignore_errors: raise DJQueryBuildException(errors=self.errors) # pragma: no cover return self.final_ast - def validate_access(self): + async def validate_access(self): """Validates access""" - if self._access_control: - self._access_control.validate_and_raise() + if self._access_checker: # pragma: no cover + await self._access_checker.check(on_denied=AccessDenialMode.RAISE) async def build_measures_queries(self): """ @@ -1424,7 +1422,7 @@ async def build_measures_queries(self): if self._ignore_errors: query_builder = query_builder.ignore_errors() parent_ast = await ( - query_builder.with_access_control(self._access_control) + query_builder.with_access_control(self._access_checker) .with_build_criteria(self._build_criteria) .add_dimensions(self.dimensions) .add_filters(self.filters) @@ -1483,18 +1481,20 @@ async def build_metric_agg( """ Build the metric's aggregate expression. """ - if self._access_control: - self._access_control.add_request_by_node(metric_node) # type: ignore + if self._access_checker: # pragma: no cover + self._access_checker.add_node(metric_node, access.ResourceAction.READ) # type: ignore metric_query_builder = await QueryBuilder.create(self.session, metric_node) if self._ignore_errors: metric_query_builder = ( # pragma: no cover metric_query_builder.ignore_errors() ) - metric_query = await ( - metric_query_builder.with_access_control(self._access_control) - .with_build_criteria(self._build_criteria) - .build() - ) + if self._access_checker: # pragma: no cover + metric_query_builder = metric_query_builder.with_access_control( + self._access_checker, + ) + metric_query = await metric_query_builder.with_build_criteria( + self._build_criteria, + ).build() self.errors.extend(metric_query_builder.errors) metric_query.ctes[-1].select.projection[0].set_semantic_entity( # type: ignore f"{metric_node.name}.{amenable_name(metric_node.name)}", diff --git a/datajunction-server/datajunction_server/construction/dimensions.py b/datajunction-server/datajunction_server/construction/dimensions.py index dceb343fc..463c1db38 100644 --- a/datajunction-server/datajunction_server/construction/dimensions.py +++ b/datajunction-server/datajunction_server/construction/dimensions.py @@ -11,7 +11,7 @@ from datajunction_server.database.node import NodeRevision from datajunction_server.database.user import User from datajunction_server.errors import DJInvalidInputException -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.column import SemanticType from datajunction_server.models.metric import TranslatedSQL from datajunction_server.models.query import ColumnMetadata @@ -27,7 +27,7 @@ async def build_dimensions_from_cube_query( cube: NodeRevision, dimensions: List[str], current_user: User, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, filters: Optional[str] = None, limit: Optional[int] = 50000, include_counts: bool = False, @@ -101,8 +101,7 @@ async def build_dimensions_from_cube_query( metrics=[metric.name for metric in cube.cube_metrics()], dimensions=dimensions, filters=[filters] if filters else [], - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) measures_query_ast = parse(measures_query[0].sql) measures_query_ast.bake_ctes() diff --git a/datajunction-server/datajunction_server/database/user.py b/datajunction-server/datajunction_server/database/user.py index 5b8d56eef..e64f60eed 100644 --- a/datajunction-server/datajunction_server/database/user.py +++ b/datajunction-server/datajunction_server/database/user.py @@ -34,6 +34,7 @@ from datajunction_server.database.notification_preference import ( NotificationPreference, ) + from datajunction_server.database.rbac import RoleAssignment from datajunction_server.database.tag import Tag logger = logging.getLogger(__name__) @@ -144,6 +145,7 @@ class User(Base): ) # Group membership relationships (for kind=GROUP) + # Groups that this user owns (for kind=GROUP) group_members: Mapped[list["GroupMember"]] = relationship( "GroupMember", foreign_keys="GroupMember.group_id", @@ -156,6 +158,13 @@ class User(Base): viewonly=True, ) + # RBAC role assignments (for authorization) + role_assignments: Mapped[list["RoleAssignment"]] = relationship( + "RoleAssignment", + foreign_keys="RoleAssignment.principal_id", + viewonly=True, + ) + @classmethod async def get_by_username( cls, diff --git a/datajunction-server/datajunction_server/internal/access/authentication/basic.py b/datajunction-server/datajunction_server/internal/access/authentication/basic.py index 2026e069d..7bc332151 100644 --- a/datajunction-server/datajunction_server/internal/access/authentication/basic.py +++ b/datajunction-server/datajunction_server/internal/access/authentication/basic.py @@ -5,10 +5,11 @@ import logging from passlib.context import CryptContext -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.base import ExecutableOption +from sqlalchemy.orm import selectinload +from datajunction_server.database.rbac import RoleAssignment, Role from datajunction_server.database.user import User from datajunction_server.errors import DJAuthenticationException, DJError, ErrorCode @@ -33,19 +34,29 @@ def get_password_hash(password) -> str: async def get_user( username: str, session: AsyncSession, - *options: ExecutableOption, + options: list[ExecutableOption] | None = None, ) -> User: """ Get a DJ user """ - user = ( - ( - await session.execute( - select(User).where(User.username == username).options(*options), - ) - ) - .unique() - .scalar_one_or_none() + from datajunction_server.database.group_member import GroupMember + + user = await User.get_by_username( + session=session, + username=username, + options=options + or [ + # Load user's direct role assignments + selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + # Load user's group memberships and the groups' role assignments + selectinload(User.member_of) + .selectinload(GroupMember.group) + .selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + ], ) if not user: raise DJAuthenticationException( diff --git a/datajunction-server/datajunction_server/internal/access/authorization.py b/datajunction-server/datajunction_server/internal/access/authorization.py deleted file mode 100644 index 8fd26aa71..000000000 --- a/datajunction-server/datajunction_server/internal/access/authorization.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -Authorization related functionality -""" - -from typing import Iterable, List, Union - -from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.database.user import User -from datajunction_server.models.access import ( - AccessControl, - AccessControlStore, - ResourceRequest, - ValidateAccessFn, -) -from datajunction_server.models.user import UserOutput - - -def validate_access_requests( - validate_access: ValidateAccessFn, - user: User, - resource_requests: Iterable[ResourceRequest], - raise_: bool = False, -) -> List[Union[NodeRevision, Node, ResourceRequest]]: - """ - Validate a set of access requests. Only approved requests are returned. - """ - if user is None: - return list(resource_requests) # pragma: no cover - access_control = AccessControlStore( - validate_access=validate_access, - user=UserOutput( - id=user.id, - username=user.username, - oauth_provider=user.oauth_provider, - ), - ) - - for request in resource_requests: - access_control.add_request(request) - - validation_results = access_control.validate() - if raise_: - access_control.raise_if_invalid_requests() # pragma: no cover - return [result for result in validation_results if result.approved] - - -def validate_access() -> ValidateAccessFn: - """ - A placeholder validate access dependency injected function - that returns a ValidateAccessFn that approves all requests - """ - - def _(access_control: AccessControl): - """ - Examines all requests in the AccessControl - and approves or denies each - - Args: - access_control (AccessControl): The access control object - containing the access control state and requests. - - Example: - if access_control.state == 'direct': - access_control.approve_all() - return - - if access_control.user=='dj': - request.approve_all() - return - - request.deny_all() - """ - access_control.approve_all() - - return _ diff --git a/datajunction-server/datajunction_server/internal/access/authorization/__init__.py b/datajunction-server/datajunction_server/internal/access/authorization/__init__.py new file mode 100644 index 000000000..7daa449cb --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/__init__.py @@ -0,0 +1,31 @@ +"""All authorization functions.""" + +__all__ = [ + "AuthContext", + "get_auth_context", + "AccessChecker", + "get_access_checker", + "AccessDenialMode", + "AuthorizationService", + "RBACAuthorizationService", + "PassthroughAuthorizationService", + "get_authorization_service", +] + +from datajunction_server.internal.access.authorization.context import ( + AuthContext, + get_auth_context, +) + +from datajunction_server.internal.access.authorization.validator import ( + AccessChecker, + get_access_checker, + AccessDenialMode, +) + +from datajunction_server.internal.access.authorization.service import ( + AuthorizationService, + RBACAuthorizationService, + PassthroughAuthorizationService, + get_authorization_service, +) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/context.py b/datajunction-server/datajunction_server/internal/access/authorization/context.py new file mode 100644 index 000000000..7eec35b43 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/context.py @@ -0,0 +1,130 @@ +""" +Authorization context for a user, pre-loaded with all roles. +""" + +from fastapi import Depends +from dataclasses import dataclass +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +from datajunction_server.internal.access.group_membership import ( + get_group_membership_service, +) +from datajunction_server.database.rbac import RoleAssignment, Role +from datajunction_server.database.user import User +from datajunction_server.utils import ( + get_current_user, + get_session, + get_settings, +) + +settings = get_settings() + + +@dataclass(frozen=True) +class AuthContext: + """ + Authorization context for a user. + + Contains all data needed to make authorization decisions, + pre-loaded and ready for fast in-memory checks. + + This separates authorization data from the full User model, + allowing for clean caching, testing, and type safety. + """ + + user_id: int + username: str + oauth_provider: Optional[str] + role_assignments: List[RoleAssignment] # Direct + groups, flattened + + @classmethod + async def from_user( + cls, + session: AsyncSession, + user: User, + ) -> "AuthContext": + """ + Build authorization context from a User object. + + This loads all effective role assignments (direct + group-based) + for the user using the configured GroupMembershipService. + + Args: + session: db session + user: user to build context for + + Returns: + AuthContext ready for authorization checks + """ + assignments = await cls.get_effective_assignments( + session=session, + user=user, + ) + + return cls( + user_id=user.id, + username=user.username, + oauth_provider=user.oauth_provider, + role_assignments=assignments, + ) + + @classmethod + async def get_effective_assignments( + cls, + session: AsyncSession, + user: User, + ) -> List[RoleAssignment]: + """ + Get all effective role assignments for a user (direct + group-based). + + Args: + session: db session + user: user to get assignments for + Returns: + list of all role assignments that apply to this user + """ + group_membership_service = get_group_membership_service() + + # Start with user's direct assignments + assignments = list(user.role_assignments) + + # Get groups from service (could be LDAP, local DB, etc.) + group_usernames = await group_membership_service.get_user_groups( + session, + user.username, + ) + + if not group_usernames: + return assignments # No groups + + # Load groups from DJ database with their role_assignments + stmt = ( + select(User) + .where(User.username.in_(group_usernames)) + .options( + selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + ) + ) + result = await session.execute(stmt) + groups = result.scalars().all() + + # Flatten group assignments into the list + for group in groups: + assignments.extend(group.role_assignments) + + return assignments + + +async def get_auth_context( + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> AuthContext: + """Build authorization context with user + group assignments.""" + return await AuthContext.from_user(session, current_user) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/service.py b/datajunction-server/datajunction_server/internal/access/authorization/service.py new file mode 100644 index 000000000..6abd28f12 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/service.py @@ -0,0 +1,329 @@ +""" +Authorization service implementations for access control. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from functools import lru_cache +from typing import List + + +from datajunction_server.models.access import ( + AccessDecision, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.authorization.context import ( + AuthContext, +) +from datajunction_server.utils import ( + SEPARATOR, + get_settings, +) + +settings = get_settings() + + +class AuthorizationService(ABC): + """ + Abstract base class for authorization strategies. + + Authorization is performed on a pre-loaded authorization context. + + Implementations of this base class decide exactly how to authorize requests: + - RBACAuthorizationService: Uses pre-loaded roles/scopes (default) + - PassthroughAuthorizationService: Always approve (testing/permissive) + - Custom: Your own authorization logic + + Each implementation should define a `name` class attribute to register itself. + """ + + name: str # Subclasses must define this + + @abstractmethod + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """ + Authorize resource requests for a user. + + This method should mutate the `approved` field on each request + to indicate whether access is granted. + + Args: + auth_context: Pre-loaded authorization context with all needed data + requests: List of resource requests to authorize + + Returns: + The same list of requests with approved=True/False set on each + """ + + +class RBACAuthorizationService(AuthorizationService): + """ + Default RBAC implementation using pre-loaded roles and scopes. + + This implementation: + 1. Works on AuthContext with pre-loaded role_assignments (direct + groups) + 2. Falls back to default_access_policy if no explicit rule exists + 3. Respects role expiration + 4. Synchronous - works on eagerly loaded data + + Group Membership Integration: + - Supports pluggable GroupMembershipService (LDAP, local DB, etc.) + - Groups are loaded when building AuthContext via from_user() + - No DB queries during authorization - all data pre-loaded + """ + + name = "rbac" + + PERMISSION_HIERARCHY = { + ResourceAction.MANAGE: { + ResourceAction.MANAGE, + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.DELETE: { + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.WRITE: { + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.EXECUTE: { + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.READ: { + ResourceAction.READ, + }, + } + + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """ + Authorize using pre-loaded RBAC roles and scopes (sync). + + Args: + auth_context: Pre-loaded authorization context with role assignments + requests: Resource requests to authorize + + Returns: + Same list of requests with approved=True/False set + """ + return [self._make_decision(auth_context, request) for request in requests] + + def _make_decision( + self, + auth_context: AuthContext, + request: ResourceRequest, + ) -> AccessDecision: + """ + Convert ResourceRequest to AccessDecision. + """ + has_grant = self.has_permission( + assignments=auth_context.role_assignments, + action=request.verb, + resource_type=request.access_object.resource_type, + resource_name=request.access_object.name, + ) + return AccessDecision( + request=request, + approved=(has_grant or settings.default_access_policy == "permissive"), + ) + + @classmethod + def resource_matches_pattern(cls, resource_name: str, pattern: str) -> bool: + """ + Check if resource name matches a pattern with wildcard support. + + resource_matches_pattern("finance.revenue", "finance.*") --> True + resource_matches_pattern("finance.quarterly.revenue", "finance.*") --> True + resource_matches_pattern("users.alice.dashboard", "users.alice.*") --> True + resource_matches_pattern("marketing.revenue", "finance.*") --> False + resource_matches_pattern("anything", "*") --> True + resource_matches_pattern("finance", "finance.*") --> False + """ + if pattern == "*": + return True # Match everything + + if "*" not in pattern: + return resource_name == pattern # Exact match + + # Wildcard pattern: finance.* matches finance.revenue and finance.quarterly.revenue + # But NOT just "finance" (must have something after the dot) + pattern_prefix = pattern.rstrip("*").rstrip(SEPARATOR) + + if not pattern_prefix: + return True # Pattern was just "*" + + # Resource must start with pattern_prefix followed by a dot + # (not an exact match to pattern_prefix, that would be handled by exact pattern) + return resource_name.startswith(pattern_prefix + SEPARATOR) + + @classmethod + def has_permission( + cls, + assignments: List, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Determine if a list of role assignments grants the requested permission. + + This method iterates through all provided role assignments, checking if any + grant the specified action on the given resource. Expired assignments are + automatically skipped. Returns True if at least one valid assignment grants + access, False otherwise. + + Args: + assignments: List of role assignments to check + action: The action being requested (READ, WRITE, etc.) + resource_type: Type of resource (NODE, NAMESPACE, etc.) + resource_name: Full name/identifier of the resource + + Returns: + True if permission is granted, False otherwise + """ + for assignment in assignments: + # Skip expired assignments + if assignment.expires_at and assignment.expires_at < datetime.now( + timezone.utc, + ): + continue + + # Check each scope in the role + for scope in assignment.role.scopes: + # Check if scope grants permission for this resource + if cls._scope_grants_permission( + scope, + action, + resource_type, + resource_name, + ): + return True + + return False + + @classmethod + def _scope_grants_permission( + cls, + scope, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Check if a scope grants permission for a resource. + + Handles: + 1. Permission hierarchy (MANAGE > DELETE > WRITE > READ, EXECUTE > READ) + 2. Empty/None scope_value or "*" = global access + 3. Wildcard pattern matching (finance.*) + 4. Cross-resource-type: namespace scope covers nodes in that namespace + """ + # Check permission hierarchy: does scope.action grant the requested action? + granted_actions = cls.PERMISSION_HIERARCHY.get(scope.action, {scope.action}) + if action not in granted_actions: + return False + + # Handle global access (empty string, None, or "*" scope_value) + if not scope.scope_value or scope.scope_value == "" or scope.scope_value == "*": + # Global scope matches any resource of the same type + return scope.scope_type == resource_type + + # Same resource type - use pattern matching + if scope.scope_type == resource_type: + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # Cross-resource-type: namespace scope can cover nodes + if ( + scope.scope_type == ResourceType.NAMESPACE + and resource_type == ResourceType.NODE + ): + # Check if node name matches the namespace pattern + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # No match + return False + + +class PassthroughAuthorizationService(AuthorizationService): + """ + Always approves all requests (for testing or permissive environments). + """ + + name = "passthrough" + + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """Approve all requests without checks (sync).""" + return [AccessDecision(request=request, approved=True) for request in requests] + + +@lru_cache(maxsize=None) +def get_authorization_service() -> AuthorizationService: + """ + Factory function to get the configured authorization service. + + This is used as a FastAPI dependency. The service can be overridden + via app.dependency_overrides for testing or custom deployments. + + Built-in providers: + - "rbac": Role-based access control using roles/scopes tables (default) + - "passthrough": Always approve all requests + + Configure via environment variable: + ```bash + AUTHORIZATION_PROVIDER=rbac # or passthrough + ``` + + Custom providers can be added by: + 1. Subclassing AuthorizationService + 2. Defining a `name` class attribute + 3. Importing the class before app starts + + Example: + ```python + class ExampleAuthService(AuthorizationService): + name = "example" + + def authorize(self, user, requests): + # Your sync authorization logic + return requests + ``` + + Returns: + AuthorizationService implementation + + Raises: + ValueError: If the configured provider is unknown + """ + provider = getattr(settings, "authorization_provider", "rbac") + + # Discover all subclasses + providers = {} + for subclass in AuthorizationService.__subclasses__(): + providers[subclass.name] = subclass + if subclass.name == provider: + return subclass() # type: ignore[abstract] + + available = ", ".join(sorted(providers.keys())) + raise ValueError( + f"Unknown authorization_provider: '{provider}'. " + f"Available providers: {available}", + ) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/validator.py b/datajunction-server/datajunction_server/internal/access/authorization/validator.py new file mode 100644 index 000000000..4ffab7dd0 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/validator.py @@ -0,0 +1,164 @@ +""" +Access validation collection and helper functions. +""" + +from fastapi import Depends +from enum import Enum + + +from datajunction_server.internal.access.authorization.service import ( + get_authorization_service, +) +from datajunction_server.database.node import Node +from datajunction_server.models.access import ( + AccessDecision, + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.authorization.context import ( + AuthContext, + get_auth_context, +) +from datajunction_server.utils import ( + get_settings, +) + +settings = get_settings() + + +class AccessDenialMode(Enum): + """ + How to handle denied access requests. + """ + + FILTER = "filter" # Return only approved requests + RAISE = "raise" # Raise exception if any denied + RETURN = "return" # Return all requests with approved field set + + +class AccessChecker: + """Collects authorization requests and validates them.""" + + def __init__(self, auth_context: AuthContext): + self.auth_context = auth_context + self.requests: list[ResourceRequest] = [] + + def add_request(self, request: ResourceRequest): + """Add a request to check.""" + self.requests.append(request) + + def add_requests(self, requests: list[ResourceRequest]): + """Add requests to check.""" + self.requests.extend(requests) + + @classmethod + def resource_request_from_node( + cls, + node: Node, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a Node.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_node(node), + ) + + def add_request_by_node_name(self, node_name: str, action: ResourceAction): + """Add request by node name.""" + self.requests.append( + ResourceRequest( + verb=action, + access_object=Resource(name=node_name, resource_type=ResourceType.NODE), + ), + ) + + def add_node(self, node: Node, action: ResourceAction): + """Add request for a node.""" + node_request = self.resource_request_from_node(node, action) + self.add_request(node_request) + + def add_nodes(self, nodes: list[Node], action: ResourceAction): + """Add requests for multiple nodes.""" + self.requests.extend( + self.resource_request_from_node(node, action) for node in nodes + ) + + @classmethod + def resource_request_from_namespace( + cls, + namespace: str, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a namespace.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_namespace(namespace), + ) + + def add_namespace(self, namespace: str, action: ResourceAction): + """Add request for a namespace.""" + namespace_request = self.resource_request_from_namespace(namespace, action) + self.add_request(namespace_request) + + def add_namespaces(self, namespaces: list[str], action: ResourceAction): + """Add requests for multiple namespaces.""" + self.requests.extend( + self.resource_request_from_namespace(namespace, action) + for namespace in namespaces + ) + + async def check( + self, + on_denied: AccessDenialMode = AccessDenialMode.FILTER, + ) -> list[AccessDecision]: + """ + Validate all requests using AuthorizationService. + + Args: + on_denied: How to handle denied requests + - FILTER: Return only approved (default) + - RAISE: Raise exception if any denied + - RETURN_ALL: Return all with approved field set + """ + auth_service = get_authorization_service() + access_decisions = auth_service.authorize(self.auth_context, self.requests) + + if on_denied == AccessDenialMode.RETURN: + return access_decisions + elif on_denied == AccessDenialMode.RAISE: + denied: list[AccessDecision] = [ + decision for decision in access_decisions if not decision.approved + ] + if denied: + from datajunction_server.errors import DJAuthorizationException + + # Show first 5 denied resources + denied_names = [d.request.access_object.name for d in denied[:5]] + more_count = max(0, len(denied) - 5) + + raise DJAuthorizationException( + message=( + f"Access denied to {len(denied)} resource(s): " + f"{', '.join(denied_names)}" + + (f" and {more_count} more" if more_count else "") + ), + ) + return access_decisions + # Default: FILTER + return [decision for decision in access_decisions if decision.approved] + + async def approved_resource_names(self) -> list[str]: + """Get approved resource names.""" + return [ + decision.request.access_object.name + for decision in await self.check(on_denied=AccessDenialMode.FILTER) + ] + + +def get_access_checker( + auth_context: AuthContext = Depends(get_auth_context), +) -> AccessChecker: + """Provide AccessChecker with pre-loaded context.""" + return AccessChecker(auth_context) diff --git a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py index 8e431d774..0dc70501f 100644 --- a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py +++ b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py @@ -12,11 +12,14 @@ QueryBuildType, VersionedQueryKey, ) +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AuthContext, + get_access_checker, +) from datajunction_server.internal.sql import build_sql_for_multiple_metrics -from datajunction_server.database.user import User -from datajunction_server.models import access from datajunction_server.models.sql import GeneratedSQL -from datajunction_server.utils import session_context, get_settings +from datajunction_server.utils import get_current_user, session_context, get_settings from datajunction_server.internal.sql import get_measures_query from datajunction_server.internal.sql import build_node_sql from datajunction_server.internal.engines import get_engine @@ -40,8 +43,6 @@ class QueryRequestParams: limit: int | None = None orderby: list[str] | None = None other_args: dict[str, Any] | None = None - current_user: User | None = None - validate_access: access.ValidateAccessFn | None = None include_all_columns: bool = False use_materialized: bool = False preaggregate: bool = False @@ -58,6 +59,16 @@ def __repr__(self): ) +async def build_access_checker_from_request( + request: Request, + session: AsyncSession, +) -> AccessChecker: + """Helper to build checker from request + session.""" + current_user = await get_current_user(request) + auth_context = await AuthContext.from_user(session, current_user) + return get_access_checker(auth_context) + + class QueryCacheManager(RefreshAheadCacheManager): """ A generic manager for handling caching operations. @@ -85,39 +96,32 @@ async def fallback( """ params = deepcopy(params) async with session_context(request) as session: + access_checker = await build_access_checker_from_request(request, session) params.nodes = list(OrderedDict.fromkeys(params.nodes)) query_parameters = ( json.loads(params.query_params) if params.query_params else {} ) - access_control_store = ( - access.AccessControlStore( - validate_access=params.validate_access, - user=params.current_user, - base_verb=access.ResourceAction.READ, - ) - if params.validate_access - else None - ) match self.query_type: case QueryBuildType.MEASURES: return await self._build_measures_query( session, params, query_parameters, + access_checker, ) case QueryBuildType.NODE: return await self._build_node_query( session, params, query_parameters, - access_control_store, + access_checker, ) case QueryBuildType.METRICS: # pragma: no cover return await self._build_metrics_query( session, params, query_parameters, - access_control_store, + access_checker, ) async def build_cache_key( @@ -155,6 +159,7 @@ async def _build_measures_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], + access_checker: AccessChecker, ) -> list[GeneratedSQL]: return await get_measures_query( session=session, @@ -164,8 +169,7 @@ async def _build_measures_query( orderby=params.orderby or [], engine_name=params.engine_name, engine_version=params.engine_version, - current_user=params.current_user, - validate_access=params.validate_access, + access_checker=access_checker, include_all_columns=params.include_all_columns, use_materialized=params.use_materialized, preagg_requested=params.preaggregate, @@ -177,7 +181,7 @@ async def _build_node_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], - access_control_store: access.AccessControlStore | None = None, + access_checker: AccessChecker, ) -> TranslatedSQL: engine = ( await get_engine(session, params.engine_name, params.engine_version) # type: ignore @@ -195,7 +199,7 @@ async def _build_node_query( ignore_errors=params.ignore_errors, use_materialized=params.use_materialized, query_parameters=query_parameters, - access_control=access_control_store, + access_checker=access_checker, ) return TranslatedSQL.create( sql=built_sql.sql, @@ -208,7 +212,7 @@ async def _build_metrics_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], - access_control_store: access.AccessControlStore | None = None, + access_checker: AccessChecker, ) -> TranslatedSQL: built_sql, _, _ = await build_sql_for_multiple_metrics( session=session, @@ -219,7 +223,7 @@ async def _build_metrics_query( limit=params.limit, engine_name=params.engine_name, engine_version=params.engine_version, - access_control=access_control_store, + access_checker=access_checker, ignore_errors=params.ignore_errors, # type: ignore query_parameters=query_parameters, use_materialized=params.use_materialized, # type: ignore diff --git a/datajunction-server/datajunction_server/internal/deployment/utils.py b/datajunction-server/datajunction_server/internal/deployment/utils.py index 7e77c4bf7..c9aeac4ac 100644 --- a/datajunction-server/datajunction_server/internal/deployment/utils.py +++ b/datajunction-server/datajunction_server/internal/deployment/utils.py @@ -6,7 +6,6 @@ from datajunction_server.internal.caching.interface import Cache from datajunction_server.service_clients import QueryServiceClient from datajunction_server.database.user import User -from datajunction_server.models import access from datajunction_server.models.deployment import ( NodeSpec, CubeSpec, @@ -114,6 +113,5 @@ class DeploymentContext: current_user: User request: Request query_service_client: QueryServiceClient - validate_access: access.ValidateAccessFn background_tasks: BackgroundTasks cache: Cache diff --git a/datajunction-server/datajunction_server/internal/materializations.py b/datajunction-server/datajunction_server/internal/materializations.py index b6705da5b..e0dd124df 100644 --- a/datajunction-server/datajunction_server/internal/materializations.py +++ b/datajunction-server/datajunction_server/internal/materializations.py @@ -20,7 +20,7 @@ build_cube_materialization, ) from datajunction_server.materialization.jobs import MaterializationJob -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.column import SemanticType from datajunction_server.models.cube_materialization import UpsertCubeMaterialization from datajunction_server.models.materialization import ( @@ -104,7 +104,7 @@ async def build_cube_materialization_config( session: AsyncSession, current_revision: NodeRevision, upsert_input: UpsertMaterialization, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, current_user: User, ) -> DruidMeasuresCubeConfig: """ @@ -129,6 +129,7 @@ async def build_cube_materialization_config( metrics=[node.name for node in current_revision.cube_metrics()], dimensions=current_revision.cube_dimensions(), use_materialized=False, + access_checker=access_checker, ) generic_config = DruidMetricsCubeConfig( lookback_window=upsert_input.config.lookback_window, @@ -156,8 +157,7 @@ async def build_cube_materialization_config( metrics=[node.name for node in current_revision.cube_metrics()], dimensions=current_revision.cube_dimensions(), filters=[], - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) for measures_query in measures_queries: metrics_expressions = await rewrite_metrics_expressions( @@ -192,7 +192,7 @@ async def build_cube_materialization_config( f"node `{current_revision.name}` and job " f"`{upsert_input.job.name}` as" # type: ignore " the config does not have valid configuration for " - f"engine `{upsert_input.job.name}`." + f"engine `{upsert_input.job.name}`." # type: ignore ), ) from exc @@ -238,7 +238,7 @@ async def create_new_materialization( session: AsyncSession, current_revision: NodeRevision, upsert: UpsertCubeMaterialization | UpsertMaterialization, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, current_user: User, ) -> Materialization: """ @@ -284,7 +284,7 @@ async def create_new_materialization( session, current_revision, upsert, - validate_access, + access_checker, current_user=current_user, ) materialization_name = ( diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index e14eac31c..19254077f 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -12,6 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload + +from datajunction_server.internal.access.authorization import ( + AccessChecker, +) from datajunction_server.internal.caching.interface import Cache from datajunction_server.models.query import QueryCreate from datajunction_server.api.helpers import ( @@ -53,7 +57,6 @@ ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.validation import NodeValidator, validate_node_data -from datajunction_server.models import access from datajunction_server.models.attribute import ( AttributeTypeIdentifier, ColumnAttributes, @@ -118,7 +121,7 @@ async def create_a_source_node( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, ): request_headers = dict(request.headers) @@ -132,7 +135,7 @@ async def create_a_source_node( current_user=current_user, request_headers=request_headers, query_service_client=query_service_client, - validate_access=validate_access, + access_checker=access_checker, background_tasks=background_tasks, save_history=save_history, ): @@ -212,7 +215,7 @@ async def create_a_node( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, cache: Cache, ) -> Node: @@ -231,7 +234,7 @@ async def create_a_node( request_headers=request_headers, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, cache=cache, ): @@ -303,7 +306,7 @@ async def create_a_cube( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, ) -> Node: request_headers = dict(request.headers) @@ -318,7 +321,7 @@ async def create_a_cube( request_headers=request_headers, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, ): return recreated_node # pragma: no cover @@ -887,7 +890,7 @@ async def update_any_node( current_user: User, save_history: Callable, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, refresh_materialization: bool = False, cache: Cache | None = None, ) -> Node: @@ -907,15 +910,6 @@ async def update_any_node( ) node = cast(Node, node) - # Check that the user has access to modify this node - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.WRITE, - ) - access_control.add_request_by_node(node) - access_control.validate_and_raise() - if data.owners and data.owners != [owner.username for owner in node.owners]: await update_owners(session, node, data.owners, current_user, save_history) @@ -929,7 +923,7 @@ async def update_any_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, refresh_materialization=refresh_materialization, ) @@ -942,7 +936,7 @@ async def update_any_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, cache=cache, ) @@ -957,7 +951,7 @@ async def update_node_with_query( query_service_client: QueryServiceClient, current_user: User, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, cache: Cache, ) -> Node: @@ -1049,7 +1043,7 @@ async def update_node_with_query( lookback_window=old.lookback_window, ) ), - validate_access, + access_checker, current_user=current_user, ), ) @@ -1193,7 +1187,7 @@ async def update_cube_node( query_service_client: QueryServiceClient, current_user: User, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, refresh_materialization: bool = False, ) -> Optional[NodeRevision]: @@ -1296,7 +1290,7 @@ async def update_cube_node( ), job=MaterializationJobTypeEnum.find_match(old.job).value.name, ), - validate_access, + access_checker, current_user=current_user, ), ) @@ -1509,7 +1503,7 @@ async def create_node_from_inactive( query_service_client: QueryServiceClient, save_history: Callable, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, cache: Cache | None = None, ) -> Optional[Node]: """ @@ -1559,7 +1553,7 @@ async def create_node_from_inactive( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, cache=cache, ) @@ -1572,7 +1566,7 @@ async def create_node_from_inactive( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, ) try: diff --git a/datajunction-server/datajunction_server/internal/sql.py b/datajunction-server/datajunction_server/internal/sql.py index b3d9e0b50..da0ab8a82 100644 --- a/datajunction-server/datajunction_server/internal/sql.py +++ b/datajunction-server/datajunction_server/internal/sql.py @@ -5,6 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.api.helpers import ( assemble_column_metadata, find_existing_cube, @@ -29,7 +33,6 @@ ) from datajunction_server.database import Engine from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.database.user import User from datajunction_server.database.catalog import Catalog from datajunction_server.errors import DJInvalidInputException, DJException from datajunction_server.internal.engines import get_engine @@ -58,7 +61,8 @@ async def build_node_sql( orderby: list[str] | None = None, limit: int | None = None, engine: Engine | None = None, - access_control: access.AccessControlStore | None = None, + *, + access_checker: AccessChecker, ignore_errors: bool = True, use_materialized: bool = True, query_parameters: dict[str, Any] | None = None, @@ -95,7 +99,7 @@ async def build_node_sql( limit=limit, engine_name=engine.name if engine else None, engine_version=engine.version if engine else None, - access_control=access_control, + access_checker=access_checker, use_materialized=use_materialized, query_parameters=query_parameters, ) @@ -113,7 +117,7 @@ async def build_node_sql( limit, engine.name if engine else None, engine.version if engine else None, - access_control=access_control, + access_checker=access_checker, ignore_errors=ignore_errors, use_materialized=use_materialized, query_parameters=query_parameters, @@ -129,7 +133,7 @@ async def build_node_sql( orderby=orderby or [], limit=limit, engine=engine, - access_control=access_control, + access_checker=access_checker, use_materialized=use_materialized, query_parameters=query_parameters, ignore_errors=ignore_errors, @@ -156,7 +160,7 @@ async def build_sql_for_multiple_metrics( limit: int | None = None, engine_name: str | None = None, engine_version: str | None = None, - access_control: access.AccessControlStore | None = None, + access_checker: AccessChecker | None = None, ignore_errors: bool = True, use_materialized: bool = True, query_parameters: dict[str, str] | None = None, @@ -184,8 +188,8 @@ async def build_sql_for_multiple_metrics( ), ], ) - if access_control: - access_control.add_request_by_node(leading_metric_node.current) # type: ignore + if access_checker: + access_checker.add_node(leading_metric_node.current, access.ResourceAction.READ) # type: ignore available_engines = ( leading_metric_node.current.catalog.engines # type: ignore if leading_metric_node.current.catalog # type: ignore @@ -233,10 +237,9 @@ async def build_sql_for_multiple_metrics( validate_orderby(orderby, metrics, dimensions) if cube and cube.availability and use_materialized and materialized_cube_catalog: - if access_control: # pragma: no cover - access_control.add_request_by_node(cube) - access_control.state = access.AccessControlState.INDIRECT - access_control.raise_if_invalid_requests() + if access_checker: # pragma: no cover + access_checker.add_node(cube, access.ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) query_ast = build_materialized_cube_node( metric_columns, dimension_columns, @@ -284,10 +287,15 @@ async def build_sql_for_multiple_metrics( dimensions=dimensions or [], orderby=orderby or [], limit=limit, - access_control=access_control, + access_checker=access_checker, ignore_errors=ignore_errors, query_parameters=query_parameters, ) + + # Check authorization for all discovered nodes + if access_checker: # pragma: no cover + await access_checker.check(on_denied=AccessDenialMode.RAISE) + columns = [ assemble_column_metadata(col, use_semantic_metadata=True) # type: ignore for col in query_ast.select.projection @@ -322,8 +330,7 @@ async def get_measures_query( orderby: list[str] = None, engine_name: str | None = None, engine_version: str | None = None, - current_user: User | None = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, include_all_columns: bool = False, use_materialized: bool = True, preagg_requested: bool = False, @@ -343,15 +350,6 @@ async def get_measures_query( build_criteria = BuildCriteria( dialect=engine.dialect if engine and engine.dialect else Dialect.SPARK, ) - access_control = ( - access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.READ, - ) - if validate_access - else None - ) if not filters: filters = [] @@ -405,7 +403,7 @@ async def get_measures_query( ) parent_ast = await ( query_builder.ignore_errors() - .with_access_control(access_control) + .with_access_control(access_checker) .with_build_criteria(build_criteria) .add_dimensions(dimensions) .add_filters(filters) diff --git a/datajunction-server/datajunction_server/models/access.py b/datajunction-server/datajunction_server/models/access.py index fea00d6bd..a74a99d9b 100644 --- a/datajunction-server/datajunction_server/models/access.py +++ b/datajunction-server/datajunction_server/models/access.py @@ -2,21 +2,10 @@ Models for authorization """ -from copy import deepcopy -from enum import Enum -from datajunction_server.typing import StrEnum -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Set, Union - -from pydantic import BaseModel, Field -from sqlalchemy.ext.asyncio import AsyncSession +from dataclasses import dataclass -from datajunction_server.construction.utils import try_get_dj_node +from datajunction_server.typing import StrEnum from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.errors import DJAuthorizationException, DJError, ErrorCode -from datajunction_server.models.user import UserOutput - -if TYPE_CHECKING: - from datajunction_server.sql.parsing.ast import Column class ResourceType(StrEnum): @@ -40,39 +29,36 @@ class ResourceAction(StrEnum): MANAGE = "manage" # Grant/revoke permissions (RBAC-specific) -class Resource(BaseModel): +@dataclass(frozen=True) +class Resource: """ Base class for resource objects that are passed to injected validation logic """ - name: str # name of the node + name: str resource_type: ResourceType - owner: str def __hash__(self) -> int: - return hash((self.name, self.resource_type, self.owner)) + return hash((self.name, self.resource_type)) @classmethod - def from_node(cls, node: Union[NodeRevision, Node]) -> "Resource": + def from_node(cls, node: NodeRevision | Node) -> "Resource": """ Create a resource object from a DJ Node """ - return cls(name=node.name, resource_type=ResourceType.NODE, owner="") + return cls(name=node.name, resource_type=ResourceType.NODE) @classmethod def from_namespace(cls, namespace: str) -> "Resource": """ Create a resource object from a namespace """ - return cls( - name=namespace, - resource_type=ResourceType.NAMESPACE, - owner="", - ) + return cls(name=namespace, resource_type=ResourceType.NAMESPACE) -class ResourceRequest(BaseModel): +@dataclass(frozen=True) +class ResourceRequest: """ Resource Requests provide the information that is available to grant access to a resource @@ -80,19 +66,6 @@ class ResourceRequest(BaseModel): verb: ResourceAction access_object: Resource - approved: Optional[bool] = None - - def approve(self): - """ - Approve the request - """ - self.approved = True - - def deny(self): - """ - Deny the request - """ - self.approved = False def __hash__(self) -> int: return hash((self.verb, self.access_object)) @@ -101,188 +74,24 @@ def __eq__(self, other) -> bool: return self.verb == other.verb and self.access_object == other.access_object def __str__(self) -> str: - return ( # pragma: no cover + return ( f"{self.verb.value}:" f"{self.access_object.resource_type.value}/" f"{self.access_object.name}" ) -class AccessControlState(Enum): - """ - State values used by the ACS function to track when - """ - - DIRECT = "direct" - INDIRECT = "indirect" - - -class AccessControl(BaseModel): - """ - An access control provides all the information - necessary to deny or approve a request +@dataclass(frozen=True) +class AccessDecision: """ + The result of an access control check for a resource request. - user: str - state: AccessControlState - direct_requests: Set[ResourceRequest] - indirect_requests: Set[ResourceRequest] - validation_request_count: int - - @property - def requests(self) -> Set[ResourceRequest]: - """ - Get all direct and indirect requests as a single set - """ - return self.direct_requests | self.indirect_requests - - def approve_all(self): - """ - Approve all requests - """ - for request in self.requests: - request.approve() - - def deny_all(self): - """ - Deny all requests - """ - for request in self.requests: - request.deny() - - -ValidateAccessFn = Callable[[AccessControl], None] - - -class AccessControlStore(BaseModel): + Attributes: + request: The resource request that was checked + approved: Whether access was granted + reason: Optional explanation if access was denied """ - An access control store tracks all ResourceRequests - """ - - validate_access: Callable[["AccessControl"], bool] - user: Optional[UserOutput] - base_verb: Optional[ResourceAction] = None - state: AccessControlState = AccessControlState.DIRECT - direct_requests: Set[ResourceRequest] = Field(default_factory=set) - indirect_requests: Set[ResourceRequest] = Field(default_factory=set) - validation_request_count: int = 0 - validation_results: Set[ResourceRequest] = Field(default_factory=set) - - def add_request(self, request: ResourceRequest): - """ - Add a resource request to the store - """ - if self.state == AccessControlState.DIRECT: - self.direct_requests.add(request) - else: - self.indirect_requests.add(request) # pragma: no cover - - async def add_request_by_node_name( - self, - session: AsyncSession, - node_name: Union[str, "Column"], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node's name - """ - node = await try_get_dj_node(session, node_name) - if node is not None: - self.add_request_by_node(node, verb) - return node - - def add_request_by_node( - self, - node: Union[NodeRevision, Node], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node - """ - self.add_request( - ResourceRequest( - verb=verb or self.base_verb, - access_object=Resource.from_node(node), - ), - ) - - def add_request_by_nodes( - self, - nodes: Iterable[Union[NodeRevision, Node]], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node - """ - for node in nodes: # pragma: no cover - self.add_request( # pragma: no cover - ResourceRequest( - verb=verb or self.base_verb, - access_object=Resource.from_node(node), - ), - ) - def raise_if_invalid_requests(self, show_denials: bool = True): - """ - Raises if validate has ever given any invalid requests - """ - denied = ", ".join( - [ - str(request) - for request in self.validation_results - if not request.approved - ], - ) - if denied: - message = ( - f"Authorization of User `{self.user.username if self.user else 'no user'}` " - "for this request failed." - f"\nThe following requests were denied:\n{denied}." - if show_denials - else "" - ) - raise DJAuthorizationException( - errors=[ - DJError( - code=ErrorCode.UNAUTHORIZED_ACCESS, - message=message, - ), - ], - ) - - def validate(self) -> Set[ResourceRequest]: - """ - Checks with ACS and stores any returned invalid requests - """ - self.validation_request_count += 1 - - access_control = AccessControl( - user=self.user.username if self.user is not None else "", - state=self.state, - direct_requests=deepcopy(self.direct_requests), - indirect_requests=deepcopy(self.indirect_requests), - validation_request_count=self.validation_request_count, - ) - - self.validate_access(access_control) # type: ignore - - self.validation_results = access_control.requests - - if any((result.approved is None for result in self.validation_results)): - raise DJAuthorizationException( - errors=[ - DJError( - code=ErrorCode.INCOMPLETE_AUTHORIZATION, - message="Injected `validate_access` must approve or deny all requests.", - ), - ], - ) - - return self.validation_results - - def validate_and_raise(self): - """ - Validates with ACS and raises if any resources were denied - """ - self.validate() - self.raise_if_invalid_requests() + request: ResourceRequest + approved: bool + reason: str | None = None diff --git a/datajunction-server/tests/api/access_test.py b/datajunction-server/tests/api/access_test.py index 6d7749c05..64ba10f69 100644 --- a/datajunction-server/tests/api/access_test.py +++ b/datajunction-server/tests/api/access_test.py @@ -1,13 +1,73 @@ """ -Tests for the data API. +Tests for access control across APIs. """ +from http import HTTPStatus import pytest from httpx import AsyncClient -from datajunction_server.api.main import app -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import AuthorizationService from datajunction_server.models import access +from datajunction_server.models.access import ResourceType + + +class DenyAllAuthorizationService(AuthorizationService): + """ + Custom authorization service that denies all access. + """ + + name = "deny_all" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision(request=request, approved=False) + for request in requests + ] + + +class NamespaceOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that allows namespace access but denies all node access. + """ + + name = "namespace_only" + + def __init__(self, allowed_namespaces: list[str]): + self.allowed_namespaces = allowed_namespaces + + def authorize(self, auth_context, requests): + decisions = [] + for request in requests: + approved = False + if request.access_object.resource_type == ResourceType.NAMESPACE: + # Allow access to specified namespaces + approved = request.access_object.name in self.allowed_namespaces + # Deny all NODE access + decisions.append(access.AccessDecision(request=request, approved=approved)) + return decisions + + +class PartialNodeAuthorizationService(AuthorizationService): + """ + Authorization service that allows access to specific namespaces and nodes. + """ + + name = "partial_node" + + def __init__(self, allowed_namespaces: list[str], allowed_nodes: list[str]): + self.allowed_namespaces = allowed_namespaces + self.allowed_nodes = allowed_nodes + + def authorize(self, auth_context, requests): + decisions = [] + for request in requests: + approved = False + if request.access_object.resource_type == ResourceType.NAMESPACE: + approved = request.access_object.name in self.allowed_namespaces + elif request.access_object.resource_type == ResourceType.NODE: + approved = request.access_object.name in self.allowed_nodes + decisions.append(access.AccessDecision(request=request, approved=approved)) + return decisions class TestDataAccessControl: @@ -19,43 +79,43 @@ class TestDataAccessControl: async def test_get_metric_data_unauthorized( self, module__client_with_examples: AsyncClient, + mocker, ) -> None: """ Test retrieving data for a metric """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() - - return _validate_access + def get_deny_all_service(): + return DenyAllAuthorizationService() - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, + ) response = await module__client_with_examples.get("/data/basic.num_comments/") data = response.json() - assert "Authorization of User `dj` for this request failed" in data["message"] - assert "read:node/basic.num_comments" in data["message"] - assert "read:node/basic.source.comments" in data["message"] - assert response.status_code == 403 - app.dependency_overrides.clear() + assert "Access denied to" in data["message"] + assert "basic.num_comments" in data["message"] + assert response.status_code == HTTPStatus.FORBIDDEN @pytest.mark.asyncio async def test_sql_with_filters_orderby_no_access( self, module__client_with_examples: AsyncClient, + mocker, ): """ Test ``GET /sql/{node_name}/`` with various filters and dimensions using a version of the DJ roads database with namespaces. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() - - return _validate_access + def get_deny_all_service(): + return DenyAllAuthorizationService() - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, + ) node_name = "foo.bar.num_repair_orders" dimensions = [ @@ -76,13 +136,110 @@ def _validate_access(access_control: access.AccessControl): params={"dimensions": dimensions, "filters": filters, "orderby": orderby}, ) data = response.json() - assert sorted(list(data["message"])) == sorted( - list( - "Authorization of User `dj` for this request failed." - "\nThe following requests were denied:\nread:node/foo.bar.dispatcher, " - "read:node/foo.bar.repair_orders, read:node/foo.bar.municipality_dim, " - "read:node/foo.bar.num_repair_orders, read:node/foo.bar.hard_hat.", - ), + assert "Access denied to" in data["message"] + assert "foo.bar" in data["message"] + assert response.status_code == HTTPStatus.FORBIDDEN + + +class TestNamespaceAccessControl: + """ + Test access control for the ``GET /namespaces/{namespace}/`` endpoint. + """ + + @pytest.mark.asyncio + async def test_list_nodes_with_no_namespace_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with no namespace READ access should get empty list. + """ + + def get_deny_all_service(): + return DenyAllAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, ) - assert data["errors"][0]["code"] == 500 - app.dependency_overrides.clear() + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data == [] + + @pytest.mark.asyncio + async def test_list_nodes_with_namespace_access_but_no_node_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with namespace READ access but no node READ access should get empty list. + """ + + def get_namespace_only_service(): + return NamespaceOnlyAuthorizationService(allowed_namespaces=["default"]) + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_namespace_only_service, + ) + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data == [] + + @pytest.mark.asyncio + async def test_list_nodes_with_partial_node_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with namespace access and partial node access should get filtered list. + """ + allowed_nodes = [ + "default.repair_orders", + "default.hard_hat", + ] + + def get_partial_service(): + return PartialNodeAuthorizationService( + allowed_namespaces=["default"], + allowed_nodes=allowed_nodes, + ) + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_partial_service, + ) + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + + # Should only return the allowed nodes + returned_names = [node["name"] for node in data] + assert set(returned_names) == set(allowed_nodes) + + @pytest.mark.asyncio + async def test_list_nodes_with_full_access( + self, + module__client_with_examples: AsyncClient, + ): + """ + User with full access (PassthroughAuthorizationService) should get all nodes. + Default test client uses PassthroughAuthorizationService. + """ + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + + # Should return multiple nodes (the roads example has many) + assert len(data) > 0 + # Verify we get node details + assert all("name" in node for node in data) + assert all(node["name"].startswith("default.") for node in data) diff --git a/datajunction-server/tests/api/dimensions_access_test.py b/datajunction-server/tests/api/dimensions_access_test.py index f235be485..8c7f1381d 100644 --- a/datajunction-server/tests/api/dimensions_access_test.py +++ b/datajunction-server/tests/api/dimensions_access_test.py @@ -5,30 +5,43 @@ import pytest from httpx import AsyncClient -from datajunction_server.api.main import app +from datajunction_server.internal.access.authorization import AuthorizationService +from datajunction_server.models import access + + +class RepairOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that only approves nodes with 'repair' in the name. + """ + + name = "repair_only" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved="repair" in request.access_object.name, + ) + for request in requests + ] @pytest.mark.asyncio async def test_list_nodes_with_dimension_access_limited( module__client_with_roads: AsyncClient, + mocker, ) -> None: """ Test ``GET /dimensions/{name}/nodes/``. """ - from datajunction_server.internal.access.authorization import validate_access - from datajunction_server.models import access - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: - if "repair" in request.access_object.name: - request.approve() - else: - request.deny() + def get_repair_only_service(): + return RepairOnlyAuthorizationService() - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_repair_only_service, + ) response = await module__client_with_roads.get( "/dimensions/default.hard_hat/nodes/", @@ -47,4 +60,3 @@ def _validate_access(access_control: access.AccessControl): "default.avg_repair_order_discounts", } assert {node["name"] for node in data} == roads_repair_nodes - app.dependency_overrides.clear() diff --git a/datajunction-server/tests/api/graphql/common_dimensions_test.py b/datajunction-server/tests/api/graphql/common_dimensions_test.py index 2948c2cc4..74cfdbe34 100644 --- a/datajunction-server/tests/api/graphql/common_dimensions_test.py +++ b/datajunction-server/tests/api/graphql/common_dimensions_test.py @@ -95,7 +95,7 @@ async def test_get_common_dimensions( "role": None, "type": "int", } in data["data"]["commonDimensions"] - assert len(capture_queries) <= 18 # type: ignore + assert len(capture_queries) <= 28 # type: ignore @pytest.mark.asyncio diff --git a/datajunction-server/tests/api/namespaces_test.py b/datajunction-server/tests/api/namespaces_test.py index 6be0ee497..16d1136ba 100644 --- a/datajunction-server/tests/api/namespaces_test.py +++ b/datajunction-server/tests/api/namespaces_test.py @@ -8,8 +8,9 @@ import pytest from httpx import AsyncClient -from datajunction_server.api.main import app -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AuthorizationService, +) from datajunction_server.models import access @@ -831,28 +832,42 @@ async def test_export_namespaces_deployment(client_with_roads: AsyncClient): ] +class DbtOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that only approves namespaces containing 'dbt'. + """ + + name = "dbt_only" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved=( + request.access_object.resource_type == access.ResourceType.NAMESPACE + and "dbt" in request.access_object.name + ), + ) + for request in requests + ] + + @pytest.mark.asyncio async def test_list_all_namespaces_access_limited( client_with_dbt: AsyncClient, + mocker, ) -> None: """ Test ``GET /namespaces/``. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: - if ( - request.access_object.resource_type == access.ResourceType.NAMESPACE - and "dbt" in request.access_object.name - ): - request.approve() - else: - request.deny() + def get_dbt_only_service(): + return DbtOnlyAuthorizationService() - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_dbt_only_service, + ) response = await client_with_dbt.get("/namespaces/") @@ -864,63 +879,42 @@ def _validate_access(access_control: access.AccessControl): {"namespace": "dbt.source.stripe", "num_nodes": 1}, {"namespace": "dbt.transform", "num_nodes": 1}, ] - app.dependency_overrides.clear() -@pytest.mark.asyncio -async def test_list_all_namespaces_access_bad_injection( - client_with_service_setup: AsyncClient, -) -> None: +class DenyAllAuthorizationService(AuthorizationService): """ - Test ``GET /namespaces/``. + Authorization service that denies all access requests. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for i, request in enumerate(access_control.requests): - if i != 0: - request.approve() - - return _validate_access + name = "deny_all" - app.dependency_overrides[validate_access] = validate_access_override - - response = await client_with_service_setup.get("/namespaces/") - - assert response.status_code == 403 - assert response.json() == { - "message": "Injected `validate_access` must approve or deny all requests.", - "errors": [ - { - "code": 501, - "message": "Injected `validate_access` must approve or deny all requests.", - "debug": None, - "context": "", - }, - ], - "warnings": [], - } - app.dependency_overrides.clear() + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved=False, + ) + for request in requests + ] @pytest.mark.asyncio async def test_list_all_namespaces_deny_all( client_with_service_setup: AsyncClient, + mocker, ) -> None: """ Test ``GET /namespaces/``. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() - - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + def get_deny_all_service(): + return DenyAllAuthorizationService() + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, + ) response = await client_with_service_setup.get("/namespaces/") assert response.status_code in (200, 201) assert response.json() == [] - app.dependency_overrides.clear() diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index a1063e25b..ca4c742f5 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -11,7 +11,7 @@ from datajunction_server.database.node import Node, NodeRevision from datajunction_server.database.queryrequest import QueryBuildType, QueryRequest from datajunction_server.database.user import User -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import AuthorizationService from datajunction_server.models import access from datajunction_server.models.node_type import NodeType from datajunction_server.sql.parsing.backends.antlr4 import parse @@ -2726,24 +2726,31 @@ async def test_get_sql_for_metrics_failures(module__client_with_examples: AsyncC @pytest.mark.asyncio -async def test_get_sql_for_metrics_no_access(module__client_with_examples: AsyncClient): +async def test_get_sql_for_metrics_no_access( + module__client_with_examples: AsyncClient, + mocker, +): """ - Test getting sql for multiple metrics. + Test getting sql for multiple metrics with denied access. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - if access_control.state == "direct": - access_control.approve_all() - else: - access_control.deny_all() + # Custom authorization service that denies all requests + class DenyAllAuthorizationService(AuthorizationService): + name = "deny_all" - return _validate_access + def authorize(self, auth_context, requests): + return [ + access.AccessDecision(request=request, approved=False) + for request in requests + ] - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access_override - ) + def get_deny_all_service(): + return DenyAllAuthorizationService() + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, + ) response = await module__client_with_examples.get( "/sql/", params={ @@ -2762,20 +2769,12 @@ def _validate_access(access_control: access.AccessControl): }, ) data = response.json() - # assert "Authorization of User `dj` for this request failed.\n" in data["message"] - assert "The following requests were denied:\n" in data["message"] - assert "read:node/default.municipality_dim" in data["message"] - assert "read:node/default.dispatcher" in data["message"] - assert "read:node/default.repair_orders_fact" in data["message"] - assert "read:node/default.hard_hat" in data["message"] - assert data["errors"][0]["code"] == 500 - - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access + assert data["message"] == ( + "Access denied to 10 resource(s): default.discounted_orders_rate, " + "default.discounted_orders_rate, default.num_repair_orders, " + "default.repair_orders_fact, default.hard_hat and 5 more" ) - module__client_with_examples.app.dependency_overrides.clear() - @pytest.mark.asyncio async def test_get_sql_for_metrics2(client_with_examples: AsyncClient): @@ -3369,31 +3368,39 @@ async def test_get_sql_for_metrics_orderby_not_in_dimensions( @pytest.mark.asyncio async def test_get_sql_for_metrics_orderby_not_in_dimensions_no_access( module__client_with_examples: AsyncClient, + mocker, ): """ Test that we extract the columns from filters to validate that they are from shared dimensions """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: + # Custom authorization service that denies specific nodes + class SelectiveDenialAuthorizationService(AuthorizationService): + name = "selective_denial" + + def authorize(self, auth_context, requests): + denied_nodes = { + "foo.bar.avg_repair_price", + "default.hard_hat.city", + } + return [ + access.AccessDecision(request=request, approved=False) if ( request.access_object.resource_type == access.ResourceType.NODE - and request.access_object.name - in ( - "foo.bar.avg_repair_price", - "default.hard_hat.city", - ) - ): - request.deny() - else: - request.approve() - - return _validate_access - - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access_override + and request.access_object.name in denied_nodes + ) + else access.AccessDecision(request=request, approved=True) + for request in requests + ] + + def get_selective_denial_service(): + return SelectiveDenialAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + return_value=SelectiveDenialAuthorizationService(), ) + response = await module__client_with_examples.get( "/sql/", params={ @@ -3410,7 +3417,6 @@ def _validate_access(access_control: access.AccessControl): "Columns ['default.hard_hat.city'] in order by " "clause must also be specified in the metrics or dimensions" ) - module__client_with_examples.app.dependency_overrides.clear() @pytest.mark.asyncio diff --git a/datajunction-server/tests/conftest.py b/datajunction-server/tests/conftest.py index 2728dbb20..fdd216967 100644 --- a/datajunction-server/tests/conftest.py +++ b/datajunction-server/tests/conftest.py @@ -55,8 +55,10 @@ from datajunction_server.database.engine import Engine from datajunction_server.database.user import User from datajunction_server.errors import DJQueryServiceClientEntityNotFound -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models.access import AccessControl, ValidateAccessFn +from datajunction_server.internal.access.authorization import ( + get_authorization_service, + PassthroughAuthorizationService, +) from datajunction_server.models.materialization import MaterializationInfo from datajunction_server.models.query import QueryCreate, QueryWithResults from datajunction_server.models.user import OAuthProvider @@ -523,16 +525,14 @@ def get_session_override() -> AsyncSession: def get_settings_override() -> Settings: return settings_no_qs - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() if use_patch: app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service async with AsyncClient( transport=httpx.ASGITransport(app=app), @@ -805,18 +805,16 @@ def get_query_service_client_override( def get_settings_override() -> Settings: return settings - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() def get_session_override() -> AsyncSession: return session app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service app.dependency_overrides[get_query_service_client] = ( get_query_service_client_override ) @@ -944,6 +942,15 @@ async def create_default_user(session: AsyncSession) -> User: return user +@pytest_asyncio.fixture +async def default_user(session: AsyncSession): + """ + Create a default user for testing. + """ + user = await create_default_user(session) + yield user + + @pytest_asyncio.fixture(scope="module") async def module__client( request, @@ -975,11 +982,9 @@ def get_session_override() -> AsyncSession: def get_settings_override() -> Settings: return module__settings - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() module_mocker.patch( "datajunction_server.api.materializations.get_query_service_client", @@ -988,7 +993,7 @@ def _(access_control: AccessControl): app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service app.dependency_overrides[get_query_service_client] = ( get_query_service_client_override ) diff --git a/datajunction-server/tests/construction/build_test.py b/datajunction-server/tests/construction/build_test.py index a0b66d2f5..458e1f8e8 100644 --- a/datajunction-server/tests/construction/build_test.py +++ b/datajunction-server/tests/construction/build_test.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession - +import pytest_asyncio import datajunction_server.sql.parsing.types as ct from datajunction_server.construction.build import ( build_materialized_cube_node, @@ -23,10 +23,54 @@ from datajunction_server.models.node_type import NodeType from datajunction_server.naming import amenable_name from datajunction_server.sql.parsing.backends.antlr4 import ast, parse +from datajunction_server.internal.access.authorization.service import ( + AuthorizationService, +) +from datajunction_server.internal.access.authorization.validator import AccessChecker +from datajunction_server.internal.access.authorization.context import AuthContext +from datajunction_server.internal.access.authentication.basic import get_user +from datajunction_server.models.access import AccessDecision + + +class AllowAllAuthorizationService(AuthorizationService): + """ + Custom authorization service that allows all access. + """ + + name = "allow_all" + + def authorize(self, auth_context, requests): + return [AccessDecision(request=request, approved=True) for request in requests] + + +@pytest_asyncio.fixture +async def access_checker( + construction_session: AsyncSession, + default_user: User, + mocker, +) -> AccessChecker: + """ + Fixture to mock access checker to allow all access. + """ + user = await get_user(default_user.username, construction_session) + + def mock_get_allow_all_service(): + return AllowAllAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + mock_get_allow_all_service, + ) + return AccessChecker( + await AuthContext.from_user(construction_session, user), + ) @pytest.mark.asyncio -async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSession): +async def test_build_metric_with_dimensions_aggs( + construction_session: AsyncSession, + access_checker: AccessChecker, +): """ Test building metric with dimensions """ @@ -40,6 +84,7 @@ async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSess filters=[], dimensions=["basic.dimension.users.country", "basic.dimension.users.gender"], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( @@ -84,6 +129,7 @@ async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSess @pytest.mark.asyncio async def test_build_metric_with_required_dimensions( construction_session: AsyncSession, + access_checker: AccessChecker, ): """ Test building metric with bound dimensions @@ -99,6 +145,7 @@ async def test_build_metric_with_required_dimensions( filters=[], dimensions=["basic.dimension.users.country", "basic.dimension.users.gender"], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( @@ -247,7 +294,10 @@ async def test_raise_on_build_without_required_dimension_column( @pytest.mark.asyncio -async def test_build_metric_with_dimensions_filters(construction_session: AsyncSession): +async def test_build_metric_with_dimensions_filters( + construction_session: AsyncSession, + access_checker: AccessChecker, +): """ Test building metric with dimension filters """ @@ -264,6 +314,7 @@ async def test_build_metric_with_dimensions_filters(construction_session: AsyncS ], dimensions=[], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( diff --git a/datajunction-server/tests/internal/authorization_test.py b/datajunction-server/tests/internal/authorization_test.py new file mode 100644 index 000000000..5e2b61b92 --- /dev/null +++ b/datajunction-server/tests/internal/authorization_test.py @@ -0,0 +1,2006 @@ +"""Tests for RBAC authorization logic.""" + +from datetime import datetime, timedelta, timezone + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from datajunction_server.database.group_member import GroupMember +from datajunction_server.database.rbac import Role, RoleAssignment, RoleScope +from datajunction_server.database.user import PrincipalKind, User +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, + AuthContext, + PassthroughAuthorizationService, + RBACAuthorizationService, + get_authorization_service, +) +from datajunction_server.errors import DJAuthorizationException +from datajunction_server.internal.access.authentication.basic import get_user +from datajunction_server.models.access import ( + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.group_membership import ( + GroupMembershipService, +) + + +class TestResourceMatching: + """Tests for wildcard pattern matching.""" + + def test_exact_match(self): + """Test exact string matching (no wildcards).""" + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.revenue", + ) + assert not RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.cost", + ) + + def test_wildcard_all(self): + """Test the universal wildcard *.""" + assert RBACAuthorizationService.resource_matches_pattern("anything", "*") + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue.quarterly", + "*", + ) + assert RBACAuthorizationService.resource_matches_pattern("", "*") + + def test_namespace_wildcard(self): + """Test namespace wildcard patterns.""" + # finance.* matches finance.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.*", + ) + + # finance.* matches finance.quarterly.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "finance.quarterly.revenue", + "finance.*", + ) + + # finance.* does NOT match finance (exact namespace) + assert not RBACAuthorizationService.resource_matches_pattern( + "finance", + "finance.*", + ) + + # finance.* does NOT match marketing.revenue + assert not RBACAuthorizationService.resource_matches_pattern( + "marketing.revenue", + "finance.*", + ) + + def test_nested_namespace_wildcard(self): + """Test nested namespace patterns.""" + # users.alice.* matches users.alice.dashboard + assert RBACAuthorizationService.resource_matches_pattern( + "users.alice.dashboard", + "users.alice.*", + ) + + # users.alice.* matches users.alice.metrics.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "users.alice.metrics.revenue", + "users.alice.*", + ) + + # users.alice.* does NOT match users.bob.dashboard + assert not RBACAuthorizationService.resource_matches_pattern( + "users.bob.dashboard", + "users.alice.*", + ) + + def test_edge_case_patterns(self): + """Test edge case patterns that reach the fallback logic (line 167-168). + + These patterns are unusual but should be handled gracefully: + - ".*" -> strips to empty string + - "**" -> strips to empty string + These are treated as global wildcards (match everything). + """ + # ".*" pattern - after stripping "*" and ".", becomes empty + assert RBACAuthorizationService.resource_matches_pattern( + "anything", + ".*", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + ".*", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "", + ".*", + ) + + # "**" pattern - after stripping "*", becomes empty + assert RBACAuthorizationService.resource_matches_pattern( + "anything", + "**", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "deeply.nested.resource.name", + "**", + ) + + def test_wildcard_in_middle_not_supported(self): + """Test that wildcards in the middle of patterns don't work as expected. + + Note: The current implementation only supports trailing wildcards. + Patterns like "finance.*.revenue" are NOT supported as glob patterns. + """ + # "finance.*.revenue" - contains * but not at end + # This will strip trailing * (none) and compare as prefix + # So it won't match "finance.quarterly.revenue" + assert not RBACAuthorizationService.resource_matches_pattern( + "finance.quarterly.revenue", + "finance.*.revenue", + ) + + # It would only match if resource literally starts with "finance.*.revenue." + # which is unlikely in practice + + +@pytest.mark.asyncio +class TestRBACPermissionChecks: + """Tests for RBAC permission checking.""" + + async def test_no_roles_returns_false( + self, + default_user: User, + session: AsyncSession, + ): + """Test that user with no roles gets False (no explicit rule).""" + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + + assert result is False + + async def test_explicit_grant_exact_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test explicit permission grant with exact resource match.""" + # Create role with exact scope + role = Role( + name="test-role", + created_by_id=default_user.id, + ) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.revenue", + ) + session.add(scope) + + # Assign role to user + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check permission + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result is True + + async def test_explicit_grant_wildcard_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test permission grant via wildcard pattern.""" + # Create role with wildcard scope + role = Role( + name="finance-reader", + created_by_id=default_user.id, + ) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", # Wildcard + ) + session.add(scope) + + # Assign role + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check permissions on various resources + user = await get_user(username=default_user.username, session=session) + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.quarterly.revenue", + ) + assert result2 is True + + # Different namespace - no match + result3 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="marketing.revenue", + ) + assert result3 is False + + async def test_wrong_action_no_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test that wrong action doesn't grant permission.""" + role = Role(name="reader-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + # Only READ permission + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check READ - should be granted + user = await get_user(username=default_user.username, session=session) + result_read = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_read is True + + # Check WRITE - should be None (no explicit rule) + result_write = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_write is False + + async def test_expired_assignment_ignored( + self, + default_user: User, + session: AsyncSession, + ): + """Test that expired role assignments are ignored.""" + role = Role(name="temp-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + # Assignment expired 1 hour ago + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + session.add(assignment) + await session.commit() + + # Should not grant permission (expired) + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert not result + + async def test_multiple_roles_any_grants( + self, + default_user: User, + session: AsyncSession, + ): + """Test that having ANY role that grants permission is sufficient.""" + # Role 1: No matching scope + role1 = Role(name="marketing-role", created_by_id=default_user.id) + session.add(role1) + await session.flush() + + scope1 = RoleScope( + role_id=role1.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="marketing.*", + ) + session.add(scope1) + + assignment1 = RoleAssignment( + principal_id=default_user.id, + role_id=role1.id, + granted_by_id=default_user.id, + ) + session.add(assignment1) + + # Role 2: Matching scope + role2 = Role(name="finance-role", created_by_id=default_user.id) + session.add(role2) + await session.flush() + + scope2 = RoleScope( + role_id=role2.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope2) + + assignment2 = RoleAssignment( + principal_id=default_user.id, + role_id=role2.id, + granted_by_id=default_user.id, + ) + session.add(assignment2) + await session.commit() + + # Should grant because role2 matches + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result is True + + async def test_universal_wildcard( + self, + default_user: User, + session: AsyncSession, + ): + """Test that * wildcard grants access to everything.""" + role = Role(name="super-admin", created_by_id=default_user.id) + session.add(role) + await session.flush() + + # Universal wildcard + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Should grant for anything + user = await get_user(username=default_user.username, session=session) + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="anything.at.all", + ) + assert result2 is True + + +@pytest.mark.asyncio +class TestAuthorizationService: + """Tests for the synchronous AuthorizationService.""" + + async def test_passthrough_service_approves_all( + self, + default_user: User, + session: AsyncSession, + ): + """Test that PassthroughAuthorizationService approves everything.""" + # Get existing user + user = await get_user(username=default_user.username, session=session) + + service = PassthroughAuthorizationService() + + requests = [ + ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + ), + ), + ResourceRequest( + verb=ResourceAction.DELETE, + access_object=Resource( + name="secret.data", + resource_type=ResourceType.NODE, + ), + ), + ] + + result = service.authorize(user, requests) # Now sync! + + assert len(result) == 2 + assert all(req.approved for req in result) + + async def test_rbac_service_with_permissions( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test RBACAuthorizationService with granted permissions.""" + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + role = Role(name="test-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + ), + ), + ResourceRequest( + verb=ResourceAction.WRITE, # Not granted + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + ), + ), + ] + + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests(requests) + result = await access_checker.check(on_denied=AccessDenialMode.RETURN) + assert len(result) == 2 + assert result[0].approved is True # READ granted + assert result[1].approved is False # WRITE not granted + + async def test_get_authorization_service_factory(self, mocker): + """Test the factory function returns correct service.""" + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + service = get_authorization_service() + assert isinstance(service, RBACAuthorizationService) + + # Test passthrough provider + mock_settings.authorization_provider = "passthrough" + service = get_authorization_service() + + # Cached instance, so need to clear cache + assert isinstance(service, RBACAuthorizationService) + + # Clear LRU cache to test different provider + get_authorization_service.cache_clear() + service = get_authorization_service() + assert isinstance(service, PassthroughAuthorizationService) + + # Test unknown provider + mock_settings.authorization_provider = "unknown" + get_authorization_service.cache_clear() + with pytest.raises(ValueError) as exc_info: + get_authorization_service() + assert "unknown" in str(exc_info.value).lower() + assert "rbac" in str(exc_info.value).lower() + assert "passthrough" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +class TestGroupBasedPermissions: + """Tests for group-based role assignments.""" + + async def test_user_inherits_group_permissions( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test that users inherit permissions from groups they belong to.""" + # Create a group + group = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + await session.flush() + + # Create role and assign to group + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + # Assign role to group + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Expire the user object so we get a fresh load + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check permission - should be granted via group + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests( + [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue.something", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + resource_type=ResourceType.NAMESPACE, + name="finance.revenue", + ), + ), + ], + ) + results = await access_checker.check(on_denied=AccessDenialMode.RETURN) + assert results[0].approved is True + assert results[1].approved is True + + async def test_user_no_permission_without_group( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test that user without group membership doesn't get permission.""" + # Create a group with permissions + group = User( + username="marketing-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Create role and assign to GROUP + role = Role(name="marketing-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="marketing.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user without adding them to the group + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check permission - should NOT be granted (user not in group) + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests( + [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NAMESPACE, + ), + ), + ], + ) + results = await access_checker.check(on_denied=AccessDenialMode.RETURN) + assert results[0].approved is False + + +@pytest.mark.asyncio +class TestCrossResourceTypePermissions: + """Tests for namespace scopes covering nodes.""" + + async def test_namespace_scope_covers_nodes( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that namespace scope grants permission for nodes in that namespace.""" + # Create role with NAMESPACE scope + role = Role(name="finance-ns-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, # Namespace scope + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + + # Reload user with member_of and group role_assignments eagerly loaded + user = await get_user(username=default_user.username, session=session) + + # Check permission for NAMESPACE resource - should be granted + result_namespace = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_namespace is True + + # Check permission for NODE resource in that namespace - should ALSO be granted! + result_node = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue.total", # Node in finance namespace + ) + assert result_node is True + + # Node in different namespace - should NOT be granted + result_other = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="marketing.revenue.total", + ) + assert result_other is False + + async def test_namespace_scope_nested_namespaces( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test namespace scope with nested namespaces.""" + # Create role with wildcard namespace scope + role = Role(name="finance-all", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", # finance.* + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # finance.quarterly.revenue node should match finance.* namespace + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.quarterly.revenue", + ) + assert result is True + + async def test_node_scope_does_not_cover_namespace( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that NODE scope does NOT grant permission for NAMESPACE resources.""" + # Create role with NODE scope + role = Role(name="specific-node-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NODE, # NODE scope + scope_value="finance.revenue", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Check permission for NODE - should be granted + result_node = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + assert result_node is True + + # Check permission for NAMESPACE - should NOT be granted (cross-type only works one way) + result_namespace = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_namespace is False + + +@pytest.mark.asyncio +class TestGlobalAccessScope: + """Tests for global access (empty or * scope_value).""" + + async def test_empty_scope_grants_global_access( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that empty scope_value grants access to all resources of that type.""" + # Create role with empty scope_value (global) + role = Role(name="global-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="", # Global! (empty string) + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Should grant for any namespace + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="marketing.anything", + ) + assert result2 is True + + # Should NOT grant for different resource type + result3 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, # Different type + resource_name="finance.revenue", + ) + assert result3 is False + + async def test_star_scope_grants_global_access( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that "*" scope_value grants access to all resources of that type.""" + # Create role with "*" scope_value + role = Role(name="star-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NODE, + scope_value="*", # Wildcard for all + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Should grant for any node + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="anything.anywhere.node", + ) + assert result is True + + +@pytest.mark.asyncio +class TestPermissionHierarchy: + """Tests for permission hierarchy (MANAGE > DELETE > WRITE > READ).""" + + async def test_manage_implies_all_permissions( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that MANAGE permission grants all other permissions.""" + # Create role with MANAGE permission + role = Role(name="finance-manager", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.MANAGE, # Top-level permission + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # MANAGE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant DELETE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.DELETE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant EXECUTE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.EXECUTE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant MANAGE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.MANAGE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + async def test_write_implies_read( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that WRITE permission implies READ.""" + # Create role with WRITE permission + role = Role(name="finance-writer", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # WRITE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # WRITE should grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # WRITE should NOT grant DELETE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.DELETE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is False + ) + + async def test_read_does_not_imply_write( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that READ permission does NOT imply WRITE.""" + # Create role with only READ permission + role = Role(name="readonly-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # READ should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # READ should NOT grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is False + ) + + async def test_execute_implies_read( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that EXECUTE permission implies READ.""" + # Create role with EXECUTE permission + role = Role(name="query-executor", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.EXECUTE, + scope_type=ResourceType.NODE, + scope_value="finance.revenue", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # EXECUTE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is True + ) + + # EXECUTE should grant EXECUTE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.EXECUTE, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is True + ) + + # EXECUTE should NOT grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is False + ) + + +@pytest.mark.asyncio +class TestAuthContext: + """Tests for AuthContext and effective assignments.""" + + async def test_auth_context_from_user_direct_assignments_only( + self, + default_user: User, + session: AsyncSession, + ): + """AuthContext includes user's direct role assignments.""" + # Create role and assign to user + role = Role(name="test-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with assignments + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext + auth_context = await AuthContext.from_user(session, user) + + assert auth_context.user_id == user.id + assert auth_context.username == user.username + assert len(auth_context.role_assignments) == 1 + assert auth_context.role_assignments[0].role.name == "test-role" + + async def test_auth_context_includes_group_assignments( + self, + default_user: User, + session: AsyncSession, + ): + """AuthContext flattens user's + groups' assignments.""" + # Create a group + group = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Create role for user (direct) + user_role = Role(name="user-role", created_by_id=default_user.id) + session.add(user_role) + await session.flush() + + user_scope = RoleScope( + role_id=user_role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="personal.*", + ) + session.add(user_scope) + + user_assignment = RoleAssignment( + principal_id=default_user.id, + role_id=user_role.id, + granted_by_id=default_user.id, + ) + session.add(user_assignment) + + # Create role for group + group_role = Role(name="group-role", created_by_id=default_user.id) + session.add(group_role) + await session.flush() + + group_scope = RoleScope( + role_id=group_role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(group_scope) + + group_assignment = RoleAssignment( + principal_id=group.id, + role_id=group_role.id, + granted_by_id=default_user.id, + ) + session.add(group_assignment) + await session.commit() + + # Reload user + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext (should include both) + auth_context = await AuthContext.from_user(session, user) + + assert auth_context.user_id == user.id + assert len(auth_context.role_assignments) == 2 # User's + group's + + role_names = {a.role.name for a in auth_context.role_assignments} + assert role_names == {"user-role", "group-role"} + + async def test_auth_context_with_multiple_groups( + self, + default_user: User, + session: AsyncSession, + ): + """User in multiple groups gets all group assignments.""" + # Create two groups + group1 = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + group2 = User( + username="data-eng-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add_all([group1, group2]) + await session.flush() + + # Add user to both groups + membership1 = GroupMember(group_id=group1.id, member_id=default_user.id) + membership2 = GroupMember(group_id=group2.id, member_id=default_user.id) + session.add_all([membership1, membership2]) + + # Give each group a role + role1 = Role(name="finance-role", created_by_id=default_user.id) + role2 = Role(name="data-eng-role", created_by_id=default_user.id) + session.add_all([role1, role2]) + await session.flush() + + scope1 = RoleScope( + role_id=role1.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + scope2 = RoleScope( + role_id=role2.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="analytics.*", + ) + session.add_all([scope1, scope2]) + + assignment1 = RoleAssignment( + principal_id=group1.id, + role_id=role1.id, + granted_by_id=default_user.id, + ) + assignment2 = RoleAssignment( + principal_id=group2.id, + role_id=role2.id, + granted_by_id=default_user.id, + ) + session.add_all([assignment1, assignment2]) + await session.commit() + + # Reload user + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext + auth_context = await AuthContext.from_user(session, user) + + # Should have assignments from both groups + assert len(auth_context.role_assignments) == 2 + role_names = {a.role.name for a in auth_context.role_assignments} + assert role_names == {"finance-role", "data-eng-role"} + + +@pytest.mark.asyncio +class TestCheckAccess: + """Tests for authorize() function with different denial modes.""" + + async def test_check_access_filter_mode_returns_only_approved( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """FILTER mode returns only approved requests (default).""" + # Give user access to finance.* but not marketing.* + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Request access to 3 nodes: 2 accessible, 1 not + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.cost", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + ), + ), + ] + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check access (default FILTER mode) + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests(requests) + approved = await access_checker.approved_resource_names() + + # Should only return the 2 approved (finance.* nodes) + assert len(approved) == 2 + assert approved == ["finance.revenue", "finance.cost"] + + async def test_check_access_raise_mode_throws_on_denial( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """ + Raise mode throws DJAuthorizationException when access denied + for a user with no permissions. + """ + user = await get_user(username=default_user.username, session=session) + + request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + ), + ) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + with pytest.raises(DJAuthorizationException) as exc_info: + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_request(request) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + + # Check exception message + assert "Access denied to 1 resource(s): finance.revenue" in str(exc_info.value) + + async def test_check_access_raise_mode_succeeds_when_approved( + self, + default_user: User, + session: AsyncSession, + ): + """RAISE mode succeeds without exception when all approved.""" + # Give user access + role = Role(name="finance-writer", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + ), + ) + + # Should NOT raise + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_request(request) + result = await access_checker.check(on_denied=AccessDenialMode.RAISE) + assert len(result) == 1 + assert result[0].approved is True + + async def test_check_access_return_mode( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """RETURN_ALL mode returns all requests with approved field set.""" + # Give user access to finance.* only + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Request access to 3 nodes: 2 accessible, 1 not + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.cost", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + ), + ), + ] + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check access with RETURN_ALL + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests(requests) + + all_requests = await access_checker.check(on_denied=AccessDenialMode.RETURN) + + # Should return all 3 requests + assert len(all_requests) == 3 + + # 2 approved, 1 denied + approved = [r for r in all_requests if r.approved] + denied = [r for r in all_requests if not r.approved] + + assert len(approved) == 2 + assert len(denied) == 1 + assert denied[0].request.access_object.name == "marketing.revenue" + + +@pytest.mark.asyncio +class TestGetEffectiveAssignments: + """Tests for get_effective_assignments() with GroupMembershipService.""" + + async def test_effective_assignments_user_only( + self, + default_user: User, + session: AsyncSession, + ): + """User with no groups gets only direct assignments.""" + # Give user a direct assignment + role = Role(name="personal-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="personal.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Get effective assignments + assignments = await AuthContext.get_effective_assignments(session, user) + + assert len(assignments) == 1 + assert assignments[0].role.name == "personal-role" + + async def test_effective_assignments_with_postgres_groups( + self, + default_user: User, + session: AsyncSession, + ): + """Effective assignments includes groups from PostgresGroupMembershipService.""" + # Create group + group = User( + username="test-group", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group via GroupMember table + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Create role and assign to GROUP + role = Role(name="group-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="shared.*", + ) + session.add(scope) + + group_assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(group_assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Get effective assignments (should use PostgresGroupMembershipService by default) + assignments = await AuthContext.get_effective_assignments(session, user) + + # Should include group's assignment + assert len(assignments) >= 1 + role_names = {a.role.name for a in assignments} + assert "group-role" in role_names + + async def test_effective_assignments_with_custom_service( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """Custom GroupMembershipService can be provided.""" + + # Create a mock service that returns a specific group + class MockGroupService(GroupMembershipService): + name = "mock" + + async def is_user_in_group(self, session, username, group_name): + return group_name == "mock-group" + + async def get_user_groups(self, session, username): + return ["mock-group"] + + async def add_user_to_group(self, session, username, group_name): + pass + + async def remove_user_from_group(self, session, username, group_name): + pass + + # Create the mock group in DB + group = User( + username="mock-group", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Assign role to mock group + role = Role(name="mock-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.EXECUTE, + scope_type=ResourceType.NAMESPACE, + scope_value="special.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Use custom service + mock_service = MockGroupService() + mocker.patch( + "datajunction_server.internal.access.authorization.context.get_group_membership_service", + lambda: mock_service, + ) + assignments = await AuthContext.get_effective_assignments( + session, + user, + ) + + # Should include mock group's assignment + role_names = {a.role.name for a in assignments} + assert "mock-role" in role_names + + +@pytest.mark.asyncio +class TestCheckAccessIntegration: + """Integration tests for authorize() with real authorization flow.""" + + async def test_check_access_with_group_based_permissions( + self, + default_user: User, + session: AsyncSession, + ): + """End-to-end: User gets access via group membership.""" + # Create group + group = User( + username="data-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Give group permission + role = Role(name="data-team-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="data.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Request access to data.* node + request = ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="data.user_events", + resource_type=ResourceType.NODE, + ), + ) + + # Should be approved via group + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_request(request) + approved = await access_checker.check(on_denied=AccessDenialMode.RETURN) + + assert len(approved) == 1 + assert approved[0].approved is True + + async def test_check_access_with_mixed_approval( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """Some requests approved, some denied.""" + # Give access to finance.* only + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Mix of accessible and inaccessible + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + ), + ), + ] + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.service.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # FILTER mode - returns only approved + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests(requests) + filtered = await access_checker.check(on_denied=AccessDenialMode.FILTER) + assert len(filtered) == 1 + assert filtered[0].request.access_object.name == "finance.revenue" + + # RETURN_ALL mode - returns both + all_results = await access_checker.check(on_denied=AccessDenialMode.RETURN) + assert len(all_results) == 2 + assert all_results[0].approved is True + assert all_results[1].approved is False + + # RAISE mode - should raise + with pytest.raises(DJAuthorizationException): + await access_checker.check(on_denied=AccessDenialMode.RAISE) diff --git a/datajunction-server/tests/internal/caching/query_cache_manager_test.py b/datajunction-server/tests/internal/caching/query_cache_manager_test.py index b86796edb..a31773744 100644 --- a/datajunction-server/tests/internal/caching/query_cache_manager_test.py +++ b/datajunction-server/tests/internal/caching/query_cache_manager_test.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace from unittest import mock from unittest.mock import patch @@ -13,6 +14,8 @@ QueryRequestParams, ) from datajunction_server.database.queryrequest import QueryBuildType +from datajunction_server.database.user import User, OAuthProvider +from datajunction_server.internal.access.authorization import AccessChecker class DummyRequest: @@ -27,6 +30,15 @@ def __init__(self, cache_control: str | None = None): self.headers = Headers(headers) self.method = "GET" + # Add state with a dummy user for get_current_user + self.state = SimpleNamespace( + user=User( + username="testuser", + email="test@example.com", + oauth_provider=OAuthProvider.BASIC, + ), + ) + @pytest.mark.asyncio async def test_cache_key_prefix_uses_query_type(): @@ -66,10 +78,17 @@ async def test_fallback_calls_get_measures_query(): """ Should call get_measures_query with correct args. """ - with patch( - "datajunction_server.internal.caching.query_cache_manager.get_measures_query", - return_value=[{"sql": "SELECT * FROM test"}], - ) as get_measures_query_mock: + mock_access_checker = mock.AsyncMock(spec=AccessChecker) + with ( + patch( + "datajunction_server.internal.caching.query_cache_manager.get_measures_query", + return_value=[{"sql": "SELECT * FROM test"}], + ) as get_measures_query_mock, + patch( + "datajunction_server.internal.caching.query_cache_manager.build_access_checker_from_request", + return_value=mock_access_checker, + ), + ): cache = CachelibCache() manager = QueryCacheManager(cache, QueryBuildType.MEASURES) params = QueryRequestParams( @@ -89,49 +108,56 @@ async def test_get_or_load_respects_cache_control(): """ Full flow test to ensure Cache-Control is respected. """ - with patch( - "datajunction_server.internal.caching.query_cache_manager.get_measures_query", - return_value=[{"sql": "SELECT * FROM test"}], - ): - with patch( + mock_access_checker = mock.AsyncMock(spec=AccessChecker) + with ( + patch( + "datajunction_server.internal.caching.query_cache_manager.get_measures_query", + return_value=[{"sql": "SELECT * FROM test"}], + ), + patch( "datajunction_server.internal.caching.query_cache_manager.VersionedQueryKey.version_query_request", return_value="versioned123", - ): - cache = CachelibCache() - manager = QueryCacheManager(cache, QueryBuildType.MEASURES) - params = QueryRequestParams( - nodes=["foo"], - dimensions=["dim1"], - filters=[], - ) - - # Put stale value in cache to test hit vs miss - key = await manager.build_cache_key(DummyRequest(), params) - cache.set(key, [{"sql": "CACHED"}]) - - background = BackgroundTasks() - - # `no-cache` => should bypass cache - request = DummyRequest(cache_control="no-cache") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "SELECT * FROM test"}] - - # Run tasks, should store - for task in background.tasks: - await task() - assert cache.get(key) == [{"sql": "SELECT * FROM test"}] - - # `no-store` => should hit cache, but not store - cache.set(key, [{"sql": "CACHED"}]) - request = DummyRequest(cache_control="no-store") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "CACHED"}] # hits stale - - # `no-cache, no-store` => should always fallback but never store - request = DummyRequest(cache_control="no-cache, no-store") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "SELECT * FROM test"}] - cache.get(key) == [{"sql": "CACHED"}] # still stale + ), + patch( + "datajunction_server.internal.caching.query_cache_manager.build_access_checker_from_request", + return_value=mock_access_checker, + ), + ): + cache = CachelibCache() + manager = QueryCacheManager(cache, QueryBuildType.MEASURES) + params = QueryRequestParams( + nodes=["foo"], + dimensions=["dim1"], + filters=[], + ) + + # Put stale value in cache to test hit vs miss + key = await manager.build_cache_key(DummyRequest(), params) + cache.set(key, [{"sql": "CACHED"}]) + + background = BackgroundTasks() + + # `no-cache` => should bypass cache + request = DummyRequest(cache_control="no-cache") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "SELECT * FROM test"}] + + # Run tasks, should store + for task in background.tasks: + await task() + assert cache.get(key) == [{"sql": "SELECT * FROM test"}] + + # `no-store` => should hit cache, but not store + cache.set(key, [{"sql": "CACHED"}]) + request = DummyRequest(cache_control="no-store") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "CACHED"}] # hits stale + + # `no-cache, no-store` => should always fallback but never store + request = DummyRequest(cache_control="no-cache, no-store") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "SELECT * FROM test"}] + cache.get(key) == [{"sql": "CACHED"}] # still stale @pytest.mark.asyncio diff --git a/datajunction-server/tests/internal/deployment/orchestration_test.py b/datajunction-server/tests/internal/deployment/orchestration_test.py index 1cf30eac7..dd8e8fe31 100644 --- a/datajunction-server/tests/internal/deployment/orchestration_test.py +++ b/datajunction-server/tests/internal/deployment/orchestration_test.py @@ -42,7 +42,6 @@ def mock_deployment_context(current_user: User): context.current_user = current_user context.request = Mock() context.query_service_client = Mock() - context.validate_access = AsyncMock(return_value=True) context.background_tasks = Mock() context.save_history = AsyncMock() context.cache = Mock() diff --git a/datajunction-server/tests/models/access_test.py b/datajunction-server/tests/models/access_test.py new file mode 100644 index 000000000..1781e1e32 --- /dev/null +++ b/datajunction-server/tests/models/access_test.py @@ -0,0 +1,128 @@ +""" +Tests for ``datajunction_server.models.access``. +""" + +from datajunction_server.models.access import ( + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) + + +class TestResource: + """Tests for Resource dataclass""" + + def test_resource_hash(self) -> None: + """Test Resource.__hash__ method (line 43)""" + resource1 = Resource(name="test.node", resource_type=ResourceType.NODE) + resource2 = Resource(name="test.node", resource_type=ResourceType.NODE) + resource3 = Resource(name="other.node", resource_type=ResourceType.NODE) + resource4 = Resource(name="test.node", resource_type=ResourceType.NAMESPACE) + + # Same name and type should have same hash + assert hash(resource1) == hash(resource2) + + # Different name should have different hash + assert hash(resource1) != hash(resource3) + + # Different type should have different hash + assert hash(resource1) != hash(resource4) + + # Resources can be used in sets and dicts + resource_set = {resource1, resource2, resource3} + assert len(resource_set) == 2 # resource1 and resource2 are same + + resource_dict = {resource1: "value1"} + assert resource_dict[resource2] == "value1" # Same hash, same key + + +class TestResourceRequest: + """Tests for ResourceRequest dataclass""" + + def test_resource_request_hash(self) -> None: + """Test ResourceRequest.__hash__ method (line 71)""" + resource = Resource(name="test.node", resource_type=ResourceType.NODE) + + request1 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request2 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request3 = ResourceRequest(verb=ResourceAction.WRITE, access_object=resource) + + other_resource = Resource(name="other.node", resource_type=ResourceType.NODE) + request4 = ResourceRequest( + verb=ResourceAction.READ, + access_object=other_resource, + ) + + # Same verb and resource should have same hash + assert hash(request1) == hash(request2) + + # Different verb should have different hash + assert hash(request1) != hash(request3) + + # Different resource should have different hash + assert hash(request1) != hash(request4) + + # ResourceRequests can be used in sets and dicts + request_set = {request1, request2, request3} + assert len(request_set) == 2 # request1 and request2 are same + + def test_resource_request_eq(self) -> None: + """Test ResourceRequest.__eq__ method (line 74)""" + resource = Resource(name="test.node", resource_type=ResourceType.NODE) + other_resource = Resource(name="other.node", resource_type=ResourceType.NODE) + + request1 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request2 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request3 = ResourceRequest(verb=ResourceAction.WRITE, access_object=resource) + request4 = ResourceRequest( + verb=ResourceAction.READ, + access_object=other_resource, + ) + + # Same verb and access_object + assert request1 == request2 + + # Different verb + assert request1 != request3 + + # Different access_object + assert request1 != request4 + + def test_resource_request_str(self) -> None: + """Test ResourceRequest.__str__ method (line 77)""" + node_resource = Resource(name="test.node", resource_type=ResourceType.NODE) + namespace_resource = Resource( + name="test.namespace", + resource_type=ResourceType.NAMESPACE, + ) + + read_request = ResourceRequest( + verb=ResourceAction.READ, + access_object=node_resource, + ) + assert str(read_request) == "read:node/test.node" + + write_request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=namespace_resource, + ) + assert str(write_request) == "write:namespace/test.namespace" + + execute_request = ResourceRequest( + verb=ResourceAction.EXECUTE, + access_object=node_resource, + ) + assert str(execute_request) == "execute:node/test.node" + + delete_request = ResourceRequest( + verb=ResourceAction.DELETE, + access_object=node_resource, + ) + assert str(delete_request) == "delete:node/test.node" + + manage_request = ResourceRequest( + verb=ResourceAction.MANAGE, + access_object=namespace_resource, + ) + assert str(manage_request) == "manage:namespace/test.namespace"