From 55e75d67770c747b8d0a89ca5a769b708359830e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Fri, 26 Dec 2025 20:57:46 -0800 Subject: [PATCH 1/6] feat: Add initial Q10 support --- roborock/cli.py | 38 +++++-- roborock/data/code_mappings.py | 17 +++ roborock/devices/b01_q10_channel.py | 63 +++++++++++ .../{b01_channel.py => b01_q7_channel.py} | 2 +- roborock/devices/device.py | 14 ++- roborock/devices/device_manager.py | 4 +- roborock/devices/mqtt_channel.py | 18 ++- roborock/devices/traits/b01/__init__.py | 8 +- roborock/devices/traits/b01/q10/__init__.py | 105 +++++++++++++++++- roborock/devices/traits/b01/q7/__init__.py | 2 +- roborock/devices/traits/traits_mixin.py | 3 + roborock/protocols/b01_q10_protocol.py | 64 +++++++++++ .../{b01_protocol.py => b01_q7_protocol.py} | 0 tests/data/test_code_mappings.py | 27 +++++ tests/devices/test_mqtt_channel.py | 31 ++++++ tests/devices/traits/b01/q10/__init__.py | 0 tests/devices/traits/b01/q10/test_init.py | 79 +++++++++++++ tests/devices/traits/b01/q7/__init__.py | 0 .../devices/traits/b01/{ => q7}/test_init.py | 26 ++--- tests/fixtures/channel_fixtures.py | 14 ++- .../__snapshots__/test_b01_q10_protocol.ambr | 90 +++++++++++++++ tests/protocols/test_b01_q07_protocol.py | 6 +- tests/protocols/test_b01_q10_protocol.py | 69 ++++++++++++ .../testdata/b01_protocol/q10/dpBattery.json | 1 + .../testdata/b01_protocol/q10/dpFault.json | 1 + .../b01_protocol/q10/dpRequetdps.json | 1 + .../q10/dpStatus-dpCleanTaskType.json | 1 + 27 files changed, 647 insertions(+), 37 deletions(-) create mode 100644 roborock/devices/b01_q10_channel.py rename roborock/devices/{b01_channel.py => b01_q7_channel.py} (98%) create mode 100644 roborock/protocols/b01_q10_protocol.py rename roborock/protocols/{b01_protocol.py => b01_q7_protocol.py} (100%) create mode 100644 tests/data/test_code_mappings.py create mode 100644 tests/devices/traits/b01/q10/__init__.py create mode 100644 tests/devices/traits/b01/q10/test_init.py create mode 100644 tests/devices/traits/b01/q7/__init__.py rename tests/devices/traits/b01/{ => q7}/test_init.py (90%) create mode 100644 tests/protocols/__snapshots__/test_b01_q10_protocol.ambr create mode 100644 tests/protocols/test_b01_q10_protocol.py create mode 100644 tests/protocols/testdata/b01_protocol/q10/dpBattery.json create mode 100644 tests/protocols/testdata/b01_protocol/q10/dpFault.json create mode 100644 tests/protocols/testdata/b01_protocol/q10/dpRequetdps.json create mode 100644 tests/protocols/testdata/b01_protocol/q10/dpStatus-dpCleanTaskType.json diff --git a/roborock/cli.py b/roborock/cli.py index b043516d..8c1810c5 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -43,6 +43,7 @@ from roborock import SHORT_MODEL_TO_ENUM, RoborockCommand from roborock.data import DeviceData, RoborockBase, UserData +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP from roborock.device_features import DeviceFeatures from roborock.devices.cache import Cache, CacheData from roborock.devices.device import RoborockDevice @@ -91,7 +92,12 @@ def wrapper(*args, **kwargs): context: RoborockContext = ctx.obj async def run(): - return await func(*args, **kwargs) + try: + await func(*args, **kwargs) + except Exception: + _LOGGER.exception("Uncaught exception in command") + click.echo(f"Error: {sys.exc_info()[1]}", err=True) + await context.cleanup() if context.is_session_mode(): # Session mode - run in the persistent loop @@ -739,6 +745,16 @@ async def network_info(ctx, device_id: str): await _display_v1_trait(context, device_id, lambda v1: v1.network_info) +def _parse_b01_q10_command(cmd: str) -> B01_Q10_DP | None: + """Parse B01_Q10 command from either enum name or value.""" + for func in (B01_Q10_DP.from_code, B01_Q10_DP.from_name, B01_Q10_DP.from_value): + try: + return func(cmd) + except ValueError: + continue + return None + + @click.command() @click.option("--device_id", required=True) @click.option("--cmd", required=True) @@ -749,12 +765,20 @@ async def command(ctx, cmd, device_id, params): context: RoborockContext = ctx.obj device_manager = await context.get_device_manager() device = await device_manager.get_device(device_id) - if device.v1_properties is None: - raise RoborockException(f"Device {device.name} does not support V1 protocol") - command_trait: Trait = device.v1_properties.command - result = await command_trait.send(cmd, json.loads(params) if params is not None else None) - if result: - click.echo(dump_json(result)) + if device.v1_properties is not None: + command_trait: Trait = device.v1_properties.command + result = await command_trait.send(cmd, json.loads(params) if params is not None else {}) + if result: + click.echo(dump_json(result)) + elif device.b01_q10_properties is not None: + # Parse B01_Q10_DP from either enum name or the value + if (cmd_value := _parse_b01_q10_command(cmd)) is None: + raise RoborockException(f"Invalid command {cmd} for B01_Q10 device") + await device.b01_q10_properties.send(cmd_value, json.loads(params) if params is not None else {}) + # B10 Commands don't have a specific time to respond, so wait a bit + await asyncio.sleep(5) + else: + raise RoborockException(f"Device {device.name} does not support sending raw commands") @click.command() diff --git a/roborock/data/code_mappings.py b/roborock/data/code_mappings.py index fc5fb0cb..db09a5d5 100644 --- a/roborock/data/code_mappings.py +++ b/roborock/data/code_mappings.py @@ -65,11 +65,28 @@ def __new__(cls, value: str, code: int) -> RoborockModeEnum: @classmethod def from_code(cls, code: int): + """Find enum member by code.""" for member in cls: if member.code == code: return member raise ValueError(f"{code} is not a valid code for {cls.__name__}") + @classmethod + def from_name(cls, name: str): + """Find enum member by name (case-insensitive).""" + for member in cls: + if member.name.lower() == name.lower(): + return member + raise ValueError(f"{name} is not a valid name for {cls.__name__}") + + @classmethod + def from_value(cls, value: str): + """Find enum member by value (case-insensitive).""" + for member in cls: + if member.value.lower() == value.lower(): + return member + raise ValueError(f"{value} is not a valid value for {cls.__name__}") + @classmethod def keys(cls) -> list[str]: """Returns a list of all member values.""" diff --git a/roborock/devices/b01_q10_channel.py b/roborock/devices/b01_q10_channel.py new file mode 100644 index 00000000..7970f9a1 --- /dev/null +++ b/roborock/devices/b01_q10_channel.py @@ -0,0 +1,63 @@ +"""Thin wrapper around the MQTT channel for Roborock B01 devices.""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.exceptions import RoborockException +from roborock.protocols.b01_q10_protocol import ( + ParamsType, + decode_rpc_response, + encode_mqtt_payload, +) + +from .mqtt_channel import MqttChannel + +_LOGGER = logging.getLogger(__name__) +_TIMEOUT = 10.0 + + +async def send_command( + mqtt_channel: MqttChannel, + command: B01_Q10_DP, + params: ParamsType, +) -> None: + """Send a command on the MQTT channel, without waiting for a response""" + _LOGGER.debug( + "Sending B01 MQTT command: cmd=%s params=%s", + command, + params, + ) + roborock_message = encode_mqtt_payload(command, params) + _LOGGER.debug("Sending MQTT message: %s", roborock_message) + try: + await mqtt_channel.publish(roborock_message) + except RoborockException as ex: + _LOGGER.debug( + "Error sending B01 decoded command (method=%s params=%s): %s", + command, + params, + ex, + ) + raise + + +async def stream_decoded_responses( + mqtt_channel: MqttChannel, +) -> AsyncGenerator[dict[B01_Q10_DP, Any], None]: + """Stream decoded DPS messages received via MQTT.""" + + async for response_message in mqtt_channel.subscribe_stream(): + try: + decoded_dps = decode_rpc_response(response_message) + except RoborockException as ex: + _LOGGER.debug( + "Failed to decode B01 RPC response: %s: %s", + response_message, + ex, + ) + continue + yield decoded_dps diff --git a/roborock/devices/b01_channel.py b/roborock/devices/b01_q7_channel.py similarity index 98% rename from roborock/devices/b01_channel.py rename to roborock/devices/b01_q7_channel.py index 0c9e06d5..6b9a076e 100644 --- a/roborock/devices/b01_channel.py +++ b/roborock/devices/b01_q7_channel.py @@ -8,7 +8,7 @@ from typing import Any from roborock.exceptions import RoborockException -from roborock.protocols.b01_protocol import ( +from roborock.protocols.b01_q7_protocol import ( CommandType, ParamsType, decode_rpc_response, diff --git a/roborock/devices/device.py b/roborock/devices/device.py index e58bac9d..f169ac6d 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -195,12 +195,14 @@ async def connect(self) -> None: if self._unsub: raise ValueError("Already connected to the device") unsub = await self._channel.subscribe(self._on_message) - if self.v1_properties is not None: - try: + try: + if self.v1_properties is not None: await self.v1_properties.discover_features() - except RoborockException: - unsub() - raise + elif self.b01_q10_properties is not None: + await self.b01_q10_properties.start() + except RoborockException: + unsub() + raise self._logger.info("Connected to device") self._unsub = unsub @@ -212,6 +214,8 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self.b01_q10_properties is not None: + await self.b01_q10_properties.close() if self._unsub: self._unsub() self._unsub = None diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 2ff01085..72ad3177 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -240,9 +240,7 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) model_part = product.model.split(".")[-1] if "ss" in model_part: - raise UnsupportedDeviceError( - f"Device {device.name} has unsupported version B01 product model {product.model}" - ) + trait = b01.q10.create(channel) elif "sc" in model_part: # Q7 devices start with 'sc' in their model naming. trait = b01.q7.create(channel) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 498cef13..dcbccb81 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -1,7 +1,8 @@ """Modules for communicating with specific Roborock devices over MQTT.""" +import asyncio import logging -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from roborock.callbacks import decoder_callback from roborock.data import HomeDataDevice, RRiot, UserData @@ -73,6 +74,21 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab dispatch = decoder_callback(self._decoder, callback, _LOGGER) return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch) + async def subscribe_stream(self) -> AsyncGenerator[RoborockMessage, None]: + """Subscribe to the device's message stream. + + This is useful for processing all incoming messages in an async for loop, + when they are not necessarily associated with a specific request. + """ + message_queue: asyncio.Queue[RoborockMessage] = asyncio.Queue() + unsub = await self.subscribe(message_queue.put_nowait) + try: + while True: + message = await message_queue.get() + yield message + finally: + unsub() + async def publish(self, message: RoborockMessage) -> None: """Publish a command message. diff --git a/roborock/devices/traits/b01/__init__.py b/roborock/devices/traits/b01/__init__.py index bf6d8b23..e729c686 100644 --- a/roborock/devices/traits/b01/__init__.py +++ b/roborock/devices/traits/b01/__init__.py @@ -1,5 +1,11 @@ """Traits for B01 devices.""" from .q7 import Q7PropertiesApi +from .q10 import Q10PropertiesApi -__all__ = ["Q7PropertiesApi", "q7", "q10"] +__all__ = [ + "Q7PropertiesApi", + "Q10PropertiesApi", + "q7", + "q10", +] diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index b3cd30d6..0e5661b8 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -1 +1,104 @@ -"""Q10""" +"""Traits for Q10 B01 devices.""" + +import asyncio +import logging +from typing import Any + +from roborock import B01Props +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.b01_q10_channel import ParamsType, send_command, stream_decoded_responses +from roborock.devices.mqtt_channel import MqttChannel +from roborock.devices.traits import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "Q10PropertiesApi", +] + + +class Q10PropertiesApi(Trait): + """API for interacting with B01 devices.""" + + def __init__(self, channel: MqttChannel) -> None: + """Initialize the B01Props API.""" + self._channel = channel + self._task: asyncio.Task | None = None + + async def start(self) -> None: + """Start any necessary subscriptions for the trait.""" + self._task = asyncio.create_task(self._run_loop()) + + async def close(self) -> None: + """Close any resources held by the trait.""" + if self._task is not None: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + + async def start_clean(self) -> None: + """Start cleaning.""" + await self.send( + command=B01_Q10_DP.START_CLEAN, + # TODO: figure out other commands + # 1 = start cleaning + # 2 = electoral clean, also has "clean_paramters" + # 4 = fast create map + params={"cmd": 1}, + ) + + async def pause_clean(self) -> None: + """Pause cleaning.""" + await self.send( + command=B01_Q10_DP.PAUSE, + params={}, + ) + + async def resume_clean(self) -> None: + """Pause cleaning.""" + await self.send( + command=B01_Q10_DP.RESUME, + params={}, + ) + + async def stop_clean(self) -> None: + """Stop cleaning.""" + await self.send( + command=B01_Q10_DP.STOP, + params={}, + ) + + async def return_to_dock(self) -> None: + """Return to dock.""" + await self.send( + command=B01_Q10_DP.START_DOCK_TASK, + params={}, + ) + + async def send(self, command: B01_Q10_DP, params: ParamsType) -> None: + """Send a command to the device.""" + await send_command( + self._channel, + command=command, + params=params, + ) + + async def _run_loop(self) -> None: + """Run the main loop for processing incoming messages.""" + async for decoded_dps in stream_decoded_responses(self._channel): + _LOGGER.debug("Received B01 Q10 decoded DPS: %s", decoded_dps) + + # Temporary debugging: Log all common values + if B01_Q10_DP.COMMON not in decoded_dps: + continue + common_values = decoded_dps[B01_Q10_DP.COMMON] + for key, value in common_values.items(): + _LOGGER.debug("%s: %s", key, value) + + +def create(channel: MqttChannel) -> Q10PropertiesApi: + """Create traits for B01 devices.""" + return Q10PropertiesApi(channel) diff --git a/roborock/devices/traits/b01/q7/__init__.py b/roborock/devices/traits/b01/q7/__init__.py index 2bfa0a6d..21ce4856 100644 --- a/roborock/devices/traits/b01/q7/__init__.py +++ b/roborock/devices/traits/b01/q7/__init__.py @@ -10,7 +10,7 @@ SCWindMapping, WaterLevelMapping, ) -from roborock.devices.b01_channel import CommandType, ParamsType, send_decoded_command +from roborock.devices.b01_q7_channel import CommandType, ParamsType, send_decoded_command from roborock.devices.mqtt_channel import MqttChannel from roborock.devices.traits import Trait from roborock.roborock_message import RoborockB01Props diff --git a/roborock/devices/traits/traits_mixin.py b/roborock/devices/traits/traits_mixin.py index 92b9597e..60c15640 100644 --- a/roborock/devices/traits/traits_mixin.py +++ b/roborock/devices/traits/traits_mixin.py @@ -34,6 +34,9 @@ class TraitsMixin: b01_q7_properties: b01.Q7PropertiesApi | None = None """B01 Q7 properties trait, if supported.""" + b01_q10_properties: b01.Q10PropertiesApi | None = None + """B01 Q10 properties trait, if supported.""" + def __init__(self, trait: Trait) -> None: """Initialize the TraitsMixin with the given trait. diff --git a/roborock/protocols/b01_q10_protocol.py b/roborock/protocols/b01_q10_protocol.py new file mode 100644 index 00000000..71119030 --- /dev/null +++ b/roborock/protocols/b01_q10_protocol.py @@ -0,0 +1,64 @@ +"""Roborock B01 Protocol encoding and decoding.""" + +import json +import logging +from typing import Any + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.exceptions import RoborockException +from roborock.roborock_message import ( + RoborockMessage, + RoborockMessageProtocol, +) + +_LOGGER = logging.getLogger(__name__) + +B01_VERSION = b"B01" +ParamsType = list | dict | int | None + + +def encode_mqtt_payload(command: B01_Q10_DP, params: ParamsType) -> RoborockMessage: + """Encode payload for B01 commands over MQTT.""" + dps_data = { + "dps": { + command.code: params, + } + } + # payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size) + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + version=B01_VERSION, + payload=json.dumps(dps_data).encode("utf-8"), + ) + + +def _convert_datapoints(datapoints: dict[str, Any], message: RoborockMessage) -> dict[B01_Q10_DP, Any]: + """Convert the 'dps' dictionary keys from strings to B01_Q10_DP enums.""" + result = {} + for key, value in datapoints.items(): + if not isinstance(key, str): + raise RoborockException(f"Invalid B01 message format: 'dps' keys should be strings for {message.payload!r}") + dps = B01_Q10_DP.from_code(int(key)) + result[dps] = value + return result + + +def decode_rpc_response(message: RoborockMessage) -> dict[B01_Q10_DP, Any]: + """Decode a B01 RPC_RESPONSE message.""" + if not message.payload: + raise RoborockException("Invalid B01 message format: missing payload") + try: + payload = json.loads(message.payload.decode()) + except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: + raise RoborockException(f"Invalid B01 message payload: {e} for {message.payload!r}") from e + + datapoints = payload.get("dps", {}) + if not isinstance(datapoints, dict): + raise RoborockException(f"Invalid B01 message format: 'dps' should be a dictionary for {message.payload!r}") + + result = _convert_datapoints(datapoints, message) + # The COMMON response contains nested datapoints that also need conversion + if common_result := result.get(B01_Q10_DP.COMMON): + common_dps_result = _convert_datapoints(common_result, message) + result[B01_Q10_DP.COMMON] = common_dps_result + return result diff --git a/roborock/protocols/b01_protocol.py b/roborock/protocols/b01_q7_protocol.py similarity index 100% rename from roborock/protocols/b01_protocol.py rename to roborock/protocols/b01_q7_protocol.py diff --git a/tests/data/test_code_mappings.py b/tests/data/test_code_mappings.py new file mode 100644 index 00000000..97faab9e --- /dev/null +++ b/tests/data/test_code_mappings.py @@ -0,0 +1,27 @@ +"""Tests for code mappings. + +These tests exercise the custom enum methods using arbitrary enum values. +""" + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP + + +def test_from_code(): + """Test from_code method.""" + assert B01_Q10_DP.START_CLEAN == B01_Q10_DP.from_code(201) + assert B01_Q10_DP.PAUSE == B01_Q10_DP.from_code(204) + assert B01_Q10_DP.STOP == B01_Q10_DP.from_code(206) + + +def test_from_name(): + """Test from_name method.""" + assert B01_Q10_DP.START_CLEAN == B01_Q10_DP.from_name("START_CLEAN") + assert B01_Q10_DP.PAUSE == B01_Q10_DP.from_name("pause") + assert B01_Q10_DP.STOP == B01_Q10_DP.from_name("Stop") + + +def test_from_value(): + """Test from_value method.""" + assert B01_Q10_DP.START_CLEAN == B01_Q10_DP.from_value("dpStartClean") + assert B01_Q10_DP.PAUSE == B01_Q10_DP.from_value("dpPause") + assert B01_Q10_DP.STOP == B01_Q10_DP.from_value("dpStop") diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index 13dc474a..5bcee7dc 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -269,3 +269,34 @@ def failing_callback(message: RoborockMessage) -> None: # Unsubscribe all remaining subscribers unsub1() unsub2() + + +async def test_subscribe_stream(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None: + """Test multiple concurrent subscribers receive all messages.""" + + messages: asyncio.Queue[RoborockMessage] = asyncio.Queue() + + async def task() -> None: + async for message in mqtt_channel.subscribe_stream(): + await messages.put(message) + + subscriber_task = asyncio.create_task(task()) + await asyncio.sleep(0.01) # yield + + handler = mqtt_session.subscribe.call_args_list[0][0][1] + + # Simulate receiving messages - each handler should decode the message independently + handler(ENCODER(TEST_REQUEST)) + handler(ENCODER(TEST_REQUEST)) + handler(ENCODER(TEST_REQUEST)) + + async with asyncio.timeout(10): + resp = await messages.get() + assert resp == TEST_REQUEST + resp = await messages.get() + assert resp == TEST_REQUEST + resp = await messages.get() + assert resp == TEST_REQUEST + assert messages.empty() + + subscriber_task.cancel() diff --git a/tests/devices/traits/b01/q10/__init__.py b/tests/devices/traits/b01/q10/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/devices/traits/b01/q10/test_init.py b/tests/devices/traits/b01/q10/test_init.py new file mode 100644 index 00000000..fa176f3b --- /dev/null +++ b/tests/devices/traits/b01/q10/test_init.py @@ -0,0 +1,79 @@ +import asyncio +import json +import pathlib +from collections.abc import AsyncGenerator + +import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.traits.b01.q10 import Q10PropertiesApi +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from tests.fixtures.channel_fixtures import FakeChannel + +PAYLOAD_FILE = pathlib.Path("tests/protocols/testdata/b01_protocol/q10/dpRequetdps.json") +TEST_RESPONSE_PAYLOAD = PAYLOAD_FILE.read_bytes() + + +def build_b01_message(dps_payload: bytes, seq: int) -> RoborockMessage: + """Build an encoded B01 RPC response message.""" + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=pad(dps_payload, AES.block_size), + version=b"B01", + seq=seq, + ) + + +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeChannel: + return FakeChannel() + + +@pytest.fixture(name="q10_api") +async def q10_api_fixture(fake_channel: FakeChannel) -> AsyncGenerator[Q10PropertiesApi, None]: + properties = Q10PropertiesApi(fake_channel) # type: ignore[arg-type] + await properties.start() + yield properties + await properties.close() + + +async def test_subscribe(q10_api: Q10PropertiesApi, fake_channel: FakeChannel): + """Test that Q10PropertiesApi handles incoming messages.""" + assert len(fake_channel.subscribers) == 1 + + message_callback = fake_channel.subscribers[0] + message_callback(build_b01_message(TEST_RESPONSE_PAYLOAD, seq=12345)) + + # We currently don't do anything with the incoming messages in this test, + # but we want to ensure no exceptions are raised during processing. + await asyncio.sleep(0.1) # Allow some time for the message to be processed + + +async def test_start_clean(q10_api: Q10PropertiesApi, fake_channel: FakeChannel): + """Test sending a command via Q10PropertiesApi.""" + await q10_api.start_clean() + + assert len(fake_channel.published_messages) == 1 + published_message = fake_channel.published_messages[0] + assert published_message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert published_message.version == b"B01" + assert published_message.payload + + payload_data = json.loads(published_message.payload.decode()) + assert payload_data == {"dps": {"201": {"cmd": 1}}} + + +async def test_send_command(q10_api: Q10PropertiesApi, fake_channel: FakeChannel): + """Test sending a command via Q10PropertiesApi.""" + await q10_api.send(B01_Q10_DP.REQUETDPS, {}) + + assert len(fake_channel.published_messages) == 1 + published_message = fake_channel.published_messages[0] + assert published_message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert published_message.version == b"B01" + assert published_message.payload + + payload_data = json.loads(published_message.payload.decode()) + assert payload_data == {"dps": {"102": {}}} diff --git a/tests/devices/traits/b01/q7/__init__.py b/tests/devices/traits/b01/q7/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/devices/traits/b01/test_init.py b/tests/devices/traits/b01/q7/test_init.py similarity index 90% rename from tests/devices/traits/b01/test_init.py rename to tests/devices/traits/b01/q7/test_init.py index 1c3c8745..06850d69 100644 --- a/tests/devices/traits/b01/test_init.py +++ b/tests/devices/traits/b01/q7/test_init.py @@ -13,10 +13,10 @@ WaterLevelMapping, WorkStatusMapping, ) -from roborock.devices.b01_channel import send_decoded_command +from roborock.devices.b01_q7_channel import send_decoded_command from roborock.devices.traits.b01.q7 import Q7PropertiesApi from roborock.exceptions import RoborockException -from roborock.protocols.b01_protocol import B01_VERSION +from roborock.protocols.b01_q10_protocol import B01_VERSION from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol from tests.fixtures.channel_fixtures import FakeChannel @@ -68,7 +68,7 @@ async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeCh } # Patch get_next_int to return our expected msg_id so the channel waits for it - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(expected_msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(expected_msg_id)): # Queue the response fake_channel.response_queue.append(build_b01_message(response_data, msg_id=expected_msg_id)) @@ -131,7 +131,7 @@ async def test_q7_response_value_mapping( """Test Q7PropertiesApi value mapping for different statuses.""" msg_id = "987654321" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message(response_data, msg_id=msg_id)) result = await q7_api.query_values(query) @@ -165,7 +165,7 @@ async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel) fake_channel.response_queue.append(message) - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): # Use a random string for command type to avoid needing import with pytest.raises(RoborockException, match="Unexpected data type for response"): @@ -200,7 +200,7 @@ async def test_send_decoded_command_error_code(fake_channel: FakeChannel): fake_channel.response_queue.append(message) - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): with pytest.raises(RoborockException, match=f"B01 command failed with code {error_code}"): await send_decoded_command(fake_channel, 10000, "prop.get", []) # type: ignore[arg-type] @@ -208,7 +208,7 @@ async def test_send_decoded_command_error_code(fake_channel: FakeChannel): async def test_q7_api_set_fan_speed(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test setting fan speed.""" msg_id = "12345" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.set_fan_speed(SCWindMapping.STRONG) @@ -222,7 +222,7 @@ async def test_q7_api_set_fan_speed(q7_api: Q7PropertiesApi, fake_channel: FakeC async def test_q7_api_set_water_level(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test setting water level.""" msg_id = "12346" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.set_water_level(WaterLevelMapping.HIGH) @@ -236,7 +236,7 @@ async def test_q7_api_set_water_level(q7_api: Q7PropertiesApi, fake_channel: Fak async def test_q7_api_start_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test starting cleaning.""" msg_id = "12347" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.start_clean() @@ -254,7 +254,7 @@ async def test_q7_api_start_clean(q7_api: Q7PropertiesApi, fake_channel: FakeCha async def test_q7_api_pause_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test pausing cleaning.""" msg_id = "12348" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.pause_clean() @@ -272,7 +272,7 @@ async def test_q7_api_pause_clean(q7_api: Q7PropertiesApi, fake_channel: FakeCha async def test_q7_api_stop_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test stopping cleaning.""" msg_id = "12349" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.stop_clean() @@ -290,7 +290,7 @@ async def test_q7_api_stop_clean(q7_api: Q7PropertiesApi, fake_channel: FakeChan async def test_q7_api_return_to_dock(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test returning to dock.""" msg_id = "12350" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.return_to_dock() @@ -304,7 +304,7 @@ async def test_q7_api_return_to_dock(q7_api: Q7PropertiesApi, fake_channel: Fake async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel): """Test locating the device.""" msg_id = "12351" - with patch("roborock.devices.b01_channel.get_next_int", return_value=int(msg_id)): + with patch("roborock.devices.b01_q7_channel.get_next_int", return_value=int(msg_id)): fake_channel.response_queue.append(build_b01_message({"result": "ok"}, msg_id=msg_id)) await q7_api.find_me() diff --git a/tests/fixtures/channel_fixtures.py b/tests/fixtures/channel_fixtures.py index 1faae11c..8e4296f9 100644 --- a/tests/fixtures/channel_fixtures.py +++ b/tests/fixtures/channel_fixtures.py @@ -1,4 +1,5 @@ -from collections.abc import Callable +import asyncio +from collections.abc import AsyncGenerator, Callable from unittest.mock import AsyncMock, MagicMock from roborock.mqtt.health_manager import HealthManager @@ -51,3 +52,14 @@ async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Calla """Simulate subscribing to messages.""" self.subscribers.append(callback) return lambda: self.subscribers.remove(callback) + + async def subscribe_stream(self) -> AsyncGenerator[RoborockMessage, None]: + """Subscribe to the device's response topic and stream messages.""" + message_queue: asyncio.Queue[RoborockMessage] = asyncio.Queue() + unsub = await self.subscribe(message_queue.put_nowait) + try: + while True: + message = await message_queue.get() + yield message + finally: + unsub() diff --git a/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr b/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr new file mode 100644 index 00000000..a88bbf8d --- /dev/null +++ b/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr @@ -0,0 +1,90 @@ +# serializer version: 1 +# name: test_decode_rpc_payload[dpBattery] + ''' + { + "dpBattery": 100 + } + ''' +# --- +# name: test_decode_rpc_payload[dpFault] + ''' + { + "dpCommon": { + "90": 0 + } + } + ''' +# --- +# name: test_decode_rpc_payload[dpRequetdps] + ''' + { + "dpCommon": { + "104": 0, + "105": false, + "109": "us", + "207": 0, + "25": 1, + "26": 74, + "29": 0, + "30": 0, + "31": 0, + "37": 1, + "40": 1, + "45": 0, + "47": 0, + "50": 0, + "51": true, + "53": false, + "6": 0, + "60": 1, + "67": 0, + "7": 0, + "76": 0, + "78": 0, + "79": { + "timeZoneCity": "America/Los_Angeles", + "timeZoneSec": -28800 + }, + "80": 0, + "81": { + "ipAdress": "1.1.1.2", + "mac": "99:AA:88:BB:77:CC", + "signal": -50, + "wifiName": "wifi-network-name" + }, + "83": 1, + "86": 1, + "87": 100, + "88": 0, + "90": 0, + "92": { + "disturb_dust_enable": 1, + "disturb_light": 1, + "disturb_resume_clean": 1, + "disturb_voice": 1 + }, + "93": 1, + "96": 0 + }, + "dpStatus": 8, + "dpBattery": 100, + "dpfunLevel": 2, + "dpWaterLevel": 1, + "dpMainBrushLife": 0, + "dpSideBrushLife": 0, + "dpFilterLife": 0, + "dpCleanCount": 1, + "dpCleanMode": 1, + "dpCleanTaskType": 0, + "dpBackType": 5 + } + ''' +# --- +# name: test_decode_rpc_payload[dpStatus-dpCleanTaskType] + ''' + { + "dpStatus": 8, + "dpCleanTaskType": 0 + } + ''' +# --- diff --git a/tests/protocols/test_b01_q07_protocol.py b/tests/protocols/test_b01_q07_protocol.py index db47bc2f..2f107c89 100644 --- a/tests/protocols/test_b01_q07_protocol.py +++ b/tests/protocols/test_b01_q07_protocol.py @@ -10,14 +10,14 @@ from freezegun import freeze_time from syrupy import SnapshotAssertion -from roborock.protocols.b01_protocol import ( +from roborock.protocols.b01_q7_protocol import ( decode_rpc_response, encode_mqtt_payload, ) from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_protocol") -TESTDATA_FILES = list(TESTDATA_PATH.glob("**/*.json")) +TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_protocol/q7") +TESTDATA_FILES = list(TESTDATA_PATH.glob("*.json")) TESTDATA_IDS = [x.stem for x in TESTDATA_FILES] diff --git a/tests/protocols/test_b01_q10_protocol.py b/tests/protocols/test_b01_q10_protocol.py new file mode 100644 index 00000000..0125dd37 --- /dev/null +++ b/tests/protocols/test_b01_q10_protocol.py @@ -0,0 +1,69 @@ +"""Tests for the B01 protocol message encoding and decoding.""" + +import json +import pathlib +from collections.abc import Generator +from typing import Any + +import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import unpad +from freezegun import freeze_time +from syrupy import SnapshotAssertion + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.protocols.b01_q10_protocol import ( + decode_rpc_response, + encode_mqtt_payload, +) +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_protocol/q10") +TESTDATA_FILES = list(TESTDATA_PATH.glob("*.json")) +TESTDATA_IDS = [x.stem for x in TESTDATA_FILES] + + +@pytest.fixture(autouse=True) +def fixed_time_fixture() -> Generator[None, None, None]: + """Fixture to freeze time for predictable request IDs.""" + with freeze_time("2025-01-20T12:00:00"): + yield + + +@pytest.mark.parametrize("filename", TESTDATA_FILES, ids=TESTDATA_IDS) +def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None: + """Test decoding a B01 RPC response protocol message.""" + with open(filename, "rb") as f: + payload = f.read() + + message = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=payload, + seq=12750, + version=b"B01", + random=97431, + timestamp=1652547161, + ) + + decoded_message = decode_rpc_response(message) + assert json.dumps(decoded_message, indent=2) == snapshot + + +@pytest.mark.parametrize( + ("command", "params"), + [ + (B01_Q10_DP.REQUETDPS, {}), + ], +) +def test_encode_mqtt_payload(command: B01_Q10_DP, params: dict[str, Any]) -> None: + """Test encoding of MQTT payload for B01 commands.""" + + message = encode_mqtt_payload(command, params) + assert isinstance(message, RoborockMessage) + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == b"B01" + assert message.payload is not None + unpadded = unpad(message.payload, AES.block_size) + decoded_json = json.loads(unpadded.decode("utf-8")) + + assert decoded_json == {"dps": {"102": {}}} diff --git a/tests/protocols/testdata/b01_protocol/q10/dpBattery.json b/tests/protocols/testdata/b01_protocol/q10/dpBattery.json new file mode 100644 index 00000000..b7a76fda --- /dev/null +++ b/tests/protocols/testdata/b01_protocol/q10/dpBattery.json @@ -0,0 +1 @@ +{"dps":{"122":100},"t":1766800902} diff --git a/tests/protocols/testdata/b01_protocol/q10/dpFault.json b/tests/protocols/testdata/b01_protocol/q10/dpFault.json new file mode 100644 index 00000000..b0c0494a --- /dev/null +++ b/tests/protocols/testdata/b01_protocol/q10/dpFault.json @@ -0,0 +1 @@ +{"dps":{"101":{"90":0}},"t":1766800904} diff --git a/tests/protocols/testdata/b01_protocol/q10/dpRequetdps.json b/tests/protocols/testdata/b01_protocol/q10/dpRequetdps.json new file mode 100644 index 00000000..13d52243 --- /dev/null +++ b/tests/protocols/testdata/b01_protocol/q10/dpRequetdps.json @@ -0,0 +1 @@ +{"dps":{"101":{"104":0,"105":false,"109":"us","207":0,"25":1,"26":74,"29":0,"30":0,"31":0,"37":1,"40":1,"45":0,"47":0,"50":0,"51":true,"53":false,"6":0,"60":1,"67":0,"7":0,"76":0,"78":0,"79":{"timeZoneCity":"America/Los_Angeles","timeZoneSec":-28800},"80":0,"81":{"ipAdress":"1.1.1.2","mac":"99:AA:88:BB:77:CC","signal":-50,"wifiName":"wifi-network-name"},"83":1,"86":1,"87":100,"88":0,"90":0,"92":{"disturb_dust_enable":1,"disturb_light":1,"disturb_resume_clean":1,"disturb_voice":1},"93":1,"96":0},"121":8,"122":100,"123":2,"124":1,"125":0,"126":0,"127":0,"136":1,"137":1,"138":0,"139":5},"t":1766802312} diff --git a/tests/protocols/testdata/b01_protocol/q10/dpStatus-dpCleanTaskType.json b/tests/protocols/testdata/b01_protocol/q10/dpStatus-dpCleanTaskType.json new file mode 100644 index 00000000..ed9de954 --- /dev/null +++ b/tests/protocols/testdata/b01_protocol/q10/dpStatus-dpCleanTaskType.json @@ -0,0 +1 @@ +{"dps":{"121":8,"138":0},"t":1766800904} From 8e974ba309edae6aea613656ea01824ea125d871 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Fri, 26 Dec 2025 21:04:59 -0800 Subject: [PATCH 2/6] chore: fix lint errors in mqtt channel --- roborock/devices/mqtt_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index dcbccb81..5ff0ab08 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -76,7 +76,7 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab async def subscribe_stream(self) -> AsyncGenerator[RoborockMessage, None]: """Subscribe to the device's message stream. - + This is useful for processing all incoming messages in an async for loop, when they are not necessarily associated with a specific request. """ From e0305cf9f3e191f9f3533ae4aeadb2b6035f617a Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 27 Dec 2025 07:20:39 -0800 Subject: [PATCH 3/6] chore: Apply suggestions for co-pilot code review --- roborock/cli.py | 5 +++-- roborock/devices/traits/b01/q10/__init__.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/roborock/cli.py b/roborock/cli.py index 8c1810c5..b99caba5 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -97,7 +97,8 @@ async def run(): except Exception: _LOGGER.exception("Uncaught exception in command") click.echo(f"Error: {sys.exc_info()[1]}", err=True) - await context.cleanup() + finally: + await context.cleanup() if context.is_session_mode(): # Session mode - run in the persistent loop @@ -775,7 +776,7 @@ async def command(ctx, cmd, device_id, params): if (cmd_value := _parse_b01_q10_command(cmd)) is None: raise RoborockException(f"Invalid command {cmd} for B01_Q10 device") await device.b01_q10_properties.send(cmd_value, json.loads(params) if params is not None else {}) - # B10 Commands don't have a specific time to respond, so wait a bit + # Q10 commands don't have a specific time to respond, so wait a bit await asyncio.sleep(5) else: raise RoborockException(f"Device {device.name} does not support sending raw commands") diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index 0e5661b8..fcc18173 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -45,7 +45,7 @@ async def start_clean(self) -> None: command=B01_Q10_DP.START_CLEAN, # TODO: figure out other commands # 1 = start cleaning - # 2 = electoral clean, also has "clean_paramters" + # 2 = "electoral" clean, also has "clean_parameters" # 4 = fast create map params={"cmd": 1}, ) @@ -58,7 +58,7 @@ async def pause_clean(self) -> None: ) async def resume_clean(self) -> None: - """Pause cleaning.""" + """Resume cleaning.""" await self.send( command=B01_Q10_DP.RESUME, params={}, From 7a7326867dbabf8264018b7ce769c88ff3c9f001 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 27 Dec 2025 07:42:40 -0800 Subject: [PATCH 4/6] chore: Update test coverage for protocol parsing --- roborock/protocols/b01_q10_protocol.py | 33 ++++++--- .../__snapshots__/test_b01_q10_protocol.ambr | 68 +++++++++---------- tests/protocols/test_b01_q10_protocol.py | 32 +++++++-- 3 files changed, 86 insertions(+), 47 deletions(-) diff --git a/roborock/protocols/b01_q10_protocol.py b/roborock/protocols/b01_q10_protocol.py index 71119030..bd354175 100644 --- a/roborock/protocols/b01_q10_protocol.py +++ b/roborock/protocols/b01_q10_protocol.py @@ -36,9 +36,11 @@ def _convert_datapoints(datapoints: dict[str, Any], message: RoborockMessage) -> """Convert the 'dps' dictionary keys from strings to B01_Q10_DP enums.""" result = {} for key, value in datapoints.items(): - if not isinstance(key, str): - raise RoborockException(f"Invalid B01 message format: 'dps' keys should be strings for {message.payload!r}") - dps = B01_Q10_DP.from_code(int(key)) + try: + code = int(key) + except ValueError as e: + raise ValueError(f"dps key is not a valid integer: {e} for {message.payload!r}") from e + dps = B01_Q10_DP.from_code(code) result[dps] = value return result @@ -49,16 +51,29 @@ def decode_rpc_response(message: RoborockMessage) -> dict[B01_Q10_DP, Any]: raise RoborockException("Invalid B01 message format: missing payload") try: payload = json.loads(message.payload.decode()) - except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: - raise RoborockException(f"Invalid B01 message payload: {e} for {message.payload!r}") from e + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise RoborockException(f"Invalid B01 json payload: {e} for {message.payload!r}") from e - datapoints = payload.get("dps", {}) + if (datapoints := payload.get("dps")) is None: + raise RoborockException(f"Invalid B01 json payload: missing 'dps' for {message.payload!r}") if not isinstance(datapoints, dict): raise RoborockException(f"Invalid B01 message format: 'dps' should be a dictionary for {message.payload!r}") - result = _convert_datapoints(datapoints, message) - # The COMMON response contains nested datapoints that also need conversion + try: + result = _convert_datapoints(datapoints, message) + except ValueError as e: + raise RoborockException(f"Invalid B01 message format: {e}") from e + + # The COMMON response contains nested datapoints that also need conversion. + # We will parse that here for now, but may move elsewhere as we add more + # complex response parsing. if common_result := result.get(B01_Q10_DP.COMMON): - common_dps_result = _convert_datapoints(common_result, message) + if not isinstance(common_result, dict): + raise RoborockException(f"Invalid dpCommon format: expected dict, got {type(common_result).__name__}") + try: + common_dps_result = _convert_datapoints(common_result, message) + except ValueError as e: + raise RoborockException(f"Invalid dpCommon format: {e}") from e result[B01_Q10_DP.COMMON] = common_dps_result + return result diff --git a/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr b/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr index a88bbf8d..04553cd4 100644 --- a/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr +++ b/tests/protocols/__snapshots__/test_b01_q10_protocol.ambr @@ -10,7 +10,7 @@ ''' { "dpCommon": { - "90": 0 + "dpFault": 0 } } ''' @@ -19,52 +19,52 @@ ''' { "dpCommon": { - "104": 0, - "105": false, - "109": "us", - "207": 0, - "25": 1, - "26": 74, - "29": 0, - "30": 0, - "31": 0, - "37": 1, - "40": 1, - "45": 0, - "47": 0, - "50": 0, - "51": true, - "53": false, - "6": 0, - "60": 1, - "67": 0, - "7": 0, - "76": 0, - "78": 0, - "79": { + "dpBreakpointClean": 0, + "dpValleyPointCharging": false, + "dpRobotCountryCode": "us", + "dpUserPlan": 0, + "dpNotDisturb": 1, + "dpVolume": 74, + "dpTotalCleanArea": 0, + "dpTotalCleanCount": 0, + "dpTotalCleanTime": 0, + "dpDustSwitch": 1, + "dpMopState": 1, + "dpAutoBoost": 0, + "dpChildLock": 0, + "dpDustSetting": 0, + "dpMapSaveSwitch": true, + "dpRecendCleanRecord": false, + "dpCleanTime": 0, + "dpMultiMapSwitch": 1, + "dpSensorLife": 0, + "dpCleanArea": 0, + "dpCarpetCleanType": 0, + "dpCleanLine": 0, + "dpTimeZone": { "timeZoneCity": "America/Los_Angeles", "timeZoneSec": -28800 }, - "80": 0, - "81": { + "dpAreaUnit": 0, + "dpNetInfo": { "ipAdress": "1.1.1.2", "mac": "99:AA:88:BB:77:CC", "signal": -50, "wifiName": "wifi-network-name" }, - "83": 1, - "86": 1, - "87": 100, - "88": 0, - "90": 0, - "92": { + "dpRobotType": 1, + "dpLineLaserObstacleAvoidance": 1, + "dpCleanProgess": 100, + "dpGroundClean": 0, + "dpFault": 0, + "dpNotDisturbExpand": { "disturb_dust_enable": 1, "disturb_light": 1, "disturb_resume_clean": 1, "disturb_voice": 1 }, - "93": 1, - "96": 0 + "dpTimerType": 1, + "dpAddCleanState": 0 }, "dpStatus": 8, "dpBattery": 100, diff --git a/tests/protocols/test_b01_q10_protocol.py b/tests/protocols/test_b01_q10_protocol.py index 0125dd37..77ebcc7d 100644 --- a/tests/protocols/test_b01_q10_protocol.py +++ b/tests/protocols/test_b01_q10_protocol.py @@ -6,12 +6,11 @@ from typing import Any import pytest -from Crypto.Cipher import AES -from Crypto.Util.Padding import unpad from freezegun import freeze_time from syrupy import SnapshotAssertion from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.exceptions import RoborockException from roborock.protocols.b01_q10_protocol import ( decode_rpc_response, encode_mqtt_payload, @@ -49,6 +48,32 @@ def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None: assert json.dumps(decoded_message, indent=2) == snapshot +@pytest.mark.parametrize( + ("payload", "expected_error_message"), + [ + (b"", "missing payload"), + (b"n", "Invalid B01 json payload"), + (b"{}", "missing 'dps'"), + (b'{"dps": []}', "'dps' should be a dictionary"), + (b'{"dps": {"not_a_number": 123}}', "dps key is not a valid integer"), + (b'{"dps": {"101": 123}}', "Invalid dpCommon format: expected dict"), + (b'{"dps": {"101": {"not_a_number": 123}}}', "Invalid dpCommon format: dps key is not a valid intege"), + ], +) +def test_decode_invalid_rpc_payload(payload: bytes, expected_error_message: str) -> None: + """Test decoding a B01 RPC response protocol message.""" + message = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=payload, + seq=12750, + version=b"B01", + random=97431, + timestamp=1652547161, + ) + with pytest.raises(RoborockException, match=expected_error_message): + decode_rpc_response(message) + + @pytest.mark.parametrize( ("command", "params"), [ @@ -63,7 +88,6 @@ def test_encode_mqtt_payload(command: B01_Q10_DP, params: dict[str, Any]) -> Non assert message.protocol == RoborockMessageProtocol.RPC_REQUEST assert message.version == b"B01" assert message.payload is not None - unpadded = unpad(message.payload, AES.block_size) - decoded_json = json.loads(unpadded.decode("utf-8")) + decoded_json = json.loads(message.payload.decode("utf-8")) assert decoded_json == {"dps": {"102": {}}} From 7b5688c605f5e981cbc0a2bb7a24d3a5abcde1b4 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 27 Dec 2025 07:44:20 -0800 Subject: [PATCH 5/6] chore: Improve test coverage for invalid pessage parsing --- roborock/protocols/b01_q10_protocol.py | 5 ++++- tests/protocols/test_b01_q10_protocol.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/roborock/protocols/b01_q10_protocol.py b/roborock/protocols/b01_q10_protocol.py index bd354175..d243935c 100644 --- a/roborock/protocols/b01_q10_protocol.py +++ b/roborock/protocols/b01_q10_protocol.py @@ -40,7 +40,10 @@ def _convert_datapoints(datapoints: dict[str, Any], message: RoborockMessage) -> code = int(key) except ValueError as e: raise ValueError(f"dps key is not a valid integer: {e} for {message.payload!r}") from e - dps = B01_Q10_DP.from_code(code) + try: + dps = B01_Q10_DP.from_code(code) + except ValueError as e: + raise ValueError(f"dps key is not a valid B01_Q10_DP: {e} for {message.payload!r}") from e result[dps] = value return result diff --git a/tests/protocols/test_b01_q10_protocol.py b/tests/protocols/test_b01_q10_protocol.py index 77ebcc7d..28e20d13 100644 --- a/tests/protocols/test_b01_q10_protocol.py +++ b/tests/protocols/test_b01_q10_protocol.py @@ -58,6 +58,7 @@ def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None: (b'{"dps": {"not_a_number": 123}}', "dps key is not a valid integer"), (b'{"dps": {"101": 123}}', "Invalid dpCommon format: expected dict"), (b'{"dps": {"101": {"not_a_number": 123}}}', "Invalid dpCommon format: dps key is not a valid intege"), + (b'{"dps": {"909090": 123}}', "dps key is not a valid B01_Q10_DP"), ], ) def test_decode_invalid_rpc_payload(payload: bytes, expected_error_message: str) -> None: From 8fe70c0ab110711efd3dc6a556fa9c27437f48a2 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 27 Dec 2025 07:46:08 -0800 Subject: [PATCH 6/6] chore: Improve cancel readability --- roborock/devices/traits/b01/q10/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index fcc18173..88280069 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -36,7 +36,7 @@ async def close(self) -> None: try: await self._task except asyncio.CancelledError: - pass + pass # ignore cancellation errors self._task = None async def start_clean(self) -> None: