diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index 32d8679..b7e328e 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import warnings from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory @@ -9,7 +10,7 @@ class WsPrivateAsync: - def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): + def __init__(self, apiKey, passphrase, secretKey, url, useServerTime=None, debug=False): self.url = url self.subscriptions = set() self.callback = None @@ -18,15 +19,25 @@ def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): self.apiKey = apiKey self.passphrase = passphrase self.secretKey = secretKey - self.useServerTime = useServerTime + self.useServerTime = False self.websocket = None + self.debug = debug + + # Set log level + if debug: + logger.setLevel(logging.DEBUG) + + # Deprecation warning for useServerTime parameter + if useServerTime is not None: + warnings.warn("useServerTime parameter is deprecated. Please remove it.", DeprecationWarning) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) @@ -43,6 +54,8 @@ async def subscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() @@ -53,6 +66,8 @@ async def login(self): passphrase=self.passphrase, secretKey=self.secretKey ) + if self.debug: + logger.debug(f"login: {loginPayload}") await self.websocket.send(loginPayload) return True @@ -65,16 +80,119 @@ async def unsubscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) - logger.info(f"unsubscribe: {payload}") + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) - # for param in params: - # self.subscriptions.discard(param) + + async def place_order(self, args: list, callback=None, id: str = None): + """ + Place order + :param args: Order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("order", args, id=id) + + async def batch_orders(self, args: list, callback=None, id: str = None): + """ + Batch place orders + :param args: Batch order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-orders", args, id=id) + + async def cancel_order(self, args: list, callback=None, id: str = None): + """ + Cancel order + :param args: Cancel order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("cancel-order", args, id=id) + + async def batch_cancel_orders(self, args: list, callback=None, id: str = None): + """ + Batch cancel orders + :param args: Batch cancel order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-cancel-orders", args, id=id) + + async def amend_order(self, args: list, callback=None, id: str = None): + """ + Amend order + :param args: Amend order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("amend-order", args, id=id) + + async def batch_amend_orders(self, args: list, callback=None, id: str = None): + """ + Batch amend orders + :param args: Batch amend order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-amend-orders", args, id=id) + + async def mass_cancel(self, args: list, callback=None, id: str = None): + """ + Mass cancel orders + Note: This method is for /ws/v5/business channel, rate limit: 1 request/second + :param args: Cancel parameter list, contains instType and instFamily + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("mass-cancel", args, id=id) async def stop(self): await self.factory.close() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index b625b2d..d7eefdc 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -2,29 +2,60 @@ import json import logging +from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory logger = logging.getLogger(__name__) class WsPublicAsync: - def __init__(self, url): + def __init__(self, url, apiKey='', passphrase='', secretKey='', debug=False): self.url = url self.subscriptions = set() self.callback = None self.loop = asyncio.get_event_loop() self.factory = WebSocketFactory(url) self.websocket = None + self.debug = debug + # Credentials for business channel login + self.apiKey = apiKey + self.passphrase = passphrase + self.secretKey = secretKey + self.isLoggedIn = False + + # Set log level + if debug: + logger.setLevel(logging.DEBUG) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) + async def login(self): + """ + Login method for business channel that requires authentication (e.g. /ws/v5/business) + """ + if not self.apiKey or not self.secretKey or not self.passphrase: + raise ValueError("apiKey, secretKey and passphrase are required for login") + + loginPayload = WsUtils.initLoginParams( + useServerTime=False, + apiKey=self.apiKey, + passphrase=self.passphrase, + secretKey=self.secretKey + ) + if self.debug: + logger.debug(f"login: {loginPayload}") + await self.websocket.send(loginPayload) + self.isLoggedIn = True + return True + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback payload_dict = { @@ -34,6 +65,8 @@ async def subscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() @@ -46,14 +79,41 @@ async def unsubscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) - logger.info(f"unsubscribe: {payload}") + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) async def stop(self): await self.factory.close() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) diff --git a/test/test_ws_private_async.py b/test/test_ws_private_async.py index 1244387..10f34bf 100644 --- a/test/test_ws_private_async.py +++ b/test/test_ws_private_async.py @@ -16,7 +16,7 @@ async def main(): passphrase=passphrase, secretKey=api_secret_key, url=url, - useServerTime=False + debug=True ) await ws.start() args = [] @@ -41,11 +41,238 @@ async def main(): await asyncio.sleep(5) print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg3] - await ws.unsubscribe(args3, callback=privateCallback, id="privateUnsub002") + await ws.unsubscribe(args3, callback=privateCallback) await asyncio.sleep(1) - # Properly close websocket connection + await ws.stop() + + +async def test_place_order(): + """ + Test place order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Order parameters + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "client_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.place_order(order_args, callback=privateCallback, id="order001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_orders(): + """ + Test batch orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Batch order parameters (max 20) + order_args = [ + { + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }, + { + "instId": "ETH-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_002", + "side": "buy", + "ordType": "limit", + "sz": "0.01", + "px": "2000" + } + ] + await ws.batch_orders(order_args, callback=privateCallback, id="batchOrder001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_cancel_order(): + """ + Test cancel order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Cancel order parameters (either ordId or clOrdId must be provided) + cancel_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id" + # Or use "clOrdId": "client_order_001" + }] + await ws.cancel_order(cancel_args, callback=privateCallback, id="cancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_cancel_orders(): + """ + Test batch cancel orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "order_id_1"}, + {"instId": "ETH-USDT", "ordId": "order_id_2"} + ] + await ws.batch_cancel_orders(cancel_args, callback=privateCallback, id="batchCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_amend_order(): + """ + Test amend order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Amend order parameters + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id", + "newSz": "0.002", + "newPx": "31000" + }] + await ws.amend_order(amend_args, callback=privateCallback, id="amend001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_mass_cancel(): + """ + Test mass cancel functionality + URL: /ws/v5/business (Rate limit: 1 request/second) + Note: This function uses the business channel + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Mass cancel parameters + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + await ws.mass_cancel(mass_cancel_args, callback=privateCallback, id="massCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_send_method(): + """Test generic send method""" + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Use generic send method to place order - callback must be provided to receive response + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.send("order", order_args, callback=privateCallback, id="send001") + await asyncio.sleep(5) await ws.stop() if __name__ == '__main__': - asyncio.run(main()) + # asyncio.run(main()) + asyncio.run(test_place_order()) + asyncio.run(test_batch_orders()) + asyncio.run(test_cancel_order()) + asyncio.run(test_batch_cancel_orders()) + asyncio.run(test_amend_order()) + asyncio.run(test_mass_cancel()) # Note: uses business channel + asyncio.run(test_send_method()) diff --git a/test/test_ws_public_async.py b/test/test_ws_public_async.py index 1ed1172..dac6c84 100644 --- a/test/test_ws_public_async.py +++ b/test/test_ws_public_async.py @@ -8,10 +8,9 @@ def publicCallback(message): async def main(): - # url = "wss://wspap.okex.com:8443/ws/v5/public?brokerId=9999" url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" - ws = WsPublicAsync(url=url) + ws = WsPublicAsync(url=url, debug=True) # Enable debug logging await ws.start() args = [] arg1 = {"channel": "instruments", "instType": "FUTURES"} @@ -34,9 +33,49 @@ async def main(): args3 = [arg1, arg2, arg3] await ws.unsubscribe(args3, publicCallback) await asyncio.sleep(1) - # Properly close websocket connection + await ws.stop() + + +async def test_business_channel_with_login(): + """ + Test business channel login functionality + Business channel requires login to subscribe to certain private data + """ + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPublicAsync( + url=url, + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + debug=True + ) + await ws.start() + + # Login + await ws.login() + await asyncio.sleep(5) + + # Subscribe to channels that require login + args = [{"channel": "candle1m", "instId": "BTC-USDT"}] + await ws.subscribe(args, publicCallback) + await asyncio.sleep(30) + await ws.stop() + + +async def test_send_method(): + """Test generic send method""" + url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" + ws = WsPublicAsync(url=url, debug=True) + await ws.start() + + # Use generic send method to subscribe - callback must be provided to receive response + args = [{"channel": "tickers", "instId": "BTC-USDT"}] + await ws.send("subscribe", args, callback=publicCallback, id="send001") + await asyncio.sleep(10) await ws.stop() if __name__ == '__main__': - asyncio.run(main()) + # asyncio.run(main()) + # asyncio.run(test_business_channel_with_login()) + asyncio.run(test_send_method()) diff --git a/test/unit/okx/websocket/__init__.py b/test/unit/okx/websocket/__init__.py index 98f2807..b0061bd 100644 --- a/test/unit/okx/websocket/__init__.py +++ b/test/unit/okx/websocket/__init__.py @@ -1 +1,2 @@ -"""Unit tests for okx.websocket package""" +# Unit tests for okx.websocket module + diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py index da367f8..531f5a8 100644 --- a/test/unit/okx/websocket/test_ws_private_async.py +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -6,6 +6,7 @@ import json import unittest import asyncio +import warnings from unittest.mock import patch, MagicMock, AsyncMock # Import the module first so patch can resolve the path @@ -23,8 +24,7 @@ def test_init_with_required_params(self): apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) self.assertEqual(ws.apiKey, "test_api_key") @@ -32,8 +32,60 @@ def test_init_with_required_params(self): self.assertEqual(ws.secretKey, "test_secret_key") self.assertEqual(ws.url, "wss://test.example.com") self.assertFalse(ws.useServerTime) + self.assertFalse(ws.debug) mock_factory.assert_called_once_with("wss://test.example.com") + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + debug=True + ) + + self.assertTrue(ws.debug) + + def test_init_with_deprecated_useServerTime_shows_warning(self): + """Test that using deprecated useServerTime parameter shows warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=True + ) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("useServerTime parameter is deprecated", str(w[0].message)) + + def test_init_without_useServerTime_no_warning(self): + """Test that not using useServerTime parameter shows no warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + # No deprecation warning expected + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + self.assertEqual(len(deprecation_warnings), 0) + class TestWsPrivateAsyncSubscribe(unittest.TestCase): """Unit tests for WsPrivateAsync subscribe method""" @@ -45,13 +97,12 @@ def test_subscribe_sends_correct_payload(self): patch.object(ws_private_module.asyncio, 'sleep', new_callable=AsyncMock): mock_ws_utils.initLoginParams.return_value = '{"op":"login"}' - + ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -82,8 +133,7 @@ def test_subscribe_with_id(self): apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -111,8 +161,7 @@ def test_unsubscribe_sends_correct_payload(self): apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -136,8 +185,7 @@ def test_unsubscribe_with_id(self): apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -154,6 +202,326 @@ async def run_test(): asyncio.get_event_loop().run_until_complete(run_test()) +class TestWsPrivateAsyncSend(unittest.TestCase): + """Unit tests for WsPrivateAsync generic send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncOrderMethods(unittest.TestCase): + """Unit tests for WsPrivateAsync order-related methods""" + + def _create_ws_instance(self): + """Helper to create WsPrivateAsync instance with mocked websocket""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + return ws, mock_websocket + + def test_place_order_sends_correct_payload(self): + """Test place_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + + async def run_test(): + await ws.place_order(order_args, callback=callback, id="order001") + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "order001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_place_order_without_id(self): + """Test place_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.place_order(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_sends_correct_payload(self): + """Test batch_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [ + {"instId": "BTC-USDT", "side": "buy", "sz": "0.001", "px": "30000"}, + {"instId": "ETH-USDT", "side": "buy", "sz": "0.01", "px": "2000"} + ] + + async def run_test(): + await ws.batch_orders(order_args, callback=callback, id="batch001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "batch001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_without_id(self): + """Test batch_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}, {"instId": "ETH-USDT"}] + + async def run_test(): + await ws.batch_orders(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_sends_correct_payload(self): + """Test cancel_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args, callback=callback, id="cancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "cancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_without_id(self): + """Test cancel_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_sends_correct_payload(self): + """Test batch_cancel_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "12345"}, + {"instId": "ETH-USDT", "ordId": "67890"} + ] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args, callback=callback, id="batchCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "batchCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_without_id(self): + """Test batch_cancel_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_sends_correct_payload(self): + """Test amend_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "12345", + "newSz": "0.002", + "newPx": "31000" + }] + + async def run_test(): + await ws.amend_order(amend_args, callback=callback, id="amend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "amend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_without_id(self): + """Test amend_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.amend_order(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_sends_correct_payload(self): + """Test batch_amend_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [ + {"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}, + {"instId": "ETH-USDT", "ordId": "67890", "newPx": "2100"} + ] + + async def run_test(): + await ws.batch_amend_orders(amend_args, callback=callback, id="batchAmend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "batchAmend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_without_id(self): + """Test batch_amend_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.batch_amend_orders(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_sends_correct_payload(self): + """Test mass_cancel sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args, callback=callback, id="massCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertEqual(payload["args"], mass_cancel_args) + self.assertEqual(payload["id"], "massCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_without_id(self): + """Test mass_cancel without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + mass_cancel_args = [{"instType": "SPOT", "instFamily": "BTC-USDT"}] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + class TestWsPrivateAsyncLogin(unittest.TestCase): """Unit tests for WsPrivateAsync login method""" @@ -163,13 +531,12 @@ def test_login_calls_init_login_params(self): patch.object(ws_private_module, 'WsUtils') as mock_ws_utils: mock_ws_utils.initLoginParams.return_value = '{"op":"login","args":[...]}' - + ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=True + url="wss://test.example.com" ) mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -178,7 +545,7 @@ async def run_test(): result = await ws.login() self.assertTrue(result) mock_ws_utils.initLoginParams.assert_called_once_with( - useServerTime=True, + useServerTime=False, apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key" @@ -201,8 +568,7 @@ def test_stop(self): apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key", - url="wss://test.example.com", - useServerTime=False + url="wss://test.example.com" ) async def run_test(): diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py index 916a0b9..ca27b38 100644 --- a/test/unit/okx/websocket/test_ws_public_async.py +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -22,17 +22,99 @@ def test_init_with_url(self): ws = WsPublicAsync(url="wss://test.example.com") self.assertEqual(ws.url, "wss://test.example.com") - self.assertIsNone(ws.callback) - self.assertIsNone(ws.websocket) - mock_factory.assert_called_once_with("wss://test.example.com") + self.assertEqual(ws.apiKey, '') + self.assertEqual(ws.passphrase, '') + self.assertEqual(ws.secretKey, '') + self.assertFalse(ws.debug) + self.assertFalse(ws.isLoggedIn) + + def test_init_with_credentials(self): + """Test initialization with all credentials for business channel""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=True) + + self.assertTrue(ws.debug) + + def test_init_with_debug_disabled(self): + """Test initialization with debug mode disabled (default)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=False) + + self.assertFalse(ws.debug) + + +class TestWsPublicAsyncLogin(unittest.TestCase): + """Unit tests for WsPublicAsync login method""" + + def test_login_without_credentials_raises_error(self): + """Test that login raises ValueError when credentials are missing""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + with self.assertRaises(ValueError) as context: + await ws.login() + self.assertIn("apiKey, secretKey and passphrase are required for login", str(context.exception)) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_login_with_credentials_success(self): + """Test successful login with valid credentials""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory, \ + patch('okx.websocket.WsPublicAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + self.assertTrue(ws.isLoggedIn) + mock_init_login.assert_called_once_with( + useServerTime=False, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket.send.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) class TestWsPublicAsyncSubscribe(unittest.TestCase): """Unit tests for WsPublicAsync subscribe method""" - def test_subscribe_sets_callback(self): - """Test subscribe sets callback correctly""" - with patch.object(ws_public_module, 'WebSocketFactory'): + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -137,6 +219,85 @@ async def run_test(): asyncio.get_event_loop().run_until_complete(run_test()) +class TestWsPublicAsyncSend(unittest.TestCase): + """Unit tests for WsPublicAsync send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_without_callback(self): + """Test send method without callback (preserves existing callback)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + existing_callback = MagicMock() + ws.callback = existing_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args) + # Callback should remain unchanged + self.assertEqual(ws.callback, existing_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_new_callback_replaces_existing(self): + """Test send method with new callback replaces existing callback""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + old_callback = MagicMock() + new_callback = MagicMock() + ws.callback = old_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=new_callback) + self.assertEqual(ws.callback, new_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + class TestWsPublicAsyncStartStop(unittest.TestCase): """Unit tests for WsPublicAsync start and stop methods"""