diff --git a/test/eth_trade_bot b/test/eth_trade_bot new file mode 100644 index 0000000..d1e1956 --- /dev/null +++ b/test/eth_trade_bot @@ -0,0 +1,1215 @@ +# ======================== 导入必要的库 ======================== +import asyncio +import time +import json +import logging +import string +import random +import os +from datetime import datetime, timedelta +from collections import defaultdict +import pytz +import pandas as pd +import numpy as np +import requests +from okx.websocket.WsPrivateAsync import WsPrivateAsync as PrivateWs +from okx.websocket.WsPublicAsync import WsPublicAsync as PublicWs +import okx.Account as Account +import okx.Trade as Trade +import okx.MarketData as MarketData +from supertrend_lib1 import SupertrendAnalyzer # 使用您提供的库 + +from decimal import Decimal, getcontext, ROUND_HALF_UP + +# ======================== Decimal 配置 ======================== +getcontext().prec = 28 # 高精度以避免价格/数量浮点误差 + +# ======================== 日志系统配置 ======================== +def setup_logging(): + os.makedirs("logs", exist_ok=True) + logger = logging.getLogger("intelligent_trading_bot") + logger.setLevel(logging.DEBUG) + + file_handler = logging.FileHandler("logs/trading.log") + file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter('%(message)s')) + + # 避免重复添加 handler(在多次导入时) + if not logger.handlers: + logger.addHandler(file_handler) + logger.addHandler(console_handler) + return logger + + +logger = setup_logging() + + +def log_action(action, details, level="info", extra_data=None, exc_info=False): + symbols = { + "debug": "🔵", + "info": "🟢", + "warning": "🟠", + "error": "🔴", + "critical": "⛔" + } + symbol = symbols.get(level, "⚪") + header = f"\n{'-' * 80}\n[{datetime.now().strftime('%H:%M:%S.%f')}] {symbol} {action}" + log_line = header + f"\n • {details}" + + if extra_data: + try: + if isinstance(extra_data, dict): + log_line += f"\n • 附加数据: {json.dumps(extra_data, indent=2, ensure_ascii=False)}" + else: + log_line += f"\n • 附加数据: {extra_data}" + except Exception: + log_line += f"\n • 附加数据: [无法序列化]" + + log_line += f"\n{'-' * 80}" + + # 适配 debug 等级 + if level == "debug": + logger.debug(log_line, exc_info=exc_info) + elif level == "info": + logger.info(log_line, exc_info=exc_info) + elif level == "warning": + logger.warning(log_line, exc_info=exc_info) + elif level == "error": + logger.error(log_line, exc_info=exc_info) + elif level == "critical": + logger.critical(log_line, exc_info=exc_info) + else: + logger.info(log_line, exc_info=exc_info) + + +def log_state_transition(current_state, new_state, reason): + """记录状态转换 - 增加更多上下文信息""" + global trading_phase + + # 获取当前仓位信息 + long_key = get_position_key(SYMBOL, "long") + short_key = get_position_key(SYMBOL, "short") + long_pos = position_info[long_key]["pos"] + short_pos = position_info[short_key]["pos"] + + # 计算ETH价值 + long_eth = calculate_contract_value(long_pos) + short_eth = calculate_contract_value(short_pos) + + logger.info(f"\n{'=' * 80}") + logger.info(f"🔄 状态变更: [{current_state}] → [{new_state}]") + logger.info(f" 原因: {reason}") + logger.info(f" 当前方向: {trading_direction.upper()}") + logger.info(f" 多仓: {long_pos}张 ({long_eth:.6f} ETH) | 空仓: {short_pos}张 ({short_eth:.6f} ETH)") + logger.info(f" 当前价格: ${current_price:.4f}") + logger.info(f" 账户净值: ${account_equity:.2f}") + logger.info(f" 活跃订单: {len(active_orders)}个") + logger.info(f" 最后趋势检查: {datetime.fromtimestamp(last_trend_check).strftime('%H:%M:%S')}") + logger.info(f"{'=' * 80}\n") + + +# ======================== 基础配置 ======================== +API_KEY = "a3a2a008-576a-4548-a1c0-f28ac940bd6b" +SECRET_KEY = "3F9170BCBE9C88EDFA6675438CD0DBAA" +PASSPHRASE = "Aa123414.." + +# ======================== 交易对基础信息(可由 API 获取覆盖) ======================== +CONTRACT_INFO = { + "symbol": "ETH-USDT-SWAP", + "lotSz": 1, # 下单数量精度 + "minSz": 1, # 最小下单数量 + "ctVal": 0.01, # 合约面值 (每张合约代表 0.01 ETH) -- 已修正注释 + "tickSz": 0.1, # 价格精度 + "ctValCcy": "ETH", # 合约价值货币 + "instType": "SWAP" # 合约类型 +} + +# 使用合约信息配置全局变量(会在启动时尝试从 API 更新) +SYMBOL = CONTRACT_INFO["symbol"] +TICK_SIZE = CONTRACT_INFO["tickSz"] + +# Supertrend策略配置 +SUPERTREND_CONFIG = { + 'symbol': SYMBOL, + 'timeframe': '1H', + 'atr_period': 7, + 'multiplier': 2.5, + 'change_atr': True, + 'min_data': 50, + 'data_delay': 1 +} + +# ======================== 全局配置常量 ======================== +PING_INTERVAL = 25 # WebSocket心跳间隔 +PONG_TIMEOUT = 5 # Pong响应超时时间 + +# 交易策略配置 +TRADE_STRATEGY = { + "price_offset": 0.015, # 价格偏移 15%(例如 0.015 表示当前价格的 15%) + "eth_position": 0.01, # 目标 ETH 持仓量(以 ETH 计) + "leverage": 10, # 杠杆倍数 + "atr_multiplier": 0.7, # ATR系数 + "order_increment": 0, # 开仓单在基础仓位上增加的合约张数;0 表示不启用增仓 + "fixed_trend_direction": "long", # 固定趋势方向 (long/short) + "trend_mode": "fixed" # "fixed" 或 "auto" +} + +# 全局变量 +account_equity = 0.0 +initial_equity = 0.0 +current_price = 0.0 +last_price = 0.0 +trading_direction = TRADE_STRATEGY.get("fixed_trend_direction", "long") +last_trend_check = 0 +last_ws_price_update = 0 +last_api_price_update = 0 +price_source = "unknown" +# 日志 +last_status_log_time = 0 +STATUS_LOG_INTERVAL = 60 # 60秒记录一次摘要 + +# 订单和仓位管理 +active_orders = {} # 使用 cl_ord_id 作为键 +position_info = defaultdict(lambda: { + "pos": 0.0, # 合约张数 + "eth_value": 0.0, # ETH价值 + "avg_px": 0.0, + "upl": 0.0, + "entry_time": 0 +}) + +# 订单对映射 +order_pair_mapping = {} # 使用 pair_id 作为键,存储订单对信息 + +# API客户端 +flag = "0" +account_api = Account.AccountAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, flag) +trade_api = Trade.TradeAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, flag) +market_api = MarketData.MarketAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, flag) + +# WebSocket实例(占位) +ws_private = None +ws_public = None + +# 初始化Supertrend分析器 +supertrend_analyzer = SupertrendAnalyzer(SUPERTREND_CONFIG) + + +# ======================== ETH合约工具函数 ======================== +def calculate_contract_value(sz): + """计算合约实际价值 (ETH数量)""" + return sz * CONTRACT_INFO["ctVal"] + + +def calculate_contract_size(eth_amount): + """根据ETH数量计算合约张数""" + return eth_amount / CONTRACT_INFO["ctVal"] + + +def round_to_min_size(size): + """调整数量到最小交易单位的整数倍""" + min_lot = CONTRACT_INFO["minSz"] + # 保证返回与 min_lot 同单位(整数倍) + if min_lot == 0: + return size + # 使用 Decimal 以获得更好精度 + s = Decimal(str(size)) + m = Decimal(str(min_lot)) + rounded = (s / m).quantize(Decimal('1'), rounding=ROUND_HALF_UP) * m + return float(rounded) + + +def validate_position_size(size): + """验证仓位大小是否符合要求""" + min_size = CONTRACT_INFO["minSz"] + if size < min_size: + log_action("风控", f"仓位大小{size}小于最小值{min_size},自动修正", "warning") + return min_size + + # 检查是否为最小单位的整数倍 + if not (size / min_size).is_integer(): + log_action("风控", f"仓位大小{size}不是最小单位{min_size}的整数倍", "warning") + return round_to_min_size(size) + + return size + + +def get_position_size(): + """根据ETH目标持仓计算合约张数""" + eth_amount = TRADE_STRATEGY["eth_position"] + contract_size = calculate_contract_size(eth_amount) + return round_to_min_size(contract_size) # 确保符合最小单位 + + +def get_size_precision(): + """获取数量精度的小数位数""" + min_sz = CONTRACT_INFO["minSz"] + # 计算需要的小数位数 + if min_sz >= 1: + return 0 + elif min_sz >= 0.1: + return 1 + elif min_sz >= 0.01: + return 2 + elif min_sz >= 0.001: + return 3 + else: + return 4 # 默认4位小数 + + +def get_price_precision(): + """获取价格精度的小数位数""" + tick_sz = CONTRACT_INFO["tickSz"] + # 计算需要的小数位数 + if tick_sz >= 1: + return 0 + elif tick_sz >= 0.1: + return 1 + elif tick_sz >= 0.01: + return 2 + elif tick_sz >= 0.001: + return 3 + else: + return 4 # 默认4位小数 + + +def validate_order_increment(): + """验证开仓增量配置是否有效。允许为 0(表示不启用增仓)。""" + increment = TRADE_STRATEGY.get("order_increment", 0) + min_sz = CONTRACT_INFO["minSz"] + + if increment < 0: + log_action("配置验证", "开仓增量不能为负数", "error") + return False + + if increment == 0: + log_action("配置验证", "开仓增量为0,表示不启用分批增仓", "info") + return True + + # 如果小于最小单位,自动修正到最小单位 + if increment < min_sz: + log_action("配置验证", + f"开仓增量({increment})小于最小交易单位({min_sz}),自动修正为最小单位", + "warning") + TRADE_STRATEGY["order_increment"] = min_sz + return True + + # 检查是否为最小单位的整数倍 + if not (increment / min_sz).is_integer(): + log_action("配置验证", + f"开仓增量({increment})不是最小单位({min_sz})的整数倍,自动四舍五入", + "warning") + TRADE_STRATEGY["order_increment"] = round_to_min_size(increment) + return True + + return True + + +# ======================== 交易阶段定义 ======================== +class TradingPhase: + INIT = "INIT" + A1_POSITION_SETUP = "A1" # 新增: 仓位建立阶段 + A2_ORDER_PAIR = "A2" # 订单对阶段 + B2_WAIT_PAIR = "B2" # 订单对监控阶段 + + +# 当前交易阶段 +trading_phase = TradingPhase.INIT +phase_start_time = 0 + + +# ======================== 核心功能函数 ======================== +def get_position_key(inst_id, pos_side): + """获取持仓唯一键""" + return f"{inst_id}-{pos_side.lower()}" + + +def generate_order_id(prefix): + """生成符合OKX要求的订单ID""" + clean_prefix = ''.join(c for c in prefix if c.isalnum()) + rand_part = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + return (clean_prefix + rand_part)[:32] + + +def round_price(price): + """根据 tick 精度调整价格(使用 Decimal 以避免 float 精度问题)""" + tick = Decimal(str(TICK_SIZE)) + p = Decimal(str(price)) + quant = (p / tick).quantize(Decimal('1'), rounding=ROUND_HALF_UP) + rounded = (quant * tick).normalize() + return float(rounded) + + +def safe_float(value, default=0.0): + """安全转换浮点数""" + try: + return float(value) if value else default + except (ValueError, TypeError): + return default + + +# ======================== 账户与市场数据 ======================== +async def fetch_account_balance(): + """获取账户余额 - 使用线程池执行阻塞 API 调用""" + global account_equity, initial_equity + + try: + log_action("账户查询", "发送账户余额请求", "debug") + # 尝试直接按 SDK 常用签名调用,放到线程池以防阻塞 + response = await asyncio.to_thread(account_api.get_account_balance, ccy="USDT") + + # 记录完整响应 + log_action("账户查询", "收到账户余额响应", "debug", response) + + if response.get("code", "") == "0" and response.get("data"): + for detail_group in response["data"]: + for detail in detail_group.get("details", []): + if detail.get("ccy") == "USDT": + equity = float(detail.get("eq", 0.0)) + account_equity = equity + if initial_equity == 0: + initial_equity = equity + log_action("账户初始化", f"初始权益: ${initial_equity:.2f}") + return True + return False + except Exception as e: + log_action("账户查询", f"请求失败: {str(e)}", "error", exc_info=True) + return False + + +def log_periodic_status(): + """记录周期性状态摘要""" + global last_status_log_time + + current_time = time.time() + if current_time - last_status_log_time >= STATUS_LOG_INTERVAL: + last_status_log_time = current_time + + long_key = get_position_key(SYMBOL, "long") + short_key = get_position_key(SYMBOL, "short") + long_pos = position_info[long_key]["pos"] + short_pos = position_info[short_key]["pos"] + + # 计算ETH价值 + long_eth = calculate_contract_value(long_pos) + short_eth = calculate_contract_value(short_pos) + + logger.info(f"\n{'=' * 80}") + logger.info(f"📊 当前状态摘要 - {trading_phase}") + logger.info(f" 当前方向: {trading_direction.upper()}") + logger.info(f" 多仓: {long_pos}张 ({long_eth:.6f} ETH) | 空仓: {short_pos}张 ({short_eth:.6f} ETH)") + logger.info(f" 当前价格: ${current_price:.4f}") + logger.info(f" 账户净值: ${account_equity:.2f}") + logger.info(f" 活跃订单: {len(active_orders)}个") + logger.info(f" 最后趋势检查: {datetime.fromtimestamp(last_trend_check).strftime('%H:%M:%S')}") + logger.info(f" 开仓增量配置: {TRADE_STRATEGY['order_increment']}张") + logger.info(f"{'=' * 80}\n") + + +async def update_position_info(): + """获取持仓信息 - 使用线程池执行阻塞 API 调用""" + try: + log_action("仓位查询", "发送仓位查询请求", "debug") + response = await asyncio.to_thread(account_api.get_positions, instType="SWAP") + + log_action("仓位查询", "收到仓位查询响应", "debug", response) + + if response.get("code") == "0": + positions = response.get("data", []) + + # 重置仓位信息 + for key in list(position_info.keys()): + position_info[key]["pos"] = 0.0 + position_info[key]["eth_value"] = 0.0 + position_info[key]["avg_px"] = 0.0 + position_info[key]["upl"] = 0.0 + position_info[key]["entry_time"] = 0 + + # 更新有效仓位 + for pos in positions: + if isinstance(pos, dict) and pos.get("instId") == SYMBOL: + pos_side = pos.get("posSide", "net").lower() + pos_key = get_position_key(SYMBOL, pos_side) + + contract_size = float(pos.get("pos", "0")) + eth_value = calculate_contract_value(contract_size) + + position_info[pos_key]["pos"] = contract_size + position_info[pos_key]["eth_value"] = eth_value + position_info[pos_key]["avg_px"] = float(pos.get("avgPx", "0")) + position_info[pos_key]["upl"] = float(pos.get("upl", "0")) + + log_action("仓位更新", + f"{pos_side}仓: {contract_size}张 ({eth_value:.6f} ETH)", + "debug") + return True + return False + except Exception as e: + log_action("仓位查询", f"请求失败: {str(e)}", "error", exc_info=True) + return False + + +async def update_current_price(): + """获取当前价格 - 通过API备选 - 使用线程池执行阻塞调用""" + global current_price, last_price, last_api_price_update, price_source + + try: + log_action("价格查询", "发送价格查询请求", "debug") + response = await asyncio.to_thread(market_api.get_ticker, instId=SYMBOL) + + log_action("价格查询", "收到价格查询响应", "debug", response) + + if response.get("code") == "0" and response.get("data"): + price = float(response["data"][0]["last"]) + + # 更新价格和时间戳 + last_price = current_price + current_price = price + last_api_price_update = time.time() + price_source = "rest_api" + + log_action("API价格更新", f"${last_price:.4f} → ${price:.4f}", "info") + return True + return False + except Exception as e: + log_action("价格查询", f"请求失败: {str(e)}", "error", exc_info=True) + return False + + +async def ensure_price_freshness(): + """确保价格足够新鲜 - 增加详细日志""" + global price_source + + # 记录当前状态 + ws_freshness = time.time() - last_ws_price_update + api_freshness = time.time() - last_api_price_update + log_action("价格新鲜度", + f"WS新鲜度: {ws_freshness:.2f}s | API新鲜度: {api_freshness:.2f}s", + "debug") + + # 如果WebSocket价格在2秒内更新过,使用WebSocket价格 + if ws_freshness <= 2.0: + price_source = "websocket" + log_action("价格新鲜度", f"使用WS价格(新鲜度:{ws_freshness:.2f}s)", "debug") + return True + + # 如果API价格在5秒内更新过,使用API价格 + if api_freshness <= 5.0: + price_source = "rest_api" + log_action("价格新鲜度", f"使用API价格(新鲜度:{api_freshness:.2f}s)", "debug") + return True + + # 主动更新价格 + log_action("价格更新", "价格数据过期,主动获取最新价格", "warning") + return await update_current_price() + + +# ======================== 订单管理 ======================== +class RateLimiter: + """OKX API速率限制管理器 - 增加详细日志""" + + def __init__(self): + self.last_request_time = 0 + self.request_count = 0 + self.window_start = time.time() + # 默认通用上限(保守),但支持自定义短时间窗口限制 + self.max_orders_per_window = 300 # 每2秒300个订单(可视为上限,谨慎使用) + self.window_seconds = 2 + + async def check_limit(self, orders_count, max_per_window=None, window_seconds=None): + """检查并遵守API速率限制。允许传入自定义限速(用于特定端点)""" + # 支持自定义值 + max_orders = max_per_window if max_per_window is not None else self.max_orders_per_window + win_seconds = window_seconds if window_seconds is not None else self.window_seconds + + current_time = time.time() + + # 检查当前时间窗口 + window_elapsed = current_time - self.window_start + if window_elapsed > win_seconds: + log_action("限速器", f"窗口重置 (已过{window_elapsed:.2f}s > {win_seconds}s)", "debug") + self.request_count = 0 + self.window_start = current_time + window_elapsed = 0 + + # 预测请求后是否超出限制 + predicted_count = self.request_count + orders_count + while predicted_count > max_orders: + # 计算需要等待的时间 + wait_time = max(0.0, win_seconds - window_elapsed + 0.1) + log_action("限速等待", + f"需等待{wait_time:.2f}秒 (当前{self.request_count}/{max_orders},请求后{predicted_count})", + "warning") + await asyncio.sleep(wait_time) + current_time = time.time() + window_elapsed = current_time - self.window_start + + if window_elapsed > win_seconds: + log_action("限速器", f"等待后窗口重置 (已过{window_elapsed:.2f}s > {win_seconds}s)", "debug") + self.request_count = 0 + self.window_start = current_time + break + + predicted_count = self.request_count + orders_count + + # 避免请求过快(即使未达限制) + if current_time - self.last_request_time < 0.1: + wait_time = 0.1 - (current_time - self.last_request_time) + log_action("限速等待", f"避免请求过快,等待{wait_time:.3f}s", "debug") + await asyncio.sleep(wait_time) + + # 更新计数器 + self.request_count += orders_count + self.last_request_time = time.time() + log_action("限速器", f"请求计数: {self.request_count}/{max_orders}", "debug", + {"订单数": orders_count, "窗口秒数": win_seconds}) + + +# 全局速率限制器实例 +rate_limiter = RateLimiter() + + +async def check_rate_limit(orders_count, max_per_window=None, window_seconds=None): + """应用速率限制""" + await rate_limiter.check_limit(orders_count, max_per_window=max_per_window, window_seconds=window_seconds) + + +async def fetch_instrument_info_from_api(): + """使用 OKX API 获取指定合约的基础信息并更新 CONTRACT_INFO(遵守 20 requests / 2s 限速)""" + global CONTRACT_INFO, TICK_SIZE, SYMBOL + try: + # OKX 文档:获取交易产品基础信息,限速 20 次/2s(User ID + Instrument Type) + await check_rate_limit(1, max_per_window=20, window_seconds=2) + + log_action("合约信息", f"请求合约信息: instType={CONTRACT_INFO.get('instType','SWAP')} instId={CONTRACT_INFO.get('symbol')}", "debug") + # account_api.get_instruments(instType="SWAP") or with instId + response = await asyncio.to_thread(account_api.get_instruments, instType=CONTRACT_INFO.get("instType", "SWAP")) + log_action("合约信息", "收到合约信息响应", "debug", response) + + if response.get("code") == "0" and response.get("data"): + # 寻找匹配的 instId(SYMBOL 形如 ETH-USDT-SWAP) + inst_list = response.get("data", []) + # Normalize symbol: OKX api uses instId like 'ETH-USDT-SWAP' depending on API; try match by prefix + target = CONTRACT_INFO.get("symbol") + found = None + for item in inst_list: + # fields often instId, tickSz, minSz, ctVal, lotSz, ctValCcy + inst_id = item.get("instId") or item.get("inst_id") or "" + if inst_id == target or inst_id.startswith(target.split('-')[0] + '-'): + found = item + break + if not found: + # fallback: try exact match where instId may be without '-SWAP' + for item in inst_list: + inst_id = item.get("instId") or "" + if target.split('-')[0] in inst_id and 'USDT' in inst_id: + found = item + break + + if found: + # Parse and update CONTRACT_INFO fields if present + minSz = safe_float(found.get("minSz", CONTRACT_INFO["minSz"])) + tickSz = safe_float(found.get("tickSz", CONTRACT_INFO["tickSz"])) + ctVal = found.get("ctVal", CONTRACT_INFO["ctVal"]) + try: + ctVal = float(ctVal) if ctVal not in (None, "", []) else CONTRACT_INFO["ctVal"] + except Exception: + ctVal = CONTRACT_INFO["ctVal"] + lotSz = found.get("lotSz", CONTRACT_INFO["lotSz"]) + try: + lotSz = float(lotSz) if lotSz not in (None, "", []) else CONTRACT_INFO["lotSz"] + except Exception: + lotSz = CONTRACT_INFO["lotSz"] + ctValCcy = found.get("ctValCcy", CONTRACT_INFO.get("ctValCcy", "ETH")) + + CONTRACT_INFO.update({ + "minSz": minSz, + "tickSz": tickSz, + "ctVal": ctVal, + "lotSz": lotSz, + "ctValCcy": ctValCcy + }) + + # Update derived globals + TICK_SIZE = CONTRACT_INFO["tickSz"] + SYMBOL = CONTRACT_INFO["symbol"] + + log_action("合约信息", f"已更新 CONTRACT_INFO: minSz={minSz}, tickSz={tickSz}, ctVal={ctVal}, lotSz={lotSz}", "info") + return True + else: + log_action("合约信息", f"未在返回列表中找到匹配合约: {CONTRACT_INFO.get('symbol')}", "warning", {"returned_count": len(inst_list)}) + return False + else: + log_action("合约信息", f"获取合约信息失败: {response.get('msg','')}", "warning", response) + return False + except Exception as e: + log_action("合约信息", f"获取合约信息异常: {e}", "error", exc_info=True) + return False + + +async def cancel_all_orders(): + """取消所有活跃订单 - 增加详细日志""" + try: + if active_orders: + order_ids = list(active_orders.keys()) + order_count = len(order_ids) + log_action("订单取消", f"取消{order_count}个订单", "info") + + # 应用速率限制 + await check_rate_limit(order_count) + + # 批量取消请求 + cancel_reqs = [{"instId": SYMBOL, "clOrdId": cl_ord_id} for cl_ord_id in order_ids] + log_action("批量取消", "发送批量取消请求", "debug", {"订单列表": order_ids}) + # trade_api.cancel_multiple_orders likely blocking; run in thread + response = await asyncio.to_thread(trade_api.cancel_multiple_orders, cancel_reqs) + + # 记录完整响应 + log_action("批量取消", "收到批量取消响应", "debug", response) + + if response.get("code") == "0": + active_orders.clear() + return True + else: + log_action("批量取消", "批量取消返回非 0 code,未清除本地订单", "warning", response) + # Do not clear locally unless confirmed; attempt best-effort update + return False + return True # 没有订单也算成功 + except Exception as e: + log_action("取消订单", f"取消失败: {str(e)}", "error", exc_info=True) + return False + + +async def cancel_single_order(cl_ord_id): + """取消单个订单 - 增加详细日志""" + try: + # 应用速率限制 + await check_rate_limit(1) + + request = {"instId": SYMBOL, "clOrdId": cl_ord_id} + log_action("取消订单", f"发送取消请求: {cl_ord_id}", "debug", request) + response = await asyncio.to_thread(trade_api.cancel_order, **request) + + # 记录完整响应 + log_action("取消订单", f"收到取消响应: {cl_ord_id}", "debug", response) + + if str(response.get("code", "")) == "0": + log_action("订单取消", f"订单 {cl_ord_id} 取消成功") + # 更新本地 active_orders + if cl_ord_id in active_orders: + del active_orders[cl_ord_id] + return True + else: + log_action("订单取消", f"订单 {cl_ord_id} 取消失败: {response.get('msg', '未知错误')}", "warning") + return False + except Exception as e: + log_action("取消订单", f"取消订单 {cl_ord_id} 失败: {str(e)}", "error") + return False + + +# ======================== 新增: A1阶段 - 仓位建立 ======================== +async def place_market_setup_order(adjust_size): + """根据当前趋势方向市价建仓 - 使用线程池执行阻塞 trade_api 调用""" + # 确保数量有效 + if adjust_size <= 0: + log_action("A1下单", "调整量为0,无需建仓", "warning") + return True + + try: + # 确定下单方向 + side = "buy" if trading_direction == "long" else "sell" + + # 生成订单ID + cl_ord_id = generate_order_id(f"A1_{trading_direction[:1]}") + + # 验证并调整仓位大小 + validated_size = validate_position_size(adjust_size) + eth_value = calculate_contract_value(validated_size) + + request = { + "instId": SYMBOL, + "tdMode": "isolated", + "clOrdId": cl_ord_id, + "side": side, + "posSide": trading_direction, + "ordType": "market", + "sz": str(validated_size) + } + + # 应用速率限制 + await check_rate_limit(1) + + # 发送下单请求(线程池) + log_action("A1下单", "发送市价开仓请求", "debug", request) + response = await asyncio.to_thread(trade_api.place_order, **request) + + # 记录完整响应 + log_action("A1下单", "收到下单响应", "debug", response) + + if str(response.get("code", "")) == "0": + # 记录活跃订单 + ord_id = response['data'][0].get('ordId', '') + active_orders[cl_ord_id] = { + "ord_id": ord_id, + "state": "live", + "type": "market", + "tag": "A1_SETUP", + "create_time": time.time(), + "side": side, + "posSide": trading_direction + } + + log_action("A1下单", f"市价{trading_direction}建仓单已提交", "info", { + "调整仓位": f"{validated_size}张 ({eth_value:.6f} ETH)", + "订单ID": cl_ord_id + }) + return True + + # 处理下单失败 + log_action("A1下单", f"市价建仓失败: {response.get('msg', '未知错误')}", "error", response) + return False + except Exception as e: + log_action("A1下单", f"市价建仓异常: {str(e)}", "error", exc_info=True) + return False + + +# ======================== 重构: A2阶段 - 挂订单对 (OKX批量下单) ======================== +async def calculate_dynamic_offset(): + """计算价格偏移 - 简化版本,只使用固定偏移""" + base_offset = TRADE_STRATEGY["price_offset"] + # 记录简单的调试信息 + log_action("价格偏移", f"使用固定偏移: {base_offset * 100:.2f}% ({base_offset})", "debug") + return base_offset + + +async def place_open_order_only(): + """只挂开仓单的逻辑""" + # 生成订单ID + open_cl_ord_id = generate_order_id("OPEN_ONLY") + + # 计算开仓价格 + offset_percent = await calculate_dynamic_offset() + # 使用 Decimal 以提升精度 + if trading_direction == "long": + open_price = round_price(Decimal(str(current_price)) * (Decimal('1') - Decimal(str(offset_percent)))) + else: + open_price = round_price(Decimal(str(current_price)) * (Decimal('1') + Decimal(str(offset_percent)))) + + # 获取仓位大小 + contract_size = get_position_size() + eth_value = calculate_contract_value(contract_size) + + # 创建开仓单请求 + request = { + "instId": SYMBOL, + "tdMode": "isolated", + "clOrdId": open_cl_ord_id, + "side": "buy" if trading_direction == "long" else "sell", + "posSide": trading_direction, + "ordType": "limit", + "px": str(open_price), + "sz": str(contract_size), + "reduceOnly": False + } + + try: + await check_rate_limit(1) + response = await asyncio.to_thread(trade_api.place_order, **request) + + if str(response.get("code", "")) == "0": + ord_id = response['data'][0].get('ordId', '') + active_orders[open_cl_ord_id] = { + "ord_id": ord_id, + "state": "live", + "type": "limit", + "tag": "OPEN_ONLY", + "create_time": time.time(), + "side": request["side"], + "posSide": trading_direction, + "px": open_price, + "sz": contract_size + } + log_action("开仓单", f"开仓单已提交 {open_cl_ord_id}", "info", { + "价格": open_price, + "方向": trading_direction, + "数量": f"{contract_size}张 ({eth_value:.6f} ETH)" + }) + return True + log_action("开仓单", f"开仓单失败: {response.get('msg', '未知错误')}", "error", response) + return False + except Exception as e: + log_action("开仓单", f"挂单失败: {str(e)}", "error", exc_info=True) + return False + + +async def place_full_order_pair(offset_percent): + """挂完整订单对""" + # 生成订单对ID + pair_id = generate_order_id("PAIR") + open_tag = "BL" if trading_direction == "long" else "SS" # BL = Buy Long, SS = Sell Short + close_tag = "SL" if trading_direction == "short" else "SC" # SL = Sell Long, SC = Cover Short + + # 计算开仓和平仓价格(使用 Decimal 再转 float via round_price) + if trading_direction == "long": + open_price = round_price(Decimal(str(current_price)) * (Decimal('1') - Decimal(str(offset_percent)))) + close_price = round_price(Decimal(str(current_price)) * (Decimal('1') + Decimal(str(offset_percent)))) + else: + open_price = round_price(Decimal(str(current_price)) * (Decimal('1') + Decimal(str(offset_percent)))) + close_price = round_price(Decimal(str(current_price)) * (Decimal('1') - Decimal(str(offset_percent)))) + + # 获取基础仓位大小 + base_contract_size = get_position_size() + + # 获取开仓增量配置值 + order_increment = TRADE_STRATEGY["order_increment"] + + # 开仓单增加配置的增量 + open_contract_size = base_contract_size + order_increment + open_contract_size = validate_position_size(open_contract_size) + + # 平仓单保持原大小 + close_contract_size = base_contract_size + + # 计算ETH价值 + base_eth_value = calculate_contract_value(base_contract_size) + increment_eth_value = calculate_contract_value(order_increment) + open_eth_value = calculate_contract_value(open_contract_size) + close_eth_value = calculate_contract_value(close_contract_size) + + # 创建开仓单 + open_cl_ord_id = generate_order_id(f"{open_tag}_{pair_id}") + open_request = { + "instId": SYMBOL, + "tdMode": "isolated", + "clOrdId": open_cl_ord_id, + "side": "buy" if trading_direction == "long" else "sell", + "posSide": trading_direction, + "ordType": "limit", + "px": str(open_price), + "sz": str(open_contract_size), # 使用修改后的开仓数量 + "reduceOnly": False + } + + # 创建平仓单 + close_cl_ord_id = generate_order_id(f"{close_tag}_{pair_id}") + close_request = { + "instId": SYMBOL, + "tdMode": "isolated", + "clOrdId": close_cl_ord_id, + "side": "sell" if trading_direction == "long" else "buy", + "posSide": trading_direction, + "ordType": "limit", + "px": str(close_price), + "sz": str(close_contract_size), # 使用原始平仓数量 + "reduceOnly": True + } + + # 批量订单请求 + batch_requests = [open_request, close_request] + + # 记录订单对关系(提前记录以处理回调) + order_pair_mapping[pair_id] = { + "open_cl_ord_id": open_cl_ord_id, + "close_cl_ord_id": close_cl_ord_id, + "status": "pending", + "create_time": time.time() + } + + try: + # 应用速率限制 (2个订单) + await check_rate_limit(2) + + # 提交批量订单 + log_action("批量下单", "提交批量订单请求", "info", { + "订单对ID": pair_id, + "开仓请求": open_request, + "平仓请求": close_request, + "价格偏移": f"{offset_percent * 100:.2f}%", + "当前价格": current_price, + "基础仓位": f"{base_contract_size}张 ({base_eth_value:.6f} ETH)", + "开仓增量": f"{order_increment}张 ({increment_eth_value:.6f} ETH)", + "开仓总量": f"{open_contract_size}张 ({open_eth_value:.6f} ETH)" + }) + response = await asyncio.to_thread(trade_api.place_multiple_orders, batch_requests) + + # 记录完整响应 + log_action("批量下单", "收到批量下单响应", "debug", response) + + # 检查主响应代码 + if response.get("code") != "0": + log_action("批量下单", "批量接口主响应错误", "error", response) + del order_pair_mapping[pair_id] + return False + + # 处理每个订单的响应 + order_data = response.get("data", []) + success_orders = [] + failure_orders = [] + + for result in order_data: + cl_ord_id = result.get("clOrdId", "") + s_code = result.get("sCode", "") + ord_id = result.get("ordId", "") # 服务器分配的订单ID + s_msg = result.get("sMsg", "") + + if s_code == "0": # 成功 + success_orders.append({ + "cl_ord_id": cl_ord_id, + "ord_id": ord_id, + "request": next((r for r in batch_requests if r["clOrdId"] == cl_ord_id), None) + }) + else: + failure_orders.append({ + "cl_ord_id": cl_ord_id, + "s_code": s_code, + "s_msg": s_msg, + "ord_id": ord_id + }) + + # 处理部分失败情况 + if failure_orders: + log_action("批量下单", f"{len(failure_orders)}个订单失败", "error", failure_orders) + + # 取消已成功的订单 + if success_orders: + for order in success_orders: + await cancel_single_order(order["cl_ord_id"]) + log_action("批量下单", f"已取消{len(success_orders)}个成功订单", "warning") + + # 清理订单对映射 + if pair_id in order_pair_mapping: + del order_pair_mapping[pair_id] + return False + + # 记录活跃订单 + for order in success_orders: + # 确定订单类型(开仓单还是平仓单) + if open_tag in order["cl_ord_id"]: + order_type = "open" + tag = open_tag + else: + order_type = "close" + tag = close_tag + + # 记录到活跃订单 + active_orders[order["cl_ord_id"]] = { + "ord_id": order["ord_id"], + "state": "live", + "type": "limit", + "tag": tag, + "create_time": time.time(), + "side": order["request"]["side"], + "posSide": trading_direction, + "pair_id": pair_id, + "px": order["request"]["px"], + "sz": order["request"]["sz"] + } + + log_action("订单记录", f"{order_type}订单已记录", "info", { + "cl_ord_id": order["cl_ord_id"], + "ord_id": order["ord_id"], + "price": order["request"]["px"], + "size": f"{order['request']['sz']}张" + }) + + # 更新订单对状态 + order_pair_mapping[pair_id]["status"] = "active" + log_action("批量下单", "✅ 订单对挂单成功", "info", { + "pair_id": pair_id, + "开仓价": open_price, + "平仓价": close_price, + "价格偏移": f"{offset_percent * 100:.2f}%", + "当前价格": current_price, + "基础仓位": f"{base_contract_size}张 ({base_eth_value:.6f} ETH)", + "开仓增量": f"{order_increment}张 ({increment_eth_value:.6f} ETH)", + "开仓总量": f"{open_contract_size}张 ({open_eth_value:.6f} ETH)" + }) + return True + + except Exception as e: + log_action("批量下单", f"批量下单异常: {str(e)}", "error", exc_info=True) + if pair_id in order_pair_mapping: + del order_pair_mapping[pair_id] + return False + + +async def place_order_pair(): + """挂完整订单对 - 简化版本,不考虑持仓均价""" + # 确保价格新鲜 + if not await ensure_price_freshness(): + log_action("A2下单", "无法获取最新价格", "error") + return False + + # 计算动态偏移量 + offset_percent = await calculate_dynamic_offset() + + # 直接挂完整订单对 + log_action("订单对", "无价格限制,挂完整订单对", "info") + return await place_full_order_pair(offset_percent) + + +# ======================== 趋势分析 ======================== +async def analyze_trend(): + """分析Supertrend趋势并返回方向 - 支持 fixed 或 auto 模式""" + global trading_direction, last_trend_check + + # 更新最后趋势检查时间 + last_trend_check = time.time() + + mode = TRADE_STRATEGY.get("trend_mode", "fixed") + if mode == "fixed": + fixed_direction = TRADE_STRATEGY.get("fixed_trend_direction", "long") + if trading_direction != fixed_direction: + log_action("趋势更新", + f"更新趋势方向: {trading_direction} → {fixed_direction}", + "warning", + {"reason": "使用配置的固定趋势方向"}) + trading_direction = fixed_direction + # 返回 False 表示没有触发方向变化事件外部需处理 + return False + else: + # auto 模式:调用 supertrend_analyzer(假设存在同步接口 get_direction 或 analyze) + try: + # 将可能阻塞的计算放到线程池 + direction = await asyncio.to_thread(supertrend_analyzer.get_direction) + # direction 应返回 'long' 或 'short' + if direction not in ["long", "short"]: + log_action("趋势分析", f"Supertrend 返回的方向不在预期集合: {direction}", "error") + return False + if direction != trading_direction: + log_action("趋势更新", f"自动趋势更新: {trading_direction} → {direction}", "info", {"source": "supertrend"}) + trading_direction = direction + return True + return False + except Exception as e: + log_action("趋势分析", f"Supertrend 分析失败: {e}", "error") + return False + + +async def close_position(pos_side): + """平掉指定方向的仓位 - 增加详细日志""" + if pos_side not in ["long", "short"]: + log_action("平仓", f"非法平仓方向: {pos_side}", "error") + return False + + pos_key = get_position_key(SYMBOL, pos_side) + pos_size = position_info[pos_key]["pos"] + + if pos_size <= 0: + log_action("平仓", f"{pos_side} 仓位为0,无需平仓", "info") + return True + + try: + # 市价平仓 + side = "sell" if pos_side == "long" else "buy" + cl_ord_id = generate_order_id(f"MC_{pos_side[:1]}") + + # 验证仓位大小 + validated_size = validate_position_size(pos_size) + eth_value = calculate_contract_value(validated_size) + + request = { + "instId": SYMBOL, + "tdMode": "isolated", + "clOrdId": cl_ord_id, + "side": side, + "posSide": pos_side, + "ordType": "market", + "sz": str(validated_size), + "reduceOnly": True + } + + # 发送下单请求(线程池) + log_action("平仓下单", "发送平仓请求", "debug", request) + response = await asyncio.to_thread(trade_api.place_order, **request) + + # 记录完整响应 + log_action("平仓下单", "收到平仓响应", "debug", response) + + if str(response.get("code", "")) == "0": + log_action("平仓下单", f"{pos_side} 仓位市价平仓提交成功", "info", { + "数量": f"{validated_size}张 ({eth_value:.6f} ETH)", + "订单ID": cl_ord_id + }) + # 在本地更新仓位信息为 0(由于我们没有等待 websocket 真实成交通知) + position_info[pos_key]["pos"] = 0.0 + return True + else: + log_action("平仓下单", f"市价平仓失败: {response.get('msg', '未知错误')}", "error", response) + return False + except Exception as e: + log_action("平仓下单", f"市价平仓异常: {str(e)}", "error", exc_info=True) + return False + + +# ======================== WebSocket 及回调占位(测试用) ======================== +# 为了测试流程,提供最小的 WebSocket 占位(不做实际连接) +async def ws_price_update_simulator(): + """模拟 WebSocket 价格更新 - 在测试时可启动""" + global current_price, last_ws_price_update + while True: + # 模拟价格微动 + if current_price == 0: + current_price = 1000.0 # 初始测试价格 + else: + current_price = current_price * (1 + (random.random() - 0.5) * 0.001) + last_ws_price_update = time.time() + await asyncio.sleep(0.5) + + +# ======================== 主流程入口(测试/调试专用) ======================== +async def main_loop_once(): + """单次运行逻辑,用于测试脚本在没有真实 websocket 回调的环境下运行""" + # 尝试从 API 获取合约信息以覆盖本地 CONTRACT_INFO(如果成功则会修正 minSz/tickSz/ctVal 等) + await fetch_instrument_info_from_api() + + # 刷新账户、仓位、价格 + await update_current_price() + await fetch_account_balance() + await update_position_info() + + # 确保 order_increment 配置有效 + validate_order_increment() + + # 分析趋势(fixed 或 auto) + await analyze_trend() + + # 根据阶段与策略做简单动作(示例) + # 如果没有活跃订单,挂一对订单作为测试 + if not active_orders: + log_action("主流程", "当前无活跃订单,尝试挂一对订单", "info") + success = await place_order_pair() + log_action("主流程", f"挂单对结果: {success}", "info") + else: + log_action("主流程", f"已有活跃订单: {len(active_orders)},跳过挂单", "debug") + + # 记录周期状态 + log_periodic_status() + + +async def main(run_forever=False): + """主入口,用于测试与调试""" + # 启动一个价格模拟器(测试用) + sim_task = asyncio.create_task(ws_price_update_simulator()) + + try: + # 单次运行并退出,或持续运行 + if run_forever: + while True: + await main_loop_once() + await asyncio.sleep(5) + else: + await main_loop_once() + finally: + sim_task.cancel() + try: + await sim_task + except asyncio.CancelledError: + pass + + +if __name__ == "__main__": + # 便于测试:运行一次主流程 + asyncio.run(main(run_forever=False)) diff --git a/test/eth_trade_bot.py b/test/eth_trade_bot.py new file mode 100644 index 0000000..cd088df --- /dev/null +++ b/test/eth_trade_bot.py @@ -0,0 +1,1122 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +ETH perpetual trading bot for OKX (SDK-only, testnet-ready). + +Security: +- Do NOT hard-code API keys into source files. +- This script reads credentials from environment variables: + OKX_API_KEY, OKX_SECRET_KEY, OKX_PASSPHRASE, OKX_FLAG +- For quick local testing, export/set these env vars in your shell before running. + +This file includes: + - SDK init (Account/Trade/MarketData, PrivateWs/PublicWs) + - REST helpers (fetch instrument, balances, positions, orders) + - WebSocket handling (private channels: account, orders, fills, positions, balance_and_position; public: tickers) + - Strategy helpers: continuous-eaten protection, gatekeeper, dynamic offset, place_pair_if_ok +""" + +import asyncio +import time +import json +import logging +import random +import os +import sys +from datetime import datetime, timezone +from collections import defaultdict, deque +from decimal import Decimal, getcontext, ROUND_HALF_UP +from typing import Optional, Dict, Any, List +import math + +# ---------------- Credentials (from environment) ---------------- +# API_KEY = os.getenv("OKX_API_KEY", "") +# SECRET_KEY = os.getenv("OKX_SECRET_KEY", "") +# PASSPHRASE = os.getenv("OKX_PASSPHRASE", "") +# OKX_FLAG = os.getenv("OKX_FLAG", "1") # "1" for testnet, "0" for mainnet + +API_KEY = os.getenv("OKX_API_KEY", "52c6b3db-8827-477d-8e25-9c8b14d816e7") +SECRET_KEY = os.getenv("OKX_SECRET_KEY", "6AA11170CBC857418B3FEA38127703CA") +PASSPHRASE = os.getenv("OKX_PASSPHRASE", "Jinquan169..") +OKX_FLAG = os.getenv("OKX_FLAG", "1") # "1" for testnet, "0" for mainnet + +# ---------------- SDK imports ---------------- +try: + from okx.websocket.WsPrivateAsync import WsPrivateAsync as PrivateWs + from okx.websocket.WsPublicAsync import WsPublicAsync as PublicWs + import okx.Account as Account + import okx.Trade as Trade + import okx.MarketData as MarketData +except Exception: + PrivateWs = None + PublicWs = None + Account = None + Trade = None + MarketData = None + +getcontext().prec = 28 + +# ---------------- Logging ---------------- +LOG_DIR = "logs" +os.makedirs(LOG_DIR, exist_ok=True) +LOG_FILE = os.path.join(LOG_DIR, "trading.log") + +logger = logging.getLogger("eth_trade_bot") +logger.setLevel(logging.DEBUG) +for h in list(logger.handlers): + logger.removeHandler(h) +fh = logging.FileHandler(LOG_FILE, encoding="utf-8") +fh.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +logger.addHandler(fh) +try: + if hasattr(sys.stdout, "reconfigure"): + try: + sys.stdout.reconfigure(encoding="utf-8") + except Exception: + pass +except Exception: + pass +ch = logging.StreamHandler(sys.stdout) +ch.setFormatter(logging.Formatter('%(message)s')) +logger.addHandler(ch) + + +def log_action(action: str, details: str, level: str = "info", extra: Optional[dict] = None, exc_info: bool = False): + prefix = { + "debug": "🔵", + "info": "🟢", + "warning": "🟠", + "error": "🔴", + "critical": "⛔" + }.get(level, "⚪") + ts = datetime.now(timezone.utc).isoformat() + msg = f"[{ts}] {prefix} {action} - {details}" + if extra: + try: + msg += " | " + json.dumps(extra, ensure_ascii=False) + except Exception: + msg += f" | {extra}" + try: + if level == "debug": + logger.debug(msg, exc_info=exc_info) + elif level == "warning": + logger.warning(msg, exc_info=exc_info) + elif level == "error": + logger.error(msg, exc_info=exc_info) + elif level == "critical": + logger.critical(msg, exc_info=exc_info) + else: + logger.info(msg, exc_info=exc_info) + except UnicodeEncodeError: + safe_msg = msg.encode("utf-8", errors="replace").decode("utf-8") + if level == "debug": + logger.debug(safe_msg, exc_info=exc_info) + elif level == "warning": + logger.warning(safe_msg, exc_info=exc_info) + elif level == "error": + logger.error(safe_msg, exc_info=exc_info) + elif level == "critical": + logger.critical(safe_msg, exc_info=exc_info) + else: + logger.info(safe_msg, exc_info=exc_info) + +# ---------------- Globals & Config ---------------- +SYMBOL = "ETH-USDT-SWAP" +CONTRACT_INFO: Dict[str, Any] = {"symbol": SYMBOL, "minSz": 0.0, "tickSz": 0.0, "ctVal": 0.0, "ctValCcy": "ETH", "instType": "SWAP"} +TICK_SIZE = CONTRACT_INFO["tickSz"] + +STRATEGY = { + "base_notional_fraction": 0.25, + "leverage": 5, + "price_offset": 0.001, + "expected_hold_seconds": 300, + "expected_slippage_pct": 0.0002, + "order_type": "limit" +} + +# Runtime state +account_equity_usdt = 0.0 +initial_equity_usdt = 0.0 +current_price = 0.0 +last_price = 0.0 +price_source = "unknown" + +active_orders: Dict[str, dict] = {} +position_info = defaultdict(lambda: {"pos": 0.0, "avg_px": 0.0, "usdt_value": 0.0}) + +# pairs & lock +active_pairs: Dict[str, dict] = {} +orders_lock = asyncio.Lock() + +# SDK clients +account_api = None +trade_api = None +market_api = None + +# WS instances +_ws_instance = None +_public_ws_instance = None + +# Dedup sets +seen_trade_ids = set() +seen_filled_ordids = set() +seen_reqids = set() + +# Public ticker best bid/ask +best_bid = 0.0 +best_ask = 0.0 + +# Controls +ENABLE_FILLS_CHANNEL = False +ENABLE_PUBLIC_TICKER = True + +# ---------------- Utilities ---------------- +def safe_float(v, default=0.0): + try: + if v is None or v == "": + return default + return float(v) + except Exception: + return default + + +def generate_order_id(prefix: str = "o"): + suffix = ''.join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10)) + ts = int(time.time() * 1000) % 1000000 + return f"{prefix}{ts}{suffix}"[:32] + + +def round_price_by_tick(p: float, tick: float): + try: + if tick <= 0: + return p + q = Decimal(str(p)) / Decimal(str(tick)) + qr = q.quantize(Decimal("1"), rounding=ROUND_HALF_UP) + return float((qr * Decimal(str(tick))).normalize()) + except Exception: + return p + + +def round_to_min_size(size: float) -> float: + min_sz = CONTRACT_INFO.get("minSz", 0) or 0 + try: + m = Decimal(str(min_sz)) + s = Decimal(str(size)) + if m == 0: + return float(s) + q = (s / m).quantize(Decimal("1"), rounding=ROUND_HALF_UP) + rounded = (q * m).normalize() + return float(rounded) + except Exception: + return float(size) + + +# ---------------- Rate limiter ---------------- +class RateLimiter: + def __init__(self): + self.window_start = time.time() + self.count = 0 + self.window_seconds = 2 + self.max_per_window = 300 + self.last_request_time = 0.0 + + async def wait_for(self, n=1, max_per_window=None, window_seconds=None): + max_w = max_per_window if max_per_window is not None else self.max_per_window + win = window_seconds if window_seconds is not None else self.window_seconds + while True: + now = time.time() + if now - self.window_start >= win: + self.window_start = now + self.count = 0 + if self.count + n <= max_w: + if now - self.last_request_time < 0.05: + await asyncio.sleep(max(0.0, 0.05 - (now - self.last_request_time))) + self.count += n + self.last_request_time = time.time() + return + await asyncio.sleep(0.05) + + +rate_limiter = RateLimiter() + +# ---------------- Initialization ---------------- +def initialize_clients(): + global account_api, trade_api, market_api + if Account is None or Trade is None or MarketData is None: + log_action("初始化", "OKX SDK 未安装,请 pip install python-okx", "error") + raise RuntimeError("OKX SDK not installed") + account_api = Account.AccountAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, OKX_FLAG) + trade_api = Trade.TradeAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, OKX_FLAG) + market_api = MarketData.MarketAPI(API_KEY, SECRET_KEY, PASSPHRASE, False, OKX_FLAG) + log_action("初始化", f"OKX SDK 已初始化 (flag={OKX_FLAG})", "info") + + +# ---------------- REST helpers ---------------- +async def fetch_instrument_info(): + global CONTRACT_INFO, TICK_SIZE, SYMBOL + await rate_limiter.wait_for(1, max_per_window=20, window_seconds=2) + try: + resp = await asyncio.to_thread(account_api.get_instruments, instType=CONTRACT_INFO.get("instType", "SWAP"), instId=SYMBOL) + except TypeError: + resp = await asyncio.to_thread(account_api.get_instruments, CONTRACT_INFO.get("instType", "SWAP"), SYMBOL) + log_action("合约", "拿到合约信息", "debug", {"resp_code": resp.get("code") if isinstance(resp, dict) else None}) + if not isinstance(resp, dict) or str(resp.get("code", "")) != "0" or not resp.get("data"): + log_action("合约", "获取合约信息失败", "warning", resp) + return False + data = resp.get("data", []) + found = None + for it in data: + if it.get("instId") == SYMBOL: + found = it + break + if not found: + for it in data: + if it.get("instId", "").startswith(SYMBOL.split("-")[0]) and "USDT" in it.get("instId", ""): + found = it + break + if not found: + log_action("合约", "没有找到合约信息", "error", {"symbol": SYMBOL}) + return False + CONTRACT_INFO["minSz"] = safe_float(found.get("minSz", CONTRACT_INFO.get("minSz", 0))) + CONTRACT_INFO["tickSz"] = safe_float(found.get("tickSz", CONTRACT_INFO.get("tickSz", 0))) + CONTRACT_INFO["ctVal"] = safe_float(found.get("ctVal", CONTRACT_INFO.get("ctVal", 0))) + CONTRACT_INFO["ctValCcy"] = found.get("ctValCcy", CONTRACT_INFO.get("ctValCcy")) + TICK_SIZE = CONTRACT_INFO["tickSz"] + log_action("合约", "合约信息更新成功", "info", CONTRACT_INFO) + return True + + +async def update_price_from_rest(): + global current_price, last_price, price_source + await rate_limiter.wait_for(1) + try: + resp = await asyncio.to_thread(market_api.get_ticker, instId=SYMBOL) + except TypeError: + resp = await asyncio.to_thread(market_api.get_ticker, SYMBOL) + if isinstance(resp, dict) and str(resp.get("code", "")) in ("0", 0) and resp.get("data"): + d = resp["data"][0] + p = None + for k in ("last", "lastPx", "price"): + if k in d: + p = safe_float(d.get(k)) + break + if p is not None: + last_price = current_price + current_price = p + price_source = "rest" + log_action("价格", f"REST 价格更新 {current_price}", "debug") + return True + return False + + +async def fetch_account_and_positions(): + global account_equity_usdt, initial_equity_usdt + await rate_limiter.wait_for(1) + try: + resp = await asyncio.to_thread(account_api.get_account_balance, ccy="") + except TypeError: + resp = await asyncio.to_thread(account_api.get_account_balance, "") + log_action("账户", "余额获取(简略)", "debug") + try: + if isinstance(resp, dict) and resp.get("data"): + total = 0 + for grp in resp.get("data", []): + for d in grp.get("details", []) if grp.get("details") else []: + if (d.get("ccy") or d.get("currency") or "").upper() == "USDT": + total += safe_float(d.get("eq", d.get("availEq", 0)), 0.0) + if total > 0: + account_equity_usdt = float(total) + if initial_equity_usdt == 0.0: + initial_equity_usdt = account_equity_usdt + except Exception: + pass + try: + resp2 = await asyncio.to_thread(account_api.get_positions, instType=CONTRACT_INFO.get("instType", "SWAP"), instId=SYMBOL) + log_action("仓位", "仓位查询返回(简略)", "debug", resp2) + except Exception: + pass + + +# ---------------- Order helpers ---------------- +async def place_order_simple(side: str, pos_side: str, ord_type: str, sz: str, px: Optional[str] = None, cl_ord_id: Optional[str] = None): + cl = cl_ord_id or generate_order_id("cl") + try: + szf = float(sz) + except Exception: + szf = float(sz) + minsz = CONTRACT_INFO.get("minSz", 0) or 0.0 + if minsz > 0 and szf < minsz: + szf = minsz + sz_str = str(szf) + px_str = None + if px is not None: + px_str = str(round_price_by_tick(float(px), CONTRACT_INFO.get("tickSz", 0.0))) + req = {"instId": SYMBOL, "tdMode": "isolated", "clOrdId": cl, "side": side, "ordType": ord_type, "sz": sz_str} + if pos_side: + req["posSide"] = pos_side + if px_str: + req["px"] = px_str + await rate_limiter.wait_for(1, max_per_window=60, window_seconds=2) + try: + resp = await asyncio.to_thread(trade_api.place_order, **req) + except Exception as e: + log_action("下单", f"下单异常: {e}", "error", exc_info=True) + prev = active_orders.get(cl, {}) + prev.update({"ord_id": "", "cl": cl, "px": px_str, "sz": sz_str, "state": "error", "raw": str(e)}) + active_orders[cl] = prev + return {"code": "-1", "msg": str(e)} + if isinstance(resp, dict): + data0 = (resp.get("data") or [{}])[0] + ord_id = data0.get("ordId") or "" + prev = active_orders.get(cl, {}) + prev.update({"ord_id": ord_id, "cl": cl, "px": px_str, "sz": sz_str, "state": "accepted", "raw": data0}) + active_orders[cl] = prev + else: + prev = active_orders.get(cl, {}) + prev.update({"ord_id": "", "cl": cl, "px": px_str, "sz": sz_str, "state": "unknown", "raw": resp}) + active_orders[cl] = prev + log_action("下单", f"已提交订单 cl={cl}", "info", resp) + return resp + + +async def cancel_order_by_cl(cl: str): + await rate_limiter.wait_for(1, max_per_window=60, window_seconds=2) + try: + resp = await asyncio.to_thread(trade_api.cancel_order, instId=SYMBOL, clOrdId=cl) + log_action("撤单", f"撤单请求提交 cl={cl}", "info", resp) + return resp + except Exception as e: + log_action("撤单", f"撤单异常: {e}", "error", exc_info=True) + return {"code": "-1", "msg": str(e)} + + +# ---------------- Helper: find other cl in same pair ---------------- +def get_other_cl_of_active_pair(cl: str) -> Optional[str]: + """ + From active_pairs mapping return the other cl in the same pair, or None. + """ + try: + for pid, p in active_pairs.items(): + buy = p.get("buy") or {} + sell = p.get("sell") or {} + if buy.get("cl") == cl: + return sell.get("cl") + if sell.get("cl") == cl: + return buy.get("cl") + except Exception: + return None + return None + + +# ---------------- Exposure helper ---------------- +def net_exposure_exceeds_threshold(threshold_fraction: float = 0.3) -> bool: + """ + Compute an approximate net exposure in USDT and check if it exceeds threshold_fraction of equity. + Uses position_info[*]['usdt_value'] when available, otherwise estimates as pos*avg_px*ctVal. + """ + try: + equity = float(account_equity_usdt or 0.0) + if equity <= 0: + return False + net = 0.0 + for k, v in position_info.items(): + usdt_val = v.get("usdt_value", None) + if usdt_val is None: + pos = safe_float(v.get("pos", 0.0)) + avg_px = safe_float(v.get("avg_px", current_price or 0.0)) + ct = safe_float(v.get("ctVal", CONTRACT_INFO.get("ctVal", 1.0))) + est = pos * avg_px * (ct or 1.0) + usdt_val = est + net += safe_float(usdt_val, 0.0) + return abs(net) > (threshold_fraction * equity) + except Exception: + return False + + +# ---------------- Price ticks buffer helper ---------------- +_price_ticks = deque() + + +def feed_price_tick(price: float) -> None: + """ + Append a price tick (timestamp, price) into the rolling deque _price_ticks. + Keeps data for VOL_WINDOW_SECONDS seconds (used by estimate_short_volatility). + """ + try: + ts = time.time() + _price_ticks.append((ts, float(price))) + while _price_ticks and _price_ticks[0][0] < ts - VOL_WINDOW_SECONDS: + _price_ticks.popleft() + except Exception: + return + + +# ---------------- Reduce positions to safe level (implementation) ---------------- +async def reduce_positions_to_safe_level(target_fraction: float = 0.1): + """ + 将净暴露降到 equity * target_fraction 以内(示例实现)。 + - target_fraction: 目标净暴露占 equity 的比例,例如 0.1 表示 10%。 + 注意:该实现为示例,使用市价/IOC 下单可能导致滑点,请在 testnet 验证并按需改用限价分批。 + """ + try: + equity = float(account_equity_usdt or initial_equity_usdt or 0.0) + if equity <= 0: + log_action("风控", "账户权益未知或为0,无法降仓", "warning") + return + # compute current net exposure (USDT) + net = 0.0 + positions_to_reduce = [] + for k, v in position_info.items(): + # skip balance-only entries + if str(k).startswith("BAL-"): + continue + usdt_val = v.get("usdt_value", None) + if usdt_val is None: + pos = safe_float(v.get("pos", 0.0)) + avg_px = safe_float(v.get("avg_px", current_price or 0.0)) + ct = safe_float(v.get("ctVal", CONTRACT_INFO.get("ctVal", 1.0))) + est = pos * avg_px * (ct or 1.0) + usdt_val = est + net += safe_float(usdt_val, 0.0) + positions_to_reduce.append((k, v)) + target_exposure = equity * target_fraction + if abs(net) <= target_exposure: + log_action("风控", "净暴露在目标范围内,无需降仓", "info", {"net": net, "target": target_exposure}) + return + reduce_amount = abs(net) - target_exposure + log_action("风控", f"开始降仓,目标减仓金额约 {reduce_amount:.4f} USDT", "warning") + remaining = reduce_amount + # Naive reduction: iterate positions and submit opposite market IOC orders until reduce_amount <= 0 + for (k, v) in positions_to_reduce: + inst = k.split("-")[0] if "-" in k else SYMBOL + pos_size = safe_float(v.get("pos", 0.0)) + if pos_size == 0: + continue + avg_px = safe_float(v.get("avg_px", current_price or 0.0)) + notional = abs(pos_size) * avg_px * v.get("ctVal", CONTRACT_INFO.get("ctVal", 1.0)) if avg_px and pos_size else 0.0 + if pos_size > 0: + side = "sell" + else: + side = "buy" + proportion = min(1.0, remaining / (notional + 1e-12)) + sz_to_close = abs(pos_size) * proportion + sz_s = str(round_to_min_size(sz_to_close)) + if float(sz_s) <= 0: + continue + try: + await rate_limiter.wait_for(1, max_per_window=60, window_seconds=2) + resp = await asyncio.to_thread(trade_api.place_order, instId=inst, tdMode="isolated", side=side, ordType="market", sz=sz_s) + log_action("风控", f"降仓下单 inst={inst} side={side} sz={sz_s}", "info", resp) + except Exception as e: + log_action("风控", f"降仓下单异常: {e}", "error", exc_info=True) + remaining -= notional * proportion + if remaining <= 0: + break + log_action("风控", "降仓动作完成(示例实现)", "info") + except Exception as e: + log_action("风控", f"降仓异常: {e}", "error", exc_info=True) + + +# ---------------- WebSocket message handlers ---------------- +def _handle_account_ws_entry(entry: dict): + try: + data = entry.get("data", []) or [] + if not data: + return + snapshot = data[0] + total_eq = snapshot.get("totalEq") or snapshot.get("adjEq") or "" + if total_eq: + try: + global account_equity_usdt, initial_equity_usdt + account_equity_usdt = float(total_eq) + if initial_equity_usdt == 0.0: + initial_equity_usdt = account_equity_usdt + except Exception: + pass + details = snapshot.get("details", []) or [] + for d in details: + ccy = d.get("ccy") + if not ccy: + continue + avail = safe_float(d.get("availBal", 0.0)) + eq = safe_float(d.get("eq", 0.0)) + try: + position_info[f"BAL-{ccy}"]["pos"] = avail + position_info[f"BAL-{ccy}"]["avg_px"] = 0.0 + position_info[f"BAL-{ccy}"]["usdt_value"] = eq + except Exception: + position_info[f"BAL-{ccy}"] = {"pos": avail, "avg_px": 0.0, "usdt_value": eq} + log_action("WS账户", "处理 account 推送", "debug", {"totalEq": total_eq}) + except Exception as e: + log_action("WS账户处理", f"异常: {e}", "error", exc_info=True) + + +def _handle_order_ws_entry(entry: dict): + try: + ord_id = entry.get("ordId") or "" + cl = entry.get("clOrdId") or "" + trade_id = entry.get("tradeId") or "" + state = entry.get("state") or "" + fill_sz = safe_float(entry.get("fillSz", 0)) + acc_fill = safe_float(entry.get("accFillSz", 0)) + req_id = entry.get("reqId") or "" + if req_id and req_id in seen_reqids: + return + if req_id: + seen_reqids.add(req_id) + if trade_id: + if trade_id in seen_trade_ids: + return + seen_trade_ids.add(trade_id) + if not trade_id and state == "filled" and ord_id: + if ord_id in seen_filled_ordids: + return + seen_filled_ordids.add(ord_id) + key = cl or ord_id or generate_order_id("ws") + ao = active_orders.get(key, {}) + ao.update({ + "ord_id": ord_id, + "cl": cl, + "state": state, + "fillSz": fill_sz, + "accFillSz": acc_fill, + "tradeId": trade_id, + "raw": entry, + "last_update": time.time() + }) + active_orders[key] = ao + log_action("WS订单", f"更新订单 {key} state={state}", "debug", ao) + if entry.get("tradeId") or str(state).lower() in ("filled", "partially_filled", "partial-filled"): + try: + order_side = (entry.get("side") or "").lower() or "buy" + our_filled_side = order_side + cl_local = cl or ao.get("cl") or "" + async def _update_and_handle(): + async with orders_lock: + ao = active_orders.get(cl_local, {}) + ao.update({"raw_ws": entry, "state_ws": state, "last_update_ws": time.time()}) + active_orders[cl_local] = ao + try: + record_fill_event(our_filled_side, entry) + asyncio.create_task(on_fill_event(cl_local, our_filled_side, entry)) + except Exception: + pass + asyncio.create_task(_update_and_handle()) + except Exception: + pass + except Exception as e: + log_action("WS订单处理", f"异常: {e}", "error", exc_info=True) + + +def _handle_positions_ws_entry(entry: dict): + try: + inst = entry.get("instId") or SYMBOL + pos = safe_float(entry.get("pos", 0)) + pos_side = (entry.get("posSide") or "net").lower() + if pos_side == "net": + side = "long" if pos >= 0 else "short" + size = abs(pos) + else: + side = pos_side + size = abs(pos) + key = f"{inst}-{side}" + # store pos and avg_px + avg_px = safe_float(entry.get("avgPx", 0)) + position_info[key]["pos"] = pos if pos_side == "net" else size + position_info[key]["avg_px"] = avg_px + # compute signed usdt_value: long positive, short negative + try: + ct = safe_float(entry.get("ctVal", CONTRACT_INFO.get("ctVal", 1.0))) + except Exception: + ct = CONTRACT_INFO.get("ctVal", 1.0) or 1.0 + # signed exposure: pos (can be negative) * avg_px * ct + signed_usdt = float(pos) * float(avg_px or current_price or 0.0) * float(ct) + position_info[key]["usdt_value"] = signed_usdt + position_info[key]["ctVal"] = ct + log_action("WS仓位", f"{key} -> {size}", "debug", position_info[key]) + except Exception as e: + log_action("WS仓位处理", f"异常: {e}", "error", exc_info=True) + + +def _handle_balance_ws_entry(entry: dict): + log_action("WS余额", "收到 balance_and_position 更新", "debug", entry) + + +def _handle_fills_ws_entry(entry: dict): + trade_id = entry.get("tradeId") + if not trade_id: + return + if trade_id in seen_trade_ids: + return + seen_trade_ids.add(trade_id) + ord_id = entry.get("ordId") or "" + cl = entry.get("clOrdId") or "" + key = cl or ord_id + ao = active_orders.get(key, {}) + ao.update({ + "tradeId": trade_id, + "fillSz": safe_float(entry.get("fillSz", 0)), + "fillPx": safe_float(entry.get("fillPx", 0)), + "last_update": time.time() + }) + active_orders[key] = ao + try: + record_fill_event(entry.get("side") or "buy", entry) + asyncio.create_task(on_fill_event(cl or ao.get("cl",""), entry.get("side") or "buy", entry)) + except Exception: + pass + + +def _ws_message_callback(message: Any): + try: + data = message + if isinstance(message, str): + try: + data = json.loads(message) + except Exception: + log_action("WS", "非 JSON 消息", "debug", {"raw": message}) + return + if "event" in data and data.get("event"): + log_action("WS 事件", f"event={data.get('event')}", "debug", data.get("arg")) + arg = data.get("arg") or {} + channel = arg.get("channel") or data.get("channel") + if channel == "orders": + payloads = data.get("data", []) or [] + if isinstance(payloads, dict): + payloads = [payloads] + for p in payloads: + _handle_order_ws_entry(p) + return + if channel == "positions": + payloads = data.get("data", []) or [] + if isinstance(payloads, dict): + payloads = [payloads] + for p in payloads: + _handle_positions_ws_entry(p) + return + if channel == "balance_and_position": + payloads = data.get("data", []) or [] + for p in payloads: + _handle_balance_ws_entry(p) + return + if channel == "fills": + payloads = data.get("data", []) or [] + for p in payloads: + _handle_fills_ws_entry(p) + return + if channel == "account": + _handle_account_ws_entry(data) + return + log_action("WS 未知消息", "未处理的频道/消息", "debug", data) + except Exception as e: + log_action("WS 回调", f"异常: {e}", "error", exc_info=True) + + +# ---------------- WS keepalive ---------------- +_ws_last_recv = time.time() +_ws_ping_task: Optional[asyncio.Task] = None +_WS_PING_INTERVAL = 20 +_WS_PONG_WAIT = 5 + + +def _ws_mark_recv(): + global _ws_last_recv + _ws_last_recv = time.time() + + +async def _ws_ping_loop(get_ws_callable, interval=_WS_PING_INTERVAL, pong_wait=_WS_PONG_WAIT): + """ + Robust ping/pong loop: + - Try multiple ways to send ping (ws.ping, inner._ws, ws.ws, fallback). + - If cannot send or no pong received within pong_wait, attempt reconnect with simple backoff. + - Avoid raising out of loop; handle exceptions and continue. + """ + backoff_seconds = 1.0 + max_backoff = 30.0 + try: + while True: + await asyncio.sleep(interval) + ws = get_ws_callable() + if ws is None: + # no instance, try reconnect with backoff + log_action("WS 保活", "没有 WS 实例,等待并重试", "debug") + await asyncio.sleep(backoff_seconds) + backoff_seconds = min(max_backoff, backoff_seconds * 1.5) + continue + # reset backoff when we have an instance + backoff_seconds = 1.0 + # 近期已有消息则跳过 ping + if time.time() - _ws_last_recv < interval: + continue + try: + sent = False + # 1) prefer coroutine ping() + if hasattr(ws, "ping") and asyncio.iscoroutinefunction(getattr(ws, "ping")): + await ws.ping() + sent = True + log_action("WS 保活", "通过 ws.ping() 发送 ping", "debug") + elif hasattr(ws, "ping") and callable(getattr(ws, "ping")): + await asyncio.to_thread(ws.ping) + sent = True + log_action("WS 保活", "通过 ws.ping() (sync) 发送 ping", "debug") + # 2) inner attributes commonly used by wrappers + elif hasattr(ws, "_ws"): + inner = getattr(ws, "_ws") + if inner is not None: + if hasattr(inner, "ping"): + if asyncio.iscoroutinefunction(getattr(inner, "ping")): + await inner.ping() + else: + await asyncio.to_thread(inner.ping) + sent = True + log_action("WS 保活", "通过 ws._ws.ping() 发送 ping", "debug") + elif hasattr(inner, "send"): + if asyncio.iscoroutinefunction(getattr(inner, "send")): + await inner.send("ping") + else: + await asyncio.to_thread(inner.send, "ping") + sent = True + log_action("WS 保活", "通过 ws._ws.send() 发送 ping", "debug") + elif hasattr(ws, "ws"): + inner = getattr(ws, "ws") + if inner is not None: + if hasattr(inner, "ping"): + if asyncio.iscoroutinefunction(getattr(inner, "ping")): + await inner.ping() + else: + await asyncio.to_thread(inner.ping) + sent = True + log_action("WS 保活", "通过 ws.ws.ping() 发送 ping", "debug") + elif hasattr(inner, "send"): + if asyncio.iscoroutinefunction(getattr(inner, "send")): + await inner.send("ping") + else: + await asyncio.to_thread(inner.send, "ping") + sent = True + log_action("WS 保活", "通过 ws.ws.send() 发送 ping", "debug") + # 3) fallback to ws.send if exists + elif hasattr(ws, "send"): + if asyncio.iscoroutinefunction(getattr(ws, "send")): + await ws.send("ping") + else: + await asyncio.to_thread(ws.send, "ping") + sent = True + log_action("WS 保活", "通过 ws.send() 发送 ping (fallback)", "debug") + + if not sent: + # 无法发送 ping:记录并尝试重建连接(但用 backoff 防止风暴) + log_action("WS 保活", "WS 实例不支持发送 ping (no ping/send/_ws/ws)", "warning") + try: + await stop_private_ws() + except Exception: + pass + await asyncio.sleep(backoff_seconds) + asyncio.create_task(start_private_ws()) + backoff_seconds = min(max_backoff, backoff_seconds * 1.5) + continue + except Exception as e: + # 发送失败:记录并重连(带 backoff) + log_action("WS 保活", f"发送 ping 失败: {e}", "warning", exc_info=True) + try: + await stop_private_ws() + except Exception: + pass + await asyncio.sleep(backoff_seconds) + asyncio.create_task(start_private_ws()) + backoff_seconds = min(max_backoff, backoff_seconds * 1.5) + continue + + # 等待 pong_wait,看是否有任何消息/心跳到达 + await asyncio.sleep(pong_wait) + if time.time() - _ws_last_recv >= pong_wait: + log_action("WS 保活", "未收到 pong/消息,重连", "warning") + try: + await stop_private_ws() + except Exception: + pass + await asyncio.sleep(backoff_seconds) + asyncio.create_task(start_private_ws()) + backoff_seconds = min(max_backoff, backoff_seconds * 1.5) + except asyncio.CancelledError: + return + except Exception as e: + log_action("WS 保活", f"异常: {e}", "error", exc_info=True) + + +# ---------------- Public / Private WS start/stop ---------------- +async def start_private_ws(): + global _ws_instance, _ws_ping_task + if PrivateWs is None: + log_action("WS", "PrivateWs SDK 未安装,无法启动私有 WS", "error") + return None + try: + if str(OKX_FLAG) == "1": + ws_url = "wss://wspap.okx.com:8443/ws/v5/private" + else: + ws_url = "wss://ws.okx.com:8443/ws/v5/private" + log_action("WS", f"Connecting private WS to {ws_url}", "info") + ws = PrivateWs(apiKey=API_KEY, passphrase=PASSPHRASE, secretKey=SECRET_KEY, url=ws_url, useServerTime=False) + await ws.start() + _ws_instance = ws + log_action("WS", "私有 WS 已启动", "info") + args = [ + {"channel": "positions", "instType": "ANY"}, + {"channel": "balance_and_position"}, + {"channel": "orders", "instType": "ANY"}, + {"channel": "account"} + ] + if ENABLE_FILLS_CHANNEL: + args.append({"channel": "fills"}) + await ws.subscribe(args, callback=_ws_message_callback) + log_action("WS", "已订阅私有频道", "info", {"args": args}) + if _ws_ping_task is None or _ws_ping_task.done(): + _ws_ping_task = asyncio.create_task(_ws_ping_loop(lambda: _ws_instance)) + return ws + except Exception as e: + log_action("WS 启动", f"失败: {e}", "error", exc_info=True) + return None + + +async def stop_private_ws(): + global _ws_instance, _ws_ping_task + try: + if _ws_instance: + try: + await _ws_instance.stop() + except Exception: + pass + _ws_instance = None + if _ws_ping_task: + _ws_ping_task.cancel() + _ws_ping_task = None + log_action("WS", "私有 WS 已停止", "info") + except Exception as e: + log_action("WS 停止", f"异常: {e}", "warning", exc_info=True) + + +async def start_public_ws(): + global _public_ws_instance + if PublicWs is None: + log_action("Public WS", "PublicWs SDK 未安装,跳过", "warning") + return None + try: + if str(OKX_FLAG) == "1": + public_url = "wss://wspap.okx.com:8443/ws/v5/public" + else: + public_url = "wss://ws.okx.com:8443/ws/v5/public" + ws = PublicWs(url=public_url) + await ws.start() + _public_ws_instance = ws + args = [{"channel": "tickers", "instId": SYMBOL}] + await ws.subscribe(args, callback=_public_ws_ticker_callback) + log_action("Public WS", "已订阅 tickers", "info", {"inst": SYMBOL, "url": public_url}) + return ws + except Exception as e: + log_action("Public WS", f"启动失败: {e}", "error", exc_info=True) + return None + + +async def stop_public_ws(): + global _public_ws_instance + try: + if _public_ws_instance: + try: + await _public_ws_instance.stop() + except Exception: + pass + _public_ws_instance = None + log_action("Public WS", "已停止", "info") + except Exception as e: + log_action("Public WS 停止", f"异常: {e}", "warning", exc_info=True) + + +def _public_ws_ticker_callback(message: Any): + global current_price, last_price, price_source, best_bid, best_ask + try: + data = message + if isinstance(message, str): + try: + data = json.loads(message) + except Exception: + return + if "event" in data: + return + arg = data.get("arg") or {} + channel = arg.get("channel") or data.get("channel") + if channel != "tickers": + return + payloads = data.get("data", []) or [] + if isinstance(payloads, dict): + payloads = [payloads] + for p in payloads: + price = None + for k in ("last", "lastPx", "price"): + if k in p: + price = safe_float(p.get(k)) + break + if price is None: + bid = safe_float(p.get("bidPx", 0)) + ask = safe_float(p.get("askPx", 0)) + if bid and ask: + price = (bid + ask) / 2.0 + elif bid: + price = bid + elif ask: + price = ask + if price is None: + continue + best_bid = safe_float(p.get("bidPx", best_bid)) + best_ask = safe_float(p.get("askPx", best_ask)) + last_price = current_price + current_price = float(price) + price_source = "ws_ticker" + try: + feed_price_tick(current_price) + except Exception: + pass + except Exception: + pass + + +# ---------------- Strategy helpers (continuous-eaten, gatekeeper, dynamic offset) ---------------- +N_CONSEC = 3 +WINDOW_SECONDS = 60 +MIN_FILLS_WINDOW = 5 +P_THRESHOLD = 0.7 +PAUSE_SECONDS_AFTER_CONSEC = 120 +SCALE_DOWN_FACTOR = 0.3 + +SAFETY_MARGIN_USDT = 0.5 +MAX_REQUIRED_MOVE_PCT = 0.015 + +MIN_OFFSET_PCT = 0.0008 +VOL_K = 1.0 +VOL_WINDOW_SECONDS = 60 + +_recent_fills = deque() +_consec_same_side = 0 +_last_fill_side: Optional[str] = None +_paused_until = 0.0 +_trend_watch_until = 0.0 +_mark_scale_down_next = False + +_price_ticks = deque() + +def record_fill_event(side: str, fill_info: Dict[str, Any]) -> None: + global _consec_same_side, _last_fill_side, _recent_fills + ts = time.time() + if _last_fill_side == side: + _consec_same_side += 1 + else: + _consec_same_side = 1 + _last_fill_side = side + _recent_fills.append((ts, side, fill_info)) + while _recent_fills and _recent_fills[0][0] < ts - WINDOW_SECONDS: + _recent_fills.popleft() + +def check_window_rule() -> bool: + if len(_recent_fills) < MIN_FILLS_WINDOW: + return False + same = sum(1 for (_, s, _) in _recent_fills if s == _last_fill_side) + frac = same / len(_recent_fills) + return frac >= P_THRESHOLD + +async def on_fill_event(cl: str, side: str, fill_info: Dict[str, Any]): + global _paused_until, _trend_watch_until, _mark_scale_down_next, _consec_same_side + record_fill_event(side, fill_info) + window_trigger = check_window_rule() + if _consec_same_side >= N_CONSEC or window_trigger: + _paused_until = time.time() + PAUSE_SECONDS_AFTER_CONSEC + _trend_watch_until = max(_trend_watch_until, time.time() + PAUSE_SECONDS_AFTER_CONSEC) + _mark_scale_down_next = True + try: + other_cl = get_other_cl_of_active_pair(cl) + if other_cl: + await cancel_order_by_cl(other_cl) + except Exception: + pass + try: + log_action("风控", f"连续被吃保护触发 side={side} consec={_consec_same_side}", "warning", {"fill": fill_info}) + except Exception: + pass + try: + if net_exposure_exceeds_threshold(): + asyncio.create_task(reduce_positions_to_safe_level()) + except Exception: + pass + return + return + +def is_paused() -> bool: + return time.time() < _paused_until + +def mark_and_consume_scale_down() -> bool: + global _mark_scale_down_next + if _mark_scale_down_next: + _mark_scale_down_next = False + return True + return False + +def should_place_pair(position_value: float, + fee_maker: float, + fee_taker: float, + funding_rate_per_8h: float, + expected_hold_seconds: float, + expected_slippage_pct: float, + safety_margin_usdt: float = SAFETY_MARGIN_USDT, + max_required_move_pct: float = MAX_REQUIRED_MOVE_PCT) -> Dict[str, Any]: + roundtrip_fee_usdt = position_value * (fee_maker + fee_maker) + funding_usdt = position_value * funding_rate_per_8h * (expected_hold_seconds / (8*3600)) + slippage_usdt = position_value * expected_slippage_pct + total_cost_usdt = roundtrip_fee_usdt + funding_usdt + slippage_usdt + safety_margin_usdt + required_move_pct = total_cost_usdt / position_value + can_place = (required_move_pct < max_required_move_pct) + return { + "roundtrip_fee_usdt": roundtrip_fee_usdt, + "funding_usdt": funding_usdt, + "slippage_usdt": slippage_usdt, + "total_cost_usdt": total_cost_usdt, + "required_move_pct": required_move_pct, + "can_place": can_place + } + +# ---------------- Main loops ---------------- +async def main_loop_once(): + ok = await fetch_instrument_info() + if not ok: + log_action("主流程", "未能加载合约信息,退出", "error") + return + await update_price_from_rest() + await fetch_account_and_positions() + log_action("主流程", "一次循环完成", "info") + + +async def main_loop_continuous(): + private_task = asyncio.create_task(start_private_ws()) + public_task = None + if ENABLE_PUBLIC_TICKER: + public_task = asyncio.create_task(start_public_ws()) + try: + while True: + try: + await main_loop_once() + except Exception as e: + log_action("主循环", f"异常: {e}", "error", exc_info=True) + await asyncio.sleep(5) + finally: + try: + await stop_private_ws() + except Exception: + pass + if public_task: + try: + await stop_public_ws() + except Exception: + pass + + +# ---------------- Entrypoint ---------------- +if __name__ == "__main__": + initialize_clients() + log_action("启动", f"OKX_FLAG={OKX_FLAG} (测试网=1)", "info") + asyncio.run(main_loop_continuous())