Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,16 @@ def build_component_expression(component: MetricComponent) -> ast.Expression:

For simple aggregations like SUM, this is: SUM(expression)
For templates like "SUM(POWER({}, 2))", expands to: SUM(POWER(expression, 2))

Note: Templates may be pre-expanded (e.g., "SUM(POWER(match_score, 2))")
by the decomposition phase, so we detect this by checking for parentheses
without template placeholders.
"""
if not component.aggregation: # pragma: no cover
# No aggregation - just return the expression as a column
return ast.Column(name=ast.Name(component.expression))

# Check if it's a template with {}
# Check if it's an unexpanded template with {}
if "{" in component.aggregation: # pragma: no cover
# Template like "SUM(POWER({}, 2))" - expand it
expanded = component.aggregation.replace("{}", component.expression)
Expand All @@ -214,6 +218,16 @@ def build_component_expression(component: MetricComponent) -> ast.Expression:
expr_ast = expr_ast.child
expr_ast.clear_parent()
return cast(ast.Expression, expr_ast)

# Check if it's a pre-expanded template (contains parentheses, like "SUM(POWER(x, 2))")
# vs a simple function name (like "SUM")
if "(" in component.aggregation:
# Pre-expanded template - parse it directly as a complete expression
expr_ast = parse(f"SELECT {component.aggregation}").select.projection[0]
if isinstance(expr_ast, ast.Alias):
expr_ast = expr_ast.child # pragma: no cover
expr_ast.clear_parent()
return cast(ast.Expression, expr_ast)
else:
# Simple function name like "SUM" - build SUM(expression)
arg_expr = parse(f"SELECT {component.expression}").select.projection[0]
Expand Down
201 changes: 158 additions & 43 deletions datajunction-server/datajunction_server/sql/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,28 +1073,94 @@ async def get_shared_dimensions(
) -> List[DimensionAttributeOutput]:
"""
Return a list of dimensions that are common between the metric nodes.

For each individual metric:
- If it has multiple parents (e.g., derived metric referencing multiple base metrics),
returns the union of dimensions from all its parents.

Across multiple metrics:
- Returns the intersection of dimensions available for each metric.

This allows derived metrics to use dimensions from any of their base metrics,
while still ensuring compatibility when querying multiple metrics together.
"""
parents = await get_metric_parents(session, metric_nodes)
return await get_common_dimensions(session, parents)
if not metric_nodes:
return []

# Get the per-metric parent mapping (batched for efficiency)
metric_to_parents = await get_metric_parents_map(session, metric_nodes)

async def get_metric_parents(
# For each metric, compute the UNION of dimensions from its parents
per_metric_dimensions: List[Dict[str, List[DimensionAttributeOutput]]] = []

for metric_node in metric_nodes:
parents = metric_to_parents.get(metric_node.name, [])
if not parents:
continue # pragma: no cover

# Compute union of dimensions from all parents
dims_by_name: Dict[str, List[DimensionAttributeOutput]] = {}
for parent in parents:
parent_dims = await group_dimensions_by_name(session, parent)
for dim_name, dim_list in parent_dims.items():
if dim_name not in dims_by_name:
dims_by_name[dim_name] = dim_list
# If already present, keep existing (they should be equivalent)

per_metric_dimensions.append(dims_by_name)

if not per_metric_dimensions:
return [] # pragma: no cover

if len(per_metric_dimensions) == 1:
# Single metric - return all its dimensions
return sorted(
[dim for dims in per_metric_dimensions[0].values() for dim in dims],
key=lambda x: (x.name, x.path),
)

# Multiple metrics - find intersection across metrics
common_names = set(per_metric_dimensions[0].keys())
for dims_by_name in per_metric_dimensions[1:]:
common_names &= set(dims_by_name.keys())

if not common_names:
return []

# Return dimensions from first metric that are in the intersection
return sorted(
[
dim
for name, dims in per_metric_dimensions[0].items()
if name in common_names
for dim in dims
],
key=lambda x: (x.name, x.path),
)


async def get_metric_parents_map(
session: AsyncSession,
metric_nodes: list[Node],
) -> list[Node]:
) -> Dict[str, List[Node]]:
"""
Return a list of non-metric parent nodes of the metrics.
Return a mapping from metric name to its non-metric parent nodes.

For derived metrics (metrics that reference base metrics), returns the
non-metric parents of those base metrics.

Note: Only 1 level of metric nesting is supported. Derived metrics can
reference base metrics, but not other derived metrics.
This batched version maintains the relationship between metrics and their
parents, which is needed to compute per-metric dimension unions.

