Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions backend/generators/image_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,36 @@ class ImageApiGenerator(ImageGeneratorBase):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
logger.debug("初始化 ImageApiGenerator...")
self.base_url = config.get('base_url', 'https://api.example.com').rstrip('/').rstrip('/v1')
raw_base_url = (config.get('base_url') or 'https://api.example.com').rstrip('/')
if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()):
self.api_version = 'v3'
else:
self.api_version = 'v1'

if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'):
self.base_url = raw_base_url.rsplit('/', 1)[0]
else:
self.base_url = raw_base_url

self.model = config.get('model', 'default-model')
self.default_aspect_ratio = config.get('default_aspect_ratio', '3:4')
self.image_size = config.get('image_size', '4K')

# 支持自定义端点路径
endpoint_type = config.get('endpoint_type', '/v1/images/generations')
endpoint_type = config.get('endpoint_type', f'/{self.api_version}/images/generations')
# 兼容旧的简写格式
if endpoint_type == 'images':
endpoint_type = '/v1/images/generations'
endpoint_type = f'/{self.api_version}/images/generations'
elif endpoint_type == 'chat':
endpoint_type = '/v1/chat/completions'
endpoint_type = f'/{self.api_version}/chat/completions'
# 确保以 / 开头
if not endpoint_type.startswith('/'):
endpoint_type = '/' + endpoint_type

if endpoint_type.startswith('/v1/') and self.api_version == 'v3':
endpoint_type = '/v3/' + endpoint_type[len('/v1/'):]
elif endpoint_type.startswith('/v3/') and self.api_version == 'v1':
endpoint_type = '/v1/' + endpoint_type[len('/v3/'):]
self.endpoint_type = endpoint_type

logger.info(f"ImageApiGenerator 初始化完成: base_url={self.base_url}, model={self.model}, endpoint={self.endpoint_type}")
Expand Down Expand Up @@ -133,7 +148,8 @@ def _generate_via_images_api(
"prompt": prompt,
"response_format": "b64_json",
"aspect_ratio": aspect_ratio,
"image_size": self.image_size
"image_size": self.image_size,
"watermark": False
}

# 收集所有参考图片
Expand Down Expand Up @@ -299,50 +315,85 @@ def _generate_via_chat_api(
result = response.json()
logger.debug(f"Chat API 响应: {str(result)[:500]}")

# 解析响应
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]:
content = choice["message"]["content"]
extracted = self._extract_image_from_chat_result(result)
if extracted is not None:
return extracted

raise Exception(
"❌ 无法从 Chat API 响应中提取图片数据\n\n"
f"【响应内容】\n{str(result)[:500]}\n\n"
"【可能原因】\n"
"1. 该模型不支持图片生成\n"
"2. 响应格式与预期不符\n"
"3. 提示词被安全过滤\n\n"
"【解决方案】\n"
"1. 确认模型名称正确\n"
"2. 修改提示词后重试"
)

def _extract_image_from_chat_result(self, result: Any) -> Optional[bytes]:
import re

if isinstance(result, dict):
data = result.get('data')
if isinstance(data, list) and data:
item = data[0] if isinstance(data[0], dict) else None
if item:
if 'b64_json' in item and isinstance(item['b64_json'], str):
b64_string = item['b64_json']
if b64_string.startswith('data:'):
b64_string = b64_string.split(',', 1)[1]
return base64.b64decode(b64_string)
if 'url' in item and isinstance(item['url'], str):
return self._download_image(item['url'].strip())

choices = result.get('choices')
if isinstance(choices, list) and choices:
choice0 = choices[0] if isinstance(choices[0], dict) else None
message = (choice0 or {}).get('message')
content = (message or {}).get('content') if isinstance(message, dict) else None

