diff --git a/roborock/devices/README.md b/roborock/devices/README.md index ff875076..f5add151 100644 --- a/roborock/devices/README.md +++ b/roborock/devices/README.md @@ -490,7 +490,7 @@ For each V1 command: | **RPC Abstraction** | `RpcChannel` with strategies | Helper functions | | **Strategy Pattern** | ✅ Multi-strategy (Local → MQTT) | ❌ Direct MQTT only | | **Health Manager** | ✅ Tracks local/MQTT health | ❌ Not needed | -| **Code Location** | `v1_channel.py` | `a01_channel.py`, `b01_channel.py` | +| **Code Location** | `v1_channel.py` | `a01_channel.py`, `b01_q7_channel.py` | #### Health Management (V1 Only) @@ -572,7 +572,7 @@ roborock/ │ ├── local_channel.py # Local TCP channel implementation │ ├── v1_channel.py # V1 protocol channel with RPC strategies │ ├── a01_channel.py # A01 protocol helpers -│ ├── b01_channel.py # B01 protocol helpers +│ ├── b01_q7_channel.py # B01 Q7 protocol helpers │ └── traits/ # Device-specific command traits │ └── v1/ # V1 device traits │ ├── __init__.py # Trait initialization @@ -585,7 +585,7 @@ roborock/ ├── protocols/ # Protocol encoders/decoders │ ├── v1_protocol.py # V1 JSON RPC protocol │ ├── a01_protocol.py # A01 protocol -│ ├── b01_protocol.py # B01 protocol +│ ├── b01_q7_protocol.py # B01 Q7 protocol │ └── ... └── data/ # Data containers and mappings ├── containers.py # Status, HomeData, etc. diff --git a/roborock/devices/b01_channel.py b/roborock/devices/b01_q7_channel.py similarity index 64% rename from roborock/devices/b01_channel.py rename to roborock/devices/b01_q7_channel.py index 0c9e06d5..55e53dbc 100644 --- a/roborock/devices/b01_channel.py +++ b/roborock/devices/b01_q7_channel.py @@ -8,14 +8,12 @@ from typing import Any from roborock.exceptions import RoborockException -from roborock.protocols.b01_protocol import ( - CommandType, - ParamsType, +from roborock.protocols.b01_q7_protocol import ( + Q7RequestMessage, decode_rpc_response, encode_mqtt_payload, ) from roborock.roborock_message import RoborockMessage -from roborock.util import get_next_int from .mqtt_channel import MqttChannel @@ -25,20 +23,11 @@ async def send_decoded_command( mqtt_channel: MqttChannel, - dps: int, - command: CommandType, - params: ParamsType, + request_message: Q7RequestMessage, ) -> dict[str, Any] | None: """Send a command on the MQTT channel and get a decoded response.""" - msg_id = str(get_next_int(100000000000, 999999999999)) - _LOGGER.debug( - "Sending B01 MQTT command: dps=%s method=%s msg_id=%s params=%s", - dps, - command, - msg_id, - params, - ) - roborock_message = encode_mqtt_payload(dps, command, params, msg_id) + _LOGGER.debug("Sending B01 MQTT command: %s", request_message) + roborock_message = encode_mqtt_payload(request_message) future: asyncio.Future[Any] = asyncio.get_running_loop().create_future() def find_response(response_message: RoborockMessage) -> None: @@ -48,13 +37,12 @@ def find_response(response_message: RoborockMessage) -> None: except RoborockException as ex: _LOGGER.debug( "Failed to decode B01 RPC response (expecting method=%s msg_id=%s): %s: %s", - command, - msg_id, + request_message.command, + request_message.msg_id, response_message, ex, ) return - for dps_value in decoded_dps.values(): # valid responses are JSON strings wrapped in the dps value if not isinstance(dps_value, str): @@ -66,29 +54,22 @@ def find_response(response_message: RoborockMessage) -> None: except (json.JSONDecodeError, TypeError): _LOGGER.debug("Received unexpected response: %s", dps_value) continue - - if isinstance(inner, dict) and inner.get("msgId") == msg_id: + if isinstance(inner, dict) and inner.get("msgId") == str(request_message.msg_id): _LOGGER.debug("Received query response: %s", inner) # Check for error code (0 = success, non-zero = error) code = inner.get("code", 0) if code != 0: - error_msg = ( - f"B01 command failed with code {code} " - f"(method={command}, msg_id={msg_id}, dps={dps}, params={params})" - ) + error_msg = f"B01 command failed with code {code} ({request_message})" _LOGGER.debug("B01 error response: %s", error_msg) if not future.done(): future.set_exception(RoborockException(error_msg)) return data = inner.get("data") # All get commands should be dicts - if command.endswith(".get") and not isinstance(data, dict): + if request_message.command.endswith(".get") and not isinstance(data, dict): if not future.done(): future.set_exception( - RoborockException( - f"Unexpected data type for response " - f"(method={command}, msg_id={msg_id}, dps={dps}, params={params})" - ) + RoborockException(f"Unexpected data type for response {data} ({request_message})") ) return if not future.done(): @@ -101,27 +82,19 @@ def find_response(response_message: RoborockMessage) -> None: await mqtt_channel.publish(roborock_message) return await asyncio.wait_for(future, timeout=_TIMEOUT) except TimeoutError as ex: - raise RoborockException( - f"B01 command timed out after {_TIMEOUT}s (method={command}, msg_id={msg_id}, dps={dps}, params={params})" - ) from ex + raise RoborockException(f"B01 command timed out after {_TIMEOUT}s ({request_message})") from ex except RoborockException as ex: _LOGGER.warning( - "Error sending B01 decoded command (method=%s msg_id=%s dps=%s params=%s): %s", - command, - msg_id, - dps, - params, + "Error sending B01 decoded command (%ss): %s", + request_message, ex, ) raise except Exception as ex: _LOGGER.exception( - "Error sending B01 decoded command (method=%s msg_id=%s dps=%s params=%s): %s", - command, - msg_id, - dps, - params, + "Error sending B01 decoded command (%ss): %s", + request_message, ex, ) raise diff --git a/roborock/devices/traits/b01/q7/__init__.py b/roborock/devices/traits/b01/q7/__init__.py index 2bfa0a6d..47a4c6c2 100644 --- a/roborock/devices/traits/b01/q7/__init__.py +++ b/roborock/devices/traits/b01/q7/__init__.py @@ -10,9 +10,10 @@ SCWindMapping, WaterLevelMapping, ) -from roborock.devices.b01_channel import CommandType, ParamsType, send_decoded_command +from roborock.devices.b01_q7_channel import send_decoded_command from roborock.devices.mqtt_channel import MqttChannel from roborock.devices.traits import Trait +from roborock.protocols.b01_q7_protocol import CommandType, ParamsType, Q7RequestMessage from roborock.roborock_message import RoborockB01Props from roborock.roborock_typing import RoborockB01Q7Methods @@ -104,9 +105,7 @@ async def send(self, command: CommandType, params: ParamsType) -> Any: """Send a command to the device.""" return await send_decoded_command( self._channel, - dps=10000, - command=command, - params=params, + Q7RequestMessage(dps=10000, command=command, params=params), ) diff --git a/roborock/protocols/b01_protocol.py b/roborock/protocols/b01_q7_protocol.py similarity index 73% rename from roborock/protocols/b01_protocol.py rename to roborock/protocols/b01_q7_protocol.py index eeaad03a..06c11143 100644 --- a/roborock/protocols/b01_protocol.py +++ b/roborock/protocols/b01_q7_protocol.py @@ -2,6 +2,7 @@ import json import logging +from dataclasses import dataclass, field from typing import Any from Crypto.Cipher import AES @@ -13,6 +14,7 @@ RoborockMessage, RoborockMessageProtocol, ) +from roborock.util import get_next_int _LOGGER = logging.getLogger(__name__) @@ -21,20 +23,32 @@ ParamsType = list | dict | int | None -def encode_mqtt_payload(dps: int, command: CommandType, params: ParamsType, msg_id: str) -> RoborockMessage: - """Encode payload for B01 commands over MQTT.""" - dps_data = { - "dps": { - dps: { - "method": str(command), - "msgId": msg_id, +@dataclass +class Q7RequestMessage: + """Data class for B01 Q7 request message.""" + + dps: int + command: CommandType + params: ParamsType + msg_id: int = field(default_factory=lambda: get_next_int(100000000000, 999999999999)) + + def to_dps_value(self) -> dict[int, Any]: + """Return the 'dps' payload dictionary.""" + return { + self.dps: { + "method": str(self.command), + "msgId": str(self.msg_id), # Important: some B01 methods use an empty object `{}` (not `[]`) for # "no params", and some setters legitimately send `0` which is falsy. # Only default to `[]` when params is actually None. - "params": params if params is not None else [], + "params": self.params if self.params is not None else [], } } - } + + +def encode_mqtt_payload(request: Q7RequestMessage) -> RoborockMessage: + """Encode payload for B01 commands over MQTT.""" + dps_data = {"dps": request.to_dps_value()} payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) return RoborockMessage( protocol=RoborockMessageProtocol.RPC_REQUEST, diff --git a/tests/devices/traits/b01/q7/__init__.py b/tests/devices/traits/b01/q7/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/devices/traits/b01/test_init.py b/tests/devices/traits/b01/q7/test_init.py similarity index 53% rename from tests/devices/traits/b01/test_init.py rename to tests/devices/traits/b01/q7/test_init.py index 1c3c8745..39acbb53 100644 --- a/tests/devices/traits/b01/test_init.py +++ b/tests/devices/traits/b01/q7/test_init.py @@ -1,4 +1,7 @@ import json +import math +import time +from collections.abc import Generator from typing import Any from unittest.mock import patch @@ -13,35 +16,44 @@ WaterLevelMapping, WorkStatusMapping, ) -from roborock.devices.b01_channel import send_decoded_command +from roborock.devices.b01_q7_channel import send_decoded_command from roborock.devices.traits.b01.q7 import Q7PropertiesApi from roborock.exceptions import RoborockException -from roborock.protocols.b01_protocol import B01_VERSION +from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol from tests.fixtures.channel_fixtures import FakeChannel -def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage: - """Build an encoded B01 RPC response message.""" - dps_payload = { - "dps": { - "10000": json.dumps( - { - "msgId": msg_id, - "data": message, - } - ) +class B01MessageBuilder: + """Helper class to build B01 RPC response messages for tests.""" + + def __init__(self) -> None: + self.msg_id = 123456789 + self.seq = 2020 + + def build(self, data: dict[str, Any] | str, code: int | None = None) -> RoborockMessage: + """Build an encoded B01 RPC response message.""" + message: dict[str, Any] = { + "msgId": str(self.msg_id), + "data": data, } - } - return RoborockMessage( - protocol=RoborockMessageProtocol.RPC_RESPONSE, - payload=pad( - json.dumps(dps_payload).encode(), - AES.block_size, - ), - version=b"B01", - seq=seq, - ) + if code is not None: + message["code"] = code + return self._build_dps(message) + + def _build_dps(self, message: dict[str, Any] | str) -> RoborockMessage: + """Build an encoded B01 RPC response message.""" + dps_payload = {"dps": {"10000": json.dumps(message)}} + self.seq += 1 + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=pad( + json.dumps(dps_payload).encode(), + AES.block_size, + ), + version=b"B01", + seq=self.seq, + ) @pytest.fixture(name="fake_channel") @@ -54,10 +66,31 @@ def q7_api_fixture(fake_channel: FakeChannel) -> Q7PropertiesApi: return Q7PropertiesApi(fake_channel) # type: ignore[arg-type] -async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): - """Test that Q7PropertiesApi correctly converts raw values.""" - expected_msg_id = "123456789" +@pytest.fixture(name="expected_msg_id", autouse=True) +def next_message_id_fixture() -> Generator[int, None, None]: + """Fixture to patch get_next_int to return the expected message ID. + + We pick an arbitrary number, but just need it to ensure we can craft a fake + response with the message id matched to the outgoing RPC. + """ + expected_msg_id = math.floor(time.time()) + + # Patch get_next_int to return our expected msg_id so the channel waits for it + with patch("roborock.protocols.b01_q7_protocol.get_next_int", return_value=expected_msg_id): + yield expected_msg_id + + +@pytest.fixture(name="message_builder") +def message_builder_fixture(expected_msg_id: int) -> B01MessageBuilder: + builder = B01MessageBuilder() + builder.msg_id = expected_msg_id + return builder + +async def test_q7_api_query_values( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): + """Test that Q7PropertiesApi correctly converts raw values.""" # We need to construct the expected result based on the mappings # status: 1 -> WAITING_FOR_ORDERS # wind: 1 -> STANDARD @@ -67,17 +100,15 @@ async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeCh "battery": 100, } - # Patch get_next_int to return our expected msg_id so the channel waits for it - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(expected_msg_id)): - # Queue the response - fake_channel.response_queue.append(build_b01_message(response_data, msg_id=expected_msg_id)) - - result = await q7_api.query_values( - [ - RoborockB01Props.STATUS, - RoborockB01Props.WIND, - ] - ) + # Queue the response + fake_channel.response_queue.append(message_builder.build(response_data)) + + result = await q7_api.query_values( + [ + RoborockB01Props.STATUS, + RoborockB01Props.WIND, + ] + ) assert result is not None assert result.status == WorkStatusMapping.WAITING_FOR_ORDERS @@ -102,7 +133,7 @@ async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeCh assert "10000" in payload_data["dps"] inner = payload_data["dps"]["10000"] assert inner["method"] == "prop.get" - assert inner["msgId"] == expected_msg_id + assert inner["msgId"] == str(message_builder.msg_id) assert inner["params"] == {"property": [RoborockB01Props.STATUS, RoborockB01Props.WIND]} @@ -127,90 +158,42 @@ async def test_q7_response_value_mapping( expected_status: WorkStatusMapping, q7_api: Q7PropertiesApi, fake_channel: FakeChannel, + message_builder: B01MessageBuilder, ): """Test Q7PropertiesApi value mapping for different statuses.""" - msg_id = "987654321" - - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message(response_data, msg_id=msg_id)) + fake_channel.response_queue.append(message_builder.build(response_data)) - result = await q7_api.query_values(query) + result = await q7_api.query_values(query) assert result is not None -async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel): +async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel, message_builder: B01MessageBuilder): """Test validity of handling non-dict responses (should not timeout).""" - msg_id = "123456789" - - dps_payload = { - "dps": { - "10000": json.dumps( - { - "msgId": msg_id, - "data": "some_string_error", - } - ) - } - } - message = RoborockMessage( - protocol=RoborockMessageProtocol.RPC_RESPONSE, - payload=pad( - json.dumps(dps_payload).encode(), - AES.block_size, - ), - version=b"B01", - seq=2021, - ) - + message = message_builder.build("some_string_error") fake_channel.response_queue.append(message) - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - # Use a random string for command type to avoid needing import + # Use a random string for command type to avoid needing import - with pytest.raises(RoborockException, match="Unexpected data type for response"): - await send_decoded_command(fake_channel, 10000, "prop.get", []) # type: ignore[arg-type] + with pytest.raises(RoborockException, match="Unexpected data type for response"): + await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] -async def test_send_decoded_command_error_code(fake_channel: FakeChannel): +async def test_send_decoded_command_error_code(fake_channel: FakeChannel, message_builder: B01MessageBuilder): """Test that non-zero error codes from device are properly handled.""" - msg_id = "999888777" - error_code = 5001 - - dps_payload = { - "dps": { - "10000": json.dumps( - { - "msgId": msg_id, - "code": error_code, - "data": {}, - } - ) - } - } - message = RoborockMessage( - protocol=RoborockMessageProtocol.RPC_RESPONSE, - payload=pad( - json.dumps(dps_payload).encode(), - AES.block_size, - ), - version=b"B01", - seq=2022, - ) - + message = message_builder.build({}, code=5001) fake_channel.response_queue.append(message) - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - with pytest.raises(RoborockException, match=f"B01 command failed with code {error_code}"): - await send_decoded_command(fake_channel, 10000, "prop.get", []) # type: ignore[arg-type] + with pytest.raises(RoborockException, match="B01 command failed with code 5001"): + await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] -async def test_q7_api_set_fan_speed(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_set_fan_speed( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test setting fan speed.""" - msg_id = "12345" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.set_fan_speed(SCWindMapping.STRONG) + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.set_fan_speed(SCWindMapping.STRONG) assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -219,12 +202,12 @@ async def test_q7_api_set_fan_speed(q7_api: Q7PropertiesApi, fake_channel: FakeC assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.WIND: SCWindMapping.STRONG.code} -async def test_q7_api_set_water_level(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_set_water_level( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test setting water level.""" - msg_id = "12346" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.set_water_level(WaterLevelMapping.HIGH) + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.set_water_level(WaterLevelMapping.HIGH) assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -233,12 +216,12 @@ async def test_q7_api_set_water_level(q7_api: Q7PropertiesApi, fake_channel: Fak assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.WATER: WaterLevelMapping.HIGH.code} -async def test_q7_api_start_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_start_clean( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test starting cleaning.""" - msg_id = "12347" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.start_clean() + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.start_clean() assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -251,12 +234,12 @@ async def test_q7_api_start_clean(q7_api: Q7PropertiesApi, fake_channel: FakeCha } -async def test_q7_api_pause_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_pause_clean( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test pausing cleaning.""" - msg_id = "12348" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.pause_clean() + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.pause_clean() assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -269,12 +252,12 @@ async def test_q7_api_pause_clean(q7_api: Q7PropertiesApi, fake_channel: FakeCha } -async def test_q7_api_stop_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_stop_clean( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test stopping cleaning.""" - msg_id = "12349" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.stop_clean() + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.stop_clean() assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -287,12 +270,12 @@ async def test_q7_api_stop_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChan } -async def test_q7_api_return_to_dock(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_return_to_dock( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): """Test returning to dock.""" - msg_id = "12350" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.return_to_dock() + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.return_to_dock() assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] @@ -301,12 +284,10 @@ async def test_q7_api_return_to_dock(q7_api: Q7PropertiesApi, fake_channel: Fake assert payload_data["dps"]["10000"]["params"] == {} -async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): +async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder): """Test locating the device.""" - msg_id = "12351" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): - fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) - await q7_api.find_me() + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.find_me() assert len(fake_channel.published_messages) == 1 message = fake_channel.published_messages[0] diff --git a/tests/protocols/test_b01_q07_protocol.py b/tests/protocols/test_b01_q07_protocol.py index db47bc2f..a507fa41 100644 --- a/tests/protocols/test_b01_q07_protocol.py +++ b/tests/protocols/test_b01_q07_protocol.py @@ -10,14 +10,11 @@ from freezegun import freeze_time from syrupy import SnapshotAssertion -from roborock.protocols.b01_protocol import ( - decode_rpc_response, - encode_mqtt_payload, -) +from roborock.protocols.b01_q7_protocol import Q7RequestMessage, decode_rpc_response, encode_mqtt_payload from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_protocol") -TESTDATA_FILES = list(TESTDATA_PATH.glob("**/*.json")) +TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_q7_protocol") +TESTDATA_FILES = list(TESTDATA_PATH.glob("*.json")) TESTDATA_IDS = [x.stem for x in TESTDATA_FILES] @@ -54,14 +51,14 @@ def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None: 10000, "prop.get", {"property": ["status", "fault"]}, - "123456789", + 123456789, ), ], ) -def test_encode_mqtt_payload(dps: int, command: str, params: dict[str, list[str]], msg_id: str) -> None: +def test_encode_mqtt_payload(dps: int, command: str, params: dict[str, list[str]], msg_id: int) -> None: """Test encoding of MQTT payload for B01 commands.""" - message = encode_mqtt_payload(dps, command, params, msg_id) + message = encode_mqtt_payload(Q7RequestMessage(dps, command, params, msg_id)) assert isinstance(message, RoborockMessage) assert message.protocol == RoborockMessageProtocol.RPC_REQUEST assert message.version == b"B01" @@ -70,5 +67,5 @@ def test_encode_mqtt_payload(dps: int, command: str, params: dict[str, list[str] decoded_json = json.loads(unpadded.decode("utf-8")) assert decoded_json["dps"][str(dps)]["method"] == command - assert decoded_json["dps"][str(dps)]["msgId"] == msg_id + assert decoded_json["dps"][str(dps)]["msgId"] == str(msg_id) assert decoded_json["dps"][str(dps)]["params"] == params diff --git a/tests/protocols/testdata/b01_protocol/q7/get_prop.json b/tests/protocols/testdata/b01_q7_protocol/get_prop.json similarity index 100% rename from tests/protocols/testdata/b01_protocol/q7/get_prop.json rename to tests/protocols/testdata/b01_q7_protocol/get_prop.json