From 91fee83b87c21a3f5da239e064e89ea64e70ea5e Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 15 Dec 2025 18:23:21 +0800 Subject: [PATCH 01/48] add set auto earn endpoint --- okx/Account.py | 4 ++++ okx/__init__.py | 2 +- okx/consts.py | 1 + test/AccountTest.py | 8 +++++--- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/okx/Account.py b/okx/Account.py index 911cf9b..74c84a5 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -323,3 +323,7 @@ def set_auto_repay(self, autoRepay=False): def spot_borrow_repay_history(self, ccy='', type='', after='', before='', limit=''): params = {'ccy': ccy, 'type': type, 'after': after, 'before': before, 'limit': limit} return self._request_with_params(GET, GET_BORROW_REPAY_HISTORY, params) + + def set_auto_earn(self, earnType='', ccy='', action='', apr=''): + params = {'earnType': earnType, 'ccy': ccy, 'action': action, 'apr': apr} + return self._request_with_params(POST, SET_AUTO_EARN, params) diff --git a/okx/__init__.py b/okx/__init__.py index 2dbeeb8..57f92fe 100644 --- a/okx/__init__.py +++ b/okx/__init__.py @@ -2,4 +2,4 @@ Python SDK for the OKX API v5 """ -__version__="0.4.0" \ No newline at end of file +__version__="0.4.1" \ No newline at end of file diff --git a/okx/consts.py b/okx/consts.py index aae0778..54bb5f4 100644 --- a/okx/consts.py +++ b/okx/consts.py @@ -66,6 +66,7 @@ MANUAL_REBORROW_REPAY = '/api/v5/account/spot-manual-borrow-repay' SET_AUTO_REPAY='/api/v5/account/set-auto-repay' GET_BORROW_REPAY_HISTORY='/api/v5/account/spot-borrow-repay-history' +SET_AUTO_EARN='/api/v5/account/set-auto-earn' # Funding NON_TRADABLE_ASSETS = '/api/v5/asset/non-tradable-assets' diff --git a/test/AccountTest.py b/test/AccountTest.py index 724112b..86260a0 100644 --- a/test/AccountTest.py +++ b/test/AccountTest.py @@ -7,9 +7,9 @@ class AccountTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key = 'da097c9c-2f77-4dea-be18-2bfa77d0e394' + api_secret_key = '56CC6C72D6B8A46EC993D48C83142A25' + passphrase = '123456aA.' self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag='1') # ''' @@ -146,6 +146,8 @@ def setUp(self): # logger.info(f'{self.AccountAPI.set_auto_repay(autoRepay=True)}') # def test_spot_borrow_repay_history(self): # logger.debug(self.AccountAPI.spot_borrow_repay_history(ccy="USDT",type="auto_borrow",after="1597026383085")) + def test_set_auto_earn(self): + logger.debug(self.AccountAPI.set_auto_earn(earnType='0',ccy="USDT",action="turn_on")) if __name__ == '__main__': unittest.main() From 7a4fa53efd11975570d3dea368cc0697a200235f Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Tue, 16 Dec 2025 10:26:08 +0800 Subject: [PATCH 02/48] feat: add idxVol and fix wrong params for posBuilder --- okx/Account.py | 8 +++++--- test/__init__.py | 0 2 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 test/__init__.py diff --git a/okx/Account.py b/okx/Account.py index 911cf9b..8de8532 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -28,20 +28,22 @@ def get_positions(self, instType='', instId='', posId=''): return self._request_with_params(GET, POSITION_INFO, params) def position_builder(self, acctLv=None,inclRealPosAndEq=False, lever=None, greeksType=None, simPos=None, - simAsset=None): + simAsset=None, idxVol=None): params = {} if acctLv is not None: params['acctLv'] = acctLv if inclRealPosAndEq is not None: params['inclRealPosAndEq'] = inclRealPosAndEq if lever is not None: - params['spotOffsetType'] = lever + params['lever'] = lever if greeksType is not None: - params['greksType'] = greeksType + params['greeksType'] = greeksType if simPos is not None: params['simPos'] = simPos if simAsset is not None: params['simAsset'] = simAsset + if idxVol is not None: + params['idxVol'] = idxVol return self._request_with_params(POST, POSITION_BUILDER, params) # Get Bills Details (recent 7 days) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 From f8847a91dd455775bfe9f2497c8c5160da209f8c Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Tue, 16 Dec 2025 10:55:37 +0800 Subject: [PATCH 03/48] feat: add unit-testing for the test_account --- okx/Account.py | 2 +- test/unit/__init__.py | 10 + test/unit/okx/__init__.py | 2 + test/unit/okx/test_account.py | 441 ++++++++++++++++++++++++++++++++++ 4 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 test/unit/__init__.py create mode 100644 test/unit/okx/__init__.py create mode 100644 test/unit/okx/test_account.py diff --git a/okx/Account.py b/okx/Account.py index 8de8532..e3952ff 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -27,7 +27,7 @@ def get_positions(self, instType='', instId='', posId=''): params = {'instType': instType, 'instId': instId, 'posId': posId} return self._request_with_params(GET, POSITION_INFO, params) - def position_builder(self, acctLv=None,inclRealPosAndEq=False, lever=None, greeksType=None, simPos=None, + def position_builder(self, acctLv=None, inclRealPosAndEq=None, lever=None, greeksType=None, simPos=None, simAsset=None, idxVol=None): params = {} if acctLv is not None: diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..940cd5e --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1,10 @@ +""" +Unit tests package + +Unit tests mirror the source code structure for easy navigation. + +Example: + okx/Account.py -> test/unit/okx/test_account.py + okx/Trade.py -> test/unit/okx/test_trade.py + okx/Finance/Savings.py -> test/unit/okx/Finance/test_savings.py +""" diff --git a/test/unit/okx/__init__.py b/test/unit/okx/__init__.py new file mode 100644 index 0000000..2b43621 --- /dev/null +++ b/test/unit/okx/__init__.py @@ -0,0 +1,2 @@ +"""Unit tests for okx package""" + diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py new file mode 100644 index 0000000..6e5f0c4 --- /dev/null +++ b/test/unit/okx/test_account.py @@ -0,0 +1,441 @@ +""" +Unit tests for okx.Account module + +Mirrors the structure: okx/Account.py -> test/unit/okx/test_account.py +""" +import unittest +from unittest.mock import patch +from okx.Account import AccountAPI +from okx import consts as c + + +class TestAccountAPIPositionBuilder(unittest.TestCase): + """Unit tests for the position_builder method""" + + def setUp(self): + """Set up test fixtures""" + self.api_key = 'test_api_key' + self.api_secret = 'test_api_secret' + self.passphrase = 'test_passphrase' + self.account_api = AccountAPI( + api_key=self.api_key, + api_secret_key=self.api_secret, + passphrase=self.passphrase, + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_all_parameters(self, mock_request): + """Test position_builder with all parameters provided""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '1000', + 'imr': '2000', + 'mmrBf': '900', + 'imrBf': '1900' + }] + } + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}] + sim_asset = [{'ccy': 'USDT', 'amt': '10000'}] + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=True, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': True, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_idxVol_only(self, mock_request): + """Test position_builder with only idxVol parameter""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0.1') + + # Assert + expected_params = { + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_negative_idxVol(self, mock_request): + """Test position_builder with negative idxVol (price decrease)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.05') + + # Assert + expected_params = { + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_no_parameters(self, mock_request): + """Test position_builder with no parameters (all None)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder() + + # Assert + # Should pass empty params dict + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, {}) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_simulated_positions(self, mock_request): + """Test position_builder with simulated positions and assets""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '5000', + 'imr': '10000' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '100', 'avgPx': '3000'} + ] + sim_asset = [ + {'ccy': 'USDT', 'amt': '100000'}, + {'ccy': 'BTC', 'amt': '1'} + ] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.1' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_pa(self, mock_request): + """Test position_builder with greeksType PA""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='PA') + + # Assert + expected_params = {'greeksType': 'PA'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_bs(self, mock_request): + """Test position_builder with greeksType BS""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='BS') + + # Assert + expected_params = {'greeksType': 'BS'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_includes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=True""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=True, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': True, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_excludes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=False (only virtual positions)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '5', 'avgPx': '60000'}] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_account_level(self, mock_request): + """Test position_builder with specific account level""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(acctLv='3') + + # Assert + expected_params = {'acctLv': '3'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_leverage(self, mock_request): + """Test position_builder with leverage parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(lever='10') + + # Assert + expected_params = {'lever': '10'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_positive(self, mock_request): + """Test position_builder with maximum positive volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='1') + + # Assert + expected_params = {'idxVol': '1'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_negative(self, mock_request): + """Test position_builder with maximum negative volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.99') + + # Assert + expected_params = {'idxVol': '-0.99'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_complex_scenario(self, mock_request): + """Test position_builder with a complex realistic scenario""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '15000', + 'imr': '30000', + 'mmrBf': '14000', + 'imrBf': '28000', + 'markPxBf': '49500' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '-50', 'avgPx': '3000'} + ] + sim_asset = [{'ccy': 'USDT', 'amt': '50000'}] + + # Act - Simulate a 5% market drop + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=False, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='-0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': False, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result['code'], '0') + self.assertIn('mmrBf', result['data'][0]) + self.assertIn('imrBf', result['data'][0]) + + +class TestAccountAPIPositionBuilderParameterHandling(unittest.TestCase): + """Test parameter handling and edge cases""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_none_parameters_are_excluded(self, mock_request): + """Test that None parameters are not included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=None, # Should be excluded + lever=None, # Should be excluded + greeksType='PA', + simPos=None, # Should be excluded + simAsset=None, # Should be excluded + idxVol='0.05' + ) + + # Assert - Only non-None params should be in the call + expected_params = { + 'acctLv': '2', + 'greeksType': 'PA', + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_false_value_for_inclRealPosAndEq_is_included(self, mock_request): + """Test that False value for inclRealPosAndEq is included (not treated as None)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(inclRealPosAndEq=False) + + # Assert - False should be included + expected_params = { + 'inclRealPosAndEq': False + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_empty_lists_are_included(self, mock_request): + """Test that empty lists are included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + simPos=[], + simAsset=[] + ) + + # Assert + expected_params = { + 'simPos': [], + 'simAsset': [] + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_zero_idxVol_is_included(self, mock_request): + """Test that zero idxVol is included (represents no volatility change)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0') + + # Assert + expected_params = { + 'idxVol': '0' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + +if __name__ == '__main__': + unittest.main() + From 0234eced98e4140951bad7788567370bf6ea592b Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Tue, 16 Dec 2025 17:54:45 +0800 Subject: [PATCH 04/48] add set auto earn endpoint --- test/AccountTest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/AccountTest.py b/test/AccountTest.py index 86260a0..3a3258d 100644 --- a/test/AccountTest.py +++ b/test/AccountTest.py @@ -7,9 +7,9 @@ class AccountTest(unittest.TestCase): def setUp(self): - api_key = 'da097c9c-2f77-4dea-be18-2bfa77d0e394' - api_secret_key = '56CC6C72D6B8A46EC993D48C83142A25' - passphrase = '123456aA.' + api_key = 'your_apiKey' + api_secret_key = 'your_secretKey' + passphrase = 'your_secretKey' self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag='1') # ''' From 31fd4013e42ed607ee23ca3b3a95535a840468ef Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Tue, 16 Dec 2025 18:38:17 +0800 Subject: [PATCH 05/48] add id parameter to all websocket subscription --- okx/websocket/WsPrivateAsync.py | 19 ++++++++++++------- okx/websocket/WsPublicAsync.py | 19 ++++++++++++------- test/WsPrivateAsyncTest.py | 16 +++++++++++----- test/WsPublicAsyncTest.py | 9 +++++++-- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index c5359aa..dbf0390 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -30,16 +30,19 @@ async def consume(self): if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback logRes = await self.login() await asyncio.sleep(5) if logRes: - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) await self.websocket.send(payload) # await self.consume() @@ -53,12 +56,15 @@ async def login(self): await self.websocket.send(loginPayload) return True - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) logger.info(f"unsubscribe: {payload}") await self.websocket.send(payload) # for param in params: @@ -66,7 +72,6 @@ async def unsubscribe(self, params: list, callback): async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): logger.info("Connecting to WebSocket...") diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index e576d65..ef44c5a 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -25,27 +25,32 @@ async def consume(self): if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) await self.websocket.send(payload) # await self.consume() - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) logger.info(f"unsubscribe: {payload}") await self.websocket.send(payload) async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): logger.info("Connecting to WebSocket...") diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index ba7fcff..e154a33 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -24,15 +24,21 @@ async def main(): args.append(arg1) args.append(arg2) args.append(arg3) - await ws.subscribe(args, callback=privateCallback) - await asyncio.sleep(30) + # 使用 id 参数来标识订阅请求,响应中会返回相同的 id + # 注意:id 只能包含字母和数字,不能包含下划线等特殊字符 + await ws.subscribe(args, callback=privateCallback, id="privateSub001") + await asyncio.sleep(10) print("-----------------------------------------unsubscribe--------------------------------------------") args2 = [arg2] - await ws.unsubscribe(args2, callback=privateCallback) - await asyncio.sleep(30) + # 使用 id 参数来标识取消订阅请求 + await ws.unsubscribe(args2, callback=privateCallback, id="privateUnsub001") + await asyncio.sleep(5) print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg3] - await ws.unsubscribe(args3, callback=privateCallback) + await ws.unsubscribe(args3, callback=privateCallback, id="privateUnsub002") + await asyncio.sleep(1) + # 正确关闭 websocket 连接 + await ws.stop() if __name__ == '__main__': diff --git a/test/WsPublicAsyncTest.py b/test/WsPublicAsyncTest.py index 14276a0..24364ee 100644 --- a/test/WsPublicAsyncTest.py +++ b/test/WsPublicAsyncTest.py @@ -22,15 +22,20 @@ async def main(): args.append(arg2) args.append(arg3) args.append(arg4) - await ws.subscribe(args, publicCallback) + # 使用 id 参数来标识订阅请求,响应中会返回相同的 id + await ws.subscribe(args, publicCallback, id="sub001") await asyncio.sleep(5) print("-----------------------------------------unsubscribe--------------------------------------------") args2 = [arg4] - await ws.unsubscribe(args2, publicCallback) + # 使用 id 参数来标识取消订阅请求 + await ws.unsubscribe(args2, publicCallback, id="unsub001") await asyncio.sleep(5) print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg2, arg3] await ws.unsubscribe(args3, publicCallback) + await asyncio.sleep(1) + # 正确关闭 websocket 连接 + await ws.stop() if __name__ == '__main__': From 14c3a94c62996bc2a375a18fb4dec01d2b8423b1 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Wed, 17 Dec 2025 11:38:31 +0800 Subject: [PATCH 06/48] websocket enhancement --- okx/websocket/WsPrivateAsync.py | 151 +++++++++++++++++++-- okx/websocket/WsPublicAsync.py | 87 ++++++++++-- test/WsPrivateAsyncTest.py | 228 +++++++++++++++++++++++++++++++- test/WsPublicAsyncTest.py | 48 ++++++- 4 files changed, 484 insertions(+), 30 deletions(-) diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index c5359aa..c085c19 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import warnings from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory @@ -9,7 +10,7 @@ class WsPrivateAsync: - def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): + def __init__(self, apiKey, passphrase, secretKey, url, useServerTime=None, debug=False): self.url = url self.subscriptions = set() self.callback = None @@ -18,28 +19,43 @@ def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): self.apiKey = apiKey self.passphrase = passphrase self.secretKey = secretKey - self.useServerTime = useServerTime + self.useServerTime = False self.websocket = None + self.debug = debug + + # 设置日志级别 + if debug: + logger.setLevel(logging.DEBUG) + + # 废弃 useServerTime 参数警告 + if useServerTime is not None: + warnings.warn("useServerTime parameter is deprecated. Please remove it.", DeprecationWarning) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback logRes = await self.login() await asyncio.sleep(5) if logRes: - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() @@ -50,26 +66,133 @@ async def login(self): passphrase=self.passphrase, secretKey=self.secretKey ) + if self.debug: + logger.debug(f"login: {loginPayload}") await self.websocket.send(loginPayload) return True - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) - logger.info(f"unsubscribe: {payload}") + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + 通用发送方法 + :param op: 操作类型 + :param args: 参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) - # for param in params: - # self.subscriptions.discard(param) + + async def place_order(self, args: list, callback=None, id: str = None): + """ + 下单 + :param args: 下单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("order", args, id=id) + + async def batch_orders(self, args: list, callback=None, id: str = None): + """ + 批量下单 + :param args: 批量下单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("batch-orders", args, id=id) + + async def cancel_order(self, args: list, callback=None, id: str = None): + """ + 撤单 + :param args: 撤单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("cancel-order", args, id=id) + + async def batch_cancel_orders(self, args: list, callback=None, id: str = None): + """ + 批量撤单 + :param args: 批量撤单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("batch-cancel-orders", args, id=id) + + async def amend_order(self, args: list, callback=None, id: str = None): + """ + 改单 + :param args: 改单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("amend-order", args, id=id) + + async def batch_amend_orders(self, args: list, callback=None, id: str = None): + """ + 批量改单 + :param args: 批量改单参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("batch-amend-orders", args, id=id) + + async def mass_cancel(self, args: list, callback=None, id: str = None): + """ + Mass cancel (批量撤销) + 注意:此方法用于 /ws/v5/business 频道,限速 1次/秒 + :param args: 撤销参数列表,包含 instType 和 instFamily + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + await self.send("mass-cancel", args, id=id) async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index e576d65..997b27c 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -2,53 +2,118 @@ import json import logging +from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory logger = logging.getLogger(__name__) class WsPublicAsync: - def __init__(self, url): + def __init__(self, url, apiKey='', passphrase='', secretKey='', debug=False): self.url = url self.subscriptions = set() self.callback = None self.loop = asyncio.get_event_loop() self.factory = WebSocketFactory(url) self.websocket = None + self.debug = debug + # 用于 business 频道的登录凭证 + self.apiKey = apiKey + self.passphrase = passphrase + self.secretKey = secretKey + self.isLoggedIn = False + + # 设置日志级别 + if debug: + logger.setLevel(logging.DEBUG) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def login(self): + """ + 登录方法,用于需要登录的 business 频道(如 /ws/v5/business) + """ + if not self.apiKey or not self.secretKey or not self.passphrase: + raise ValueError("apiKey, secretKey and passphrase are required for login") + + loginPayload = WsUtils.initLoginParams( + useServerTime=False, + apiKey=self.apiKey, + passphrase=self.passphrase, + secretKey=self.secretKey + ) + if self.debug: + logger.debug(f"login: {loginPayload}") + await self.websocket.send(loginPayload) + self.isLoggedIn = True + return True + + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) - logger.info(f"unsubscribe: {payload}") + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + 通用发送方法 + :param op: 操作类型 + :param args: 参数列表 + :param callback: 回调函数 + :param id: 可选的请求ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index ba7fcff..f478984 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -8,13 +8,14 @@ def privateCallback(message): async def main(): + """订阅测试""" url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( apiKey="your apiKey", passphrase="your passphrase", secretKey="your secretKey", url=url, - useServerTime=False + debug=True ) await ws.start() args = [] @@ -33,7 +34,230 @@ async def main(): print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg3] await ws.unsubscribe(args3, callback=privateCallback) + await asyncio.sleep(1) + await ws.stop() + + +async def test_place_order(): + """ + 测试下单功能 + URL: /ws/v5/private (限速: 60次/秒) + """ + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 下单参数 + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "client_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.place_order(order_args, callback=privateCallback, id="order001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_orders(): + """ + 测试批量下单功能 + URL: /ws/v5/private (限速: 60次/秒, 最多20个订单) + """ + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 批量下单参数 (最多20个) + order_args = [ + { + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }, + { + "instId": "ETH-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_002", + "side": "buy", + "ordType": "limit", + "sz": "0.01", + "px": "2000" + } + ] + await ws.batch_orders(order_args, callback=privateCallback, id="batchOrder001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_cancel_order(): + """ + 测试撤单功能 + URL: /ws/v5/private (限速: 60次/秒) + """ + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 撤单参数 (ordId 和 clOrdId 必须传一个) + cancel_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id" + # 或者使用 "clOrdId": "client_order_001" + }] + await ws.cancel_order(cancel_args, callback=privateCallback, id="cancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_cancel_orders(): + """ + 测试批量撤单功能 + URL: /ws/v5/private (限速: 60次/秒, 最多20个订单) + """ + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "order_id_1"}, + {"instId": "ETH-USDT", "ordId": "order_id_2"} + ] + await ws.batch_cancel_orders(cancel_args, callback=privateCallback, id="batchCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_amend_order(): + """ + 测试改单功能 + URL: /ws/v5/private (限速: 60次/秒) + """ + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 改单参数 + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id", + "newSz": "0.002", + "newPx": "31000" + }] + await ws.amend_order(amend_args, callback=privateCallback, id="amend001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_mass_cancel(): + """ + 测试批量撤销功能 + URL: /ws/v5/business (限速: 1次/秒) + 注意: 此功能使用 business 频道 + """ + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 批量撤销参数 + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + await ws.mass_cancel(mass_cancel_args, callback=privateCallback, id="massCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_send_method(): + """测试通用send方法""" + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # 使用通用send方法下单 - 注意要传入callback才能收到响应 + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.send("order", order_args, callback=privateCallback, id="send001") + await asyncio.sleep(5) + await ws.stop() if __name__ == '__main__': - asyncio.run(main()) + # asyncio.run(main()) + asyncio.run(test_place_order()) + asyncio.run(test_batch_orders()) + asyncio.run(test_cancel_order()) + asyncio.run(test_batch_cancel_orders()) + asyncio.run(test_amend_order()) + asyncio.run(test_mass_cancel()) # 注意使用 business 频道 + asyncio.run(test_send_method()) diff --git a/test/WsPublicAsyncTest.py b/test/WsPublicAsyncTest.py index 14276a0..8fda306 100644 --- a/test/WsPublicAsyncTest.py +++ b/test/WsPublicAsyncTest.py @@ -8,10 +8,9 @@ def publicCallback(message): async def main(): - # url = "wss://wspap.okex.com:8443/ws/v5/public?brokerId=9999" url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" - ws = WsPublicAsync(url=url) + ws = WsPublicAsync(url=url, debug=True) # 开启debug日志 await ws.start() args = [] arg1 = {"channel": "instruments", "instType": "FUTURES"} @@ -31,7 +30,50 @@ async def main(): print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg2, arg3] await ws.unsubscribe(args3, publicCallback) + await asyncio.sleep(1) + await ws.stop() + + +async def test_business_channel_with_login(): + """ + 测试 business 频道的登录功能 + business 频道需要登录后才能订阅某些私有数据 + """ + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPublicAsync( + url=url, + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + debug=True + ) + await ws.start() + + # 登录 + await ws.login() + await asyncio.sleep(5) + + # 订阅需要登录的频道 + args = [{"channel": "candle1m", "instId": "BTC-USDT"}] + await ws.subscribe(args, publicCallback) + await asyncio.sleep(30) + await ws.stop() + + +async def test_send_method(): + """测试通用send方法""" + url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" + ws = WsPublicAsync(url=url, debug=True) + await ws.start() + + # 使用通用send方法订阅 - 注意要传入callback才能收到响应 + args = [{"channel": "tickers", "instId": "BTC-USDT"}] + await ws.send("subscribe", args, callback=publicCallback, id="send001") + await asyncio.sleep(10) + await ws.stop() if __name__ == '__main__': - asyncio.run(main()) + # asyncio.run(main()) + # asyncio.run(test_business_channel_with_login()) + asyncio.run(test_send_method()) From 12ebac75c4607788ce8ef2d5ddf350f416f78165 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 15:15:09 +0800 Subject: [PATCH 07/48] =?UTF-8?q?=E5=88=9B=E5=BB=BA0.4.1=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E5=88=86=E6=94=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- okx/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/okx/__init__.py b/okx/__init__.py index 2dbeeb8..57f92fe 100644 --- a/okx/__init__.py +++ b/okx/__init__.py @@ -2,4 +2,4 @@ Python SDK for the OKX API v5 """ -__version__="0.4.0" \ No newline at end of file +__version__="0.4.1" \ No newline at end of file From f78c2fede2145688649ce1437f720f3df7777b6a Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 15:23:35 +0800 Subject: [PATCH 08/48] add set auto earn endpoint --- okx/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/okx/__init__.py b/okx/__init__.py index 57f92fe..2dbeeb8 100644 --- a/okx/__init__.py +++ b/okx/__init__.py @@ -2,4 +2,4 @@ Python SDK for the OKX API v5 """ -__version__="0.4.1" \ No newline at end of file +__version__="0.4.0" \ No newline at end of file From e82ee55a0832839eecd361c998b3e59700e9af42 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Tue, 16 Dec 2025 10:26:08 +0800 Subject: [PATCH 09/48] feat: add idxVol and fix wrong params for posBuilder --- okx/Account.py | 8 +++++--- test/__init__.py | 0 2 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 test/__init__.py diff --git a/okx/Account.py b/okx/Account.py index 911cf9b..8de8532 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -28,20 +28,22 @@ def get_positions(self, instType='', instId='', posId=''): return self._request_with_params(GET, POSITION_INFO, params) def position_builder(self, acctLv=None,inclRealPosAndEq=False, lever=None, greeksType=None, simPos=None, - simAsset=None): + simAsset=None, idxVol=None): params = {} if acctLv is not None: params['acctLv'] = acctLv if inclRealPosAndEq is not None: params['inclRealPosAndEq'] = inclRealPosAndEq if lever is not None: - params['spotOffsetType'] = lever + params['lever'] = lever if greeksType is not None: - params['greksType'] = greeksType + params['greeksType'] = greeksType if simPos is not None: params['simPos'] = simPos if simAsset is not None: params['simAsset'] = simAsset + if idxVol is not None: + params['idxVol'] = idxVol return self._request_with_params(POST, POSITION_BUILDER, params) # Get Bills Details (recent 7 days) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 From 3f739c23f362491d8886f39c0c44089ecf183787 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Tue, 16 Dec 2025 10:55:37 +0800 Subject: [PATCH 10/48] feat: add unit-testing for the test_account --- okx/Account.py | 2 +- test/unit/__init__.py | 10 + test/unit/okx/__init__.py | 2 + test/unit/okx/test_account.py | 441 ++++++++++++++++++++++++++++++++++ 4 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 test/unit/__init__.py create mode 100644 test/unit/okx/__init__.py create mode 100644 test/unit/okx/test_account.py diff --git a/okx/Account.py b/okx/Account.py index 8de8532..e3952ff 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -27,7 +27,7 @@ def get_positions(self, instType='', instId='', posId=''): params = {'instType': instType, 'instId': instId, 'posId': posId} return self._request_with_params(GET, POSITION_INFO, params) - def position_builder(self, acctLv=None,inclRealPosAndEq=False, lever=None, greeksType=None, simPos=None, + def position_builder(self, acctLv=None, inclRealPosAndEq=None, lever=None, greeksType=None, simPos=None, simAsset=None, idxVol=None): params = {} if acctLv is not None: diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..940cd5e --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1,10 @@ +""" +Unit tests package + +Unit tests mirror the source code structure for easy navigation. + +Example: + okx/Account.py -> test/unit/okx/test_account.py + okx/Trade.py -> test/unit/okx/test_trade.py + okx/Finance/Savings.py -> test/unit/okx/Finance/test_savings.py +""" diff --git a/test/unit/okx/__init__.py b/test/unit/okx/__init__.py new file mode 100644 index 0000000..2b43621 --- /dev/null +++ b/test/unit/okx/__init__.py @@ -0,0 +1,2 @@ +"""Unit tests for okx package""" + diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py new file mode 100644 index 0000000..6e5f0c4 --- /dev/null +++ b/test/unit/okx/test_account.py @@ -0,0 +1,441 @@ +""" +Unit tests for okx.Account module + +Mirrors the structure: okx/Account.py -> test/unit/okx/test_account.py +""" +import unittest +from unittest.mock import patch +from okx.Account import AccountAPI +from okx import consts as c + + +class TestAccountAPIPositionBuilder(unittest.TestCase): + """Unit tests for the position_builder method""" + + def setUp(self): + """Set up test fixtures""" + self.api_key = 'test_api_key' + self.api_secret = 'test_api_secret' + self.passphrase = 'test_passphrase' + self.account_api = AccountAPI( + api_key=self.api_key, + api_secret_key=self.api_secret, + passphrase=self.passphrase, + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_all_parameters(self, mock_request): + """Test position_builder with all parameters provided""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '1000', + 'imr': '2000', + 'mmrBf': '900', + 'imrBf': '1900' + }] + } + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}] + sim_asset = [{'ccy': 'USDT', 'amt': '10000'}] + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=True, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': True, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_idxVol_only(self, mock_request): + """Test position_builder with only idxVol parameter""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0.1') + + # Assert + expected_params = { + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_negative_idxVol(self, mock_request): + """Test position_builder with negative idxVol (price decrease)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.05') + + # Assert + expected_params = { + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_no_parameters(self, mock_request): + """Test position_builder with no parameters (all None)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder() + + # Assert + # Should pass empty params dict + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, {}) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_simulated_positions(self, mock_request): + """Test position_builder with simulated positions and assets""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '5000', + 'imr': '10000' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '100', 'avgPx': '3000'} + ] + sim_asset = [ + {'ccy': 'USDT', 'amt': '100000'}, + {'ccy': 'BTC', 'amt': '1'} + ] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.1' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_pa(self, mock_request): + """Test position_builder with greeksType PA""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='PA') + + # Assert + expected_params = {'greeksType': 'PA'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_bs(self, mock_request): + """Test position_builder with greeksType BS""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='BS') + + # Assert + expected_params = {'greeksType': 'BS'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_includes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=True""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=True, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': True, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_excludes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=False (only virtual positions)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '5', 'avgPx': '60000'}] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_account_level(self, mock_request): + """Test position_builder with specific account level""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(acctLv='3') + + # Assert + expected_params = {'acctLv': '3'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_leverage(self, mock_request): + """Test position_builder with leverage parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(lever='10') + + # Assert + expected_params = {'lever': '10'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_positive(self, mock_request): + """Test position_builder with maximum positive volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='1') + + # Assert + expected_params = {'idxVol': '1'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_negative(self, mock_request): + """Test position_builder with maximum negative volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.99') + + # Assert + expected_params = {'idxVol': '-0.99'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_complex_scenario(self, mock_request): + """Test position_builder with a complex realistic scenario""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '15000', + 'imr': '30000', + 'mmrBf': '14000', + 'imrBf': '28000', + 'markPxBf': '49500' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '-50', 'avgPx': '3000'} + ] + sim_asset = [{'ccy': 'USDT', 'amt': '50000'}] + + # Act - Simulate a 5% market drop + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=False, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='-0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': False, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result['code'], '0') + self.assertIn('mmrBf', result['data'][0]) + self.assertIn('imrBf', result['data'][0]) + + +class TestAccountAPIPositionBuilderParameterHandling(unittest.TestCase): + """Test parameter handling and edge cases""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_none_parameters_are_excluded(self, mock_request): + """Test that None parameters are not included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=None, # Should be excluded + lever=None, # Should be excluded + greeksType='PA', + simPos=None, # Should be excluded + simAsset=None, # Should be excluded + idxVol='0.05' + ) + + # Assert - Only non-None params should be in the call + expected_params = { + 'acctLv': '2', + 'greeksType': 'PA', + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_false_value_for_inclRealPosAndEq_is_included(self, mock_request): + """Test that False value for inclRealPosAndEq is included (not treated as None)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(inclRealPosAndEq=False) + + # Assert - False should be included + expected_params = { + 'inclRealPosAndEq': False + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_empty_lists_are_included(self, mock_request): + """Test that empty lists are included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + simPos=[], + simAsset=[] + ) + + # Assert + expected_params = { + 'simPos': [], + 'simAsset': [] + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_zero_idxVol_is_included(self, mock_request): + """Test that zero idxVol is included (represents no volatility change)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0') + + # Assert + expected_params = { + 'idxVol': '0' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + +if __name__ == '__main__': + unittest.main() + From 94c20702cfd3b9da0774ba73736827bf5512adcf Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 17:53:20 +0800 Subject: [PATCH 11/48] websocket enhancement --- test/unit/__init__.py | 2 + test/unit/okx/__init__.py | 2 + test/unit/okx/websocket/__init__.py | 2 + .../okx/websocket/test_ws_private_async.py | 585 ++++++++++++++++++ .../okx/websocket/test_ws_public_async.py | 322 ++++++++++ 5 files changed, 913 insertions(+) create mode 100644 test/unit/__init__.py create mode 100644 test/unit/okx/__init__.py create mode 100644 test/unit/okx/websocket/__init__.py create mode 100644 test/unit/okx/websocket/test_ws_private_async.py create mode 100644 test/unit/okx/websocket/test_ws_public_async.py diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..75f7509 --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1,2 @@ +# Unit tests for okx SDK + diff --git a/test/unit/okx/__init__.py b/test/unit/okx/__init__.py new file mode 100644 index 0000000..7e12b04 --- /dev/null +++ b/test/unit/okx/__init__.py @@ -0,0 +1,2 @@ +# Unit tests for okx module + diff --git a/test/unit/okx/websocket/__init__.py b/test/unit/okx/websocket/__init__.py new file mode 100644 index 0000000..b0061bd --- /dev/null +++ b/test/unit/okx/websocket/__init__.py @@ -0,0 +1,2 @@ +# Unit tests for okx.websocket module + diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py new file mode 100644 index 0000000..e81de0d --- /dev/null +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -0,0 +1,585 @@ +""" +Unit tests for okx.websocket.WsPrivateAsync module + +Mirrors the structure: okx/websocket/WsPrivateAsync.py -> test/unit/okx/websocket/test_ws_private_async.py +""" +import json +import unittest +import asyncio +import warnings +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPrivateAsyncInit(unittest.TestCase): + """Unit tests for WsPrivateAsync initialization""" + + def test_init_with_required_params(self): + """Test initialization with required parameters only""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory: + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + self.assertEqual(ws.url, "wss://test.example.com") + self.assertFalse(ws.useServerTime) + self.assertFalse(ws.debug) + mock_factory.assert_called_once_with("wss://test.example.com") + + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + debug=True + ) + + self.assertTrue(ws.debug) + + def test_init_with_deprecated_useServerTime_shows_warning(self): + """Test that using deprecated useServerTime parameter shows warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=True + ) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("useServerTime parameter is deprecated", str(w[0].message)) + + def test_init_without_useServerTime_no_warning(self): + """Test that not using useServerTime parameter shows no warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + # No deprecation warning expected + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + self.assertEqual(len(deprecation_warnings), 0) + + +class TestWsPrivateAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync subscribe method""" + + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ + patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + + mock_init_login.return_value = '{"op":"login"}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ + patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + + mock_init_login.return_value = '{"op":"login"}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync unsubscribe method""" + + def test_unsubscribe_without_id(self): + """Test unsubscribe without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncSend(unittest.TestCase): + """Unit tests for WsPrivateAsync generic send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncOrderMethods(unittest.TestCase): + """Unit tests for WsPrivateAsync order-related methods""" + + def _create_ws_instance(self): + """Helper to create WsPrivateAsync instance with mocked websocket""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + return ws, mock_websocket + + def test_place_order_sends_correct_payload(self): + """Test place_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + + async def run_test(): + await ws.place_order(order_args, callback=callback, id="order001") + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "order001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_place_order_without_id(self): + """Test place_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.place_order(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_sends_correct_payload(self): + """Test batch_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [ + {"instId": "BTC-USDT", "side": "buy", "sz": "0.001", "px": "30000"}, + {"instId": "ETH-USDT", "side": "buy", "sz": "0.01", "px": "2000"} + ] + + async def run_test(): + await ws.batch_orders(order_args, callback=callback, id="batch001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "batch001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_without_id(self): + """Test batch_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}, {"instId": "ETH-USDT"}] + + async def run_test(): + await ws.batch_orders(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_sends_correct_payload(self): + """Test cancel_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args, callback=callback, id="cancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "cancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_without_id(self): + """Test cancel_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_sends_correct_payload(self): + """Test batch_cancel_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "12345"}, + {"instId": "ETH-USDT", "ordId": "67890"} + ] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args, callback=callback, id="batchCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "batchCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_without_id(self): + """Test batch_cancel_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_sends_correct_payload(self): + """Test amend_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "12345", + "newSz": "0.002", + "newPx": "31000" + }] + + async def run_test(): + await ws.amend_order(amend_args, callback=callback, id="amend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "amend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_without_id(self): + """Test amend_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.amend_order(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_sends_correct_payload(self): + """Test batch_amend_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [ + {"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}, + {"instId": "ETH-USDT", "ordId": "67890", "newPx": "2100"} + ] + + async def run_test(): + await ws.batch_amend_orders(amend_args, callback=callback, id="batchAmend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "batchAmend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_without_id(self): + """Test batch_amend_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.batch_amend_orders(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_sends_correct_payload(self): + """Test mass_cancel sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args, callback=callback, id="massCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertEqual(payload["args"], mass_cancel_args) + self.assertEqual(payload["id"], "massCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_without_id(self): + """Test mass_cancel without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + mass_cancel_args = [{"instType": "SPOT", "instFamily": "BTC-USDT"}] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncLogin(unittest.TestCase): + """Unit tests for WsPrivateAsync login method""" + + def test_login_calls_init_login_params(self): + """Test login calls WsUtils.initLoginParams with correct parameters""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + mock_init_login.assert_called_once_with( + useServerTime=False, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncStartStop(unittest.TestCase): + """Unit tests for WsPrivateAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py new file mode 100644 index 0000000..b2b2c8f --- /dev/null +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -0,0 +1,322 @@ +""" +Unit tests for okx.websocket.WsPublicAsync module + +Mirrors the structure: okx/websocket/WsPublicAsync.py -> test/unit/okx/websocket/test_ws_public_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPublicAsyncInit(unittest.TestCase): + """Unit tests for WsPublicAsync initialization""" + + def test_init_with_url_only(self): + """Test initialization with only url parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + self.assertEqual(ws.url, "wss://test.example.com") + self.assertEqual(ws.apiKey, '') + self.assertEqual(ws.passphrase, '') + self.assertEqual(ws.secretKey, '') + self.assertFalse(ws.debug) + self.assertFalse(ws.isLoggedIn) + + def test_init_with_credentials(self): + """Test initialization with all credentials for business channel""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=True) + + self.assertTrue(ws.debug) + + def test_init_with_debug_disabled(self): + """Test initialization with debug mode disabled (default)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=False) + + self.assertFalse(ws.debug) + + +class TestWsPublicAsyncLogin(unittest.TestCase): + """Unit tests for WsPublicAsync login method""" + + def test_login_without_credentials_raises_error(self): + """Test that login raises ValueError when credentials are missing""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + with self.assertRaises(ValueError) as context: + await ws.login() + self.assertIn("apiKey, secretKey and passphrase are required for login", str(context.exception)) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_login_with_credentials_success(self): + """Test successful login with valid credentials""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory, \ + patch('okx.websocket.WsPublicAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + self.assertTrue(ws.isLoggedIn) + mock_init_login.assert_called_once_with( + useServerTime=False, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket.send.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync subscribe method""" + + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + mock_websocket.send.assert_called_once() + + # Verify the payload + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + + # Verify the payload includes id + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_multiple_channels(self): + """Test subscribe with multiple channels""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [ + {"channel": "tickers", "instId": "BTC-USDT"}, + {"channel": "tickers", "instId": "ETH-USDT"} + ] + + async def run_test(): + await ws.subscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(len(payload["args"]), 2) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync unsubscribe method""" + + def test_unsubscribe_without_id(self): + """Test unsubscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncSend(unittest.TestCase): + """Unit tests for WsPublicAsync send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_without_callback(self): + """Test send method without callback (preserves existing callback)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + existing_callback = MagicMock() + ws.callback = existing_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args) + # Callback should remain unchanged + self.assertEqual(ws.callback, existing_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_new_callback_replaces_existing(self): + """Test send method with new callback replaces existing callback""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + old_callback = MagicMock() + new_callback = MagicMock() + ws.callback = old_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=new_callback) + self.assertEqual(ws.callback, new_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncStartStop(unittest.TestCase): + """Unit tests for WsPublicAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() From 0f4d44c10c69eedef73ddc897ad25e8533c05bfc Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 18:04:44 +0800 Subject: [PATCH 12/48] websocket enhancement --- test/unit/okx/websocket/__init__.py | 1 + .../okx/websocket/test_ws_private_async.py | 219 ++++++++++++++++++ .../okx/websocket/test_ws_public_async.py | 163 +++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 test/unit/okx/websocket/__init__.py create mode 100644 test/unit/okx/websocket/test_ws_private_async.py create mode 100644 test/unit/okx/websocket/test_ws_public_async.py diff --git a/test/unit/okx/websocket/__init__.py b/test/unit/okx/websocket/__init__.py new file mode 100644 index 0000000..98f2807 --- /dev/null +++ b/test/unit/okx/websocket/__init__.py @@ -0,0 +1 @@ +"""Unit tests for okx.websocket package""" diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py new file mode 100644 index 0000000..6417698 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -0,0 +1,219 @@ +""" +Unit tests for okx.websocket.WsPrivateAsync module + +Mirrors the structure: okx/websocket/WsPrivateAsync.py -> test/unit/okx/websocket/test_ws_private_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPrivateAsyncInit(unittest.TestCase): + """Unit tests for WsPrivateAsync initialization""" + + def test_init_with_required_params(self): + """Test initialization with required parameters""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory: + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + self.assertEqual(ws.url, "wss://test.example.com") + self.assertFalse(ws.useServerTime) + mock_factory.assert_called_once_with("wss://test.example.com") + + +class TestWsPrivateAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync subscribe method""" + + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ + patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + + mock_init_login.return_value = '{"op":"login"}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ + patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + + mock_init_login.return_value = '{"op":"login"}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync unsubscribe method""" + + def test_unsubscribe_without_id(self): + """Test unsubscribe without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncLogin(unittest.TestCase): + """Unit tests for WsPrivateAsync login method""" + + def test_login_calls_init_login_params(self): + """Test login calls WsUtils.initLoginParams with correct parameters""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=True + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + mock_init_login.assert_called_once_with( + useServerTime=True, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncStartStop(unittest.TestCase): + """Unit tests for WsPrivateAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py new file mode 100644 index 0000000..39dd2f1 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -0,0 +1,163 @@ +""" +Unit tests for okx.websocket.WsPublicAsync module + +Mirrors the structure: okx/websocket/WsPublicAsync.py -> test/unit/okx/websocket/test_ws_public_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPublicAsyncInit(unittest.TestCase): + """Unit tests for WsPublicAsync initialization""" + + def test_init_with_url(self): + """Test initialization with url parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory: + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + self.assertEqual(ws.url, "wss://test.example.com") + self.assertIsNone(ws.callback) + self.assertIsNone(ws.websocket) + mock_factory.assert_called_once_with("wss://test.example.com") + + +class TestWsPublicAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync subscribe method""" + + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + mock_websocket.send.assert_called_once() + + # Verify the payload + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + + # Verify the payload includes id + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_multiple_channels(self): + """Test subscribe with multiple channels""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [ + {"channel": "tickers", "instId": "BTC-USDT"}, + {"channel": "tickers", "instId": "ETH-USDT"} + ] + + async def run_test(): + await ws.subscribe(params, callback, id="multi001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(len(payload["args"]), 2) + self.assertEqual(payload["id"], "multi001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync unsubscribe method""" + + def test_unsubscribe_without_id(self): + """Test unsubscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncStartStop(unittest.TestCase): + """Unit tests for WsPublicAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() From 69a416ab5221ad51d90ed45adc02851d74361f00 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 18:20:05 +0800 Subject: [PATCH 13/48] add /api/v5/account/set-auto-earn endpoint --- test/unit/okx/test_account.py | 122 ++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py index 6e5f0c4..cb1c08b 100644 --- a/test/unit/okx/test_account.py +++ b/test/unit/okx/test_account.py @@ -436,6 +436,128 @@ def test_zero_idxVol_is_included(self, mock_request): mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) +class TestAccountAPISetAutoEarn(unittest.TestCase): + """Unit tests for the set_auto_earn method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_with_all_params(self, mock_request): + """Test set_auto_earn with all parameters provided""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + earnType='current', + ccy='USDT', + action='start', + apr='0.05' + ) + + # Assert + expected_params = { + 'earnType': 'current', + 'ccy': 'USDT', + 'action': 'start', + 'apr': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_start_action(self, mock_request): + """Test set_auto_earn with start action""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + earnType='current', + ccy='BTC', + action='start' + ) + + # Assert + expected_params = { + 'earnType': 'current', + 'ccy': 'BTC', + 'action': 'start', + 'apr': '' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_stop_action(self, mock_request): + """Test set_auto_earn with stop action""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + earnType='current', + ccy='ETH', + action='stop' + ) + + # Assert + expected_params = { + 'earnType': 'current', + 'ccy': 'ETH', + 'action': 'stop', + 'apr': '' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_with_empty_params(self, mock_request): + """Test set_auto_earn with empty parameters (default values)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn() + + # Assert + expected_params = { + 'earnType': '', + 'ccy': '', + 'action': '', + 'apr': '' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_different_currencies(self, mock_request): + """Test set_auto_earn with different currencies""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + currencies = ['USDT', 'BTC', 'ETH', 'USDC'] + + for ccy in currencies: + mock_request.reset_mock() + result = self.account_api.set_auto_earn( + earnType='current', + ccy=ccy, + action='start' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['ccy'], ccy) + + if __name__ == '__main__': unittest.main() From 68fa15c76f00808b8b53c3d03c0c3c1b5390f6d4 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 15:46:34 +0800 Subject: [PATCH 14/48] add /api/v5/account/set-auto-earn endpoint --- okx/Account.py | 8 ++++++-- test/AccountTest.py | 2 +- test/unit/okx/test_account.py | 34 +++++++++++++++------------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/okx/Account.py b/okx/Account.py index fb7d991..65e412c 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -326,6 +326,10 @@ def spot_borrow_repay_history(self, ccy='', type='', after='', before='', limit= params = {'ccy': ccy, 'type': type, 'after': after, 'before': before, 'limit': limit} return self._request_with_params(GET, GET_BORROW_REPAY_HISTORY, params) - def set_auto_earn(self, earnType='', ccy='', action='', apr=''): - params = {'earnType': earnType, 'ccy': ccy, 'action': action, 'apr': apr} + def set_auto_earn(self, ccy, action, earnType=None, apr=None): + params = {'ccy': ccy, 'action': action} + if earnType is not None: + params['earnType'] = earnType + if apr is not None: + params['apr'] = apr return self._request_with_params(POST, SET_AUTO_EARN, params) diff --git a/test/AccountTest.py b/test/AccountTest.py index 3a3258d..14793d0 100644 --- a/test/AccountTest.py +++ b/test/AccountTest.py @@ -147,7 +147,7 @@ def setUp(self): # def test_spot_borrow_repay_history(self): # logger.debug(self.AccountAPI.spot_borrow_repay_history(ccy="USDT",type="auto_borrow",after="1597026383085")) def test_set_auto_earn(self): - logger.debug(self.AccountAPI.set_auto_earn(earnType='0',ccy="USDT",action="turn_on")) + logger.debug(self.AccountAPI.set_auto_earn(ccy="USDT", action="turn_on", earnType='0')) if __name__ == '__main__': unittest.main() diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py index cb1c08b..eb142d6 100644 --- a/test/unit/okx/test_account.py +++ b/test/unit/okx/test_account.py @@ -457,17 +457,17 @@ def test_set_auto_earn_with_all_params(self, mock_request): # Act result = self.account_api.set_auto_earn( - earnType='current', ccy='USDT', action='start', + earnType='current', apr='0.05' ) # Assert expected_params = { - 'earnType': 'current', 'ccy': 'USDT', 'action': 'start', + 'earnType': 'current', 'apr': '0.05' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) @@ -482,17 +482,16 @@ def test_set_auto_earn_start_action(self, mock_request): # Act result = self.account_api.set_auto_earn( - earnType='current', ccy='BTC', - action='start' + action='start', + earnType='current' ) # Assert expected_params = { - 'earnType': 'current', 'ccy': 'BTC', 'action': 'start', - 'apr': '' + 'earnType': 'current' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) @@ -505,36 +504,33 @@ def test_set_auto_earn_stop_action(self, mock_request): # Act result = self.account_api.set_auto_earn( - earnType='current', ccy='ETH', - action='stop' + action='stop', + earnType='current' ) # Assert expected_params = { - 'earnType': 'current', 'ccy': 'ETH', 'action': 'stop', - 'apr': '' + 'earnType': 'current' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) @patch.object(AccountAPI, '_request_with_params') - def test_set_auto_earn_with_empty_params(self, mock_request): - """Test set_auto_earn with empty parameters (default values)""" + def test_set_auto_earn_with_required_params_only(self, mock_request): + """Test set_auto_earn with required parameters only (ccy, action)""" # Arrange mock_response = {'code': '0', 'msg': '', 'data': []} mock_request.return_value = mock_response # Act - result = self.account_api.set_auto_earn() + result = self.account_api.set_auto_earn(ccy='USDT', action='turn_on') # Assert expected_params = { - 'earnType': '', - 'ccy': '', - 'action': '', - 'apr': '' + 'ccy': 'USDT', + 'action': 'turn_on' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) @@ -549,9 +545,9 @@ def test_set_auto_earn_different_currencies(self, mock_request): for ccy in currencies: mock_request.reset_mock() result = self.account_api.set_auto_earn( - earnType='current', ccy=ccy, - action='start' + action='turn_on', + earnType='0' ) call_args = mock_request.call_args[0][2] From 3aebd726dea6341138d2bf96993f1370e7dbf9f3 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Tue, 16 Dec 2025 18:52:11 +0800 Subject: [PATCH 15/48] add toAddrType to the 3 withdrawal endpoints --- okx/Funding.py | 12 ++++++------ test/FundingTest.py | 11 ++++++++++- test/WsPrivateAsyncTest.py | 6 ++++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/okx/Funding.py b/okx/Funding.py index 6ca0344..da25c25 100644 --- a/okx/Funding.py +++ b/okx/Funding.py @@ -35,9 +35,9 @@ def funds_transfer(self, ccy, amt, from_, to, type='0', subAcct='', instId='', t return self._request_with_params(POST, FUNDS_TRANSFER, params) # Withdrawal - def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId=''): + def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId='', toAddrType=''): params = {'ccy': ccy, 'amt': amt, 'dest': dest, 'toAddr': toAddr, 'chain': chain, - 'areaCode': areaCode, 'clientId': clientId} + 'areaCode': areaCode, 'clientId': clientId, 'toAddrType': toAddrType} return self._request_with_params(POST, WITHDRAWAL_COIN, params) # Get Deposit History @@ -47,8 +47,8 @@ def get_deposit_history(self, ccy='', type='', state='', after='', before='', li return self._request_with_params(GET, DEPOSIT_HISTORY, params) # Get Withdrawal History - def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='',txId=''): - params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit,'txId':txId} + def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='', txId='', toAddrType=''): + params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit, 'txId': txId, 'toAddrType': toAddrType} return self._request_with_params(GET, WITHDRAWAL_HISTORY, params) # Get Currencies @@ -113,7 +113,7 @@ def get_deposit_withdraw_status(self, wdId='', txId='', ccy='', to='', chain='') return self._request_with_params(GET, GET_DEPOSIT_WITHDrAW_STATUS, params) #Get withdrawal history - def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before ='', limit=''): - params = {'ccy': ccy, 'wdId': wdId, 'clientId': clientId, 'txId': txId, 'type': type, 'state': state, 'after': after, 'before': before, 'limit': limit} + def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before='', limit='', toAddrType=''): + params = {'ccy': ccy, 'wdId': wdId, 'clientId': clientId, 'txId': txId, 'type': type, 'state': state, 'after': after, 'before': before, 'limit': limit, 'toAddrType': toAddrType} return self._request_with_params(GET, GET_WITHDRAWAL_HISTORY, params) diff --git a/test/FundingTest.py b/test/FundingTest.py index e87bb76..d793fb8 100644 --- a/test/FundingTest.py +++ b/test/FundingTest.py @@ -78,7 +78,16 @@ def test_get_lending_summary(self): # print(self.FundingAPI.get_deposit_history()) def test_withdrawal(self): - print(self.FundingAPI.withdrawal(ccy='USDT',amt='1',dest='3',toAddr='18740405107',areaCode='86')) + # toAddrType: 地址类型 + # 1: 钱包地址、邮箱、手机号或登录账户名 + # 2: UID(仅适用于 dest=3 的情况) + print(self.FundingAPI.withdrawal(ccy='USDT', amt='1', dest='3', toAddr='18740405107', areaCode='86', toAddrType='1')) + + def test_get_withdrawal_history_with_toAddrType(self): + # toAddrType: 地址类型筛选 + # 1: 钱包地址、邮箱、手机号或登录账户名 + # 2: UID + print(self.FundingAPI.get_withdrawal_history(toAddrType='1')) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index ba7fcff..d0b6599 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -21,9 +21,15 @@ async def main(): arg1 = {"channel": "account", "ccy": "BTC"} arg2 = {"channel": "orders", "instType": "ANY"} arg3 = {"channel": "balance_and_position"} + # Withdrawal info channel 订阅示例,支持 toAddrType 参数 + # toAddrType: 地址类型 + # 1: 钱包地址、邮箱、手机号或登录账户名 + # 2: UID(仅适用于 dest=3 的情况) + arg4 = {"channel": "withdrawal-info", "ccy": "USDT", "toAddrType": "1"} args.append(arg1) args.append(arg2) args.append(arg3) + args.append(arg4) await ws.subscribe(args, callback=privateCallback) await asyncio.sleep(30) print("-----------------------------------------unsubscribe--------------------------------------------") From 7398c55ebaee08797ccac6442f3da2c9d7209e7f Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 18:22:48 +0800 Subject: [PATCH 16/48] add toAddrType to the 3 withdrawal endpoints --- test/unit/okx/test_funding.py | 226 ++++++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 test/unit/okx/test_funding.py diff --git a/test/unit/okx/test_funding.py b/test/unit/okx/test_funding.py new file mode 100644 index 0000000..7426f58 --- /dev/null +++ b/test/unit/okx/test_funding.py @@ -0,0 +1,226 @@ +""" +Unit tests for okx.Funding module + +Mirrors the structure: okx/Funding.py -> test/unit/okx/test_funding.py +""" +import unittest +from unittest.mock import patch +from okx.Funding import FundingAPI +from okx import consts as c + + +class TestFundingAPIWithdrawal(unittest.TestCase): + """Unit tests for the withdrawal method""" + + def setUp(self): + """Set up test fixtures""" + self.funding_api = FundingAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_required_params(self, mock_request): + """Test withdrawal with required parameters only""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': [{'wdId': '12345'}]} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='100', + dest='4', + toAddr='0x1234567890abcdef' + ) + + # Assert + expected_params = { + 'ccy': 'USDT', + 'amt': '100', + 'dest': '4', + 'toAddr': '0x1234567890abcdef', + 'chain': '', + 'areaCode': '', + 'clientId': '', + 'toAddrType': '' + } + mock_request.assert_called_once_with(c.POST, c.WITHDRAWAL_COIN, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_all_params(self, mock_request): + """Test withdrawal with all parameters provided""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': [{'wdId': '12345'}]} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='100', + dest='4', + toAddr='0x1234567890abcdef', + chain='USDT-TRC20', + areaCode='86', + clientId='client123', + toAddrType='1' + ) + + # Assert + expected_params = { + 'ccy': 'USDT', + 'amt': '100', + 'dest': '4', + 'toAddr': '0x1234567890abcdef', + 'chain': 'USDT-TRC20', + 'areaCode': '86', + 'clientId': 'client123', + 'toAddrType': '1' + } + mock_request.assert_called_once_with(c.POST, c.WITHDRAWAL_COIN, expected_params) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_toAddrType_okx_account(self, mock_request): + """Test withdrawal with toAddrType for OKX account""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='50', + dest='3', + toAddr='user@example.com', + toAddrType='1' # OKX account + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['toAddrType'], '1') + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_toAddrType_external(self, mock_request): + """Test withdrawal with toAddrType for external address""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='BTC', + amt='0.1', + dest='4', + toAddr='bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh', + toAddrType='2' # External address + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['toAddrType'], '2') + + +class TestFundingAPIGetWithdrawalHistory(unittest.TestCase): + """Unit tests for the get_withdrawal_history method""" + + def setUp(self): + """Set up test fixtures""" + self.funding_api = FundingAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_no_params(self, mock_request): + """Test get_withdrawal_history with no parameters""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history() + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['toAddrType'], '') + self.assertEqual(result, mock_response) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_toAddrType(self, mock_request): + """Test get_withdrawal_history with toAddrType parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history( + ccy='USDT', + toAddrType='1' + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['ccy'], 'USDT') + self.assertEqual(call_args['toAddrType'], '1') + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_all_params(self, mock_request): + """Test get_withdrawal_history with all parameters""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history( + ccy='BTC', + wdId='12345', + clientId='client123', + txId='tx123', + type='1', + state='2', + after='1609459200000', + before='1609545600000', + limit='10', + toAddrType='2' + ) + + # Assert + expected_params = { + 'ccy': 'BTC', + 'wdId': '12345', + 'clientId': 'client123', + 'txId': 'tx123', + 'type': '1', + 'state': '2', + 'after': '1609459200000', + 'before': '1609545600000', + 'limit': '10', + 'toAddrType': '2' + } + mock_request.assert_called_once_with(c.GET, c.GET_WITHDRAWAL_HISTORY, expected_params) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_filter_by_state(self, mock_request): + """Test get_withdrawal_history filtering by state""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + states = ['0', '1', '2', '3', '4', '5'] + + for state in states: + mock_request.reset_mock() + result = self.funding_api.get_withdrawal_history(state=state) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['state'], state) + + +if __name__ == '__main__': + unittest.main() + From df68d90b4758c16b4857342eca6e70fd72fc5547 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 15:32:11 +0800 Subject: [PATCH 17/48] add toAddrType to the 3 withdrawal endpoints --- okx/Funding.py | 18 ++++++++++++------ test/FundingTest.py | 10 +++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/okx/Funding.py b/okx/Funding.py index da25c25..91eea9d 100644 --- a/okx/Funding.py +++ b/okx/Funding.py @@ -35,9 +35,11 @@ def funds_transfer(self, ccy, amt, from_, to, type='0', subAcct='', instId='', t return self._request_with_params(POST, FUNDS_TRANSFER, params) # Withdrawal - def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId='', toAddrType=''): + def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId='', toAddrType=None): params = {'ccy': ccy, 'amt': amt, 'dest': dest, 'toAddr': toAddr, 'chain': chain, - 'areaCode': areaCode, 'clientId': clientId, 'toAddrType': toAddrType} + 'areaCode': areaCode, 'clientId': clientId} + if toAddrType is not None: + params['toAddrType'] = toAddrType return self._request_with_params(POST, WITHDRAWAL_COIN, params) # Get Deposit History @@ -47,8 +49,10 @@ def get_deposit_history(self, ccy='', type='', state='', after='', before='', li return self._request_with_params(GET, DEPOSIT_HISTORY, params) # Get Withdrawal History - def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='', txId='', toAddrType=''): - params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit, 'txId': txId, 'toAddrType': toAddrType} + def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='', txId='', toAddrType=None): + params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit, 'txId': txId} + if toAddrType is not None: + params['toAddrType'] = toAddrType return self._request_with_params(GET, WITHDRAWAL_HISTORY, params) # Get Currencies @@ -113,7 +117,9 @@ def get_deposit_withdraw_status(self, wdId='', txId='', ccy='', to='', chain='') return self._request_with_params(GET, GET_DEPOSIT_WITHDrAW_STATUS, params) #Get withdrawal history - def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before='', limit='', toAddrType=''): - params = {'ccy': ccy, 'wdId': wdId, 'clientId': clientId, 'txId': txId, 'type': type, 'state': state, 'after': after, 'before': before, 'limit': limit, 'toAddrType': toAddrType} + def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before='', limit='', toAddrType=None): + params = {'ccy': ccy, 'wdId': wdId, 'clientId': clientId, 'txId': txId, 'type': type, 'state': state, 'after': after, 'before': before, 'limit': limit} + if toAddrType is not None: + params['toAddrType'] = toAddrType return self._request_with_params(GET, GET_WITHDRAWAL_HISTORY, params) diff --git a/test/FundingTest.py b/test/FundingTest.py index d793fb8..f9fb0c6 100644 --- a/test/FundingTest.py +++ b/test/FundingTest.py @@ -78,14 +78,14 @@ def test_get_lending_summary(self): # print(self.FundingAPI.get_deposit_history()) def test_withdrawal(self): - # toAddrType: 地址类型 - # 1: 钱包地址、邮箱、手机号或登录账户名 - # 2: UID(仅适用于 dest=3 的情况) + # toAddrType: Address type + # 1: Wallet address, email, phone number or login account + # 2: UID (only applicable when dest=3) print(self.FundingAPI.withdrawal(ccy='USDT', amt='1', dest='3', toAddr='18740405107', areaCode='86', toAddrType='1')) def test_get_withdrawal_history_with_toAddrType(self): - # toAddrType: 地址类型筛选 - # 1: 钱包地址、邮箱、手机号或登录账户名 + # toAddrType: Address type filter + # 1: Wallet address, email, phone number or login account # 2: UID print(self.FundingAPI.get_withdrawal_history(toAddrType='1')) From 19f7393dd88a1dcdc35a158c68c998343acb33f4 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 15:39:26 +0800 Subject: [PATCH 18/48] add toAddrType to the 3 withdrawal endpoints --- okx/Funding.py | 7 +------ okx/consts.py | 1 - test/WsPrivateAsyncTest.py | 8 ++++---- test/unit/okx/test_funding.py | 5 ++--- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/okx/Funding.py b/okx/Funding.py index 91eea9d..a984778 100644 --- a/okx/Funding.py +++ b/okx/Funding.py @@ -48,12 +48,7 @@ def get_deposit_history(self, ccy='', type='', state='', after='', before='', li 'depId': depId, 'fromWdId': fromWdId} return self._request_with_params(GET, DEPOSIT_HISTORY, params) - # Get Withdrawal History - def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='', txId='', toAddrType=None): - params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit, 'txId': txId} - if toAddrType is not None: - params['toAddrType'] = toAddrType - return self._request_with_params(GET, WITHDRAWAL_HISTORY, params) + # Get Currencies def get_currencies(self, ccy=''): diff --git a/okx/consts.py b/okx/consts.py index aae0778..9ae2059 100644 --- a/okx/consts.py +++ b/okx/consts.py @@ -81,7 +81,6 @@ DEPOSIT_LIGHTNING = '/api/v5/asset/deposit-lightning' WITHDRAWAL_LIGHTNING = '/api/v5/asset/withdrawal-lightning' CANCEL_WITHDRAWAL = '/api/v5/asset/cancel-withdrawal' -WITHDRAWAL_HISTORY = '/api/v5/asset/withdrawal-history' CONVERT_DUST_ASSETS = '/api/v5/asset/convert-dust-assets' ASSET_VALUATION = '/api/v5/asset/asset-valuation' GET_WITHDRAWAL_HISTORY = '/api/v5/asset/withdrawal-history' diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index d0b6599..93d3ddf 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -21,10 +21,10 @@ async def main(): arg1 = {"channel": "account", "ccy": "BTC"} arg2 = {"channel": "orders", "instType": "ANY"} arg3 = {"channel": "balance_and_position"} - # Withdrawal info channel 订阅示例,支持 toAddrType 参数 - # toAddrType: 地址类型 - # 1: 钱包地址、邮箱、手机号或登录账户名 - # 2: UID(仅适用于 dest=3 的情况) + # Withdrawal info channel subscription example, supporting the toAddrType parameter + # toAddrType: Address type + # 1: Wallet address, email, phone number or login account name + # 2: UID (applicable only when dest=3) arg4 = {"channel": "withdrawal-info", "ccy": "USDT", "toAddrType": "1"} args.append(arg1) args.append(arg2) diff --git a/test/unit/okx/test_funding.py b/test/unit/okx/test_funding.py index 7426f58..61fb701 100644 --- a/test/unit/okx/test_funding.py +++ b/test/unit/okx/test_funding.py @@ -44,8 +44,7 @@ def test_withdrawal_with_required_params(self, mock_request): 'toAddr': '0x1234567890abcdef', 'chain': '', 'areaCode': '', - 'clientId': '', - 'toAddrType': '' + 'clientId': '' } mock_request.assert_called_once_with(c.POST, c.WITHDRAWAL_COIN, expected_params) self.assertEqual(result, mock_response) @@ -147,7 +146,7 @@ def test_get_withdrawal_history_with_no_params(self, mock_request): # Assert call_args = mock_request.call_args[0][2] - self.assertEqual(call_args['toAddrType'], '') + self.assertNotIn('toAddrType', call_args) self.assertEqual(result, mock_response) @patch.object(FundingAPI, '_request_with_params') From 759c0542fd98cc5eac547e628532b9ef22b37b8a Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Wed, 17 Dec 2025 10:28:48 +0800 Subject: [PATCH 19/48] add market-data-history endpoint --- okx/PublicData.py | 13 +++++++++++++ okx/consts.py | 1 + test/PublicDataTest.py | 19 +++++++++++++++++-- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/okx/PublicData.py b/okx/PublicData.py index b3bb961..9f193a1 100644 --- a/okx/PublicData.py +++ b/okx/PublicData.py @@ -122,3 +122,16 @@ def get_option_trades(self, instId='', instFamily='', optType=''): 'optType': optType } return self._request_with_params(GET, GET_OPTION_TRADES, params) + + # Get historical market data + def get_market_data_history(self, module, instType, dateAggrType, begin, end, instIdList='', instFamilyList=''): + params = { + 'module': module, + 'instType': instType, + 'dateAggrType': dateAggrType, + 'begin': begin, + 'end': end, + 'instIdList': instIdList, + 'instFamilyList': instFamilyList + } + return self._request_with_params(GET, MARKET_DATA_HISTORY, params) diff --git a/okx/consts.py b/okx/consts.py index 9ae2059..f207c39 100644 --- a/okx/consts.py +++ b/okx/consts.py @@ -129,6 +129,7 @@ CONVERT_CONTRACT_COIN = '/api/v5/public/convert-contract-coin' GET_OPTION_TICKBANDS = '/api/v5/public/instrument-tick-bands' GET_OPTION_TRADES = '/api/v5/public/option-trades' +MARKET_DATA_HISTORY = '/api/v5/public/market-data-history' # Trading data SUPPORT_COIN = '/api/v5/rubik/stat/trading-data/support-coin' diff --git a/test/PublicDataTest.py b/test/PublicDataTest.py index 7d7449d..56a028b 100644 --- a/test/PublicDataTest.py +++ b/test/PublicDataTest.py @@ -56,8 +56,23 @@ def test_get_mark_price(self): # def test_get_option_tickBands(self): # print(self.publicDataApi.get_option_tick_bands(instType='OPTION')) - def test_get_option_trades(self): - print(self.publicDataApi.get_option_trades(instFamily='BTC-USD')) + # def test_get_option_trades(self): + # print(self.publicDataApi.get_option_trades(instFamily='BTC-USD')) + + def test_get_market_data_history(self): + # module: 数据模块类型 + # 1: Trade history, 2: 1-minute candlestick, 3: Funding rate + # 5: 5000-level orderbook (from Nov 1, 2025), 6: 50-level orderbook + # instType: SPOT, FUTURES, SWAP, OPTION + # dateAggrType: daily, monthly + print(self.publicDataApi.get_market_data_history( + module='6', + instType='SPOT', + dateAggrType='daily', + begin='1761274032000', + end='1761883371133', + instIdList='BTC-USDT' + )) if __name__ == '__main__': unittest.main() \ No newline at end of file From 46817fe9cf295264b02d07c55dceb854992ce953 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 18:15:50 +0800 Subject: [PATCH 20/48] add market-data-history endpoint --- test/unit/okx/test_public_data.py | 212 ++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 test/unit/okx/test_public_data.py diff --git a/test/unit/okx/test_public_data.py b/test/unit/okx/test_public_data.py new file mode 100644 index 0000000..3a3a872 --- /dev/null +++ b/test/unit/okx/test_public_data.py @@ -0,0 +1,212 @@ +""" +Unit tests for okx.PublicData module + +Mirrors the structure: okx/PublicData.py -> test/unit/okx/test_public_data.py +""" +import unittest +from unittest.mock import patch +from okx.PublicData import PublicAPI +from okx import consts as c + + +class TestPublicAPIMarketDataHistory(unittest.TestCase): + """Unit tests for the get_market_data_history method""" + + def setUp(self): + """Set up test fixtures""" + self.public_api = PublicAPI(flag='0') + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_required_params(self, mock_request): + """Test get_market_data_history with required parameters only""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{'ts': '1234567890', 'vol': '1000'}] + } + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SPOT', + 'dateAggrType': '1D', + 'begin': '1609459200000', + 'end': '1609545600000', + 'instIdList': '', + 'instFamilyList': '' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_all_params(self, mock_request): + """Test get_market_data_history with all parameters provided""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{'ts': '1234567890', 'vol': '1000'}] + } + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SWAP', + dateAggrType='1W', + begin='1609459200000', + end='1609545600000', + instIdList='BTC-USDT-SWAP,ETH-USDT-SWAP', + instFamilyList='BTC-USDT,ETH-USDT' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SWAP', + 'dateAggrType': '1W', + 'begin': '1609459200000', + 'end': '1609545600000', + 'instIdList': 'BTC-USDT-SWAP,ETH-USDT-SWAP', + 'instFamilyList': 'BTC-USDT,ETH-USDT' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_inst_id_list(self, mock_request): + """Test get_market_data_history with instIdList parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000', + instIdList='BTC-USDT' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SPOT', + 'dateAggrType': '1D', + 'begin': '1609459200000', + 'end': '1609545600000', + 'instIdList': 'BTC-USDT', + 'instFamilyList': '' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_inst_family_list(self, mock_request): + """Test get_market_data_history with instFamilyList parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='openInterest', + instType='FUTURES', + dateAggrType='1M', + begin='1609459200000', + end='1612137600000', + instFamilyList='BTC-USD' + ) + + # Assert + expected_params = { + 'module': 'openInterest', + 'instType': 'FUTURES', + 'dateAggrType': '1M', + 'begin': '1609459200000', + 'end': '1612137600000', + 'instIdList': '', + 'instFamilyList': 'BTC-USD' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_inst_types(self, mock_request): + """Test get_market_data_history with different instType values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + inst_types = ['SPOT', 'SWAP', 'FUTURES', 'OPTION'] + + for inst_type in inst_types: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module='volume', + instType=inst_type, + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args + self.assertEqual(call_args[0][1], c.MARKET_DATA_HISTORY) + self.assertEqual(call_args[1]['instType'] if call_args[1] else call_args[0][2]['instType'], inst_type) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_date_aggr_types(self, mock_request): + """Test get_market_data_history with different dateAggrType values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + date_aggr_types = ['1D', '1W', '1M'] + + for aggr_type in date_aggr_types: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType=aggr_type, + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['dateAggrType'], aggr_type) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_modules(self, mock_request): + """Test get_market_data_history with different module values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + modules = ['volume', 'openInterest', 'tradeCount'] + + for module in modules: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module=module, + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['module'], module) + + +if __name__ == '__main__': + unittest.main() + From 4b165d06ff99840979d781b1696bf75ac6caf249 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 15:51:21 +0800 Subject: [PATCH 21/48] add market-data-history endpoint --- okx/PublicData.py | 10 ++++++---- test/unit/okx/test_public_data.py | 8 ++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/okx/PublicData.py b/okx/PublicData.py index 9f193a1..f441266 100644 --- a/okx/PublicData.py +++ b/okx/PublicData.py @@ -124,14 +124,16 @@ def get_option_trades(self, instId='', instFamily='', optType=''): return self._request_with_params(GET, GET_OPTION_TRADES, params) # Get historical market data - def get_market_data_history(self, module, instType, dateAggrType, begin, end, instIdList='', instFamilyList=''): + def get_market_data_history(self, module, instType, dateAggrType, begin, end, instIdList=None, instFamilyList=None): params = { 'module': module, 'instType': instType, 'dateAggrType': dateAggrType, 'begin': begin, - 'end': end, - 'instIdList': instIdList, - 'instFamilyList': instFamilyList + 'end': end } + if instIdList is not None: + params['instIdList'] = instIdList + if instFamilyList is not None: + params['instFamilyList'] = instFamilyList return self._request_with_params(GET, MARKET_DATA_HISTORY, params) diff --git a/test/unit/okx/test_public_data.py b/test/unit/okx/test_public_data.py index 3a3a872..abe6824 100644 --- a/test/unit/okx/test_public_data.py +++ b/test/unit/okx/test_public_data.py @@ -42,9 +42,7 @@ def test_get_market_data_history_with_required_params(self, mock_request): 'instType': 'SPOT', 'dateAggrType': '1D', 'begin': '1609459200000', - 'end': '1609545600000', - 'instIdList': '', - 'instFamilyList': '' + 'end': '1609545600000' } mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) self.assertEqual(result, mock_response) @@ -108,8 +106,7 @@ def test_get_market_data_history_with_inst_id_list(self, mock_request): 'dateAggrType': '1D', 'begin': '1609459200000', 'end': '1609545600000', - 'instIdList': 'BTC-USDT', - 'instFamilyList': '' + 'instIdList': 'BTC-USDT' } mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) @@ -137,7 +134,6 @@ def test_get_market_data_history_with_inst_family_list(self, mock_request): 'dateAggrType': '1M', 'begin': '1609459200000', 'end': '1612137600000', - 'instIdList': '', 'instFamilyList': 'BTC-USD' } mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) From 4391f0e8d29325e517c60a127c0a24275ae352f3 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 16:14:42 +0800 Subject: [PATCH 22/48] add /api/v5/account/set-auto-earn endpoint --- okx/Account.py | 4 +--- test/unit/okx/test_account.py | 34 ++++++++++++++++------------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/okx/Account.py b/okx/Account.py index 65e412c..e3d316f 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -326,10 +326,8 @@ def spot_borrow_repay_history(self, ccy='', type='', after='', before='', limit= params = {'ccy': ccy, 'type': type, 'after': after, 'before': before, 'limit': limit} return self._request_with_params(GET, GET_BORROW_REPAY_HISTORY, params) - def set_auto_earn(self, ccy, action, earnType=None, apr=None): + def set_auto_earn(self, ccy, action, earnType=None): params = {'ccy': ccy, 'action': action} if earnType is not None: params['earnType'] = earnType - if apr is not None: - params['apr'] = apr return self._request_with_params(POST, SET_AUTO_EARN, params) diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py index eb142d6..75794df 100644 --- a/test/unit/okx/test_account.py +++ b/test/unit/okx/test_account.py @@ -458,24 +458,22 @@ def test_set_auto_earn_with_all_params(self, mock_request): # Act result = self.account_api.set_auto_earn( ccy='USDT', - action='start', - earnType='current', - apr='0.05' + action='turn_on', + earnType='0' ) # Assert expected_params = { 'ccy': 'USDT', - 'action': 'start', - 'earnType': 'current', - 'apr': '0.05' + 'action': 'turn_on', + 'earnType': '0' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) self.assertEqual(result, mock_response) @patch.object(AccountAPI, '_request_with_params') - def test_set_auto_earn_start_action(self, mock_request): - """Test set_auto_earn with start action""" + def test_set_auto_earn_turn_on_action(self, mock_request): + """Test set_auto_earn with turn_on action""" # Arrange mock_response = {'code': '0', 'msg': '', 'data': []} mock_request.return_value = mock_response @@ -483,21 +481,21 @@ def test_set_auto_earn_start_action(self, mock_request): # Act result = self.account_api.set_auto_earn( ccy='BTC', - action='start', - earnType='current' + action='turn_on', + earnType='0' ) # Assert expected_params = { 'ccy': 'BTC', - 'action': 'start', - 'earnType': 'current' + 'action': 'turn_on', + 'earnType': '0' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) @patch.object(AccountAPI, '_request_with_params') - def test_set_auto_earn_stop_action(self, mock_request): - """Test set_auto_earn with stop action""" + def test_set_auto_earn_turn_off_action(self, mock_request): + """Test set_auto_earn with turn_off action""" # Arrange mock_response = {'code': '0', 'msg': '', 'data': []} mock_request.return_value = mock_response @@ -505,15 +503,15 @@ def test_set_auto_earn_stop_action(self, mock_request): # Act result = self.account_api.set_auto_earn( ccy='ETH', - action='stop', - earnType='current' + action='turn_off', + earnType='0' ) # Assert expected_params = { 'ccy': 'ETH', - 'action': 'stop', - 'earnType': 'current' + 'action': 'turn_off', + 'earnType': '0' } mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) From 2f9fea96d181a09d4c42c30bf01bd44dd9dcd892 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Wed, 17 Dec 2025 10:33:47 +0800 Subject: [PATCH 23/48] http version compatibility issue --- okx/okxclient.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/okx/okxclient.py b/okx/okxclient.py index e88e3a7..0357a69 100644 --- a/okx/okxclient.py +++ b/okx/okxclient.py @@ -14,7 +14,16 @@ class OkxClient(Client): def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_server_time=None, flag='1',base_api=c.API_URL, debug=False, proxy=None): - super().__init__(base_url=base_api, http2=True, proxy=proxy) + # 兼容不同版本的 httpx + # 新版本(0.24.0+)使用 proxy,旧版本使用 proxies + try: + super().__init__(base_url=base_api, http2=True, proxy=proxy) + except TypeError: + # 旧版本 httpx 使用 proxies 参数 + if proxy: + super().__init__(base_url=base_api, http2=True, proxies={'http://': proxy, 'https://': proxy}) + else: + super().__init__(base_url=base_api, http2=True) self.API_KEY = api_key self.API_SECRET_KEY = api_secret_key self.PASSPHRASE = passphrase From ed85a8c10a1bca119c00454574a6abbe5b7e3c77 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Thu, 18 Dec 2025 18:10:38 +0800 Subject: [PATCH 24/48] confirm if it's version compatibility issue --- test/unit/okx/test_okxclient.py | 186 ++++++++++++++++++ test/unit/okx/websocket/__init__.py | 1 + .../okx/websocket/test_ws_private_async.py | 163 +++++++++++++++ .../okx/websocket/test_ws_public_async.py | 121 ++++++++++++ 4 files changed, 471 insertions(+) create mode 100644 test/unit/okx/test_okxclient.py create mode 100644 test/unit/okx/websocket/__init__.py create mode 100644 test/unit/okx/websocket/test_ws_private_async.py create mode 100644 test/unit/okx/websocket/test_ws_public_async.py diff --git a/test/unit/okx/test_okxclient.py b/test/unit/okx/test_okxclient.py new file mode 100644 index 0000000..463507b --- /dev/null +++ b/test/unit/okx/test_okxclient.py @@ -0,0 +1,186 @@ +""" +Unit tests for okx.okxclient module + +Mirrors the structure: okx/okxclient.py -> test/unit/okx/test_okxclient.py +""" +import unittest +import warnings +from unittest.mock import patch, MagicMock + + +class TestOkxClientInit(unittest.TestCase): + """Unit tests for OkxClient initialization""" + + def test_init_with_default_parameters(self): + """Test initialization with default parameters""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient() + + self.assertEqual(client.API_KEY, '-1') + self.assertEqual(client.API_SECRET_KEY, '-1') + self.assertEqual(client.PASSPHRASE, '-1') + self.assertEqual(client.flag, '1') + self.assertFalse(client.debug) + + def test_init_with_custom_parameters(self): + """Test initialization with custom parameters""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0', + debug=True + ) + + self.assertEqual(client.API_KEY, 'test_key') + self.assertEqual(client.API_SECRET_KEY, 'test_secret') + self.assertEqual(client.PASSPHRASE, 'test_pass') + self.assertEqual(client.flag, '0') + self.assertTrue(client.debug) + + def test_init_with_deprecated_use_server_time_shows_warning(self): + """Test that using deprecated use_server_time parameter shows warning""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + client = OkxClient(use_server_time=True) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("use_server_time parameter is deprecated", str(w[0].message)) + + +class TestOkxClientHttpxCompatibility(unittest.TestCase): + """Unit tests for httpx version compatibility in OkxClient""" + + def test_init_with_new_httpx_proxy_parameter(self): + """Test initialization with new httpx version using proxy parameter""" + with patch('okx.okxclient.Client.__init__') as mock_init: + # Simulate new httpx version (accepts proxy parameter) + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient(proxy='http://proxy.example.com:8080') + + # Should call super().__init__ with proxy parameter + mock_init.assert_called_once() + call_kwargs = mock_init.call_args + self.assertIn('proxy', call_kwargs.kwargs) + self.assertEqual(call_kwargs.kwargs['proxy'], 'http://proxy.example.com:8080') + + def test_init_with_old_httpx_falls_back_to_proxies(self): + """Test initialization falls back to proxies for old httpx version""" + call_count = [0] + + def mock_init_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1 and 'proxy' in kwargs: + # First call with proxy parameter - simulate old httpx raising TypeError + raise TypeError("__init__() got an unexpected keyword argument 'proxy'") + # Second call should work (with proxies or without) + return None + + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.side_effect = mock_init_side_effect + + from okx.okxclient import OkxClient + client = OkxClient(proxy='http://proxy.example.com:8080') + + # Should have been called twice + self.assertEqual(mock_init.call_count, 2) + + # Second call should use proxies parameter + second_call = mock_init.call_args_list[1] + self.assertIn('proxies', second_call.kwargs) + expected_proxies = { + 'http://': 'http://proxy.example.com:8080', + 'https://': 'http://proxy.example.com:8080' + } + self.assertEqual(second_call.kwargs['proxies'], expected_proxies) + + def test_init_with_old_httpx_no_proxy(self): + """Test initialization with old httpx version without proxy""" + call_count = [0] + + def mock_init_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1 and 'proxy' in kwargs: + # First call with proxy parameter - simulate old httpx raising TypeError + raise TypeError("__init__() got an unexpected keyword argument 'proxy'") + return None + + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.side_effect = mock_init_side_effect + + from okx.okxclient import OkxClient + client = OkxClient() # No proxy + + # Should have been called twice + self.assertEqual(mock_init.call_count, 2) + + # Second call should not have proxies parameter + second_call = mock_init.call_args_list[1] + self.assertNotIn('proxies', second_call.kwargs) + + def test_init_without_proxy(self): + """Test initialization without proxy parameter""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient() + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + self.assertEqual(call_kwargs.get('proxy'), None) + + +class TestOkxClientRequest(unittest.TestCase): + """Unit tests for OkxClient request methods""" + + def setUp(self): + """Set up test fixtures""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + from okx.okxclient import OkxClient + self.client = OkxClient( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + def test_request_without_params(self): + """Test _request_without_params calls _request with empty dict""" + with patch.object(self.client, '_request') as mock_request: + mock_request.return_value = {'code': '0'} + + result = self.client._request_without_params('GET', '/api/v5/test') + + mock_request.assert_called_once_with('GET', '/api/v5/test', {}) + + def test_request_with_params(self): + """Test _request_with_params passes params correctly""" + with patch.object(self.client, '_request') as mock_request: + mock_request.return_value = {'code': '0'} + params = {'instId': 'BTC-USDT'} + + result = self.client._request_with_params('GET', '/api/v5/test', params) + + mock_request.assert_called_once_with('GET', '/api/v5/test', params) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/unit/okx/websocket/__init__.py b/test/unit/okx/websocket/__init__.py new file mode 100644 index 0000000..98f2807 --- /dev/null +++ b/test/unit/okx/websocket/__init__.py @@ -0,0 +1 @@ +"""Unit tests for okx.websocket package""" diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py new file mode 100644 index 0000000..3ee2210 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -0,0 +1,163 @@ +""" +Unit tests for okx.websocket.WsPrivateAsync module + +Mirrors the structure: okx/websocket/WsPrivateAsync.py -> test/unit/okx/websocket/test_ws_private_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPrivateAsyncInit(unittest.TestCase): + """Unit tests for WsPrivateAsync initialization""" + + def test_init_with_required_params(self): + """Test initialization with required parameters""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory: + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + self.assertEqual(ws.url, "wss://test.example.com") + self.assertFalse(ws.useServerTime) + mock_factory.assert_called_once_with("wss://test.example.com") + + +class TestWsPrivateAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync subscribe method""" + + def test_subscribe_sends_correct_payload(self): + """Test subscribe sends correct payload after login""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ + patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + + mock_init_login.return_value = '{"op":"login"}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync unsubscribe method""" + + def test_unsubscribe_sends_correct_payload(self): + """Test unsubscribe sends correct payload""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncLogin(unittest.TestCase): + """Unit tests for WsPrivateAsync login method""" + + def test_login_calls_init_login_params(self): + """Test login calls WsUtils.initLoginParams with correct parameters""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ + patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=True + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + mock_init_login.assert_called_once_with( + useServerTime=True, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncStartStop(unittest.TestCase): + """Unit tests for WsPrivateAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory and stops loop""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=False + ) + ws.loop = MagicMock() + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + ws.loop.stop.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py new file mode 100644 index 0000000..de2d223 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -0,0 +1,121 @@ +""" +Unit tests for okx.websocket.WsPublicAsync module + +Mirrors the structure: okx/websocket/WsPublicAsync.py -> test/unit/okx/websocket/test_ws_public_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + + +class TestWsPublicAsyncInit(unittest.TestCase): + """Unit tests for WsPublicAsync initialization""" + + def test_init_with_url(self): + """Test initialization with url parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory: + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + self.assertEqual(ws.url, "wss://test.example.com") + self.assertIsNone(ws.callback) + self.assertIsNone(ws.websocket) + mock_factory.assert_called_once_with("wss://test.example.com") + + +class TestWsPublicAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync subscribe method""" + + def test_subscribe_sets_callback(self): + """Test subscribe sets callback correctly""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + mock_websocket.send.assert_called_once() + + # Verify the payload + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_multiple_channels(self): + """Test subscribe with multiple channels""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [ + {"channel": "tickers", "instId": "BTC-USDT"}, + {"channel": "tickers", "instId": "ETH-USDT"} + ] + + async def run_test(): + await ws.subscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(len(payload["args"]), 2) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync unsubscribe method""" + + def test_unsubscribe_sends_correct_payload(self): + """Test unsubscribe sends correct payload""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncStartStop(unittest.TestCase): + """Unit tests for WsPublicAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory and stops loop""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + ws.loop = MagicMock() + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + ws.loop.stop.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() From 0023ce69de33f332d0f0c4ee75bde36bce3572df Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 16:17:18 +0800 Subject: [PATCH 25/48] confirm if it's version compatibility issue --- okx/okxclient.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/okx/okxclient.py b/okx/okxclient.py index 0357a69..6cd4751 100644 --- a/okx/okxclient.py +++ b/okx/okxclient.py @@ -14,12 +14,12 @@ class OkxClient(Client): def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_server_time=None, flag='1',base_api=c.API_URL, debug=False, proxy=None): - # 兼容不同版本的 httpx - # 新版本(0.24.0+)使用 proxy,旧版本使用 proxies + # Compatible with different versions of httpx + # New versions (0.24.0+) use proxy, older versions use proxies try: super().__init__(base_url=base_api, http2=True, proxy=proxy) except TypeError: - # 旧版本 httpx 使用 proxies 参数 + # Older versions of httpx use proxies parameter if proxy: super().__init__(base_url=base_api, http2=True, proxies={'http://': proxy, 'https://': proxy}) else: From 500baa685e86123431da90d8f6fd016058d64195 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 16:49:38 +0800 Subject: [PATCH 26/48] add id parameter to all websocket subscription --- okx/websocket/WsPrivateAsync.py | 6 +++++- okx/websocket/WsPublicAsync.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index dbf0390..32d8679 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -79,4 +79,8 @@ async def start(self): self.loop.create_task(self.consume()) def stop_sync(self): - self.loop.run_until_complete(self.stop()) + if self.loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self.stop(), self.loop) + future.result(timeout=10) + else: + self.loop.run_until_complete(self.stop()) diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index ef44c5a..b625b2d 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -58,4 +58,8 @@ async def start(self): self.loop.create_task(self.consume()) def stop_sync(self): - self.loop.run_until_complete(self.stop()) + if self.loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self.stop(), self.loop) + future.result(timeout=10) + else: + self.loop.run_until_complete(self.stop()) From 69af8211f9d68f0b57fc13af27991dae5f0d4ad5 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 16:52:49 +0800 Subject: [PATCH 27/48] add id parameter to all websocket subscription --- test/WsPrivateAsyncTest.py | 4 ++-- test/WsPublicAsyncTest.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index cf04a36..6130b83 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -34,14 +34,14 @@ async def main(): await asyncio.sleep(30) print("-----------------------------------------unsubscribe--------------------------------------------") args2 = [arg2] - # 使用 id 参数来标识取消订阅请求 + # Use id parameter to identify unsubscribe request await ws.unsubscribe(args2, callback=privateCallback, id="privateUnsub001") await asyncio.sleep(5) print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg3] await ws.unsubscribe(args3, callback=privateCallback, id="privateUnsub002") await asyncio.sleep(1) - # 正确关闭 websocket 连接 + # Properly close websocket connection await ws.stop() diff --git a/test/WsPublicAsyncTest.py b/test/WsPublicAsyncTest.py index 24364ee..1ed1172 100644 --- a/test/WsPublicAsyncTest.py +++ b/test/WsPublicAsyncTest.py @@ -22,19 +22,19 @@ async def main(): args.append(arg2) args.append(arg3) args.append(arg4) - # 使用 id 参数来标识订阅请求,响应中会返回相同的 id + # Use id parameter to identify subscribe request, the same id will be returned in response await ws.subscribe(args, publicCallback, id="sub001") await asyncio.sleep(5) print("-----------------------------------------unsubscribe--------------------------------------------") args2 = [arg4] - # 使用 id 参数来标识取消订阅请求 + # Use id parameter to identify unsubscribe request await ws.unsubscribe(args2, publicCallback, id="unsub001") await asyncio.sleep(5) print("-----------------------------------------unsubscribe all--------------------------------------------") args3 = [arg1, arg2, arg3] await ws.unsubscribe(args3, publicCallback) await asyncio.sleep(1) - # 正确关闭 websocket 连接 + # Properly close websocket connection await ws.stop() From e8084271d1b8c7a092a272b47271ca73e2475027 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 16:56:48 +0800 Subject: [PATCH 28/48] feat: add api changes --- okx/Account.py | 4 +++- okx/Grid.py | 8 ++++++-- okx/Trade.py | 26 +++++++++++++++++++------- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/okx/Account.py b/okx/Account.py index e3952ff..d280013 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -102,8 +102,10 @@ def get_instruments(self, instType='', ugly='', instFamily='', instId=''): return self._request_with_params(GET, GET_INSTRUMENTS, params) # Get the maximum loan of isolated MARGIN - def get_max_loan(self, instId, mgnMode, mgnCcy=''): + def get_max_loan(self, instId, mgnMode, mgnCcy='', tradeQuoteCcy=None): params = {'instId': instId, 'mgnMode': mgnMode, 'mgnCcy': mgnCcy} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_LOAN, params) # Get Fee Rates diff --git a/okx/Grid.py b/okx/Grid.py index 74b6a9b..d761708 100644 --- a/okx/Grid.py +++ b/okx/Grid.py @@ -7,11 +7,13 @@ def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_serve OkxClient.__init__(self, api_key, api_secret_key, passphrase, use_server_time, flag, domain, debug, proxy) def grid_order_algo(self, instId='', algoOrdType='', maxPx='', minPx='', gridNum='', runType='', tpTriggerPx='', - slTriggerPx='', tag='', quoteSz='', baseSz='', sz='', direction='', lever='', basePos=''): + slTriggerPx='', tag='', quoteSz='', baseSz='', sz='', direction='', lever='', basePos='', tradeQuoteCcy=None): params = {'instId': instId, 'algoOrdType': algoOrdType, 'maxPx': maxPx, 'minPx': minPx, 'gridNum': gridNum, 'runType': runType, 'tpTriggerPx': tpTriggerPx, 'slTriggerPx': slTriggerPx, 'tag': tag, 'quoteSz': quoteSz, 'baseSz': baseSz, 'sz': sz, 'direction': direction, 'lever': lever, 'basePos': basePos} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(POST, GRID_ORDER_ALGO, params) def grid_amend_order_algo(self, algoId='', instId='', slTriggerPx='', tpTriggerPx=''): @@ -79,11 +81,13 @@ def grid_ai_param(self, algoOrdType='', instId='', direction='', duration=''): # - Place recurring buy order def place_recurring_buy_order(self, stgyName='', recurringList=[], period='', recurringDay='', recurringTime='', - timeZone='', amt='', investmentCcy='', tdMode='', algoClOrdId='', tag=''): + timeZone='', amt='', investmentCcy='', tdMode='', algoClOrdId='', tag='', tradeQuoteCcy=None): params = {'stgyName': stgyName, 'recurringList': recurringList, 'period': period, 'recurringDay': recurringDay, 'recurringTime': recurringTime, 'timeZone': timeZone, 'amt': amt, 'investmentCcy': investmentCcy, 'tdMode': tdMode, 'algoClOrdId': algoClOrdId, 'tag': tag} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(POST, PLACE_RECURRING_BUY_ORDER, params) # - Amend recurring buy order diff --git a/okx/Trade.py b/okx/Trade.py index db994ae..a975dbb 100644 --- a/okx/Trade.py +++ b/okx/Trade.py @@ -12,10 +12,14 @@ def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_serve # Place Order def place_order(self, instId, tdMode, side, ordType, sz, ccy='', clOrdId='', tag='', posSide='', px='', - reduceOnly='', tgtCcy='', stpMode='', attachAlgoOrds=None, pxUsd='', pxVol='', banAmend='', tradeQuoteCcy=''): + reduceOnly='', tgtCcy='', stpMode='', attachAlgoOrds=None, pxUsd='', pxVol='', banAmend='', tradeQuoteCcy=None, pxAmendType=None): params = {'instId': instId, 'tdMode': tdMode, 'side': side, 'ordType': ordType, 'sz': sz, 'ccy': ccy, 'clOrdId': clOrdId, 'tag': tag, 'posSide': posSide, 'px': px, 'reduceOnly': reduceOnly, - 'tgtCcy': tgtCcy, 'stpMode': stpMode, 'pxUsd': pxUsd, 'pxVol': pxVol, 'banAmend': banAmend, 'tradeQuoteCcy': tradeQuoteCcy} + 'tgtCcy': tgtCcy, 'stpMode': stpMode, 'pxUsd': pxUsd, 'pxVol': pxVol, 'banAmend': banAmend} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType params['attachAlgoOrds'] = attachAlgoOrds return self._request_with_params(POST, PLACR_ORDER, params) @@ -35,11 +39,13 @@ def cancel_multiple_orders(self, orders_data): # Amend Order def amend_order(self, instId, cxlOnFail='', ordId='', clOrdId='', reqId='', newSz='', newPx='', newTpTriggerPx='', newTpOrdPx='', newSlTriggerPx='', newSlOrdPx='', newTpTriggerPxType='', newSlTriggerPxType='', - attachAlgoOrds='', newTriggerPx='', newOrdPx=''): + attachAlgoOrds='', newTriggerPx='', newOrdPx='', pxAmendType=None): params = {'instId': instId, 'cxlOnFail': cxlOnFail, 'ordId': ordId, 'clOrdId': clOrdId, 'reqId': reqId, 'newSz': newSz, 'newPx': newPx, 'newTpTriggerPx': newTpTriggerPx, 'newTpOrdPx': newTpOrdPx, 'newSlTriggerPx': newSlTriggerPx, 'newSlOrdPx': newSlOrdPx, 'newTpTriggerPxType': newTpTriggerPxType, 'newSlTriggerPxType': newSlTriggerPxType, 'newTriggerPx': newTriggerPx, 'newOrdPx': newOrdPx} + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType params['attachAlgoOrds'] = attachAlgoOrds return self._request_with_params(POST, AMEND_ORDER, params) @@ -95,8 +101,8 @@ def place_algo_order(self, instId='', tdMode='', side='', ordType='', sz='', ccy pxSpread='', szLimit='', pxLimit='', timeInterval='', tpTriggerPxType='', slTriggerPxType='', callbackRatio='', callbackSpread='', activePx='', tag='', triggerPxType='', closeFraction='' - , quickMgnType='', algoClOrdId='', tradeQuoteCcy='', tpOrdKind='', cxlOnClosePos='' - , chaseType='', chaseVal='', maxChaseType='', maxChaseVal='', attachAlgoOrds=[]): + , quickMgnType='', algoClOrdId='', tradeQuoteCcy=None, tpOrdKind='', cxlOnClosePos='' + , chaseType='', chaseVal='', maxChaseType='', maxChaseVal='', attachAlgoOrds=[], pxAmendType=None): params = {'instId': instId, 'tdMode': tdMode, 'side': side, 'ordType': ordType, 'sz': sz, 'ccy': ccy, 'posSide': posSide, 'reduceOnly': reduceOnly, 'tpTriggerPx': tpTriggerPx, 'tpOrdPx': tpOrdPx, 'slTriggerPx': slTriggerPx, 'slOrdPx': slOrdPx, 'triggerPx': triggerPx, 'orderPx': orderPx, @@ -105,9 +111,13 @@ def place_algo_order(self, instId='', tdMode='', side='', ordType='', sz='', ccy 'pxSpread': pxSpread, 'tpTriggerPxType': tpTriggerPxType, 'slTriggerPxType': slTriggerPxType, 'callbackRatio': callbackRatio, 'callbackSpread': callbackSpread, 'activePx': activePx, 'tag': tag, 'triggerPxType': triggerPxType, 'closeFraction': closeFraction, - 'quickMgnType': quickMgnType, 'algoClOrdId': algoClOrdId, 'tradeQuoteCcy': tradeQuoteCcy, + 'quickMgnType': quickMgnType, 'algoClOrdId': algoClOrdId, 'tpOrdKind': tpOrdKind, 'cxlOnClosePos': cxlOnClosePos, 'chaseType': chaseType, 'chaseVal': chaseVal, 'maxChaseType': maxChaseType, 'maxChaseVal': maxChaseVal, 'attachAlgoOrds': attachAlgoOrds} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType return self._request_with_params(POST, PLACE_ALGO_ORDER, params) # Cancel Algo Order @@ -179,11 +189,13 @@ def get_algo_order_details(self, algoId='', algoClOrdId=''): # Amend algo order def amend_algo_order(self, instId='', algoId='', algoClOrdId='', cxlOnFail='', reqId='', newSz='', newTriggerPx='', newOrdPx='', newTpTriggerPx='', newTpOrdPx='', newSlTriggerPx='', newSlOrdPx='', newTpTriggerPxType='', - newSlTriggerPxType=''): + newSlTriggerPxType='', pxAmendType=None): params = {'instId': instId, 'algoId': algoId, 'algoClOrdId': algoClOrdId, 'cxlOnFail': cxlOnFail, 'reqId': reqId, 'newSz': newSz, 'newTriggerPx': newTriggerPx, 'newOrdPx': newOrdPx, 'newTpTriggerPx': newTpTriggerPx, 'newTpOrdPx': newTpOrdPx, 'newSlTriggerPx': newSlTriggerPx, 'newSlOrdPx': newSlOrdPx, 'newTpTriggerPxType': newTpTriggerPxType, 'newSlTriggerPxType': newSlTriggerPxType} + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType return self._request_with_params(POST, AMEND_ALGO_ORDER, params) def get_oneclick_repay_list_v2(self): From 61562c4bfff8fdaf83b89b45a23ac5051e2e7e5c Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 16:58:32 +0800 Subject: [PATCH 29/48] websocket enhancement --- test/WsPrivateAsyncTest.py | 46 +++++++++++++++++++------------------- test/WsPublicAsyncTest.py | 14 ++++++------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py index f478984..eb7e852 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/WsPrivateAsyncTest.py @@ -8,7 +8,7 @@ def privateCallback(message): async def main(): - """订阅测试""" + """Subscription test""" url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( apiKey="your apiKey", @@ -40,8 +40,8 @@ async def main(): async def test_place_order(): """ - 测试下单功能 - URL: /ws/v5/private (限速: 60次/秒) + Test place order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) """ url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( @@ -55,7 +55,7 @@ async def test_place_order(): await ws.login() await asyncio.sleep(5) - # 下单参数 + # Order parameters order_args = [{ "instId": "BTC-USDT", "tdMode": "cash", @@ -72,8 +72,8 @@ async def test_place_order(): async def test_batch_orders(): """ - 测试批量下单功能 - URL: /ws/v5/private (限速: 60次/秒, 最多20个订单) + Test batch orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) """ url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( @@ -87,7 +87,7 @@ async def test_batch_orders(): await ws.login() await asyncio.sleep(5) - # 批量下单参数 (最多20个) + # Batch order parameters (max 20) order_args = [ { "instId": "BTC-USDT", @@ -115,8 +115,8 @@ async def test_batch_orders(): async def test_cancel_order(): """ - 测试撤单功能 - URL: /ws/v5/private (限速: 60次/秒) + Test cancel order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) """ url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( @@ -130,11 +130,11 @@ async def test_cancel_order(): await ws.login() await asyncio.sleep(5) - # 撤单参数 (ordId 和 clOrdId 必须传一个) + # Cancel order parameters (either ordId or clOrdId must be provided) cancel_args = [{ "instId": "BTC-USDT", "ordId": "your_order_id" - # 或者使用 "clOrdId": "client_order_001" + # Or use "clOrdId": "client_order_001" }] await ws.cancel_order(cancel_args, callback=privateCallback, id="cancel001") await asyncio.sleep(5) @@ -143,8 +143,8 @@ async def test_cancel_order(): async def test_batch_cancel_orders(): """ - 测试批量撤单功能 - URL: /ws/v5/private (限速: 60次/秒, 最多20个订单) + Test batch cancel orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) """ url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( @@ -169,8 +169,8 @@ async def test_batch_cancel_orders(): async def test_amend_order(): """ - 测试改单功能 - URL: /ws/v5/private (限速: 60次/秒) + Test amend order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) """ url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( @@ -184,7 +184,7 @@ async def test_amend_order(): await ws.login() await asyncio.sleep(5) - # 改单参数 + # Amend order parameters amend_args = [{ "instId": "BTC-USDT", "ordId": "your_order_id", @@ -198,9 +198,9 @@ async def test_amend_order(): async def test_mass_cancel(): """ - 测试批量撤销功能 - URL: /ws/v5/business (限速: 1次/秒) - 注意: 此功能使用 business 频道 + Test mass cancel functionality + URL: /ws/v5/business (Rate limit: 1 request/second) + Note: This function uses the business channel """ url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" ws = WsPrivateAsync( @@ -214,7 +214,7 @@ async def test_mass_cancel(): await ws.login() await asyncio.sleep(5) - # 批量撤销参数 + # Mass cancel parameters mass_cancel_args = [{ "instType": "SPOT", "instFamily": "BTC-USDT" @@ -225,7 +225,7 @@ async def test_mass_cancel(): async def test_send_method(): - """测试通用send方法""" + """Test generic send method""" url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( apiKey="your apiKey", @@ -238,7 +238,7 @@ async def test_send_method(): await ws.login() await asyncio.sleep(5) - # 使用通用send方法下单 - 注意要传入callback才能收到响应 + # Use generic send method to place order - callback must be provided to receive response order_args = [{ "instId": "BTC-USDT", "tdMode": "cash", @@ -259,5 +259,5 @@ async def test_send_method(): asyncio.run(test_cancel_order()) asyncio.run(test_batch_cancel_orders()) asyncio.run(test_amend_order()) - asyncio.run(test_mass_cancel()) # 注意使用 business 频道 + asyncio.run(test_mass_cancel()) # Note: uses business channel asyncio.run(test_send_method()) diff --git a/test/WsPublicAsyncTest.py b/test/WsPublicAsyncTest.py index 8fda306..c3f7be8 100644 --- a/test/WsPublicAsyncTest.py +++ b/test/WsPublicAsyncTest.py @@ -10,7 +10,7 @@ def publicCallback(message): async def main(): # url = "wss://wspap.okex.com:8443/ws/v5/public?brokerId=9999" url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" - ws = WsPublicAsync(url=url, debug=True) # 开启debug日志 + ws = WsPublicAsync(url=url, debug=True) # Enable debug logging await ws.start() args = [] arg1 = {"channel": "instruments", "instType": "FUTURES"} @@ -36,8 +36,8 @@ async def main(): async def test_business_channel_with_login(): """ - 测试 business 频道的登录功能 - business 频道需要登录后才能订阅某些私有数据 + Test business channel login functionality + Business channel requires login to subscribe to certain private data """ url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" ws = WsPublicAsync( @@ -49,11 +49,11 @@ async def test_business_channel_with_login(): ) await ws.start() - # 登录 + # Login await ws.login() await asyncio.sleep(5) - # 订阅需要登录的频道 + # Subscribe to channels that require login args = [{"channel": "candle1m", "instId": "BTC-USDT"}] await ws.subscribe(args, publicCallback) await asyncio.sleep(30) @@ -61,12 +61,12 @@ async def test_business_channel_with_login(): async def test_send_method(): - """测试通用send方法""" + """Test generic send method""" url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" ws = WsPublicAsync(url=url, debug=True) await ws.start() - # 使用通用send方法订阅 - 注意要传入callback才能收到响应 + # Use generic send method to subscribe - callback must be provided to receive response args = [{"channel": "tickers", "instId": "BTC-USDT"}] await ws.send("subscribe", args, callback=publicCallback, id="send001") await asyncio.sleep(10) From 11efc370da2f05b063082a27f39011f670c9e509 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 17:04:30 +0800 Subject: [PATCH 30/48] feat: add ci cd --- .gitlab-ci.yml | 119 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 .gitlab-ci.yml diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..c4f99fa --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,119 @@ +# GitLab CI/CD Configuration for python-okx +# Documentation: https://docs.gitlab.com/ee/ci/ + +# Define stages in order of execution +stages: + - lint + - test + - build + +# Global variables +variables: + PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" + +# Cache pip downloads between jobs +cache: + paths: + - .cache/pip/ + - venv/ + +# ============================================ +# LINT STAGE +# ============================================ + +lint: + stage: lint + image: python:3.11-slim + before_script: + - pip install --upgrade pip + - pip install flake8 black isort + script: + - echo "Running flake8 linting..." + - flake8 okx/ --max-line-length=120 --ignore=E501,W503,E203 --count --show-source --statistics + # Optional: Check code formatting (uncomment if you want to enforce) + # - black --check okx/ + # - isort --check-only okx/ + allow_failure: true # Set to false once codebase is cleaned up + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + +# ============================================ +# TEST STAGE +# ============================================ + +# Template for test jobs +.test_template: &test_template + stage: test + before_script: + - pip install --upgrade pip + - pip install -e . + - pip install pytest pytest-cov pytest-asyncio + script: + - echo "Running unit tests..." + - python -m pytest test/unit/ -v --cov=okx --cov-report=term-missing --cov-report=xml:coverage.xml + coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' + artifacts: + reports: + coverage_report: + coverage_format: cobertura + path: coverage.xml + paths: + - coverage.xml + expire_in: 1 week + +# Test with Python 3.9 +test:python3.9: + <<: *test_template + image: python:3.9-slim + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + +# Test with Python 3.10 +test:python3.10: + <<: *test_template + image: python:3.10-slim + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + +# Test with Python 3.11 +test:python3.11: + <<: *test_template + image: python:3.11-slim + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + +# Test with Python 3.12 +test:python3.12: + <<: *test_template + image: python:3.12-slim + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + +# ============================================ +# BUILD STAGE +# ============================================ + +build: + stage: build + image: python:3.11-slim + before_script: + - pip install --upgrade pip + - pip install build twine + script: + - echo "Building package..." + - python -m build + - echo "Checking package with twine..." + - twine check dist/* + artifacts: + paths: + - dist/ + expire_in: 1 week + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + - if: $CI_COMMIT_TAG From b6c3b839fc5d4d51abd8bda3a40e518910e12015 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Fri, 19 Dec 2025 17:04:59 +0800 Subject: [PATCH 31/48] websocket enhancement --- okx/websocket/WsPublicAsync.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index d6091f6..1a7e28d 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -37,7 +37,6 @@ async def consume(self): if self.callback: self.callback(message) - async def subscribe(self, params: list, callback, id: str = None): async def login(self): """ 登录方法,用于需要登录的 business 频道(如 /ws/v5/business) @@ -66,10 +65,6 @@ async def subscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) - } - if id is not None: - payload_dict["id"] = id - payload = json.dumps(payload_dict) if self.debug: logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) @@ -84,11 +79,6 @@ async def unsubscribe(self, params: list, callback, id: str = None): if id is not None: payload_dict["id"] = id payload = json.dumps(payload_dict) - logger.info(f"unsubscribe: {payload}") - } - if id is not None: - payload_dict["id"] = id - payload = json.dumps(payload_dict) if self.debug: logger.debug(f"unsubscribe: {payload}") else: From 81c269cfd50ae6fcee6d8f1c37167b2c61484161 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 17:19:14 +0800 Subject: [PATCH 32/48] feat: add ci cd and remove wrong commit for gitlab --- .github/workflows/ci.yml | 113 +++++++++++++++++++++++++++++++++++++ .gitlab-ci.yml | 119 --------------------------------------- 2 files changed, 113 insertions(+), 119 deletions(-) create mode 100644 .github/workflows/ci.yml delete mode 100644 .gitlab-ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..117d8be --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,113 @@ +# GitHub Actions CI/CD Configuration for python-okx +name: CI + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +jobs: + # ============================================ + # LINT JOB + # ============================================ + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install linting tools + run: | + python -m pip install --upgrade pip + pip install flake8 + + - name: Run flake8 + run: | + flake8 okx/ --max-line-length=120 --ignore=E501,W503,E203 --count --show-source --statistics + continue-on-error: true # Set to false once codebase is cleaned up + + # ============================================ + # TEST JOB + # ============================================ + test: + name: Test (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.py') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pytest pytest-cov pytest-asyncio + + - name: Run tests + run: | + python -m pytest test/unit/ -v --cov=okx --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + if: matrix.python-version == '3.11' + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + fail_ci_if_error: false + + # ============================================ + # BUILD JOB + # ============================================ + build: + name: Build Package + runs-on: ubuntu-latest + needs: [test] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tools + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: python -m build + + - name: Check package + run: twine check dist/* + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + retention-days: 7 + diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index c4f99fa..0000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,119 +0,0 @@ -# GitLab CI/CD Configuration for python-okx -# Documentation: https://docs.gitlab.com/ee/ci/ - -# Define stages in order of execution -stages: - - lint - - test - - build - -# Global variables -variables: - PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" - -# Cache pip downloads between jobs -cache: - paths: - - .cache/pip/ - - venv/ - -# ============================================ -# LINT STAGE -# ============================================ - -lint: - stage: lint - image: python:3.11-slim - before_script: - - pip install --upgrade pip - - pip install flake8 black isort - script: - - echo "Running flake8 linting..." - - flake8 okx/ --max-line-length=120 --ignore=E501,W503,E203 --count --show-source --statistics - # Optional: Check code formatting (uncomment if you want to enforce) - # - black --check okx/ - # - isort --check-only okx/ - allow_failure: true # Set to false once codebase is cleaned up - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -# ============================================ -# TEST STAGE -# ============================================ - -# Template for test jobs -.test_template: &test_template - stage: test - before_script: - - pip install --upgrade pip - - pip install -e . - - pip install pytest pytest-cov pytest-asyncio - script: - - echo "Running unit tests..." - - python -m pytest test/unit/ -v --cov=okx --cov-report=term-missing --cov-report=xml:coverage.xml - coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' - artifacts: - reports: - coverage_report: - coverage_format: cobertura - path: coverage.xml - paths: - - coverage.xml - expire_in: 1 week - -# Test with Python 3.9 -test:python3.9: - <<: *test_template - image: python:3.9-slim - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -# Test with Python 3.10 -test:python3.10: - <<: *test_template - image: python:3.10-slim - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -# Test with Python 3.11 -test:python3.11: - <<: *test_template - image: python:3.11-slim - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -# Test with Python 3.12 -test:python3.12: - <<: *test_template - image: python:3.12-slim - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -# ============================================ -# BUILD STAGE -# ============================================ - -build: - stage: build - image: python:3.11-slim - before_script: - - pip install --upgrade pip - - pip install build twine - script: - - echo "Building package..." - - python -m build - - echo "Checking package with twine..." - - twine check dist/* - artifacts: - paths: - - dist/ - expire_in: 1 week - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - - if: $CI_COMMIT_TAG From 89e8e25629a31761f79fbcf4ee92b7865f9d87dc Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 17:27:29 +0800 Subject: [PATCH 33/48] feat: add cicd --- .github/workflows/ci.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 117d8be..18ff677 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,15 @@ name: CI on: push: - branches: [main, master] + branches: + - master + - 'release/*' + - 'releases/*' pull_request: - branches: [main, master] + branches: + - master + - 'release/*' + - 'releases/*' jobs: # ============================================ From 2d47171c08ca803740663c74d632b6281b7c0b81 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 17:30:36 +0800 Subject: [PATCH 34/48] fix: missing import --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 18ff677..2e0717a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e . - pip install pytest pytest-cov pytest-asyncio + pip install pytest pytest-cov pytest-asyncio websockets certifi - name: Run tests run: | From ee04e6b1dcecd9ef5433d04d77421a0adbbf3244 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 17:41:25 +0800 Subject: [PATCH 35/48] fix: broken testing --- setup.py | 4 +- .../okx/websocket/test_ws_private_async.py | 47 +++++++++---------- .../okx/websocket/test_ws_public_async.py | 27 +++++------ 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/setup.py b/setup.py index b91e237..ce2838d 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,8 @@ "loguru", "requests", "Twisted", - "pyOpenSSL" + "pyOpenSSL", + "websockets", + "certifi" ] ) \ No newline at end of file diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py index 4e43114..da367f8 100644 --- a/test/unit/okx/websocket/test_ws_private_async.py +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -8,14 +8,17 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock +# Import the module first so patch can resolve the path +import okx.websocket.WsPrivateAsync as ws_private_module +from okx.websocket.WsPrivateAsync import WsPrivateAsync + class TestWsPrivateAsyncInit(unittest.TestCase): """Unit tests for WsPrivateAsync initialization""" def test_init_with_required_params(self): """Test initialization with required parameters""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory: - from okx.websocket.WsPrivateAsync import WsPrivateAsync + with patch.object(ws_private_module, 'WebSocketFactory') as mock_factory: ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -37,13 +40,12 @@ class TestWsPrivateAsyncSubscribe(unittest.TestCase): def test_subscribe_sends_correct_payload(self): """Test subscribe sends correct payload after login""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ - patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ - patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils, \ + patch.object(ws_private_module.asyncio, 'sleep', new_callable=AsyncMock): - mock_init_login.return_value = '{"op":"login"}' + mock_ws_utils.initLoginParams.return_value = '{"op":"login"}' - from okx.websocket.WsPrivateAsync import WsPrivateAsync ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -70,13 +72,12 @@ async def run_test(): def test_subscribe_with_id(self): """Test subscribe with id parameter""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ - patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login, \ - patch('okx.websocket.WsPrivateAsync.asyncio.sleep', new_callable=AsyncMock): + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils, \ + patch.object(ws_private_module.asyncio, 'sleep', new_callable=AsyncMock): - mock_init_login.return_value = '{"op":"login"}' + mock_ws_utils.initLoginParams.return_value = '{"op":"login"}' - from okx.websocket.WsPrivateAsync import WsPrivateAsync ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -105,8 +106,7 @@ class TestWsPrivateAsyncUnsubscribe(unittest.TestCase): def test_unsubscribe_sends_correct_payload(self): """Test unsubscribe sends correct payload""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): - from okx.websocket.WsPrivateAsync import WsPrivateAsync + with patch.object(ws_private_module, 'WebSocketFactory'): ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -131,8 +131,7 @@ async def run_test(): def test_unsubscribe_with_id(self): """Test unsubscribe with id parameter""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): - from okx.websocket.WsPrivateAsync import WsPrivateAsync + with patch.object(ws_private_module, 'WebSocketFactory'): ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -160,12 +159,11 @@ class TestWsPrivateAsyncLogin(unittest.TestCase): def test_login_calls_init_login_params(self): """Test login calls WsUtils.initLoginParams with correct parameters""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'), \ - patch('okx.websocket.WsPrivateAsync.WsUtils.initLoginParams') as mock_init_login: + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils: - mock_init_login.return_value = '{"op":"login","args":[...]}' + mock_ws_utils.initLoginParams.return_value = '{"op":"login","args":[...]}' - from okx.websocket.WsPrivateAsync import WsPrivateAsync ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -179,7 +177,7 @@ def test_login_calls_init_login_params(self): async def run_test(): result = await ws.login() self.assertTrue(result) - mock_init_login.assert_called_once_with( + mock_ws_utils.initLoginParams.assert_called_once_with( useServerTime=True, apiKey="test_api_key", passphrase="test_passphrase", @@ -193,13 +191,12 @@ class TestWsPrivateAsyncStartStop(unittest.TestCase): """Unit tests for WsPrivateAsync start and stop methods""" def test_stop(self): - """Test stop method closes the factory and stops loop""" - with patch('okx.websocket.WsPrivateAsync.WebSocketFactory') as mock_factory_class: + """Test stop method closes the factory""" + with patch.object(ws_private_module, 'WebSocketFactory') as mock_factory_class: mock_factory_instance = MagicMock() mock_factory_instance.close = AsyncMock() mock_factory_class.return_value = mock_factory_instance - from okx.websocket.WsPrivateAsync import WsPrivateAsync ws = WsPrivateAsync( apiKey="test_api_key", passphrase="test_passphrase", @@ -207,12 +204,10 @@ def test_stop(self): url="wss://test.example.com", useServerTime=False ) - ws.loop = MagicMock() async def run_test(): await ws.stop() mock_factory_instance.close.assert_called_once() - ws.loop.stop.assert_called_once() asyncio.get_event_loop().run_until_complete(run_test()) diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py index 6443ac4..916a0b9 100644 --- a/test/unit/okx/websocket/test_ws_public_async.py +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -8,14 +8,17 @@ import asyncio from unittest.mock import patch, MagicMock, AsyncMock +# Import the module first so patch can resolve the path +import okx.websocket.WsPublicAsync as ws_public_module +from okx.websocket.WsPublicAsync import WsPublicAsync + class TestWsPublicAsyncInit(unittest.TestCase): """Unit tests for WsPublicAsync initialization""" def test_init_with_url(self): """Test initialization with url parameter""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory: - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory') as mock_factory: ws = WsPublicAsync(url="wss://test.example.com") self.assertEqual(ws.url, "wss://test.example.com") @@ -29,8 +32,7 @@ class TestWsPublicAsyncSubscribe(unittest.TestCase): def test_subscribe_sets_callback(self): """Test subscribe sets callback correctly""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory'): ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -53,8 +55,7 @@ async def run_test(): def test_subscribe_with_id(self): """Test subscribe with id parameter""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory'): ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -75,8 +76,7 @@ async def run_test(): def test_subscribe_with_multiple_channels(self): """Test subscribe with multiple channels""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory'): ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -101,8 +101,7 @@ class TestWsPublicAsyncUnsubscribe(unittest.TestCase): def test_unsubscribe_without_id(self): """Test unsubscribe without id parameter""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory'): ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -121,8 +120,7 @@ async def run_test(): def test_unsubscribe_with_id(self): """Test unsubscribe with id parameter""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): - from okx.websocket.WsPublicAsync import WsPublicAsync + with patch.object(ws_public_module, 'WebSocketFactory'): ws = WsPublicAsync(url="wss://test.example.com") mock_websocket = AsyncMock() ws.websocket = mock_websocket @@ -144,19 +142,16 @@ class TestWsPublicAsyncStartStop(unittest.TestCase): def test_stop(self): """Test stop method closes the factory""" - with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory_class: + with patch.object(ws_public_module, 'WebSocketFactory') as mock_factory_class: mock_factory_instance = MagicMock() mock_factory_instance.close = AsyncMock() mock_factory_class.return_value = mock_factory_instance - from okx.websocket.WsPublicAsync import WsPublicAsync ws = WsPublicAsync(url="wss://test.example.com") - ws.loop = MagicMock() async def run_test(): await ws.stop() mock_factory_instance.close.assert_called_once() - ws.loop.stop.assert_called_once() asyncio.get_event_loop().run_until_complete(run_test()) From 75a20fc4eb115843420c116b708a8c56c4cdde61 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 19:05:39 +0800 Subject: [PATCH 36/48] feat: add .env --- .env.example | 10 ++++++ .gitignore | 7 +++- test/config.py | 35 +++++++++++++++++++ test/{AccountTest.py => test_account.py} | 7 ++-- ...ckTradingTest.py => test_block_trading.py} | 9 ++--- test/{ConvertTest.py => test_convert.py} | 11 +++--- ...opyTradingTest.py => test_copy_trading.py} | 8 ++--- ...{EthStakingTest.py => test_eth_staking.py} | 8 ++--- ...xibleLoanTest.py => test_flexible_loan.py} | 8 ++--- test/{FundingTest.py => test_funding.py} | 8 ++--- test/{GridTest.py => test_grid.py} | 8 ++--- test/{MarketTest.py => test_market.py} | 9 ++--- ...{PublicDataTest.py => test_public_data.py} | 9 ++--- test/{SavingsTest.py => test_savings.py} | 8 ++--- ...{SolStakingTest.py => test_sol_staking.py} | 8 ++--- test/{SpreadTest.py => test_spread.py} | 9 ++--- ...takingDefiTest.py => test_staking_defi.py} | 7 ++-- ...{SubAccountTest.py => test_sub_account.py} | 8 ++--- test/{TradeTest.py => test_trade.py} | 7 ++-- ...radingDataTest.py => test_trading_data.py} | 9 +++-- ...eAsyncTest.py => test_ws_private_async.py} | 8 +++-- ...icAsyncTest.py => test_ws_public_async.py} | 0 22 files changed, 127 insertions(+), 74 deletions(-) create mode 100644 .env.example create mode 100644 test/config.py rename test/{AccountTest.py => test_account.py} (98%) rename test/{BlockTradingTest.py => test_block_trading.py} (93%) rename test/{ConvertTest.py => test_convert.py} (83%) rename test/{CopyTradingTest.py => test_copy_trading.py} (88%) rename test/{EthStakingTest.py => test_eth_staking.py} (82%) rename test/{FlexibleLoanTest.py => test_flexible_loan.py} (84%) rename test/{FundingTest.py => test_funding.py} (95%) rename test/{GridTest.py => test_grid.py} (95%) rename test/{MarketTest.py => test_market.py} (93%) rename test/{PublicDataTest.py => test_public_data.py} (95%) rename test/{SavingsTest.py => test_savings.py} (84%) rename test/{SolStakingTest.py => test_sol_staking.py} (82%) rename test/{SpreadTest.py => test_spread.py} (91%) rename test/{StakingDefiTest.py => test_staking_defi.py} (82%) rename test/{SubAccountTest.py => test_sub_account.py} (92%) rename test/{TradeTest.py => test_trade.py} (98%) rename test/{TradingDataTest.py => test_trading_data.py} (86%) rename test/{WsPrivateAsyncTest.py => test_ws_private_async.py} (88%) rename test/{WsPublicAsyncTest.py => test_ws_public_async.py} (100%) diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..0f777f4 --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +# OKX API Credentials +# Copy this file to .env and fill in your actual credentials +# NEVER commit .env to version control! + +OKX_API_KEY=your_api_key_here +OKX_API_SECRET=your_api_secret_here +OKX_PASSPHRASE=your_passphrase_here + +# Optional: Set to '0' for live trading, '1' for demo trading +OKX_FLAG=1 diff --git a/.gitignore b/.gitignore index 10c5a42..41c5c7c 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,9 @@ build/ ### VS Code ### .vscode/ -id_rsa* \ No newline at end of file +id_rsa* + +# Environment files +.env +.env.local +.env.*.local \ No newline at end of file diff --git a/test/config.py b/test/config.py new file mode 100644 index 0000000..090e7e8 --- /dev/null +++ b/test/config.py @@ -0,0 +1,35 @@ +""" +Test configuration module - loads API credentials from environment variables. + +Usage: + from test.config import get_api_credentials + + api_key, api_secret, passphrase, flag = get_api_credentials() +""" +import os +from pathlib import Path + +# Try to load from .env file if python-dotenv is available +try: + from dotenv import load_dotenv + # Load .env from project root + env_path = Path(__file__).parent.parent / '.env' + load_dotenv(env_path) +except ImportError: + pass # python-dotenv not installed, rely on system environment variables + + +def get_api_credentials(): + """ + Get API credentials from environment variables. + + Returns: + tuple: (api_key, api_secret, passphrase, flag) + """ + api_key = os.getenv('OKX_API_KEY', '') + api_secret = os.getenv('OKX_API_SECRET', '') + passphrase = os.getenv('OKX_PASSPHRASE', '') + flag = os.getenv('OKX_FLAG', '1') # Default to demo trading + + return api_key, api_secret, passphrase, flag + diff --git a/test/AccountTest.py b/test/test_account.py similarity index 98% rename from test/AccountTest.py rename to test/test_account.py index 14793d0..50b5493 100644 --- a/test/AccountTest.py +++ b/test/test_account.py @@ -3,14 +3,13 @@ from loguru import logger from okx import Account +from test.config import get_api_credentials class AccountTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag=flag) # ''' # POSITIONS_HISTORY = '/api/v5/account/positions-history' #need add diff --git a/test/BlockTradingTest.py b/test/test_block_trading.py similarity index 93% rename from test/BlockTradingTest.py rename to test/test_block_trading.py index e8b0e88..1c02542 100644 --- a/test/BlockTradingTest.py +++ b/test/test_block_trading.py @@ -1,12 +1,13 @@ import unittest from okx import BlockTrading +from test.config import get_api_credentials + + class BlockTradingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.BlockTradingAPI = BlockTrading.BlockTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.BlockTradingAPI = BlockTrading.BlockTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) """ def test_get_counter_parties(self): diff --git a/test/ConvertTest.py b/test/test_convert.py similarity index 83% rename from test/ConvertTest.py rename to test/test_convert.py index aa1d175..fc1f048 100644 --- a/test/ConvertTest.py +++ b/test/test_convert.py @@ -1,11 +1,12 @@ import unittest -from ..okx import Convert +from okx import Convert +from test.config import get_api_credentials + + class ConvertTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.ConvertAPI = Convert.ConvertAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.ConvertAPI = Convert.ConvertAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' def test_get_currencies(self): diff --git a/test/CopyTradingTest.py b/test/test_copy_trading.py similarity index 88% rename from test/CopyTradingTest.py rename to test/test_copy_trading.py index 95c1739..b516bc5 100644 --- a/test/CopyTradingTest.py +++ b/test/test_copy_trading.py @@ -1,13 +1,13 @@ import unittest from okx import CopyTrading +from test.config import get_api_credentials + class CopyTradingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.StackingAPI = CopyTrading.CopyTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='0') + flag=flag) # def test_get_existing_leading_positions(self): # print(self.StackingAPI.get_existing_leading_positions(instId='DOGE-USDT-SWAP')) diff --git a/test/EthStakingTest.py b/test/test_eth_staking.py similarity index 82% rename from test/EthStakingTest.py rename to test/test_eth_staking.py index 3be4860..f280d1d 100644 --- a/test/EthStakingTest.py +++ b/test/test_eth_staking.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import EthStaking +from test.config import get_api_credentials + class EthStakingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = EthStaking.EthStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = EthStaking.EthStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_eth_product_info(self): print(self.StackingAPI.eth_product_info()) diff --git a/test/FlexibleLoanTest.py b/test/test_flexible_loan.py similarity index 84% rename from test/FlexibleLoanTest.py rename to test/test_flexible_loan.py index b4320a5..517b194 100644 --- a/test/FlexibleLoanTest.py +++ b/test/test_flexible_loan.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import FlexibleLoan +from test.config import get_api_credentials + class FlexibleLoanTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.FlexibleLoanAPI = FlexibleLoan.FlexibleLoanAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.FlexibleLoanAPI = FlexibleLoan.FlexibleLoanAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_borrow_currencies(self): print(self.FlexibleLoanAPI.borrow_currencies()) diff --git a/test/FundingTest.py b/test/test_funding.py similarity index 95% rename from test/FundingTest.py rename to test/test_funding.py index f9fb0c6..3247f65 100644 --- a/test/FundingTest.py +++ b/test/test_funding.py @@ -1,13 +1,13 @@ import unittest from okx import Funding +from test.config import get_api_credentials + class FundingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.FundingAPI = Funding.FundingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='0') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.FundingAPI = Funding.FundingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) """ CANCEL_WITHDRAWAL = '/api/v5/asset/cancel-withdrawal' #need add CONVERT_DUST_ASSETS = '/api/v5/asset/convert-dust-assets' #need add diff --git a/test/GridTest.py b/test/test_grid.py similarity index 95% rename from test/GridTest.py rename to test/test_grid.py index 6e8f7da..92b130c 100644 --- a/test/GridTest.py +++ b/test/test_grid.py @@ -1,13 +1,13 @@ import unittest from okx import Grid +from test.config import get_api_credentials + class GridTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.GridAPI = Grid.GridAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1', debug=False) + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.GridAPI = Grid.GridAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag, debug=False) """ GRID_COMPUTE_MARIGIN_BALANCE = '/api/v5/tradingBot/grid/compute-margin-balance' GRID_MARGIN_BALANCE = '/api/v5/tradingBot/grid/margin-balance' diff --git a/test/MarketTest.py b/test/test_market.py similarity index 93% rename from test/MarketTest.py rename to test/test_market.py index 2a45d7f..f3b6acf 100644 --- a/test/MarketTest.py +++ b/test/test_market.py @@ -12,12 +12,13 @@ BLOCK_TRADES = '/api/v5/market/block-trades'#need to add ''' +from test.config import get_api_credentials + + class MarketAPITest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.MarketApi = MarketData.MarketAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.MarketApi = MarketData.MarketAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' def test_index_component(self): diff --git a/test/PublicDataTest.py b/test/test_public_data.py similarity index 95% rename from test/PublicDataTest.py rename to test/test_public_data.py index 56a028b..79e5d35 100644 --- a/test/PublicDataTest.py +++ b/test/test_public_data.py @@ -1,11 +1,12 @@ import unittest from okx import PublicData +from test.config import get_api_credentials + + class publicDataTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.publicDataApi = PublicData.PublicAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.publicDataApi = PublicData.PublicAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' TestCase For: INTEREST_LOAN = '/api/v5/public/interest-rate-loan-quota' #need to add diff --git a/test/SavingsTest.py b/test/test_savings.py similarity index 84% rename from test/SavingsTest.py rename to test/test_savings.py index 9dac910..9baf663 100644 --- a/test/SavingsTest.py +++ b/test/test_savings.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import Savings +from test.config import get_api_credentials + class SavingsTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = Savings.SavingsAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = Savings.SavingsAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_get_saving_balance(self): diff --git a/test/SolStakingTest.py b/test/test_sol_staking.py similarity index 82% rename from test/SolStakingTest.py rename to test/test_sol_staking.py index c3899d4..110fb67 100644 --- a/test/SolStakingTest.py +++ b/test/test_sol_staking.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import SolStaking +from test.config import get_api_credentials + class SolStakingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = SolStaking.SolStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = SolStaking.SolStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_sol_purchase(self): print(self.StackingAPI.sol_purchase(amt="1")) diff --git a/test/SpreadTest.py b/test/test_spread.py similarity index 91% rename from test/SpreadTest.py rename to test/test_spread.py index cbaf6e8..25a596b 100644 --- a/test/SpreadTest.py +++ b/test/test_spread.py @@ -1,11 +1,12 @@ import unittest from okx import SpreadTrading +from test.config import get_api_credentials + + class TradeTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.tradeApi = SpreadTrading.SpreadTradingAPI(api_key, api_secret_key, passphrase, False, '1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.tradeApi = SpreadTrading.SpreadTradingAPI(api_key, api_secret_key, passphrase, False, flag) # def test_place_order(self): # print(self.tradeApi.place_order(sprdId='BTC-USDT_BTC-USDT-SWAP',clOrdId='b15',side='buy',ordType='limit', diff --git a/test/StakingDefiTest.py b/test/test_staking_defi.py similarity index 82% rename from test/StakingDefiTest.py rename to test/test_staking_defi.py index 29ac045..bfe908a 100644 --- a/test/StakingDefiTest.py +++ b/test/test_staking_defi.py @@ -1,14 +1,13 @@ import unittest from okx.Finance import StakingDefi +from test.config import get_api_credentials class StakingDefiTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.StackingAPI = StakingDefi.StakingDefiAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='1') + flag=flag) def test_get_offers(self): print(self.StackingAPI.get_offers(ccy="USDT")) diff --git a/test/SubAccountTest.py b/test/test_sub_account.py similarity index 92% rename from test/SubAccountTest.py rename to test/test_sub_account.py index b4281bd..8682194 100644 --- a/test/SubAccountTest.py +++ b/test/test_sub_account.py @@ -1,12 +1,12 @@ import unittest from okx import SubAccount +from test.config import get_api_credentials + class SubAccountTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.SubAccountApi = SubAccount.SubAccountAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.SubAccountApi = SubAccount.SubAccountAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' ENTRUST_SUBACCOUNT_LIST = '/api/v5/users/entrust-subaccount-list' #need to add SET_TRSNSFER_OUT = '/api/v5/users/subaccount/set-transfer-out' #need to add diff --git a/test/TradeTest.py b/test/test_trade.py similarity index 98% rename from test/TradeTest.py rename to test/test_trade.py index 3f1b601..c460943 100644 --- a/test/TradeTest.py +++ b/test/test_trade.py @@ -1,14 +1,13 @@ import unittest from okx import Trade +from test.config import get_api_credentials class TradeTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.tradeApi = Trade.TradeAPI(api_key, api_secret_key, passphrase, False, '1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.tradeApi = Trade.TradeAPI(api_key, api_secret_key, passphrase, False, flag) # # """ # def test_place_order(self): diff --git a/test/TradingDataTest.py b/test/test_trading_data.py similarity index 86% rename from test/TradingDataTest.py rename to test/test_trading_data.py index ac318b6..0872f37 100644 --- a/test/TradingDataTest.py +++ b/test/test_trading_data.py @@ -1,14 +1,13 @@ import unittest -from ..okx import TradingData +from okx import TradingData +from test.config import get_api_credentials class TradingDataTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.TradingDataAPI = TradingData.TradingDataAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='1') + flag=flag) """ def test_get_support_coins(self): print(self.TradingDataAPI.get_support_coin()) diff --git a/test/WsPrivateAsyncTest.py b/test/test_ws_private_async.py similarity index 88% rename from test/WsPrivateAsyncTest.py rename to test/test_ws_private_async.py index 6130b83..1244387 100644 --- a/test/WsPrivateAsyncTest.py +++ b/test/test_ws_private_async.py @@ -1,6 +1,7 @@ import asyncio from okx.websocket.WsPrivateAsync import WsPrivateAsync +from test.config import get_api_credentials def privateCallback(message): @@ -8,11 +9,12 @@ def privateCallback(message): async def main(): + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, useServerTime=False ) diff --git a/test/WsPublicAsyncTest.py b/test/test_ws_public_async.py similarity index 100% rename from test/WsPublicAsyncTest.py rename to test/test_ws_public_async.py From 6cee7dc266cf4237d32ef063b63be5f1d49d5c7c Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 19:30:38 +0800 Subject: [PATCH 37/48] feat: rmv changes --- okx/Trade.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/okx/Trade.py b/okx/Trade.py index a975dbb..9d6d9e5 100644 --- a/okx/Trade.py +++ b/okx/Trade.py @@ -189,13 +189,11 @@ def get_algo_order_details(self, algoId='', algoClOrdId=''): # Amend algo order def amend_algo_order(self, instId='', algoId='', algoClOrdId='', cxlOnFail='', reqId='', newSz='', newTriggerPx='', newOrdPx='', newTpTriggerPx='', newTpOrdPx='', newSlTriggerPx='', newSlOrdPx='', newTpTriggerPxType='', - newSlTriggerPxType='', pxAmendType=None): + newSlTriggerPxType=''): params = {'instId': instId, 'algoId': algoId, 'algoClOrdId': algoClOrdId, 'cxlOnFail': cxlOnFail, 'reqId': reqId, 'newSz': newSz, 'newTriggerPx': newTriggerPx, 'newOrdPx': newOrdPx, 'newTpTriggerPx': newTpTriggerPx, 'newTpOrdPx': newTpOrdPx, 'newSlTriggerPx': newSlTriggerPx, 'newSlOrdPx': newSlOrdPx, 'newTpTriggerPxType': newTpTriggerPxType, 'newSlTriggerPxType': newSlTriggerPxType} - if pxAmendType is not None: - params['pxAmendType'] = pxAmendType return self._request_with_params(POST, AMEND_ALGO_ORDER, params) def get_oneclick_repay_list_v2(self): From 128da2d464134505f354359b08190fc234da58cd Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Fri, 19 Dec 2025 21:35:38 +0800 Subject: [PATCH 38/48] fix: add testing files --- test/config.py | 1 - test/test_account.py | 11 ++++++-- test/test_grid.py | 26 ++++++++++++++++- test/test_trade.py | 66 ++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 94 insertions(+), 10 deletions(-) diff --git a/test/config.py b/test/config.py index 090e7e8..37fefdb 100644 --- a/test/config.py +++ b/test/config.py @@ -12,7 +12,6 @@ # Try to load from .env file if python-dotenv is available try: from dotenv import load_dotenv - # Load .env from project root env_path = Path(__file__).parent.parent / '.env' load_dotenv(env_path) except ImportError: diff --git a/test/test_account.py b/test/test_account.py index 50b5493..ef71181 100644 --- a/test/test_account.py +++ b/test/test_account.py @@ -145,8 +145,15 @@ def setUp(self): # logger.info(f'{self.AccountAPI.set_auto_repay(autoRepay=True)}') # def test_spot_borrow_repay_history(self): # logger.debug(self.AccountAPI.spot_borrow_repay_history(ccy="USDT",type="auto_borrow",after="1597026383085")) - def test_set_auto_earn(self): - logger.debug(self.AccountAPI.set_auto_earn(ccy="USDT", action="turn_on", earnType='0')) + # def test_set_auto_earn(self): + # logger.debug(self.AccountAPI.set_auto_earn(ccy="USDT", action="turn_on", earnType='0')) + #def test_get_max_loan_with_trade_quote_ccy(self): + # logger.debug(self.AccountAPI.get_max_loan( + # instId="BTC-USDT", + # mgnMode="isolated", + # mgnCcy="USDT", + # tradeQuoteCcy="USDT" + # )) if __name__ == '__main__': unittest.main() diff --git a/test/test_grid.py b/test/test_grid.py index 92b130c..fb14c95 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -79,7 +79,31 @@ def test_withdrawl_profits(self): # def test_get_recurring_buy_sub_orders(self): # print(self.GridAPI.get_recurring_buy_sub_orders(algoId="581191143417970688")) - #581191143417970688 + #def test_grid_order_algo_with_trade_quote_ccy(self): + # print(self.GridAPI.grid_order_algo( + # instId="BTC-USDT", + # algoOrdType="grid", + # maxPx="45000", + # minPx="20000", + # gridNum="100", + # runType="1", + # quoteSz="50", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_recurring_buy_order_with_trade_quote_ccy(self): + # print(self.GridAPI.place_recurring_buy_order( + # stgyName="test_strategy", + # recurringList=[{'ccy': 'ETH', 'ratio': '1'}], + # period="daily", + # recurringDay='1', + # recurringTime='0', + # timeZone='8', + # amt='100', + # investmentCcy='USDT', + # tdMode='cash', + # tradeQuoteCcy="USDT" + # )) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/test_trade.py b/test/test_trade.py index c460943..6da7a54 100644 --- a/test/test_trade.py +++ b/test/test_trade.py @@ -233,12 +233,66 @@ def setUp(self): # def test_close_all_positions(self): # print(self.tradeApi.close_positions(instId="BTC-USDT-SWAP", mgnMode="cross",clOrdId='1213124')) - def test_get_oneclick_repay_list_v2(self): - print(self.tradeApi.get_oneclick_repay_list_v2()) - def test_oneclick_repay_v2(self): - print(self.tradeApi.oneclick_repay_v2('BTC',['USDT'])) - def test_oneclick_repay_history_v2(self): - print(self.tradeApi.oneclick_repay_history_v2()) + #def test_get_oneclick_repay_list_v2(self): + # print(self.tradeApi.get_oneclick_repay_list_v2()) + #def test_oneclick_repay_v2(self): + # print(self.tradeApi.oneclick_repay_v2('BTC',['USDT'])) + #def test_oneclick_repay_history_v2(self): + # print(self.tradeApi.oneclick_repay_history_v2()) + + #def test_place_order_with_trade_quote_ccy(self): + # print(self.tradeApi.place_order( + # instId="BTC-USDT", + # tdMode="cash", + # side="buy", + # ordType="limit", + # sz="0.01", + # px="30000", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_order_with_px_amend_type(self): + # print(self.tradeApi.place_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="limit", + # sz="1", + # px="30000", + # pxAmendType="1" + # )) + + #def test_amend_order_with_px_amend_type(self): + # print(self.tradeApi.amend_order( + # instId="BTC-USDT-SWAP", + # ordId="123", + # newPx="30500", + # pxAmendType="1" + # )) + + #def test_place_algo_order_with_trade_quote_ccy(self): + # print(self.tradeApi.place_algo_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="trigger", + # sz="1", + # triggerPx="30000", + # orderPx="-1", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_algo_order_with_px_amend_type(self): + # print(self.tradeApi.place_algo_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="trigger", + # sz="1", + # triggerPx="30000", + # orderPx="-1", + # pxAmendType="1" + # )) if __name__ == '__main__': unittest.main() From ee3c26378c3c6ae2d5f930e2103fa3516ab05eea Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 10:16:50 +0800 Subject: [PATCH 39/48] websocket enhancement --- okx/websocket/WsPrivateAsync.py | 72 ++++++++++++++++----------------- okx/websocket/WsPublicAsync.py | 16 ++++---- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index 8ca8edd..b7e328e 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -23,11 +23,11 @@ def __init__(self, apiKey, passphrase, secretKey, url, useServerTime=None, debug self.websocket = None self.debug = debug - # 设置日志级别 + # Set log level if debug: logger.setLevel(logging.DEBUG) - # 废弃 useServerTime 参数警告 + # Deprecation warning for useServerTime parameter if useServerTime is not None: warnings.warn("useServerTime parameter is deprecated. Please remove it.", DeprecationWarning) @@ -88,11 +88,11 @@ async def unsubscribe(self, params: list, callback, id: str = None): async def send(self, op: str, args: list, callback=None, id: str = None): """ - 通用发送方法 - :param op: 操作类型 - :param args: 参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -109,10 +109,10 @@ async def send(self, op: str, args: list, callback=None, id: str = None): async def place_order(self, args: list, callback=None, id: str = None): """ - 下单 - :param args: 下单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Place order + :param args: Order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -120,10 +120,10 @@ async def place_order(self, args: list, callback=None, id: str = None): async def batch_orders(self, args: list, callback=None, id: str = None): """ - 批量下单 - :param args: 批量下单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Batch place orders + :param args: Batch order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -131,10 +131,10 @@ async def batch_orders(self, args: list, callback=None, id: str = None): async def cancel_order(self, args: list, callback=None, id: str = None): """ - 撤单 - :param args: 撤单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Cancel order + :param args: Cancel order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -142,10 +142,10 @@ async def cancel_order(self, args: list, callback=None, id: str = None): async def batch_cancel_orders(self, args: list, callback=None, id: str = None): """ - 批量撤单 - :param args: 批量撤单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Batch cancel orders + :param args: Batch cancel order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -153,10 +153,10 @@ async def batch_cancel_orders(self, args: list, callback=None, id: str = None): async def amend_order(self, args: list, callback=None, id: str = None): """ - 改单 - :param args: 改单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Amend order + :param args: Amend order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -164,10 +164,10 @@ async def amend_order(self, args: list, callback=None, id: str = None): async def batch_amend_orders(self, args: list, callback=None, id: str = None): """ - 批量改单 - :param args: 批量改单参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Batch amend orders + :param args: Batch amend order parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback @@ -175,11 +175,11 @@ async def batch_amend_orders(self, args: list, callback=None, id: str = None): async def mass_cancel(self, args: list, callback=None, id: str = None): """ - Mass cancel (批量撤销) - 注意:此方法用于 /ws/v5/business 频道,限速 1次/秒 - :param args: 撤销参数列表,包含 instType 和 instFamily - :param callback: 回调函数 - :param id: 可选的请求ID + Mass cancel orders + Note: This method is for /ws/v5/business channel, rate limit: 1 request/second + :param args: Cancel parameter list, contains instType and instFamily + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index 1a7e28d..d7eefdc 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -17,13 +17,13 @@ def __init__(self, url, apiKey='', passphrase='', secretKey='', debug=False): self.factory = WebSocketFactory(url) self.websocket = None self.debug = debug - # 用于 business 频道的登录凭证 + # Credentials for business channel login self.apiKey = apiKey self.passphrase = passphrase self.secretKey = secretKey self.isLoggedIn = False - # 设置日志级别 + # Set log level if debug: logger.setLevel(logging.DEBUG) @@ -39,7 +39,7 @@ async def consume(self): async def login(self): """ - 登录方法,用于需要登录的 business 频道(如 /ws/v5/business) + Login method for business channel that requires authentication (e.g. /ws/v5/business) """ if not self.apiKey or not self.secretKey or not self.passphrase: raise ValueError("apiKey, secretKey and passphrase are required for login") @@ -87,11 +87,11 @@ async def unsubscribe(self, params: list, callback, id: str = None): async def send(self, op: str, args: list, callback=None, id: str = None): """ - 通用发送方法 - :param op: 操作类型 - :param args: 参数列表 - :param callback: 回调函数 - :param id: 可选的请求ID + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID """ if callback: self.callback = callback From 4813e7e04163527b7f2252de27c7255881128c40 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 10:20:17 +0800 Subject: [PATCH 40/48] websocket enhancement --- test/unit/okx/websocket/test_ws_public_async.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py index 9ac561c..ca27b38 100644 --- a/test/unit/okx/websocket/test_ws_public_async.py +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -111,9 +111,6 @@ async def run_test(): class TestWsPublicAsyncSubscribe(unittest.TestCase): """Unit tests for WsPublicAsync subscribe method""" - def test_subscribe_sets_callback(self): - """Test subscribe sets callback correctly""" - with patch.object(ws_public_module, 'WebSocketFactory'): def test_subscribe_without_id(self): """Test subscribe without id parameter""" with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): From 9dfb5b64cbef3af1da5f7ed85ffd68d12e2dd5c0 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 10:45:07 +0800 Subject: [PATCH 41/48] websocket enhancement --- test/test_ws_private_async.py | 49 ++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/test/test_ws_private_async.py b/test/test_ws_private_async.py index fbee060..10f34bf 100644 --- a/test/test_ws_private_async.py +++ b/test/test_ws_private_async.py @@ -51,11 +51,12 @@ async def test_place_order(): Test place order functionality URL: /ws/v5/private (Rate limit: 60 requests/second) """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -83,11 +84,12 @@ async def test_batch_orders(): Test batch orders functionality URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -126,11 +128,12 @@ async def test_cancel_order(): Test cancel order functionality URL: /ws/v5/private (Rate limit: 60 requests/second) """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -154,11 +157,12 @@ async def test_batch_cancel_orders(): Test batch cancel orders functionality URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -180,11 +184,12 @@ async def test_amend_order(): Test amend order functionality URL: /ws/v5/private (Rate limit: 60 requests/second) """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -210,11 +215,12 @@ async def test_mass_cancel(): URL: /ws/v5/business (Rate limit: 1 request/second) Note: This function uses the business channel """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) @@ -234,11 +240,12 @@ async def test_mass_cancel(): async def test_send_method(): """Test generic send method""" + api_key, api_secret_key, passphrase, _ = get_api_credentials() url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, url=url, debug=True ) From fb237f6a72b6bc448d8a495317c8322a26bc07f2 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 10:55:13 +0800 Subject: [PATCH 42/48] websocket enhancement --- test/unit/okx/websocket/test_ws_private_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py index 35ea1cc..531f5a8 100644 --- a/test/unit/okx/websocket/test_ws_private_async.py +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -545,7 +545,7 @@ async def run_test(): result = await ws.login() self.assertTrue(result) mock_ws_utils.initLoginParams.assert_called_once_with( - useServerTime=True, + useServerTime=False, apiKey="test_api_key", passphrase="test_passphrase", secretKey="test_secret_key" From 2dfafef97d9bd9085b2f04da51f141ae37556eb1 Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 11:19:53 +0800 Subject: [PATCH 43/48] config update --- test/config.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/test/config.py b/test/config.py index 37fefdb..78885b3 100644 --- a/test/config.py +++ b/test/config.py @@ -7,15 +7,39 @@ api_key, api_secret, passphrase, flag = get_api_credentials() """ import os +import logging from pathlib import Path -# Try to load from .env file if python-dotenv is available -try: - from dotenv import load_dotenv +logger = logging.getLogger(__name__) + +# Flag to ensure .env is loaded only once +_env_loaded = False + + +def _load_env_once(): + """Load .env file only once, log any exceptions.""" + global _env_loaded + if _env_loaded: + return + + _env_loaded = True env_path = Path(__file__).parent.parent / '.env' - load_dotenv(env_path) -except ImportError: - pass # python-dotenv not installed, rely on system environment variables + + try: + from dotenv import load_dotenv + if env_path.exists(): + load_dotenv(env_path) + logger.debug(f"Loaded .env file from: {env_path}") + else: + logger.warning(f".env file not found at: {env_path}") + except ImportError: + logger.warning("python-dotenv not installed, relying on system environment variables") + except Exception as e: + logger.error(f"Failed to load .env file: {e}") + + +# Load .env when module is imported +_load_env_once() def get_api_credentials(): From c35c35545df392efebf7911b3014e1e264bf2339 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Mon, 22 Dec 2025 12:13:50 +0800 Subject: [PATCH 44/48] feat: add dependency file and cicd --- .github/workflows/ci.yml | 72 +++++++++++++++++++++++++++++++++++----- requirements-dev.txt | 19 +++++++++++ requirements.txt | 7 ++++ setup.py | 56 +++++++++++++++++++++++-------- 4 files changed, 131 insertions(+), 23 deletions(-) create mode 100644 requirements-dev.txt create mode 100644 requirements.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e0717a..9570f8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,48 @@ on: - 'releases/*' jobs: + # ============================================ + # DEPENDENCY CHECK JOB + # Ensures all imports are satisfied by requirements.txt + # ============================================ + dependency-check: + name: Dependency Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install only from requirements.txt (clean environment) + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Verify all production imports work + run: | + # This catches missing dependencies in requirements.txt + python -c " + import okx + from okx import Account, Trade, Funding, MarketData, PublicData + from okx import SubAccount, Convert, BlockTrading, CopyTrading + from okx import SpreadTrading, Grid, TradingData, Status + from okx.websocket import WsPublicAsync, WsPrivateAsync + print('✅ All imports successful') + print(f' okx version: {okx.__version__}') + " + + - name: Install dev dependencies and verify test imports + run: | + pip install -r requirements-dev.txt + python -c " + import pytest + import unittest + print('✅ Dev imports successful') + " + # ============================================ # LINT JOB # ============================================ @@ -31,11 +73,11 @@ jobs: - name: Install linting tools run: | python -m pip install --upgrade pip - pip install flake8 + pip install ruff - - name: Run flake8 + - name: Run ruff run: | - flake8 okx/ --max-line-length=120 --ignore=E501,W503,E203 --count --show-source --statistics + ruff check okx/ --ignore=E501 continue-on-error: true # Set to false once codebase is cleaned up # ============================================ @@ -44,10 +86,12 @@ jobs: test: name: Test (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest + needs: [dependency-check] # Only run tests if dependencies are valid strategy: - fail-fast: false + fail-fast: true + max-parallel: 1 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.12"] steps: - uses: actions/checkout@v4 @@ -61,7 +105,7 @@ jobs: uses: actions/cache@v4 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.py') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements*.txt') }} restore-keys: | ${{ runner.os }}-pip-${{ matrix.python-version }}- ${{ runner.os }}-pip- @@ -69,15 +113,15 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install -r requirements-dev.txt pip install -e . - pip install pytest pytest-cov pytest-asyncio websockets certifi - name: Run tests run: | python -m pytest test/unit/ -v --cov=okx --cov-report=term-missing --cov-report=xml - name: Upload coverage to Codecov - if: matrix.python-version == '3.11' + if: matrix.python-version == '3.12' uses: codecov/codecov-action@v4 with: file: ./coverage.xml @@ -110,10 +154,20 @@ jobs: - name: Check package run: twine check dist/* + - name: Test install from wheel (clean environment) + run: | + # Create a fresh venv and install the built wheel + python -m venv /tmp/test-install + /tmp/test-install/bin/pip install dist/*.whl + /tmp/test-install/bin/python -c " + import okx + from okx import Account, Trade, Funding + print(f'✅ Package installs correctly: okx {okx.__version__}') + " + - name: Upload build artifacts uses: actions/upload-artifact@v4 with: name: dist path: dist/ retention-days: 7 - diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..94c4dd4 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,19 @@ +# Development dependencies +-r requirements.txt + +# Testing +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 + +# Linting & Formatting +ruff>=0.1.0 +mypy>=1.0.0 + +# Environment +python-dotenv>=1.0.0 + +# Build tools +build>=1.0.0 +twine>=4.0.0 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f361bd9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +# Core dependencies for python-okx +httpx[http2]>=0.24.0 +requests>=2.25.0 +websockets>=10.0 +certifi>=2021.0.0 +loguru>=0.7.0 + diff --git a/setup.py b/setup.py index ce2838d..3dfb2d1 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,36 @@ import setuptools +from pathlib import Path + +# Read version from package import okx -with open("README.md", "r",encoding="utf-8") as fh: + +# Read README +with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() + +def parse_requirements(filename): # type: (str) -> list + """Parse requirements from a requirements file.""" + requirements_path = Path(__file__).parent / filename + requirements = [] + + if not requirements_path.exists(): + return requirements + + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + # Skip empty lines, comments, and -r includes + if not line or line.startswith("#") or line.startswith("-r"): + continue + # Handle inline comments + if "#" in line: + line = line.split("#")[0].strip() + requirements.append(line) + + return requirements + + setuptools.setup( name="python-okx", version=okx.__version__, @@ -12,21 +40,21 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://okx.com/docs-v5/", - packages=setuptools.find_packages(), + packages=setuptools.find_packages(exclude=["test", "test.*", "example"]), + python_requires=">=3.7", classifiers=[ "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - install_requires=[ - "importlib-metadata", - "httpx[http2]", - "keyring", - "loguru", - "requests", - "Twisted", - "pyOpenSSL", - "websockets", - "certifi" - ] -) \ No newline at end of file + install_requires=parse_requirements("requirements.txt"), + extras_require={ + "dev": parse_requirements("requirements-dev.txt"), + }, +) From a278d23e3ce1c70d7cb75efb6411c0df02a63edb Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Mon, 22 Dec 2025 12:26:21 +0800 Subject: [PATCH 45/48] feat: rmv requirements-dev.txt --- .github/workflows/ci.yml | 10 +++--- README.md | 72 ++++++++++++++++++++++++++++++++++++---- requirements-dev.txt | 19 ----------- requirements.txt | 10 +++++- setup.py | 3 -- 5 files changed, 80 insertions(+), 34 deletions(-) delete mode 100644 requirements-dev.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9570f8d..83b0bc3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,13 +47,13 @@ jobs: print(f' okx version: {okx.__version__}') " - - name: Install dev dependencies and verify test imports + - name: Verify test imports run: | - pip install -r requirements-dev.txt + pip install pytest python -c " import pytest import unittest - print('✅ Dev imports successful') + print('✅ Test imports successful') " # ============================================ @@ -105,7 +105,7 @@ jobs: uses: actions/cache@v4 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements*.txt') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }} restore-keys: | ${{ runner.os }}-pip-${{ matrix.python-version }}- ${{ runner.os }}-pip- @@ -113,7 +113,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements-dev.txt + pip install -r requirements.txt pip install -e . - name: Run tests diff --git a/README.md b/README.md index 19aac31..3e2d5b0 100644 --- a/README.md +++ b/README.md @@ -20,20 +20,80 @@ Make sure you update often and check the [Changelog](https://www.okx.com/docs-v5 ### Quick start #### Prerequisites -`python version:>=3.9` +`python version:>=3.7` -`WebSocketAPI: websockets package advise version 6.0` - -#### Step 1: register an account on OKX and apply for an API key +#### Step 1: Register an account on OKX and apply for an API key - Register for an account: https://www.okx.com/account/register - Apply for an API key: https://www.okx.com/account/users/myApi -#### Step 2: install python-okx +#### Step 2: Install python-okx -```python +```bash pip install python-okx ``` +### API Credentials + +#### Option 1: Hardcoded credentials + +```python +from okx import Account + +account = Account.AccountAPI( + api_key="your-api-key-here", + api_secret_key="your-api-secret-here", + passphrase="your-passphrase-here", + flag="1", # 0 = live trading, 1 = demo trading + debug=False +) +``` + +#### Option 2: Using `.env` file (recommended) + +Create a `.env` file in your project root: + +```bash +OKX_API_KEY=your-api-key-here +OKX_API_SECRET=your-api-secret-here +OKX_PASSPHRASE=your-passphrase-here +OKX_FLAG=1 +``` + +Then load it in your code: + +```python +import os +from dotenv import load_dotenv +from okx import Account + +load_dotenv() + +account = Account.AccountAPI( + api_key=os.getenv('OKX_API_KEY'), + api_secret_key=os.getenv('OKX_API_SECRET'), + passphrase=os.getenv('OKX_PASSPHRASE'), + flag=os.getenv('OKX_FLAG', '1'), + debug=False +) +``` + +### Development Setup + +For contributors or local development: + +```bash +# Clone the repository +git clone https://github.com/okxapi/python-okx.git +cd python-okx + +# Install dependencies +pip install -r requirements.txt +pip install -e . + +# Run tests +pytest test/unit/ -v +``` + #### Step 3: Run examples - Fill in API credentials in the corresponding examples diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 94c4dd4..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Development dependencies --r requirements.txt - -# Testing -pytest>=7.0.0 -pytest-asyncio>=0.21.0 -pytest-cov>=4.0.0 - -# Linting & Formatting -ruff>=0.1.0 -mypy>=1.0.0 - -# Environment -python-dotenv>=1.0.0 - -# Build tools -build>=1.0.0 -twine>=4.0.0 - diff --git a/requirements.txt b/requirements.txt index f361bd9..2926476 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,15 @@ -# Core dependencies for python-okx +# Core dependencies httpx[http2]>=0.24.0 requests>=2.25.0 websockets>=10.0 certifi>=2021.0.0 loguru>=0.7.0 +python-dotenv>=1.0.0 +# Development & Testing +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 +ruff>=0.1.0 +build>=1.0.0 +twine>=4.0.0 diff --git a/setup.py b/setup.py index 3dfb2d1..1eef4bb 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,4 @@ def parse_requirements(filename): # type: (str) -> list "Operating System :: OS Independent", ], install_requires=parse_requirements("requirements.txt"), - extras_require={ - "dev": parse_requirements("requirements-dev.txt"), - }, ) From 299acc477654c706314e189a2320203495f08f6d Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Mon, 22 Dec 2025 12:34:08 +0800 Subject: [PATCH 46/48] fix: ci build failure --- .github/workflows/ci.yml | 6 +++--- requirements.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83b0bc3..bd4ca7f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -143,13 +143,13 @@ jobs: with: python-version: "3.11" - - name: Install build tools + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install build twine + pip install -r requirements.txt - name: Build package - run: python -m build + run: python -m build --no-isolation - name: Check package run: twine check dist/* diff --git a/requirements.txt b/requirements.txt index 2926476..b9a5daf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# Core dependencies +# Dependencies for python-okx httpx[http2]>=0.24.0 requests>=2.25.0 websockets>=10.0 From 2b2936a6ac37ac5e0ea6b8c1a1aaa026d312be99 Mon Sep 17 00:00:00 2001 From: "jason.hsu" Date: Mon, 22 Dec 2025 12:38:52 +0800 Subject: [PATCH 47/48] fix: ci build failure --- MANIFEST.in | 3 +++ setup.py | 23 +++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..23682e9 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include requirements.txt +include README.md + diff --git a/setup.py b/setup.py index 1eef4bb..57568e0 100644 --- a/setup.py +++ b/setup.py @@ -1,27 +1,30 @@ +import os import setuptools -from pathlib import Path + +# Get the directory where setup.py is located +HERE = os.path.dirname(os.path.abspath(__file__)) # Read version from package import okx # Read README -with open("README.md", "r", encoding="utf-8") as fh: +with open(os.path.join(HERE, "README.md"), "r", encoding="utf-8") as fh: long_description = fh.read() -def parse_requirements(filename): # type: (str) -> list - """Parse requirements from a requirements file.""" - requirements_path = Path(__file__).parent / filename +def parse_requirements(): + """Parse requirements from requirements.txt.""" requirements = [] + req_path = os.path.join(HERE, "requirements.txt") - if not requirements_path.exists(): + if not os.path.exists(req_path): return requirements - with open(requirements_path, "r", encoding="utf-8") as f: + with open(req_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() - # Skip empty lines, comments, and -r includes - if not line or line.startswith("#") or line.startswith("-r"): + # Skip empty lines and comments + if not line or line.startswith("#"): continue # Handle inline comments if "#" in line: @@ -53,5 +56,5 @@ def parse_requirements(filename): # type: (str) -> list "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - install_requires=parse_requirements("requirements.txt"), + install_requires=parse_requirements(), ) From 7e40b520a4fd7405c16519bf6454e29ec2e7ad7e Mon Sep 17 00:00:00 2001 From: "zihao.jiang" Date: Mon, 22 Dec 2025 15:04:12 +0800 Subject: [PATCH 48/48] add tradeQuoteCcy request param to the trade-related endpoints --- okx/Account.py | 8 +- test/unit/okx/test_account.py | 140 ++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/okx/Account.py b/okx/Account.py index a6b182e..5fddf9f 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -76,14 +76,18 @@ def set_leverage(self, lever, mgnMode, instId='', ccy='', posSide=''): return self._request_with_params(POST, SET_LEVERAGE, params) # Get Maximum Tradable Size For Instrument - def get_max_order_size(self, instId, tdMode, ccy='', px=''): + def get_max_order_size(self, instId, tdMode, ccy='', px='', tradeQuoteCcy=None): params = {'instId': instId, 'tdMode': tdMode, 'ccy': ccy, 'px': px} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_TRADE_SIZE, params) # Get Maximum Available Tradable Amount - def get_max_avail_size(self, instId, tdMode, ccy='', reduceOnly='', unSpotOffset='', quickMgnType=''): + def get_max_avail_size(self, instId, tdMode, ccy='', reduceOnly='', unSpotOffset='', quickMgnType='', tradeQuoteCcy=None): params = {'instId': instId, 'tdMode': tdMode, 'ccy': ccy, 'reduceOnly': reduceOnly, 'unSpotOffset': unSpotOffset, 'quickMgnType': quickMgnType} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_AVAIL_SIZE, params) # Increase / Decrease margin diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py index 75794df..57f949d 100644 --- a/test/unit/okx/test_account.py +++ b/test/unit/okx/test_account.py @@ -552,6 +552,146 @@ def test_set_auto_earn_different_currencies(self, mock_request): self.assertEqual(call_args['ccy'], ccy) +class TestAccountAPIGetMaxOrderSize(unittest.TestCase): + """Unit tests for the get_max_order_size method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_with_required_params(self, mock_request): + """Test get_max_order_size with required parameters only""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USDT', + tdMode='cash' + ) + + expected_params = { + 'instId': 'BTC-USDT', + 'tdMode': 'cash', + 'ccy': '', + 'px': '' + } + mock_request.assert_called_once_with(c.GET, c.MAX_TRADE_SIZE, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_with_tradeQuoteCcy(self, mock_request): + """Test get_max_order_size with tradeQuoteCcy parameter for Unified USD Orderbook""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USD', + tdMode='cash', + tradeQuoteCcy='USDC' + ) + + expected_params = { + 'instId': 'BTC-USD', + 'tdMode': 'cash', + 'ccy': '', + 'px': '', + 'tradeQuoteCcy': 'USDC' + } + mock_request.assert_called_once_with(c.GET, c.MAX_TRADE_SIZE, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_without_tradeQuoteCcy(self, mock_request): + """Test get_max_order_size without tradeQuoteCcy (should not include in params)""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USDT', + tdMode='cash' + ) + + call_args = mock_request.call_args[0][2] + self.assertNotIn('tradeQuoteCcy', call_args) + + +class TestAccountAPIGetMaxAvailSize(unittest.TestCase): + """Unit tests for the get_max_avail_size method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_with_required_params(self, mock_request): + """Test get_max_avail_size with required parameters only""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USDT', + tdMode='cash' + ) + + expected_params = { + 'instId': 'BTC-USDT', + 'tdMode': 'cash', + 'ccy': '', + 'reduceOnly': '', + 'unSpotOffset': '', + 'quickMgnType': '' + } + mock_request.assert_called_once_with(c.GET, c.MAX_AVAIL_SIZE, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_with_tradeQuoteCcy(self, mock_request): + """Test get_max_avail_size with tradeQuoteCcy parameter for Unified USD Orderbook""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USD', + tdMode='cash', + tradeQuoteCcy='USDC' + ) + + expected_params = { + 'instId': 'BTC-USD', + 'tdMode': 'cash', + 'ccy': '', + 'reduceOnly': '', + 'unSpotOffset': '', + 'quickMgnType': '', + 'tradeQuoteCcy': 'USDC' + } + mock_request.assert_called_once_with(c.GET, c.MAX_AVAIL_SIZE, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_without_tradeQuoteCcy(self, mock_request): + """Test get_max_avail_size without tradeQuoteCcy (should not include in params)""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USDT', + tdMode='cash' + ) + + call_args = mock_request.call_args[0][2] + self.assertNotIn('tradeQuoteCcy', call_args) + + if __name__ == '__main__': unittest.main()