From 0734396983dd424eb93e527036814f43750eaf04 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Thu, 9 May 2024 22:00:27 -0700 Subject: [PATCH] Support selecting wildcards in node queries --- .../internal/namespaces.py | 27 +-- .../datajunction_server/internal/nodes.py | 98 ++++++----- .../internal/validation.py | 12 +- .../datajunction_server/sql/parsing/ast.py | 87 +++++++++- .../sql/parsing/backends/antlr4.py | 2 +- .../tests/api/namespaces_test.py | 45 +---- .../tests/api/node_update_test.py | 117 ++++++++++++- .../tests/construction/compile_test.py | 162 ++++++++++++++++++ 8 files changed, 440 insertions(+), 110 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/namespaces.py b/datajunction-server/datajunction_server/internal/namespaces.py index 60f48325a..454755d1e 100644 --- a/datajunction-server/datajunction_server/internal/namespaces.py +++ b/datajunction-server/datajunction_server/internal/namespaces.py @@ -8,7 +8,7 @@ from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, selectinload from datajunction_server.api.helpers import get_node_namespace from datajunction_server.database.history import ActivityType, EntityType, History @@ -26,6 +26,7 @@ ) from datajunction_server.models.node import NodeMinimumDetail from datajunction_server.models.node_type import NodeType +from datajunction_server.sql.dag import topological_sort from datajunction_server.typing import UTCDatetime from datajunction_server.utils import SEPARATOR @@ -229,10 +230,10 @@ async def hard_delete_namespace( """ Hard delete a node namespace. """ - node_names = ( + nodes = ( ( await session.execute( - select(Node.name) + select(Node) .where( or_( Node.namespace.like( @@ -241,27 +242,33 @@ async def hard_delete_namespace( Node.namespace == namespace, ), ) - .order_by(Node.name), + .order_by(Node.name) + .options( + joinedload(Node.current).options( + selectinload(NodeRevision.parents), + ), + ), ) ) + .unique() .scalars() .all() ) - if not cascade and node_names: + if not cascade and nodes: raise DJActionNotAllowedException( message=( f"Cannot hard delete namespace `{namespace}` as there are still the " - f"following nodes under it: `{node_names}`. Set `cascade` to true to " - "additionally hard delete the above nodes in this namespace. WARNING:" + f"following nodes under it: `{[node.name for node in nodes]}`. Set `cascade` to " + "true to additionally hard delete the above nodes in this namespace. WARNING:" " this action cannot be undone." ), ) impacts = {} - for node_name in node_names: - impacts[node_name] = await hard_delete_node( - node_name, + for node in reversed(topological_sort(nodes)): + impacts[node.name] = await hard_delete_node( + node.name, session, current_user=current_user, ) diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index aa5de5f51..e79f470bf 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -868,7 +868,7 @@ async def propagate_update_downstream( # pylint: disable=too-many-locals ) # The downstreams need to be sorted topologically in order for the updates to be done - # in the right order. Otherwise it is possible for a leaf node like a metric to be updated + # in the right order. Otherwise, it is possible for a leaf node like a metric to be updated # before its upstreams are updated. for downstream in downstreams: original_node_revision = downstream.current @@ -1866,14 +1866,13 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement # Check if any columns have been updated existing_columns = {col.name: col for col in node.current.columns} # type: ignore - updated_columns = False + updated_columns = len(current_node_revision.columns) != len(node_validator.columns) for col in node_validator.columns: if existing_col := existing_columns.get(col.name): if existing_col.type != col.type: existing_col.type = col.type updated_columns = True else: - node.current.columns.append(col) # type: ignore # pragma: no cover updated_columns = True # pragma: no cover # Only create a new revision if the columns have been updated @@ -1893,16 +1892,16 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement node_validator.updated_columns = node_validator.modified_columns( new_revision, # type: ignore ) - new_revision.columns = node_validator.columns # Save the new revision of the child + new_revision.columns = node_validator.columns node.current_version = new_revision.version # type: ignore new_revision.node_id = node.id # type: ignore - session.add(node) session.add(new_revision) - await session.commit() - await session.refresh(node.current) # type: ignore - await session.refresh(node, ["current"]) + session.add(node) + await session.commit() + await session.refresh(node.current) # type: ignore + await session.refresh(node, ["current"]) return node_validator @@ -1918,7 +1917,14 @@ async def hard_delete_node( node = await Node.get_by_name( session, name, - options=[joinedload(Node.current), joinedload(Node.revisions)], + options=[ + joinedload(Node.current), + joinedload(Node.revisions).options( + selectinload(NodeRevision.columns).options( + joinedload(Column.attributes), + ), + ), + ], include_inactive=True, raise_if_not_exists=False, ) @@ -1946,42 +1952,50 @@ async def hard_delete_node( user=current_user.username if current_user else None, ), ) - node_validator = await revalidate_node( - name=node.name, - session=session, - current_user=current_user, - ) - impact.append( - { - "name": node.name, - "status": node_validator.status, - "effect": "downstream node is now invalid", - }, - ) + try: + node_validator = await revalidate_node( + name=node.name, + session=session, + current_user=current_user, + ) + impact.append( + { + "name": node.name, + "status": node_validator.status, + "effect": "downstream node is now invalid", + }, + ) + except DJNodeNotFound: + _logger.warning("Node not found %s", node.name) # Revalidate all linked nodes for node in linked_nodes: - session.add( # Capture this in the downstream node's history - History( - entity_type=EntityType.LINK, - entity_name=name, - node=node.name, - activity_type=ActivityType.DELETE, - user=current_user.username if current_user else None, - ), - ) - node_validator = await revalidate_node( - name=node.name, - session=session, - current_user=current_user, - ) - impact.append( - { - "name": node.name, - "status": node_validator.status, - "effect": "broken link", - }, - ) + if node: + session.add( # Capture this in the downstream node's history + History( + entity_type=EntityType.LINK, + entity_name=name, + node=node.name, + activity_type=ActivityType.DELETE, + user=current_user.username if current_user else None, + ), + ) + try: + node_validator = await revalidate_node( + name=node.name, + session=session, + current_user=current_user, + # update=False, + ) + impact.append( + { + "name": node.name, + "status": node_validator.status, + "effect": "broken link", + }, + ) + except DJNodeNotFound: + _logger.warning("Node not found %s", node.name) session.add( # Capture this in the downstream node's history History( entity_type=EntityType.NODE, diff --git a/datajunction-server/datajunction_server/internal/validation.py b/datajunction-server/datajunction_server/internal/validation.py index 892c02a62..f773b7eed 100644 --- a/datajunction-server/datajunction_server/internal/validation.py +++ b/datajunction-server/datajunction_server/internal/validation.py @@ -1,12 +1,12 @@ """Node validation functions.""" from dataclasses import dataclass, field -from typing import Dict, List, Set, Union +from typing import Dict, List, Optional, Set, Union from sqlalchemy.exc import MissingGreenlet from sqlalchemy.ext.asyncio import AsyncSession from datajunction_server.api.helpers import find_bound_dimensions -from datajunction_server.database import Column, Node, NodeRevision +from datajunction_server.database import Column, ColumnAttribute, Node, NodeRevision from datajunction_server.errors import DJError, DJException, ErrorCode from datajunction_server.models.base import labelize from datajunction_server.models.node import NodeRevisionBase, NodeStatus @@ -22,6 +22,7 @@ class NodeValidator: # pylint: disable=too-many-instance-attributes Node validation """ + query_ast: Optional[ast.Query] = None status: NodeStatus = NodeStatus.VALID columns: List[Column] = field(default_factory=list) required_dimensions: List[Column] = field(default_factory=list) @@ -128,7 +129,12 @@ async def validate_node_data( # pylint: disable=too-many-locals,too-many-statem name=column_name, display_name=labelize(column_name), type=column_type, - attributes=existing_column.attributes if existing_column else [], + attributes=[ + ColumnAttribute(attribute_type=col_attr.attribute_type) + for col_attr in existing_column.attributes + ] + if existing_column + else [], dimension=existing_column.dimension if existing_column else None, order=idx, ) diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index 4f064e2c5..ec7fc0dbe 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -38,6 +38,7 @@ from datajunction_server.models.column import SemanticType from datajunction_server.models.node import BuildCriteria from datajunction_server.models.node_type import NodeType as DJNodeType +from datajunction_server.naming import SEPARATOR from datajunction_server.sql.functions import function_registry, table_function_registry from datajunction_server.sql.parsing.backends.exceptions import DJParseException from datajunction_server.sql.parsing.types import ( @@ -1102,7 +1103,6 @@ class Wildcard(Named, Expression): Wildcard or '*' expression """ - name: Name = field(init=False, repr=False, default=Name("*")) _table: Optional["Table"] = field(repr=False, default=None) @property @@ -1121,12 +1121,65 @@ def add_table(self, table: "Table") -> "Wildcard": return self def __str__(self) -> str: - return "*" + return (self.namespace[0].name + SEPARATOR if self.namespace else "") + "*" @property def type(self) -> ColumnType: return WildcardType() + async def compile(self, ctx: CompileContext): + """ + Compile a Wildcard AST node. If the wildcard is used in a SELECT statement, we + replace it with the equivalent of explicitly selecting all upstream columns. + """ + if not isinstance(self.parent, Select): + return super().compile(ctx) + + wildcard_parent = cast(Select, self.parent) + + # If a wildcard is used in a SELECT statement, we pull the columns that it + # represents by scanning all relations in the FROM clause (both the primary table + # and any joined tables) + wildcard_table_namespace = ( + self.namespace[0].name + if self.namespace and isinstance(self.namespace, list) + else None + ) + from_relations = [ + wildcard_parent.from_.relations[0].primary, + *[ext.right for ext in wildcard_parent.from_.relations[0].extensions], + ] + + for relation in from_relations: + await relation.compile(ctx) + if ( + not wildcard_table_namespace + or relation.alias_or_name.name == wildcard_table_namespace + ): + # Figure out where the relation's columns are stored depending on the relation type + if isinstance(relation, Table): + wildcard_origin = cast(Table, relation) + if wildcard_origin_node := wildcard_origin.dj_node: + relation_columns = wildcard_origin_node.columns + else: + relation_columns = wildcard_origin._cte_columns + else: + wildcard_origin = cast(Query, relation) + relation_columns = wildcard_origin.select.projection + + # Use these columns to replace the wildcard + for col in relation_columns: + wildcard_parent.projection.append( + Column( + name=Name(col.name) + if isinstance(col.name, str) + else col.name, + _table=wildcard_origin, + _type=col.type, + ), + ) + wildcard_parent.projection.remove(self) + @dataclass(eq=False) class TableExpression(Aliasable, Expression): @@ -1140,6 +1193,7 @@ class TableExpression(Aliasable, Expression): ) # all those expressions that can be had from the table; usually derived from dj node metadata for Table # ref (referenced) columns are columns used elsewhere from this table _ref_columns: List[Column] = field(init=False, repr=False, default_factory=list) + _cte_columns: List[Expression] = field(default_factory=list) @property def columns(self) -> List[Expression]: @@ -1345,8 +1399,12 @@ def set_alias(self: TNode, alias: "Name") -> TNode: return self async def compile(self, ctx: CompileContext): - # things we can validate here: - # - if the node is a dimension in a groupby, is it joinable? + """ + Compile a Table AST node by finding and saving the columns it references + """ + if self._is_compiled: + return + self._is_compiled = True try: if not self.dj_node: @@ -1356,12 +1414,26 @@ async def compile(self, ctx: CompileContext): {DJNodeType.SOURCE, DJNodeType.TRANSFORM, DJNodeType.DIMENSION}, ) self.set_dj_node(dj_node) + except DJErrorException as exc: + ctx.exception.errors.append(exc.dj_error) + + if self.dj_node: + # If the Table object is a reference to a DJ node, save the columns of the + # DJ node into self._columns for later use self._columns = [ Column(Name(col.name), _type=col.type, _table=self) for col in self.dj_node.columns ] - except DJErrorException as exc: - ctx.exception.errors.append(exc.dj_error) + elif query := self.get_nearest_parent_of_type(Query): + # If the Table object is a reference to a CTE, save the columns output by + # the CTE into self._columns for later use + for cte in query.ctes: + if self.alias_or_name.name == cte.alias_or_name.name: + await cte.compile(ctx) + self._cte_columns = [ + Column(col.alias_or_name, _type=col.type, _table=self) + for col in cte._columns + ] class Operation(Expression): @@ -2565,6 +2637,9 @@ async def compile(self, ctx: CompileContext): ), ) await super().compile(ctx) + for child in self.projection: + if isinstance(child, Wildcard): + await child.compile(ctx) @dataclass(eq=False) diff --git a/datajunction-server/datajunction_server/sql/parsing/backends/antlr4.py b/datajunction-server/datajunction_server/sql/parsing/backends/antlr4.py index c88419c6e..c065c0501 100644 --- a/datajunction-server/datajunction_server/sql/parsing/backends/antlr4.py +++ b/datajunction-server/datajunction_server/sql/parsing/backends/antlr4.py @@ -703,7 +703,7 @@ def _(ctx: sbp.StarContext): namespace = None if qual_name := ctx.qualifiedName(): namespace = visit(qual_name) - star = ast.Wildcard() + star = ast.Wildcard(name=ast.Name("*", namespace=ast.Name(namespace))) star.name.namespace = namespace return star diff --git a/datajunction-server/tests/api/namespaces_test.py b/datajunction-server/tests/api/namespaces_test.py index ec3e33e3a..d1addab42 100644 --- a/datajunction-server/tests/api/namespaces_test.py +++ b/datajunction-server/tests/api/namespaces_test.py @@ -446,28 +446,10 @@ async def test_hard_delete_namespace(client_with_examples: AsyncClient): "status": "valid", }, ], - "foo.bar.hard_hat_state": [ - { - "effect": "downstream node is now " "invalid", - "name": "foo.bar.local_hard_hats", - "status": "invalid", - }, - ], - "foo.bar.hard_hats": [ - { - "effect": "downstream node is now invalid", - "name": "foo.bar.local_hard_hats", - "status": "invalid", - }, - ], + "foo.bar.hard_hat_state": [], + "foo.bar.hard_hats": [], "foo.bar.local_hard_hats": [], - "foo.bar.municipality": [ - { - "effect": "downstream node is now " "invalid", - "name": "foo.bar.municipality_dim", - "status": "invalid", - }, - ], + "foo.bar.municipality": [], "foo.bar.municipality_dim": [ { "effect": "broken link", @@ -520,29 +502,12 @@ async def test_hard_delete_namespace(client_with_examples: AsyncClient): "status": "valid", }, ], - "foo.bar.repair_order_details": [ - { - "effect": "downstream node is " "now invalid", - "name": "foo.bar.total_repair_cost", - "status": "invalid", - }, - { - "effect": "downstream node is " "now invalid", - "name": "foo.bar.total_repair_order_discounts", - "status": "invalid", - }, - ], + "foo.bar.repair_order_details": [], "foo.bar.repair_orders": [], "foo.bar.repair_type": [], "foo.bar.total_repair_cost": [], "foo.bar.total_repair_order_discounts": [], - "foo.bar.us_region": [ - { - "effect": "downstream node is now invalid", - "name": "foo.bar.us_state", - "status": "invalid", - }, - ], + "foo.bar.us_region": [], "foo.bar.us_state": [], "foo.bar.us_states": [], }, diff --git a/datajunction-server/tests/api/node_update_test.py b/datajunction-server/tests/api/node_update_test.py index db0c72d2a..eaf054237 100644 --- a/datajunction-server/tests/api/node_update_test.py +++ b/datajunction-server/tests/api/node_update_test.py @@ -48,7 +48,7 @@ async def test_update_source_node( "activity_type": "update", "created_at": mock.ANY, "details": { - "changes": {"updated_columns": []}, + "changes": {"updated_columns": ["total_amount_nationwide"]}, "reason": "Caused by update of `default.repair_order_details` to " "v2.0", "upstream": { @@ -60,7 +60,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.national_level_agg", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -71,7 +71,10 @@ async def test_update_source_node( "created_at": mock.ANY, "details": { "changes": { - "updated_columns": [], + "updated_columns": [ + "avg_repair_amount_in_region", + "total_amount_in_region", + ], }, "reason": "Caused by update of `default.repair_order_details` to v2.0", "upstream": { @@ -83,7 +86,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.regional_level_agg", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -93,7 +96,7 @@ async def test_update_source_node( "activity_type": "update", "created_at": mock.ANY, "details": { - "changes": {"updated_columns": []}, + "changes": {"updated_columns": ["default_DOT_avg_repair_price"]}, "reason": "Caused by update of `default.repair_order_details` to " "v2.0", "upstream": { @@ -105,7 +108,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.avg_repair_price", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -116,7 +119,7 @@ async def test_update_source_node( "created_at": mock.ANY, "details": { "changes": { - "updated_columns": [], + "updated_columns": ["default_DOT_regional_repair_efficiency"], }, "reason": "Caused by update of `default.repair_order_details` to " "v2.0", @@ -129,7 +132,7 @@ async def test_update_source_node( "entity_type": "node", "id": mock.ANY, "node": "default.regional_repair_efficiency", - "post": {"status": "invalid", "version": "v1.0"}, + "post": {"status": "invalid", "version": "v2.0"}, "pre": {"status": "valid", "version": "v1.0"}, "user": mock.ANY, }, @@ -171,3 +174,101 @@ async def test_update_source_node( "type": "double", }, ] + + +@pytest.mark.asyncio +async def test_crud_with_wildcards( + client_with_roads: AsyncClient, +): + """ + Test creating and updating nodes with SELECT * in their queries + """ + response = await client_with_roads.get("/nodes/default.hard_hat") + old_columns = response.json()["columns"] + + # Using a wildcard should result in a node revision with the same columns + await client_with_roads.patch( + "/nodes/default.hard_hat", + json={"query": "SELECT * FROM default.hard_hats"}, + ) + response = await client_with_roads.get("/nodes/default.hard_hat") + data = response.json() + assert data["columns"] == old_columns + assert data["version"] == "v2.0" + + # Create a transform with wildcards + response = await client_with_roads.post( + "/nodes/transform", + json={ + "name": "default.ny_hard_hats", + "display_name": "NY Hard Hats", + "description": "Blah", + "query": """SELECT + hh.*, + hhs.* + FROM default.hard_hat hh + LEFT JOIN default.hard_hat_state hhs + ON hh.hard_hat_id = hhs.hard_hat_id + WHERE hh.state_id = 'NY'""", + "mode": "published", + }, + ) + data = response.json() + joined_columns = [ + "hard_hat_id", + "last_name", + "first_name", + "title", + "birth_date", + "hire_date", + "address", + "city", + "state", + "postal_code", + "country", + "manager", + "contractor_id", + "hard_hat_id", + "state_id", + ] + assert [col["name"] for col in data["columns"]] == joined_columns + + # Create a transform based on the earlier transform + response = await client_with_roads.post( + "/nodes/transform", + json={ + "name": "default.ny_hard_hats_2", + "display_name": "NY Hard Hats 2", + "description": "Blah", + "query": "SELECT * FROM default.ny_hard_hats", + "mode": "published", + }, + ) + data = response.json() + assert [col["name"] for col in data["columns"]] == joined_columns + + # Update original hard_hat, which should trigger cascading updates of the children + await client_with_roads.patch( + "/nodes/default.hard_hat", + json={ + "query": "SELECT last_name, first_name, birth_date FROM default.hard_hats", + }, + ) + response = await client_with_roads.get("/nodes/default.ny_hard_hats") + data = response.json() + assert [col["name"] for col in data["columns"]] == [ + "last_name", + "first_name", + "birth_date", + "hard_hat_id", + "state_id", + ] + response = await client_with_roads.get("/nodes/default.ny_hard_hats_2") + data = response.json() + assert [col["name"] for col in data["columns"]] == [ + "last_name", + "first_name", + "birth_date", + "hard_hat_id", + "state_id", + ] diff --git a/datajunction-server/tests/construction/compile_test.py b/datajunction-server/tests/construction/compile_test.py index b7346cec6..a9a85cfc8 100644 --- a/datajunction-server/tests/construction/compile_test.py +++ b/datajunction-server/tests/construction/compile_test.py @@ -218,3 +218,165 @@ async def test_having(construction_session: AsyncSession): query_ast = parse(node_a_rev.query) ctx = CompileContext(session=construction_session, exception=DJException()) await query_ast.compile(ctx) + + +@pytest.mark.asyncio +async def test_wildcard_handling( + construction_session: AsyncSession, +): + """ + Test that it handles wildcards by replacing them with the explicit columns + that the wildcard references on the parent node. + """ + query = parse("SELECT * FROM dbt.source.jaffle_shop.orders") + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + assert str(query) == str( + parse( + "SELECT dbt.source.jaffle_shop.orders.id," + "dbt.source.jaffle_shop.orders.user_id," + "dbt.source.jaffle_shop.orders.order_date," + "dbt.source.jaffle_shop.orders.status," + "dbt.source.jaffle_shop.orders._etl_loaded_at " + "FROM dbt.source.jaffle_shop.orders", + ), + ) + + expected_nested_select_star = """SELECT + country, + num_users + FROM ( + SELECT + basic.transform.country_agg.country, + basic.transform.country_agg.num_users + FROM basic.transform.country_agg + )""" + query = parse( + "SELECT * FROM (SELECT * FROM basic.transform.country_agg)", + ) + + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + assert str(query) == str(parse(expected_nested_select_star)) + + query = parse( + "SELECT country, num_users FROM (SELECT * FROM basic.transform.country_agg)", + ) + + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + assert str(query) == str(parse(expected_nested_select_star)) + + query = parse( + "SELECT * FROM (SELECT country, num_users FROM basic.transform.country_agg)", + ) + + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + assert str(query) == str(parse(expected_nested_select_star)) + + query = parse( + "SELECT *, blah FROM " + "(SELECT country, num_users FROM basic.transform.country_agg)", + ) + + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + with pytest.raises(DJParseException) as exc_info: + query.select.projection[ # type: ignore # pylint: disable=pointless-statement + 0 + ].type + assert "Cannot resolve type of column blah" in str(exc_info.value) + + query = parse( + "WITH a AS (SELECT * FROM basic.transform.country_agg)," + "b AS (SELECT * FROM dbt.source.jaffle_shop.orders)" + "SELECT * FROM a JOIN b ON a.country = b.user_id", + ) + + context = CompileContext( + session=construction_session, + exception=DJException(), + ) + await query.compile(context) + assert str(query) == str( + parse( + """ + WITH a AS ( + SELECT + basic.transform.country_agg.country, + basic.transform.country_agg.num_users + FROM basic.transform.country_agg + ), b AS ( + SELECT + dbt.source.jaffle_shop.orders.id, + dbt.source.jaffle_shop.orders.user_id, + dbt.source.jaffle_shop.orders.order_date, + dbt.source.jaffle_shop.orders.status, + dbt.source.jaffle_shop.orders._etl_loaded_at + FROM dbt.source.jaffle_shop.orders + ) + SELECT + a.country, + a.num_users, + b.id, + b.user_id, + b.order_date, + b.status, + b._etl_loaded_at + FROM a JOIN b ON a.country = b.user_id + """, + ), + ) + + query = parse( + "WITH a AS (SELECT * FROM basic.transform.country_agg)," + "b AS (SELECT * FROM dbt.source.jaffle_shop.orders)" + "SELECT a.*, b.* FROM a JOIN b ON a.country = b.user_id", + ) + + await query.compile(context) + assert str(query) == str( + parse( + """ + WITH a AS ( + SELECT + basic.transform.country_agg.country, + basic.transform.country_agg.num_users + FROM basic.transform.country_agg + ), b AS ( + SELECT + dbt.source.jaffle_shop.orders.id, + dbt.source.jaffle_shop.orders.user_id, + dbt.source.jaffle_shop.orders.order_date, + dbt.source.jaffle_shop.orders.status, + dbt.source.jaffle_shop.orders._etl_loaded_at + FROM dbt.source.jaffle_shop.orders + ) + SELECT + a.country, + a.num_users, + b.id, + b.user_id, + b.order_date, + b.status, + b._etl_loaded_at + FROM a JOIN b ON a.country = b.user_id + """, + ), + )