diff --git a/src/PowerPlatform/Dataverse/client.py b/src/PowerPlatform/Dataverse/client.py index 2fb11e8..e59ad26 100644 --- a/src/PowerPlatform/Dataverse/client.py +++ b/src/PowerPlatform/Dataverse/client.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Union, List, Iterable +from typing import Any, Dict, Optional, Union, List, Iterable, Iterator +from contextlib import contextmanager from azure.core.credentials import TokenCredential @@ -99,6 +100,13 @@ def _get_odata(self) -> _ODataClient: ) return self._odata + @contextmanager + def _scoped_odata(self) -> Iterator[_ODataClient]: + """Yield the low-level client while ensuring a correlation scope is active.""" + od = self._get_odata() + with od._call_scope(): + yield od + # ---------------- Unified CRUD: create/update/delete ---------------- def create(self, table_schema_name: str, records: Union[Dict[str, Any], List[Dict[str, Any]]]) -> List[str]: """ @@ -132,19 +140,19 @@ def create(self, table_schema_name: str, records: Union[Dict[str, Any], List[Dic ids = client.create("account", records) print(f"Created {len(ids)} accounts") """ - od = self._get_odata() - entity_set = od._entity_set_from_schema_name(table_schema_name) - if isinstance(records, dict): - rid = od._create(entity_set, table_schema_name, records) - # _create returns str on single input - if not isinstance(rid, str): - raise TypeError("_create (single) did not return GUID string") - return [rid] - if isinstance(records, list): - ids = od._create_multiple(entity_set, table_schema_name, records) - if not isinstance(ids, list) or not all(isinstance(x, str) for x in ids): - raise TypeError("_create (multi) did not return list[str]") - return ids + with self._scoped_odata() as od: + entity_set = od._entity_set_from_schema_name(table_schema_name) + if isinstance(records, dict): + rid = od._create(entity_set, table_schema_name, records) + # _create returns str on single input + if not isinstance(rid, str): + raise TypeError("_create (single) did not return GUID string") + return [rid] + if isinstance(records, list): + ids = od._create_multiple(entity_set, table_schema_name, records) + if not isinstance(ids, list) or not all(isinstance(x, str) for x in ids): + raise TypeError("_create (multi) did not return list[str]") + return ids raise TypeError("records must be dict or list[dict]") def update( @@ -192,16 +200,16 @@ def update( ] client.update("account", ids, changes) """ - od = self._get_odata() - if isinstance(ids, str): - if not isinstance(changes, dict): - raise TypeError("For single id, changes must be a dict") - od._update(table_schema_name, ids, changes) # discard representation + with self._scoped_odata() as od: + if isinstance(ids, str): + if not isinstance(changes, dict): + raise TypeError("For single id, changes must be a dict") + od._update(table_schema_name, ids, changes) # discard representation + return None + if not isinstance(ids, list): + raise TypeError("ids must be str or list[str]") + od._update_by_ids(table_schema_name, ids, changes) return None - if not isinstance(ids, list): - raise TypeError("ids must be str or list[str]") - od._update_by_ids(table_schema_name, ids, changes) - return None def delete( self, @@ -235,21 +243,21 @@ def delete( job_id = client.delete("account", [id1, id2, id3]) """ - od = self._get_odata() - if isinstance(ids, str): - od._delete(table_schema_name, ids) - return None - if not isinstance(ids, list): - raise TypeError("ids must be str or list[str]") - if not ids: + with self._scoped_odata() as od: + if isinstance(ids, str): + od._delete(table_schema_name, ids) + return None + if not isinstance(ids, list): + raise TypeError("ids must be str or list[str]") + if not ids: + return None + if not all(isinstance(rid, str) for rid in ids): + raise TypeError("ids must contain string GUIDs") + if use_bulk_delete: + return od._delete_multiple(table_schema_name, ids) + for rid in ids: + od._delete(table_schema_name, rid) return None - if not all(isinstance(rid, str) for rid in ids): - raise TypeError("ids must contain string GUIDs") - if use_bulk_delete: - return od._delete_multiple(table_schema_name, ids) - for rid in ids: - od._delete(table_schema_name, rid) - return None def get( self, @@ -328,24 +336,29 @@ def get( ): print(f"Batch size: {len(batch)}") """ - od = self._get_odata() if record_id is not None: if not isinstance(record_id, str): raise TypeError("record_id must be str") - return od._get( - table_schema_name, - record_id, - select=select, - ) - return od._get_multiple( - table_schema_name, - select=select, - filter=filter, - orderby=orderby, - top=top, - expand=expand, - page_size=page_size, - ) + with self._scoped_odata() as od: + return od._get( + table_schema_name, + record_id, + select=select, + ) + + def _paged() -> Iterable[List[Dict[str, Any]]]: + with self._scoped_odata() as od: + yield from od._get_multiple( + table_schema_name, + select=select, + filter=filter, + orderby=orderby, + top=top, + expand=expand, + page_size=page_size, + ) + + return _paged() # SQL via Web API sql parameter def query_sql(self, sql: str): @@ -381,7 +394,8 @@ def query_sql(self, sql: str): sql = "SELECT a.name, a.telephone1 FROM account AS a WHERE a.statecode = 0" results = client.query_sql(sql) """ - return self._get_odata()._query_sql(sql) + with self._scoped_odata() as od: + return od._query_sql(sql) # Table metadata helpers def get_table_info(self, table_schema_name: str) -> Optional[Dict[str, Any]]: @@ -404,7 +418,8 @@ def get_table_info(self, table_schema_name: str) -> Optional[Dict[str, Any]]: print(f"Logical name: {info['table_logical_name']}") print(f"Entity set: {info['entity_set_name']}") """ - return self._get_odata()._get_table_info(table_schema_name) + with self._scoped_odata() as od: + return od._get_table_info(table_schema_name) def create_table( self, @@ -474,12 +489,13 @@ class ItemStatus(IntEnum): primary_column_schema_name="new_ProductName" ) """ - return self._get_odata()._create_table( - table_schema_name, - columns, - solution_unique_name, - primary_column_schema_name, - ) + with self._scoped_odata() as od: + return od._create_table( + table_schema_name, + columns, + solution_unique_name, + primary_column_schema_name, + ) def delete_table(self, table_schema_name: str) -> None: """ @@ -499,7 +515,8 @@ def delete_table(self, table_schema_name: str) -> None: client.delete_table("new_MyTestTable") """ - self._get_odata()._delete_table(table_schema_name) + with self._scoped_odata() as od: + od._delete_table(table_schema_name) def list_tables(self) -> list[str]: """ @@ -515,7 +532,8 @@ def list_tables(self) -> list[str]: for table in tables: print(table) """ - return self._get_odata()._list_tables() + with self._scoped_odata() as od: + return od._list_tables() def create_columns( self, @@ -545,10 +563,11 @@ def create_columns( ) print(created) # ['new_Scratch', 'new_Flags'] """ - return self._get_odata()._create_columns( - table_schema_name, - columns, - ) + with self._scoped_odata() as od: + return od._create_columns( + table_schema_name, + columns, + ) def delete_columns( self, @@ -573,10 +592,11 @@ def delete_columns( ) print(removed) # ['new_Scratch', 'new_Flags'] """ - return self._get_odata()._delete_columns( - table_schema_name, - columns, - ) + with self._scoped_odata() as od: + return od._delete_columns( + table_schema_name, + columns, + ) # File upload def upload_file( @@ -640,18 +660,18 @@ def upload_file( mode="auto" ) """ - od = self._get_odata() - entity_set = od._entity_set_from_schema_name(table_schema_name) - od._upload_file( - entity_set, - record_id, - file_name_attribute, - path, - mode=mode, - mime_type=mime_type, - if_none_match=if_none_match, - ) - return None + with self._scoped_odata() as od: + entity_set = od._entity_set_from_schema_name(table_schema_name) + od._upload_file( + entity_set, + record_id, + file_name_attribute, + path, + mode=mode, + mime_type=mime_type, + if_none_match=if_none_match, + ) + return None # Cache utilities def flush_cache(self, kind) -> int: @@ -675,7 +695,8 @@ def flush_cache(self, kind) -> int: removed = client.flush_cache("picklist") print(f"Cleared {removed} cached picklist entries") """ - return self._get_odata()._flush_cache(kind) + with self._scoped_odata() as od: + return od._flush_cache(kind) __all__ = ["DataverseClient"] diff --git a/src/PowerPlatform/Dataverse/core/errors.py b/src/PowerPlatform/Dataverse/core/errors.py index 4fd4adb..9ade2fd 100644 --- a/src/PowerPlatform/Dataverse/core/errors.py +++ b/src/PowerPlatform/Dataverse/core/errors.py @@ -141,10 +141,12 @@ class HttpError(DataverseError): :type subcode: :class:`str` | None :param service_error_code: Optional Dataverse-specific error code from the API response. :type service_error_code: :class:`str` | None - :param correlation_id: Optional correlation ID for tracking requests across services. + :param correlation_id: Optional client-generated correlation ID for tracking requests within an SDK call. :type correlation_id: :class:`str` | None - :param request_id: Optional request ID from the API response headers. - :type request_id: :class:`str` | None + :param client_request_id: Optional client-generated request ID injected into outbound headers. + :type client_request_id: :class:`str` | None + :param service_request_id: Optional ``x-ms-service-request-id`` value returned by Dataverse servers. + :type service_request_id: :class:`str` | None :param traceparent: Optional W3C trace context for distributed tracing. :type traceparent: :class:`str` | None :param body_excerpt: Optional excerpt of the response body for diagnostics. @@ -163,7 +165,8 @@ def __init__( subcode: Optional[str] = None, service_error_code: Optional[str] = None, correlation_id: Optional[str] = None, - request_id: Optional[str] = None, + client_request_id: Optional[str] = None, + service_request_id: Optional[str] = None, traceparent: Optional[str] = None, body_excerpt: Optional[str] = None, retry_after: Optional[int] = None, @@ -174,8 +177,10 @@ def __init__( d["service_error_code"] = service_error_code if correlation_id is not None: d["correlation_id"] = correlation_id - if request_id is not None: - d["request_id"] = request_id + if client_request_id is not None: + d["client_request_id"] = client_request_id + if service_request_id is not None: + d["service_request_id"] = service_request_id if traceparent is not None: d["traceparent"] = traceparent if body_excerpt is not None: diff --git a/src/PowerPlatform/Dataverse/data/_odata.py b/src/PowerPlatform/Dataverse/data/_odata.py index 8eda7ad..ea31892 100644 --- a/src/PowerPlatform/Dataverse/data/_odata.py +++ b/src/PowerPlatform/Dataverse/data/_odata.py @@ -5,14 +5,18 @@ from __future__ import annotations -from typing import Any, Dict, Optional, List, Union, Iterable +from typing import Any, Dict, Optional, List, Union, Iterable, Callable from enum import Enum +from dataclasses import dataclass, field import unicodedata import time import re import json +import uuid from datetime import datetime, timezone import importlib.resources as ir +from contextlib import contextmanager +from contextvars import ContextVar from ..core._http import _HttpClient from ._upload import _ODataFileUpload @@ -35,6 +39,42 @@ _USER_AGENT = f"DataverseSvcPythonClient:{_SDK_VERSION}" _GUID_RE = re.compile(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}") +_CALL_SCOPE_CORRELATION_ID: ContextVar[Optional[str]] = ContextVar("_CALL_SCOPE_CORRELATION_ID", default=None) +_DEFAULT_EXPECTED_STATUSES: tuple[int, ...] = (200, 201, 202, 204) + + +@dataclass +class _RequestContext: + """Structured request context used by ``_request`` to clarify payload and metadata.""" + + method: str + url: str + expected: tuple[int, ...] = _DEFAULT_EXPECTED_STATUSES + headers: Optional[Dict[str, str]] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def build( + cls, + method: str, + url: str, + *, + expected: tuple[int, ...] = _DEFAULT_EXPECTED_STATUSES, + merge_headers: Optional[Callable[[Optional[Dict[str, str]]], Dict[str, str]]] = None, + **kwargs: Any, + ) -> "_RequestContext": + headers = kwargs.get("headers") + headers = merge_headers(headers) if merge_headers else (headers or {}) + headers.setdefault("x-ms-client-request-id", str(uuid.uuid4())) + headers.setdefault("x-ms-correlation-id", _CALL_SCOPE_CORRELATION_ID.get()) + kwargs["headers"] = headers + return cls( + method=method, + url=url, + expected=expected, + headers=headers, + kwargs=kwargs or {}, + ) class _ODataClient(_ODataFileUpload): @@ -113,6 +153,16 @@ def __init__( self._picklist_label_cache = {} self._picklist_cache_ttl_seconds = 3600 # 1 hour TTL + @contextmanager + def _call_scope(self): + """Context manager to generate a new correlation id for each SDK call scope.""" + shared_id = str(uuid.uuid4()) + token = _CALL_SCOPE_CORRELATION_ID.set(shared_id) + try: + yield shared_id + finally: + _CALL_SCOPE_CORRELATION_ID.reset(token) + def _headers(self) -> Dict[str, str]: """Build standard OData headers with bearer auth.""" scope = f"{self.base_url}/.default" @@ -137,13 +187,19 @@ def _merge_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, def _raw_request(self, method: str, url: str, **kwargs): return self._http._request(method, url, **kwargs) - def _request(self, method: str, url: str, *, expected: tuple[int, ...] = (200, 201, 202, 204), **kwargs): - headers_in = kwargs.pop("headers", None) - kwargs["headers"] = self._merge_headers(headers_in) - r = self._raw_request(method, url, **kwargs) - if r.status_code in expected: + def _request(self, method: str, url: str, *, expected: tuple[int, ...] = _DEFAULT_EXPECTED_STATUSES, **kwargs): + request_context = _RequestContext.build( + method, + url, + expected=expected, + merge_headers=self._merge_headers, + **kwargs, + ) + + r = self._raw_request(request_context.method, request_context.url, **request_context.kwargs) + if r.status_code in request_context.expected: return r - headers = getattr(r, "headers", {}) or {} + response_headers = getattr(r, "headers", {}) or {} body_excerpt = (getattr(r, "text", "") or "")[:200] svc_code = None msg = f"HTTP {r.status_code}" @@ -164,12 +220,13 @@ def _request(self, method: str, url: str, *, expected: tuple[int, ...] = (200, 2 pass sc = r.status_code subcode = _http_subcode(sc) - correlation_id = headers.get("x-ms-correlation-request-id") or headers.get("x-ms-correlation-id") request_id = ( - headers.get("x-ms-client-request-id") or headers.get("request-id") or headers.get("x-ms-request-id") + response_headers.get("x-ms-service-request-id") + or response_headers.get("req_id") + or response_headers.get("x-ms-request-id") ) - traceparent = headers.get("traceparent") - ra = headers.get("Retry-After") + traceparent = response_headers.get("traceparent") + ra = response_headers.get("Retry-After") retry_after = None if ra: try: @@ -182,8 +239,13 @@ def _request(self, method: str, url: str, *, expected: tuple[int, ...] = (200, 2 status_code=sc, subcode=subcode, service_error_code=svc_code, - correlation_id=correlation_id, - request_id=request_id, + correlation_id=request_context.headers.get( + "x-ms-correlation-id" + ), # this is a value set on client side, although it's logged on server side too + client_request_id=request_context.headers.get( + "x-ms-client-request-id" + ), # this is a value set on client side, although it's logged on server side too + service_request_id=request_id, traceparent=traceparent, body_excerpt=body_excerpt, retry_after=retry_after, diff --git a/tests/unit/core/test_http_errors.py b/tests/unit/core/test_http_errors.py index 137aea5..729ebae 100644 --- a/tests/unit/core/test_http_errors.py +++ b/tests/unit/core/test_http_errors.py @@ -2,6 +2,9 @@ # Licensed under the MIT license. import pytest +from azure.core.credentials import TokenCredential +from PowerPlatform.Dataverse.client import DataverseClient +from PowerPlatform.Dataverse.core.config import DataverseConfig from PowerPlatform.Dataverse.core.errors import HttpError from PowerPlatform.Dataverse.core._error_codes import HTTP_404, HTTP_429, HTTP_500 from PowerPlatform.Dataverse.data._odata import _ODataClient @@ -55,6 +58,25 @@ def __init__(self, responses): self._http = DummyHTTP(responses) +class RecordingHTTP(DummyHTTP): + def __init__(self, responses): + super().__init__(responses) + self.recorded_headers = [] + + def _request(self, method, url, **kwargs): + headers = (kwargs.get("headers") or {}).copy() + self.recorded_headers.append(headers) + return super()._request(method, url, **kwargs) + + +class DummyCredential(TokenCredential): + def get_token(self, *scopes, **kwargs): + class Tok: + token = "dummy-token" + + return Tok() + + # --- Tests --- @@ -120,3 +142,41 @@ def test_http_non_mapped_status_code_subcode_fallback(): c._request("get", c.api + "/accounts") err = ei.value.to_dict() assert err["subcode"] == "http_418" + + +def test_correlation_id_diff_without_scope(): + responses = [ + (200, {}, {"value": []}), + (200, {}, {"value": []}), + ] + c = MockClient([]) + recorder = RecordingHTTP(responses) + c._http = recorder + c._request("get", c.api + "/accounts") + c._request("get", c.api + "/accounts") + assert len(recorder.recorded_headers) == 2 + h1, h2 = recorder.recorded_headers + assert h1["x-ms-client-request-id"] != h2["x-ms-client-request-id"] + cid1 = h1.get("x-ms-correlation-request-id") + cid2 = h2.get("x-ms-correlation-request-id") + if cid1 is not None and cid2 is not None: + assert cid1 != cid2 + else: + assert cid1 is cid2 is None + + +def test_correlation_id_shared_inside_call_scope(): + responses = [ + (200, {}, {"value": []}), + (200, {}, {"value": []}), + ] + c = MockClient([]) + recorder = RecordingHTTP(responses) + c._http = recorder + with c._call_scope(): + c._request("get", c.api + "/accounts") + c._request("get", c.api + "/accounts") + assert len(recorder.recorded_headers) == 2 + h1, h2 = recorder.recorded_headers + assert h1["x-ms-client-request-id"] != h2["x-ms-client-request-id"] + assert h1["x-ms-correlation-id"] == h2["x-ms-correlation-id"]