Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ logs/
*.sublime*
.python-version
.hatch
venv/
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
## dbt-databricks 1.9.0 (TBD)

### Features

- Add config for generating unique tmp table names for enabling parralel merge (thanks @huangxingyi-git!) ([854](https://github.com/databricks/dbt-databricks/pull/854))
- Add support for serverless job clusters on python models ([706](https://github.com/databricks/dbt-databricks/pull/706))
- Add 'user_folder_for_python' behavior to switch writing python model notebooks to the user's folder ([835](https://github.com/databricks/dbt-databricks/pull/835))
Expand All @@ -18,6 +17,7 @@
- Add a new `workflow_job` submission method for python, which creates a long-lived Databricks Workflow instead of a one-time run (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762))
- Allow for additional options to be passed to the Databricks Job API when using other python submission methods. For example, enable email_notifications (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762))
- Support microbatch incremental strategy using replace_where ([825](https://github.com/databricks/dbt-databricks/pull/825))
- Support pyspark session connection ([862]https://github.com/databricks/dbt-databricks/pull/862)

### Fixes

Expand Down
59 changes: 59 additions & 0 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@

# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
# toggle for session managements that assumes the adapter is running in a Databricks session
USE_SESSION_CONNECTION = os.getenv("DBT_DATABRICKS_SESSION_CONNECTION", "False").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
Expand Down Expand Up @@ -1079,6 +1081,63 @@ def exponential_backoff(attempt: int) -> int:
)


class DatabricksSessionConnectionManager(DatabricksConnectionManager):
def cancel_open(self) -> list[str]:
return SparkConnectionManager.cancel_open(self)

def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)
connection = self.get_thread_connection().handle
dbr_version = connection.dbr_version
return (dbr_version > version) - (dbr_version < version)

def set_query_header(self, query_header_context: dict[str, Any]) -> None:
SparkConnectionManager.set_query_header(self, query_header_context)

def set_connection_name(
self, name: Optional[str] = None, query_header_context: Any = None
) -> Connection:
return SparkConnectionManager.set_connection_name(self, name)

def add_query(
self,
sql: str,
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
*,
close_cursor: bool = False,
) -> tuple[Connection, Any]:
return SparkConnectionManager.add_query(self, sql, auto_begin, bindings, abridge_sql_log)

def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table":
raise NotImplementedError(
"list_schemas is not implemented for DatabricksSessionConnectionManager - "
+ "should call the list_schemas macro instead"
)

def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> "Table":
raise NotImplementedError(
"list_tables is not implemented for DatabricksSessionConnectionManager - "
+ "should call the list_tables macro instead"
)

@classmethod
def open(cls, connection: Connection) -> Connection:
from dbt.adapters.databricks.session_connection import DatabricksSessionConnectionWrapper
from dbt.adapters.spark.session import Connection

handle = DatabricksSessionConnectionWrapper(Connection())
connection.handle = handle
connection.state = ConnectionState.OPEN
return connection

@classmethod
def get_response(cls, cursor: Any) -> DatabricksAdapterResponse:
response = SparkConnectionManager.get_response(cursor)
return DatabricksAdapterResponse(_message=response._message)


def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

Expand Down
10 changes: 7 additions & 3 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
USE_SESSION_CONNECTION,
DatabricksConnectionManager,
DatabricksDBTConnection,
DatabricksSQLConnectionWrapper,
DatabricksSessionConnectionManager,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.python_models.python_submissions import (
Expand Down Expand Up @@ -156,7 +158,9 @@ class DatabricksAdapter(SparkAdapter):
Relation = DatabricksRelation
Column = DatabricksColumn

if USE_LONG_SESSIONS:
if USE_SESSION_CONNECTION:
ConnectionManager: type[DatabricksConnectionManager] = DatabricksSessionConnectionManager
elif USE_LONG_SESSIONS:
ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager
else:
ConnectionManager = DatabricksConnectionManager
Expand Down Expand Up @@ -272,7 +276,7 @@ def list_schemas(self, database: Optional[str]) -> list[str]:
If `database` is `None`, fallback to executing `show databases` because
`list_schemas` tries to collect schemas from all catalogs when `database` is `None`.
"""
if database is not None:
if database is not None and not USE_SESSION_CONNECTION:
results = self.connections.list_schemas(database=database)
else:
results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database})
Expand Down Expand Up @@ -346,7 +350,7 @@ def _get_hive_relations(
kwargs = {"relation": relation}

new_rows: list[tuple[str, Optional[str]]]
if all([relation.database, relation.schema]):
if all([relation.database, relation.schema]) and not USE_SESSION_CONNECTION:
tables = self.connections.list_tables(
database=relation.database, # type: ignore[arg-type]
schema=relation.schema, # type: ignore[arg-type]
Expand Down
42 changes: 42 additions & 0 deletions dbt/adapters/databricks/session_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re
import sys
from typing import Tuple

from dbt.adapters.spark.session import Connection
from dbt.adapters.spark.session import SessionConnectionWrapper

DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")


class DatabricksSessionConnectionWrapper(SessionConnectionWrapper):
_is_cluster: bool
_dbr_version: Tuple[int, int]

def __init__(self, handle: Connection) -> None:
super().__init__(handle)
self._is_cluster = True
self.cursor()

@property
def dbr_version(self) -> Tuple[int, int]:
if not hasattr(self, "_dbr_version"):
if self._is_cluster:
with self._cursor() as cursor:
cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion")
results = cursor.fetchone()
if results:
dbr_version: str = results[1]

m = DBR_VERSION_REGEX.search(dbr_version)
assert m, f"Unknown DBR version: {dbr_version}"
major = int(m.group(1))
try:
minor = int(m.group(2))
except ValueError:
minor = sys.maxsize
self._dbr_version = (major, minor)
else:
# Assuming SQL Warehouse uses the latest version.
self._dbr_version = (sys.maxsize, sys.maxsize)

return self._dbr_version
36 changes: 32 additions & 4 deletions dbt/include/databricks/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@
{% endmacro %}

{% macro databricks__show_tables(relation) %}
{% call statement('show_tables', fetch_result=True) -%}
show tables in {{ relation|lower }}
{% endcall %}
{% set database = (relation.database | default(''))| lower | replace('`', '') %}
{% set schema = relation.schema | lower | replace('`', '') %}

{% if database and schema -%}
{% call statement('show_tables', fetch_result=True) -%}
SHOW TABLES IN {{ database }}.{{ schema }}
{% endcall %}
{% else -%}
{% call statement('show_tables', fetch_result=True) -%}
SHOW TABLES IN {{ relation | lower }}
{% endcall %}
{% endif %}

{% do return(load_result('show_tables').table) %}
{% endmacro %}
Expand Down Expand Up @@ -103,4 +112,23 @@
{% endcall %}

{% do return(load_result('get_uc_tables').table) %}
{% endmacro %}
{% endmacro %}

{% macro list_schemas(database) %}
{{ return(adapter.dispatch('list_schemas', 'dbt')(database)) }}
{% endmacro %}

{% macro databricks__list_schemas(database) -%}
{% set database_clean = (database | default('')) | replace('`', '') %}
{% if database_clean -%}
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
SHOW DATABASES IN {{ database_clean }}
{% endcall %}
{% else -%}
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
SHOW DATABASES
{% endcall %}
{% endif -%}

{{ return(load_result('list_schemas').table) }}
{% endmacro %}
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,29 @@ uc-cluster-e2e = "pytest --color=yes -v --profile databricks_uc_cluster -n auto
sqlw-e2e = "pytest --color=yes -v --profile databricks_uc_sql_endpoint -n auto --dist=loadscope tests/functional"

[tool.hatch.envs.test.scripts]
unit = "pytest --color=yes -v --profile databricks_cluster -n auto --dist=loadscope tests/unit"
unit = """
if [[ "$DBT_DATABRICKS_SESSION_CONNECTION" == "True" ]]; then
PROFILE="session_connection"
else
PROFILE="databricks_cluster"
fi

pytest --color=yes -v --profile $PROFILE -n auto --dist=loadscope tests/unit
"""

[[tool.hatch.envs.test.matrix]]
python = ["3.9", "3.10", "3.11", "3.12"]
session_support = ["no_session", "session"]


[tool.hatch.envs.test.overrides]
matrix.session_support.env-vars = [
{ key = "DBT_DATABRICKS_SESSION_CONNECTION", value = "False", if = ["no_session"]},
{ key = "DBT_DATABRICKS_SESSION_CONNECTION", value = "True", if = ["session"] }
]




[tool.ruff]
line-length = 100
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_http_headers(http_header):
test_http_headers(["a", "b"])
test_http_headers({"a": 1, "b": 2})

@pytest.mark.skip_profile("session_connection")
def test_invalid_custom_user_agent(self):
with pytest.raises(DbtValidationError) as excinfo:
config = self._get_config()
Expand All @@ -120,6 +121,7 @@ def test_invalid_custom_user_agent(self):

assert "Invalid invocation environment" in str(excinfo.value)

@pytest.mark.skip_profile("session_connection")
def test_custom_user_agent(self):
config = self._get_config()
adapter = DatabricksAdapter(config, get_context("spawn"))
Expand All @@ -134,12 +136,14 @@ def test_custom_user_agent(self):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

@pytest.mark.skip_profile("session_connection")
def test_environment_single_http_header(self):
self._test_environment_http_headers(
http_headers_str='{"test":{"jobId":1,"runId":12123}}',
expected_http_headers=[("test", '{"jobId": 1, "runId": 12123}')],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_multiple_http_headers(self):
self._test_environment_http_headers(
http_headers_str='{"test":{"jobId":1,"runId":12123},"dummy":{"jobId":1,"runId":12123}}',
Expand All @@ -149,6 +153,7 @@ def test_environment_multiple_http_headers(self):
],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_users_http_headers_intersection_error(self):
with pytest.raises(DbtValidationError) as excinfo:
self._test_environment_http_headers(
Expand All @@ -159,6 +164,7 @@ def test_environment_users_http_headers_intersection_error(self):

assert "Intersection with reserved http_headers in keys: {'t'}" in str(excinfo.value)

@pytest.mark.skip_profile("session_connection")
def test_environment_users_http_headers_union_success(self):
self._test_environment_http_headers(
http_headers_str='{"t":{"jobId":1,"runId":12123},"d":{"jobId":1,"runId":12123}}',
Expand All @@ -170,6 +176,7 @@ def test_environment_users_http_headers_union_success(self):
],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_http_headers_string(self):
self._test_environment_http_headers(
http_headers_str='{"string":"some-string"}',
Expand Down Expand Up @@ -269,6 +276,7 @@ def connect(

return connect

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_connection(self):
self._test_databricks_sql_connector_connection(self._connect_func())

Expand All @@ -291,6 +299,7 @@ def _test_databricks_sql_connector_connection(self, connect):
assert len(connection.credentials.session_properties) == 1
assert connection.credentials.session_properties["spark.sql.ansi.enabled"] == "true"

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_catalog_connection(self):
self._test_databricks_sql_connector_catalog_connection(
self._connect_func(expected_catalog="main")
Expand All @@ -314,6 +323,7 @@ def _test_databricks_sql_connector_catalog_connection(self, connect):
assert connection.credentials.schema == "analytics"
assert connection.credentials.database == "main"

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_http_header_connection(self):
self._test_databricks_sql_connector_http_header_connection(
{"aaa": "xxx"}, self._connect_func(expected_http_headers=[("aaa", "xxx")])
Expand Down
Loading