diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..0f777f4 --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +# OKX API Credentials +# Copy this file to .env and fill in your actual credentials +# NEVER commit .env to version control! + +OKX_API_KEY=your_api_key_here +OKX_API_SECRET=your_api_secret_here +OKX_PASSPHRASE=your_passphrase_here + +# Optional: Set to '0' for live trading, '1' for demo trading +OKX_FLAG=1 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..bd4ca7f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,173 @@ +# GitHub Actions CI/CD Configuration for python-okx +name: CI + +on: + push: + branches: + - master + - 'release/*' + - 'releases/*' + pull_request: + branches: + - master + - 'release/*' + - 'releases/*' + +jobs: + # ============================================ + # DEPENDENCY CHECK JOB + # Ensures all imports are satisfied by requirements.txt + # ============================================ + dependency-check: + name: Dependency Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install only from requirements.txt (clean environment) + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Verify all production imports work + run: | + # This catches missing dependencies in requirements.txt + python -c " + import okx + from okx import Account, Trade, Funding, MarketData, PublicData + from okx import SubAccount, Convert, BlockTrading, CopyTrading + from okx import SpreadTrading, Grid, TradingData, Status + from okx.websocket import WsPublicAsync, WsPrivateAsync + print('✅ All imports successful') + print(f' okx version: {okx.__version__}') + " + + - name: Verify test imports + run: | + pip install pytest + python -c " + import pytest + import unittest + print('✅ Test imports successful') + " + + # ============================================ + # LINT JOB + # ============================================ + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install linting tools + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Run ruff + run: | + ruff check okx/ --ignore=E501 + continue-on-error: true # Set to false once codebase is cleaned up + + # ============================================ + # TEST JOB + # ============================================ + test: + name: Test (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + needs: [dependency-check] # Only run tests if dependencies are valid + strategy: + fail-fast: true + max-parallel: 1 + matrix: + python-version: ["3.9", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + + - name: Run tests + run: | + python -m pytest test/unit/ -v --cov=okx --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + if: matrix.python-version == '3.12' + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + fail_ci_if_error: false + + # ============================================ + # BUILD JOB + # ============================================ + build: + name: Build Package + runs-on: ubuntu-latest + needs: [test] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Build package + run: python -m build --no-isolation + + - name: Check package + run: twine check dist/* + + - name: Test install from wheel (clean environment) + run: | + # Create a fresh venv and install the built wheel + python -m venv /tmp/test-install + /tmp/test-install/bin/pip install dist/*.whl + /tmp/test-install/bin/python -c " + import okx + from okx import Account, Trade, Funding + print(f'✅ Package installs correctly: okx {okx.__version__}') + " + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + retention-days: 7 diff --git a/.gitignore b/.gitignore index 10c5a42..41c5c7c 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,9 @@ build/ ### VS Code ### .vscode/ -id_rsa* \ No newline at end of file +id_rsa* + +# Environment files +.env +.env.local +.env.*.local \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..23682e9 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include requirements.txt +include README.md + diff --git a/README.md b/README.md index 19aac31..3e2d5b0 100644 --- a/README.md +++ b/README.md @@ -20,20 +20,80 @@ Make sure you update often and check the [Changelog](https://www.okx.com/docs-v5 ### Quick start #### Prerequisites -`python version:>=3.9` +`python version:>=3.7` -`WebSocketAPI: websockets package advise version 6.0` - -#### Step 1: register an account on OKX and apply for an API key +#### Step 1: Register an account on OKX and apply for an API key - Register for an account: https://www.okx.com/account/register - Apply for an API key: https://www.okx.com/account/users/myApi -#### Step 2: install python-okx +#### Step 2: Install python-okx -```python +```bash pip install python-okx ``` +### API Credentials + +#### Option 1: Hardcoded credentials + +```python +from okx import Account + +account = Account.AccountAPI( + api_key="your-api-key-here", + api_secret_key="your-api-secret-here", + passphrase="your-passphrase-here", + flag="1", # 0 = live trading, 1 = demo trading + debug=False +) +``` + +#### Option 2: Using `.env` file (recommended) + +Create a `.env` file in your project root: + +```bash +OKX_API_KEY=your-api-key-here +OKX_API_SECRET=your-api-secret-here +OKX_PASSPHRASE=your-passphrase-here +OKX_FLAG=1 +``` + +Then load it in your code: + +```python +import os +from dotenv import load_dotenv +from okx import Account + +load_dotenv() + +account = Account.AccountAPI( + api_key=os.getenv('OKX_API_KEY'), + api_secret_key=os.getenv('OKX_API_SECRET'), + passphrase=os.getenv('OKX_PASSPHRASE'), + flag=os.getenv('OKX_FLAG', '1'), + debug=False +) +``` + +### Development Setup + +For contributors or local development: + +```bash +# Clone the repository +git clone https://github.com/okxapi/python-okx.git +cd python-okx + +# Install dependencies +pip install -r requirements.txt +pip install -e . + +# Run tests +pytest test/unit/ -v +``` + #### Step 3: Run examples - Fill in API credentials in the corresponding examples diff --git a/okx/Account.py b/okx/Account.py index 911cf9b..5fddf9f 100644 --- a/okx/Account.py +++ b/okx/Account.py @@ -27,21 +27,23 @@ def get_positions(self, instType='', instId='', posId=''): params = {'instType': instType, 'instId': instId, 'posId': posId} return self._request_with_params(GET, POSITION_INFO, params) - def position_builder(self, acctLv=None,inclRealPosAndEq=False, lever=None, greeksType=None, simPos=None, - simAsset=None): + def position_builder(self, acctLv=None, inclRealPosAndEq=None, lever=None, greeksType=None, simPos=None, + simAsset=None, idxVol=None): params = {} if acctLv is not None: params['acctLv'] = acctLv if inclRealPosAndEq is not None: params['inclRealPosAndEq'] = inclRealPosAndEq if lever is not None: - params['spotOffsetType'] = lever + params['lever'] = lever if greeksType is not None: - params['greksType'] = greeksType + params['greeksType'] = greeksType if simPos is not None: params['simPos'] = simPos if simAsset is not None: params['simAsset'] = simAsset + if idxVol is not None: + params['idxVol'] = idxVol return self._request_with_params(POST, POSITION_BUILDER, params) # Get Bills Details (recent 7 days) @@ -74,14 +76,18 @@ def set_leverage(self, lever, mgnMode, instId='', ccy='', posSide=''): return self._request_with_params(POST, SET_LEVERAGE, params) # Get Maximum Tradable Size For Instrument - def get_max_order_size(self, instId, tdMode, ccy='', px=''): + def get_max_order_size(self, instId, tdMode, ccy='', px='', tradeQuoteCcy=None): params = {'instId': instId, 'tdMode': tdMode, 'ccy': ccy, 'px': px} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_TRADE_SIZE, params) # Get Maximum Available Tradable Amount - def get_max_avail_size(self, instId, tdMode, ccy='', reduceOnly='', unSpotOffset='', quickMgnType=''): + def get_max_avail_size(self, instId, tdMode, ccy='', reduceOnly='', unSpotOffset='', quickMgnType='', tradeQuoteCcy=None): params = {'instId': instId, 'tdMode': tdMode, 'ccy': ccy, 'reduceOnly': reduceOnly, 'unSpotOffset': unSpotOffset, 'quickMgnType': quickMgnType} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_AVAIL_SIZE, params) # Increase / Decrease margin @@ -100,8 +106,10 @@ def get_instruments(self, instType='', ugly='', instFamily='', instId=''): return self._request_with_params(GET, GET_INSTRUMENTS, params) # Get the maximum loan of isolated MARGIN - def get_max_loan(self, instId, mgnMode, mgnCcy=''): + def get_max_loan(self, instId, mgnMode, mgnCcy='', tradeQuoteCcy=None): params = {'instId': instId, 'mgnMode': mgnMode, 'mgnCcy': mgnCcy} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(GET, MAX_LOAN, params) # Get Fee Rates @@ -323,3 +331,9 @@ def set_auto_repay(self, autoRepay=False): def spot_borrow_repay_history(self, ccy='', type='', after='', before='', limit=''): params = {'ccy': ccy, 'type': type, 'after': after, 'before': before, 'limit': limit} return self._request_with_params(GET, GET_BORROW_REPAY_HISTORY, params) + + def set_auto_earn(self, ccy, action, earnType=None): + params = {'ccy': ccy, 'action': action} + if earnType is not None: + params['earnType'] = earnType + return self._request_with_params(POST, SET_AUTO_EARN, params) diff --git a/okx/Funding.py b/okx/Funding.py index 6ca0344..a984778 100644 --- a/okx/Funding.py +++ b/okx/Funding.py @@ -35,9 +35,11 @@ def funds_transfer(self, ccy, amt, from_, to, type='0', subAcct='', instId='', t return self._request_with_params(POST, FUNDS_TRANSFER, params) # Withdrawal - def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId=''): + def withdrawal(self, ccy, amt, dest, toAddr, chain='', areaCode='', clientId='', toAddrType=None): params = {'ccy': ccy, 'amt': amt, 'dest': dest, 'toAddr': toAddr, 'chain': chain, 'areaCode': areaCode, 'clientId': clientId} + if toAddrType is not None: + params['toAddrType'] = toAddrType return self._request_with_params(POST, WITHDRAWAL_COIN, params) # Get Deposit History @@ -46,10 +48,7 @@ def get_deposit_history(self, ccy='', type='', state='', after='', before='', li 'depId': depId, 'fromWdId': fromWdId} return self._request_with_params(GET, DEPOSIT_HISTORY, params) - # Get Withdrawal History - def get_withdrawal_history(self, ccy='', wdId='', state='', after='', before='', limit='',txId=''): - params = {'ccy': ccy, 'wdId': wdId, 'state': state, 'after': after, 'before': before, 'limit': limit,'txId':txId} - return self._request_with_params(GET, WITHDRAWAL_HISTORY, params) + # Get Currencies def get_currencies(self, ccy=''): @@ -113,7 +112,9 @@ def get_deposit_withdraw_status(self, wdId='', txId='', ccy='', to='', chain='') return self._request_with_params(GET, GET_DEPOSIT_WITHDrAW_STATUS, params) #Get withdrawal history - def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before ='', limit=''): + def get_withdrawal_history(self, ccy='', wdId='', clientId='', txId='', type='', state='', after='', before='', limit='', toAddrType=None): params = {'ccy': ccy, 'wdId': wdId, 'clientId': clientId, 'txId': txId, 'type': type, 'state': state, 'after': after, 'before': before, 'limit': limit} + if toAddrType is not None: + params['toAddrType'] = toAddrType return self._request_with_params(GET, GET_WITHDRAWAL_HISTORY, params) diff --git a/okx/Grid.py b/okx/Grid.py index 74b6a9b..d761708 100644 --- a/okx/Grid.py +++ b/okx/Grid.py @@ -7,11 +7,13 @@ def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_serve OkxClient.__init__(self, api_key, api_secret_key, passphrase, use_server_time, flag, domain, debug, proxy) def grid_order_algo(self, instId='', algoOrdType='', maxPx='', minPx='', gridNum='', runType='', tpTriggerPx='', - slTriggerPx='', tag='', quoteSz='', baseSz='', sz='', direction='', lever='', basePos=''): + slTriggerPx='', tag='', quoteSz='', baseSz='', sz='', direction='', lever='', basePos='', tradeQuoteCcy=None): params = {'instId': instId, 'algoOrdType': algoOrdType, 'maxPx': maxPx, 'minPx': minPx, 'gridNum': gridNum, 'runType': runType, 'tpTriggerPx': tpTriggerPx, 'slTriggerPx': slTriggerPx, 'tag': tag, 'quoteSz': quoteSz, 'baseSz': baseSz, 'sz': sz, 'direction': direction, 'lever': lever, 'basePos': basePos} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(POST, GRID_ORDER_ALGO, params) def grid_amend_order_algo(self, algoId='', instId='', slTriggerPx='', tpTriggerPx=''): @@ -79,11 +81,13 @@ def grid_ai_param(self, algoOrdType='', instId='', direction='', duration=''): # - Place recurring buy order def place_recurring_buy_order(self, stgyName='', recurringList=[], period='', recurringDay='', recurringTime='', - timeZone='', amt='', investmentCcy='', tdMode='', algoClOrdId='', tag=''): + timeZone='', amt='', investmentCcy='', tdMode='', algoClOrdId='', tag='', tradeQuoteCcy=None): params = {'stgyName': stgyName, 'recurringList': recurringList, 'period': period, 'recurringDay': recurringDay, 'recurringTime': recurringTime, 'timeZone': timeZone, 'amt': amt, 'investmentCcy': investmentCcy, 'tdMode': tdMode, 'algoClOrdId': algoClOrdId, 'tag': tag} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy return self._request_with_params(POST, PLACE_RECURRING_BUY_ORDER, params) # - Amend recurring buy order diff --git a/okx/PublicData.py b/okx/PublicData.py index b3bb961..f441266 100644 --- a/okx/PublicData.py +++ b/okx/PublicData.py @@ -122,3 +122,18 @@ def get_option_trades(self, instId='', instFamily='', optType=''): 'optType': optType } return self._request_with_params(GET, GET_OPTION_TRADES, params) + + # Get historical market data + def get_market_data_history(self, module, instType, dateAggrType, begin, end, instIdList=None, instFamilyList=None): + params = { + 'module': module, + 'instType': instType, + 'dateAggrType': dateAggrType, + 'begin': begin, + 'end': end + } + if instIdList is not None: + params['instIdList'] = instIdList + if instFamilyList is not None: + params['instFamilyList'] = instFamilyList + return self._request_with_params(GET, MARKET_DATA_HISTORY, params) diff --git a/okx/Trade.py b/okx/Trade.py index db994ae..9d6d9e5 100644 --- a/okx/Trade.py +++ b/okx/Trade.py @@ -12,10 +12,14 @@ def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_serve # Place Order def place_order(self, instId, tdMode, side, ordType, sz, ccy='', clOrdId='', tag='', posSide='', px='', - reduceOnly='', tgtCcy='', stpMode='', attachAlgoOrds=None, pxUsd='', pxVol='', banAmend='', tradeQuoteCcy=''): + reduceOnly='', tgtCcy='', stpMode='', attachAlgoOrds=None, pxUsd='', pxVol='', banAmend='', tradeQuoteCcy=None, pxAmendType=None): params = {'instId': instId, 'tdMode': tdMode, 'side': side, 'ordType': ordType, 'sz': sz, 'ccy': ccy, 'clOrdId': clOrdId, 'tag': tag, 'posSide': posSide, 'px': px, 'reduceOnly': reduceOnly, - 'tgtCcy': tgtCcy, 'stpMode': stpMode, 'pxUsd': pxUsd, 'pxVol': pxVol, 'banAmend': banAmend, 'tradeQuoteCcy': tradeQuoteCcy} + 'tgtCcy': tgtCcy, 'stpMode': stpMode, 'pxUsd': pxUsd, 'pxVol': pxVol, 'banAmend': banAmend} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType params['attachAlgoOrds'] = attachAlgoOrds return self._request_with_params(POST, PLACR_ORDER, params) @@ -35,11 +39,13 @@ def cancel_multiple_orders(self, orders_data): # Amend Order def amend_order(self, instId, cxlOnFail='', ordId='', clOrdId='', reqId='', newSz='', newPx='', newTpTriggerPx='', newTpOrdPx='', newSlTriggerPx='', newSlOrdPx='', newTpTriggerPxType='', newSlTriggerPxType='', - attachAlgoOrds='', newTriggerPx='', newOrdPx=''): + attachAlgoOrds='', newTriggerPx='', newOrdPx='', pxAmendType=None): params = {'instId': instId, 'cxlOnFail': cxlOnFail, 'ordId': ordId, 'clOrdId': clOrdId, 'reqId': reqId, 'newSz': newSz, 'newPx': newPx, 'newTpTriggerPx': newTpTriggerPx, 'newTpOrdPx': newTpOrdPx, 'newSlTriggerPx': newSlTriggerPx, 'newSlOrdPx': newSlOrdPx, 'newTpTriggerPxType': newTpTriggerPxType, 'newSlTriggerPxType': newSlTriggerPxType, 'newTriggerPx': newTriggerPx, 'newOrdPx': newOrdPx} + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType params['attachAlgoOrds'] = attachAlgoOrds return self._request_with_params(POST, AMEND_ORDER, params) @@ -95,8 +101,8 @@ def place_algo_order(self, instId='', tdMode='', side='', ordType='', sz='', ccy pxSpread='', szLimit='', pxLimit='', timeInterval='', tpTriggerPxType='', slTriggerPxType='', callbackRatio='', callbackSpread='', activePx='', tag='', triggerPxType='', closeFraction='' - , quickMgnType='', algoClOrdId='', tradeQuoteCcy='', tpOrdKind='', cxlOnClosePos='' - , chaseType='', chaseVal='', maxChaseType='', maxChaseVal='', attachAlgoOrds=[]): + , quickMgnType='', algoClOrdId='', tradeQuoteCcy=None, tpOrdKind='', cxlOnClosePos='' + , chaseType='', chaseVal='', maxChaseType='', maxChaseVal='', attachAlgoOrds=[], pxAmendType=None): params = {'instId': instId, 'tdMode': tdMode, 'side': side, 'ordType': ordType, 'sz': sz, 'ccy': ccy, 'posSide': posSide, 'reduceOnly': reduceOnly, 'tpTriggerPx': tpTriggerPx, 'tpOrdPx': tpOrdPx, 'slTriggerPx': slTriggerPx, 'slOrdPx': slOrdPx, 'triggerPx': triggerPx, 'orderPx': orderPx, @@ -105,9 +111,13 @@ def place_algo_order(self, instId='', tdMode='', side='', ordType='', sz='', ccy 'pxSpread': pxSpread, 'tpTriggerPxType': tpTriggerPxType, 'slTriggerPxType': slTriggerPxType, 'callbackRatio': callbackRatio, 'callbackSpread': callbackSpread, 'activePx': activePx, 'tag': tag, 'triggerPxType': triggerPxType, 'closeFraction': closeFraction, - 'quickMgnType': quickMgnType, 'algoClOrdId': algoClOrdId, 'tradeQuoteCcy': tradeQuoteCcy, + 'quickMgnType': quickMgnType, 'algoClOrdId': algoClOrdId, 'tpOrdKind': tpOrdKind, 'cxlOnClosePos': cxlOnClosePos, 'chaseType': chaseType, 'chaseVal': chaseVal, 'maxChaseType': maxChaseType, 'maxChaseVal': maxChaseVal, 'attachAlgoOrds': attachAlgoOrds} + if tradeQuoteCcy is not None: + params['tradeQuoteCcy'] = tradeQuoteCcy + if pxAmendType is not None: + params['pxAmendType'] = pxAmendType return self._request_with_params(POST, PLACE_ALGO_ORDER, params) # Cancel Algo Order diff --git a/okx/__init__.py b/okx/__init__.py index 2dbeeb8..57f92fe 100644 --- a/okx/__init__.py +++ b/okx/__init__.py @@ -2,4 +2,4 @@ Python SDK for the OKX API v5 """ -__version__="0.4.0" \ No newline at end of file +__version__="0.4.1" \ No newline at end of file diff --git a/okx/consts.py b/okx/consts.py index aae0778..cef00ef 100644 --- a/okx/consts.py +++ b/okx/consts.py @@ -66,6 +66,7 @@ MANUAL_REBORROW_REPAY = '/api/v5/account/spot-manual-borrow-repay' SET_AUTO_REPAY='/api/v5/account/set-auto-repay' GET_BORROW_REPAY_HISTORY='/api/v5/account/spot-borrow-repay-history' +SET_AUTO_EARN='/api/v5/account/set-auto-earn' # Funding NON_TRADABLE_ASSETS = '/api/v5/asset/non-tradable-assets' @@ -81,7 +82,6 @@ DEPOSIT_LIGHTNING = '/api/v5/asset/deposit-lightning' WITHDRAWAL_LIGHTNING = '/api/v5/asset/withdrawal-lightning' CANCEL_WITHDRAWAL = '/api/v5/asset/cancel-withdrawal' -WITHDRAWAL_HISTORY = '/api/v5/asset/withdrawal-history' CONVERT_DUST_ASSETS = '/api/v5/asset/convert-dust-assets' ASSET_VALUATION = '/api/v5/asset/asset-valuation' GET_WITHDRAWAL_HISTORY = '/api/v5/asset/withdrawal-history' @@ -130,6 +130,7 @@ CONVERT_CONTRACT_COIN = '/api/v5/public/convert-contract-coin' GET_OPTION_TICKBANDS = '/api/v5/public/instrument-tick-bands' GET_OPTION_TRADES = '/api/v5/public/option-trades' +MARKET_DATA_HISTORY = '/api/v5/public/market-data-history' # Trading data SUPPORT_COIN = '/api/v5/rubik/stat/trading-data/support-coin' diff --git a/okx/okxclient.py b/okx/okxclient.py index e88e3a7..6cd4751 100644 --- a/okx/okxclient.py +++ b/okx/okxclient.py @@ -14,7 +14,16 @@ class OkxClient(Client): def __init__(self, api_key='-1', api_secret_key='-1', passphrase='-1', use_server_time=None, flag='1',base_api=c.API_URL, debug=False, proxy=None): - super().__init__(base_url=base_api, http2=True, proxy=proxy) + # Compatible with different versions of httpx + # New versions (0.24.0+) use proxy, older versions use proxies + try: + super().__init__(base_url=base_api, http2=True, proxy=proxy) + except TypeError: + # Older versions of httpx use proxies parameter + if proxy: + super().__init__(base_url=base_api, http2=True, proxies={'http://': proxy, 'https://': proxy}) + else: + super().__init__(base_url=base_api, http2=True) self.API_KEY = api_key self.API_SECRET_KEY = api_secret_key self.PASSPHRASE = passphrase diff --git a/okx/websocket/WsPrivateAsync.py b/okx/websocket/WsPrivateAsync.py index c5359aa..b7e328e 100644 --- a/okx/websocket/WsPrivateAsync.py +++ b/okx/websocket/WsPrivateAsync.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import warnings from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory @@ -9,7 +10,7 @@ class WsPrivateAsync: - def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): + def __init__(self, apiKey, passphrase, secretKey, url, useServerTime=None, debug=False): self.url = url self.subscriptions = set() self.callback = None @@ -18,28 +19,43 @@ def __init__(self, apiKey, passphrase, secretKey, url, useServerTime): self.apiKey = apiKey self.passphrase = passphrase self.secretKey = secretKey - self.useServerTime = useServerTime + self.useServerTime = False self.websocket = None + self.debug = debug + + # Set log level + if debug: + logger.setLevel(logging.DEBUG) + + # Deprecation warning for useServerTime parameter + if useServerTime is not None: + warnings.warn("useServerTime parameter is deprecated. Please remove it.", DeprecationWarning) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback logRes = await self.login() await asyncio.sleep(5) if logRes: - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() @@ -50,28 +66,139 @@ async def login(self): passphrase=self.passphrase, secretKey=self.secretKey ) + if self.debug: + logger.debug(f"login: {loginPayload}") await self.websocket.send(loginPayload) return True - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) - logger.info(f"unsubscribe: {payload}") + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) - # for param in params: - # self.subscriptions.discard(param) + + async def place_order(self, args: list, callback=None, id: str = None): + """ + Place order + :param args: Order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("order", args, id=id) + + async def batch_orders(self, args: list, callback=None, id: str = None): + """ + Batch place orders + :param args: Batch order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-orders", args, id=id) + + async def cancel_order(self, args: list, callback=None, id: str = None): + """ + Cancel order + :param args: Cancel order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("cancel-order", args, id=id) + + async def batch_cancel_orders(self, args: list, callback=None, id: str = None): + """ + Batch cancel orders + :param args: Batch cancel order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-cancel-orders", args, id=id) + + async def amend_order(self, args: list, callback=None, id: str = None): + """ + Amend order + :param args: Amend order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("amend-order", args, id=id) + + async def batch_amend_orders(self, args: list, callback=None, id: str = None): + """ + Batch amend orders + :param args: Batch amend order parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("batch-amend-orders", args, id=id) + + async def mass_cancel(self, args: list, callback=None, id: str = None): + """ + Mass cancel orders + Note: This method is for /ws/v5/business channel, rate limit: 1 request/second + :param args: Cancel parameter list, contains instType and instFamily + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + await self.send("mass-cancel", args, id=id) async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) def stop_sync(self): - self.loop.run_until_complete(self.stop()) + if self.loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self.stop(), self.loop) + future.result(timeout=10) + else: + self.loop.run_until_complete(self.stop()) diff --git a/okx/websocket/WsPublicAsync.py b/okx/websocket/WsPublicAsync.py index e576d65..d7eefdc 100644 --- a/okx/websocket/WsPublicAsync.py +++ b/okx/websocket/WsPublicAsync.py @@ -2,55 +2,124 @@ import json import logging +from okx.websocket import WsUtils from okx.websocket.WebSocketFactory import WebSocketFactory logger = logging.getLogger(__name__) class WsPublicAsync: - def __init__(self, url): + def __init__(self, url, apiKey='', passphrase='', secretKey='', debug=False): self.url = url self.subscriptions = set() self.callback = None self.loop = asyncio.get_event_loop() self.factory = WebSocketFactory(url) self.websocket = None + self.debug = debug + # Credentials for business channel login + self.apiKey = apiKey + self.passphrase = passphrase + self.secretKey = secretKey + self.isLoggedIn = False + + # Set log level + if debug: + logger.setLevel(logging.DEBUG) async def connect(self): self.websocket = await self.factory.connect() async def consume(self): async for message in self.websocket: - logger.debug("Received message: {%s}", message) + if self.debug: + logger.debug("Received message: {%s}", message) if self.callback: self.callback(message) - async def subscribe(self, params: list, callback): + async def login(self): + """ + Login method for business channel that requires authentication (e.g. /ws/v5/business) + """ + if not self.apiKey or not self.secretKey or not self.passphrase: + raise ValueError("apiKey, secretKey and passphrase are required for login") + + loginPayload = WsUtils.initLoginParams( + useServerTime=False, + apiKey=self.apiKey, + passphrase=self.passphrase, + secretKey=self.secretKey + ) + if self.debug: + logger.debug(f"login: {loginPayload}") + await self.websocket.send(loginPayload) + self.isLoggedIn = True + return True + + async def subscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "subscribe", "args": params - }) + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"subscribe: {payload}") await self.websocket.send(payload) # await self.consume() - async def unsubscribe(self, params: list, callback): + async def unsubscribe(self, params: list, callback, id: str = None): self.callback = callback - payload = json.dumps({ + payload_dict = { "op": "unsubscribe", "args": params - }) - logger.info(f"unsubscribe: {payload}") + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"unsubscribe: {payload}") + else: + logger.info(f"unsubscribe: {payload}") + await self.websocket.send(payload) + + async def send(self, op: str, args: list, callback=None, id: str = None): + """ + Generic send method + :param op: Operation type + :param args: Parameter list + :param callback: Callback function + :param id: Optional request ID + """ + if callback: + self.callback = callback + payload_dict = { + "op": op, + "args": args + } + if id is not None: + payload_dict["id"] = id + payload = json.dumps(payload_dict) + if self.debug: + logger.debug(f"send: {payload}") await self.websocket.send(payload) async def stop(self): await self.factory.close() - self.loop.stop() async def start(self): - logger.info("Connecting to WebSocket...") + if self.debug: + logger.debug("Connecting to WebSocket...") + else: + logger.info("Connecting to WebSocket...") await self.connect() self.loop.create_task(self.consume()) def stop_sync(self): - self.loop.run_until_complete(self.stop()) + if self.loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self.stop(), self.loop) + future.result(timeout=10) + else: + self.loop.run_until_complete(self.stop()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b9a5daf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +# Dependencies for python-okx +httpx[http2]>=0.24.0 +requests>=2.25.0 +websockets>=10.0 +certifi>=2021.0.0 +loguru>=0.7.0 +python-dotenv>=1.0.0 + +# Development & Testing +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 +ruff>=0.1.0 +build>=1.0.0 +twine>=4.0.0 diff --git a/setup.py b/setup.py index b91e237..57568e0 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,39 @@ +import os import setuptools + +# Get the directory where setup.py is located +HERE = os.path.dirname(os.path.abspath(__file__)) + +# Read version from package import okx -with open("README.md", "r",encoding="utf-8") as fh: + +# Read README +with open(os.path.join(HERE, "README.md"), "r", encoding="utf-8") as fh: long_description = fh.read() + +def parse_requirements(): + """Parse requirements from requirements.txt.""" + requirements = [] + req_path = os.path.join(HERE, "requirements.txt") + + if not os.path.exists(req_path): + return requirements + + with open(req_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + # Handle inline comments + if "#" in line: + line = line.split("#")[0].strip() + requirements.append(line) + + return requirements + + setuptools.setup( name="python-okx", version=okx.__version__, @@ -12,19 +43,18 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://okx.com/docs-v5/", - packages=setuptools.find_packages(), + packages=setuptools.find_packages(exclude=["test", "test.*", "example"]), + python_requires=">=3.7", classifiers=[ "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - install_requires=[ - "importlib-metadata", - "httpx[http2]", - "keyring", - "loguru", - "requests", - "Twisted", - "pyOpenSSL" - ] -) \ No newline at end of file + install_requires=parse_requirements(), +) diff --git a/test/WsPrivateAsyncTest.py b/test/WsPrivateAsyncTest.py deleted file mode 100644 index ba7fcff..0000000 --- a/test/WsPrivateAsyncTest.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio - -from okx.websocket.WsPrivateAsync import WsPrivateAsync - - -def privateCallback(message): - print("privateCallback", message) - - -async def main(): - url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" - ws = WsPrivateAsync( - apiKey="your apiKey", - passphrase="your passphrase", - secretKey="your secretKey", - url=url, - useServerTime=False - ) - await ws.start() - args = [] - arg1 = {"channel": "account", "ccy": "BTC"} - arg2 = {"channel": "orders", "instType": "ANY"} - arg3 = {"channel": "balance_and_position"} - args.append(arg1) - args.append(arg2) - args.append(arg3) - await ws.subscribe(args, callback=privateCallback) - await asyncio.sleep(30) - print("-----------------------------------------unsubscribe--------------------------------------------") - args2 = [arg2] - await ws.unsubscribe(args2, callback=privateCallback) - await asyncio.sleep(30) - print("-----------------------------------------unsubscribe all--------------------------------------------") - args3 = [arg1, arg3] - await ws.unsubscribe(args3, callback=privateCallback) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/test/WsPublicAsyncTest.py b/test/WsPublicAsyncTest.py deleted file mode 100644 index 14276a0..0000000 --- a/test/WsPublicAsyncTest.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio - -from okx.websocket.WsPublicAsync import WsPublicAsync - - -def publicCallback(message): - print("publicCallback", message) - - -async def main(): - - # url = "wss://wspap.okex.com:8443/ws/v5/public?brokerId=9999" - url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" - ws = WsPublicAsync(url=url) - await ws.start() - args = [] - arg1 = {"channel": "instruments", "instType": "FUTURES"} - arg2 = {"channel": "instruments", "instType": "SPOT"} - arg3 = {"channel": "tickers", "instId": "BTC-USDT-SWAP"} - arg4 = {"channel": "tickers", "instId": "ETH-USDT"} - args.append(arg1) - args.append(arg2) - args.append(arg3) - args.append(arg4) - await ws.subscribe(args, publicCallback) - await asyncio.sleep(5) - print("-----------------------------------------unsubscribe--------------------------------------------") - args2 = [arg4] - await ws.unsubscribe(args2, publicCallback) - await asyncio.sleep(5) - print("-----------------------------------------unsubscribe all--------------------------------------------") - args3 = [arg1, arg2, arg3] - await ws.unsubscribe(args3, publicCallback) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/config.py b/test/config.py new file mode 100644 index 0000000..78885b3 --- /dev/null +++ b/test/config.py @@ -0,0 +1,58 @@ +""" +Test configuration module - loads API credentials from environment variables. + +Usage: + from test.config import get_api_credentials + + api_key, api_secret, passphrase, flag = get_api_credentials() +""" +import os +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Flag to ensure .env is loaded only once +_env_loaded = False + + +def _load_env_once(): + """Load .env file only once, log any exceptions.""" + global _env_loaded + if _env_loaded: + return + + _env_loaded = True + env_path = Path(__file__).parent.parent / '.env' + + try: + from dotenv import load_dotenv + if env_path.exists(): + load_dotenv(env_path) + logger.debug(f"Loaded .env file from: {env_path}") + else: + logger.warning(f".env file not found at: {env_path}") + except ImportError: + logger.warning("python-dotenv not installed, relying on system environment variables") + except Exception as e: + logger.error(f"Failed to load .env file: {e}") + + +# Load .env when module is imported +_load_env_once() + + +def get_api_credentials(): + """ + Get API credentials from environment variables. + + Returns: + tuple: (api_key, api_secret, passphrase, flag) + """ + api_key = os.getenv('OKX_API_KEY', '') + api_secret = os.getenv('OKX_API_SECRET', '') + passphrase = os.getenv('OKX_PASSPHRASE', '') + flag = os.getenv('OKX_FLAG', '1') # Default to demo trading + + return api_key, api_secret, passphrase, flag + diff --git a/test/AccountTest.py b/test/test_account.py similarity index 93% rename from test/AccountTest.py rename to test/test_account.py index 724112b..ef71181 100644 --- a/test/AccountTest.py +++ b/test/test_account.py @@ -3,14 +3,13 @@ from loguru import logger from okx import Account +from test.config import get_api_credentials class AccountTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.AccountAPI = Account.AccountAPI(api_key, api_secret_key, passphrase, flag=flag) # ''' # POSITIONS_HISTORY = '/api/v5/account/positions-history' #need add @@ -146,6 +145,15 @@ def setUp(self): # logger.info(f'{self.AccountAPI.set_auto_repay(autoRepay=True)}') # def test_spot_borrow_repay_history(self): # logger.debug(self.AccountAPI.spot_borrow_repay_history(ccy="USDT",type="auto_borrow",after="1597026383085")) + # def test_set_auto_earn(self): + # logger.debug(self.AccountAPI.set_auto_earn(ccy="USDT", action="turn_on", earnType='0')) + #def test_get_max_loan_with_trade_quote_ccy(self): + # logger.debug(self.AccountAPI.get_max_loan( + # instId="BTC-USDT", + # mgnMode="isolated", + # mgnCcy="USDT", + # tradeQuoteCcy="USDT" + # )) if __name__ == '__main__': unittest.main() diff --git a/test/BlockTradingTest.py b/test/test_block_trading.py similarity index 93% rename from test/BlockTradingTest.py rename to test/test_block_trading.py index e8b0e88..1c02542 100644 --- a/test/BlockTradingTest.py +++ b/test/test_block_trading.py @@ -1,12 +1,13 @@ import unittest from okx import BlockTrading +from test.config import get_api_credentials + + class BlockTradingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.BlockTradingAPI = BlockTrading.BlockTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.BlockTradingAPI = BlockTrading.BlockTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) """ def test_get_counter_parties(self): diff --git a/test/ConvertTest.py b/test/test_convert.py similarity index 83% rename from test/ConvertTest.py rename to test/test_convert.py index aa1d175..fc1f048 100644 --- a/test/ConvertTest.py +++ b/test/test_convert.py @@ -1,11 +1,12 @@ import unittest -from ..okx import Convert +from okx import Convert +from test.config import get_api_credentials + + class ConvertTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.ConvertAPI = Convert.ConvertAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.ConvertAPI = Convert.ConvertAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' def test_get_currencies(self): diff --git a/test/CopyTradingTest.py b/test/test_copy_trading.py similarity index 88% rename from test/CopyTradingTest.py rename to test/test_copy_trading.py index 95c1739..b516bc5 100644 --- a/test/CopyTradingTest.py +++ b/test/test_copy_trading.py @@ -1,13 +1,13 @@ import unittest from okx import CopyTrading +from test.config import get_api_credentials + class CopyTradingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.StackingAPI = CopyTrading.CopyTradingAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='0') + flag=flag) # def test_get_existing_leading_positions(self): # print(self.StackingAPI.get_existing_leading_positions(instId='DOGE-USDT-SWAP')) diff --git a/test/EthStakingTest.py b/test/test_eth_staking.py similarity index 82% rename from test/EthStakingTest.py rename to test/test_eth_staking.py index 3be4860..f280d1d 100644 --- a/test/EthStakingTest.py +++ b/test/test_eth_staking.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import EthStaking +from test.config import get_api_credentials + class EthStakingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = EthStaking.EthStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = EthStaking.EthStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_eth_product_info(self): print(self.StackingAPI.eth_product_info()) diff --git a/test/FlexibleLoanTest.py b/test/test_flexible_loan.py similarity index 84% rename from test/FlexibleLoanTest.py rename to test/test_flexible_loan.py index b4320a5..517b194 100644 --- a/test/FlexibleLoanTest.py +++ b/test/test_flexible_loan.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import FlexibleLoan +from test.config import get_api_credentials + class FlexibleLoanTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.FlexibleLoanAPI = FlexibleLoan.FlexibleLoanAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.FlexibleLoanAPI = FlexibleLoan.FlexibleLoanAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_borrow_currencies(self): print(self.FlexibleLoanAPI.borrow_currencies()) diff --git a/test/FundingTest.py b/test/test_funding.py similarity index 82% rename from test/FundingTest.py rename to test/test_funding.py index e87bb76..3247f65 100644 --- a/test/FundingTest.py +++ b/test/test_funding.py @@ -1,13 +1,13 @@ import unittest from okx import Funding +from test.config import get_api_credentials + class FundingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.FundingAPI = Funding.FundingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='0') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.FundingAPI = Funding.FundingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) """ CANCEL_WITHDRAWAL = '/api/v5/asset/cancel-withdrawal' #need add CONVERT_DUST_ASSETS = '/api/v5/asset/convert-dust-assets' #need add @@ -78,7 +78,16 @@ def test_get_lending_summary(self): # print(self.FundingAPI.get_deposit_history()) def test_withdrawal(self): - print(self.FundingAPI.withdrawal(ccy='USDT',amt='1',dest='3',toAddr='18740405107',areaCode='86')) + # toAddrType: Address type + # 1: Wallet address, email, phone number or login account + # 2: UID (only applicable when dest=3) + print(self.FundingAPI.withdrawal(ccy='USDT', amt='1', dest='3', toAddr='18740405107', areaCode='86', toAddrType='1')) + + def test_get_withdrawal_history_with_toAddrType(self): + # toAddrType: Address type filter + # 1: Wallet address, email, phone number or login account + # 2: UID + print(self.FundingAPI.get_withdrawal_history(toAddrType='1')) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/GridTest.py b/test/test_grid.py similarity index 77% rename from test/GridTest.py rename to test/test_grid.py index 6e8f7da..fb14c95 100644 --- a/test/GridTest.py +++ b/test/test_grid.py @@ -1,13 +1,13 @@ import unittest from okx import Grid +from test.config import get_api_credentials + class GridTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.GridAPI = Grid.GridAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1', debug=False) + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.GridAPI = Grid.GridAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag, debug=False) """ GRID_COMPUTE_MARIGIN_BALANCE = '/api/v5/tradingBot/grid/compute-margin-balance' GRID_MARGIN_BALANCE = '/api/v5/tradingBot/grid/margin-balance' @@ -79,7 +79,31 @@ def test_withdrawl_profits(self): # def test_get_recurring_buy_sub_orders(self): # print(self.GridAPI.get_recurring_buy_sub_orders(algoId="581191143417970688")) - #581191143417970688 + #def test_grid_order_algo_with_trade_quote_ccy(self): + # print(self.GridAPI.grid_order_algo( + # instId="BTC-USDT", + # algoOrdType="grid", + # maxPx="45000", + # minPx="20000", + # gridNum="100", + # runType="1", + # quoteSz="50", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_recurring_buy_order_with_trade_quote_ccy(self): + # print(self.GridAPI.place_recurring_buy_order( + # stgyName="test_strategy", + # recurringList=[{'ccy': 'ETH', 'ratio': '1'}], + # period="daily", + # recurringDay='1', + # recurringTime='0', + # timeZone='8', + # amt='100', + # investmentCcy='USDT', + # tdMode='cash', + # tradeQuoteCcy="USDT" + # )) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/MarketTest.py b/test/test_market.py similarity index 93% rename from test/MarketTest.py rename to test/test_market.py index 2a45d7f..f3b6acf 100644 --- a/test/MarketTest.py +++ b/test/test_market.py @@ -12,12 +12,13 @@ BLOCK_TRADES = '/api/v5/market/block-trades'#need to add ''' +from test.config import get_api_credentials + + class MarketAPITest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.MarketApi = MarketData.MarketAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.MarketApi = MarketData.MarketAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' def test_index_component(self): diff --git a/test/PublicDataTest.py b/test/test_public_data.py similarity index 75% rename from test/PublicDataTest.py rename to test/test_public_data.py index 7d7449d..79e5d35 100644 --- a/test/PublicDataTest.py +++ b/test/test_public_data.py @@ -1,11 +1,12 @@ import unittest from okx import PublicData +from test.config import get_api_credentials + + class publicDataTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.publicDataApi = PublicData.PublicAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.publicDataApi = PublicData.PublicAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' TestCase For: INTEREST_LOAN = '/api/v5/public/interest-rate-loan-quota' #need to add @@ -56,8 +57,23 @@ def test_get_mark_price(self): # def test_get_option_tickBands(self): # print(self.publicDataApi.get_option_tick_bands(instType='OPTION')) - def test_get_option_trades(self): - print(self.publicDataApi.get_option_trades(instFamily='BTC-USD')) + # def test_get_option_trades(self): + # print(self.publicDataApi.get_option_trades(instFamily='BTC-USD')) + + def test_get_market_data_history(self): + # module: 数据模块类型 + # 1: Trade history, 2: 1-minute candlestick, 3: Funding rate + # 5: 5000-level orderbook (from Nov 1, 2025), 6: 50-level orderbook + # instType: SPOT, FUTURES, SWAP, OPTION + # dateAggrType: daily, monthly + print(self.publicDataApi.get_market_data_history( + module='6', + instType='SPOT', + dateAggrType='daily', + begin='1761274032000', + end='1761883371133', + instIdList='BTC-USDT' + )) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/SavingsTest.py b/test/test_savings.py similarity index 84% rename from test/SavingsTest.py rename to test/test_savings.py index 9dac910..9baf663 100644 --- a/test/SavingsTest.py +++ b/test/test_savings.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import Savings +from test.config import get_api_credentials + class SavingsTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = Savings.SavingsAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = Savings.SavingsAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_get_saving_balance(self): diff --git a/test/SolStakingTest.py b/test/test_sol_staking.py similarity index 82% rename from test/SolStakingTest.py rename to test/test_sol_staking.py index c3899d4..110fb67 100644 --- a/test/SolStakingTest.py +++ b/test/test_sol_staking.py @@ -1,12 +1,12 @@ import unittest from okx.Finance import SolStaking +from test.config import get_api_credentials + class SolStakingTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.StackingAPI = SolStaking.SolStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.StackingAPI = SolStaking.SolStakingAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) def test_sol_purchase(self): print(self.StackingAPI.sol_purchase(amt="1")) diff --git a/test/SpreadTest.py b/test/test_spread.py similarity index 91% rename from test/SpreadTest.py rename to test/test_spread.py index cbaf6e8..25a596b 100644 --- a/test/SpreadTest.py +++ b/test/test_spread.py @@ -1,11 +1,12 @@ import unittest from okx import SpreadTrading +from test.config import get_api_credentials + + class TradeTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.tradeApi = SpreadTrading.SpreadTradingAPI(api_key, api_secret_key, passphrase, False, '1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.tradeApi = SpreadTrading.SpreadTradingAPI(api_key, api_secret_key, passphrase, False, flag) # def test_place_order(self): # print(self.tradeApi.place_order(sprdId='BTC-USDT_BTC-USDT-SWAP',clOrdId='b15',side='buy',ordType='limit', diff --git a/test/StakingDefiTest.py b/test/test_staking_defi.py similarity index 82% rename from test/StakingDefiTest.py rename to test/test_staking_defi.py index 29ac045..bfe908a 100644 --- a/test/StakingDefiTest.py +++ b/test/test_staking_defi.py @@ -1,14 +1,13 @@ import unittest from okx.Finance import StakingDefi +from test.config import get_api_credentials class StakingDefiTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.StackingAPI = StakingDefi.StakingDefiAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='1') + flag=flag) def test_get_offers(self): print(self.StackingAPI.get_offers(ccy="USDT")) diff --git a/test/SubAccountTest.py b/test/test_sub_account.py similarity index 92% rename from test/SubAccountTest.py rename to test/test_sub_account.py index b4281bd..8682194 100644 --- a/test/SubAccountTest.py +++ b/test/test_sub_account.py @@ -1,12 +1,12 @@ import unittest from okx import SubAccount +from test.config import get_api_credentials + class SubAccountTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.SubAccountApi = SubAccount.SubAccountAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag='1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.SubAccountApi = SubAccount.SubAccountAPI(api_key, api_secret_key, passphrase, use_server_time=False, flag=flag) ''' ENTRUST_SUBACCOUNT_LIST = '/api/v5/users/entrust-subaccount-list' #need to add SET_TRSNSFER_OUT = '/api/v5/users/subaccount/set-transfer-out' #need to add diff --git a/test/TradeTest.py b/test/test_trade.py similarity index 82% rename from test/TradeTest.py rename to test/test_trade.py index 3f1b601..6da7a54 100644 --- a/test/TradeTest.py +++ b/test/test_trade.py @@ -1,14 +1,13 @@ import unittest from okx import Trade +from test.config import get_api_credentials class TradeTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' - self.tradeApi = Trade.TradeAPI(api_key, api_secret_key, passphrase, False, '1') + api_key, api_secret_key, passphrase, flag = get_api_credentials() + self.tradeApi = Trade.TradeAPI(api_key, api_secret_key, passphrase, False, flag) # # """ # def test_place_order(self): @@ -234,12 +233,66 @@ def setUp(self): # def test_close_all_positions(self): # print(self.tradeApi.close_positions(instId="BTC-USDT-SWAP", mgnMode="cross",clOrdId='1213124')) - def test_get_oneclick_repay_list_v2(self): - print(self.tradeApi.get_oneclick_repay_list_v2()) - def test_oneclick_repay_v2(self): - print(self.tradeApi.oneclick_repay_v2('BTC',['USDT'])) - def test_oneclick_repay_history_v2(self): - print(self.tradeApi.oneclick_repay_history_v2()) + #def test_get_oneclick_repay_list_v2(self): + # print(self.tradeApi.get_oneclick_repay_list_v2()) + #def test_oneclick_repay_v2(self): + # print(self.tradeApi.oneclick_repay_v2('BTC',['USDT'])) + #def test_oneclick_repay_history_v2(self): + # print(self.tradeApi.oneclick_repay_history_v2()) + + #def test_place_order_with_trade_quote_ccy(self): + # print(self.tradeApi.place_order( + # instId="BTC-USDT", + # tdMode="cash", + # side="buy", + # ordType="limit", + # sz="0.01", + # px="30000", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_order_with_px_amend_type(self): + # print(self.tradeApi.place_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="limit", + # sz="1", + # px="30000", + # pxAmendType="1" + # )) + + #def test_amend_order_with_px_amend_type(self): + # print(self.tradeApi.amend_order( + # instId="BTC-USDT-SWAP", + # ordId="123", + # newPx="30500", + # pxAmendType="1" + # )) + + #def test_place_algo_order_with_trade_quote_ccy(self): + # print(self.tradeApi.place_algo_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="trigger", + # sz="1", + # triggerPx="30000", + # orderPx="-1", + # tradeQuoteCcy="USDT" + # )) + + #def test_place_algo_order_with_px_amend_type(self): + # print(self.tradeApi.place_algo_order( + # instId="BTC-USDT-SWAP", + # tdMode="cash", + # side="buy", + # ordType="trigger", + # sz="1", + # triggerPx="30000", + # orderPx="-1", + # pxAmendType="1" + # )) if __name__ == '__main__': unittest.main() diff --git a/test/TradingDataTest.py b/test/test_trading_data.py similarity index 86% rename from test/TradingDataTest.py rename to test/test_trading_data.py index ac318b6..0872f37 100644 --- a/test/TradingDataTest.py +++ b/test/test_trading_data.py @@ -1,14 +1,13 @@ import unittest -from ..okx import TradingData +from okx import TradingData +from test.config import get_api_credentials class TradingDataTest(unittest.TestCase): def setUp(self): - api_key = 'your_apiKey' - api_secret_key = 'your_secretKey' - passphrase = 'your_secretKey' + api_key, api_secret_key, passphrase, flag = get_api_credentials() self.TradingDataAPI = TradingData.TradingDataAPI(api_key, api_secret_key, passphrase, use_server_time=False, - flag='1') + flag=flag) """ def test_get_support_coins(self): print(self.TradingDataAPI.get_support_coin()) diff --git a/test/test_ws_private_async.py b/test/test_ws_private_async.py new file mode 100644 index 0000000..10f34bf --- /dev/null +++ b/test/test_ws_private_async.py @@ -0,0 +1,278 @@ +import asyncio + +from okx.websocket.WsPrivateAsync import WsPrivateAsync +from test.config import get_api_credentials + + +def privateCallback(message): + print("privateCallback", message) + + +async def main(): + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + args = [] + arg1 = {"channel": "account", "ccy": "BTC"} + arg2 = {"channel": "orders", "instType": "ANY"} + arg3 = {"channel": "balance_and_position"} + # Withdrawal info channel subscription example, supporting the toAddrType parameter + # toAddrType: Address type + # 1: Wallet address, email, phone number or login account name + # 2: UID (applicable only when dest=3) + arg4 = {"channel": "withdrawal-info", "ccy": "USDT", "toAddrType": "1"} + args.append(arg1) + args.append(arg2) + args.append(arg3) + args.append(arg4) + await ws.subscribe(args, callback=privateCallback) + await asyncio.sleep(30) + print("-----------------------------------------unsubscribe--------------------------------------------") + args2 = [arg2] + # Use id parameter to identify unsubscribe request + await ws.unsubscribe(args2, callback=privateCallback, id="privateUnsub001") + await asyncio.sleep(5) + print("-----------------------------------------unsubscribe all--------------------------------------------") + args3 = [arg1, arg3] + await ws.unsubscribe(args3, callback=privateCallback) + await asyncio.sleep(1) + await ws.stop() + + +async def test_place_order(): + """ + Test place order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Order parameters + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "client_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.place_order(order_args, callback=privateCallback, id="order001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_orders(): + """ + Test batch orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Batch order parameters (max 20) + order_args = [ + { + "instId": "BTC-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_001", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }, + { + "instId": "ETH-USDT", + "tdMode": "cash", + "clOrdId": "batch_order_002", + "side": "buy", + "ordType": "limit", + "sz": "0.01", + "px": "2000" + } + ] + await ws.batch_orders(order_args, callback=privateCallback, id="batchOrder001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_cancel_order(): + """ + Test cancel order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Cancel order parameters (either ordId or clOrdId must be provided) + cancel_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id" + # Or use "clOrdId": "client_order_001" + }] + await ws.cancel_order(cancel_args, callback=privateCallback, id="cancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_batch_cancel_orders(): + """ + Test batch cancel orders functionality + URL: /ws/v5/private (Rate limit: 60 requests/second, max 20 orders) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "order_id_1"}, + {"instId": "ETH-USDT", "ordId": "order_id_2"} + ] + await ws.batch_cancel_orders(cancel_args, callback=privateCallback, id="batchCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_amend_order(): + """ + Test amend order functionality + URL: /ws/v5/private (Rate limit: 60 requests/second) + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Amend order parameters + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "your_order_id", + "newSz": "0.002", + "newPx": "31000" + }] + await ws.amend_order(amend_args, callback=privateCallback, id="amend001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_mass_cancel(): + """ + Test mass cancel functionality + URL: /ws/v5/business (Rate limit: 1 request/second) + Note: This function uses the business channel + """ + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Mass cancel parameters + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + await ws.mass_cancel(mass_cancel_args, callback=privateCallback, id="massCancel001") + await asyncio.sleep(5) + await ws.stop() + + +async def test_send_method(): + """Test generic send method""" + api_key, api_secret_key, passphrase, _ = get_api_credentials() + url = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999" + ws = WsPrivateAsync( + apiKey=api_key, + passphrase=passphrase, + secretKey=api_secret_key, + url=url, + debug=True + ) + await ws.start() + await ws.login() + await asyncio.sleep(5) + + # Use generic send method to place order - callback must be provided to receive response + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + await ws.send("order", order_args, callback=privateCallback, id="send001") + await asyncio.sleep(5) + await ws.stop() + + +if __name__ == '__main__': + # asyncio.run(main()) + asyncio.run(test_place_order()) + asyncio.run(test_batch_orders()) + asyncio.run(test_cancel_order()) + asyncio.run(test_batch_cancel_orders()) + asyncio.run(test_amend_order()) + asyncio.run(test_mass_cancel()) # Note: uses business channel + asyncio.run(test_send_method()) diff --git a/test/test_ws_public_async.py b/test/test_ws_public_async.py new file mode 100644 index 0000000..dac6c84 --- /dev/null +++ b/test/test_ws_public_async.py @@ -0,0 +1,81 @@ +import asyncio + +from okx.websocket.WsPublicAsync import WsPublicAsync + + +def publicCallback(message): + print("publicCallback", message) + + +async def main(): + # url = "wss://wspap.okex.com:8443/ws/v5/public?brokerId=9999" + url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" + ws = WsPublicAsync(url=url, debug=True) # Enable debug logging + await ws.start() + args = [] + arg1 = {"channel": "instruments", "instType": "FUTURES"} + arg2 = {"channel": "instruments", "instType": "SPOT"} + arg3 = {"channel": "tickers", "instId": "BTC-USDT-SWAP"} + arg4 = {"channel": "tickers", "instId": "ETH-USDT"} + args.append(arg1) + args.append(arg2) + args.append(arg3) + args.append(arg4) + # Use id parameter to identify subscribe request, the same id will be returned in response + await ws.subscribe(args, publicCallback, id="sub001") + await asyncio.sleep(5) + print("-----------------------------------------unsubscribe--------------------------------------------") + args2 = [arg4] + # Use id parameter to identify unsubscribe request + await ws.unsubscribe(args2, publicCallback, id="unsub001") + await asyncio.sleep(5) + print("-----------------------------------------unsubscribe all--------------------------------------------") + args3 = [arg1, arg2, arg3] + await ws.unsubscribe(args3, publicCallback) + await asyncio.sleep(1) + await ws.stop() + + +async def test_business_channel_with_login(): + """ + Test business channel login functionality + Business channel requires login to subscribe to certain private data + """ + url = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999" + ws = WsPublicAsync( + url=url, + apiKey="your apiKey", + passphrase="your passphrase", + secretKey="your secretKey", + debug=True + ) + await ws.start() + + # Login + await ws.login() + await asyncio.sleep(5) + + # Subscribe to channels that require login + args = [{"channel": "candle1m", "instId": "BTC-USDT"}] + await ws.subscribe(args, publicCallback) + await asyncio.sleep(30) + await ws.stop() + + +async def test_send_method(): + """Test generic send method""" + url = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" + ws = WsPublicAsync(url=url, debug=True) + await ws.start() + + # Use generic send method to subscribe - callback must be provided to receive response + args = [{"channel": "tickers", "instId": "BTC-USDT"}] + await ws.send("subscribe", args, callback=publicCallback, id="send001") + await asyncio.sleep(10) + await ws.stop() + + +if __name__ == '__main__': + # asyncio.run(main()) + # asyncio.run(test_business_channel_with_login()) + asyncio.run(test_send_method()) diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..940cd5e --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1,10 @@ +""" +Unit tests package + +Unit tests mirror the source code structure for easy navigation. + +Example: + okx/Account.py -> test/unit/okx/test_account.py + okx/Trade.py -> test/unit/okx/test_trade.py + okx/Finance/Savings.py -> test/unit/okx/Finance/test_savings.py +""" diff --git a/test/unit/okx/__init__.py b/test/unit/okx/__init__.py new file mode 100644 index 0000000..2b43621 --- /dev/null +++ b/test/unit/okx/__init__.py @@ -0,0 +1,2 @@ +"""Unit tests for okx package""" + diff --git a/test/unit/okx/test_account.py b/test/unit/okx/test_account.py new file mode 100644 index 0000000..57f949d --- /dev/null +++ b/test/unit/okx/test_account.py @@ -0,0 +1,697 @@ +""" +Unit tests for okx.Account module + +Mirrors the structure: okx/Account.py -> test/unit/okx/test_account.py +""" +import unittest +from unittest.mock import patch +from okx.Account import AccountAPI +from okx import consts as c + + +class TestAccountAPIPositionBuilder(unittest.TestCase): + """Unit tests for the position_builder method""" + + def setUp(self): + """Set up test fixtures""" + self.api_key = 'test_api_key' + self.api_secret = 'test_api_secret' + self.passphrase = 'test_passphrase' + self.account_api = AccountAPI( + api_key=self.api_key, + api_secret_key=self.api_secret, + passphrase=self.passphrase, + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_all_parameters(self, mock_request): + """Test position_builder with all parameters provided""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '1000', + 'imr': '2000', + 'mmrBf': '900', + 'imrBf': '1900' + }] + } + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}] + sim_asset = [{'ccy': 'USDT', 'amt': '10000'}] + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=True, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': True, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_idxVol_only(self, mock_request): + """Test position_builder with only idxVol parameter""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0.1') + + # Assert + expected_params = { + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_negative_idxVol(self, mock_request): + """Test position_builder with negative idxVol (price decrease)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.05') + + # Assert + expected_params = { + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_no_parameters(self, mock_request): + """Test position_builder with no parameters (all None)""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [] + } + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder() + + # Assert + # Should pass empty params dict + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, {}) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_simulated_positions(self, mock_request): + """Test position_builder with simulated positions and assets""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '5000', + 'imr': '10000' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '100', 'avgPx': '3000'} + ] + sim_asset = [ + {'ccy': 'USDT', 'amt': '100000'}, + {'ccy': 'BTC', 'amt': '1'} + ] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos, + simAsset=sim_asset, + idxVol='0.1' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '0.1' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_pa(self, mock_request): + """Test position_builder with greeksType PA""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='PA') + + # Assert + expected_params = {'greeksType': 'PA'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_greeks_type_bs(self, mock_request): + """Test position_builder with greeksType BS""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(greeksType='BS') + + # Assert + expected_params = {'greeksType': 'BS'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_includes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=True""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=True, + idxVol='0.05' + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': True, + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_excludes_real_positions(self, mock_request): + """Test position_builder with inclRealPosAndEq=False (only virtual positions)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + sim_pos = [{'instId': 'BTC-USDT-SWAP', 'pos': '5', 'avgPx': '60000'}] + + # Act + result = self.account_api.position_builder( + inclRealPosAndEq=False, + simPos=sim_pos + ) + + # Assert + expected_params = { + 'inclRealPosAndEq': False, + 'simPos': sim_pos + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_account_level(self, mock_request): + """Test position_builder with specific account level""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(acctLv='3') + + # Assert + expected_params = {'acctLv': '3'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_with_leverage(self, mock_request): + """Test position_builder with leverage parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(lever='10') + + # Assert + expected_params = {'lever': '10'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_positive(self, mock_request): + """Test position_builder with maximum positive volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='1') + + # Assert + expected_params = {'idxVol': '1'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_extreme_volatility_negative(self, mock_request): + """Test position_builder with maximum negative volatility""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='-0.99') + + # Assert + expected_params = {'idxVol': '-0.99'} + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_position_builder_complex_scenario(self, mock_request): + """Test position_builder with a complex realistic scenario""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{ + 'mmr': '15000', + 'imr': '30000', + 'mmrBf': '14000', + 'imrBf': '28000', + 'markPxBf': '49500' + }] + } + mock_request.return_value = mock_response + + sim_pos = [ + {'instId': 'BTC-USDT-SWAP', 'pos': '10', 'avgPx': '50000'}, + {'instId': 'ETH-USDT-SWAP', 'pos': '-50', 'avgPx': '3000'} + ] + sim_asset = [{'ccy': 'USDT', 'amt': '50000'}] + + # Act - Simulate a 5% market drop + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=False, + lever='5', + greeksType='PA', + simPos=sim_pos, + simAsset=sim_asset, + idxVol='-0.05' + ) + + # Assert + expected_params = { + 'acctLv': '2', + 'inclRealPosAndEq': False, + 'lever': '5', + 'greeksType': 'PA', + 'simPos': sim_pos, + 'simAsset': sim_asset, + 'idxVol': '-0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + self.assertEqual(result['code'], '0') + self.assertIn('mmrBf', result['data'][0]) + self.assertIn('imrBf', result['data'][0]) + + +class TestAccountAPIPositionBuilderParameterHandling(unittest.TestCase): + """Test parameter handling and edge cases""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_none_parameters_are_excluded(self, mock_request): + """Test that None parameters are not included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + acctLv='2', + inclRealPosAndEq=None, # Should be excluded + lever=None, # Should be excluded + greeksType='PA', + simPos=None, # Should be excluded + simAsset=None, # Should be excluded + idxVol='0.05' + ) + + # Assert - Only non-None params should be in the call + expected_params = { + 'acctLv': '2', + 'greeksType': 'PA', + 'idxVol': '0.05' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_false_value_for_inclRealPosAndEq_is_included(self, mock_request): + """Test that False value for inclRealPosAndEq is included (not treated as None)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(inclRealPosAndEq=False) + + # Assert - False should be included + expected_params = { + 'inclRealPosAndEq': False + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_empty_lists_are_included(self, mock_request): + """Test that empty lists are included in the request""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder( + simPos=[], + simAsset=[] + ) + + # Assert + expected_params = { + 'simPos': [], + 'simAsset': [] + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_zero_idxVol_is_included(self, mock_request): + """Test that zero idxVol is included (represents no volatility change)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.position_builder(idxVol='0') + + # Assert + expected_params = { + 'idxVol': '0' + } + mock_request.assert_called_once_with(c.POST, c.POSITION_BUILDER, expected_params) + + +class TestAccountAPISetAutoEarn(unittest.TestCase): + """Unit tests for the set_auto_earn method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_with_all_params(self, mock_request): + """Test set_auto_earn with all parameters provided""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + ccy='USDT', + action='turn_on', + earnType='0' + ) + + # Assert + expected_params = { + 'ccy': 'USDT', + 'action': 'turn_on', + 'earnType': '0' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_turn_on_action(self, mock_request): + """Test set_auto_earn with turn_on action""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + ccy='BTC', + action='turn_on', + earnType='0' + ) + + # Assert + expected_params = { + 'ccy': 'BTC', + 'action': 'turn_on', + 'earnType': '0' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_turn_off_action(self, mock_request): + """Test set_auto_earn with turn_off action""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn( + ccy='ETH', + action='turn_off', + earnType='0' + ) + + # Assert + expected_params = { + 'ccy': 'ETH', + 'action': 'turn_off', + 'earnType': '0' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_with_required_params_only(self, mock_request): + """Test set_auto_earn with required parameters only (ccy, action)""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.account_api.set_auto_earn(ccy='USDT', action='turn_on') + + # Assert + expected_params = { + 'ccy': 'USDT', + 'action': 'turn_on' + } + mock_request.assert_called_once_with(c.POST, c.SET_AUTO_EARN, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_set_auto_earn_different_currencies(self, mock_request): + """Test set_auto_earn with different currencies""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + currencies = ['USDT', 'BTC', 'ETH', 'USDC'] + + for ccy in currencies: + mock_request.reset_mock() + result = self.account_api.set_auto_earn( + ccy=ccy, + action='turn_on', + earnType='0' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['ccy'], ccy) + + +class TestAccountAPIGetMaxOrderSize(unittest.TestCase): + """Unit tests for the get_max_order_size method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_with_required_params(self, mock_request): + """Test get_max_order_size with required parameters only""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USDT', + tdMode='cash' + ) + + expected_params = { + 'instId': 'BTC-USDT', + 'tdMode': 'cash', + 'ccy': '', + 'px': '' + } + mock_request.assert_called_once_with(c.GET, c.MAX_TRADE_SIZE, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_with_tradeQuoteCcy(self, mock_request): + """Test get_max_order_size with tradeQuoteCcy parameter for Unified USD Orderbook""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USD', + tdMode='cash', + tradeQuoteCcy='USDC' + ) + + expected_params = { + 'instId': 'BTC-USD', + 'tdMode': 'cash', + 'ccy': '', + 'px': '', + 'tradeQuoteCcy': 'USDC' + } + mock_request.assert_called_once_with(c.GET, c.MAX_TRADE_SIZE, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_order_size_without_tradeQuoteCcy(self, mock_request): + """Test get_max_order_size without tradeQuoteCcy (should not include in params)""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_order_size( + instId='BTC-USDT', + tdMode='cash' + ) + + call_args = mock_request.call_args[0][2] + self.assertNotIn('tradeQuoteCcy', call_args) + + +class TestAccountAPIGetMaxAvailSize(unittest.TestCase): + """Unit tests for the get_max_avail_size method""" + + def setUp(self): + """Set up test fixtures""" + self.account_api = AccountAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_with_required_params(self, mock_request): + """Test get_max_avail_size with required parameters only""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USDT', + tdMode='cash' + ) + + expected_params = { + 'instId': 'BTC-USDT', + 'tdMode': 'cash', + 'ccy': '', + 'reduceOnly': '', + 'unSpotOffset': '', + 'quickMgnType': '' + } + mock_request.assert_called_once_with(c.GET, c.MAX_AVAIL_SIZE, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_with_tradeQuoteCcy(self, mock_request): + """Test get_max_avail_size with tradeQuoteCcy parameter for Unified USD Orderbook""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USD', + tdMode='cash', + tradeQuoteCcy='USDC' + ) + + expected_params = { + 'instId': 'BTC-USD', + 'tdMode': 'cash', + 'ccy': '', + 'reduceOnly': '', + 'unSpotOffset': '', + 'quickMgnType': '', + 'tradeQuoteCcy': 'USDC' + } + mock_request.assert_called_once_with(c.GET, c.MAX_AVAIL_SIZE, expected_params) + + @patch.object(AccountAPI, '_request_with_params') + def test_get_max_avail_size_without_tradeQuoteCcy(self, mock_request): + """Test get_max_avail_size without tradeQuoteCcy (should not include in params)""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + result = self.account_api.get_max_avail_size( + instId='BTC-USDT', + tdMode='cash' + ) + + call_args = mock_request.call_args[0][2] + self.assertNotIn('tradeQuoteCcy', call_args) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/unit/okx/test_funding.py b/test/unit/okx/test_funding.py new file mode 100644 index 0000000..61fb701 --- /dev/null +++ b/test/unit/okx/test_funding.py @@ -0,0 +1,225 @@ +""" +Unit tests for okx.Funding module + +Mirrors the structure: okx/Funding.py -> test/unit/okx/test_funding.py +""" +import unittest +from unittest.mock import patch +from okx.Funding import FundingAPI +from okx import consts as c + + +class TestFundingAPIWithdrawal(unittest.TestCase): + """Unit tests for the withdrawal method""" + + def setUp(self): + """Set up test fixtures""" + self.funding_api = FundingAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_required_params(self, mock_request): + """Test withdrawal with required parameters only""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': [{'wdId': '12345'}]} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='100', + dest='4', + toAddr='0x1234567890abcdef' + ) + + # Assert + expected_params = { + 'ccy': 'USDT', + 'amt': '100', + 'dest': '4', + 'toAddr': '0x1234567890abcdef', + 'chain': '', + 'areaCode': '', + 'clientId': '' + } + mock_request.assert_called_once_with(c.POST, c.WITHDRAWAL_COIN, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_all_params(self, mock_request): + """Test withdrawal with all parameters provided""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': [{'wdId': '12345'}]} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='100', + dest='4', + toAddr='0x1234567890abcdef', + chain='USDT-TRC20', + areaCode='86', + clientId='client123', + toAddrType='1' + ) + + # Assert + expected_params = { + 'ccy': 'USDT', + 'amt': '100', + 'dest': '4', + 'toAddr': '0x1234567890abcdef', + 'chain': 'USDT-TRC20', + 'areaCode': '86', + 'clientId': 'client123', + 'toAddrType': '1' + } + mock_request.assert_called_once_with(c.POST, c.WITHDRAWAL_COIN, expected_params) + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_toAddrType_okx_account(self, mock_request): + """Test withdrawal with toAddrType for OKX account""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='USDT', + amt='50', + dest='3', + toAddr='user@example.com', + toAddrType='1' # OKX account + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['toAddrType'], '1') + + @patch.object(FundingAPI, '_request_with_params') + def test_withdrawal_with_toAddrType_external(self, mock_request): + """Test withdrawal with toAddrType for external address""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.withdrawal( + ccy='BTC', + amt='0.1', + dest='4', + toAddr='bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh', + toAddrType='2' # External address + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['toAddrType'], '2') + + +class TestFundingAPIGetWithdrawalHistory(unittest.TestCase): + """Unit tests for the get_withdrawal_history method""" + + def setUp(self): + """Set up test fixtures""" + self.funding_api = FundingAPI( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_no_params(self, mock_request): + """Test get_withdrawal_history with no parameters""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history() + + # Assert + call_args = mock_request.call_args[0][2] + self.assertNotIn('toAddrType', call_args) + self.assertEqual(result, mock_response) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_toAddrType(self, mock_request): + """Test get_withdrawal_history with toAddrType parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history( + ccy='USDT', + toAddrType='1' + ) + + # Assert + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['ccy'], 'USDT') + self.assertEqual(call_args['toAddrType'], '1') + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_with_all_params(self, mock_request): + """Test get_withdrawal_history with all parameters""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.funding_api.get_withdrawal_history( + ccy='BTC', + wdId='12345', + clientId='client123', + txId='tx123', + type='1', + state='2', + after='1609459200000', + before='1609545600000', + limit='10', + toAddrType='2' + ) + + # Assert + expected_params = { + 'ccy': 'BTC', + 'wdId': '12345', + 'clientId': 'client123', + 'txId': 'tx123', + 'type': '1', + 'state': '2', + 'after': '1609459200000', + 'before': '1609545600000', + 'limit': '10', + 'toAddrType': '2' + } + mock_request.assert_called_once_with(c.GET, c.GET_WITHDRAWAL_HISTORY, expected_params) + + @patch.object(FundingAPI, '_request_with_params') + def test_get_withdrawal_history_filter_by_state(self, mock_request): + """Test get_withdrawal_history filtering by state""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + states = ['0', '1', '2', '3', '4', '5'] + + for state in states: + mock_request.reset_mock() + result = self.funding_api.get_withdrawal_history(state=state) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['state'], state) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/unit/okx/test_okxclient.py b/test/unit/okx/test_okxclient.py new file mode 100644 index 0000000..463507b --- /dev/null +++ b/test/unit/okx/test_okxclient.py @@ -0,0 +1,186 @@ +""" +Unit tests for okx.okxclient module + +Mirrors the structure: okx/okxclient.py -> test/unit/okx/test_okxclient.py +""" +import unittest +import warnings +from unittest.mock import patch, MagicMock + + +class TestOkxClientInit(unittest.TestCase): + """Unit tests for OkxClient initialization""" + + def test_init_with_default_parameters(self): + """Test initialization with default parameters""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient() + + self.assertEqual(client.API_KEY, '-1') + self.assertEqual(client.API_SECRET_KEY, '-1') + self.assertEqual(client.PASSPHRASE, '-1') + self.assertEqual(client.flag, '1') + self.assertFalse(client.debug) + + def test_init_with_custom_parameters(self): + """Test initialization with custom parameters""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0', + debug=True + ) + + self.assertEqual(client.API_KEY, 'test_key') + self.assertEqual(client.API_SECRET_KEY, 'test_secret') + self.assertEqual(client.PASSPHRASE, 'test_pass') + self.assertEqual(client.flag, '0') + self.assertTrue(client.debug) + + def test_init_with_deprecated_use_server_time_shows_warning(self): + """Test that using deprecated use_server_time parameter shows warning""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + client = OkxClient(use_server_time=True) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("use_server_time parameter is deprecated", str(w[0].message)) + + +class TestOkxClientHttpxCompatibility(unittest.TestCase): + """Unit tests for httpx version compatibility in OkxClient""" + + def test_init_with_new_httpx_proxy_parameter(self): + """Test initialization with new httpx version using proxy parameter""" + with patch('okx.okxclient.Client.__init__') as mock_init: + # Simulate new httpx version (accepts proxy parameter) + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient(proxy='http://proxy.example.com:8080') + + # Should call super().__init__ with proxy parameter + mock_init.assert_called_once() + call_kwargs = mock_init.call_args + self.assertIn('proxy', call_kwargs.kwargs) + self.assertEqual(call_kwargs.kwargs['proxy'], 'http://proxy.example.com:8080') + + def test_init_with_old_httpx_falls_back_to_proxies(self): + """Test initialization falls back to proxies for old httpx version""" + call_count = [0] + + def mock_init_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1 and 'proxy' in kwargs: + # First call with proxy parameter - simulate old httpx raising TypeError + raise TypeError("__init__() got an unexpected keyword argument 'proxy'") + # Second call should work (with proxies or without) + return None + + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.side_effect = mock_init_side_effect + + from okx.okxclient import OkxClient + client = OkxClient(proxy='http://proxy.example.com:8080') + + # Should have been called twice + self.assertEqual(mock_init.call_count, 2) + + # Second call should use proxies parameter + second_call = mock_init.call_args_list[1] + self.assertIn('proxies', second_call.kwargs) + expected_proxies = { + 'http://': 'http://proxy.example.com:8080', + 'https://': 'http://proxy.example.com:8080' + } + self.assertEqual(second_call.kwargs['proxies'], expected_proxies) + + def test_init_with_old_httpx_no_proxy(self): + """Test initialization with old httpx version without proxy""" + call_count = [0] + + def mock_init_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1 and 'proxy' in kwargs: + # First call with proxy parameter - simulate old httpx raising TypeError + raise TypeError("__init__() got an unexpected keyword argument 'proxy'") + return None + + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.side_effect = mock_init_side_effect + + from okx.okxclient import OkxClient + client = OkxClient() # No proxy + + # Should have been called twice + self.assertEqual(mock_init.call_count, 2) + + # Second call should not have proxies parameter + second_call = mock_init.call_args_list[1] + self.assertNotIn('proxies', second_call.kwargs) + + def test_init_without_proxy(self): + """Test initialization without proxy parameter""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + + from okx.okxclient import OkxClient + client = OkxClient() + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + self.assertEqual(call_kwargs.get('proxy'), None) + + +class TestOkxClientRequest(unittest.TestCase): + """Unit tests for OkxClient request methods""" + + def setUp(self): + """Set up test fixtures""" + with patch('okx.okxclient.Client.__init__') as mock_init: + mock_init.return_value = None + from okx.okxclient import OkxClient + self.client = OkxClient( + api_key='test_key', + api_secret_key='test_secret', + passphrase='test_pass', + flag='0' + ) + + def test_request_without_params(self): + """Test _request_without_params calls _request with empty dict""" + with patch.object(self.client, '_request') as mock_request: + mock_request.return_value = {'code': '0'} + + result = self.client._request_without_params('GET', '/api/v5/test') + + mock_request.assert_called_once_with('GET', '/api/v5/test', {}) + + def test_request_with_params(self): + """Test _request_with_params passes params correctly""" + with patch.object(self.client, '_request') as mock_request: + mock_request.return_value = {'code': '0'} + params = {'instId': 'BTC-USDT'} + + result = self.client._request_with_params('GET', '/api/v5/test', params) + + mock_request.assert_called_once_with('GET', '/api/v5/test', params) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/unit/okx/test_public_data.py b/test/unit/okx/test_public_data.py new file mode 100644 index 0000000..abe6824 --- /dev/null +++ b/test/unit/okx/test_public_data.py @@ -0,0 +1,208 @@ +""" +Unit tests for okx.PublicData module + +Mirrors the structure: okx/PublicData.py -> test/unit/okx/test_public_data.py +""" +import unittest +from unittest.mock import patch +from okx.PublicData import PublicAPI +from okx import consts as c + + +class TestPublicAPIMarketDataHistory(unittest.TestCase): + """Unit tests for the get_market_data_history method""" + + def setUp(self): + """Set up test fixtures""" + self.public_api = PublicAPI(flag='0') + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_required_params(self, mock_request): + """Test get_market_data_history with required parameters only""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{'ts': '1234567890', 'vol': '1000'}] + } + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SPOT', + 'dateAggrType': '1D', + 'begin': '1609459200000', + 'end': '1609545600000' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_all_params(self, mock_request): + """Test get_market_data_history with all parameters provided""" + # Arrange + mock_response = { + 'code': '0', + 'msg': '', + 'data': [{'ts': '1234567890', 'vol': '1000'}] + } + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SWAP', + dateAggrType='1W', + begin='1609459200000', + end='1609545600000', + instIdList='BTC-USDT-SWAP,ETH-USDT-SWAP', + instFamilyList='BTC-USDT,ETH-USDT' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SWAP', + 'dateAggrType': '1W', + 'begin': '1609459200000', + 'end': '1609545600000', + 'instIdList': 'BTC-USDT-SWAP,ETH-USDT-SWAP', + 'instFamilyList': 'BTC-USDT,ETH-USDT' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + self.assertEqual(result, mock_response) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_inst_id_list(self, mock_request): + """Test get_market_data_history with instIdList parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000', + instIdList='BTC-USDT' + ) + + # Assert + expected_params = { + 'module': 'volume', + 'instType': 'SPOT', + 'dateAggrType': '1D', + 'begin': '1609459200000', + 'end': '1609545600000', + 'instIdList': 'BTC-USDT' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_with_inst_family_list(self, mock_request): + """Test get_market_data_history with instFamilyList parameter""" + # Arrange + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + # Act + result = self.public_api.get_market_data_history( + module='openInterest', + instType='FUTURES', + dateAggrType='1M', + begin='1609459200000', + end='1612137600000', + instFamilyList='BTC-USD' + ) + + # Assert + expected_params = { + 'module': 'openInterest', + 'instType': 'FUTURES', + 'dateAggrType': '1M', + 'begin': '1609459200000', + 'end': '1612137600000', + 'instFamilyList': 'BTC-USD' + } + mock_request.assert_called_once_with(c.GET, c.MARKET_DATA_HISTORY, expected_params) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_inst_types(self, mock_request): + """Test get_market_data_history with different instType values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + inst_types = ['SPOT', 'SWAP', 'FUTURES', 'OPTION'] + + for inst_type in inst_types: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module='volume', + instType=inst_type, + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args + self.assertEqual(call_args[0][1], c.MARKET_DATA_HISTORY) + self.assertEqual(call_args[1]['instType'] if call_args[1] else call_args[0][2]['instType'], inst_type) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_date_aggr_types(self, mock_request): + """Test get_market_data_history with different dateAggrType values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + date_aggr_types = ['1D', '1W', '1M'] + + for aggr_type in date_aggr_types: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module='volume', + instType='SPOT', + dateAggrType=aggr_type, + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['dateAggrType'], aggr_type) + + @patch.object(PublicAPI, '_request_with_params') + def test_get_market_data_history_different_modules(self, mock_request): + """Test get_market_data_history with different module values""" + mock_response = {'code': '0', 'msg': '', 'data': []} + mock_request.return_value = mock_response + + modules = ['volume', 'openInterest', 'tradeCount'] + + for module in modules: + mock_request.reset_mock() + result = self.public_api.get_market_data_history( + module=module, + instType='SPOT', + dateAggrType='1D', + begin='1609459200000', + end='1609545600000' + ) + + call_args = mock_request.call_args[0][2] + self.assertEqual(call_args['module'], module) + + +if __name__ == '__main__': + unittest.main() + diff --git a/test/unit/okx/websocket/__init__.py b/test/unit/okx/websocket/__init__.py new file mode 100644 index 0000000..b0061bd --- /dev/null +++ b/test/unit/okx/websocket/__init__.py @@ -0,0 +1,2 @@ +# Unit tests for okx.websocket module + diff --git a/test/unit/okx/websocket/test_ws_private_async.py b/test/unit/okx/websocket/test_ws_private_async.py new file mode 100644 index 0000000..531f5a8 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_private_async.py @@ -0,0 +1,582 @@ +""" +Unit tests for okx.websocket.WsPrivateAsync module + +Mirrors the structure: okx/websocket/WsPrivateAsync.py -> test/unit/okx/websocket/test_ws_private_async.py +""" +import json +import unittest +import asyncio +import warnings +from unittest.mock import patch, MagicMock, AsyncMock + +# Import the module first so patch can resolve the path +import okx.websocket.WsPrivateAsync as ws_private_module +from okx.websocket.WsPrivateAsync import WsPrivateAsync + + +class TestWsPrivateAsyncInit(unittest.TestCase): + """Unit tests for WsPrivateAsync initialization""" + + def test_init_with_required_params(self): + """Test initialization with required parameters""" + with patch.object(ws_private_module, 'WebSocketFactory') as mock_factory: + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + self.assertEqual(ws.url, "wss://test.example.com") + self.assertFalse(ws.useServerTime) + self.assertFalse(ws.debug) + mock_factory.assert_called_once_with("wss://test.example.com") + + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + debug=True + ) + + self.assertTrue(ws.debug) + + def test_init_with_deprecated_useServerTime_shows_warning(self): + """Test that using deprecated useServerTime parameter shows warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com", + useServerTime=True + ) + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("useServerTime parameter is deprecated", str(w[0].message)) + + def test_init_without_useServerTime_no_warning(self): + """Test that not using useServerTime parameter shows no warning""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + # No deprecation warning expected + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + self.assertEqual(len(deprecation_warnings), 0) + + +class TestWsPrivateAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync subscribe method""" + + def test_subscribe_sends_correct_payload(self): + """Test subscribe sends correct payload after login""" + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils, \ + patch.object(ws_private_module.asyncio, 'sleep', new_callable=AsyncMock): + + mock_ws_utils.initLoginParams.return_value = '{"op":"login"}' + + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils, \ + patch.object(ws_private_module.asyncio, 'sleep', new_callable=AsyncMock): + + mock_ws_utils.initLoginParams.return_value = '{"op":"login"}' + + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + # Second call should be the subscribe (first is login) + subscribe_call = mock_websocket.send.call_args_list[1] + payload = json.loads(subscribe_call[0][0]) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPrivateAsync unsubscribe method""" + + def test_unsubscribe_sends_correct_payload(self): + """Test unsubscribe sends correct payload""" + with patch.object(ws_private_module, 'WebSocketFactory'): + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch.object(ws_private_module, 'WebSocketFactory'): + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "account", "ccy": "BTC"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncSend(unittest.TestCase): + """Unit tests for WsPrivateAsync generic send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncOrderMethods(unittest.TestCase): + """Unit tests for WsPrivateAsync order-related methods""" + + def _create_ws_instance(self): + """Helper to create WsPrivateAsync instance with mocked websocket""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + from okx.websocket.WsPrivateAsync import WsPrivateAsync + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + return ws, mock_websocket + + def test_place_order_sends_correct_payload(self): + """Test place_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [{ + "instId": "BTC-USDT", + "tdMode": "cash", + "side": "buy", + "ordType": "limit", + "sz": "0.001", + "px": "30000" + }] + + async def run_test(): + await ws.place_order(order_args, callback=callback, id="order001") + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "order001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_place_order_without_id(self): + """Test place_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.place_order(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_sends_correct_payload(self): + """Test batch_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + order_args = [ + {"instId": "BTC-USDT", "side": "buy", "sz": "0.001", "px": "30000"}, + {"instId": "ETH-USDT", "side": "buy", "sz": "0.01", "px": "2000"} + ] + + async def run_test(): + await ws.batch_orders(order_args, callback=callback, id="batch001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertEqual(payload["args"], order_args) + self.assertEqual(payload["id"], "batch001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_orders_without_id(self): + """Test batch_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + order_args = [{"instId": "BTC-USDT"}, {"instId": "ETH-USDT"}] + + async def run_test(): + await ws.batch_orders(order_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_sends_correct_payload(self): + """Test cancel_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args, callback=callback, id="cancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "cancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_cancel_order_without_id(self): + """Test cancel_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.cancel_order(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "cancel-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_sends_correct_payload(self): + """Test batch_cancel_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + cancel_args = [ + {"instId": "BTC-USDT", "ordId": "12345"}, + {"instId": "ETH-USDT", "ordId": "67890"} + ] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args, callback=callback, id="batchCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertEqual(payload["args"], cancel_args) + self.assertEqual(payload["id"], "batchCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_cancel_orders_without_id(self): + """Test batch_cancel_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + cancel_args = [{"instId": "BTC-USDT", "ordId": "12345"}] + + async def run_test(): + await ws.batch_cancel_orders(cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-cancel-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_sends_correct_payload(self): + """Test amend_order sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [{ + "instId": "BTC-USDT", + "ordId": "12345", + "newSz": "0.002", + "newPx": "31000" + }] + + async def run_test(): + await ws.amend_order(amend_args, callback=callback, id="amend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "amend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_amend_order_without_id(self): + """Test amend_order without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.amend_order(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "amend-order") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_sends_correct_payload(self): + """Test batch_amend_orders sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + amend_args = [ + {"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}, + {"instId": "ETH-USDT", "ordId": "67890", "newPx": "2100"} + ] + + async def run_test(): + await ws.batch_amend_orders(amend_args, callback=callback, id="batchAmend001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertEqual(payload["args"], amend_args) + self.assertEqual(payload["id"], "batchAmend001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_batch_amend_orders_without_id(self): + """Test batch_amend_orders without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + amend_args = [{"instId": "BTC-USDT", "ordId": "12345", "newSz": "0.002"}] + + async def run_test(): + await ws.batch_amend_orders(amend_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "batch-amend-orders") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_sends_correct_payload(self): + """Test mass_cancel sends correct operation""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + callback = MagicMock() + mass_cancel_args = [{ + "instType": "SPOT", + "instFamily": "BTC-USDT" + }] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args, callback=callback, id="massCancel001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertEqual(payload["args"], mass_cancel_args) + self.assertEqual(payload["id"], "massCancel001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_mass_cancel_without_id(self): + """Test mass_cancel without id parameter""" + with patch('okx.websocket.WsPrivateAsync.WebSocketFactory'): + ws, mock_websocket = self._create_ws_instance() + mass_cancel_args = [{"instType": "SPOT", "instFamily": "BTC-USDT"}] + + async def run_test(): + await ws.mass_cancel(mass_cancel_args) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "mass-cancel") + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncLogin(unittest.TestCase): + """Unit tests for WsPrivateAsync login method""" + + def test_login_calls_init_login_params(self): + """Test login calls WsUtils.initLoginParams with correct parameters""" + with patch.object(ws_private_module, 'WebSocketFactory'), \ + patch.object(ws_private_module, 'WsUtils') as mock_ws_utils: + + mock_ws_utils.initLoginParams.return_value = '{"op":"login","args":[...]}' + + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + mock_ws_utils.initLoginParams.assert_called_once_with( + useServerTime=False, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPrivateAsyncStartStop(unittest.TestCase): + """Unit tests for WsPrivateAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch.object(ws_private_module, 'WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + ws = WsPrivateAsync( + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key", + url="wss://test.example.com" + ) + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/okx/websocket/test_ws_public_async.py b/test/unit/okx/websocket/test_ws_public_async.py new file mode 100644 index 0000000..ca27b38 --- /dev/null +++ b/test/unit/okx/websocket/test_ws_public_async.py @@ -0,0 +1,321 @@ +""" +Unit tests for okx.websocket.WsPublicAsync module + +Mirrors the structure: okx/websocket/WsPublicAsync.py -> test/unit/okx/websocket/test_ws_public_async.py +""" +import json +import unittest +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock + +# Import the module first so patch can resolve the path +import okx.websocket.WsPublicAsync as ws_public_module +from okx.websocket.WsPublicAsync import WsPublicAsync + + +class TestWsPublicAsyncInit(unittest.TestCase): + """Unit tests for WsPublicAsync initialization""" + + def test_init_with_url(self): + """Test initialization with url parameter""" + with patch.object(ws_public_module, 'WebSocketFactory') as mock_factory: + ws = WsPublicAsync(url="wss://test.example.com") + + self.assertEqual(ws.url, "wss://test.example.com") + self.assertEqual(ws.apiKey, '') + self.assertEqual(ws.passphrase, '') + self.assertEqual(ws.secretKey, '') + self.assertFalse(ws.debug) + self.assertFalse(ws.isLoggedIn) + + def test_init_with_credentials(self): + """Test initialization with all credentials for business channel""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + + self.assertEqual(ws.apiKey, "test_api_key") + self.assertEqual(ws.passphrase, "test_passphrase") + self.assertEqual(ws.secretKey, "test_secret_key") + + def test_init_with_debug_enabled(self): + """Test initialization with debug mode enabled""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=True) + + self.assertTrue(ws.debug) + + def test_init_with_debug_disabled(self): + """Test initialization with debug mode disabled (default)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com", debug=False) + + self.assertFalse(ws.debug) + + +class TestWsPublicAsyncLogin(unittest.TestCase): + """Unit tests for WsPublicAsync login method""" + + def test_login_without_credentials_raises_error(self): + """Test that login raises ValueError when credentials are missing""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + with self.assertRaises(ValueError) as context: + await ws.login() + self.assertIn("apiKey, secretKey and passphrase are required for login", str(context.exception)) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_login_with_credentials_success(self): + """Test successful login with valid credentials""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory') as mock_factory, \ + patch('okx.websocket.WsPublicAsync.WsUtils.initLoginParams') as mock_init_login: + + mock_init_login.return_value = '{"op":"login","args":[...]}' + + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync( + url="wss://test.example.com", + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + + async def run_test(): + result = await ws.login() + self.assertTrue(result) + self.assertTrue(ws.isLoggedIn) + mock_init_login.assert_called_once_with( + useServerTime=False, + apiKey="test_api_key", + passphrase="test_passphrase", + secretKey="test_secret_key" + ) + mock_websocket.send.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncSubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync subscribe method""" + + def test_subscribe_without_id(self): + """Test subscribe without id parameter""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback) + self.assertEqual(ws.callback, callback) + mock_websocket.send.assert_called_once() + + # Verify the payload + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_id(self): + """Test subscribe with id parameter""" + with patch.object(ws_public_module, 'WebSocketFactory'): + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.subscribe(params, callback, id="sub001") + + # Verify the payload includes id + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "subscribe") + self.assertEqual(payload["args"], params) + self.assertEqual(payload["id"], "sub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_subscribe_with_multiple_channels(self): + """Test subscribe with multiple channels""" + with patch.object(ws_public_module, 'WebSocketFactory'): + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [ + {"channel": "tickers", "instId": "BTC-USDT"}, + {"channel": "tickers", "instId": "ETH-USDT"} + ] + + async def run_test(): + await ws.subscribe(params, callback, id="multi001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(len(payload["args"]), 2) + self.assertEqual(payload["id"], "multi001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncUnsubscribe(unittest.TestCase): + """Unit tests for WsPublicAsync unsubscribe method""" + + def test_unsubscribe_without_id(self): + """Test unsubscribe without id parameter""" + with patch.object(ws_public_module, 'WebSocketFactory'): + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["args"], params) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_unsubscribe_with_id(self): + """Test unsubscribe with id parameter""" + with patch.object(ws_public_module, 'WebSocketFactory'): + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + params = [{"channel": "tickers", "instId": "BTC-USDT"}] + + async def run_test(): + await ws.unsubscribe(params, callback, id="unsub001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "unsubscribe") + self.assertEqual(payload["id"], "unsub001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncSend(unittest.TestCase): + """Unit tests for WsPublicAsync send method""" + + def test_send_without_id(self): + """Test generic send method without id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + callback = MagicMock() + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=callback) + self.assertEqual(ws.callback, callback) + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["args"], args) + self.assertNotIn("id", payload) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_id(self): + """Test generic send method with id""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, id="send001") + call_args = mock_websocket.send.call_args[0][0] + payload = json.loads(call_args) + self.assertEqual(payload["op"], "custom_op") + self.assertEqual(payload["id"], "send001") + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_without_callback(self): + """Test send method without callback (preserves existing callback)""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + existing_callback = MagicMock() + ws.callback = existing_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args) + # Callback should remain unchanged + self.assertEqual(ws.callback, existing_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + def test_send_with_new_callback_replaces_existing(self): + """Test send method with new callback replaces existing callback""" + with patch('okx.websocket.WsPublicAsync.WebSocketFactory'): + from okx.websocket.WsPublicAsync import WsPublicAsync + ws = WsPublicAsync(url="wss://test.example.com") + mock_websocket = AsyncMock() + ws.websocket = mock_websocket + old_callback = MagicMock() + new_callback = MagicMock() + ws.callback = old_callback + args = [{"instId": "BTC-USDT"}] + + async def run_test(): + await ws.send("custom_op", args, callback=new_callback) + self.assertEqual(ws.callback, new_callback) + + asyncio.get_event_loop().run_until_complete(run_test()) + + +class TestWsPublicAsyncStartStop(unittest.TestCase): + """Unit tests for WsPublicAsync start and stop methods""" + + def test_stop(self): + """Test stop method closes the factory""" + with patch.object(ws_public_module, 'WebSocketFactory') as mock_factory_class: + mock_factory_instance = MagicMock() + mock_factory_instance.close = AsyncMock() + mock_factory_class.return_value = mock_factory_instance + + ws = WsPublicAsync(url="wss://test.example.com") + + async def run_test(): + await ws.stop() + mock_factory_instance.close.assert_called_once() + + asyncio.get_event_loop().run_until_complete(run_test()) + + +if __name__ == '__main__': + unittest.main()