if isinstance(content, list):
for part in content:
if not isinstance(part, dict):
continue
part_type = part.get('type')
if part_type in {'image_url', 'image'}:
image_url = part.get('image_url')
if isinstance(image_url, dict):
url = image_url.get('url')
if isinstance(url, str) and url.strip():
url = url.strip()
if url.startswith('data:image'):
return base64.b64decode(url.split(',', 1)[1])
if url.startswith('http://') or url.startswith('https://'):
return self._download_image(url)

if isinstance(content, str):
# Markdown 图片链接: ![xxx](url)
text = content.strip()

pattern = r'!\[.*?\]\((https?://[^\s\)]+)\)'
urls = re.findall(pattern, content)
urls = re.findall(pattern, text)
if urls:
logger.info(f"从 Markdown 提取到 {len(urls)} 张图片,下载第一张...")
return self._download_image(urls[0])

# Markdown 图片 Base64: ![xxx](data:image/...)
base64_pattern = r'!\[.*?\]\((data:image\/[^;]+;base64,[^\s\)]+)\)'
base64_urls = re.findall(base64_pattern, content)
base64_urls = re.findall(base64_pattern, text)
if base64_urls:
logger.info("从 Markdown 提取到 Base64 图片数据")
base64_data = base64_urls[0].split(",")[1]
base64_data = base64_urls[0].split(",", 1)[1]
return base64.b64decode(base64_data)

# 纯 Base64 data URL
if content.startswith("data:image"):
if text.startswith('data:image') and ',' in text:
logger.info("检测到 Base64 图片数据")
base64_data = content.split(",")[1]
return base64.b64decode(base64_data)
return base64.b64decode(text.split(',', 1)[1])

# 纯 URL
if content.startswith("http://") or content.startswith("https://"):
if text.startswith('http://') or text.startswith('https://'):
logger.info("检测到图片 URL")
return self._download_image(content.strip())
return self._download_image(text)

raise Exception(
"❌ 无法从 Chat API 响应中提取图片数据\n\n"
f"【响应内容】\n{str(result)[:500]}\n\n"
"【可能原因】\n"
"1. 该模型不支持图片生成\n"
"2. 响应格式与预期不符\n"
"3. 提示词被安全过滤\n\n"
"【解决方案】\n"
"1. 确认模型名称正确\n"
"2. 修改提示词后重试"
)
return None

def _download_image(self, url: str) -> bytes:
"""下载图片并返回二进制数据"""
Expand Down
30 changes: 23 additions & 7 deletions backend/generators/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,38 @@ def __init__(self, config: Dict[str, Any]):
"解决方案:在系统设置页面编辑该服务商,填写 Base URL"
)

# 规范化 base_url:去除末尾 /v1
self.base_url = self.base_url.rstrip('/').rstrip('/v1')
# 规范化 base_url 并识别 API 版本(OpenAI 默认 v1,火山引擎 Ark/Doubao 为 v3)
raw_base_url = (self.base_url or '').rstrip('/')

# 识别 Doubao/Ark:如果 base_url 以 /v3 结尾,或域名包含 volces/ark,则使用 v3
if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()):
self.api_version = 'v3'
else:
self.api_version = 'v1'

# 去掉末尾版本号,保留基础路径(例如 https://ark.../api)
if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'):
self.base_url = raw_base_url.rsplit('/', 1)[0]
else:
self.base_url = raw_base_url

# 默认模型
self.default_model = config.get('model', 'dall-e-3')

# API 端点类型: 支持完整路径 (如 '/v1/images/generations') 或简写 ('images', 'chat')
endpoint_type = config.get('endpoint_type', '/v1/images/generations')
# 兼容旧的简写格式
default_endpoint = f"/{self.api_version}/images/generations"
endpoint_type = config.get('endpoint_type', default_endpoint)
# 兼容简写:根据识别出的版本拼接正确路径
if endpoint_type == 'images':
endpoint_type = '/v1/images/generations'
endpoint_type = f"/{self.api_version}/images/generations"
elif endpoint_type == 'chat':
endpoint_type = '/v1/chat/completions'
endpoint_type = f"/{self.api_version}/chat/completions"
self.endpoint_type = endpoint_type

