From 4c9d88745070036e57cd023212b4ffd2c1ee11e4 Mon Sep 17 00:00:00 2001 From: geruh Date: Sat, 20 Dec 2025 00:57:21 -0800 Subject: [PATCH 1/5] feat: Allow servers to express supported endpoints with ConfigResponse --- pyiceberg/catalog/rest/__init__.py | 169 +++++++++++++++++++- tests/catalog/test_rest.py | 239 +++++++++++++++++++++++++++-- 2 files changed, 394 insertions(+), 14 deletions(-) diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index a28ff562bd..f9a8d0a6fd 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -21,7 +21,7 @@ Union, ) -from pydantic import Field, field_validator +from pydantic import ConfigDict, Field, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt @@ -76,6 +76,43 @@ import pyarrow as pa +class HttpMethod(str, Enum): + GET = "GET" + HEAD = "HEAD" + POST = "POST" + DELETE = "DELETE" + + +class Endpoint(IcebergBaseModel): + model_config = ConfigDict(frozen=True) + + http_method: HttpMethod = Field() + path: str = Field() + + @field_validator("path", mode="before") + @classmethod + def _validate_path(cls, raw_path: str) -> str: + if not raw_path: + raise ValueError("Invalid path: empty") + raw_path = raw_path.strip() + if not raw_path: + raise ValueError("Invalid path: empty") + return raw_path + + def __str__(self) -> str: + """Return the string representation of the Endpoint class.""" + return f"{self.http_method.value} {self.path}" + + @classmethod + def from_string(cls, endpoint: str | None) -> "Endpoint": + if endpoint is None: + raise ValueError("Invalid endpoint (must consist of 'METHOD /path'): None") + elements = endpoint.split(None, 1) + if len(elements) != 2: + raise ValueError(f"Invalid endpoint (must consist of two elements separated by a single space): {endpoint}") + return cls(http_method=HttpMethod(elements[0].upper()), path=elements[1]) + + class Endpoints: get_config: str = "config" list_namespaces: str = "namespaces" @@ -86,7 +123,7 @@ class Endpoints: namespace_exists: str = "namespaces/{namespace}" list_tables: str = "namespaces/{namespace}/tables" create_table: str = "namespaces/{namespace}/tables" - register_table = "namespaces/{namespace}/register" + register_table: str = "namespaces/{namespace}/register" load_table: str = "namespaces/{namespace}/tables/{table}" update_table: str = "namespaces/{namespace}/tables/{table}" drop_table: str = "namespaces/{namespace}/tables/{table}" @@ -100,6 +137,66 @@ class Endpoints: fetch_scan_tasks: str = "namespaces/{namespace}/tables/{table}/tasks" +class Capability: + V1_LIST_NAMESPACES = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces") + V1_LOAD_NAMESPACE = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}") + V1_NAMESPACE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}") + V1_UPDATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/properties") + V1_CREATE_NAMESPACE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces") + V1_DELETE_NAMESPACE = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}") + + V1_LIST_TABLES = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/tables") + V1_LOAD_TABLE = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}") + V1_TABLE_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}") + V1_CREATE_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables") + V1_UPDATE_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}") + V1_DELETE_TABLE = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}") + V1_RENAME_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/tables/rename") + V1_REGISTER_TABLE = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/register") + + V1_LIST_VIEWS = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/views") + V1_LOAD_VIEW = Endpoint(http_method=HttpMethod.GET, path="/v1/{prefix}/namespaces/{namespace}/views/{view}") + V1_VIEW_EXISTS = Endpoint(http_method=HttpMethod.HEAD, path="/v1/{prefix}/namespaces/{namespace}/views/{view}") + V1_CREATE_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/views") + V1_UPDATE_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/views/{view}") + V1_DELETE_VIEW = Endpoint(http_method=HttpMethod.DELETE, path="/v1/{prefix}/namespaces/{namespace}/views/{view}") + V1_RENAME_VIEW = Endpoint(http_method=HttpMethod.POST, path="/v1/{prefix}/views/rename") + V1_SUBMIT_TABLE_SCAN_PLAN = Endpoint( + http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}/plan" + ) + V1_TABLE_SCAN_PLAN_TASKS = Endpoint( + http_method=HttpMethod.POST, path="/v1/{prefix}/namespaces/{namespace}/tables/{table}/tasks" + ) + + +# Default endpoints for backwards compatibility with legacy servers that don't return endpoints +# in ConfigResponse. Only includes namespace and table endpoints. +DEFAULT_ENDPOINTS: frozenset[Endpoint] = frozenset( + ( + Capability.V1_LIST_NAMESPACES, + Capability.V1_LOAD_NAMESPACE, + Capability.V1_CREATE_NAMESPACE, + Capability.V1_UPDATE_NAMESPACE, + Capability.V1_DELETE_NAMESPACE, + Capability.V1_LIST_TABLES, + Capability.V1_LOAD_TABLE, + Capability.V1_CREATE_TABLE, + Capability.V1_UPDATE_TABLE, + Capability.V1_DELETE_TABLE, + Capability.V1_RENAME_TABLE, + Capability.V1_REGISTER_TABLE, + ) +) + +# View endpoints conditionally added based on VIEW_ENDPOINTS_SUPPORTED property. +VIEW_ENDPOINTS: frozenset[Endpoint] = frozenset( + ( + Capability.V1_LIST_VIEWS, + Capability.V1_DELETE_VIEW, + ) +) + + class IdentifierKind(Enum): TABLE = "table" VIEW = "view" @@ -134,6 +231,8 @@ class IdentifierKind(Enum): CUSTOM = "custom" REST_SCAN_PLANNING_ENABLED = "rest-scan-planning-enabled" REST_SCAN_PLANNING_ENABLED_DEFAULT = False +VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported" +VIEW_ENDPOINTS_SUPPORTED_DEFAULT = False NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8) @@ -180,6 +279,14 @@ class RegisterTableRequest(IcebergBaseModel): class ConfigResponse(IcebergBaseModel): defaults: Properties | None = Field(default_factory=dict) overrides: Properties | None = Field(default_factory=dict) + endpoints: set[Endpoint] | None = Field(default=None) + + @field_validator("endpoints", mode="before") + @classmethod + def _parse_endpoints(cls, v: list[str] | None) -> set[Endpoint] | None: + if v is None: + return None + return {Endpoint.from_string(s) for s in v} class ListNamespaceResponse(IcebergBaseModel): @@ -218,6 +325,7 @@ class ListViewsResponse(IcebergBaseModel): class RestCatalog(Catalog): uri: str _session: Session + _supported_endpoints: set[Endpoint] def __init__(self, name: str, **properties: str): """Rest Catalog. @@ -279,7 +387,9 @@ def is_rest_scan_planning_enabled(self) -> bool: Returns: True if enabled, False otherwise. """ - return property_as_bool(self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT) + return Capability.V1_SUBMIT_TABLE_SCAN_PLAN in self._supported_endpoints and property_as_bool( + self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT + ) def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: """Create the LegacyOAuth2AuthManager by fetching required properties. @@ -327,6 +437,18 @@ def url(self, endpoint: str, prefixed: bool = True, **kwargs: Any) -> str: return url + endpoint.format(**kwargs) + def _check_endpoint(self, endpoint: Endpoint) -> None: + """Check if an endpoint is supported by the server. + + Args: + endpoint: The endpoint to check against the set of supported endpoints + + Raises: + NotImplementedError: If the endpoint is not supported. + """ + if endpoint not in self._supported_endpoints: + raise NotImplementedError(f"Server does not support endpoint: {endpoint}") + @property def auth_url(self) -> str: self._warn_oauth_tokens_deprecation() @@ -384,6 +506,17 @@ def _fetch_config(self) -> None: # Update URI based on overrides self.uri = config[URI] + # Determine supported endpoints + endpoints = config_response.endpoints + if endpoints: + self._supported_endpoints = set(endpoints) + else: + # Use default endpoints for legacy servers that don't return endpoints + self._supported_endpoints = set(DEFAULT_ENDPOINTS) + # Conditionally add view endpoints based on config + if property_as_bool(self.properties, VIEW_ENDPOINTS_SUPPORTED, VIEW_ENDPOINTS_SUPPORTED_DEFAULT): + self._supported_endpoints.update(VIEW_ENDPOINTS) + def _identifier_to_validated_tuple(self, identifier: str | Identifier) -> Identifier: identifier_tuple = self.identifier_to_tuple(identifier) if len(identifier_tuple) <= 1: @@ -503,6 +636,7 @@ def _create_table( properties: Properties = EMPTY_DICT, stage_create: bool = False, ) -> TableResponse: + self._check_endpoint(Capability.V1_CREATE_TABLE) iceberg_schema = self._convert_schema_if_needed( schema, int(properties.get(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION)), # type: ignore @@ -591,6 +725,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) - Raises: TableAlreadyExistsError: If the table already exists """ + self._check_endpoint(Capability.V1_REGISTER_TABLE) namespace_and_table = self._split_identifier_for_path(identifier) request = RegisterTableRequest( name=namespace_and_table["table"], @@ -611,6 +746,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) - @retry(**_RETRY_ARGS) def list_tables(self, namespace: str | Identifier) -> list[Identifier]: + self._check_endpoint(Capability.V1_LIST_TABLES) namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat)) @@ -622,6 +758,7 @@ def list_tables(self, namespace: str | Identifier) -> list[Identifier]: @retry(**_RETRY_ARGS) def load_table(self, identifier: str | Identifier) -> Table: + self._check_endpoint(Capability.V1_LOAD_TABLE) params = {} if mode := self.properties.get(SNAPSHOT_LOADING_MODE): if mode in {"all", "refs"}: @@ -642,6 +779,7 @@ def load_table(self, identifier: str | Identifier) -> Table: @retry(**_RETRY_ARGS) def drop_table(self, identifier: str | Identifier, purge_requested: bool = False) -> None: + self._check_endpoint(Capability.V1_DELETE_TABLE) response = self._session.delete( self.url(Endpoints.drop_table, prefixed=True, **self._split_identifier_for_path(identifier)), params={"purgeRequested": purge_requested}, @@ -657,6 +795,7 @@ def purge_table(self, identifier: str | Identifier) -> None: @retry(**_RETRY_ARGS) def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: + self._check_endpoint(Capability.V1_RENAME_TABLE) payload = { "source": self._split_identifier_for_json(from_identifier), "destination": self._split_identifier_for_json(to_identifier), @@ -692,6 +831,8 @@ def _remove_catalog_name_from_table_request_identifier(self, table_request: Comm @retry(**_RETRY_ARGS) def list_views(self, namespace: str | Identifier) -> list[Identifier]: + if Capability.V1_LIST_VIEWS not in self._supported_endpoints: + return [] namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.list_views, namespace=namespace_concat)) @@ -720,6 +861,7 @@ def commit_table( CommitFailedException: Requirement not met, or a conflict with a concurrent commit. CommitStateUnknownException: Failed due to an internal exception on the side of the catalog. """ + self._check_endpoint(Capability.V1_UPDATE_TABLE) identifier = table.name() table_identifier = TableIdentifier(namespace=identifier[:-1], name=identifier[-1]) table_request = CommitTableRequest(identifier=table_identifier, requirements=requirements, updates=updates) @@ -749,6 +891,7 @@ def commit_table( @retry(**_RETRY_ARGS) def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: + self._check_endpoint(Capability.V1_CREATE_NAMESPACE) namespace_tuple = self._check_valid_namespace_identifier(namespace) payload = {"namespace": namespace_tuple, "properties": properties} response = self._session.post(self.url(Endpoints.create_namespace), json=payload) @@ -759,6 +902,7 @@ def create_namespace(self, namespace: str | Identifier, properties: Properties = @retry(**_RETRY_ARGS) def drop_namespace(self, namespace: str | Identifier) -> None: + self._check_endpoint(Capability.V1_DELETE_NAMESPACE) namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace)) @@ -769,6 +913,7 @@ def drop_namespace(self, namespace: str | Identifier) -> None: @retry(**_RETRY_ARGS) def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]: + self._check_endpoint(Capability.V1_LIST_NAMESPACES) namespace_tuple = self.identifier_to_tuple(namespace) response = self._session.get( self.url( @@ -786,6 +931,7 @@ def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]: @retry(**_RETRY_ARGS) def load_namespace_properties(self, namespace: str | Identifier) -> Properties: + self._check_endpoint(Capability.V1_LOAD_NAMESPACE) namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace)) @@ -800,6 +946,7 @@ def load_namespace_properties(self, namespace: str | Identifier) -> Properties: def update_namespace_properties( self, namespace: str | Identifier, removals: set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: + self._check_endpoint(Capability.V1_UPDATE_NAMESPACE) namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) payload = {"removals": list(removals or []), "updates": updates} @@ -819,6 +966,14 @@ def update_namespace_properties( def namespace_exists(self, namespace: str | Identifier) -> bool: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) + + if Capability.V1_NAMESPACE_EXISTS not in self._supported_endpoints: + try: + self.load_namespace_properties(namespace_tuple) + return True + except NoSuchNamespaceError: + return False + response = self._session.head(self.url(Endpoints.namespace_exists, namespace=namespace)) if response.status_code == 404: @@ -843,6 +998,13 @@ def table_exists(self, identifier: str | Identifier) -> bool: Returns: bool: True if the table exists, False otherwise. """ + if Capability.V1_TABLE_EXISTS not in self._supported_endpoints: + try: + self.load_table(identifier) + return True + except NoSuchTableError: + return False + response = self._session.head( self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)) ) @@ -886,6 +1048,7 @@ def view_exists(self, identifier: str | Identifier) -> bool: @retry(**_RETRY_ARGS) def drop_view(self, identifier: str) -> None: + self._check_endpoint(Capability.V1_DELETE_VIEW) response = self._session.delete( self.url(Endpoints.drop_view, prefixed=True, **self._split_identifier_for_path(identifier, IdentifierKind.VIEW)), ) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 464314f3be..03bef31070 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -27,7 +27,7 @@ import pyiceberg from pyiceberg.catalog import PropertiesUpdateSummary, load_catalog -from pyiceberg.catalog.rest import OAUTH2_SERVER_URI, SNAPSHOT_LOADING_MODE, RestCatalog +from pyiceberg.catalog.rest import DEFAULT_ENDPOINTS, OAUTH2_SERVER_URI, SNAPSHOT_LOADING_MODE, Capability, RestCatalog from pyiceberg.exceptions import ( AuthorizationExpiredError, NamespaceAlreadyExistsError, @@ -457,7 +457,9 @@ def test_list_views_200(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) - assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_views(namespace) == [("examples", "fooshare")] + assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, **{"view-endpoints-supported": "true"}).list_views(namespace) == [ + ("examples", "fooshare") + ] def test_list_views_200_sigv4(rest_mock: Mocker) -> None: @@ -469,9 +471,9 @@ def test_list_views_200_sigv4(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) - assert RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}).list_views(namespace) == [ - ("examples", "fooshare") - ] + assert RestCatalog( + "rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true", "view-endpoints-supported": "true"} + ).list_views(namespace) == [("examples", "fooshare")] assert rest_mock.called @@ -490,7 +492,7 @@ def test_list_views_404(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchNamespaceError) as e: - RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_views(namespace) + RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, **{"view-endpoints-supported": "true"}).list_views(namespace) assert "Namespace does not exist" in str(e.value) @@ -502,7 +504,7 @@ def test_view_exists_204(rest_mock: Mocker) -> None: status_code=204, request_headers=TEST_HEADERS, ) - catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, **{"view-endpoints-supported": "true"}) assert catalog.view_exists((namespace, view)) @@ -514,7 +516,7 @@ def test_view_exists_404(rest_mock: Mocker) -> None: status_code=404, request_headers=TEST_HEADERS, ) - catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, **{"view-endpoints-supported": "true"}) assert not catalog.view_exists((namespace, view)) @@ -782,6 +784,7 @@ def test_namespace_exists_200(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog._supported_endpoints.add(Capability.V1_NAMESPACE_EXISTS) assert catalog.namespace_exists("fokko") @@ -793,6 +796,7 @@ def test_namespace_exists_204(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog._supported_endpoints.add(Capability.V1_NAMESPACE_EXISTS) assert catalog.namespace_exists("fokko") @@ -804,6 +808,7 @@ def test_namespace_exists_404(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog._supported_endpoints.add(Capability.V1_NAMESPACE_EXISTS) assert not catalog.namespace_exists("fokko") @@ -815,6 +820,7 @@ def test_namespace_exists_500(rest_mock: Mocker) -> None: request_headers=TEST_HEADERS, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + catalog._supported_endpoints.add(Capability.V1_NAMESPACE_EXISTS) with pytest.raises(ServerError): catalog.namespace_exists("fokko") @@ -957,6 +963,15 @@ def test_load_table_404(rest_mock: Mocker) -> None: def test_table_exists_200(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_TABLE_EXISTS)], + }, + status_code=200, + ) rest_mock.head( f"{TEST_URI}v1/namespaces/fokko/tables/table", status_code=200, @@ -967,6 +982,15 @@ def test_table_exists_200(rest_mock: Mocker) -> None: def test_table_exists_204(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_TABLE_EXISTS)], + }, + status_code=200, + ) rest_mock.head( f"{TEST_URI}v1/namespaces/fokko/tables/table", status_code=204, @@ -977,6 +1001,15 @@ def test_table_exists_204(rest_mock: Mocker) -> None: def test_table_exists_404(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_TABLE_EXISTS)], + }, + status_code=200, + ) rest_mock.head( f"{TEST_URI}v1/namespaces/fokko/tables/table", status_code=404, @@ -987,6 +1020,15 @@ def test_table_exists_404(rest_mock: Mocker) -> None: def test_table_exists_500(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_TABLE_EXISTS)], + }, + status_code=200, + ) rest_mock.head( f"{TEST_URI}v1/namespaces/fokko/tables/table", status_code=500, @@ -1339,6 +1381,15 @@ def test_delete_table_from_self_identifier_204( def test_rename_table_200(rest_mock: Mocker, example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any]) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_NAMESPACE_EXISTS), str(Capability.V1_RENAME_TABLE), str(Capability.V1_LOAD_TABLE)], + }, + status_code=200, + ) rest_mock.post( f"{TEST_URI}v1/tables/rename", json={ @@ -1377,6 +1428,15 @@ def test_rename_table_200(rest_mock: Mocker, example_table_metadata_with_snapsho def test_rename_table_from_self_identifier_200( rest_mock: Mocker, example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any] ) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_NAMESPACE_EXISTS), str(Capability.V1_RENAME_TABLE), str(Capability.V1_LOAD_TABLE)], + }, + status_code=200, + ) rest_mock.get( f"{TEST_URI}v1/namespaces/pdames/tables/source", json=example_table_metadata_with_snapshot_v1_rest_json, @@ -1420,6 +1480,15 @@ def test_rename_table_from_self_identifier_200( def test_rename_table_source_namespace_does_not_exist(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_NAMESPACE_EXISTS), str(Capability.V1_RENAME_TABLE), str(Capability.V1_LOAD_TABLE)], + }, + status_code=200, + ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) from_identifier = ("invalid", "source") to_identifier = ("pdames", "destination") @@ -1441,6 +1510,15 @@ def test_rename_table_source_namespace_does_not_exist(rest_mock: Mocker) -> None def test_rename_table_destination_namespace_does_not_exist(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_NAMESPACE_EXISTS), str(Capability.V1_RENAME_TABLE), str(Capability.V1_LOAD_TABLE)], + }, + status_code=200, + ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) from_identifier = ("pdames", "source") to_identifier = ("invalid", "destination") @@ -1824,6 +1902,15 @@ def test_table_identifier_in_commit_table_request( def test_drop_view_invalid_namespace(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_DELETE_VIEW)], + }, + status_code=200, + ) view = "view" with pytest.raises(NoSuchIdentifierError) as e: # Missing namespace @@ -1833,6 +1920,15 @@ def test_drop_view_invalid_namespace(rest_mock: Mocker) -> None: def test_drop_view_404(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_DELETE_VIEW)], + }, + status_code=200, + ) rest_mock.delete( f"{TEST_URI}v1/namespaces/some_namespace/views/does_not_exists", json={ @@ -1852,6 +1948,15 @@ def test_drop_view_404(rest_mock: Mocker) -> None: def test_drop_view_204(rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [str(Capability.V1_DELETE_VIEW)], + }, + status_code=200, + ) rest_mock.delete( f"{TEST_URI}v1/namespaces/some_namespace/views/some_view", json={}, @@ -2007,7 +2112,11 @@ def test_rest_scan_planning_disabled_by_default(self, rest_mock: Mocker) -> None def test_rest_scan_planning_enabled_by_property(self, rest_mock: Mocker) -> None: rest_mock.get( f"{TEST_URI}v1/config", - json={"defaults": {}, "overrides": {}}, + json={ + "defaults": {}, + "overrides": {}, + "endpoints": ["POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/plan"], + }, status_code=200, ) catalog = RestCatalog( @@ -2019,12 +2128,31 @@ def test_rest_scan_planning_enabled_by_property(self, rest_mock: Mocker) -> None assert catalog.is_rest_scan_planning_enabled() is True - def test_rest_scan_planning_explicitly_disabled(self, rest_mock: Mocker) -> None: + def test_rest_scan_planning_disabled_without_endpoint_support(self, rest_mock: Mocker) -> None: rest_mock.get( f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200, ) + catalog = RestCatalog( + "rest", + uri=TEST_URI, + token=TEST_TOKEN, + **{"rest-scan-planning-enabled": "true"}, + ) + + assert catalog.is_rest_scan_planning_enabled() is False + + def test_rest_scan_planning_explicitly_disabled(self, rest_mock: Mocker) -> None: + rest_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": ["POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/plan"], + }, + status_code=200, + ) catalog = RestCatalog( "rest", uri=TEST_URI, @@ -2037,9 +2165,98 @@ def test_rest_scan_planning_explicitly_disabled(self, rest_mock: Mocker) -> None def test_rest_scan_planning_enabled_from_server_config(self, rest_mock: Mocker) -> None: rest_mock.get( f"{TEST_URI}v1/config", - json={"defaults": {"rest-scan-planning-enabled": "true"}, "overrides": {}}, + json={ + "defaults": {"rest-scan-planning-enabled": "true"}, + "overrides": {}, + "endpoints": ["POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/plan"], + }, status_code=200, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) assert catalog.is_rest_scan_planning_enabled() is True + + def test_supported_endpoint(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": ["GET /v1/{prefix}/namespaces", "GET /v1/{prefix}/namespaces/{namespace}/tables"], + }, + status_code=200, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token="token") + + # Should not raise since these endpoints are in the supported set + catalog._check_endpoint(Capability.V1_LIST_NAMESPACES) + catalog._check_endpoint(Capability.V1_LIST_TABLES) + + def test_unsupported_endpoint(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": ["GET /v1/{prefix}/namespaces"], + }, + status_code=200, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token="token") + + with pytest.raises(NotImplementedError, match="Server does not support endpoint"): + catalog._check_endpoint(Capability.V1_LIST_TABLES) + + def test_config_returns_invalid_endpoint(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": ["INVALID_ENDPOINT"], + }, + status_code=200, + ) + + with pytest.raises(ValueError, match="Invalid endpoint"): + RestCatalog("rest", uri=TEST_URI, token="token") + + def test_default_endpoints_used_when_none_returned(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={"defaults": {}, "overrides": {}}, + status_code=200, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token="token") + + # Should not raise for default endpoints + for endpoint in DEFAULT_ENDPOINTS: + catalog._check_endpoint(endpoint) + + def test_view_endpoints_not_included_by_default(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={"defaults": {}, "overrides": {}}, + status_code=200, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token="token") + + with pytest.raises(NotImplementedError, match="Server does not support endpoint"): + catalog._check_endpoint(Capability.V1_LIST_VIEWS) + + def test_view_endpoints_enabled_with_config(self, requests_mock: Mocker) -> None: + requests_mock.get( + f"{TEST_URI}v1/config", + json={"defaults": {}, "overrides": {}}, + status_code=200, + ) + catalog = RestCatalog( + "rest", + uri=TEST_URI, + token="token", + **{"view-endpoints-supported": "true"}, + ) + + # View endpoints should be supported when enabled + catalog._check_endpoint(Capability.V1_LIST_VIEWS) + catalog._check_endpoint(Capability.V1_DELETE_VIEW) From 663be50f1275ca89883b37da828ea9c9f30f0dea Mon Sep 17 00:00:00 2001 From: geruh Date: Tue, 23 Dec 2025 16:11:27 -0800 Subject: [PATCH 2/5] feat: Add models for rest scan planning --- pyiceberg/catalog/rest/scan_planning.py | 208 ++++++++++ tests/catalog/test_scan_planning_models.py | 450 +++++++++++++++++++++ 2 files changed, 658 insertions(+) create mode 100644 pyiceberg/catalog/rest/scan_planning.py create mode 100644 tests/catalog/test_scan_planning_models.py diff --git a/pyiceberg/catalog/rest/scan_planning.py b/pyiceberg/catalog/rest/scan_planning.py new file mode 100644 index 0000000000..ddccf1d9e3 --- /dev/null +++ b/pyiceberg/catalog/rest/scan_planning.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal +from typing import Annotated, Generic, Literal, TypeAlias, TypeVar +from uuid import UUID + +from pydantic import Field, model_validator + +from pyiceberg.catalog.rest.response import ErrorResponseMessage +from pyiceberg.expressions import BooleanExpression +from pyiceberg.manifest import FileFormat +from pyiceberg.typedef import IcebergBaseModel + +# Primitive types that can appear in partition values and bounds +PrimitiveTypeValue: TypeAlias = bool | int | float | str | Decimal | UUID | date | time | datetime | bytes + +V = TypeVar("V") + + +class KeyValueMap(IcebergBaseModel, Generic[V]): + """Map serialized as parallel key/value arrays for column statistics.""" + + keys: list[int] = Field(default_factory=list) + values: list[V] = Field(default_factory=list) + + @model_validator(mode="after") + def _validate_lengths_match(self) -> KeyValueMap[V]: + if len(self.keys) != len(self.values): + raise ValueError(f"keys and values must have same length: {len(self.keys)} != {len(self.values)}") + return self + + def to_dict(self) -> dict[int, V]: + """Convert to dictionary mapping field ID to value.""" + return dict(zip(self.keys, self.values, strict=True)) + + +class CountMap(KeyValueMap[int]): + """Map of field IDs to counts.""" + + +class ValueMap(KeyValueMap[PrimitiveTypeValue]): + """Map of field IDs to primitive values (for lower/upper bounds).""" + + +class StorageCredential(IcebergBaseModel): + """Storage credential for accessing content files.""" + + prefix: str = Field(description="Storage location prefix this credential applies to") + config: dict[str, str] = Field(default_factory=dict) + + +class RESTContentFile(IcebergBaseModel): + """Base model for data and delete files from REST API.""" + + spec_id: int = Field(alias="spec-id") + partition: list[PrimitiveTypeValue] = Field(default_factory=list) + content: Literal["data", "position-deletes", "equality-deletes"] + file_path: str = Field(alias="file-path") + file_format: FileFormat = Field(alias="file-format") + file_size_in_bytes: int = Field(alias="file-size-in-bytes") + record_count: int = Field(alias="record-count") + key_metadata: str | None = Field(alias="key-metadata", default=None) + split_offsets: list[int] | None = Field(alias="split-offsets", default=None) + sort_order_id: int | None = Field(alias="sort-order-id", default=None) + + +class RESTDataFile(RESTContentFile): + """Data file from REST API.""" + + content: Literal["data"] = Field(default="data") + first_row_id: int | None = Field(alias="first-row-id", default=None) + column_sizes: CountMap | None = Field(alias="column-sizes", default=None) + value_counts: CountMap | None = Field(alias="value-counts", default=None) + null_value_counts: CountMap | None = Field(alias="null-value-counts", default=None) + nan_value_counts: CountMap | None = Field(alias="nan-value-counts", default=None) + lower_bounds: ValueMap | None = Field(alias="lower-bounds", default=None) + upper_bounds: ValueMap | None = Field(alias="upper-bounds", default=None) + + +class RESTPositionDeleteFile(RESTContentFile): + """Position delete file from REST API.""" + + content: Literal["position-deletes"] = Field(default="position-deletes") + referenced_data_file: str | None = Field(alias="referenced-data-file", default=None) + content_offset: int | None = Field(alias="content-offset", default=None) + content_size_in_bytes: int | None = Field(alias="content-size-in-bytes", default=None) + + +class RESTEqualityDeleteFile(RESTContentFile): + """Equality delete file from REST API.""" + + content: Literal["equality-deletes"] = Field(default="equality-deletes") + equality_ids: list[int] | None = Field(alias="equality-ids", default=None) + + +# Discriminated union for delete files +RESTDeleteFile = Annotated[ + RESTPositionDeleteFile | RESTEqualityDeleteFile, + Field(discriminator="content"), +] + + +class RESTFileScanTask(IcebergBaseModel): + """A file scan task from the REST server.""" + + data_file: RESTDataFile = Field(alias="data-file") + delete_file_references: list[int] | None = Field(alias="delete-file-references", default=None) + residual_filter: BooleanExpression | None = Field(alias="residual-filter", default=None) + + +class ScanTasks(IcebergBaseModel): + """Container for scan tasks returned by the server.""" + + delete_files: list[RESTDeleteFile] = Field(alias="delete-files", default_factory=list) + file_scan_tasks: list[RESTFileScanTask] = Field(alias="file-scan-tasks", default_factory=list) + plan_tasks: list[str] = Field(alias="plan-tasks", default_factory=list) + + @model_validator(mode="after") + def _validate_delete_file_references(self) -> ScanTasks: + # validate delete file references are in bounds + max_idx = len(self.delete_files) - 1 + for task in self.file_scan_tasks: + for idx in task.delete_file_references or []: + if idx < 0 or idx > max_idx: + raise ValueError(f"Invalid delete file reference: {idx} (valid range: 0-{max_idx})") + + if self.delete_files and not self.file_scan_tasks: + raise ValueError("Invalid response: deleteFiles should only be returned with fileScanTasks that reference them") + + return self + + +class PlanCompleted(ScanTasks): + """Completed scan plan result.""" + + status: Literal["completed"] = "completed" + plan_id: str | None = Field(alias="plan-id", default=None) + storage_credentials: list[StorageCredential] | None = Field(alias="storage-credentials", default=None) + + +class PlanSubmitted(IcebergBaseModel): + """Scan plan submitted, poll for completion.""" + + status: Literal["submitted"] = "submitted" + plan_id: str | None = Field(alias="plan-id", default=None) + + +class PlanCancelled(IcebergBaseModel): + """Planning was cancelled.""" + + status: Literal["cancelled"] = "cancelled" + + +class PlanFailed(IcebergBaseModel): + """Planning failed with error.""" + + status: Literal["failed"] = "failed" + error: ErrorResponseMessage + + +PlanningResponse = Annotated[ + PlanCompleted | PlanSubmitted | PlanCancelled | PlanFailed, + Field(discriminator="status"), +] + + +class PlanTableScanRequest(IcebergBaseModel): + """Request body for planning a REST scan.""" + + snapshot_id: int | None = Field(alias="snapshot-id", default=None) + select: list[str] | None = Field(default=None) + filter: BooleanExpression | None = Field(default=None) + case_sensitive: bool = Field(alias="case-sensitive", default=True) + use_snapshot_schema: bool = Field(alias="use-snapshot-schema", default=False) + start_snapshot_id: int | None = Field(alias="start-snapshot-id", default=None) + end_snapshot_id: int | None = Field(alias="end-snapshot-id", default=None) + stats_fields: list[str] | None = Field(alias="stats-fields", default=None) + + @model_validator(mode="after") + def _validate_snapshot_fields(self) -> PlanTableScanRequest: + if self.start_snapshot_id is not None and self.end_snapshot_id is None: + raise ValueError("end-snapshot-id is required when start-snapshot-id is specified") + if self.snapshot_id is not None and self.start_snapshot_id is not None: + raise ValueError("Cannot specify both snapshot-id and start-snapshot-id") + return self + + +class FetchScanTasksRequest(IcebergBaseModel): + """Request body for fetching scan tasks endpoint.""" + + plan_task: str = Field(alias="plan-task") diff --git a/tests/catalog/test_scan_planning_models.py b/tests/catalog/test_scan_planning_models.py new file mode 100644 index 0000000000..9f03c8f7cd --- /dev/null +++ b/tests/catalog/test_scan_planning_models.py @@ -0,0 +1,450 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any + +import pytest +from pydantic import TypeAdapter, ValidationError + +from pyiceberg.catalog.rest.scan_planning import ( + CountMap, + FetchScanTasksRequest, + PlanCancelled, + PlanCompleted, + PlanningResponse, + PlanSubmitted, + PlanTableScanRequest, + RESTDataFile, + RESTDeleteFile, + RESTEqualityDeleteFile, + RESTFileScanTask, + RESTPositionDeleteFile, + ScanTasks, + StorageCredential, + ValueMap, +) +from pyiceberg.expressions import AlwaysTrue, EqualTo, Reference +from pyiceberg.manifest import FileFormat + + +def test_count_map_valid() -> None: + cm = CountMap(keys=[1, 2, 3], values=[100, 200, 300]) + assert cm.to_dict() == {1: 100, 2: 200, 3: 300} + + +def test_count_map_empty() -> None: + cm = CountMap() + assert cm.to_dict() == {} + + +def test_count_map_length_mismatch() -> None: + with pytest.raises(ValidationError) as exc_info: + CountMap(keys=[1, 2, 3], values=[100, 200]) + assert "must have same length" in str(exc_info.value) + + +def test_value_map_mixed_types() -> None: + vm = ValueMap(keys=[1, 2, 3], values=[True, 42, "val"]) + assert vm.to_dict() == {1: True, 2: 42, 3: "val"} + + +def test_data_file_parsing() -> None: + data = { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/file.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + } + df = RESTDataFile.model_validate(data) + assert df.content == "data" + assert df.file_path == "s3://bucket/table/file.parquet" + assert df.file_format == FileFormat.PARQUET + assert df.file_size_in_bytes == 1024 + + +def test_data_file_with_stats() -> None: + data = { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/file.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + "column-sizes": {"keys": [1, 2], "values": [500, 524]}, + "value-counts": {"keys": [1, 2], "values": [100, 100]}, + } + df = RESTDataFile.model_validate(data) + assert df.column_sizes is not None + assert df.column_sizes.to_dict() == {1: 500, 2: 524} + + +def test_position_delete_file() -> None: + data = { + "spec-id": 0, + "content": "position-deletes", + "file-path": "s3://bucket/table/delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 512, + "record-count": 10, + "content-offset": 100, + "content-size-in-bytes": 200, + } + pdf = RESTPositionDeleteFile.model_validate(data) + assert pdf.content == "position-deletes" + assert pdf.content_offset == 100 + assert pdf.content_size_in_bytes == 200 + + +def test_equality_delete_file() -> None: + data = { + "spec-id": 0, + "content": "equality-deletes", + "file-path": "s3://bucket/table/eq-delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + "equality-ids": [1, 2], + } + edf = RESTEqualityDeleteFile.model_validate(data) + assert edf.content == "equality-deletes" + assert edf.equality_ids == [1, 2] + + +def test_file_format_case_insensitive() -> None: + for fmt in ["parquet", "PARQUET", "Parquet"]: + data = { + "spec-id": 0, + "content": "data", + "file-path": "/path", + "file-format": fmt, + "file-size-in-bytes": 100, + "record-count": 10, + } + df = RESTDataFile.model_validate(data) + assert df.file_format == FileFormat.PARQUET + + +@pytest.mark.parametrize( + "format_str,expected", + [ + ("parquet", FileFormat.PARQUET), + ("avro", FileFormat.AVRO), + ("orc", FileFormat.ORC), + ], +) +def test_file_formats(format_str: str, expected: FileFormat) -> None: + data = { + "spec-id": 0, + "content": "data", + "file-path": f"s3://bucket/table/path/file.{format_str}", + "file-format": format_str, + "file-size-in-bytes": 1024, + "record-count": 100, + } + df = RESTDataFile.model_validate(data) + assert df.file_format == expected + + +def test_delete_file_discriminator_position() -> None: + data = { + "spec-id": 0, + "content": "position-deletes", + "file-path": "s3://bucket/table/delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + } + result = TypeAdapter(RESTDeleteFile).validate_python(data) + assert isinstance(result, RESTPositionDeleteFile) + + +def test_delete_file_discriminator_equality() -> None: + data = { + "spec-id": 0, + "content": "equality-deletes", + "file-path": "s3://bucket/table/delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + "equality-ids": [1], + } + result = TypeAdapter(RESTDeleteFile).validate_python(data) + assert isinstance(result, RESTEqualityDeleteFile) + + +def test_basic_scan_task() -> None: + data = { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/file.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + } + } + task = RESTFileScanTask.model_validate(data) + assert task.data_file.file_path == "s3://bucket/table/file.parquet" + assert task.delete_file_references is None + assert task.residual_filter is None + + +def test_scan_task_with_delete_references() -> None: + data = { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/file.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + }, + "delete-file-references": [0, 1, 2], + } + task = RESTFileScanTask.model_validate(data) + assert task.delete_file_references == [0, 1, 2] + + +def test_scan_task_with_residual_filter_true() -> None: + data = { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/file.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + }, + "residual-filter": True, + } + task = RESTFileScanTask.model_validate(data) + assert isinstance(task.residual_filter, AlwaysTrue) + + +def test_empty_scan_tasks() -> None: + data: dict[str, Any] = { + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + } + scan_tasks = ScanTasks.model_validate(data) + assert len(scan_tasks.file_scan_tasks) == 0 + assert len(scan_tasks.delete_files) == 0 + assert len(scan_tasks.plan_tasks) == 0 + + +def test_scan_tasks_with_files() -> None: + data = { + "delete-files": [ + { + "spec-id": 0, + "content": "position-deletes", + "file-path": "s3://bucket/table/delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + } + ], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/data.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + }, + "delete-file-references": [0], + } + ], + "plan-tasks": ["token-1", "token-2"], + } + scan_tasks = ScanTasks.model_validate(data) + assert len(scan_tasks.delete_files) == 1 + assert len(scan_tasks.file_scan_tasks) == 1 + assert len(scan_tasks.plan_tasks) == 2 + + +def test_invalid_delete_file_reference() -> None: + data = { + "delete-files": [], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/table/data.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + }, + "delete-file-references": [0], + } + ], + "plan-tasks": [], + } + with pytest.raises(ValidationError) as exc_info: + ScanTasks.model_validate(data) + assert "Invalid delete file reference" in str(exc_info.value) + + +def test_delete_files_require_file_scan_tasks() -> None: + data = { + "delete-files": [ + { + "spec-id": 0, + "content": "position-deletes", + "file-path": "s3://bucket/table/delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + } + ], + "file-scan-tasks": [], + "plan-tasks": [], + } + with pytest.raises(ValidationError) as exc_info: + ScanTasks.model_validate(data) + assert "deleteFiles should only be returned with fileScanTasks" in str(exc_info.value) + + +def test_minimal_request() -> None: + request = PlanTableScanRequest() + dumped = request.model_dump(by_alias=True, exclude_none=True) + assert dumped == {"case-sensitive": True, "use-snapshot-schema": False} + + +def test_request_with_snapshot_id() -> None: + request = PlanTableScanRequest(snapshot_id=12345) + dumped = request.model_dump(by_alias=True, exclude_none=True) + assert dumped["snapshot-id"] == 12345 + + +def test_request_with_select_and_filter() -> None: + request = PlanTableScanRequest( + select=["id", "name"], + filter=EqualTo(Reference("id"), 42), + ) + dumped = request.model_dump(by_alias=True, exclude_none=True) + assert dumped["select"] == ["id", "name"] + assert "filter" in dumped + + +def test_incremental_scan_request() -> None: + request = PlanTableScanRequest( + start_snapshot_id=100, + end_snapshot_id=200, + ) + dumped = request.model_dump(by_alias=True, exclude_none=True) + assert dumped["start-snapshot-id"] == 100 + assert dumped["end-snapshot-id"] == 200 + + +def test_start_snapshot_requires_end_snapshot() -> None: + with pytest.raises(ValidationError) as exc_info: + PlanTableScanRequest(start_snapshot_id=100) + assert "end-snapshot-id is required" in str(exc_info.value) + + +def test_snapshot_id_conflicts_with_start_snapshot() -> None: + with pytest.raises(ValidationError) as exc_info: + PlanTableScanRequest(snapshot_id=50, start_snapshot_id=100, end_snapshot_id=200) + assert "Cannot specify both" in str(exc_info.value) + + +def test_fetch_scan_tasks_request() -> None: + request = FetchScanTasksRequest(plan_task="token-abc-123") + dumped = request.model_dump(by_alias=True) + assert dumped == {"plan-task": "token-abc-123"} + + +def test_completed_response() -> None: + data = { + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + } + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanCompleted) + assert result.plan_id == "plan-123" + + +def test_completed_response_without_plan_id() -> None: + data = { + "status": "completed", + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + } + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanCompleted) + assert result.plan_id is None + + +def test_completed_response_with_credentials() -> None: + data = { + "status": "completed", + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + "storage-credentials": [ + {"prefix": "s3://bucket/", "config": {}}, + ], + } + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanCompleted) + assert result.storage_credentials is not None + assert len(result.storage_credentials) == 1 + + +def test_submitted_response() -> None: + data = { + "status": "submitted", + "plan-id": "drus-plan", + } + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanSubmitted) + assert result.plan_id == "drus-plan" + + +def test_submitted_response_without_plan_id() -> None: + data = {"status": "submitted"} + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanSubmitted) + + +def test_cancelled_response() -> None: + data = {"status": "cancelled"} + result = TypeAdapter(PlanningResponse).validate_python(data) + assert isinstance(result, PlanCancelled) + + +def test_storage_credential_parsing() -> None: + data = { + "prefix": "s3://bucket/path/", + "config": { + "s3.access-key-id": "key", + "s3.secret-access-key": "secret", + }, + } + cred = StorageCredential.model_validate(data) + assert cred.prefix == "s3://bucket/path/" + assert cred.config["s3.access-key-id"] == "key" From 3e35bfdaecbf733f77c5a181d9627689439dc5eb Mon Sep 17 00:00:00 2001 From: geruh Date: Wed, 24 Dec 2025 20:04:41 -0800 Subject: [PATCH 3/5] feat: Add support for rest scan planning --- pyiceberg/catalog/__init__.py | 7 + pyiceberg/catalog/rest/__init__.py | 116 ++++++- pyiceberg/catalog/rest/scan_planning.py | 8 +- pyiceberg/exceptions.py | 8 + pyiceberg/table/__init__.py | 113 ++++++- tests/catalog/test_scan_planning_models.py | 284 ++++++++++++++++-- .../test_rest_scan_planning_integration.py | 188 ++++++++++++ 7 files changed, 690 insertions(+), 34 deletions(-) create mode 100644 tests/integration/test_rest_scan_planning_integration.py diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index a4f1d47bea..7fc67a33ca 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -794,6 +794,13 @@ def close(self) -> None: # noqa: B027 Default implementation does nothing. Override in subclasses that need cleanup. """ + def is_rest_scan_planning_enabled(self) -> bool: + """Check if server-side scan planning is enabled. + + Returns False by default. + """ + return False + def __enter__(self) -> Catalog: """Enter the context manager. diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index f9a8d0a6fd..b42d51ba1d 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections import deque +from collections.abc import Iterator from enum import Enum from typing import ( TYPE_CHECKING, @@ -21,7 +23,7 @@ Union, ) -from pydantic import ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, TypeAdapter, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt @@ -36,6 +38,16 @@ ) from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager from pyiceberg.catalog.rest.response import _handle_non_200_response +from pyiceberg.catalog.rest.scan_planning import ( + FetchScanTasksRequest, + PlanCancelled, + PlanCompleted, + PlanFailed, + PlanningResponse, + PlanSubmitted, + PlanTableScanRequest, + ScanTasks, +) from pyiceberg.exceptions import ( AuthorizationExpiredError, CommitFailedException, @@ -44,6 +56,7 @@ NamespaceNotEmptyError, NoSuchIdentifierError, NoSuchNamespaceError, + NoSuchPlanTaskError, NoSuchTableError, NoSuchViewError, TableAlreadyExistsError, @@ -56,6 +69,7 @@ CommitTableRequest, CommitTableResponse, CreateTableTransaction, + FileScanTask, StagedTable, Table, TableIdentifier, @@ -322,6 +336,9 @@ class ListViewsResponse(IcebergBaseModel): identifiers: list[ListViewResponseEntry] = Field() +_PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) + + class RestCatalog(Catalog): uri: str _session: Session @@ -391,6 +408,103 @@ def is_rest_scan_planning_enabled(self) -> bool: self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT ) + @retry(**_RETRY_ARGS) + def _plan_table_scan(self, identifier: str | Identifier, request: PlanTableScanRequest) -> PlanningResponse: + """Submit a scan plan request to the REST server. + + Args: + identifier: Table identifier. + request: The scan plan request parameters. + + Returns: + PlanningResponse the result of the scan plan request representing the status + Raises: + NoSuchTableError: If a table with the given identifier does not exist. + """ + self._check_endpoint(Capability.V1_SUBMIT_TABLE_SCAN_PLAN) + response = self._session.post( + self.url(Endpoints.plan_table_scan, prefixed=True, **self._split_identifier_for_path(identifier)), + json=request.model_dump(by_alias=True, exclude_none=True), + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {404: NoSuchTableError}) + + return _PLANNING_RESPONSE_ADAPTER.validate_json(response.text) + + @retry(**_RETRY_ARGS) + def _fetch_scan_tasks(self, identifier: str | Identifier, plan_task: str) -> ScanTasks: + """Fetch additional scan tasks using a plan task token. + + Args: + identifier: Table identifier. + plan_task: The plan task token from a previous response. + + Returns: + ScanTasks containing file scan tasks and possibly more plan-task tokens. + + Raises: + NoSuchPlanTaskError: If a plan task with the given identifier or task does not exist. + """ + self._check_endpoint(Capability.V1_TABLE_SCAN_PLAN_TASKS) + request = FetchScanTasksRequest(plan_task=plan_task) + response = self._session.post( + self.url(Endpoints.fetch_scan_tasks, prefixed=True, **self._split_identifier_for_path(identifier)), + json=request.model_dump(by_alias=True), + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {404: NoSuchPlanTaskError}) + + return ScanTasks.model_validate_json(response.text) + + def plan_scan(self, identifier: str | Identifier, request: PlanTableScanRequest) -> Iterator["FileScanTask"]: + """Plan a table scan and yield FileScanTasks. + + Handles the full scan planning lifecycle including pagination. + Each response batch is self-contained, so tasks are yielded as received. + + Args: + identifier: Table identifier. + request: The scan plan request parameters. + + Yields: + FileScanTask objects ready for execution. + + Raises: + RuntimeError: If planning fails, is cancelled, or returns unexpected response. + NotImplementedError: If async planning is required but not yet supported. + """ + response = self._plan_table_scan(identifier, request) + + if isinstance(response, PlanFailed): + raise RuntimeError(f"Received status: failed: {response.error.message}") + + if isinstance(response, PlanCancelled): + raise RuntimeError("Received status: cancelled") + + if isinstance(response, PlanSubmitted): + # TODO: implement polling for async planning + raise NotImplementedError(f"Async scan planning not yet supported for planId: {response.plan_id}") + + if not isinstance(response, PlanCompleted): + raise RuntimeError(f"Invalid planStatus for response: {type(response).__name__}") + + # Yield tasks from initial response + for task in response.file_scan_tasks: + yield FileScanTask.from_rest_response(task, response.delete_files) + + # Fetch and yield from additional batches + pending_tasks = deque(response.plan_tasks) + while pending_tasks: + plan_task = pending_tasks.popleft() + batch = self._fetch_scan_tasks(identifier, plan_task) + for task in batch.file_scan_tasks: + yield FileScanTask.from_rest_response(task, batch.delete_files) + pending_tasks.extend(batch.plan_tasks) + def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: """Create the LegacyOAuth2AuthManager by fetching required properties. diff --git a/pyiceberg/catalog/rest/scan_planning.py b/pyiceberg/catalog/rest/scan_planning.py index ddccf1d9e3..ec12e9a946 100644 --- a/pyiceberg/catalog/rest/scan_planning.py +++ b/pyiceberg/catalog/rest/scan_planning.py @@ -25,9 +25,15 @@ from pyiceberg.catalog.rest.response import ErrorResponseMessage from pyiceberg.expressions import BooleanExpression -from pyiceberg.manifest import FileFormat +from pyiceberg.manifest import DataFileContent, FileFormat from pyiceberg.typedef import IcebergBaseModel +CONTENT_TYPE_MAP: dict[str, DataFileContent] = { + "data": DataFileContent.DATA, + "position-deletes": DataFileContent.POSITION_DELETES, + "equality-deletes": DataFileContent.EQUALITY_DELETES, +} + # Primitive types that can appear in partition values and bounds PrimitiveTypeValue: TypeAlias = bool | int | float | str | Decimal | UUID | date | time | datetime | bytes diff --git a/pyiceberg/exceptions.py b/pyiceberg/exceptions.py index c80f104e46..d64fe8c2ba 100644 --- a/pyiceberg/exceptions.py +++ b/pyiceberg/exceptions.py @@ -52,6 +52,14 @@ class NoSuchNamespaceError(Exception): """Raised when a referenced name-space is not found.""" +class NoSuchPlanError(Exception): + """Raised when a scan plan ID is not found.""" + + +class NoSuchPlanTaskError(Exception): + """Raised when a scan plan task is not found.""" + + class RESTError(Exception): """Raises when there is an unknown response from the REST Catalog.""" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2e26a4ccc2..6a1226a9bd 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -145,6 +145,11 @@ from pyiceberg_core.datafusion import IcebergDataFusionTable from pyiceberg.catalog import Catalog + from pyiceberg.catalog.rest.scan_planning import ( + RESTContentFile, + RESTDeleteFile, + RESTFileScanTask, + ) ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -1168,6 +1173,8 @@ def scan( snapshot_id=snapshot_id, options=options, limit=limit, + catalog=self.catalog, + table_identifier=self._identifier, ) @property @@ -1684,6 +1691,8 @@ class TableScan(ABC): snapshot_id: int | None options: Properties limit: int | None + catalog: Catalog | None + table_identifier: Identifier | None def __init__( self, @@ -1695,6 +1704,8 @@ def __init__( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + catalog: Catalog | None = None, + table_identifier: Identifier | None = None, ): self.table_metadata = table_metadata self.io = io @@ -1704,6 +1715,8 @@ def __init__( self.snapshot_id = snapshot_id self.options = options self.limit = limit + self.catalog = catalog + self.table_identifier = table_identifier def snapshot(self) -> Snapshot | None: if self.snapshot_id: @@ -1798,6 +1811,67 @@ def __init__( self.delete_files = delete_files or set() self.residual = residual + @staticmethod + def from_rest_response( + rest_task: RESTFileScanTask, + delete_files: list[RESTDeleteFile], + ) -> FileScanTask: + """Convert a RESTFileScanTask to a FileScanTask. + + Args: + rest_task: The REST file scan task. + delete_files: The list of delete files from the ScanTasks response. + + Returns: + A FileScanTask with the converted data and delete files. + + Raises: + NotImplementedError: If equality delete files are encountered. + """ + from pyiceberg.catalog.rest.scan_planning import RESTEqualityDeleteFile + + data_file = _rest_file_to_data_file(rest_task.data_file, include_stats=True) + + resolved_deletes: set[DataFile] = set() + if rest_task.delete_file_references: + for idx in rest_task.delete_file_references: + delete_file = delete_files[idx] + if isinstance(delete_file, RESTEqualityDeleteFile): + raise NotImplementedError(f"PyIceberg does not yet support equality deletes: {delete_file.file_path}") + resolved_deletes.add(_rest_file_to_data_file(delete_file, include_stats=False)) + + return FileScanTask( + data_file=data_file, + delete_files=resolved_deletes, + residual=rest_task.residual_filter if rest_task.residual_filter else ALWAYS_TRUE, + ) + + +def _rest_file_to_data_file(rest_file: RESTContentFile, *, include_stats: bool) -> DataFile: + """Convert a REST content file to a manifest DataFile.""" + from pyiceberg.catalog.rest.scan_planning import CONTENT_TYPE_MAP + + column_sizes = getattr(rest_file, "column_sizes", None) + value_counts = getattr(rest_file, "value_counts", None) + null_value_counts = getattr(rest_file, "null_value_counts", None) + nan_value_counts = getattr(rest_file, "nan_value_counts", None) + + return DataFile.from_args( + content=CONTENT_TYPE_MAP[rest_file.content], + file_path=rest_file.file_path, + file_format=rest_file.file_format, + partition=Record(*rest_file.partition), + record_count=rest_file.record_count, + file_size_in_bytes=rest_file.file_size_in_bytes, + column_sizes=column_sizes.to_dict() if include_stats and column_sizes else None, + value_counts=value_counts.to_dict() if include_stats and value_counts else None, + null_value_counts=null_value_counts.to_dict() if include_stats and null_value_counts else None, + nan_value_counts=nan_value_counts.to_dict() if include_stats and nan_value_counts else None, + split_offsets=rest_file.split_offsets, + sort_order_id=rest_file.sort_order_id, + spec_id=rest_file.spec_id, + ) + def _open_manifest( io: FileIO, @@ -1970,12 +2044,27 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: ], ) - def plan_files(self) -> Iterable[FileScanTask]: - """Plans the relevant files by filtering on the PartitionSpecs. + def _should_use_rest_planning(self) -> bool: + """Check if REST scan planning should be used for this scan.""" + if self.catalog is None: + return False + return self.catalog.is_rest_scan_planning_enabled() + + def _plan_files_rest(self) -> Iterable[FileScanTask]: + """Plan files using REST server-side scan planning.""" + from pyiceberg.catalog.rest.scan_planning import PlanTableScanRequest + + request = PlanTableScanRequest( + snapshot_id=self.snapshot_id, + select=list(self.selected_fields) if self.selected_fields != ("*",) else None, + filter=self.row_filter if self.row_filter != ALWAYS_TRUE else None, + case_sensitive=self.case_sensitive, + ) - Returns: - List of FileScanTasks that contain both data and delete files. - """ + return self.catalog.plan_scan(self.table_identifier, request) # type: ignore[union-attr] + + def _plan_files_local(self) -> Iterable[FileScanTask]: + """Plan files locally by reading manifests.""" data_entries: list[ManifestEntry] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) @@ -2006,6 +2095,20 @@ def plan_files(self) -> Iterable[FileScanTask]: for data_entry in data_entries ] + def plan_files(self) -> Iterable[FileScanTask]: + """Plans the relevant files by filtering on the PartitionSpecs. + + If the table comes from a REST catalog with scan planning enabled, + this will use server-side scan planning. Otherwise, it falls back + to local planning by reading manifests. + + Returns: + List of FileScanTasks that contain both data and delete files. + """ + if self._should_use_rest_planning(): + return self._plan_files_rest() + return self._plan_files_local() + def to_arrow(self) -> pa.Table: """Read an Arrow table eagerly from this DataScan. diff --git a/tests/catalog/test_scan_planning_models.py b/tests/catalog/test_scan_planning_models.py index 9f03c8f7cd..7a33419665 100644 --- a/tests/catalog/test_scan_planning_models.py +++ b/tests/catalog/test_scan_planning_models.py @@ -18,7 +18,9 @@ import pytest from pydantic import TypeAdapter, ValidationError +from requests_mock import Mocker +from pyiceberg.catalog.rest import RestCatalog from pyiceberg.catalog.rest.scan_planning import ( CountMap, FetchScanTasksRequest, @@ -33,12 +35,13 @@ RESTFileScanTask, RESTPositionDeleteFile, ScanTasks, - StorageCredential, ValueMap, ) from pyiceberg.expressions import AlwaysTrue, EqualTo, Reference from pyiceberg.manifest import FileFormat +TEST_URI = "https://iceberg-test-catalog/" + def test_count_map_valid() -> None: cm = CountMap(keys=[1, 2, 3], values=[100, 200, 300]) @@ -399,22 +402,6 @@ def test_completed_response_without_plan_id() -> None: assert result.plan_id is None -def test_completed_response_with_credentials() -> None: - data = { - "status": "completed", - "delete-files": [], - "file-scan-tasks": [], - "plan-tasks": [], - "storage-credentials": [ - {"prefix": "s3://bucket/", "config": {}}, - ], - } - result = TypeAdapter(PlanningResponse).validate_python(data) - assert isinstance(result, PlanCompleted) - assert result.storage_credentials is not None - assert len(result.storage_credentials) == 1 - - def test_submitted_response() -> None: data = { "status": "submitted", @@ -437,14 +424,257 @@ def test_cancelled_response() -> None: assert isinstance(result, PlanCancelled) -def test_storage_credential_parsing() -> None: - data = { - "prefix": "s3://bucket/path/", - "config": { - "s3.access-key-id": "key", - "s3.secret-access-key": "secret", +@pytest.fixture +def rest_mock(requests_mock: Mocker) -> Mocker: + requests_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {}, + "overrides": {}, + "endpoints": [ + "POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/plan", + "POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/tasks", + ], }, - } - cred = StorageCredential.model_validate(data) - assert cred.prefix == "s3://bucket/path/" - assert cred.config["s3.access-key-id"] == "key" + status_code=200, + ) + return requests_mock + + +def _create_test_catalog() -> RestCatalog: + return RestCatalog( + "test", + uri=TEST_URI, + **{"rest-scan-planning-enabled": "true"}, + ) + + +def test_plan_scan_completed_single_batch(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file1.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 67, + } + }, + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file2.parquet", + "file-format": "parquet", + "file-size-in-bytes": 2048, + "record-count": 200, + } + }, + ], + "plan-tasks": [], + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + + tasks = list(catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 2 + assert tasks[0].file.file_path == "s3://bucket/data/file1.parquet" + assert tasks[1].file.file_path == "s3://bucket/data/file2.parquet" + + +def test_plan_scan_with_pagination(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file1.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + } + } + ], + "plan-tasks": ["token-batch-2"], + }, + status_code=200, + ) + + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/tasks", + json={ + "delete-files": [], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file2.parquet", + "file-format": "parquet", + "file-size-in-bytes": 2048, + "record-count": 200, + } + } + ], + "plan-tasks": [], + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + + tasks = list(catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 2 + assert tasks[0].file.file_path == "s3://bucket/data/file1.parquet" + assert tasks[1].file.file_path == "s3://bucket/data/file2.parquet" + + +def test_plan_scan_with_delete_files(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [ + { + "spec-id": 0, + "content": "position-deletes", + "file-path": "s3://bucket/data/delete1.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 10, + } + ], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file1.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 100, + }, + "delete-file-references": [0], + } + ], + "plan-tasks": [], + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + tasks = list(catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 1 + assert tasks[0].file.file_path == "s3://bucket/data/file1.parquet" + assert len(tasks[0].delete_files) == 1 + + +def test_plan_scan_async_not_supported(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "submitted", + "plan-id": "plan-456", + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + with pytest.raises(NotImplementedError, match="Async scan planning not yet supported"): + list(catalog.plan_scan(("db", "tbl"), request)) + + +def test_plan_scan_empty_result(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + tasks = list(catalog.plan_scan(("db", "tbl"), request)) + assert len(tasks) == 0 + + +def test_plan_scan_cancelled(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={"status": "cancelled"}, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + with pytest.raises(RuntimeError, match="Scan planning was cancelled"): + list(catalog.plan_scan(("db", "tbl"), request)) + + +def test_plan_scan_equality_deletes_not_supported(rest_mock: Mocker) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [ + { + "spec-id": 0, + "content": "equality-deletes", + "file-path": "s3://bucket/data/eq-delete.parquet", + "file-format": "parquet", + "file-size-in-bytes": 256, + "record-count": 5, + "equality-ids": [1, 2], + } + ], + "file-scan-tasks": [ + { + "data-file": { + "spec-id": 0, + "content": "data", + "file-path": "s3://bucket/data/file1.parquet", + "file-format": "parquet", + "file-size-in-bytes": 1024, + "record-count": 1000, + }, + "delete-file-references": [0], + } + ], + "plan-tasks": [], + }, + status_code=200, + ) + + catalog = _create_test_catalog() + request = PlanTableScanRequest() + with pytest.raises(NotImplementedError, match="PyIceberg does not yet support equality deletes"): + list(catalog.plan_scan(("db", "tbl"), request)) diff --git a/tests/integration/test_rest_scan_planning_integration.py b/tests/integration/test_rest_scan_planning_integration.py new file mode 100644 index 0000000000..be02643f23 --- /dev/null +++ b/tests/integration/test_rest_scan_planning_integration.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from typing import Any + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.catalog.rest import RestCatalog +from pyiceberg.expressions import And, BooleanExpression, EqualTo, GreaterThan, LessThan +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import ALWAYS_TRUE, Table +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import LongType, NestedField, StringType + + +@pytest.fixture(scope="session") +def scan_catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + "rest-scan-planning-enabled": "true", + }, + ) + + +def recreate_table(catalog: Catalog, identifier: str, **kwargs: Any) -> Table: + """Drop table if exists and create a new one.""" + try: + catalog.drop_table(identifier) + except Exception: + pass + return catalog.create_table(identifier, **kwargs) + + +def _assert_remote_scan_matches_local_scan( + rest_table: Table, + session_catalog: Catalog, + identifier: str, + row_filter: BooleanExpression = ALWAYS_TRUE, +) -> None: + rest_tasks = list(rest_table.scan(row_filter=row_filter).plan_files()) + rest_paths = {task.file.file_path for task in rest_tasks} + + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan(row_filter=row_filter).plan_files()) + local_paths = {task.file.file_path for task in local_tasks} + + assert rest_paths == local_paths + + +@pytest.mark.integration +def test_rest_scan_matches_local(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan" + + table = recreate_table( + scan_catalog, + identifier, + schema=Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType()), + NestedField(3, "num", LongType()), + ), + ) + table.append(pa.Table.from_pydict({"id": [1, 2, 3], "data": ["a", "b", "c"], "num": [10, 20, 30]})) + table.append(pa.Table.from_pydict({"id": [4, 5, 6], "data": ["d", "e", "f"], "num": [40, 50, 60]})) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_with_filter(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_filter" + + table = recreate_table( + scan_catalog, + identifier, + schema=Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", LongType()), + ), + ) + table.append(pa.Table.from_pydict({"id": [1, 2, 3], "data": [10, 20, 30]})) + + try: + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=And(GreaterThan("data", 5), LessThan("data", 25)), + ) + + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=EqualTo("id", 1), + ) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_with_deletes(spark: SparkSession, scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_deletes" + + spark.sql(f"DROP TABLE IF EXISTS {identifier}") + spark.sql(f""" + CREATE TABLE {identifier} (id bigint, data bigint) + USING iceberg + TBLPROPERTIES( + 'format-version' = 2, + 'write.delete.mode'='merge-on-read' + ) + """) + spark.sql(f"INSERT INTO {identifier} VALUES (1, 10), (2, 20), (3, 30)") + spark.sql(f"DELETE FROM {identifier} WHERE id = 2") + + try: + rest_table = scan_catalog.load_table(identifier) + rest_tasks = list(rest_table.scan().plan_files()) + rest_paths = {task.file.file_path for task in rest_tasks} + rest_delete_paths = {delete.file_path for task in rest_tasks for delete in task.delete_files} + + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan().plan_files()) + local_paths = {task.file.file_path for task in local_tasks} + local_delete_paths = {delete.file_path for task in local_tasks for delete in task.delete_files} + + assert rest_paths == local_paths + assert rest_delete_paths == local_delete_paths + finally: + spark.sql(f"DROP TABLE IF EXISTS {identifier}") + + +@pytest.mark.integration +def test_rest_scan_with_partitioning(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_partitioned" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "category", StringType()), + NestedField(3, "data", LongType()), + ) + partition_spec = PartitionSpec(PartitionField(2, 1000, IdentityTransform(), "category")) + + table = recreate_table(scan_catalog, identifier, schema=schema, partition_spec=partition_spec) + + table.append(pa.Table.from_pydict({"id": [1, 2], "category": ["a", "a"], "data": [10, 20]})) + table.append(pa.Table.from_pydict({"id": [3, 4], "category": ["b", "b"], "data": [30, 40]})) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + + # test filter against partition + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=EqualTo("category", "a"), + ) + finally: + scan_catalog.drop_table(identifier) From 9c65059adecde6beb24a478deb5f716cb2f0d069 Mon Sep 17 00:00:00 2001 From: geruh Date: Wed, 24 Dec 2025 23:55:04 -0800 Subject: [PATCH 4/5] Add tests and fix models --- pyiceberg/catalog/rest/__init__.py | 4 +- pyiceberg/catalog/rest/scan_planning.py | 4 +- .../test_rest_scan_planning_integration.py | 169 +++++++++++++++++- 3 files changed, 168 insertions(+), 9 deletions(-) diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index b42d51ba1d..0272aac0db 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -424,7 +424,7 @@ def _plan_table_scan(self, identifier: str | Identifier, request: PlanTableScanR self._check_endpoint(Capability.V1_SUBMIT_TABLE_SCAN_PLAN) response = self._session.post( self.url(Endpoints.plan_table_scan, prefixed=True, **self._split_identifier_for_path(identifier)), - json=request.model_dump(by_alias=True, exclude_none=True), + data=request.model_dump_json(by_alias=True, exclude_none=True).encode(UTF8), ) try: response.raise_for_status() @@ -451,7 +451,7 @@ def _fetch_scan_tasks(self, identifier: str | Identifier, plan_task: str) -> Sca request = FetchScanTasksRequest(plan_task=plan_task) response = self._session.post( self.url(Endpoints.fetch_scan_tasks, prefixed=True, **self._split_identifier_for_path(identifier)), - json=request.model_dump(by_alias=True), + data=request.model_dump_json(by_alias=True).encode(UTF8), ) try: response.raise_for_status() diff --git a/pyiceberg/catalog/rest/scan_planning.py b/pyiceberg/catalog/rest/scan_planning.py index ec12e9a946..3f74ee068f 100644 --- a/pyiceberg/catalog/rest/scan_planning.py +++ b/pyiceberg/catalog/rest/scan_planning.py @@ -24,7 +24,7 @@ from pydantic import Field, model_validator from pyiceberg.catalog.rest.response import ErrorResponseMessage -from pyiceberg.expressions import BooleanExpression +from pyiceberg.expressions import BooleanExpression, SerializableBooleanExpression from pyiceberg.manifest import DataFileContent, FileFormat from pyiceberg.typedef import IcebergBaseModel @@ -192,7 +192,7 @@ class PlanTableScanRequest(IcebergBaseModel): snapshot_id: int | None = Field(alias="snapshot-id", default=None) select: list[str] | None = Field(default=None) - filter: BooleanExpression | None = Field(default=None) + filter: SerializableBooleanExpression | None = Field(default=None) case_sensitive: bool = Field(alias="case-sensitive", default=True) use_snapshot_schema: bool = Field(alias="use-snapshot-schema", default=False) start_snapshot_id: int | None = Field(alias="start-snapshot-id", default=None) diff --git a/tests/integration/test_rest_scan_planning_integration.py b/tests/integration/test_rest_scan_planning_integration.py index be02643f23..0a67aff989 100644 --- a/tests/integration/test_rest_scan_planning_integration.py +++ b/tests/integration/test_rest_scan_planning_integration.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal from typing import Any +from uuid import uuid4 import pyarrow as pa import pytest @@ -23,17 +26,50 @@ from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.catalog.rest import RestCatalog -from pyiceberg.expressions import And, BooleanExpression, EqualTo, GreaterThan, LessThan +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.expressions import ( + And, + BooleanExpression, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNull, + LessThan, + LessThanOrEqual, + Not, + NotEqualTo, + NotIn, + NotNull, + Or, + StartsWith, +) from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import ALWAYS_TRUE, Table -from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import LongType, NestedField, StringType +from pyiceberg.transforms import ( + IdentityTransform, +) +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) @pytest.fixture(scope="session") def scan_catalog() -> Catalog: - return load_catalog( + catalog = load_catalog( "local", **{ "type": "rest", @@ -44,13 +80,15 @@ def scan_catalog() -> Catalog: "rest-scan-planning-enabled": "true", }, ) + catalog.create_namespace_if_not_exists("default") + return catalog def recreate_table(catalog: Catalog, identifier: str, **kwargs: Any) -> Table: """Drop table if exists and create a new one.""" try: catalog.drop_table(identifier) - except Exception: + except NoSuchTableError: pass return catalog.create_table(identifier, **kwargs) @@ -186,3 +224,124 @@ def test_rest_scan_with_partitioning(scan_catalog: RestCatalog, session_catalog: ) finally: scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_primitive_types(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_primitives" + + schema = Schema( + NestedField(1, "bool_col", BooleanType()), + NestedField(2, "long_col", LongType()), + NestedField(3, "double_col", DoubleType()), + NestedField(4, "decimal_col", DecimalType(10, 2)), + NestedField(5, "string_col", StringType()), + NestedField(6, "date_col", DateType()), + NestedField(7, "time_col", TimeType()), + NestedField(8, "timestamp_col", TimestampType()), + NestedField(9, "timestamptz_col", TimestamptzType()), + NestedField(10, "uuid_col", UUIDType()), + NestedField(11, "fixed_col", FixedType(16)), + NestedField(12, "binary_col", BinaryType()), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + now = datetime.now() + now_tz = datetime.now(tz=timezone.utc) + today = date.today() + uuid1, uuid2, uuid3 = uuid4(), uuid4(), uuid4() + + arrow_table = pa.Table.from_pydict( + { + "bool_col": [True, False, True], + "long_col": [100, 200, 300], + "double_col": [1.11, 2.22, 3.33], + "decimal_col": [Decimal("1.23"), Decimal("4.56"), Decimal("7.89")], + "string_col": ["a", "b", "c"], + "date_col": [today, today - timedelta(days=1), today - timedelta(days=2)], + "time_col": [time(8, 30, 0), time(12, 0, 0), time(18, 45, 30)], + "timestamp_col": [now, now - timedelta(hours=1), now - timedelta(hours=2)], + "timestamptz_col": [now_tz, now_tz - timedelta(hours=1), now_tz - timedelta(hours=2)], + "uuid_col": [uuid1.bytes, uuid2.bytes, uuid3.bytes], + "fixed_col": [b"0123456789abcdef", b"abcdef0123456789", b"fedcba9876543210"], + "binary_col": [b"hello", b"world", b"test"], + }, + schema=schema.as_arrow(), + ) + table.append(arrow_table) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_complex_filters(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_complex_filters" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "name", StringType()), + NestedField(3, "value", LongType()), + NestedField(4, "optional", StringType(), required=False), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + table.append( + pa.Table.from_pydict( + { + "id": list(range(1, 21)), + "name": [f"item_{i}" for i in range(1, 21)], + "value": [i * 100 for i in range(1, 21)], + "optional": [None if i % 3 == 0 else f"opt_{i}" for i in range(1, 21)], + } + ) + ) + + try: + filters = [ + EqualTo("id", 10), + NotEqualTo("id", 10), + GreaterThan("value", 1000), + GreaterThanOrEqual("value", 1000), + LessThan("value", 500), + LessThanOrEqual("value", 500), + In("id", [1, 5, 10, 15]), + NotIn("id", [1, 5, 10, 15]), + IsNull("optional"), + NotNull("optional"), + StartsWith("name", "item_1"), + And(GreaterThan("id", 5), LessThan("id", 15)), + Or(EqualTo("id", 1), EqualTo("id", 20)), + Not(EqualTo("id", 10)), + ] + + for filter_expr in filters: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier, row_filter=filter_expr) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_empty_table(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_empty_table" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType()), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + try: + rest_tasks = list(table.scan().plan_files()) + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan().plan_files()) + + assert len(rest_tasks) == 0 + assert len(local_tasks) == 0 + finally: + scan_catalog.drop_table(identifier) From 41844610cf0e9403739fcb9b838f95eba54c9d0e Mon Sep 17 00:00:00 2001 From: geruh Date: Thu, 25 Dec 2025 01:02:49 -0800 Subject: [PATCH 5/5] fix test to align with java --- tests/catalog/test_scan_planning_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/catalog/test_scan_planning_models.py b/tests/catalog/test_scan_planning_models.py index 7a33419665..903d688fdf 100644 --- a/tests/catalog/test_scan_planning_models.py +++ b/tests/catalog/test_scan_planning_models.py @@ -635,7 +635,7 @@ def test_plan_scan_cancelled(rest_mock: Mocker) -> None: catalog = _create_test_catalog() request = PlanTableScanRequest() - with pytest.raises(RuntimeError, match="Scan planning was cancelled"): + with pytest.raises(RuntimeError, match="Received status: cancelled"): list(catalog.plan_scan(("db", "tbl"), request))