diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 0df3f95c..548b5fc3 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -7,6 +7,7 @@ from ..exceptions import ObjectValidationError, ValidationError from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema +from ..utils import is_valid_uuid from ..yaml import InfrahubFile, InfrahubFileKind from .models import InfrahubObjectParameters from .processors.factory import DataProcessorFactory @@ -33,6 +34,32 @@ def validate_list_of_objects(value: list[Any]) -> bool: return all(isinstance(item, dict) for item in value) +def normalize_hfid_reference(value: str | list[str]) -> str | list[str]: + """Normalize a reference value to HFID format. + + Args: + value: Either a string (ID or single-component HFID) or a list of strings (multi-component HFID). + + Returns: + - If value is already a list: returns it unchanged as list[str] + - If value is a valid UUID string: returns it unchanged as str (will be treated as an ID) + - If value is a non-UUID string: wraps it in a list as list[str] (single-component HFID) + """ + if isinstance(value, list): + return value + if is_valid_uuid(value): + return value + return [value] + + +def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]: + """Normalize a list of reference values to HFID format. + + Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID. + """ + return [normalize_hfid_reference(v) for v in values] + + class RelationshipDataFormat(str, Enum): UNKNOWN = "unknown" @@ -444,10 +471,13 @@ async def create_node( # - if the relationship is bidirectional and is mandatory on the other side, then we need to create this object First # - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First # - if the relationship is not bidirectional, then we need to create the related object First - if rel_info.is_reference and isinstance(value, list): - clean_data[key] = value - elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str): - clean_data[key] = [value] + if rel_info.format == RelationshipDataFormat.MANY_REF and isinstance(value, list): + # Cardinality-many: normalize each string HFID to list format: "name" -> ["name"] + # UUIDs are left as-is since they are treated as IDs + clean_data[key] = normalize_hfid_references(value) + elif rel_info.format == RelationshipDataFormat.ONE_REF: + # Cardinality-one: normalize string to HFID list format: "name" -> ["name"] or keep as string (UUID) + clean_data[key] = normalize_hfid_reference(value) elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory: remaining_rels.append(key) elif not rel_info.is_reference and not rel_info.is_mandatory: diff --git a/tests/unit/sdk/spec/test_object.py b/tests/unit/sdk/spec/test_object.py index 1af02ac3..581b2572 100644 --- a/tests/unit/sdk/spec/test_object.py +++ b/tests/unit/sdk/spec/test_object.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, patch import pytest @@ -9,6 +11,7 @@ if TYPE_CHECKING: from infrahub_sdk.client import InfrahubClient + from infrahub_sdk.node import InfrahubNode @pytest.fixture @@ -263,3 +266,105 @@ async def test_parameters_non_dict(client_with_schema_01: InfrahubClient, locati obj = ObjectFile(location="some/path", content=location_with_non_dict_parameters) with pytest.raises(ValidationError): await obj.validate_format(client=client_with_schema_01) + + +@dataclass +class HfidLoadTestCase: + """Test case for HFID normalization in object loading.""" + + name: str + data: list[dict[str, Any]] + expected_primary_tag: str | list[str] | None + expected_tags: list[str] | list[list[str]] | None + + +HFID_NORMALIZATION_TEST_CASES = [ + HfidLoadTestCase( + name="cardinality_one_string_hfid_normalized", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "Important"}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_list_hfid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": ["Important"]}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_uuid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "550e8400-e29b-41d4-a716-446655440000"}], + expected_primary_tag="550e8400-e29b-41d4-a716-446655440000", + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_many_string_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", "Active"]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_list_hfids_unchanged", + data=[{"name": "Mexico", "type": "Country", "tags": [["Important"], ["Active"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_mixed_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", ["namespace", "name"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["namespace", "name"]], + ), + HfidLoadTestCase( + name="cardinality_many_uuids_unchanged", + data=[ + { + "name": "Mexico", + "type": "Country", + "tags": ["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + } + ], + expected_primary_tag=None, + expected_tags=["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + ), +] + + +@pytest.mark.parametrize("test_case", HFID_NORMALIZATION_TEST_CASES, ids=lambda tc: tc.name) +async def test_hfid_normalization_in_object_loading( + client_with_schema_01: InfrahubClient, test_case: HfidLoadTestCase +) -> None: + """Test that HFIDs are normalized correctly based on cardinality and format.""" + + root_location = {"apiVersion": "infrahub.app/v1", "kind": "Object", "spec": {"kind": "BuiltinLocation", "data": []}} + location = { + "apiVersion": root_location["apiVersion"], + "kind": root_location["kind"], + "spec": {"kind": root_location["spec"]["kind"], "data": test_case.data}, + } + + obj = ObjectFile(location="some/path", content=location) + await obj.validate_format(client=client_with_schema_01) + + create_calls: list[dict[str, Any]] = [] + + async def mock_create( + kind: str, + branch: str | None = None, + data: dict | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> InfrahubNode: + create_calls.append({"kind": kind, "data": data}) + original_create = client_with_schema_01.__class__.create + return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs) + + client_with_schema_01.create = mock_create + + with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock): + await obj.process(client=client_with_schema_01) + + assert len(create_calls) == 1 + if test_case.expected_primary_tag is not None: + assert create_calls[0]["data"]["primary_tag"] == test_case.expected_primary_tag + if test_case.expected_tags is not None: + assert create_calls[0]["data"]["tags"] == test_case.expected_tags