diff --git a/backend/generators/image_api.py b/backend/generators/image_api.py index 62e99b6..fcf96f9 100644 --- a/backend/generators/image_api.py +++ b/backend/generators/image_api.py @@ -36,21 +36,36 @@ class ImageApiGenerator(ImageGeneratorBase): def __init__(self, config: Dict[str, Any]): super().__init__(config) logger.debug("初始化 ImageApiGenerator...") - self.base_url = config.get('base_url', 'https://api.example.com').rstrip('/').rstrip('/v1') + raw_base_url = (config.get('base_url') or 'https://api.example.com').rstrip('/') + if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()): + self.api_version = 'v3' + else: + self.api_version = 'v1' + + if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'): + self.base_url = raw_base_url.rsplit('/', 1)[0] + else: + self.base_url = raw_base_url + self.model = config.get('model', 'default-model') self.default_aspect_ratio = config.get('default_aspect_ratio', '3:4') self.image_size = config.get('image_size', '4K') # 支持自定义端点路径 - endpoint_type = config.get('endpoint_type', '/v1/images/generations') + endpoint_type = config.get('endpoint_type', f'/{self.api_version}/images/generations') # 兼容旧的简写格式 if endpoint_type == 'images': - endpoint_type = '/v1/images/generations' + endpoint_type = f'/{self.api_version}/images/generations' elif endpoint_type == 'chat': - endpoint_type = '/v1/chat/completions' + endpoint_type = f'/{self.api_version}/chat/completions' # 确保以 / 开头 if not endpoint_type.startswith('/'): endpoint_type = '/' + endpoint_type + + if endpoint_type.startswith('/v1/') and self.api_version == 'v3': + endpoint_type = '/v3/' + endpoint_type[len('/v1/'):] + elif endpoint_type.startswith('/v3/') and self.api_version == 'v1': + endpoint_type = '/v1/' + endpoint_type[len('/v3/'):] self.endpoint_type = endpoint_type logger.info(f"ImageApiGenerator 初始化完成: base_url={self.base_url}, model={self.model}, endpoint={self.endpoint_type}") @@ -133,7 +148,8 @@ def _generate_via_images_api( "prompt": prompt, "response_format": "b64_json", "aspect_ratio": aspect_ratio, - "image_size": self.image_size + "image_size": self.image_size, + "watermark": False } # 收集所有参考图片 @@ -299,50 +315,85 @@ def _generate_via_chat_api( result = response.json() logger.debug(f"Chat API 响应: {str(result)[:500]}") - # 解析响应 - if "choices" in result and len(result["choices"]) > 0: - choice = result["choices"][0] - if "message" in choice and "content" in choice["message"]: - content = choice["message"]["content"] + extracted = self._extract_image_from_chat_result(result) + if extracted is not None: + return extracted + + raise Exception( + "❌ 无法从 Chat API 响应中提取图片数据\n\n" + f"【响应内容】\n{str(result)[:500]}\n\n" + "【可能原因】\n" + "1. 该模型不支持图片生成\n" + "2. 响应格式与预期不符\n" + "3. 提示词被安全过滤\n\n" + "【解决方案】\n" + "1. 确认模型名称正确\n" + "2. 修改提示词后重试" + ) + + def _extract_image_from_chat_result(self, result: Any) -> Optional[bytes]: + import re + + if isinstance(result, dict): + data = result.get('data') + if isinstance(data, list) and data: + item = data[0] if isinstance(data[0], dict) else None + if item: + if 'b64_json' in item and isinstance(item['b64_json'], str): + b64_string = item['b64_json'] + if b64_string.startswith('data:'): + b64_string = b64_string.split(',', 1)[1] + return base64.b64decode(b64_string) + if 'url' in item and isinstance(item['url'], str): + return self._download_image(item['url'].strip()) + + choices = result.get('choices') + if isinstance(choices, list) and choices: + choice0 = choices[0] if isinstance(choices[0], dict) else None + message = (choice0 or {}).get('message') + content = (message or {}).get('content') if isinstance(message, dict) else None + + if isinstance(content, list): + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get('type') + if part_type in {'image_url', 'image'}: + image_url = part.get('image_url') + if isinstance(image_url, dict): + url = image_url.get('url') + if isinstance(url, str) and url.strip(): + url = url.strip() + if url.startswith('data:image'): + return base64.b64decode(url.split(',', 1)[1]) + if url.startswith('http://') or url.startswith('https://'): + return self._download_image(url) if isinstance(content, str): - # Markdown 图片链接: ![xxx](url) + text = content.strip() + pattern = r'!\[.*?\]\((https?://[^\s\)]+)\)' - urls = re.findall(pattern, content) + urls = re.findall(pattern, text) if urls: logger.info(f"从 Markdown 提取到 {len(urls)} 张图片,下载第一张...") return self._download_image(urls[0]) - # Markdown 图片 Base64: ![xxx](data:image/...) base64_pattern = r'!\[.*?\]\((data:image\/[^;]+;base64,[^\s\)]+)\)' - base64_urls = re.findall(base64_pattern, content) + base64_urls = re.findall(base64_pattern, text) if base64_urls: logger.info("从 Markdown 提取到 Base64 图片数据") - base64_data = base64_urls[0].split(",")[1] + base64_data = base64_urls[0].split(",", 1)[1] return base64.b64decode(base64_data) - # 纯 Base64 data URL - if content.startswith("data:image"): + if text.startswith('data:image') and ',' in text: logger.info("检测到 Base64 图片数据") - base64_data = content.split(",")[1] - return base64.b64decode(base64_data) + return base64.b64decode(text.split(',', 1)[1]) - # 纯 URL - if content.startswith("http://") or content.startswith("https://"): + if text.startswith('http://') or text.startswith('https://'): logger.info("检测到图片 URL") - return self._download_image(content.strip()) + return self._download_image(text) - raise Exception( - "❌ 无法从 Chat API 响应中提取图片数据\n\n" - f"【响应内容】\n{str(result)[:500]}\n\n" - "【可能原因】\n" - "1. 该模型不支持图片生成\n" - "2. 响应格式与预期不符\n" - "3. 提示词被安全过滤\n\n" - "【解决方案】\n" - "1. 确认模型名称正确\n" - "2. 修改提示词后重试" - ) + return None def _download_image(self, url: str) -> bytes: """下载图片并返回二进制数据""" diff --git a/backend/generators/openai_compatible.py b/backend/generators/openai_compatible.py index 2fdbf76..7517e1a 100644 --- a/backend/generators/openai_compatible.py +++ b/backend/generators/openai_compatible.py @@ -69,22 +69,38 @@ def __init__(self, config: Dict[str, Any]): "解决方案:在系统设置页面编辑该服务商,填写 Base URL" ) - # 规范化 base_url:去除末尾 /v1 - self.base_url = self.base_url.rstrip('/').rstrip('/v1') + # 规范化 base_url 并识别 API 版本(OpenAI 默认 v1,火山引擎 Ark/Doubao 为 v3) + raw_base_url = (self.base_url or '').rstrip('/') + + # 识别 Doubao/Ark:如果 base_url 以 /v3 结尾,或域名包含 volces/ark,则使用 v3 + if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()): + self.api_version = 'v3' + else: + self.api_version = 'v1' + + # 去掉末尾版本号,保留基础路径(例如 https://ark.../api) + if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'): + self.base_url = raw_base_url.rsplit('/', 1)[0] + else: + self.base_url = raw_base_url # 默认模型 self.default_model = config.get('model', 'dall-e-3') # API 端点类型: 支持完整路径 (如 '/v1/images/generations') 或简写 ('images', 'chat') - endpoint_type = config.get('endpoint_type', '/v1/images/generations') - # 兼容旧的简写格式 + default_endpoint = f"/{self.api_version}/images/generations" + endpoint_type = config.get('endpoint_type', default_endpoint) + # 兼容简写:根据识别出的版本拼接正确路径 if endpoint_type == 'images': - endpoint_type = '/v1/images/generations' + endpoint_type = f"/{self.api_version}/images/generations" elif endpoint_type == 'chat': - endpoint_type = '/v1/chat/completions' + endpoint_type = f"/{self.api_version}/chat/completions" self.endpoint_type = endpoint_type - logger.info(f"OpenAICompatibleGenerator 初始化完成: base_url={self.base_url}, model={self.default_model}, endpoint={self.endpoint_type}") + logger.info( + f"OpenAICompatibleGenerator 初始化完成: base_url={self.base_url}, api_version={self.api_version}, " + f"model={self.default_model}, endpoint={self.endpoint_type}" + ) def validate_config(self) -> bool: """验证配置""" diff --git a/backend/routes/config_routes.py b/backend/routes/config_routes.py index 4f58eaf..a829753 100644 --- a/backend/routes/config_routes.py +++ b/backend/routes/config_routes.py @@ -364,8 +364,8 @@ def _test_openai_compatible(config: dict, test_prompt: str) -> dict: """测试 OpenAI 兼容接口""" import requests - base_url = config['base_url'].rstrip('/').rstrip('/v1') if config.get('base_url') else 'https://api.openai.com' - url = f"{base_url}/v1/chat/completions" + base_url, api_version = _normalize_base_url_and_version(config.get('base_url')) + url = f"{base_url}/{api_version}/chat/completions" payload = { "model": config.get('model') or 'gpt-3.5-turbo', @@ -396,8 +396,8 @@ def _test_image_api(config: dict) -> dict: """测试图片 API 连接""" import requests - base_url = config['base_url'].rstrip('/').rstrip('/v1') if config.get('base_url') else 'https://api.openai.com' - url = f"{base_url}/v1/models" + base_url, api_version = _normalize_base_url_and_version(config.get('base_url')) + url = f"{base_url}/{api_version}/models" response = requests.get( url, @@ -426,3 +426,17 @@ def _check_response(result_text: str) -> dict: "success": True, "message": f"连接成功,但响应内容不符合预期: {result_text[:100]}" } + + +def _normalize_base_url_and_version(base_url: str | None) -> tuple[str, str]: + raw = (base_url or 'https://api.openai.com').rstrip('/') + lowered = raw.lower() + if raw.endswith('/v3') or ('volces' in lowered) or ('ark' in lowered): + api_version = 'v3' + else: + api_version = 'v1' + + if raw.endswith('/v1') or raw.endswith('/v3'): + raw = raw.rsplit('/', 1)[0] + + return raw, api_version diff --git a/backend/utils/text_client.py b/backend/utils/text_client.py index 2bd2be6..bd41aa9 100644 --- a/backend/utils/text_client.py +++ b/backend/utils/text_client.py @@ -48,10 +48,26 @@ def __init__(self, api_key: str = None, base_url: str = None, endpoint_type: str "解决方案:在系统设置页面编辑文本生成服务商,填写 API Key" ) - self.base_url = (base_url or "https://api.openai.com").rstrip('/').rstrip('/v1') + # 规范化 base_url 并识别 API 版本(OpenAI 默认 v1,火山引擎 Ark/Doubao 为 v3) + raw_base_url = (base_url or "https://api.openai.com").rstrip('/') + if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()): + self.api_version = 'v3' + else: + self.api_version = 'v1' - # 支持自定义端点路径 - endpoint = endpoint_type or '/v1/chat/completions' + # 去掉末尾版本号,保留基础路径(例如 https://ark.../api) + if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'): + self.base_url = raw_base_url.rsplit('/', 1)[0] + else: + self.base_url = raw_base_url + + # 支持自定义端点路径;简写会自动映射至正确版本 + default_endpoint = f"/{self.api_version}/chat/completions" + endpoint = endpoint_type or default_endpoint + if endpoint == 'chat': + endpoint = f"/{self.api_version}/chat/completions" + elif endpoint == 'images': + endpoint = f"/{self.api_version}/images/generations" # 确保端点以 / 开头 if not endpoint.startswith('/'): endpoint = '/' + endpoint