Note: Only 1 level of metric nesting is supported.
"""
if not metric_nodes:
return [] # pragma: no cover
return {}

metric_names = {m.name for m in metric_nodes}
result: Dict[str, List[Node]] = {name: [] for name in metric_names}

# Query 1: Get all immediate parents for the input metrics
# Get all immediate parents for the input metrics WITH the child metric name
find_latest_node_revisions = [
and_(
NodeRevision.name == metric_node.name,
Expand All @@ -1103,7 +1169,7 @@ async def get_metric_parents(
for metric_node in metric_nodes
]
statement = (
select(Node)
select(NodeRevision.name.label("metric_name"), Node)
.where(or_(*find_latest_node_revisions))
.select_from(
join(
Expand All @@ -1116,42 +1182,91 @@ async def get_metric_parents(
),
)
)
immediate_parents = list(set((await session.execute(statement)).scalars().all()))

# Separate metric and non-metric parents
metric_parents = [p for p in immediate_parents if p.type == NodeType.METRIC]
non_metric_parents = [p for p in immediate_parents if p.type != NodeType.METRIC]

# Query 2: For metric parents (base metrics), get their parents
# With 1-level nesting, these must be non-metrics
if metric_parents:
find_base_metric_revisions = [
and_(
NodeRevision.name == m.name,
NodeRevision.version == m.current_version,
)
for m in metric_parents
]
statement = (
rows = (await session.execute(statement)).all()

# Build mapping and track metric parents that need further resolution
metric_parents_to_resolve: Dict[
str,
List[str],
] = {} # base_metric_name -> [derived_metric_names]

for metric_name, parent_node in rows:
if parent_node.type == NodeType.METRIC:
# This is a derived metric - need to get the base metric's parents
if parent_node.name not in metric_parents_to_resolve:
metric_parents_to_resolve[parent_node.name] = []
metric_parents_to_resolve[parent_node.name].append(metric_name)
else:
# Non-metric parent - add directly
result[metric_name].append(parent_node)

# For metric parents (base metrics), get their parents
if metric_parents_to_resolve:
base_metric_names = list(metric_parents_to_resolve.keys())
# Get the base metrics' current versions
base_metrics_stmt = (
select(Node)
.where(or_(*find_base_metric_revisions))
.select_from(
join(
.where(Node.name.in_(base_metric_names))
.where(is_(Node.deactivated_at, None))
)
base_metrics = list((await session.execute(base_metrics_stmt)).scalars().all())

if base_metrics: # pragma: no branch
find_base_metric_revisions = [
and_(
NodeRevision.name == m.name,
NodeRevision.version == m.current_version,
)
for m in base_metrics
]
statement = (
select(NodeRevision.name.label("base_metric_name"), Node)
.where(or_(*find_base_metric_revisions))
.select_from(
join(
NodeRevision,
NodeRelationship,
join(
NodeRevision,
NodeRelationship,
),
Node,
NodeRelationship.parent_id == Node.id,
),
Node,
NodeRelationship.parent_id == Node.id,
),
)
)
)
base_metric_parents = list(
set((await session.execute(statement)).scalars().all()),
)
non_metric_parents.extend(base_metric_parents)
base_rows = (await session.execute(statement)).all()

return list(set(non_metric_parents))
# Map base metric parents back to the derived metrics
for base_metric_name, parent_node in base_rows:
if parent_node.type != NodeType.METRIC: # pragma: no branch
# Add to all derived metrics that reference this base metric
for derived_metric_name in metric_parents_to_resolve.get(
base_metric_name,
[],
):
result[derived_metric_name].append(parent_node)

# Deduplicate parents for each metric
return {name: list(set(parents)) for name, parents in result.items()}


async def get_metric_parents(
session: AsyncSession,
metric_nodes: list[Node],
) -> list[Node]:
"""
Return a flat list of non-metric parent nodes of the metrics.

For derived metrics (metrics that reference base metrics), returns the
non-metric parents of those base metrics.

Note: Only 1 level of metric nesting is supported. Derived metrics can
reference base metrics, but not other derived metrics.
"""
metric_to_parents = await get_metric_parents_map(session, metric_nodes)
all_parents = []
for parents in metric_to_parents.values():
all_parents.extend(parents)
return list(set(all_parents))


async def get_common_dimensions(session: AsyncSession, nodes: list[Node]):
Expand All @@ -1160,7 +1275,7 @@ async def get_common_dimensions(session: AsyncSession, nodes: list[Node]):
"""
metric_nodes = [node for node in nodes if node.type == NodeType.METRIC]
other_nodes = [node for node in nodes if node.type != NodeType.METRIC]
if metric_nodes:
if metric_nodes: # pragma: no branch
nodes = list(set(other_nodes + await get_metric_parents(session, metric_nodes)))

common = await group_dimensions_by_name(session, nodes[0])
Expand All @@ -1172,7 +1287,7 @@ async def get_common_dimensions(session: AsyncSession, nodes: list[Node]):
common_dim_keys = common.keys() & list(node_dimensions.keys())
if not common_dim_keys:
return []
for dim_key in to_delete:
for dim_key in to_delete: # pragma: no cover
del common[dim_key] # pragma: no cover
return sorted(
[y for x in common.values() for y in x],
Expand Down
Loading
Loading