mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 13:52:42 +08:00
285 lines
8.3 KiB
Python
285 lines
8.3 KiB
Python
import re
|
|
from typing import Annotated, Optional
|
|
|
|
from fastapi import APIRouter, Body, Depends, Request
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app import schemas
|
|
from app.agent.llm import (
|
|
LLMHelper,
|
|
LLMProviderManager,
|
|
LLMTestTimeout,
|
|
render_auth_result_html,
|
|
)
|
|
from app.core.config import settings
|
|
from app.db.models import User
|
|
from app.db.user_oper import (
|
|
get_current_active_superuser_async,
|
|
get_current_active_user_async,
|
|
)
|
|
from app.log import logger
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class LlmTestRequest(BaseModel):
|
|
enabled: Optional[bool] = None
|
|
provider: Optional[str] = None
|
|
model: Optional[str] = None
|
|
thinking_level: Optional[str] = None
|
|
api_key: Optional[str] = None
|
|
base_url: Optional[str] = None
|
|
|
|
|
|
class LlmProviderAuthStartRequest(BaseModel):
|
|
provider: str
|
|
method: str
|
|
|
|
|
|
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
|
"""
|
|
清理错误信息中的敏感字段,避免回显密钥。
|
|
"""
|
|
if not message:
|
|
return "LLM 调用失败"
|
|
|
|
sanitized = message
|
|
if api_key:
|
|
sanitized = sanitized.replace(api_key, "***")
|
|
sanitized = re.sub(
|
|
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
|
r"\1***",
|
|
sanitized,
|
|
)
|
|
sanitized = re.sub(
|
|
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
|
"Authorization: ***",
|
|
sanitized,
|
|
)
|
|
return sanitized
|
|
|
|
|
|
@router.get("/models", summary="获取LLM模型列表", response_model=schemas.Response)
|
|
async def get_llm_models(
|
|
provider: str,
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
force_refresh: Optional[bool] = False,
|
|
_: User = Depends(get_current_active_user_async),
|
|
):
|
|
"""
|
|
获取指定 provider 的模型目录。
|
|
"""
|
|
try:
|
|
provider_manager = LLMProviderManager()
|
|
models = await LLMHelper().get_models(
|
|
provider=provider,
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
force_refresh=bool(force_refresh),
|
|
)
|
|
return schemas.Response(
|
|
success=True,
|
|
data={
|
|
"provider": provider,
|
|
"models": models,
|
|
"auth_status": provider_manager.get_auth_status(provider),
|
|
},
|
|
)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.get("/providers", summary="获取LLM提供商目录", response_model=schemas.Response)
|
|
async def get_llm_providers(
|
|
_: User = Depends(get_current_active_user_async),
|
|
):
|
|
"""
|
|
返回前端可直接渲染的 provider 目录。
|
|
"""
|
|
try:
|
|
providers = LLMProviderManager().list_providers()
|
|
return schemas.Response(success=True, data=providers)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.post(
|
|
"/provider-auth/start",
|
|
summary="启动LLM提供商授权",
|
|
response_model=schemas.Response,
|
|
)
|
|
async def start_llm_provider_auth(
|
|
payload: LlmProviderAuthStartRequest,
|
|
request: Request,
|
|
_: User = Depends(get_current_active_superuser_async),
|
|
):
|
|
"""
|
|
启动 provider 授权会话。
|
|
"""
|
|
try:
|
|
callback_url = None
|
|
if payload.provider == "chatgpt" and payload.method == "browser_oauth":
|
|
callback_url = str(
|
|
request.url_for("llm_provider_auth_callback", provider_id=payload.provider)
|
|
)
|
|
result = await LLMProviderManager().start_auth(
|
|
payload.provider,
|
|
payload.method,
|
|
callback_url,
|
|
)
|
|
return schemas.Response(success=True, data=result)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.get(
|
|
"/provider-auth/{session_id}",
|
|
summary="获取LLM提供商授权会话状态",
|
|
response_model=schemas.Response,
|
|
)
|
|
async def get_llm_provider_auth_session(
|
|
session_id: str,
|
|
_: User = Depends(get_current_active_superuser_async),
|
|
):
|
|
"""
|
|
查询授权会话状态。
|
|
"""
|
|
try:
|
|
result = LLMProviderManager().get_session_status(session_id)
|
|
return schemas.Response(success=True, data=result)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.post(
|
|
"/provider-auth/{session_id}/poll",
|
|
summary="轮询LLM提供商授权会话",
|
|
response_model=schemas.Response,
|
|
)
|
|
async def poll_llm_provider_auth_session(
|
|
session_id: str,
|
|
_: User = Depends(get_current_active_superuser_async),
|
|
):
|
|
"""
|
|
轮询 device code / OAuth 会话状态。
|
|
"""
|
|
try:
|
|
result = await LLMProviderManager().poll_auth_session(session_id)
|
|
return schemas.Response(success=True, data=result)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.delete(
|
|
"/provider-auth/{provider_id}",
|
|
summary="断开LLM提供商授权",
|
|
response_model=schemas.Response,
|
|
)
|
|
async def delete_llm_provider_auth(
|
|
provider_id: str,
|
|
_: User = Depends(get_current_active_superuser_async),
|
|
):
|
|
"""
|
|
删除已保存的 provider 授权信息。
|
|
"""
|
|
try:
|
|
await LLMProviderManager().clear_auth(provider_id)
|
|
return schemas.Response(success=True)
|
|
except Exception as err:
|
|
return schemas.Response(success=False, message=str(err))
|
|
|
|
|
|
@router.get(
|
|
"/provider-auth/callback/{provider_id}",
|
|
summary="LLM提供商OAuth回调",
|
|
response_class=HTMLResponse,
|
|
name="llm_provider_auth_callback",
|
|
)
|
|
async def llm_provider_auth_callback(
|
|
provider_id: str,
|
|
code: Optional[str] = None,
|
|
state: Optional[str] = None,
|
|
error: Optional[str] = None,
|
|
error_description: Optional[str] = None,
|
|
):
|
|
"""
|
|
处理需要浏览器回跳的 OAuth provider。
|
|
"""
|
|
success, message = await LLMProviderManager().handle_chatgpt_callback(
|
|
provider_id,
|
|
code,
|
|
state,
|
|
error,
|
|
error_description,
|
|
)
|
|
return HTMLResponse(content=render_auth_result_html(success, message))
|
|
|
|
|
|
@router.post("/test", summary="测试LLM调用", response_model=schemas.Response)
|
|
async def llm_test(
|
|
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
|
_: User = Depends(get_current_active_superuser_async),
|
|
):
|
|
"""
|
|
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
|
"""
|
|
payload = payload or LlmTestRequest(
|
|
enabled=settings.AI_AGENT_ENABLE,
|
|
provider=settings.LLM_PROVIDER,
|
|
model=settings.LLM_MODEL,
|
|
thinking_level=settings.LLM_THINKING_LEVEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
base_url=settings.LLM_BASE_URL,
|
|
)
|
|
|
|
if not payload.provider:
|
|
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
|
if not payload.model or not payload.model.strip():
|
|
return schemas.Response(success=False, message="请先配置 LLM 模型")
|
|
|
|
data = {
|
|
"provider": payload.provider,
|
|
"model": payload.model,
|
|
}
|
|
if not payload.enabled:
|
|
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
|
|
|
if (
|
|
payload.provider not in {"chatgpt", "github-copilot"}
|
|
and (not payload.api_key or not payload.api_key.strip())
|
|
):
|
|
return schemas.Response(
|
|
success=False,
|
|
message="请先配置 LLM API Key",
|
|
data=data,
|
|
)
|
|
|
|
try:
|
|
result = await LLMHelper.test_current_settings(
|
|
provider=payload.provider,
|
|
model=payload.model,
|
|
thinking_level=payload.thinking_level,
|
|
api_key=payload.api_key,
|
|
base_url=payload.base_url,
|
|
)
|
|
if not result.get("reply_preview"):
|
|
return schemas.Response(
|
|
success=False,
|
|
message="模型响应为空",
|
|
data=result,
|
|
)
|
|
return schemas.Response(success=True, data=result)
|
|
except (LLMTestTimeout, TimeoutError) as err:
|
|
logger.warning(err)
|
|
return schemas.Response(
|
|
success=False,
|
|
message="LLM 调用超时",
|
|
)
|
|
except Exception as err:
|
|
return schemas.Response(
|
|
success=False,
|
|
message=_sanitize_llm_test_error(str(err), payload.api_key),
|
|
)
|