Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,24 @@ async def create(
) -> "QueryBuilder":
"""
Create a QueryBuilder instance for the node revision.
"""
await refresh_if_needed(
session,
node_revision,
["required_dimensions", "dimension_links"],
)

Note: If the node was loaded with Node.get_by_name_eager(),
required_dimensions and dimension_links will already be loaded.
"""
# Only refresh if not already loaded (i.e., not from eager loading)
# Check if dimension_links is loaded by seeing if it's accessible
try:
_ = node_revision.dimension_links
links_loaded = True
except: # noqa: E722
links_loaded = False

if not links_loaded:
await refresh_if_needed(
session,
node_revision,
["required_dimensions", "dimension_links"],
)
instance = cls(session, node_revision, use_materialized=use_materialized)
return instance

Expand Down Expand Up @@ -460,11 +472,21 @@ async def build(self) -> ast.Query:
7. Add all requested dimensions to the final select.
8. Add order by and limit to the final select (TODO)
"""
await refresh_if_needed(
self.session,
self.node_revision,
["availability", "columns", "query_ast"],
)
# Only refresh if not already loaded (from eager loading)
# Check if columns are accessible without triggering a query
try:
_ = self.node_revision.columns
columns_loaded = True
except: # noqa: E722
columns_loaded = False

if not columns_loaded:
await refresh_if_needed(
self.session,
self.node_revision,
["availability", "columns", "query_ast"],
)

if self.node_revision.query_ast:
node_ast = self.node_revision.query_ast # pragma: no cover
else:
Expand Down
111 changes: 111 additions & 0 deletions datajunction-server/datajunction_server/database/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,117 @@ async def get_by_names(
ordered_nodes = [nodes_by_name[name] for name in names if name in nodes_by_name]
return ordered_nodes

@classmethod
async def get_by_name_eager(
cls,
session: AsyncSession,
name: str,
load_dimensions: bool = True,
load_parents: bool = True,
raise_if_not_exists: bool = False,
) -> Optional["Node"]:
"""
Get a node by name with eager loading to minimize database queries.

This method loads all commonly needed relationships in 1-2 queries instead
of the N+1 pattern that happens with lazy loading. This significantly speeds
up query building by eliminating redundant database round trips.

Args:
session: Database session
name: Node name
load_dimensions: Whether to eagerly load dimension links and their targets
load_parents: Whether to eagerly load parent node relationships
raise_if_not_exists: Whether to raise exception if node not found

Returns:
Node with all relationships eagerly loaded, or None if not found

Example:
>>> node = await Node.get_by_name_eager(session, "default.revenue", load_dimensions=True)
>>> # All columns, dimension_links, availability already loaded - no extra queries!
>>> columns = node.current.columns # Already loaded
>>> links = node.current.dimension_links # Already loaded
"""
from datajunction_server.database.dimensionlink import DimensionLink

# Build base query
statement = (
select(Node).where(Node.name == name).where(is_(Node.deactivated_at, None))
)

# Eagerly load current revision with all its relationships
options = [
joinedload(Node.current).options(
# Load columns with their attributes
selectinload(NodeRevision.columns).options(
selectinload(Column.attributes),
joinedload(
Column.dimension,
), # Load dimension references on columns
),
# Load availability state
joinedload(NodeRevision.availability),
# Load catalog and its engines
joinedload(NodeRevision.catalog).options(
selectinload(Catalog.engines),
),
# Load metric metadata if exists
joinedload(NodeRevision.metric_metadata),
# Load materializations
selectinload(NodeRevision.materializations),
),
# Load node-level relationships
selectinload(Node.tags),
selectinload(Node.created_by),
selectinload(Node.owners),
]

# Optionally load dimension links (common for query building with dimensions)
if load_dimensions:
options.append(
joinedload(Node.current)
.selectinload(NodeRevision.dimension_links)
.options(
# Load the target dimension node
joinedload(DimensionLink.dimension).options(
# Load dimension's current revision
joinedload(Node.current).options(
selectinload(NodeRevision.columns).options(
selectinload(Column.attributes),
),
joinedload(NodeRevision.availability),
),
),
),
)

# Optionally load parent nodes (for building upstream references)
if load_parents:
options.append(
joinedload(Node.current)
.selectinload(NodeRevision.parents)
.options(
joinedload(Node.current).options(
selectinload(NodeRevision.columns),
joinedload(NodeRevision.availability),
),
),
)

statement = statement.options(*options)

result = await session.execute(statement)
node = result.unique().scalar_one_or_none()

if not node and raise_if_not_exists:
raise DJNodeNotFound(
message=f"Node `{name}` does not exist.",
http_status_code=404,
)

return node

@classmethod
async def get_cube_by_name(
cls,
Expand Down
28 changes: 25 additions & 3 deletions datajunction-server/datajunction_server/internal/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,23 @@ async def build_node_sql(
"""
Build node SQL and save it to query requests
"""
import time

start_time = time.time()

if orderby:
validate_orderby(orderby, [node_name], dimensions or [])

# Use eager loading to minimize database queries
node = cast(
Node,
await Node.get_by_name(session, node_name, raise_if_not_exists=True),
await Node.get_by_name_eager(
session,
node_name,
load_dimensions=bool(dimensions), # Only load if we need dimensions
load_parents=True, # Always load for upstream references
raise_if_not_exists=True,
),
)
if not engine: # pragma: no cover
engine = node.current.catalog.engines[0]
Expand Down Expand Up @@ -102,7 +113,7 @@ async def build_node_sql(
return translated_sql

# For all other nodes, build the node query
node = await Node.get_by_name(session, node_name, raise_if_not_exists=True) # type: ignore
# (node already loaded above with eager loading, reuse it)
if node.type == NodeType.METRIC:
translated_sql, engine, _ = await build_sql_for_multiple_metrics(
session,
Expand Down Expand Up @@ -140,12 +151,23 @@ async def build_node_sql(
]
query = str(query_ast)

return TranslatedSQL.create(
result = TranslatedSQL.create(
sql=query,
columns=columns,
dialect=engine.dialect if engine else None,
)

# Log timing
build_time = (time.time() - start_time) * 1000
logger.info(
"build_node_sql completed: node=%s, dimensions=%d, time=%.2fms",
node_name,
len(dimensions or []),
build_time,
)

return result


async def build_sql_for_multiple_metrics(
session: AsyncSession,
Expand Down
Loading