mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 13:52:42 +08:00
433 lines
13 KiB
Python
433 lines
13 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from typing import AsyncIterator, List, Optional, Tuple
|
|
|
|
from fastapi import APIRouter, Request, Security
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
|
|
|
from app import schemas
|
|
from app.api.openai_utils import (
|
|
build_completion_payload,
|
|
build_prompt,
|
|
build_responses_input,
|
|
build_session_id,
|
|
)
|
|
from app.agent import MoviePilotAgent, StreamingHandler
|
|
from app.core.config import settings
|
|
from app.core.security import openai_bearer_scheme
|
|
from app.schemas.types import MessageChannel
|
|
|
|
router = APIRouter()
|
|
|
|
MODEL_ID = "moviepilot-agent"
|
|
SESSION_PREFIX = "openai:"
|
|
|
|
|
|
class _CollectingMoviePilotAgent(MoviePilotAgent):
|
|
"""
|
|
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
|
|
"""
|
|
|
|
def __init__(self, *args, stream_mode: bool = False, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.collected_messages: List[str] = []
|
|
self.stream_mode = stream_mode
|
|
if stream_mode:
|
|
self.stream_handler = _OpenAIStreamingHandler()
|
|
|
|
def _should_stream(self) -> bool:
|
|
return self.stream_mode
|
|
|
|
async def send_agent_message(self, message: str, title: str = ""):
|
|
text = (message or "").strip()
|
|
if title and text:
|
|
text = f"{title}\n{text}"
|
|
elif title:
|
|
text = title.strip()
|
|
if text:
|
|
self.collected_messages.append(text)
|
|
if self.stream_mode:
|
|
self.stream_handler.emit(text)
|
|
|
|
async def _save_agent_message_to_db(self, message: str, title: str = ""):
|
|
return None
|
|
|
|
|
|
class _OpenAIStreamingHandler(StreamingHandler):
|
|
"""
|
|
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._event_queue: Optional[asyncio.Queue] = None
|
|
|
|
def bind_queue(self, queue: asyncio.Queue):
|
|
self._event_queue = queue
|
|
|
|
def emit(self, token: str):
|
|
emitted = super().emit(token)
|
|
if emitted and self._event_queue is not None:
|
|
self._event_queue.put_nowait(emitted)
|
|
|
|
def flush_pending_tool_summary(self) -> str:
|
|
emitted = super().flush_pending_tool_summary()
|
|
if emitted and self._event_queue is not None:
|
|
self._event_queue.put_nowait(emitted)
|
|
return emitted
|
|
|
|
async def start_streaming(
|
|
self,
|
|
channel: Optional[str] = None,
|
|
source: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
title: str = "",
|
|
):
|
|
self._channel = channel
|
|
self._source = source
|
|
self._user_id = user_id
|
|
self._username = username
|
|
self._title = title
|
|
self._streaming_enabled = True
|
|
self._sent_text = ""
|
|
self._message_response = None
|
|
self._msg_start_offset = 0
|
|
self._max_message_length = 0
|
|
|
|
async def stop_streaming(self) -> Tuple[bool, str]:
|
|
if not self._streaming_enabled:
|
|
return False, ""
|
|
self._streaming_enabled = False
|
|
with self._lock:
|
|
final_text = self._buffer
|
|
self._buffer = ""
|
|
self._sent_text = ""
|
|
self._message_response = None
|
|
self._msg_start_offset = 0
|
|
return True, final_text
|
|
|
|
|
|
def _sse_payload(data: dict) -> str:
|
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
async def _stream_response(
|
|
agent: _CollectingMoviePilotAgent,
|
|
prompt: str,
|
|
images: List[str],
|
|
) -> AsyncIterator[str]:
|
|
event_queue: asyncio.Queue = asyncio.Queue()
|
|
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
|
|
agent.stream_handler.bind_queue(event_queue)
|
|
|
|
created = int(time.time())
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
|
finished = False
|
|
|
|
async def _run_agent():
|
|
try:
|
|
await agent.process(prompt, images=images, files=None)
|
|
except Exception as exc:
|
|
await event_queue.put({"error": str(exc)})
|
|
finally:
|
|
await event_queue.put(None)
|
|
|
|
task = asyncio.create_task(_run_agent())
|
|
|
|
try:
|
|
yield _sse_payload(
|
|
{
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": MODEL_ID,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant"},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
while True:
|
|
item = await event_queue.get()
|
|
if item is None:
|
|
break
|
|
if isinstance(item, dict) and item.get("error"):
|
|
raise RuntimeError(str(item["error"]))
|
|
text = str(item or "")
|
|
if not text:
|
|
continue
|
|
yield _sse_payload(
|
|
{
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": MODEL_ID,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": text},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
finished = True
|
|
yield _sse_payload(
|
|
{
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": MODEL_ID,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
yield "data: [DONE]\n\n"
|
|
finally:
|
|
if not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
elif finished:
|
|
await task
|
|
|
|
|
|
def _error_response(
|
|
message: str,
|
|
status_code: int,
|
|
error_type: str = "invalid_request_error",
|
|
code: Optional[str] = None,
|
|
) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=status_code,
|
|
content=schemas.OpenAIErrorResponse(
|
|
error=schemas.OpenAIErrorDetail(
|
|
message=message,
|
|
type=error_type,
|
|
code=code,
|
|
)
|
|
).model_dump(),
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
def _check_auth(
|
|
credentials: Optional[HTTPAuthorizationCredentials],
|
|
) -> Optional[JSONResponse]:
|
|
if not credentials or credentials.scheme.lower() != "bearer":
|
|
return _error_response(
|
|
"Invalid bearer token.",
|
|
401,
|
|
error_type="authentication_error",
|
|
code="invalid_api_key",
|
|
)
|
|
if credentials.credentials != settings.API_TOKEN:
|
|
return _error_response(
|
|
"Invalid bearer token.",
|
|
401,
|
|
error_type="authentication_error",
|
|
code="invalid_api_key",
|
|
)
|
|
return None
|
|
|
|
|
|
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
|
|
async def list_models(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
|
):
|
|
auth_error = _check_auth(credentials)
|
|
if auth_error:
|
|
return auth_error
|
|
now = int(time.time())
|
|
return schemas.OpenAIModelListResponse(
|
|
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/chat/completions",
|
|
summary="OpenAI compatible chat completions",
|
|
response_model=schemas.OpenAIChatCompletionResponse,
|
|
)
|
|
async def chat_completions(
|
|
payload: schemas.OpenAIChatCompletionsRequest,
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
|
):
|
|
auth_error = _check_auth(credentials)
|
|
if auth_error:
|
|
return auth_error
|
|
|
|
if not settings.AI_AGENT_ENABLE:
|
|
return _error_response(
|
|
"MoviePilot AI agent is disabled.",
|
|
503,
|
|
error_type="server_error",
|
|
code="ai_agent_disabled",
|
|
)
|
|
|
|
if not payload.messages:
|
|
return _error_response(
|
|
"`messages` must be a non-empty array.",
|
|
400,
|
|
code="invalid_messages",
|
|
)
|
|
|
|
session_key = (
|
|
str(payload.user or "").strip()
|
|
or str(request.headers.get("x-session-id") or "").strip()
|
|
or str(uuid.uuid4())
|
|
)
|
|
use_server_session = bool(
|
|
str(payload.user or "").strip()
|
|
or str(request.headers.get("x-session-id") or "").strip()
|
|
)
|
|
|
|
try:
|
|
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
|
|
except ValueError as exc:
|
|
return _error_response(str(exc), 400, code="invalid_messages")
|
|
|
|
session_id = build_session_id(session_key, SESSION_PREFIX)
|
|
username = str(payload.user or "openai-client")
|
|
agent = _CollectingMoviePilotAgent(
|
|
session_id=session_id,
|
|
user_id=session_key,
|
|
channel=MessageChannel.Web.value,
|
|
source="openai",
|
|
username=username,
|
|
stream_mode=payload.stream,
|
|
)
|
|
|
|
if payload.stream:
|
|
return StreamingResponse(
|
|
_stream_response(agent=agent, prompt=prompt, images=images),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
try:
|
|
result = await agent.process(prompt, images=images, files=None)
|
|
except Exception as exc:
|
|
return _error_response(
|
|
str(exc),
|
|
500,
|
|
error_type="server_error",
|
|
code="agent_execution_failed",
|
|
)
|
|
|
|
content = "\n\n".join(
|
|
message.strip()
|
|
for message in agent.collected_messages
|
|
if message and message.strip()
|
|
).strip()
|
|
if not content and result:
|
|
content = str(result).strip()
|
|
if not content:
|
|
content = "未获得有效回复。"
|
|
|
|
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
|
|
|
|
|
|
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
|
|
async def responses(
|
|
payload: schemas.OpenAIResponsesRequest,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
|
):
|
|
auth_error = _check_auth(credentials)
|
|
if auth_error:
|
|
return auth_error
|
|
|
|
if not settings.AI_AGENT_ENABLE:
|
|
return _error_response(
|
|
"MoviePilot AI agent is disabled.",
|
|
503,
|
|
error_type="server_error",
|
|
code="ai_agent_disabled",
|
|
)
|
|
|
|
if payload.stream:
|
|
return _error_response(
|
|
"Streaming is not supported for /responses yet.",
|
|
400,
|
|
code="unsupported_stream",
|
|
)
|
|
|
|
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
|
|
if not normalized_messages:
|
|
return _error_response(
|
|
"`input` must include at least one usable message.",
|
|
400,
|
|
code="invalid_input",
|
|
)
|
|
|
|
try:
|
|
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
|
|
except ValueError as exc:
|
|
return _error_response(str(exc), 400, code="invalid_input")
|
|
|
|
session_key = str(payload.user or uuid.uuid4())
|
|
session_id = build_session_id(session_key, SESSION_PREFIX)
|
|
agent = _CollectingMoviePilotAgent(
|
|
session_id=session_id,
|
|
user_id=session_key,
|
|
channel=MessageChannel.Web.value,
|
|
source="openai.responses",
|
|
username=str(payload.user or "openai-client"),
|
|
stream_mode=False,
|
|
)
|
|
|
|
try:
|
|
result = await agent.process(prompt, images=images, files=None)
|
|
except Exception as exc:
|
|
return _error_response(
|
|
str(exc),
|
|
500,
|
|
error_type="server_error",
|
|
code="agent_execution_failed",
|
|
)
|
|
|
|
content = "\n\n".join(
|
|
message.strip()
|
|
for message in agent.collected_messages
|
|
if message and message.strip()
|
|
).strip()
|
|
if not content and result:
|
|
content = str(result).strip()
|
|
if not content:
|
|
content = "未获得有效回复。"
|
|
|
|
created_at = int(time.time())
|
|
response_id = f"resp_{uuid.uuid4().hex}"
|
|
output_message = schemas.OpenAIResponsesOutputMessage(
|
|
id=f"msg_{uuid.uuid4().hex}",
|
|
content=[schemas.OpenAIResponsesOutputText(text=content)],
|
|
)
|
|
return schemas.OpenAIResponsesResponse(
|
|
id=response_id,
|
|
created_at=created_at,
|
|
model=MODEL_ID,
|
|
output=[output_message],
|
|
usage=schemas.OpenAIUsage(),
|
|
)
|