diff --git a/README.md b/README.md index 517ae4f..203c6b1 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ - +

diff --git a/VERSION b/VERSION index 1809198..1545d96 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.4.0 +3.5.0 diff --git a/src/glassflow/etl/models/__init__.py b/src/glassflow/etl/models/__init__.py index 4f39346..91c3333 100644 --- a/src/glassflow/etl/models/__init__.py +++ b/src/glassflow/etl/models/__init__.py @@ -5,7 +5,6 @@ JoinConfigPatch, JoinOrientation, JoinSourceConfig, - JoinSourceConfigPatch, JoinType, ) from .pipeline import PipelineConfig, PipelineConfigPatch, PipelineStatus @@ -24,7 +23,6 @@ SourceConfigPatch, SourceType, TopicConfig, - TopicConfigPatch, ) __all__ = [ @@ -42,6 +40,7 @@ "PipelineConfigPatch", "PipelineStatus", "SinkConfig", + "SinkConfigPatch", "SinkType", "TableMapping", "Schema", @@ -52,10 +51,7 @@ "TopicConfig", "GlassFlowConfig", "SourceConfigPatch", - "TopicConfigPatch", "KafkaConnectionParamsPatch", "DeduplicationConfigPatch", "JoinConfigPatch", - "JoinSourceConfigPatch", - "SinkConfigPatch", ] diff --git a/src/glassflow/etl/models/join.py b/src/glassflow/etl/models/join.py index 66a44a5..90cd71f 100644 --- a/src/glassflow/etl/models/join.py +++ b/src/glassflow/etl/models/join.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -64,16 +64,23 @@ def validate_type( raise ValueError("type is required when join is enabled") return v + def update(self, patch: "JoinConfigPatch") -> "JoinConfig": + """Apply a patch to this join config.""" + update_dict: dict[str, Any] = {} -class JoinSourceConfigPatch(BaseModel): - source_id: Optional[str] = Field(default=None) - join_key: Optional[str] = Field(default=None) - time_window: Optional[str] = Field(default=None) - orientation: Optional[JoinOrientation] = Field(default=None) + if patch.enabled is not None: + update_dict["enabled"] = patch.enabled + if patch.type is not None: + update_dict["type"] = patch.type + if patch.sources is not None: + update_dict["sources"] = patch.sources + + if update_dict: + return self.model_copy(update=update_dict) + return self class JoinConfigPatch(BaseModel): enabled: Optional[bool] = Field(default=None) type: Optional[JoinType] = Field(default=None) - # TODO: How to patch an element in a list? sources: Optional[List[JoinSourceConfig]] = Field(default=None) diff --git a/src/glassflow/etl/models/pipeline.py b/src/glassflow/etl/models/pipeline.py index 2f188ee..e02368c 100644 --- a/src/glassflow/etl/models/pipeline.py +++ b/src/glassflow/etl/models/pipeline.py @@ -190,6 +190,39 @@ def validate_data_type_compatibility(cls, v: SinkConfig, info: Any) -> SinkConfi return v + def update(self, config_patch: "PipelineConfigPatch") -> "PipelineConfig": + """ + Apply a patch configuration to this pipeline configuration. + + Args: + config_patch: The patch configuration (PipelineConfigPatch or dict) + + Returns: + PipelineConfig: A new PipelineConfig instance with the patch applied + """ + # Start with a deep copy of the current config + updated_config = self.model_copy(deep=True) + + # Update name if provided + if config_patch.name is not None: + updated_config.name = config_patch.name + + # Update source if provided + if config_patch.source is not None: + updated_config.source = updated_config.source.update(config_patch.source) + + # Update join if provided + if config_patch.join is not None: + updated_config.join = (updated_config.join or JoinConfig()).update( + config_patch.join + ) + + # Update sink if provided + if config_patch.sink is not None: + updated_config.sink = updated_config.sink.update(config_patch.sink) + + return updated_config + class PipelineConfigPatch(BaseModel): name: Optional[str] = Field(default=None) diff --git a/src/glassflow/etl/models/sink.py b/src/glassflow/etl/models/sink.py index f7a36e8..db7cf49 100644 --- a/src/glassflow/etl/models/sink.py +++ b/src/glassflow/etl/models/sink.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional from pydantic import BaseModel, Field @@ -33,6 +33,44 @@ class SinkConfig(BaseModel): table: str table_mapping: List[TableMapping] + def update(self, patch: "SinkConfigPatch") -> "SinkConfig": + """Apply a patch to this sink config.""" + update_dict: dict[str, Any] = {} + + # Check each field explicitly to handle model instances properly + if patch.provider is not None: + update_dict["provider"] = patch.provider + if patch.host is not None: + update_dict["host"] = patch.host + if patch.port is not None: + update_dict["port"] = patch.port + if patch.http_port is not None: + update_dict["http_port"] = patch.http_port + if patch.database is not None: + update_dict["database"] = patch.database + if patch.username is not None: + update_dict["username"] = patch.username + if patch.password is not None: + update_dict["password"] = patch.password + if patch.secure is not None: + update_dict["secure"] = patch.secure + if patch.skip_certificate_verification is not None: + update_dict["skip_certificate_verification"] = ( + patch.skip_certificate_verification + ) + if patch.max_batch_size is not None: + update_dict["max_batch_size"] = patch.max_batch_size + if patch.max_delay_time is not None: + update_dict["max_delay_time"] = patch.max_delay_time + if patch.table is not None: + update_dict["table"] = patch.table + if patch.table_mapping is not None: + update_dict["table_mapping"] = patch.table_mapping + + if update_dict: + return self.model_copy(update=update_dict) + return self + class SinkConfigPatch(BaseModel): provider: Optional[str] = Field(default=None) @@ -43,6 +81,7 @@ class SinkConfigPatch(BaseModel): username: Optional[str] = Field(default=None) password: Optional[str] = Field(default=None) secure: Optional[bool] = Field(default=None) + skip_certificate_verification: Optional[bool] = Field(default=None) max_batch_size: Optional[int] = Field(default=None) max_delay_time: Optional[str] = Field(default=None) table: Optional[str] = Field(default=None) diff --git a/src/glassflow/etl/models/source.py b/src/glassflow/etl/models/source.py index dec2ef9..d7d32a4 100644 --- a/src/glassflow/etl/models/source.py +++ b/src/glassflow/etl/models/source.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator @@ -40,6 +40,29 @@ class DeduplicationConfig(BaseModel): id_field_type: Optional[KafkaDataType] = Field(default=None) time_window: Optional[str] = Field(default=None) + def update(self, patch: "DeduplicationConfigPatch") -> "DeduplicationConfig": + """Apply a patch to this deduplication config.""" + update_dict: dict[str, Any] = {} + + # Check each field explicitly - use model_fields_set to distinguish + # between "not provided" and "set to None" + fields_set = ( + patch.model_fields_set if hasattr(patch, "model_fields_set") else set() + ) + + if "enabled" in fields_set or patch.enabled is not None: + update_dict["enabled"] = patch.enabled + if "id_field" in fields_set: + update_dict["id_field"] = patch.id_field + if "id_field_type" in fields_set: + update_dict["id_field_type"] = patch.id_field_type + if "time_window" in fields_set: + update_dict["time_window"] = patch.time_window + + if update_dict: + return self.model_copy(update=update_dict) + return self + @model_validator(mode="before") @classmethod def validate_deduplication_fields(cls, values): @@ -148,6 +171,13 @@ def empty_str_to_none(values): values["mechanism"] = None return values + def update(self, patch: "KafkaConnectionParamsPatch") -> "KafkaConnectionParams": + """Apply a patch to this connection params config.""" + current_dict = self.model_dump() + patch_dict = patch.model_dump(exclude_none=True) + merged_dict = {**current_dict, **patch_dict} + return KafkaConnectionParams.model_validate(merged_dict) + class SourceType(CaseInsensitiveStrEnum): KAFKA = "kafka" @@ -159,6 +189,29 @@ class SourceConfig(BaseModel): connection_params: KafkaConnectionParams topics: List[TopicConfig] + def update(self, patch: "SourceConfigPatch") -> "SourceConfig": + """Apply a patch to this source config.""" + update_dict: dict[str, Any] = {} + + if patch.type is not None: + update_dict["type"] = patch.type + if patch.provider is not None: + update_dict["provider"] = patch.provider + + # Handle connection_params patch + if patch.connection_params is not None: + update_dict["connection_params"] = self.connection_params.update( + patch.connection_params + ) + + # Handle topics patch - full replacement only if provided + if patch.topics is not None: + update_dict["topics"] = patch.topics + + if update_dict: + return self.model_copy(update=update_dict) + return self + class DeduplicationConfigPatch(BaseModel): enabled: Optional[bool] = Field(default=None) @@ -167,13 +220,6 @@ class DeduplicationConfigPatch(BaseModel): time_window: Optional[str] = Field(default=None) -class TopicConfigPatch(BaseModel): - consumer_group_initial_offset: Optional[ConsumerGroupOffset] = Field(default=None) - name: Optional[str] = Field(default=None) - event_schema: Optional[Schema] = Field(default=None) - deduplication: Optional[DeduplicationConfigPatch] = Field(default=None) - - class KafkaConnectionParamsPatch(BaseModel): brokers: Optional[List[str]] = Field(default=None) protocol: Optional[KafkaProtocol] = Field(default=None) @@ -181,10 +227,16 @@ class KafkaConnectionParamsPatch(BaseModel): username: Optional[str] = Field(default=None) password: Optional[str] = Field(default=None) root_ca: Optional[str] = Field(default=None) + kerberos_service_name: Optional[str] = Field(default=None) + kerberos_keytab: Optional[str] = Field(default=None) + kerberos_realm: Optional[str] = Field(default=None) + kerberos_config: Optional[str] = Field(default=None) skip_auth: Optional[bool] = Field(default=None) class SourceConfigPatch(BaseModel): + type: Optional[SourceType] = Field(default=None) provider: Optional[str] = Field(default=None) connection_params: Optional[KafkaConnectionParamsPatch] = Field(default=None) - topics: Optional[List[TopicConfigPatch]] = Field(default=None) + # Full replacement only; users must provide complete TopicConfig entries + topics: Optional[List[TopicConfig]] = Field(default=None) diff --git a/src/glassflow/etl/pipeline.py b/src/glassflow/etl/pipeline.py index e4c786c..f4eb909 100644 --- a/src/glassflow/etl/pipeline.py +++ b/src/glassflow/etl/pipeline.py @@ -131,7 +131,8 @@ def rename(self, name: str) -> Pipeline: def update( self, config_patch: models.PipelineConfigPatch | dict[str, Any] ) -> Pipeline: - """Updates the pipeline with the given config. + """Updates the pipeline with the given config patch. + Pipeline must be stopped or terminated before updating. Args: config_patch: Pipeline configuration patch @@ -141,9 +142,33 @@ def update( Raises: PipelineNotFoundError: If pipeline is not found + PipelineInTransitionError: If pipeline is in transition + InvalidStatusTransitionError: If pipeline is not in a state that can be + updated APIError: If the API request fails """ - raise NotImplementedError("Updating is not implemented") + self.get() # Get latest config + if isinstance(config_patch, dict): + config_patch = models.PipelineConfigPatch.model_validate(config_patch) + else: + config_patch = config_patch + updated_config = self.config.update(config_patch) + + self._request( + "POST", + f"{self.ENDPOINT}/{self.pipeline_id}/edit", + json=updated_config.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + ), + event_name="PipelineUpdated", + ) + self.status = models.PipelineStatus.RESUMING + + # Update self.config with the updated configuration + self.config = updated_config + return self def delete(self) -> None: """ diff --git a/tests/test_models/test_config_update.py b/tests/test_models/test_config_update.py new file mode 100644 index 0000000..951d3a5 --- /dev/null +++ b/tests/test_models/test_config_update.py @@ -0,0 +1,383 @@ +"""Tests for config update methods.""" + +from glassflow.etl import models + + +class TestPipelineConfigUpdate: + """Tests for PipelineConfig.update() method.""" + + def test_update_name(self, valid_config): + """Test updating pipeline name.""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch(name="Updated Name") + + updated = config.update(patch) + + assert updated.name == "Updated Name" + assert updated.pipeline_id == config.pipeline_id + assert updated.source == config.source + assert updated.sink == config.sink + # Original config should be unchanged (immutable) + assert config.name != "Updated Name" + + def test_update_source(self, valid_config): + """Test updating source configuration.""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch( + source=models.SourceConfigPatch( + provider="new-provider", + connection_params=models.KafkaConnectionParamsPatch( + brokers=["new-broker:9092"] + ), + ) + ) + + updated = config.update(patch) + + assert updated.source.provider == "new-provider" + assert updated.source.connection_params.brokers == ["new-broker:9092"] + # Other source fields should remain unchanged + assert updated.source.type == config.source.type + assert updated.name == config.name + + def test_update_sink(self, valid_config): + """Test updating sink configuration.""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch( + sink=models.SinkConfigPatch(host="new-host", port="9000") + ) + + updated = config.update(patch) + + assert updated.sink.host == "new-host" + assert updated.sink.port == "9000" + # Other sink fields should remain unchanged + assert updated.sink.database == config.sink.database + assert updated.sink.username == config.sink.username + + def test_update_join(self, valid_config): + """Test updating join configuration.""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch(join=models.JoinConfigPatch(enabled=False)) + + updated = config.update(patch) + + assert updated.join.enabled is False + # Original join should have been enabled + assert config.join.enabled is True + + def test_update_join_when_none(self, valid_config_without_joins): + """Test updating join when it's initially None.""" + config = models.PipelineConfig(**valid_config_without_joins) + assert config.join is None or config.join.enabled is False + + patch = models.PipelineConfigPatch( + join=models.JoinConfigPatch( + enabled=True, + type=models.JoinType.TEMPORAL, + sources=[ + models.JoinSourceConfig( + source_id="user_logins", + join_key="user_id", + time_window="1h", + orientation=models.JoinOrientation.LEFT, + ), + models.JoinSourceConfig( + source_id="orders", + join_key="user_id", + time_window="1h", + orientation=models.JoinOrientation.RIGHT, + ), + ], + ) + ) + + updated = config.update(patch) + + assert updated.join.enabled is True + assert updated.join.type == models.JoinType.TEMPORAL + + def test_update_multiple_fields(self, valid_config): + """Test updating multiple fields at once.""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch( + name="Multi Update", + source=models.SourceConfigPatch(provider="updated-provider"), + sink=models.SinkConfigPatch(host="updated-host"), + ) + + updated = config.update(patch) + + assert updated.name == "Multi Update" + assert updated.source.provider == "updated-provider" + assert updated.sink.host == "updated-host" + + def test_update_empty_patch(self, valid_config): + """Test updating with an empty patch (all None).""" + config = models.PipelineConfig(**valid_config) + patch = models.PipelineConfigPatch() + + updated = config.update(patch) + + # Should return a copy with no changes + assert updated.name == config.name + assert updated.source == config.source + assert updated.sink == config.sink + + def test_update_partial_nested(self, valid_config): + """Test updating only part of a nested configuration.""" + config = models.PipelineConfig(**valid_config) + original_brokers = config.source.connection_params.brokers + original_protocol = config.source.connection_params.protocol + + patch = models.PipelineConfigPatch( + source=models.SourceConfigPatch( + connection_params=models.KafkaConnectionParamsPatch( + username="new-username" + ) + ) + ) + + updated = config.update(patch) + + # Only username should change + assert updated.source.connection_params.username == "new-username" + # Other connection params should remain unchanged + assert updated.source.connection_params.brokers == original_brokers + assert updated.source.connection_params.protocol == original_protocol + + +class TestSourceConfigUpdate: + """Tests for SourceConfig.update() method.""" + + def test_update_connection_params(self, valid_config): + """Test updating Kafka connection parameters.""" + from glassflow.etl.models.source import KafkaProtocol + + source = models.SourceConfig(**valid_config["source"]) + patch = models.SourceConfigPatch( + connection_params=models.KafkaConnectionParamsPatch( + brokers=["updated-broker:9092"], + protocol=KafkaProtocol.PLAINTEXT, + ) + ) + + updated = source.update(patch) + + assert updated.connection_params.brokers == ["updated-broker:9092"] + from glassflow.etl.models.source import KafkaProtocol + + assert updated.connection_params.protocol == KafkaProtocol.PLAINTEXT + # Other fields should remain unchanged + assert updated.type == source.type + assert updated.provider == source.provider + + def test_update_topics(self, valid_config): + """Test updating topics with full TopicConfig objects (no partial patch).""" + source = models.SourceConfig(**valid_config["source"]) + new_topic = models.TopicConfig( + name="new-topic", + schema=models.Schema( + type=models.SchemaType.JSON, + fields=[ + models.SchemaField(name="id", type=models.KafkaDataType.STRING) + ], + ), + ) + patch = models.SourceConfigPatch(topics=[new_topic]) + + updated = source.update(patch) + + assert len(updated.topics) == 1 + assert updated.topics[0].name == "new-topic" + # Topics list is replaced, not merged + assert len(source.topics) > 1 + + def test_update_provider(self, valid_config): + """Test updating provider.""" + source = models.SourceConfig(**valid_config["source"]) + patch = models.SourceConfigPatch(provider="updated-provider") + + updated = source.update(patch) + + assert updated.provider == "updated-provider" + assert updated.connection_params == source.connection_params + + +class TestSinkConfigUpdate: + """Tests for SinkConfig.update() method.""" + + def test_update_host_port(self, valid_config): + """Test updating sink host and port.""" + sink = models.SinkConfig(**valid_config["sink"]) + patch = models.SinkConfigPatch(host="new-host", port="8080") + + updated = sink.update(patch) + + assert updated.host == "new-host" + assert updated.port == "8080" + # Other fields should remain unchanged + assert updated.database == sink.database + assert updated.username == sink.username + + def test_update_credentials(self, valid_config): + """Test updating sink credentials.""" + sink = models.SinkConfig(**valid_config["sink"]) + patch = models.SinkConfigPatch(username="new-user", password="new-password") + + updated = sink.update(patch) + + assert updated.username == "new-user" + assert updated.password == "new-password" + assert updated.host == sink.host + + def test_update_table_mapping(self, valid_config): + """Test updating table mapping.""" + sink = models.SinkConfig(**valid_config["sink"]) + new_mapping = [ + models.TableMapping( + source_id="test", + field_name="id", + column_name="id", + column_type=models.ClickhouseDataType.STRING, + ) + ] + patch = models.SinkConfigPatch(table_mapping=new_mapping) + + updated = sink.update(patch) + + assert len(updated.table_mapping) == 1 + assert updated.table_mapping[0].source_id == "test" + + def test_update_multiple_sink_fields(self, valid_config): + """Test updating multiple sink fields at once.""" + sink = models.SinkConfig(**valid_config["sink"]) + patch = models.SinkConfigPatch( + host="new-host", + port="8080", + database="new-db", + table="new-table", + ) + + updated = sink.update(patch) + + assert updated.host == "new-host" + assert updated.port == "8080" + assert updated.database == "new-db" + assert updated.table == "new-table" + + +class TestJoinConfigUpdate: + """Tests for JoinConfig.update() method.""" + + def test_update_enabled(self, valid_config): + """Test updating join enabled status.""" + join = models.JoinConfig(**valid_config["join"]) + patch = models.JoinConfigPatch(enabled=False) + + updated = join.update(patch) + + assert updated.enabled is False + assert updated.type == join.type + assert updated.sources == join.sources + + def test_update_type(self, valid_config): + """Test updating join type.""" + join = models.JoinConfig(**valid_config["join"]) + patch = models.JoinConfigPatch(type=models.JoinType.TEMPORAL) + + updated = join.update(patch) + + assert updated.type == models.JoinType.TEMPORAL + assert updated.enabled == join.enabled + + def test_update_sources(self, valid_config): + """Test updating join sources.""" + join = models.JoinConfig(**valid_config["join"]) + new_sources = [ + models.JoinSourceConfig( + source_id="source1", + join_key="key1", + time_window="2h", + orientation=models.JoinOrientation.LEFT, + ), + models.JoinSourceConfig( + source_id="source2", + join_key="key2", + time_window="2h", + orientation=models.JoinOrientation.RIGHT, + ), + ] + patch = models.JoinConfigPatch(sources=new_sources) + + updated = join.update(patch) + + assert updated.sources == new_sources + assert len(updated.sources) == 2 + + +class TestKafkaConnectionParamsUpdate: + """Tests for KafkaConnectionParams.update() method.""" + + def test_update_brokers(self, valid_config): + """Test updating brokers.""" + conn_params = models.KafkaConnectionParams( + **valid_config["source"]["connection_params"] + ) + patch = models.KafkaConnectionParamsPatch(brokers=["broker1:9092"]) + + updated = conn_params.update(patch) + + assert updated.brokers == ["broker1:9092"] + # Other fields should remain unchanged + assert updated.protocol == conn_params.protocol + assert updated.mechanism == conn_params.mechanism + + def test_update_auth_fields(self, valid_config): + """Test updating authentication fields.""" + conn_params = models.KafkaConnectionParams( + **valid_config["source"]["connection_params"] + ) + patch = models.KafkaConnectionParamsPatch( + username="new-user", + password="new-pass", + mechanism=models.KafkaMechanism.PLAIN, + ) + + updated = conn_params.update(patch) + + assert updated.username == "new-user" + assert updated.password == "new-pass" + assert updated.mechanism == models.KafkaMechanism.PLAIN + + +class TestDeduplicationConfigUpdate: + """Tests for DeduplicationConfig.update() method.""" + + def test_update_enabled(self, valid_config): + """Test updating deduplication enabled status.""" + dedup = models.DeduplicationConfig( + **valid_config["source"]["topics"][0]["deduplication"] + ) + patch = models.DeduplicationConfigPatch(enabled=False) + + updated = dedup.update(patch) + + assert updated.enabled is False + # Other fields should remain unchanged + assert updated.id_field == dedup.id_field + assert updated.time_window == dedup.time_window + + def test_update_id_field(self, valid_config): + """Test updating deduplication id field.""" + dedup = models.DeduplicationConfig( + **valid_config["source"]["topics"][0]["deduplication"] + ) + patch = models.DeduplicationConfigPatch( + id_field="new_id_field", id_field_type=models.KafkaDataType.INT + ) + + updated = dedup.update(patch) + + assert updated.id_field == "new_id_field" + assert updated.id_field_type == models.KafkaDataType.INT diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ab7d135..256be77 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -187,6 +187,90 @@ def test_rename_connection_error(self, pipeline, mock_connection_error): pipeline.rename(new_name) assert "Failed to connect to GlassFlow ETL API" in str(exc_info.value) + def test_update_success( + self, pipeline, mock_success, get_pipeline_response, get_health_payload + ): + """Test successful pipeline update.""" + config_patch = models.PipelineConfigPatch(name="Updated Name") + + with mock_success( + [get_pipeline_response, get_health_payload(pipeline.pipeline_id)] + ) as mock_request: + result = pipeline.update(config_patch) + # Should call GET pipeline, GET health (from get()), + # then POST to edit endpoint + assert len(mock_request.call_args_list) == 3 + assert mock_request.call_args_list[0][0] == ( + "GET", + f"{pipeline.ENDPOINT}/{pipeline.pipeline_id}", + ) + assert mock_request.call_args_list[1][0] == ( + "GET", + f"{pipeline.ENDPOINT}/{pipeline.pipeline_id}/health", + ) + # Check the update request includes the merged config + edit_call = mock_request.call_args_list[2] + assert edit_call[0][0] == "POST" + assert edit_call[0][1] == f"{pipeline.ENDPOINT}/{pipeline.pipeline_id}/edit" + # The request should include the full updated config + edit_json = edit_call[1]["json"] + assert edit_json["name"] == "Updated Name" + assert result == pipeline + assert pipeline.config.name == "Updated Name" + + def test_update_with_dict( + self, pipeline, mock_success, get_pipeline_response, get_health_payload + ): + """Test pipeline update with dictionary patch.""" + patch_dict = {"name": "Dict Updated Name"} + + with mock_success( + [get_pipeline_response, get_health_payload(pipeline.pipeline_id)] + ): + result = pipeline.update(patch_dict) + assert result == pipeline + assert pipeline.config.name == "Dict Updated Name" + + def test_update_nested_config( + self, pipeline, mock_success, get_pipeline_response, get_health_payload + ): + """Test pipeline update with nested configuration.""" + config_patch = models.PipelineConfigPatch( + source=models.SourceConfigPatch( + connection_params=models.KafkaConnectionParamsPatch( + brokers=["new-broker:9092"] + ) + ) + ) + + with mock_success( + [get_pipeline_response, get_health_payload(pipeline.pipeline_id)] + ): + result = pipeline.update(config_patch) + assert result == pipeline + assert pipeline.config.source.connection_params.brokers == [ + "new-broker:9092" + ] + + def test_update_not_found(self, pipeline, mock_not_found_response): + """Test pipeline update when pipeline is not found.""" + from unittest.mock import patch as mock_patch + + config_patch = models.PipelineConfigPatch(name="Updated Name") + with mock_patch("httpx.Client.request", return_value=mock_not_found_response): + with pytest.raises(errors.PipelineNotFoundError): + pipeline.update(config_patch) + + def test_update_connection_error(self, pipeline, mock_connection_error): + """Test pipeline update with connection error.""" + from unittest.mock import patch as mock_patch + + config_patch = models.PipelineConfigPatch(name="Updated Name") + with mock_patch("httpx.Client.request", side_effect=mock_connection_error): + with pytest.raises(errors.ConnectionError) as exc_info: + pipeline.update(config_patch) + assert "Failed to connect to GlassFlow ETL API" in str(exc_info.value) + class TestPipelineValidation: """Tests for config validation."""