From f6f4557e011e8055fc34f214a6f9a117b35ad8d9 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 1 Dec 2025 09:51:49 +0000 Subject: [PATCH] Expand unit tests to increase code coverage Add comprehensive unit tests for multiple modules: - Webhook handler: Tests for initialization, JSONPath filtering, event sending, retry logic, and different payload types (92% coverage) - Queue models: Tests for CommandType, QueuedCommand, CommandResult, QueueStats, QueueInfo, and their serialization (100% coverage) - Metrics collector: Tests for all Prometheus metrics recording methods and singleton pattern (100% coverage) - Database engine: Tests for engine lifecycle, session management, and global functions (100% coverage) - Queue manager: Tests for initialization, start/stop, enqueue, debouncing, queue full handling, and worker processing (90% coverage) - API schemas: Tests for Pydantic validation, request/response schemas, tag values, and coordinate validation (97% coverage) Overall test coverage improved from 29% to 38%. --- tests/unit/test_api_schemas.py | 471 +++++++++++++++++++++++++++++ tests/unit/test_database_engine.py | 222 ++++++++++++++ tests/unit/test_metrics.py | 215 +++++++++++++ tests/unit/test_queue_manager.py | 382 +++++++++++++++++++++++ tests/unit/test_queue_models.py | 269 ++++++++++++++++ tests/unit/test_webhook_handler.py | 324 ++++++++++++++++++++ 6 files changed, 1883 insertions(+) create mode 100644 tests/unit/test_api_schemas.py create mode 100644 tests/unit/test_database_engine.py create mode 100644 tests/unit/test_metrics.py create mode 100644 tests/unit/test_queue_manager.py create mode 100644 tests/unit/test_queue_models.py create mode 100644 tests/unit/test_webhook_handler.py diff --git a/tests/unit/test_api_schemas.py b/tests/unit/test_api_schemas.py new file mode 100644 index 0000000..1a2b69e --- /dev/null +++ b/tests/unit/test_api_schemas.py @@ -0,0 +1,471 @@ +"""Unit tests for API Pydantic schemas.""" + +from datetime import datetime + +import pytest +from pydantic import ValidationError + +from meshcore_api.api.schemas import ( + AdvertisementFilters, + AdvertisementResponse, + BulkTagUpdateRequest, + CoordinateValue, + ErrorResponse, + MessageFilters, + MessageResponse, + NodeResponse, + PaginationParams, + PingRequest, + SendChannelMessageRequest, + SendMessageRequest, + SendAdvertRequest, + SendTelemetryRequestRequest, + SendTracePathRequest, + TagValueRequest, + TagValueUpdateRequest, + TelemetryFilters, + TracePathFilters, +) + + +class TestPaginationParams: + """Test PaginationParams schema.""" + + def test_default_values(self): + """Test default pagination values.""" + params = PaginationParams() + assert params.limit == 100 + assert params.offset == 0 + + def test_custom_values(self): + """Test custom pagination values.""" + params = PaginationParams(limit=50, offset=100) + assert params.limit == 50 + assert params.offset == 100 + + def test_limit_min_validation(self): + """Test limit minimum validation.""" + with pytest.raises(ValidationError): + PaginationParams(limit=0) + + def test_limit_max_validation(self): + """Test limit maximum validation.""" + with pytest.raises(ValidationError): + PaginationParams(limit=1001) + + def test_offset_min_validation(self): + """Test offset minimum validation.""" + with pytest.raises(ValidationError): + PaginationParams(offset=-1) + + +class TestErrorResponse: + """Test ErrorResponse schema.""" + + def test_error_only(self): + """Test error response with just error message.""" + response = ErrorResponse(error="Something went wrong") + assert response.error == "Something went wrong" + assert response.detail is None + + def test_error_with_detail(self): + """Test error response with detail.""" + response = ErrorResponse(error="Error", detail="Additional info") + assert response.error == "Error" + assert response.detail == "Additional info" + + +class TestNodeResponse: + """Test NodeResponse schema.""" + + def test_from_attributes(self): + """Test NodeResponse can be created with from_attributes.""" + response = NodeResponse( + id=1, + public_key="a" * 64, + first_seen=datetime.utcnow(), + created_at=datetime.utcnow(), + ) + assert response.id == 1 + assert response.public_key == "a" * 64 + + def test_optional_fields(self): + """Test NodeResponse optional fields.""" + response = NodeResponse( + id=1, + public_key="a" * 64, + node_type="repeater", + name="Test Node", + last_seen=datetime.utcnow(), + first_seen=datetime.utcnow(), + created_at=datetime.utcnow(), + tags={"friendly_name": "My Node"}, + ) + assert response.node_type == "repeater" + assert response.name == "Test Node" + assert response.tags == {"friendly_name": "My Node"} + + +class TestMessageResponse: + """Test MessageResponse schema.""" + + def test_basic_message(self): + """Test basic message response.""" + response = MessageResponse( + id=1, + direction="inbound", + message_type="contact", + content="Hello World", + received_at=datetime.utcnow(), + ) + assert response.id == 1 + assert response.direction == "inbound" + assert response.content == "Hello World" + + def test_message_with_all_fields(self): + """Test message response with all fields.""" + response = MessageResponse( + id=1, + direction="inbound", + message_type="channel", + pubkey_prefix="abc123", + channel_idx=4, + txt_type=1, + path_len=3, + signature="sig123", + content="Hello", + snr=8.5, + sender_timestamp=datetime.utcnow(), + received_at=datetime.utcnow(), + ) + assert response.channel_idx == 4 + assert response.snr == 8.5 + + +class TestMessageFilters: + """Test MessageFilters schema.""" + + def test_empty_filters(self): + """Test empty message filters.""" + filters = MessageFilters() + assert filters.pubkey_prefix is None + assert filters.channel_idx is None + assert filters.message_type is None + + def test_pubkey_prefix_validation(self): + """Test pubkey prefix length validation.""" + # Valid prefix (2-12 chars) + filters = MessageFilters(pubkey_prefix="ab") + assert filters.pubkey_prefix == "ab" + + filters = MessageFilters(pubkey_prefix="a" * 12) + assert len(filters.pubkey_prefix) == 12 + + def test_pubkey_prefix_too_short(self): + """Test pubkey prefix minimum length.""" + with pytest.raises(ValidationError): + MessageFilters(pubkey_prefix="a") + + def test_pubkey_prefix_too_long(self): + """Test pubkey prefix maximum length.""" + with pytest.raises(ValidationError): + MessageFilters(pubkey_prefix="a" * 13) + + +class TestAdvertisementResponse: + """Test AdvertisementResponse schema.""" + + def test_basic_advertisement(self): + """Test basic advertisement response.""" + response = AdvertisementResponse( + id=1, + public_key="a" * 64, + received_at=datetime.utcnow(), + ) + assert response.id == 1 + assert response.public_key == "a" * 64 + + +class TestAdvertisementFilters: + """Test AdvertisementFilters schema.""" + + def test_node_prefix_validation(self): + """Test node prefix length validation.""" + filters = AdvertisementFilters(node_prefix="ab") + assert filters.node_prefix == "ab" + + filters = AdvertisementFilters(node_prefix="a" * 64) + assert len(filters.node_prefix) == 64 + + +class TestTracePathFilters: + """Test TracePathFilters schema.""" + + def test_date_filters(self): + """Test trace path date filters.""" + now = datetime.utcnow() + filters = TracePathFilters(start_date=now, end_date=now) + assert filters.start_date == now + assert filters.end_date == now + + +class TestTelemetryFilters: + """Test TelemetryFilters schema.""" + + def test_all_filters(self): + """Test all telemetry filters.""" + now = datetime.utcnow() + filters = TelemetryFilters( + node_prefix="abc", + start_date=now, + end_date=now, + ) + assert filters.node_prefix == "abc" + + +class TestSendMessageRequest: + """Test SendMessageRequest schema.""" + + def test_valid_request(self): + """Test valid send message request.""" + request = SendMessageRequest( + destination="a" * 64, + text="Hello World", + ) + assert request.destination == "a" * 64 + assert request.text == "Hello World" + assert request.text_type == "plain" + + def test_destination_hex_validation(self): + """Test destination must be valid hex.""" + with pytest.raises(ValidationError) as exc_info: + SendMessageRequest(destination="g" * 64, text="Hello") + assert "hexadecimal" in str(exc_info.value).lower() + + def test_destination_case_normalization(self): + """Test destination is normalized to lowercase.""" + request = SendMessageRequest( + destination="A" * 64, + text="Hello", + ) + assert request.destination == "a" * 64 + + def test_destination_length_validation(self): + """Test destination must be 64 characters.""" + with pytest.raises(ValidationError): + SendMessageRequest(destination="abc", text="Hello") + + def test_text_min_length(self): + """Test text minimum length.""" + with pytest.raises(ValidationError): + SendMessageRequest(destination="a" * 64, text="") + + def test_text_max_length(self): + """Test text maximum length.""" + with pytest.raises(ValidationError): + SendMessageRequest(destination="a" * 64, text="x" * 1001) + + +class TestSendChannelMessageRequest: + """Test SendChannelMessageRequest schema.""" + + def test_valid_request(self): + """Test valid channel message request.""" + request = SendChannelMessageRequest(text="Broadcast message") + assert request.text == "Broadcast message" + assert request.flood is False + + def test_with_flood(self): + """Test channel message with flood enabled.""" + request = SendChannelMessageRequest(text="Flood message", flood=True) + assert request.flood is True + + +class TestSendAdvertRequest: + """Test SendAdvertRequest schema.""" + + def test_default_flood(self): + """Test default flood value.""" + request = SendAdvertRequest() + assert request.flood is False + + def test_with_flood(self): + """Test advert with flood enabled.""" + request = SendAdvertRequest(flood=True) + assert request.flood is True + + +class TestSendTracePathRequest: + """Test SendTracePathRequest schema.""" + + def test_valid_request(self): + """Test valid trace path request.""" + request = SendTracePathRequest(destination="a" * 64) + assert request.destination == "a" * 64 + + def test_destination_validation(self): + """Test destination validation.""" + with pytest.raises(ValidationError): + SendTracePathRequest(destination="invalid") + + +class TestPingRequest: + """Test PingRequest schema.""" + + def test_valid_request(self): + """Test valid ping request.""" + request = PingRequest(destination="b" * 64) + assert request.destination == "b" * 64 + + def test_mixed_case_normalization(self): + """Test mixed case destination is normalized.""" + request = PingRequest(destination="AbCdEf" + "0" * 58) + assert request.destination == "abcdef" + "0" * 58 + + +class TestSendTelemetryRequestRequest: + """Test SendTelemetryRequestRequest schema.""" + + def test_valid_request(self): + """Test valid telemetry request.""" + request = SendTelemetryRequestRequest(destination="c" * 64) + assert request.destination == "c" * 64 + + +class TestCoordinateValue: + """Test CoordinateValue schema.""" + + def test_valid_coordinate(self): + """Test valid coordinate.""" + coord = CoordinateValue(latitude=37.7749, longitude=-122.4194) + assert coord.latitude == 37.7749 + assert coord.longitude == -122.4194 + + def test_latitude_range(self): + """Test latitude range validation.""" + # Valid extremes + CoordinateValue(latitude=-90, longitude=0) + CoordinateValue(latitude=90, longitude=0) + + # Invalid + with pytest.raises(ValidationError): + CoordinateValue(latitude=-91, longitude=0) + with pytest.raises(ValidationError): + CoordinateValue(latitude=91, longitude=0) + + def test_longitude_range(self): + """Test longitude range validation.""" + # Valid extremes + CoordinateValue(latitude=0, longitude=-180) + CoordinateValue(latitude=0, longitude=180) + + # Invalid + with pytest.raises(ValidationError): + CoordinateValue(latitude=0, longitude=-181) + with pytest.raises(ValidationError): + CoordinateValue(latitude=0, longitude=181) + + +class TestTagValueUpdateRequest: + """Test TagValueUpdateRequest schema.""" + + def test_string_tag(self): + """Test string tag value.""" + tag = TagValueUpdateRequest(value_type="string", value="Test Value") + assert tag.value_type == "string" + assert tag.value == "Test Value" + + def test_number_tag_int(self): + """Test number tag with integer value.""" + tag = TagValueUpdateRequest(value_type="number", value=42) + assert tag.value == 42 + + def test_number_tag_float(self): + """Test number tag with float value.""" + tag = TagValueUpdateRequest(value_type="number", value=3.14) + assert tag.value == 3.14 + + def test_boolean_tag_true(self): + """Test boolean tag with True value.""" + tag = TagValueUpdateRequest(value_type="boolean", value=True) + assert tag.value is True + + def test_boolean_tag_false(self): + """Test boolean tag with False value.""" + tag = TagValueUpdateRequest(value_type="boolean", value=False) + assert tag.value is False + + def test_coordinate_tag(self): + """Test coordinate tag value.""" + coord = CoordinateValue(latitude=40.7128, longitude=-74.0060) + tag = TagValueUpdateRequest(value_type="coordinate", value=coord) + assert tag.value.latitude == 40.7128 + assert tag.value.longitude == -74.0060 + + def test_string_type_wrong_value(self): + """Test string type with non-string value fails.""" + with pytest.raises(ValidationError): + TagValueUpdateRequest(value_type="string", value=123) + + def test_number_type_wrong_value(self): + """Test number type with non-number value fails.""" + with pytest.raises(ValidationError): + TagValueUpdateRequest(value_type="number", value="not a number") + + def test_number_type_boolean_value_fails(self): + """Test number type rejects boolean values.""" + with pytest.raises(ValidationError): + TagValueUpdateRequest(value_type="number", value=True) + + def test_boolean_type_wrong_value(self): + """Test boolean type with non-boolean value fails.""" + with pytest.raises(ValidationError): + TagValueUpdateRequest(value_type="boolean", value="true") + + +class TestTagValueRequest: + """Test TagValueRequest schema (includes key).""" + + def test_with_key(self): + """Test tag value request with key.""" + tag = TagValueRequest(key="friendly_name", value_type="string", value="My Node") + assert tag.key == "friendly_name" + assert tag.value == "My Node" + + def test_key_min_length(self): + """Test key minimum length.""" + with pytest.raises(ValidationError): + TagValueRequest(key="", value_type="string", value="test") + + def test_key_max_length(self): + """Test key maximum length.""" + with pytest.raises(ValidationError): + TagValueRequest(key="a" * 129, value_type="string", value="test") + + +class TestBulkTagUpdateRequest: + """Test BulkTagUpdateRequest schema.""" + + def test_valid_bulk_update(self): + """Test valid bulk tag update.""" + request = BulkTagUpdateRequest( + tags=[ + TagValueRequest(key="name", value_type="string", value="Test"), + TagValueRequest(key="count", value_type="number", value=5), + ] + ) + assert len(request.tags) == 2 + + def test_empty_tags_fails(self): + """Test empty tags list fails.""" + with pytest.raises(ValidationError): + BulkTagUpdateRequest(tags=[]) + + def test_too_many_tags_fails(self): + """Test too many tags fails.""" + tags = [ + TagValueRequest(key=f"tag_{i}", value_type="string", value=f"val_{i}") + for i in range(51) + ] + with pytest.raises(ValidationError): + BulkTagUpdateRequest(tags=tags) diff --git a/tests/unit/test_database_engine.py b/tests/unit/test_database_engine.py new file mode 100644 index 0000000..81e9efd --- /dev/null +++ b/tests/unit/test_database_engine.py @@ -0,0 +1,222 @@ +"""Unit tests for database engine module.""" + +import os +import tempfile +from pathlib import Path + +import pytest + +from meshcore_api.database import engine +from meshcore_api.database.models import Node + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database path.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield os.path.join(tmpdir, "test.db") + + +@pytest.fixture +def db_engine(temp_db_path): + """Create a database engine for testing.""" + db = engine.DatabaseEngine(temp_db_path) + db.initialize() + yield db + db.close() + + +@pytest.fixture(autouse=True) +def reset_global_engine(): + """Reset global database engine before and after each test.""" + engine._db_engine = None + yield + engine._db_engine = None + + +class TestDatabaseEngine: + """Test DatabaseEngine class.""" + + def test_init_stores_path(self, temp_db_path): + """Test DatabaseEngine stores the database path.""" + db = engine.DatabaseEngine(temp_db_path) + assert db.db_path == temp_db_path + assert db.engine is None + assert db.session_factory is None + + def test_initialize_creates_database_file(self, temp_db_path): + """Test initialize creates the database file.""" + db = engine.DatabaseEngine(temp_db_path) + db.initialize() + assert os.path.exists(temp_db_path) + db.close() + + def test_initialize_creates_parent_directory(self): + """Test initialize creates parent directories if needed.""" + with tempfile.TemporaryDirectory() as tmpdir: + nested_path = os.path.join(tmpdir, "nested", "dir", "test.db") + db = engine.DatabaseEngine(nested_path) + db.initialize() + assert os.path.exists(nested_path) + db.close() + + def test_initialize_sets_engine_and_factory(self, temp_db_path): + """Test initialize sets up engine and session factory.""" + db = engine.DatabaseEngine(temp_db_path) + db.initialize() + assert db.engine is not None + assert db.session_factory is not None + db.close() + + def test_get_session_returns_session(self, db_engine): + """Test get_session returns a valid session.""" + session = db_engine.get_session() + assert session is not None + session.close() + + def test_get_session_raises_if_not_initialized(self, temp_db_path): + """Test get_session raises error if not initialized.""" + db = engine.DatabaseEngine(temp_db_path) + with pytest.raises(RuntimeError) as exc_info: + db.get_session() + assert "not initialized" in str(exc_info.value).lower() + + def test_session_scope_commits_on_success(self, db_engine): + """Test session_scope commits changes on success.""" + public_key = "a" * 64 + with db_engine.session_scope() as session: + node = Node( + public_key=public_key, + public_key_prefix_2=public_key[:2], + public_key_prefix_8=public_key[:8], + ) + session.add(node) + + # Verify node was committed + with db_engine.session_scope() as session: + found = session.query(Node).filter_by(public_key=public_key).first() + assert found is not None + assert found.public_key == public_key + + def test_session_scope_rollbacks_on_exception(self, db_engine): + """Test session_scope rolls back changes on exception.""" + public_key = "b" * 64 + + try: + with db_engine.session_scope() as session: + node = Node( + public_key=public_key, + public_key_prefix_2=public_key[:2], + public_key_prefix_8=public_key[:8], + ) + session.add(node) + raise ValueError("Test error") + except ValueError: + pass + + # Verify node was not committed + with db_engine.session_scope() as session: + found = session.query(Node).filter_by(public_key=public_key).first() + assert found is None + + def test_close_disposes_engine(self, temp_db_path): + """Test close disposes the engine.""" + db = engine.DatabaseEngine(temp_db_path) + db.initialize() + db.close() + # Engine should be disposed (no error on close) + assert True + + def test_close_handles_none_engine(self, temp_db_path): + """Test close handles case where engine is None.""" + db = engine.DatabaseEngine(temp_db_path) + db.close() # Should not raise + assert True + + +class TestGlobalFunctions: + """Test global database functions.""" + + def test_init_database_creates_engine(self, temp_db_path): + """Test init_database creates and returns engine.""" + db = engine.init_database(temp_db_path) + assert db is not None + assert engine._db_engine is db + db.close() + + def test_init_database_initializes_tables(self, temp_db_path): + """Test init_database creates tables.""" + db = engine.init_database(temp_db_path) + # Should be able to query nodes table + with db.session_scope() as session: + count = session.query(Node).count() + assert count == 0 + db.close() + + def test_get_database_returns_engine(self, temp_db_path): + """Test get_database returns the global engine.""" + created_db = engine.init_database(temp_db_path) + retrieved_db = engine.get_database() + assert retrieved_db is created_db + created_db.close() + + def test_get_database_raises_if_not_initialized(self): + """Test get_database raises error if not initialized.""" + with pytest.raises(RuntimeError) as exc_info: + engine.get_database() + assert "not initialized" in str(exc_info.value).lower() + + def test_get_session_uses_global_engine(self, temp_db_path): + """Test get_session function uses global engine.""" + engine.init_database(temp_db_path) + session = engine.get_session() + assert session is not None + session.close() + engine.get_database().close() + + def test_get_session_raises_if_not_initialized(self): + """Test get_session raises error if not initialized.""" + with pytest.raises(RuntimeError): + engine.get_session() + + def test_session_scope_uses_global_engine(self, temp_db_path): + """Test session_scope function uses global engine.""" + engine.init_database(temp_db_path) + public_key = "c" * 64 + + with engine.session_scope() as session: + node = Node( + public_key=public_key, + public_key_prefix_2=public_key[:2], + public_key_prefix_8=public_key[:8], + ) + session.add(node) + + # Verify node was committed + with engine.session_scope() as session: + found = session.query(Node).filter_by(public_key=public_key).first() + assert found is not None + + engine.get_database().close() + + +class TestDatabaseEngineMultipleSessions: + """Test multiple sessions and concurrent access.""" + + def test_multiple_sessions_independent(self, db_engine): + """Test multiple sessions are independent.""" + session1 = db_engine.get_session() + session2 = db_engine.get_session() + assert session1 is not session2 + session1.close() + session2.close() + + def test_session_scope_yields_unique_sessions(self, db_engine): + """Test session_scope yields unique sessions each time.""" + sessions = [] + for _ in range(3): + with db_engine.session_scope() as session: + sessions.append(id(session)) + + # All sessions should be unique (different objects) + assert len(set(sessions)) == 3 diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py new file mode 100644 index 0000000..64ab06b --- /dev/null +++ b/tests/unit/test_metrics.py @@ -0,0 +1,215 @@ +"""Unit tests for Prometheus metrics collector.""" + +import pytest +from prometheus_client import REGISTRY + +from meshcore_api.subscriber import metrics + + +# Use a module-level fixture that runs once to create the collector +@pytest.fixture(scope="module") +def metrics_collector(): + """Get or create a metrics collector for testing. + + Prometheus metrics are registered globally and cannot be re-registered, + so we use the singleton pattern and share one collector across all tests. + """ + # If there's already a global instance, use it + if metrics._metrics is not None: + return metrics._metrics + # Otherwise create one via get_metrics() + return metrics.get_metrics() + + +class TestMetricsCollector: + """Test MetricsCollector class.""" + + def test_collector_has_event_counter(self, metrics_collector): + """Test MetricsCollector has event counter.""" + assert metrics_collector.events_total is not None + + def test_collector_has_message_counter(self, metrics_collector): + """Test MetricsCollector has message counter.""" + assert metrics_collector.messages_total is not None + + def test_collector_has_advertisement_counter(self, metrics_collector): + """Test MetricsCollector has advertisement counter.""" + assert metrics_collector.advertisements_total is not None + + def test_collector_has_packet_counter(self, metrics_collector): + """Test MetricsCollector has packet counter.""" + assert metrics_collector.packets_total is not None + + def test_collector_has_cleanup_counter(self, metrics_collector): + """Test MetricsCollector has cleanup counter.""" + assert metrics_collector.db_cleanup_rows_deleted is not None + + def test_collector_has_error_counter(self, metrics_collector): + """Test MetricsCollector has error counter.""" + assert metrics_collector.errors_total is not None + + def test_collector_has_node_gauges(self, metrics_collector): + """Test MetricsCollector has node gauges.""" + assert metrics_collector.nodes_total is not None + assert metrics_collector.nodes_active is not None + assert metrics_collector.nodes_by_area is not None + assert metrics_collector.nodes_by_role is not None + assert metrics_collector.nodes_online is not None + assert metrics_collector.nodes_with_tags is not None + + def test_collector_has_battery_gauges(self, metrics_collector): + """Test MetricsCollector has battery gauges.""" + assert metrics_collector.battery_voltage is not None + assert metrics_collector.battery_percentage is not None + + def test_collector_has_storage_gauges(self, metrics_collector): + """Test MetricsCollector has storage gauges.""" + assert metrics_collector.storage_used_bytes is not None + assert metrics_collector.storage_total_bytes is not None + + def test_collector_has_radio_gauges(self, metrics_collector): + """Test MetricsCollector has radio gauges.""" + assert metrics_collector.radio_noise_floor_dbm is not None + assert metrics_collector.radio_airtime_percent is not None + + def test_collector_has_database_gauges(self, metrics_collector): + """Test MetricsCollector has database gauges.""" + assert metrics_collector.db_table_rows is not None + assert metrics_collector.db_size_bytes is not None + + def test_collector_has_connection_gauge(self, metrics_collector): + """Test MetricsCollector has connection status gauge.""" + assert metrics_collector.connection_status is not None + + def test_collector_has_histograms(self, metrics_collector): + """Test MetricsCollector has histograms.""" + assert metrics_collector.message_roundtrip_seconds is not None + assert metrics_collector.path_hop_count is not None + assert metrics_collector.snr_db is not None + assert metrics_collector.rssi_dbm is not None + + def test_record_event(self, metrics_collector): + """Test recording events increments counter.""" + metrics_collector.record_event("CONTACT_MSG_RECV") + metrics_collector.record_event("CHANNEL_MSG_RECV") + # Verify it doesn't raise an error + assert True + + def test_record_message(self, metrics_collector): + """Test recording messages with labels.""" + metrics_collector.record_message(direction="inbound", message_type="contact") + metrics_collector.record_message(direction="outbound", message_type="channel") + assert True + + def test_record_advertisement(self, metrics_collector): + """Test recording advertisements.""" + metrics_collector.record_advertisement(adv_type="repeater") + metrics_collector.record_advertisement(adv_type=None) # Test None handling + assert True + + def test_record_roundtrip(self, metrics_collector): + """Test recording roundtrip time.""" + metrics_collector.record_roundtrip(milliseconds=150) + metrics_collector.record_roundtrip(milliseconds=5000) + assert True + + def test_record_hop_count(self, metrics_collector): + """Test recording hop counts.""" + metrics_collector.record_hop_count(hops=3) + metrics_collector.record_hop_count(hops=1) + metrics_collector.record_hop_count(hops=10) + assert True + + def test_record_snr(self, metrics_collector): + """Test recording SNR measurements.""" + metrics_collector.record_snr(snr=15.5) + metrics_collector.record_snr(snr=-5.0) + assert True + + def test_record_rssi(self, metrics_collector): + """Test recording RSSI measurements.""" + metrics_collector.record_rssi(rssi=-80.0) + metrics_collector.record_rssi(rssi=-110.0) + assert True + + def test_update_battery(self, metrics_collector): + """Test updating battery metrics.""" + metrics_collector.update_battery(voltage=3.7) + metrics_collector.update_battery(percentage=85) + metrics_collector.update_battery(voltage=4.2, percentage=100) + metrics_collector.update_battery() # No values + assert True + + def test_update_storage(self, metrics_collector): + """Test updating storage metrics.""" + metrics_collector.update_storage(used=1024000) + metrics_collector.update_storage(total=4096000) + metrics_collector.update_storage(used=2048000, total=4096000) + metrics_collector.update_storage() # No values + assert True + + def test_update_radio_stats(self, metrics_collector): + """Test updating radio statistics.""" + metrics_collector.update_radio_stats(noise_floor=-95.0) + metrics_collector.update_radio_stats(airtime=5.5) + metrics_collector.update_radio_stats(noise_floor=-90.0, airtime=10.2) + metrics_collector.update_radio_stats() # No values + assert True + + def test_record_packet(self, metrics_collector): + """Test recording packet transmissions.""" + metrics_collector.record_packet(direction="tx", status="success") + metrics_collector.record_packet(direction="rx", status="success") + metrics_collector.record_packet(direction="tx", status="failed") + assert True + + def test_update_db_table_rows(self, metrics_collector): + """Test updating database table row counts.""" + metrics_collector.update_db_table_rows(table="nodes", count=100) + metrics_collector.update_db_table_rows(table="messages", count=5000) + assert True + + def test_update_db_size(self, metrics_collector): + """Test updating database size.""" + metrics_collector.update_db_size(size_bytes=1048576) + assert True + + def test_record_cleanup(self, metrics_collector): + """Test recording cleanup operations.""" + metrics_collector.record_cleanup(table="messages", count=100) + metrics_collector.record_cleanup(table="events_log", count=500) + assert True + + def test_set_connection_status(self, metrics_collector): + """Test setting connection status.""" + metrics_collector.set_connection_status(connected=True) + metrics_collector.set_connection_status(connected=False) + assert True + + def test_record_error(self, metrics_collector): + """Test recording errors.""" + metrics_collector.record_error(component="database", error_type="connection") + metrics_collector.record_error(component="meshcore", error_type="timeout") + assert True + + +class TestGetMetrics: + """Test get_metrics function.""" + + def test_get_metrics_returns_collector(self, metrics_collector): + """Test get_metrics returns a collector.""" + collector = metrics.get_metrics() + assert collector is not None + assert isinstance(collector, metrics.MetricsCollector) + + def test_get_metrics_returns_same_instance(self, metrics_collector): + """Test get_metrics returns the same instance on repeated calls.""" + collector1 = metrics.get_metrics() + collector2 = metrics.get_metrics() + assert collector1 is collector2 + + def test_get_metrics_is_singleton(self, metrics_collector): + """Test metrics collector follows singleton pattern.""" + # All calls should return the same collector + collector = metrics.get_metrics() + assert collector is metrics._metrics diff --git a/tests/unit/test_queue_manager.py b/tests/unit/test_queue_manager.py new file mode 100644 index 0000000..04146be --- /dev/null +++ b/tests/unit/test_queue_manager.py @@ -0,0 +1,382 @@ +"""Unit tests for command queue manager.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio + +from meshcore_api.queue.manager import CommandQueueManager, QueueFullError +from meshcore_api.queue.models import CommandType, QueueFullBehavior + + +@pytest.fixture +def mock_meshcore(): + """Create a mock MeshCore interface.""" + mock = MagicMock() + mock.send_message = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + mock.send_channel_message = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + mock.send_advert = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + mock.send_trace_path = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + mock.ping = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + mock.send_telemetry_request = AsyncMock(return_value=MagicMock(to_dict=lambda: {})) + return mock + + +@pytest_asyncio.fixture +async def queue_manager(mock_meshcore): + """Create a queue manager for testing.""" + manager = CommandQueueManager( + meshcore=mock_meshcore, + max_queue_size=10, + rate_limit_per_second=100.0, # Fast for testing + rate_limit_burst=10, + debounce_window_seconds=0.1, # Short for testing + debounce_enabled=True, + ) + yield manager + await manager.stop() + + +@pytest.mark.asyncio +class TestCommandQueueManagerInit: + """Test CommandQueueManager initialization.""" + + async def test_init_default_values(self, mock_meshcore): + """Test initialization with default values.""" + manager = CommandQueueManager(meshcore=mock_meshcore) + assert manager.meshcore is mock_meshcore + assert manager.max_queue_size == 100 + assert manager.queue_full_behavior == QueueFullBehavior.REJECT + await manager.stop() + + async def test_init_custom_values(self, mock_meshcore): + """Test initialization with custom values.""" + manager = CommandQueueManager( + meshcore=mock_meshcore, + max_queue_size=50, + queue_full_behavior=QueueFullBehavior.DROP_OLDEST, + rate_limit_per_second=0.5, + rate_limit_burst=3, + debounce_window_seconds=30.0, + debounce_cache_max_size=500, + debounce_enabled=False, + ) + assert manager.max_queue_size == 50 + assert manager.queue_full_behavior == QueueFullBehavior.DROP_OLDEST + await manager.stop() + + async def test_init_default_debounce_commands(self, mock_meshcore): + """Test default debounce commands are set.""" + manager = CommandQueueManager(meshcore=mock_meshcore) + # Check that debouncer was initialized with default commands + assert manager._debouncer is not None + await manager.stop() + + async def test_init_custom_debounce_commands(self, mock_meshcore): + """Test custom debounce commands can be specified.""" + custom_commands = {CommandType.PING} + manager = CommandQueueManager( + meshcore=mock_meshcore, + debounce_commands=custom_commands, + ) + assert manager._debouncer.enabled_commands == custom_commands + await manager.stop() + + +@pytest.mark.asyncio +class TestCommandQueueManagerStartStop: + """Test start and stop methods.""" + + async def test_start_creates_worker_task(self, queue_manager): + """Test start creates the worker task.""" + await queue_manager.start() + assert queue_manager._worker_task is not None + assert not queue_manager._worker_task.done() + + async def test_start_idempotent(self, queue_manager): + """Test calling start multiple times is safe.""" + await queue_manager.start() + task1 = queue_manager._worker_task + await queue_manager.start() # Should not create new task + assert queue_manager._worker_task is task1 + + async def test_stop_cancels_worker(self, queue_manager): + """Test stop cancels the worker task.""" + await queue_manager.start() + await queue_manager.stop() + assert queue_manager._worker_task.done() or queue_manager._worker_task.cancelled() + + async def test_stop_without_start(self, queue_manager): + """Test stop without start is safe.""" + await queue_manager.stop() # Should not raise + assert True + + +@pytest.mark.asyncio +class TestCommandQueueManagerEnqueue: + """Test enqueue method.""" + + async def test_enqueue_command(self, queue_manager): + """Test enqueueing a command.""" + result, queue_info = await queue_manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": "Hello"}, + ) + assert result.success is True + assert queue_info.position == 1 + assert queue_info.debounced is False + + async def test_enqueue_multiple_commands(self, queue_manager): + """Test enqueueing multiple commands.""" + for i in range(3): + result, queue_info = await queue_manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": f"Message {i}"}, + ) + assert result.success is True + assert queue_info.position == i + 1 + + async def test_enqueue_debounces_duplicate(self, queue_manager): + """Test duplicate commands are debounced.""" + params = {"destination": "a" * 64, "text": "Hello"} + + # First enqueue + result1, info1 = await queue_manager.enqueue(CommandType.SEND_MESSAGE, params) + assert result1.success is True + assert info1.debounced is False + + # Second enqueue (same params - should be debounced) + result2, info2 = await queue_manager.enqueue(CommandType.SEND_MESSAGE, params) + assert result2.success is True + assert info2.debounced is True + + async def test_enqueue_queue_full_reject(self, mock_meshcore): + """Test enqueue raises error when queue is full and behavior is REJECT.""" + manager = CommandQueueManager( + meshcore=mock_meshcore, + max_queue_size=2, + queue_full_behavior=QueueFullBehavior.REJECT, + debounce_enabled=False, # Disable to test queue full + ) + + # Fill the queue + await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "1"}) + await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "2"}) + + # Third should raise + with pytest.raises(QueueFullError): + await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "3"}) + + await manager.stop() + + async def test_enqueue_queue_full_drop_oldest(self, mock_meshcore): + """Test enqueue drops oldest when queue is full and behavior is DROP_OLDEST.""" + manager = CommandQueueManager( + meshcore=mock_meshcore, + max_queue_size=2, + queue_full_behavior=QueueFullBehavior.DROP_OLDEST, + debounce_enabled=False, + ) + + # Fill the queue + await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "1"}) + await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "2"}) + + # Third should succeed by dropping oldest + result, info = await manager.enqueue(CommandType.SEND_MESSAGE, {"destination": "a" * 64, "text": "3"}) + assert result.success is True + assert "dropped" in result.message.lower() + + await manager.stop() + + +@pytest.mark.asyncio +class TestCommandQueueManagerEstimateWaitTime: + """Test wait time estimation.""" + + async def test_estimate_wait_time_empty_queue(self, queue_manager): + """Test wait time is 0 for empty queue.""" + wait_time = queue_manager._estimate_wait_time(0) + assert wait_time == 0.0 + + async def test_estimate_wait_time_with_position(self, queue_manager): + """Test wait time estimation based on position.""" + wait_time = queue_manager._estimate_wait_time(5) + # With rate of 100/sec, 5 commands = 0.05 seconds + assert wait_time == pytest.approx(0.05, abs=0.01) + + async def test_estimate_wait_time_rate_limit_disabled(self, mock_meshcore): + """Test wait time is 0 when rate limiting is disabled.""" + manager = CommandQueueManager( + meshcore=mock_meshcore, + rate_limit_enabled=False, + ) + wait_time = manager._estimate_wait_time(10) + assert wait_time == 0.0 + await manager.stop() + + +@pytest.mark.asyncio +class TestCommandQueueManagerWorker: + """Test the background worker.""" + + async def test_worker_processes_commands(self, queue_manager, mock_meshcore): + """Test worker processes queued commands.""" + await queue_manager.start() + + # Enqueue a command + await queue_manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": "Hello"}, + ) + + # Wait for processing + await asyncio.sleep(0.1) + + # Verify meshcore was called + mock_meshcore.send_message.assert_called() + + async def test_worker_processes_channel_message(self, queue_manager, mock_meshcore): + """Test worker processes channel messages.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.SEND_CHANNEL_MESSAGE, + {"text": "Broadcast"}, + ) + + await asyncio.sleep(0.1) + mock_meshcore.send_channel_message.assert_called() + + async def test_worker_processes_advert(self, queue_manager, mock_meshcore): + """Test worker processes advertisements.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.SEND_ADVERT, + {"flood": True}, + ) + + await asyncio.sleep(0.1) + mock_meshcore.send_advert.assert_called() + + async def test_worker_processes_trace_path(self, queue_manager, mock_meshcore): + """Test worker processes trace path commands.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.SEND_TRACE_PATH, + {"destination": "a" * 64}, + ) + + await asyncio.sleep(0.1) + mock_meshcore.send_trace_path.assert_called() + + async def test_worker_processes_ping(self, queue_manager, mock_meshcore): + """Test worker processes ping commands.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.PING, + {"destination": "a" * 64}, + ) + + await asyncio.sleep(0.1) + mock_meshcore.ping.assert_called() + + async def test_worker_processes_telemetry_request(self, queue_manager, mock_meshcore): + """Test worker processes telemetry request commands.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.SEND_TELEMETRY_REQUEST, + {"destination": "a" * 64}, + ) + + await asyncio.sleep(0.1) + mock_meshcore.send_telemetry_request.assert_called() + + async def test_worker_handles_command_error(self, mock_meshcore): + """Test worker handles command execution errors.""" + mock_meshcore.send_message = AsyncMock(side_effect=Exception("Network error")) + + manager = CommandQueueManager( + meshcore=mock_meshcore, + rate_limit_per_second=100.0, + debounce_enabled=False, + ) + await manager.start() + + await manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": "Hello"}, + ) + + # Should not crash, just log error + await asyncio.sleep(0.1) + await manager.stop() + + +@pytest.mark.asyncio +class TestCommandQueueManagerGetStats: + """Test get_stats method.""" + + async def test_get_stats(self, queue_manager): + """Test getting queue statistics.""" + stats = queue_manager.get_stats() + assert stats.queue_size == 0 + assert stats.max_queue_size == 10 + assert stats.commands_processed_total == 0 + assert stats.commands_dropped_total == 0 + assert stats.commands_debounced_total == 0 + + async def test_get_stats_after_enqueue(self, queue_manager): + """Test stats update after enqueueing.""" + await queue_manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": "Hello"}, + ) + + stats = queue_manager.get_stats() + assert stats.queue_size == 1 + + async def test_get_stats_after_processing(self, queue_manager): + """Test stats update after processing.""" + await queue_manager.start() + + await queue_manager.enqueue( + CommandType.SEND_MESSAGE, + {"destination": "a" * 64, "text": "Hello"}, + ) + + # Wait for processing + await asyncio.sleep(0.1) + + stats = queue_manager.get_stats() + assert stats.commands_processed_total >= 1 + + async def test_get_stats_debounced_count(self, queue_manager): + """Test debounced count in stats.""" + params = {"destination": "a" * 64, "text": "Hello"} + + await queue_manager.enqueue(CommandType.SEND_MESSAGE, params) + await queue_manager.enqueue(CommandType.SEND_MESSAGE, params) # Debounced + + stats = queue_manager.get_stats() + assert stats.commands_debounced_total == 1 + + +@pytest.mark.asyncio +class TestQueueFullError: + """Test QueueFullError exception.""" + + async def test_queue_full_error_message(self): + """Test QueueFullError has correct message.""" + error = QueueFullError("Test message") + assert str(error) == "Test message" + + async def test_queue_full_error_is_exception(self): + """Test QueueFullError is an Exception.""" + error = QueueFullError("Test") + assert isinstance(error, Exception) diff --git a/tests/unit/test_queue_models.py b/tests/unit/test_queue_models.py new file mode 100644 index 0000000..09202ed --- /dev/null +++ b/tests/unit/test_queue_models.py @@ -0,0 +1,269 @@ +"""Unit tests for queue data models.""" + +from datetime import datetime + +import pytest + +from meshcore_api.queue.models import ( + CommandResult, + CommandType, + QueuedCommand, + QueueFullBehavior, + QueueInfo, + QueueStats, +) + + +class TestCommandType: + """Test CommandType enum.""" + + def test_command_types_exist(self): + """Test all expected command types exist.""" + assert CommandType.SEND_MESSAGE.value == "send_message" + assert CommandType.SEND_CHANNEL_MESSAGE.value == "send_channel_message" + assert CommandType.SEND_ADVERT.value == "send_advert" + assert CommandType.SEND_TRACE_PATH.value == "send_trace_path" + assert CommandType.PING.value == "ping" + assert CommandType.SEND_TELEMETRY_REQUEST.value == "send_telemetry_request" + + def test_command_type_is_string_enum(self): + """Test CommandType inherits from str.""" + assert isinstance(CommandType.SEND_MESSAGE, str) + assert CommandType.SEND_MESSAGE == "send_message" + + +class TestQueueFullBehavior: + """Test QueueFullBehavior enum.""" + + def test_behaviors_exist(self): + """Test both queue full behaviors exist.""" + assert QueueFullBehavior.REJECT.value == "reject" + assert QueueFullBehavior.DROP_OLDEST.value == "drop_oldest" + + +class TestQueuedCommand: + """Test QueuedCommand dataclass.""" + + def test_create_queued_command(self): + """Test creating a QueuedCommand.""" + cmd = QueuedCommand( + command_type=CommandType.SEND_MESSAGE, + parameters={"destination": "abc123", "text": "Hello"}, + ) + assert cmd.command_type == CommandType.SEND_MESSAGE + assert cmd.parameters == {"destination": "abc123", "text": "Hello"} + assert cmd.request_id is not None + assert cmd.enqueued_at is not None + assert cmd.command_hash is None + + def test_queued_command_with_hash(self): + """Test QueuedCommand with command hash.""" + cmd = QueuedCommand( + command_type=CommandType.SEND_MESSAGE, + parameters={"text": "Hello"}, + command_hash="abc123hash", + ) + assert cmd.command_hash == "abc123hash" + + def test_queued_command_auto_request_id(self): + """Test QueuedCommand generates unique request IDs.""" + cmd1 = QueuedCommand(command_type=CommandType.PING, parameters={}) + cmd2 = QueuedCommand(command_type=CommandType.PING, parameters={}) + assert cmd1.request_id != cmd2.request_id + + def test_queued_command_to_dict(self): + """Test QueuedCommand serialization.""" + cmd = QueuedCommand( + command_type=CommandType.SEND_CHANNEL_MESSAGE, + parameters={"text": "Hello", "flood": False}, + ) + result = cmd.to_dict() + assert result["command_type"] == "send_channel_message" + assert result["parameters"] == {"text": "Hello", "flood": False} + assert "request_id" in result + assert "enqueued_at" in result + + def test_queued_command_to_dict_datetime_format(self): + """Test QueuedCommand to_dict produces ISO format datetime.""" + cmd = QueuedCommand(command_type=CommandType.PING, parameters={}) + result = cmd.to_dict() + # Should be ISO format string + assert isinstance(result["enqueued_at"], str) + # Should be parseable as ISO datetime + datetime.fromisoformat(result["enqueued_at"]) + + +class TestCommandResult: + """Test CommandResult dataclass.""" + + def test_create_success_result(self): + """Test creating a successful CommandResult.""" + result = CommandResult( + success=True, + message="Command executed successfully", + request_id="test-123", + ) + assert result.success is True + assert result.message == "Command executed successfully" + assert result.request_id == "test-123" + assert result.executed_at is not None + assert result.error is None + assert result.details is None + + def test_create_failure_result(self): + """Test creating a failed CommandResult.""" + result = CommandResult( + success=False, + message="Command failed", + request_id="test-456", + error="Connection timeout", + ) + assert result.success is False + assert result.error == "Connection timeout" + + def test_result_with_details(self): + """Test CommandResult with details.""" + result = CommandResult( + success=True, + message="Done", + request_id="test-789", + details={"hops": 3, "latency_ms": 150}, + ) + assert result.details == {"hops": 3, "latency_ms": 150} + + def test_command_result_to_dict_success(self): + """Test CommandResult to_dict for success.""" + result = CommandResult( + success=True, + message="Success", + request_id="test-1", + ) + d = result.to_dict() + assert d["success"] is True + assert d["message"] == "Success" + assert "error" not in d + + def test_command_result_to_dict_with_error(self): + """Test CommandResult to_dict includes error.""" + result = CommandResult( + success=False, + message="Failed", + request_id="test-2", + error="Network error", + ) + d = result.to_dict() + assert d["error"] == "Network error" + + def test_command_result_to_dict_with_details(self): + """Test CommandResult to_dict merges details.""" + result = CommandResult( + success=True, + message="Done", + request_id="test-3", + details={"extra_field": "value"}, + ) + d = result.to_dict() + assert d["extra_field"] == "value" + + +class TestQueueStats: + """Test QueueStats dataclass.""" + + def test_create_queue_stats(self): + """Test creating QueueStats.""" + stats = QueueStats( + queue_size=10, + max_queue_size=100, + rate_limit_tokens_available=4.5, + debounce_cache_size=50, + commands_processed_total=1000, + commands_dropped_total=5, + commands_debounced_total=200, + ) + assert stats.queue_size == 10 + assert stats.max_queue_size == 100 + assert stats.rate_limit_tokens_available == 4.5 + assert stats.debounce_cache_size == 50 + assert stats.commands_processed_total == 1000 + assert stats.commands_dropped_total == 5 + assert stats.commands_debounced_total == 200 + + def test_queue_stats_to_dict(self): + """Test QueueStats serialization.""" + stats = QueueStats( + queue_size=5, + max_queue_size=50, + rate_limit_tokens_available=2.333333, + debounce_cache_size=25, + commands_processed_total=500, + commands_dropped_total=2, + commands_debounced_total=100, + ) + d = stats.to_dict() + assert d["queue_size"] == 5 + assert d["max_queue_size"] == 50 + # Should be rounded to 2 decimal places + assert d["rate_limit_tokens_available"] == 2.33 + assert d["debounce_cache_size"] == 25 + assert d["commands_processed_total"] == 500 + assert d["commands_dropped_total"] == 2 + assert d["commands_debounced_total"] == 100 + + +class TestQueueInfo: + """Test QueueInfo dataclass.""" + + def test_create_queue_info(self): + """Test creating QueueInfo.""" + info = QueueInfo( + position=3, + estimated_wait_seconds=1.5, + queue_size=10, + ) + assert info.position == 3 + assert info.estimated_wait_seconds == 1.5 + assert info.queue_size == 10 + assert info.debounced is False + assert info.original_request_time is None + + def test_queue_info_debounced(self): + """Test QueueInfo for debounced command.""" + original_time = datetime.utcnow() + info = QueueInfo( + position=1, + estimated_wait_seconds=0.5, + queue_size=5, + debounced=True, + original_request_time=original_time, + ) + assert info.debounced is True + assert info.original_request_time == original_time + + def test_queue_info_to_dict(self): + """Test QueueInfo serialization.""" + info = QueueInfo( + position=2, + estimated_wait_seconds=1.2345, + queue_size=8, + ) + d = info.to_dict() + assert d["position"] == 2 + # Should be rounded + assert d["estimated_wait_seconds"] == 1.23 + assert d["queue_size"] == 8 + assert d["debounced"] is False + assert "original_request_time" not in d + + def test_queue_info_to_dict_with_original_time(self): + """Test QueueInfo serialization includes original time.""" + original_time = datetime(2024, 1, 15, 10, 30, 0) + info = QueueInfo( + position=1, + estimated_wait_seconds=0.5, + queue_size=5, + debounced=True, + original_request_time=original_time, + ) + d = info.to_dict() + assert d["debounced"] is True + assert d["original_request_time"] == "2024-01-15T10:30:00Z" diff --git a/tests/unit/test_webhook_handler.py b/tests/unit/test_webhook_handler.py new file mode 100644 index 0000000..27c444e --- /dev/null +++ b/tests/unit/test_webhook_handler.py @@ -0,0 +1,324 @@ +"""Unit tests for webhook handler.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from meshcore_api.webhook.handler import WebhookHandler + + +@pytest.mark.asyncio +class TestWebhookHandlerInit: + """Test WebhookHandler initialization.""" + + async def test_init_with_all_urls(self): + """Test initialization with all webhook URLs configured.""" + handler = WebhookHandler( + message_direct_url="https://example.com/direct", + message_channel_url="https://example.com/channel", + advertisement_url="https://example.com/advert", + ) + assert handler.message_direct_url == "https://example.com/direct" + assert handler.message_channel_url == "https://example.com/channel" + assert handler.advertisement_url == "https://example.com/advert" + assert handler.timeout == 5 + assert handler.retry_count == 3 + await handler.close() + + async def test_init_with_no_urls(self): + """Test initialization with no webhook URLs configured.""" + handler = WebhookHandler() + assert handler.message_direct_url is None + assert handler.message_channel_url is None + assert handler.advertisement_url is None + await handler.close() + + async def test_init_custom_timeout_and_retry(self): + """Test initialization with custom timeout and retry count.""" + handler = WebhookHandler(timeout=10, retry_count=5) + assert handler.timeout == 10 + assert handler.retry_count == 5 + await handler.close() + + async def test_init_event_url_mapping(self): + """Test event type to URL mapping is set up correctly.""" + handler = WebhookHandler( + message_direct_url="https://example.com/direct", + message_channel_url="https://example.com/channel", + advertisement_url="https://example.com/advert", + ) + assert handler._event_url_map["CONTACT_MSG_RECV"] == "https://example.com/direct" + assert handler._event_url_map["CHANNEL_MSG_RECV"] == "https://example.com/channel" + assert handler._event_url_map["ADVERTISEMENT"] == "https://example.com/advert" + assert handler._event_url_map["NEW_ADVERT"] == "https://example.com/advert" + await handler.close() + + async def test_init_jsonpath_expressions(self): + """Test JSONPath expressions are compiled correctly.""" + handler = WebhookHandler( + message_direct_jsonpath="$.data.text", + message_channel_jsonpath="$.data", + advertisement_jsonpath="$", + ) + assert "CONTACT_MSG_RECV" in handler._jsonpath_map + assert "CHANNEL_MSG_RECV" in handler._jsonpath_map + assert "ADVERTISEMENT" in handler._jsonpath_map + await handler.close() + + +@pytest.mark.asyncio +class TestWebhookHandlerCompileJsonpath: + """Test JSONPath compilation.""" + + async def test_compile_valid_jsonpath(self): + """Test compiling a valid JSONPath expression.""" + handler = WebhookHandler() + # Should have compiled the default "$" expressions + assert handler._jsonpath_map["CONTACT_MSG_RECV"] is not None + await handler.close() + + async def test_compile_invalid_jsonpath_falls_back_to_root(self): + """Test invalid JSONPath falls back to root expression.""" + # JSONPath that would cause parsing issues + handler = WebhookHandler(message_direct_jsonpath="[invalid") + # Should have fallen back to "$" (root expression) + assert handler._jsonpath_map["CONTACT_MSG_RECV"] is not None + await handler.close() + + +@pytest.mark.asyncio +class TestWebhookHandlerApplyJsonpath: + """Test JSONPath application.""" + + async def test_apply_jsonpath_root(self): + """Test applying root JSONPath returns full payload.""" + handler = WebhookHandler(message_direct_jsonpath="$") + payload = {"event_type": "CONTACT_MSG_RECV", "data": {"text": "Hello"}} + result = handler._apply_jsonpath("CONTACT_MSG_RECV", payload) + assert result == payload + await handler.close() + + async def test_apply_jsonpath_nested(self): + """Test applying nested JSONPath returns filtered data.""" + handler = WebhookHandler(message_direct_jsonpath="$.data.text") + payload = {"event_type": "CONTACT_MSG_RECV", "data": {"text": "Hello"}} + result = handler._apply_jsonpath("CONTACT_MSG_RECV", payload) + assert result == "Hello" + await handler.close() + + async def test_apply_jsonpath_no_matches(self): + """Test JSONPath with no matches returns full payload.""" + handler = WebhookHandler(message_direct_jsonpath="$.nonexistent") + payload = {"event_type": "CONTACT_MSG_RECV", "data": {"text": "Hello"}} + result = handler._apply_jsonpath("CONTACT_MSG_RECV", payload) + # Should return full payload when no matches + assert result == payload + await handler.close() + + async def test_apply_jsonpath_unknown_event_type(self): + """Test applying JSONPath for unknown event type.""" + handler = WebhookHandler() + payload = {"event_type": "UNKNOWN", "data": {}} + result = handler._apply_jsonpath("UNKNOWN", payload) + # Should return full payload for unknown types + assert result == payload + await handler.close() + + +@pytest.mark.asyncio +class TestWebhookHandlerSendEvent: + """Test send_event method.""" + + async def test_send_event_no_url_configured(self): + """Test send_event does nothing when no URL configured.""" + handler = WebhookHandler() + # Should not raise, just log and return + await handler.send_event("CONTACT_MSG_RECV", {"text": "Hello"}) + await handler.close() + + async def test_send_event_success(self): + """Test send_event successfully sends webhook.""" + handler = WebhookHandler(message_direct_url="https://example.com/direct") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler.send_event("CONTACT_MSG_RECV", {"text": "Hello"}) + mock_post.assert_called_once() + + await handler.close() + + async def test_send_event_includes_timestamp(self): + """Test send_event includes timestamp in payload.""" + handler = WebhookHandler(message_direct_url="https://example.com/direct") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler.send_event("CONTACT_MSG_RECV", {"text": "Hello"}) + + call_kwargs = mock_post.call_args + payload = call_kwargs.kwargs.get("json", {}) + assert "timestamp" in payload + assert payload["event_type"] == "CONTACT_MSG_RECV" + assert payload["data"] == {"text": "Hello"} + + await handler.close() + + +@pytest.mark.asyncio +class TestWebhookHandlerSendWebhook: + """Test _send_webhook method.""" + + async def test_send_webhook_success(self): + """Test successful webhook delivery.""" + handler = WebhookHandler() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler._send_webhook("https://example.com/webhook", {"event": "test"}) + assert mock_post.call_count == 1 + + await handler.close() + + async def test_send_webhook_retry_on_http_error(self): + """Test webhook retries on HTTP errors.""" + handler = WebhookHandler(retry_count=2) + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError("Server Error", request=MagicMock(), response=mock_response) + ) + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + with patch("asyncio.sleep", new_callable=AsyncMock): + await handler._send_webhook("https://example.com/webhook", {"event": "test"}) + # Should attempt 1 initial + 2 retries = 3 total + assert mock_post.call_count == 3 + + await handler.close() + + async def test_send_webhook_retry_on_timeout(self): + """Test webhook retries on timeout.""" + handler = WebhookHandler(retry_count=1) + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = httpx.TimeoutException("Timeout") + with patch("asyncio.sleep", new_callable=AsyncMock): + await handler._send_webhook("https://example.com/webhook", {"event": "test"}) + # Should attempt 1 initial + 1 retry = 2 total + assert mock_post.call_count == 2 + + await handler.close() + + async def test_send_webhook_no_retries(self): + """Test webhook with zero retry count only makes one attempt.""" + handler = WebhookHandler(retry_count=0) + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError("Server Error", request=MagicMock(), response=mock_response) + ) + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler._send_webhook("https://example.com/webhook", {"event": "test"}) + # Should only attempt once (no retries) + assert mock_post.call_count == 1 + + await handler.close() + + async def test_send_webhook_string_payload(self): + """Test sending string payload as plain text.""" + handler = WebhookHandler() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler._send_webhook("https://example.com/webhook", "Hello World") + + call_kwargs = mock_post.call_args.kwargs + assert call_kwargs["content"] == "Hello World" + assert call_kwargs["headers"]["Content-Type"] == "text/plain" + + await handler.close() + + async def test_send_webhook_primitive_payload(self): + """Test sending primitive payload (number) as JSON.""" + handler = WebhookHandler() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler._send_webhook("https://example.com/webhook", 42) + + call_kwargs = mock_post.call_args.kwargs + assert call_kwargs["content"] == "42" + assert call_kwargs["headers"]["Content-Type"] == "application/json" + + await handler.close() + + async def test_send_webhook_list_payload(self): + """Test sending list payload as JSON.""" + handler = WebhookHandler() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await handler._send_webhook("https://example.com/webhook", [1, 2, 3]) + + call_kwargs = mock_post.call_args.kwargs + assert "json" in call_kwargs + assert call_kwargs["json"] == [1, 2, 3] + + await handler.close() + + async def test_send_webhook_generic_exception(self): + """Test webhook handles generic exceptions.""" + handler = WebhookHandler(retry_count=1) + + with patch.object(handler.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = Exception("Generic error") + with patch("asyncio.sleep", new_callable=AsyncMock): + await handler._send_webhook("https://example.com/webhook", {"event": "test"}) + # Should still retry + assert mock_post.call_count == 2 + + await handler.close() + + +@pytest.mark.asyncio +class TestWebhookHandlerClose: + """Test webhook handler close method.""" + + async def test_close(self): + """Test closing the HTTP client.""" + handler = WebhookHandler() + await handler.close() + # Should complete without error + assert True