Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions datajunction-server/datajunction_server/internal/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, selectinload

from datajunction_server.api.helpers import get_node_namespace
from datajunction_server.database.history import ActivityType, EntityType, History
Expand All @@ -26,6 +26,7 @@
)
from datajunction_server.models.node import NodeMinimumDetail
from datajunction_server.models.node_type import NodeType
from datajunction_server.sql.dag import topological_sort
from datajunction_server.typing import UTCDatetime
from datajunction_server.utils import SEPARATOR

Expand Down Expand Up @@ -229,10 +230,10 @@ async def hard_delete_namespace(
"""
Hard delete a node namespace.
"""
node_names = (
nodes = (
(
await session.execute(
select(Node.name)
select(Node)
.where(
or_(
Node.namespace.like(
Expand All @@ -241,27 +242,33 @@ async def hard_delete_namespace(
Node.namespace == namespace,
),
)
.order_by(Node.name),
.order_by(Node.name)
.options(
joinedload(Node.current).options(
selectinload(NodeRevision.parents),
),
),
)
)
.unique()
.scalars()
.all()
)

if not cascade and node_names:
if not cascade and nodes:
raise DJActionNotAllowedException(
message=(
f"Cannot hard delete namespace `{namespace}` as there are still the "
f"following nodes under it: `{node_names}`. Set `cascade` to true to "
"additionally hard delete the above nodes in this namespace. WARNING:"
f"following nodes under it: `{[node.name for node in nodes]}`. Set `cascade` to "
"true to additionally hard delete the above nodes in this namespace. WARNING:"
" this action cannot be undone."
),
)

impacts = {}
for node_name in node_names:
impacts[node_name] = await hard_delete_node(
node_name,
for node in reversed(topological_sort(nodes)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we don't have to do a topological sort here, doing so will mean that we process the deletes for nodes in the namespace in an order that minimizes unnecessary invalidation followed by deletion.

impacts[node.name] = await hard_delete_node(
node.name,
session,
current_user=current_user,
)
Expand Down
98 changes: 56 additions & 42 deletions datajunction-server/datajunction_server/internal/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ async def propagate_update_downstream( # pylint: disable=too-many-locals
)

# The downstreams need to be sorted topologically in order for the updates to be done
# in the right order. Otherwise it is possible for a leaf node like a metric to be updated
# in the right order. Otherwise, it is possible for a leaf node like a metric to be updated
# before its upstreams are updated.
for downstream in downstreams:
original_node_revision = downstream.current
Expand Down Expand Up @@ -1866,14 +1866,13 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement

# Check if any columns have been updated
existing_columns = {col.name: col for col in node.current.columns} # type: ignore
updated_columns = False
updated_columns = len(current_node_revision.columns) != len(node_validator.columns)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

for col in node_validator.columns:
if existing_col := existing_columns.get(col.name):
if existing_col.type != col.type:
existing_col.type = col.type
updated_columns = True
else:
node.current.columns.append(col) # type: ignore # pragma: no cover
updated_columns = True # pragma: no cover

# Only create a new revision if the columns have been updated
Expand All @@ -1893,16 +1892,16 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement
node_validator.updated_columns = node_validator.modified_columns(
new_revision, # type: ignore
)
new_revision.columns = node_validator.columns

# Save the new revision of the child
new_revision.columns = node_validator.columns
node.current_version = new_revision.version # type: ignore
new_revision.node_id = node.id # type: ignore
session.add(node)
session.add(new_revision)
await session.commit()
await session.refresh(node.current) # type: ignore
await session.refresh(node, ["current"])
session.add(node)
await session.commit()
await session.refresh(node.current) # type: ignore
await session.refresh(node, ["current"])
return node_validator


Expand All @@ -1918,7 +1917,14 @@ async def hard_delete_node(
node = await Node.get_by_name(
session,
name,
options=[joinedload(Node.current), joinedload(Node.revisions)],
options=[
joinedload(Node.current),
joinedload(Node.revisions).options(
selectinload(NodeRevision.columns).options(
joinedload(Column.attributes),
),
),
],
include_inactive=True,
raise_if_not_exists=False,
)
Expand Down Expand Up @@ -1946,42 +1952,50 @@ async def hard_delete_node(
user=current_user.username if current_user else None,
),
)
node_validator = await revalidate_node(
name=node.name,
session=session,
current_user=current_user,
)
impact.append(
{
"name": node.name,
"status": node_validator.status,
"effect": "downstream node is now invalid",
},
)
try:
node_validator = await revalidate_node(
name=node.name,
session=session,
current_user=current_user,
)
impact.append(
{
"name": node.name,
"status": node_validator.status,
"effect": "downstream node is now invalid",
},
)
except DJNodeNotFound:
_logger.warning("Node not found %s", node.name)

# Revalidate all linked nodes
for node in linked_nodes:
session.add( # Capture this in the downstream node's history
History(
entity_type=EntityType.LINK,
entity_name=name,
node=node.name,
activity_type=ActivityType.DELETE,
user=current_user.username if current_user else None,
),
)
node_validator = await revalidate_node(
name=node.name,
session=session,
current_user=current_user,
)
impact.append(
{
"name": node.name,
"status": node_validator.status,
"effect": "broken link",
},
)
if node:
session.add( # Capture this in the downstream node's history
History(
entity_type=EntityType.LINK,
entity_name=name,
node=node.name,
activity_type=ActivityType.DELETE,
user=current_user.username if current_user else None,
),
)
try:
node_validator = await revalidate_node(
name=node.name,
session=session,
current_user=current_user,
# update=False,
)
impact.append(
{
"name": node.name,
"status": node_validator.status,
"effect": "broken link",
},
)
except DJNodeNotFound:
_logger.warning("Node not found %s", node.name)
session.add( # Capture this in the downstream node's history
History(
entity_type=EntityType.NODE,
Expand Down
12 changes: 9 additions & 3 deletions datajunction-server/datajunction_server/internal/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Node validation functions."""
from dataclasses import dataclass, field
from typing import Dict, List, Set, Union
from typing import Dict, List, Optional, Set, Union

from sqlalchemy.exc import MissingGreenlet
from sqlalchemy.ext.asyncio import AsyncSession

from datajunction_server.api.helpers import find_bound_dimensions
from datajunction_server.database import Column, Node, NodeRevision
from datajunction_server.database import Column, ColumnAttribute, Node, NodeRevision
from datajunction_server.errors import DJError, DJException, ErrorCode
from datajunction_server.models.base import labelize
from datajunction_server.models.node import NodeRevisionBase, NodeStatus
Expand All @@ -22,6 +22,7 @@ class NodeValidator: # pylint: disable=too-many-instance-attributes
Node validation
"""

query_ast: Optional[ast.Query] = None
status: NodeStatus = NodeStatus.VALID
columns: List[Column] = field(default_factory=list)
required_dimensions: List[Column] = field(default_factory=list)
Expand Down Expand Up @@ -128,7 +129,12 @@ async def validate_node_data( # pylint: disable=too-many-locals,too-many-statem
name=column_name,
display_name=labelize(column_name),
type=column_type,
attributes=existing_column.attributes if existing_column else [],
attributes=[
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an annoying issue: the attributes field on the Column database object is actually a ColumnAttribute object and not an Attribute object. Since ColumnAttribute is the relational link between an Attribute and a Column, we actually need to recreate this for any copied columns. Otherwise, it'll cause problems downstream since it may reference a column that doesn't exist.

ColumnAttribute(attribute_type=col_attr.attribute_type)
for col_attr in existing_column.attributes
]
if existing_column
else [],
dimension=existing_column.dimension if existing_column else None,
order=idx,
)
Expand Down
87 changes: 81 additions & 6 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from datajunction_server.models.column import SemanticType
from datajunction_server.models.node import BuildCriteria
from datajunction_server.models.node_type import NodeType as DJNodeType
from datajunction_server.naming import SEPARATOR
from datajunction_server.sql.functions import function_registry, table_function_registry
from datajunction_server.sql.parsing.backends.exceptions import DJParseException
from datajunction_server.sql.parsing.types import (
Expand Down Expand Up @@ -1102,7 +1103,6 @@ class Wildcard(Named, Expression):
Wildcard or '*' expression
"""

name: Name = field(init=False, repr=False, default=Name("*"))
_table: Optional["Table"] = field(repr=False, default=None)

@property
Expand All @@ -1121,12 +1121,65 @@ def add_table(self, table: "Table") -> "Wildcard":
return self

def __str__(self) -> str:
return "*"
return (self.namespace[0].name + SEPARATOR if self.namespace else "") + "*"

@property
def type(self) -> ColumnType:
return WildcardType()

async def compile(self, ctx: CompileContext):
"""
Compile a Wildcard AST node. If the wildcard is used in a SELECT statement, we
replace it with the equivalent of explicitly selecting all upstream columns.
"""
if not isinstance(self.parent, Select):
return super().compile(ctx)

wildcard_parent = cast(Select, self.parent)

# If a wildcard is used in a SELECT statement, we pull the columns that it
# represents by scanning all relations in the FROM clause (both the primary table
# and any joined tables)
wildcard_table_namespace = (
self.namespace[0].name
if self.namespace and isinstance(self.namespace, list)
else None
)
from_relations = [
wildcard_parent.from_.relations[0].primary,
*[ext.right for ext in wildcard_parent.from_.relations[0].extensions],
]

for relation in from_relations:
await relation.compile(ctx)
if (
not wildcard_table_namespace
or relation.alias_or_name.name == wildcard_table_namespace
):
# Figure out where the relation's columns are stored depending on the relation type
if isinstance(relation, Table):
wildcard_origin = cast(Table, relation)
if wildcard_origin_node := wildcard_origin.dj_node:
relation_columns = wildcard_origin_node.columns
else:
relation_columns = wildcard_origin._cte_columns
else:
wildcard_origin = cast(Query, relation)
relation_columns = wildcard_origin.select.projection

# Use these columns to replace the wildcard
for col in relation_columns:
wildcard_parent.projection.append(
Column(
name=Name(col.name)
if isinstance(col.name, str)
else col.name,
_table=wildcard_origin,
_type=col.type,
),
)
wildcard_parent.projection.remove(self)


@dataclass(eq=False)
class TableExpression(Aliasable, Expression):
Expand All @@ -1140,6 +1193,7 @@ class TableExpression(Aliasable, Expression):
) # all those expressions that can be had from the table; usually derived from dj node metadata for Table
# ref (referenced) columns are columns used elsewhere from this table
_ref_columns: List[Column] = field(init=False, repr=False, default_factory=list)
_cte_columns: List[Expression] = field(default_factory=list)

@property
def columns(self) -> List[Expression]:
Expand Down Expand Up @@ -1345,8 +1399,12 @@ def set_alias(self: TNode, alias: "Name") -> TNode:
return self

async def compile(self, ctx: CompileContext):
# things we can validate here:
# - if the node is a dimension in a groupby, is it joinable?
"""
Compile a Table AST node by finding and saving the columns it references
"""
if self._is_compiled:
return

self._is_compiled = True
try:
if not self.dj_node:
Expand All @@ -1356,12 +1414,26 @@ async def compile(self, ctx: CompileContext):
{DJNodeType.SOURCE, DJNodeType.TRANSFORM, DJNodeType.DIMENSION},
)
self.set_dj_node(dj_node)
except DJErrorException as exc:
ctx.exception.errors.append(exc.dj_error)

if self.dj_node:
# If the Table object is a reference to a DJ node, save the columns of the
# DJ node into self._columns for later use
self._columns = [
Column(Name(col.name), _type=col.type, _table=self)
for col in self.dj_node.columns
]
except DJErrorException as exc:
ctx.exception.errors.append(exc.dj_error)
elif query := self.get_nearest_parent_of_type(Query):
# If the Table object is a reference to a CTE, save the columns output by
# the CTE into self._columns for later use
for cte in query.ctes:
if self.alias_or_name.name == cte.alias_or_name.name:
await cte.compile(ctx)
self._cte_columns = [
Column(col.alias_or_name, _type=col.type, _table=self)
for col in cte._columns
]


class Operation(Expression):
Expand Down Expand Up @@ -2565,6 +2637,9 @@ async def compile(self, ctx: CompileContext):
),
)
await super().compile(ctx)
for child in self.projection:
if isinstance(child, Wildcard):
await child.compile(ctx)


@dataclass(eq=False)
Expand Down
Loading