Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from flask_cors import CORS
from backend.config import Config
from backend.routes import register_routes
from backend.models import db
import os


def setup_logging():
Expand All @@ -20,17 +22,16 @@ def setup_logging():
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_format = logging.Formatter(
'\n%(asctime)s | %(levelname)-8s | %(name)s\n'
' └─ %(message)s',
datefmt='%H:%M:%S'
"\n%(asctime)s | %(levelname)-8s | %(name)s\n └─ %(message)s",
datefmt="%H:%M:%S",
)
console_handler.setFormatter(console_format)
root_logger.addHandler(console_handler)

# 设置各模块的日志级别
logging.getLogger('backend').setLevel(logging.DEBUG)
logging.getLogger('werkzeug').setLevel(logging.INFO)
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger("backend").setLevel(logging.DEBUG)
logging.getLogger("werkzeug").setLevel(logging.INFO)
logging.getLogger("urllib3").setLevel(logging.WARNING)

return root_logger

Expand All @@ -41,27 +42,42 @@ def create_app():
logger.info("🚀 正在启动 红墨 AI图文生成器...")

# 检查是否存在前端构建产物(Docker 环境)
frontend_dist = Path(__file__).parent.parent / 'frontend' / 'dist'
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
if frontend_dist.exists():
logger.info("📦 检测到前端构建产物,启用静态文件托管模式")
app = Flask(
__name__,
static_folder=str(frontend_dist),
static_url_path=''
)
app = Flask(__name__, static_folder=str(frontend_dist), static_url_path="")
else:
logger.info("🔧 开发模式,前端请单独启动")
app = Flask(__name__)

app.config.from_object(Config)

CORS(app, resources={
r"/api/*": {
"origins": Config.CORS_ORIGINS,
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Content-Type"],
}
})
# 数据库配置
db_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "history", "redink.db"
)
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_path}"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False

# 确保 history 目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)

db.init_app(app)

with app.app_context():
db.create_all()
_ensure_history_indexes()

CORS(
app,
resources={
r"/api/*": {
"origins": Config.CORS_ORIGINS,
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Content-Type"],
}
},
)

# 注册所有 API 路由
register_routes(app)
Expand All @@ -71,16 +87,18 @@ def create_app():

# 根据是否有前端构建产物决定根路由行为
if frontend_dist.exists():
@app.route('/')

@app.route("/")
def serve_index():
return send_from_directory(app.static_folder, 'index.html')
return send_from_directory(app.static_folder, "index.html")

# 处理 Vue Router 的 HTML5 History 模式
@app.errorhandler(404)
def fallback(e):
return send_from_directory(app.static_folder, 'index.html')
return send_from_directory(app.static_folder, "index.html")
else:
@app.route('/')

@app.route("/")
def index():
return {
"message": "红墨 AI图文生成器 API",
Expand All @@ -89,13 +107,29 @@ def index():
"health": "/api/health",
"outline": "POST /api/outline",
"generate": "POST /api/generate",
"images": "GET /api/images/<filename>"
}
"images": "GET /api/images/<filename>",
},
}

return app


def _ensure_history_indexes():
"""确保历史记录相关索引存在(幂等)"""
from sqlalchemy import text

statements = [
"CREATE INDEX IF NOT EXISTS idx_history_records_status ON history_records (status)",
"CREATE INDEX IF NOT EXISTS idx_history_records_created_at ON history_records (created_at)",
"CREATE INDEX IF NOT EXISTS idx_history_records_task_id ON history_records (task_id)",
"CREATE INDEX IF NOT EXISTS idx_history_records_status_created_at ON history_records (status, created_at)",
]

for stmt in statements:
db.session.execute(text(stmt))
db.session.commit()


def _validate_config_on_startup(logger):
"""启动时验证配置"""
from pathlib import Path
Expand All @@ -104,19 +138,19 @@ def _validate_config_on_startup(logger):
logger.info("📋 检查配置文件...")

# 检查 text_providers.yaml
text_config_path = Path(__file__).parent.parent / 'text_providers.yaml'
text_config_path = Path(__file__).parent.parent / "text_providers.yaml"
if text_config_path.exists():
try:
with open(text_config_path, 'r', encoding='utf-8') as f:
with open(text_config_path, "r", encoding="utf-8") as f:
text_config = yaml.safe_load(f) or {}
active = text_config.get('active_provider', '未设置')
providers = list(text_config.get('providers', {}).keys())
active = text_config.get("active_provider", "未设置")
providers = list(text_config.get("providers", {}).keys())
logger.info(f"✅ 文本生成配置: 激活={active}, 可用服务商={providers}")

