diff --git a/.gitignore b/.gitignore index 0725a26b2..3912f818d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ logs/ *.sublime* .python-version .hatch +venv/ diff --git a/CHANGELOG.md b/CHANGELOG.md index dfce1d4e6..a746b439f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,6 @@ ## dbt-databricks 1.9.0 (TBD) ### Features - - Add config for generating unique tmp table names for enabling parralel merge (thanks @huangxingyi-git!) ([854](https://github.com/databricks/dbt-databricks/pull/854)) - Add support for serverless job clusters on python models ([706](https://github.com/databricks/dbt-databricks/pull/706)) - Add 'user_folder_for_python' behavior to switch writing python model notebooks to the user's folder ([835](https://github.com/databricks/dbt-databricks/pull/835)) @@ -18,6 +17,7 @@ - Add a new `workflow_job` submission method for python, which creates a long-lived Databricks Workflow instead of a one-time run (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762)) - Allow for additional options to be passed to the Databricks Job API when using other python submission methods. For example, enable email_notifications (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762)) - Support microbatch incremental strategy using replace_where ([825](https://github.com/databricks/dbt-databricks/pull/825)) +- Support pyspark session connection ([862]https://github.com/databricks/dbt-databricks/pull/862) ### Fixes diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 49b1ba82b..5e2465c9f 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -91,6 +91,8 @@ # toggle for session managements that minimizes the number of sessions opened/closed USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" +# toggle for session managements that assumes the adapter is running in a Databricks session +USE_SESSION_CONNECTION = os.getenv("DBT_DATABRICKS_SESSION_CONNECTION", "False").upper() == "TRUE" # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. @@ -1079,6 +1081,63 @@ def exponential_backoff(attempt: int) -> int: ) +class DatabricksSessionConnectionManager(DatabricksConnectionManager): + def cancel_open(self) -> list[str]: + return SparkConnectionManager.cancel_open(self) + + def compare_dbr_version(self, major: int, minor: int) -> int: + version = (major, minor) + connection = self.get_thread_connection().handle + dbr_version = connection.dbr_version + return (dbr_version > version) - (dbr_version < version) + + def set_query_header(self, query_header_context: dict[str, Any]) -> None: + SparkConnectionManager.set_query_header(self, query_header_context) + + def set_connection_name( + self, name: Optional[str] = None, query_header_context: Any = None + ) -> Connection: + return SparkConnectionManager.set_connection_name(self, name) + + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + *, + close_cursor: bool = False, + ) -> tuple[Connection, Any]: + return SparkConnectionManager.add_query(self, sql, auto_begin, bindings, abridge_sql_log) + + def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table": + raise NotImplementedError( + "list_schemas is not implemented for DatabricksSessionConnectionManager - " + + "should call the list_schemas macro instead" + ) + + def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> "Table": + raise NotImplementedError( + "list_tables is not implemented for DatabricksSessionConnectionManager - " + + "should call the list_tables macro instead" + ) + + @classmethod + def open(cls, connection: Connection) -> Connection: + from dbt.adapters.databricks.session_connection import DatabricksSessionConnectionWrapper + from dbt.adapters.spark.session import Connection + + handle = DatabricksSessionConnectionWrapper(Connection()) + connection.handle = handle + connection.state = ConnectionState.OPEN + return connection + + @classmethod + def get_response(cls, cursor: Any) -> DatabricksAdapterResponse: + response = SparkConnectionManager.get_response(cursor) + return DatabricksAdapterResponse(_message=response._message) + + def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict: pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}" diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 3e0288b09..03df68920 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -33,9 +33,11 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.connections import ( USE_LONG_SESSIONS, + USE_SESSION_CONNECTION, DatabricksConnectionManager, DatabricksDBTConnection, DatabricksSQLConnectionWrapper, + DatabricksSessionConnectionManager, ExtendedSessionConnectionManager, ) from dbt.adapters.databricks.python_models.python_submissions import ( @@ -156,7 +158,9 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - if USE_LONG_SESSIONS: + if USE_SESSION_CONNECTION: + ConnectionManager: type[DatabricksConnectionManager] = DatabricksSessionConnectionManager + elif USE_LONG_SESSIONS: ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager else: ConnectionManager = DatabricksConnectionManager @@ -272,7 +276,7 @@ def list_schemas(self, database: Optional[str]) -> list[str]: If `database` is `None`, fallback to executing `show databases` because `list_schemas` tries to collect schemas from all catalogs when `database` is `None`. """ - if database is not None: + if database is not None and not USE_SESSION_CONNECTION: results = self.connections.list_schemas(database=database) else: results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) @@ -346,7 +350,7 @@ def _get_hive_relations( kwargs = {"relation": relation} new_rows: list[tuple[str, Optional[str]]] - if all([relation.database, relation.schema]): + if all([relation.database, relation.schema]) and not USE_SESSION_CONNECTION: tables = self.connections.list_tables( database=relation.database, # type: ignore[arg-type] schema=relation.schema, # type: ignore[arg-type] diff --git a/dbt/adapters/databricks/session_connection.py b/dbt/adapters/databricks/session_connection.py new file mode 100644 index 000000000..9fc998406 --- /dev/null +++ b/dbt/adapters/databricks/session_connection.py @@ -0,0 +1,42 @@ +import re +import sys +from typing import Tuple + +from dbt.adapters.spark.session import Connection +from dbt.adapters.spark.session import SessionConnectionWrapper + +DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") + + +class DatabricksSessionConnectionWrapper(SessionConnectionWrapper): + _is_cluster: bool + _dbr_version: Tuple[int, int] + + def __init__(self, handle: Connection) -> None: + super().__init__(handle) + self._is_cluster = True + self.cursor() + + @property + def dbr_version(self) -> Tuple[int, int]: + if not hasattr(self, "_dbr_version"): + if self._is_cluster: + with self._cursor() as cursor: + cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion") + results = cursor.fetchone() + if results: + dbr_version: str = results[1] + + m = DBR_VERSION_REGEX.search(dbr_version) + assert m, f"Unknown DBR version: {dbr_version}" + major = int(m.group(1)) + try: + minor = int(m.group(2)) + except ValueError: + minor = sys.maxsize + self._dbr_version = (major, minor) + else: + # Assuming SQL Warehouse uses the latest version. + self._dbr_version = (sys.maxsize, sys.maxsize) + + return self._dbr_version diff --git a/dbt/include/databricks/macros/adapters/metadata.sql b/dbt/include/databricks/macros/adapters/metadata.sql index fd79a58ce..355638736 100644 --- a/dbt/include/databricks/macros/adapters/metadata.sql +++ b/dbt/include/databricks/macros/adapters/metadata.sql @@ -19,9 +19,18 @@ {% endmacro %} {% macro databricks__show_tables(relation) %} - {% call statement('show_tables', fetch_result=True) -%} - show tables in {{ relation|lower }} - {% endcall %} + {% set database = (relation.database | default(''))| lower | replace('`', '') %} + {% set schema = relation.schema | lower | replace('`', '') %} + + {% if database and schema -%} + {% call statement('show_tables', fetch_result=True) -%} + SHOW TABLES IN {{ database }}.{{ schema }} + {% endcall %} + {% else -%} + {% call statement('show_tables', fetch_result=True) -%} + SHOW TABLES IN {{ relation | lower }} + {% endcall %} + {% endif %} {% do return(load_result('show_tables').table) %} {% endmacro %} @@ -103,4 +112,23 @@ {% endcall %} {% do return(load_result('get_uc_tables').table) %} -{% endmacro %} \ No newline at end of file +{% endmacro %} + +{% macro list_schemas(database) %} + {{ return(adapter.dispatch('list_schemas', 'dbt')(database)) }} +{% endmacro %} + +{% macro databricks__list_schemas(database) -%} + {% set database_clean = (database | default('')) | replace('`', '') %} + {% if database_clean -%} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) %} + SHOW DATABASES IN {{ database_clean }} + {% endcall %} + {% else -%} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) %} + SHOW DATABASES + {% endcall %} + {% endif -%} + + {{ return(load_result('list_schemas').table) }} +{% endmacro %} diff --git a/pyproject.toml b/pyproject.toml index f1f680ead..15efc9e4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,10 +91,29 @@ uc-cluster-e2e = "pytest --color=yes -v --profile databricks_uc_cluster -n auto sqlw-e2e = "pytest --color=yes -v --profile databricks_uc_sql_endpoint -n auto --dist=loadscope tests/functional" [tool.hatch.envs.test.scripts] -unit = "pytest --color=yes -v --profile databricks_cluster -n auto --dist=loadscope tests/unit" +unit = """ +if [[ "$DBT_DATABRICKS_SESSION_CONNECTION" == "True" ]]; then + PROFILE="session_connection" +else + PROFILE="databricks_cluster" +fi + +pytest --color=yes -v --profile $PROFILE -n auto --dist=loadscope tests/unit +""" [[tool.hatch.envs.test.matrix]] python = ["3.9", "3.10", "3.11", "3.12"] +session_support = ["no_session", "session"] + + +[tool.hatch.envs.test.overrides] +matrix.session_support.env-vars = [ + { key = "DBT_DATABRICKS_SESSION_CONNECTION", value = "False", if = ["no_session"]}, + { key = "DBT_DATABRICKS_SESSION_CONNECTION", value = "True", if = ["session"] } +] + + + [tool.ruff] line-length = 100 diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index b26d5f80a..e707dcbb3 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -110,6 +110,7 @@ def test_http_headers(http_header): test_http_headers(["a", "b"]) test_http_headers({"a": 1, "b": 2}) + @pytest.mark.skip_profile("session_connection") def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() @@ -120,6 +121,7 @@ def test_invalid_custom_user_agent(self): assert "Invalid invocation environment" in str(excinfo.value) + @pytest.mark.skip_profile("session_connection") def test_custom_user_agent(self): config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) @@ -134,12 +136,14 @@ def test_custom_user_agent(self): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load + @pytest.mark.skip_profile("session_connection") def test_environment_single_http_header(self): self._test_environment_http_headers( http_headers_str='{"test":{"jobId":1,"runId":12123}}', expected_http_headers=[("test", '{"jobId": 1, "runId": 12123}')], ) + @pytest.mark.skip_profile("session_connection") def test_environment_multiple_http_headers(self): self._test_environment_http_headers( http_headers_str='{"test":{"jobId":1,"runId":12123},"dummy":{"jobId":1,"runId":12123}}', @@ -149,6 +153,7 @@ def test_environment_multiple_http_headers(self): ], ) + @pytest.mark.skip_profile("session_connection") def test_environment_users_http_headers_intersection_error(self): with pytest.raises(DbtValidationError) as excinfo: self._test_environment_http_headers( @@ -159,6 +164,7 @@ def test_environment_users_http_headers_intersection_error(self): assert "Intersection with reserved http_headers in keys: {'t'}" in str(excinfo.value) + @pytest.mark.skip_profile("session_connection") def test_environment_users_http_headers_union_success(self): self._test_environment_http_headers( http_headers_str='{"t":{"jobId":1,"runId":12123},"d":{"jobId":1,"runId":12123}}', @@ -170,6 +176,7 @@ def test_environment_users_http_headers_union_success(self): ], ) + @pytest.mark.skip_profile("session_connection") def test_environment_http_headers_string(self): self._test_environment_http_headers( http_headers_str='{"string":"some-string"}', @@ -269,6 +276,7 @@ def connect( return connect + @pytest.mark.skip_profile("session_connection") def test_databricks_sql_connector_connection(self): self._test_databricks_sql_connector_connection(self._connect_func()) @@ -291,6 +299,7 @@ def _test_databricks_sql_connector_connection(self, connect): assert len(connection.credentials.session_properties) == 1 assert connection.credentials.session_properties["spark.sql.ansi.enabled"] == "true" + @pytest.mark.skip_profile("session_connection") def test_databricks_sql_connector_catalog_connection(self): self._test_databricks_sql_connector_catalog_connection( self._connect_func(expected_catalog="main") @@ -314,6 +323,7 @@ def _test_databricks_sql_connector_catalog_connection(self, connect): assert connection.credentials.schema == "analytics" assert connection.credentials.database == "main" + @pytest.mark.skip_profile("session_connection") def test_databricks_sql_connector_http_header_connection(self): self._test_databricks_sql_connector_http_header_connection( {"aaa": "xxx"}, self._connect_func(expected_http_headers=[("aaa", "xxx")])