logger.info(f"OpenAICompatibleGenerator 初始化完成: base_url={self.base_url}, model={self.default_model}, endpoint={self.endpoint_type}")
logger.info(
f"OpenAICompatibleGenerator 初始化完成: base_url={self.base_url}, api_version={self.api_version}, "
f"model={self.default_model}, endpoint={self.endpoint_type}"
)

def validate_config(self) -> bool:
"""验证配置"""
Expand Down
22 changes: 18 additions & 4 deletions backend/routes/config_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def _test_openai_compatible(config: dict, test_prompt: str) -> dict:
"""测试 OpenAI 兼容接口"""
import requests

base_url = config['base_url'].rstrip('/').rstrip('/v1') if config.get('base_url') else 'https://api.openai.com'
url = f"{base_url}/v1/chat/completions"
base_url, api_version = _normalize_base_url_and_version(config.get('base_url'))
url = f"{base_url}/{api_version}/chat/completions"

payload = {
"model": config.get('model') or 'gpt-3.5-turbo',
Expand Down Expand Up @@ -396,8 +396,8 @@ def _test_image_api(config: dict) -> dict:
"""测试图片 API 连接"""
import requests

base_url = config['base_url'].rstrip('/').rstrip('/v1') if config.get('base_url') else 'https://api.openai.com'
url = f"{base_url}/v1/models"
base_url, api_version = _normalize_base_url_and_version(config.get('base_url'))
url = f"{base_url}/{api_version}/models"

response = requests.get(
url,
Expand Down Expand Up @@ -426,3 +426,17 @@ def _check_response(result_text: str) -> dict:
"success": True,
"message": f"连接成功,但响应内容不符合预期: {result_text[:100]}"
}


def _normalize_base_url_and_version(base_url: str | None) -> tuple[str, str]:
raw = (base_url or 'https://api.openai.com').rstrip('/')
lowered = raw.lower()
if raw.endswith('/v3') or ('volces' in lowered) or ('ark' in lowered):
api_version = 'v3'
else:
api_version = 'v1'

if raw.endswith('/v1') or raw.endswith('/v3'):
raw = raw.rsplit('/', 1)[0]

return raw, api_version
22 changes: 19 additions & 3 deletions backend/utils/text_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,26 @@ def __init__(self, api_key: str = None, base_url: str = None, endpoint_type: str
"解决方案:在系统设置页面编辑文本生成服务商,填写 API Key"
)

self.base_url = (base_url or "https://api.openai.com").rstrip('/').rstrip('/v1')
# 规范化 base_url 并识别 API 版本(OpenAI 默认 v1,火山引擎 Ark/Doubao 为 v3)
raw_base_url = (base_url or "https://api.openai.com").rstrip('/')
if raw_base_url.endswith('/v3') or ('volces' in raw_base_url.lower()) or ('ark' in raw_base_url.lower()):
self.api_version = 'v3'
else:
self.api_version = 'v1'

# 支持自定义端点路径
endpoint = endpoint_type or '/v1/chat/completions'
# 去掉末尾版本号,保留基础路径(例如 https://ark.../api)
if raw_base_url.endswith('/v1') or raw_base_url.endswith('/v3'):
self.base_url = raw_base_url.rsplit('/', 1)[0]
else:
self.base_url = raw_base_url

# 支持自定义端点路径;简写会自动映射至正确版本
default_endpoint = f"/{self.api_version}/chat/completions"
endpoint = endpoint_type or default_endpoint
if endpoint == 'chat':
endpoint = f"/{self.api_version}/chat/completions"
elif endpoint == 'images':
endpoint = f"/{self.api_version}/images/generations"
# 确保端点以 / 开头
if not endpoint.startswith('/'):
endpoint = '/' + endpoint
Expand Down