# 检查激活的服务商是否有 API Key
if active in text_config.get('providers', {}):
provider = text_config['providers'][active]
if not provider.get('api_key'):
if active in text_config.get("providers", {}):
provider = text_config["providers"][active]
if not provider.get("api_key"):
logger.warning(f"⚠️ 文本服务商 [{active}] 未配置 API Key")
else:
logger.info(f"✅ 文本服务商 [{active}] API Key 已配置")
Expand All @@ -126,19 +160,19 @@ def _validate_config_on_startup(logger):
logger.warning("⚠️ text_providers.yaml 不存在,将使用默认配置")

# 检查 image_providers.yaml
image_config_path = Path(__file__).parent.parent / 'image_providers.yaml'
image_config_path = Path(__file__).parent.parent / "image_providers.yaml"
if image_config_path.exists():
try:
with open(image_config_path, 'r', encoding='utf-8') as f:
with open(image_config_path, "r", encoding="utf-8") as f:
image_config = yaml.safe_load(f) or {}
active = image_config.get('active_provider', '未设置')
providers = list(image_config.get('providers', {}).keys())
active = image_config.get("active_provider", "未设置")
providers = list(image_config.get("providers", {}).keys())
logger.info(f"✅ 图片生成配置: 激活={active}, 可用服务商={providers}")

# 检查激活的服务商是否有 API Key
if active in image_config.get('providers', {}):
provider = image_config['providers'][active]
if not provider.get('api_key'):
if active in image_config.get("providers", {}):
provider = image_config["providers"][active]
if not provider.get("api_key"):
logger.warning(f"⚠️ 图片服务商 [{active}] 未配置 API Key")
else:
logger.info(f"✅ 图片服务商 [{active}] API Key 已配置")
Expand All @@ -150,10 +184,6 @@ def _validate_config_on_startup(logger):
logger.info("✅ 配置检查完成")


if __name__ == '__main__':
if __name__ == "__main__":
app = create_app()
app.run(
host=Config.HOST,
port=Config.PORT,
debug=Config.DEBUG
)
app.run(host=Config.HOST, port=Config.PORT, debug=Config.DEBUG)
88 changes: 88 additions & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datetime import datetime
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Index
import json

db = SQLAlchemy()


class HistoryRecord(db.Model):
"""历史记录模型"""

__tablename__ = "history_records"
__table_args__ = (
Index("idx_history_records_status", "status"),
Index("idx_history_records_created_at", "created_at"),
Index("idx_history_records_task_id", "task_id"),
Index("idx_history_records_status_created_at", "status", "created_at"),
)

id = db.Column(db.String(36), primary_key=True) # 使用 UUID
title = db.Column(db.String(255), nullable=False)
status = db.Column(db.String(50), nullable=False, default="draft")
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(
db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

# 存储 JSON 数据为 TEXT
_outline_json = db.Column("outline", db.Text, nullable=True)
_images_json = db.Column("images", db.Text, nullable=True)

thumbnail = db.Column(db.String(255), nullable=True)
page_count = db.Column(db.Integer, default=0)
task_id = db.Column(db.String(100), nullable=True)

@property
def outline(self):
if self._outline_json:
try:
return json.loads(self._outline_json)
except Exception:
return {}
return {}

@outline.setter
def outline(self, value):
self._outline_json = json.dumps(value, ensure_ascii=False)

@property
def images(self):
if self._images_json:
try:
return json.loads(self._images_json)
except Exception:
return {"task_id": self.task_id, "generated": []}
return {"task_id": self.task_id, "generated": []}

@images.setter
def images(self, value):
self._images_json = json.dumps(value, ensure_ascii=False)

def to_dict(self):
"""转为字典供 API 返回"""
return {
"id": self.id,
"title": self.title,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"outline": self.outline,
"images": self.images,
"thumbnail": self.thumbnail,
"page_count": self.page_count,
"task_id": self.task_id,
}

def to_index_dict(self):
"""转为简要字典供列表 API 返回"""
return {
"id": self.id,
"title": self.title,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"thumbnail": self.thumbnail,
"page_count": self.page_count,
"task_id": self.task_id,
}
Loading