Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from datajunction_server.database.user import User
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.access.authorization import validate_access
from datajunction_server.internal.access.authorization import (
AccessChecker,
get_access_checker,
)
from datajunction_server.internal.materializations import build_cube_materialization
from datajunction_server.internal.nodes import (
get_all_cube_revisions_metadata,
get_single_cube_revision_metadata,
)
from datajunction_server.models import access
from datajunction_server.models.cube import (
CubeRevisionMetadata,
DimensionValue,
Expand Down Expand Up @@ -208,9 +210,7 @@ async def get_cube_dimension_sql(
include_counts: bool = False,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
) -> TranslatedSQL:
"""
Generates SQL to retrieve all unique values of a dimension for the cube
Expand All @@ -222,7 +222,7 @@ async def get_cube_dimension_sql(
node_revision,
dimensions,
current_user,
validate_access,
access_checker,
filters,
limit,
include_counts,
Expand Down Expand Up @@ -251,9 +251,7 @@ async def get_cube_dimension_values(
request: Request,
query_service_client: QueryServiceClient = Depends(get_query_service_client),
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
) -> DimensionValues:
"""
All unique values of a dimension from the cube
Expand All @@ -266,7 +264,7 @@ async def get_cube_dimension_values(
cube,
dimensions,
current_user,
validate_access,
access_checker,
filters,
limit,
include_counts,
Expand Down
75 changes: 19 additions & 56 deletions datajunction-server/datajunction_server/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
)
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.access.authorization import (
validate_access,
validate_access_requests,
AccessChecker,
AccessDenialMode,
get_access_checker,
)
from datajunction_server.internal.history import ActivityType, EntityType
from datajunction_server.models import access
Expand Down Expand Up @@ -66,9 +67,7 @@ async def add_availability_state(
*,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
save_history: Callable = Depends(get_save_history),
) -> JSONResponse:
"""
Expand All @@ -90,17 +89,13 @@ async def add_availability_state(

# Source nodes require that any availability states set are for one of the defined tables
node_revision = node.current # type: ignore
validate_access_requests(
validate_access,
current_user,
[
access.ResourceRequest(
verb=access.ResourceAction.WRITE,
access_object=access.Resource.from_node(node_revision),
),
],
True,
access_checker.add_request(
access.ResourceRequest(
verb=access.ResourceAction.WRITE,
access_object=access.Resource.from_node(node_revision),
),
)
await access_checker.check(on_denied=AccessDenialMode.RAISE)

if node.current.type == NodeType.SOURCE: # type: ignore
if (
Expand Down Expand Up @@ -190,9 +185,7 @@ async def remove_availability_state(
*,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
save_history: Callable = Depends(get_save_history),
) -> JSONResponse:
"""
Expand All @@ -215,17 +208,13 @@ async def remove_availability_state(
),
)

validate_access_requests(
validate_access,
current_user,
[
access.ResourceRequest(
verb=access.ResourceAction.WRITE,
access_object=access.Resource.from_node(node),
),
],
True,
access_checker.add_request(
access.ResourceRequest(
verb=access.ResourceAction.WRITE,
access_object=access.Resource.from_node(node.current), # type: ignore
),
)
await access_checker.check(on_denied=AccessDenialMode.RAISE)

# Save the old availability state for history record
old_availability = (
Expand Down Expand Up @@ -281,10 +270,6 @@ async def get_data(
query_service_client: QueryServiceClient = Depends(get_query_service_client),
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
background_tasks: BackgroundTasks,
cache: Cache = Depends(get_cache),
) -> QueryWithResults:
Expand All @@ -310,8 +295,6 @@ async def get_data(
engine_version=engine_version,
use_materialized=use_materialized,
ignore_errors=ignore_errors,
current_user=current_user,
validate_access=validate_access,
),
)

Expand Down Expand Up @@ -362,10 +345,6 @@ async def get_data_stream_for_node(
query_service_client: QueryServiceClient = Depends(get_query_service_client),
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
background_tasks: BackgroundTasks,
cache: Cache = Depends(get_cache),
) -> QueryWithResults:
Expand Down Expand Up @@ -401,8 +380,6 @@ async def get_data_stream_for_node(
engine_version=engine_version,
use_materialized=True,
ignore_errors=False,
current_user=current_user,
validate_access=validate_access,
),
)
query_create = QueryCreate(
Expand Down Expand Up @@ -468,10 +445,6 @@ async def get_data_for_metrics(
query_service_client: QueryServiceClient = Depends(get_query_service_client),
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
cache: Cache = Depends(get_cache),
background_tasks: BackgroundTasks,
) -> QueryWithResults:
Expand All @@ -497,8 +470,6 @@ async def get_data_for_metrics(
engine_version=engine_version,
use_materialized=True,
ignore_errors=False,
current_user=current_user,
validate_access=validate_access,
),
)
node = cast(
Expand Down Expand Up @@ -544,20 +515,12 @@ async def get_data_stream_for_metrics(
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
current_user: User = Depends(get_current_user),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
) -> QueryWithResults:
"""
Return data for a set of metrics with dimensions and filters using server sent events
"""
request_headers = dict(request.headers)
access_control = access.AccessControlStore(
validate_access=validate_access,
user=current_user,
base_verb=access.ResourceAction.READ,
)

translated_sql, engine, catalog = await build_sql_for_multiple_metrics(
session,
metrics,
Expand All @@ -567,7 +530,7 @@ async def get_data_stream_for_metrics(
limit,
engine_name,
engine_version,
access_control,
access_checker,
)

query_create = QueryCreate(
Expand Down
20 changes: 15 additions & 5 deletions datajunction-server/datajunction_server/api/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from datajunction_server.internal.deployment.utils import DeploymentContext
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.access.authorization import (
validate_access,
AccessChecker,
AccessDenialMode,
get_access_checker,
)
from datajunction_server.models import access
from datajunction_server.models.deployment import DeploymentStatus
Expand Down Expand Up @@ -158,22 +160,30 @@ async def create_deployment(
current_user: User = Depends(get_current_user),
query_service_client: QueryServiceClient = Depends(get_query_service_client),
cache: Cache = Depends(get_cache),
validate_access: access.ValidateAccessFn = Depends(
validate_access,
),
access_checker: AccessChecker = Depends(get_access_checker),
) -> DeploymentInfo:
"""
This endpoint takes a deployment specification (namespace, nodes, tags), topologically
sorts and validates the deployable objects, and deploys the nodes in parallel where
possible. It returns a summary of the deployment.
"""
access_checker.add_request(
access.ResourceRequest(
verb=access.ResourceAction.WRITE,
access_object=access.Resource(
resource_type=access.ResourceType.NAMESPACE,
name=deployment_spec.namespace,
),
),
)
await access_checker.check(on_denied=AccessDenialMode.RAISE)

deployment_id = await executor.submit(
spec=deployment_spec,
context=DeploymentContext(
current_user=current_user,
request=request,
query_service_client=query_service_client,
validate_access=validate_access,
background_tasks=background_tasks,
cache=cache,
),
Expand Down
Loading
Loading