diff --git a/app/agent/__init__.py b/app/agent/__init__.py index a35423be..528f61f6 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -33,7 +33,7 @@ from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings -from app.helper.llm import LLMHelper +from app.agent.llm import LLMHelper from app.log import logger from app.schemas import Notification, NotificationType from app.schemas.message import ChannelCapabilityManager, ChannelCapability @@ -310,12 +310,12 @@ class MoviePilotAgent: return False @staticmethod - def _initialize_llm(streaming: bool = False): + async def _initialize_llm(streaming: bool = False): """ 初始化 LLM :param streaming: 是否启用流式输出 """ - return LLMHelper.get_llm(streaming=streaming) + return await LLMHelper.get_llm(streaming=streaming) @staticmethod def _extract_text_content(content) -> str: @@ -387,7 +387,7 @@ class MoviePilotAgent: allow_message_tools=self.allow_message_tools, ) - def _create_agent(self, streaming: bool = False): + async def _create_agent(self, streaming: bool = False): """ 创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware) :param streaming: 是否启用流式输出 @@ -397,12 +397,12 @@ class MoviePilotAgent: system_prompt = prompt_manager.get_agent_prompt(channel=self.channel) # LLM 模型(用于 agent 执行) - llm = self._initialize_llm(streaming=streaming) + llm = await self._initialize_llm(streaming=streaming) self._sync_model_profile(llm) # 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。 non_streaming_llm = ( - llm if not streaming else self._initialize_llm(streaming=False) + llm if not streaming else await self._initialize_llm(streaming=False) ) # 工具列表 @@ -577,7 +577,7 @@ class MoviePilotAgent: use_streaming = self._should_stream() # 创建智能体(根据是否流式传入不同 LLM) - agent = self._create_agent(streaming=use_streaming) + agent = await self._create_agent(streaming=use_streaming) if use_streaming: self.stream_handler.set_dispatch_policy( diff --git a/app/agent/llm/__init__.py b/app/agent/llm/__init__.py new file mode 100644 index 00000000..bad76183 --- /dev/null +++ b/app/agent/llm/__init__.py @@ -0,0 +1,19 @@ +"""Agent 内部使用的 LLM 适配层。""" + +from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout +from app.agent.llm.provider import ( + LLMProviderAuthError, + LLMProviderError, + LLMProviderManager, + render_auth_result_html, +) + +__all__ = [ + "LLMHelper", + "LLMProviderAuthError", + "LLMProviderError", + "LLMProviderManager", + "LLMTestError", + "LLMTestTimeout", + "render_auth_result_html", +] diff --git a/app/helper/llm.py b/app/agent/llm/helper.py similarity index 72% rename from app/helper/llm.py rename to app/agent/llm/helper.py index d84f18b1..ad7b068d 100644 --- a/app/helper/llm.py +++ b/app/agent/llm/helper.py @@ -342,7 +342,7 @@ class LLMHelper: return {} # OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。 - if provider_name == "openai" and model_name.startswith( + if provider_name in {"openai", "chatgpt"} and model_name.startswith( ("gpt-5", "o1", "o3", "o4") ): openai_effort = cls._normalize_openai_reasoning_effort( @@ -366,11 +366,89 @@ class LLMHelper: return bool(settings.LLM_SUPPORT_IMAGE_INPUT) @staticmethod - def get_llm( + def _build_legacy_runtime( + provider_name: str, + model_name: str | None, + api_key: str | None = None, + base_url: str | None = None, + ) -> dict[str, Any]: + """ + 在 provider 目录不可用时回退到旧的直接构造逻辑。 + + 这主要用于单测 stub 环境以及极端的最小运行环境,正常生产路径仍优先 + 走 `LLMProviderManager.resolve_runtime()`。 + """ + api_key_value = api_key if api_key is not None else settings.LLM_API_KEY + base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL + if not api_key_value: + raise ValueError("未配置LLM API Key") + + runtime_name = provider_name if provider_name in {"google", "deepseek"} else "openai_compatible" + return { + "provider_id": provider_name, + "runtime": runtime_name, + "model_id": model_name, + "api_key": api_key_value, + "base_url": base_url_value, + "default_headers": None, + "use_responses_api": None, + "model_record": None, + "model_metadata": None, + } + + @classmethod + def _resolve_thinking_level( + cls, + thinking_level: str | None = None, + disable_thinking: bool | None = None, + reasoning_effort: str | None = None, + ) -> str | None: + """ + 统一兼容新旧 thinking 参数。 + + 新前端只会传 `thinking_level`,但测试和部分旧调用仍可能带 + `disable_thinking` / `reasoning_effort`,这里集中做一次归一化。 + """ + + def _normalize(value: str | None) -> str | None: + normalized = str(value or "").strip().lower() + if not normalized: + return None + alias_map = { + "none": "off", + "disabled": "off", + "disable": "off", + "enabled": "auto", + "enable": "auto", + "default": "auto", + "dynamic": "auto", + } + normalized = alias_map.get(normalized, normalized) + if normalized in cls._SUPPORTED_THINKING_LEVELS: + return normalized + logger.warning(f"忽略不支持的思考级别: {value}") + return None + + normalized_thinking_level = _normalize(thinking_level) + if normalized_thinking_level: + return normalized_thinking_level + + legacy_effort = _normalize(reasoning_effort) + if disable_thinking: + return "off" + if disable_thinking is False: + return legacy_effort or "auto" + return legacy_effort + + @classmethod + async def get_llm( + cls, streaming: bool = False, provider: str | None = None, model: str | None = None, thinking_level: str | None = None, + disable_thinking: bool | None = None, + reasoning_effort: str | None = None, api_key: str | None = None, base_url: str | None = None, ): @@ -383,28 +461,50 @@ class LLMHelper: 是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。 不同模型对思考模式的支持和表现不同,具体映射关系请 参考代码实现。对于不支持思考模式的模型,该参数将被忽略。 + :param disable_thinking: 兼容旧参数,若传入则会被转换为新的 + `thinking_level` 语义。 + :param reasoning_effort: 兼容旧参数,若传入则会被转换为新的 + `thinking_level` 语义。 :param api_key: API Key,默认为 配置项LLM_API_KEY。对于某些提供商( 如 DeepSeek),可能需要同时提供 base_url。 :param base_url: API Base URL,默认为配置项LLM_BASE_URL。 :return: LLM实例 """ - provider_name = str( - provider if provider is not None else settings.LLM_PROVIDER - ).lower() + provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).lower() model_name = model if model is not None else settings.LLM_MODEL - api_key_value = api_key if api_key is not None else settings.LLM_API_KEY - base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL - thinking_kwargs = LLMHelper._build_thinking_kwargs( + normalized_thinking_level = cls._resolve_thinking_level( + thinking_level=thinking_level, + disable_thinking=disable_thinking, + reasoning_effort=reasoning_effort, + ) + try: + # 延迟导入,避免单测在最小 stub 环境下 import `llm.py` 时被 provider + # 目录依赖链拖住。 + from app.agent.llm.provider import LLMProviderManager + + runtime = await LLMProviderManager().resolve_runtime( + provider_id=provider_name, + model=model_name, + api_key=api_key, + base_url=base_url, + ) + except Exception as err: + logger.debug(f"LLM provider 目录不可用,回退到旧运行时逻辑: {err}") + runtime = cls._build_legacy_runtime( + provider_name=provider_name, + model_name=model_name, + api_key=api_key, + base_url=base_url, + ) + model_name = runtime.get("model_id") or model_name + thinking_kwargs = cls._build_thinking_kwargs( provider=provider_name, model=model_name, - thinking_level=thinking_level + thinking_level=normalized_thinking_level, ) - if not api_key_value: - raise ValueError("未配置LLM API Key") - - if provider_name == "google": + if runtime["runtime"] == "google": # 修补 Gemini 2.5 思考模型的 thought_signature 兼容性 _patch_gemini_thought_signature() @@ -420,49 +520,76 @@ class LLMHelper: model = ChatGoogleGenerativeAI( model=model_name, - api_key=api_key_value, + api_key=runtime["api_key"], retries=3, temperature=settings.LLM_TEMPERATURE, streaming=streaming, client_args=client_args, **thinking_kwargs, ) - elif provider_name == "deepseek": + elif runtime["runtime"] == "deepseek": from langchain_deepseek import ChatDeepSeek _patch_deepseek_reasoning_content_support() model = ChatDeepSeek( model=model_name, - api_key=api_key_value, - api_base=base_url_value, + api_key=runtime["api_key"], + api_base=runtime["base_url"], max_retries=3, temperature=settings.LLM_TEMPERATURE, streaming=streaming, stream_usage=True, **thinking_kwargs, ) + elif runtime["runtime"] in {"anthropic_compatible", "copilot_anthropic"}: + from langchain_anthropic import ChatAnthropic + + model = ChatAnthropic( + model=model_name, + api_key=runtime["api_key"], + base_url=runtime["base_url"], + max_retries=3, + temperature=settings.LLM_TEMPERATURE, + streaming=streaming, + stream_usage=True, + anthropic_proxy=settings.PROXY_HOST, + default_headers=runtime.get("default_headers"), + **thinking_kwargs, + ) else: from langchain_openai import ChatOpenAI model = ChatOpenAI( model=model_name, - api_key=api_key_value, + api_key=runtime["api_key"], max_retries=3, - base_url=base_url_value, + base_url=runtime.get("base_url"), temperature=settings.LLM_TEMPERATURE, streaming=streaming, stream_usage=True, openai_proxy=settings.PROXY_HOST, + default_headers=runtime.get("default_headers"), + use_responses_api=runtime.get("use_responses_api"), **thinking_kwargs, ) - # 检查是否有profile - if hasattr(model, "profile") and model.profile: + # 优先使用 provider / models.dev 目录中的上下文上限,减少用户手填成本。 + model_profile = getattr(model, "profile", None) + if model_profile: logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}") else: + model_record = runtime.get("model_record") or {} + model_metadata = runtime.get("model_metadata") or {} + metadata_limit = model_metadata.get("limit") or {} + max_input_tokens = ( + model_record.get("input_tokens") + or model_record.get("context_tokens") + or metadata_limit.get("input") + or metadata_limit.get("context") + or settings.LLM_MAX_CONTEXT_TOKENS * 1000 + ) model.profile = { - "max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS - * 1000, # 转换为token单位 + "max_input_tokens": int(max_input_tokens), } return model @@ -514,6 +641,8 @@ class LLMHelper: provider: str | None = None, model: str | None = None, thinking_level: str | None = None, + disable_thinking: bool | None = None, + reasoning_effort: str | None = None, api_key: str | None = None, base_url: str | None = None, ) -> dict: @@ -522,16 +651,16 @@ class LLMHelper: """ provider_name = provider if provider is not None else settings.LLM_PROVIDER model_name = model if model is not None else settings.LLM_MODEL - api_key_value = api_key if api_key is not None else settings.LLM_API_KEY - base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL start = time.perf_counter() - llm = LLMHelper.get_llm( + llm = await LLMHelper.get_llm( streaming=False, provider=provider_name, model=model_name, thinking_level=thinking_level, - api_key=api_key_value, - base_url=base_url_value, + disable_thinking=disable_thinking, + reasoning_effort=reasoning_effort, + api_key=api_key, + base_url=base_url, ) try: response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout) @@ -556,18 +685,47 @@ class LLMHelper: data["reply_preview"] = reply_text[:120] return data - def get_models( - self, provider: str, api_key: str, base_url: str = None - ) -> List[str]: - """获取模型列表""" + async def get_models( + self, + provider: str, + api_key: str | None = None, + base_url: str | None = None, + force_refresh: bool = False, + ) -> List[dict[str, Any]]: + """ + 获取模型列表。 + + 返回值会带上 context/supports_reasoning 等元数据,供前端直接渲染并自动 + 回填上下文大小。 + """ logger.info(f"获取 {provider} 模型列表...") - if provider == "google": - return self._get_google_models(api_key) - else: - return self._get_openai_compatible_models(provider, api_key, base_url) + try: + from app.agent.llm.provider import LLMProviderManager + + return await LLMProviderManager().list_models( + provider_id=provider, + api_key=api_key, + base_url=base_url, + force_refresh=force_refresh, + ) + except Exception as err: + logger.debug(f"LLM provider 目录不可用,回退旧模型列表逻辑: {err}") + if provider == "google": + return [ + {"id": model_id, "name": model_id} + for model_id in await self._get_google_models(api_key or "") + ] + return [ + {"id": model_id, "name": model_id} + for model_id in await self._get_openai_compatible_models( + provider, + api_key or "", + base_url, + ) + ] @staticmethod - def _get_google_models(api_key: str) -> List[str]: + async def _get_google_models(api_key: str) -> List[str]: """获取Google模型列表(使用 google-genai SDK v1)""" try: from google import genai @@ -583,29 +741,32 @@ class LLMHelper: ) client = genai.Client(api_key=api_key, http_options=http_options) - models = client.models.list() - return [ + models = await client.aio.models.list() + result = [ m.name - for m in models + for m in models.page if m.supported_actions and "generateContent" in m.supported_actions ] + await client.aio.aclose() + return result except Exception as e: logger.error(f"获取Google模型列表失败:{e}") raise e @staticmethod - def _get_openai_compatible_models( + async def _get_openai_compatible_models( provider: str, api_key: str, base_url: str = None ) -> List[str]: """获取OpenAI兼容模型列表""" try: - from openai import OpenAI + from openai import AsyncOpenAI if provider == "deepseek": base_url = base_url or "https://api.deepseek.com" - client = OpenAI(api_key=api_key, base_url=base_url) - models = client.models.list() + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + models = await client.models.list() + await client.close() return [model.id for model in models.data] except Exception as e: logger.error(f"获取 {provider} 模型列表失败:{e}") diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py new file mode 100644 index 00000000..8fbaeb03 --- /dev/null +++ b/app/agent/llm/provider.py @@ -0,0 +1,1507 @@ +"""LLM provider registry, auth flows, and model metadata helpers.""" + +from __future__ import annotations + +import asyncio +import base64 +import copy +import hashlib +import json +import secrets +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlencode + +import aiofiles +import httpx +import jwt + +from app.core.config import settings +from app.db.systemconfig_oper import SystemConfigOper +from app.log import logger +from app.schemas.types import SystemConfigKey +from app.utils.singleton import Singleton + + +class LLMProviderError(RuntimeError): + """通用 LLM provider 异常。""" + + +class LLMProviderAuthError(LLMProviderError): + """LLM provider 鉴权异常。""" + + +@dataclass(frozen=True) +class ProviderAuthMethod: + """前端展示用的授权方式定义。""" + + id: str + type: str + label: str + description: str = "" + + +@dataclass(frozen=True) +class ProviderSpec: + """描述一个可接入的 LLM provider。""" + + id: str + name: str + runtime: str + models_dev_provider_id: Optional[str] = None + default_base_url: Optional[str] = None + base_url_editable: bool = False + requires_base_url: bool = False + supports_api_key: bool = True + api_key_label: str = "API Key" + api_key_hint: str = "" + oauth_methods: Tuple[ProviderAuthMethod, ...] = () + supports_model_refresh: bool = True + model_list_strategy: str = "openai_compatible" + sort_order: int = 100 + description: str = "" + + +@dataclass +class PendingAuthSession: + """保存临时鉴权会话,避免把 PKCE/device code 等状态写回配置。""" + + session_id: str + provider_id: str + method_id: str + flow_type: str + status: str = "pending" + message: str = "" + authorize_url: Optional[str] = None + instructions: Optional[str] = None + verification_url: Optional[str] = None + user_code: Optional[str] = None + interval_seconds: int = 5 + expires_at: float = 0 + created_at: float = field(default_factory=time.time) + context: Dict[str, Any] = field(default_factory=dict) + + +class LLMProviderManager(metaclass=Singleton): + """统一维护 provider 目录、models.dev 缓存和 OAuth 状态。""" + + _MODELS_DEV_URL = "https://models.dev/api.json" + _MODELS_DEV_CACHE_TTL = 5 * 60 + _CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" + _CHATGPT_ISSUER = "https://auth.openai.com" + _CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" + _COPILOT_CLIENT_ID = "Ov23li8tweQw6odWQebz" + _DEFAULT_TIMEOUT = httpx.Timeout(15.0, connect=10.0) + _CHATGPT_ALLOWED_OAUTH_MODELS = { + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.4", + "gpt-5.4-mini", + "gpt-5.5", + } + + def __init__(self): + self._lock = threading.RLock() + self._models_dev_lock = asyncio.Lock() + self._pending_sessions: dict[str, PendingAuthSession] = {} + self._oauth_state_index: dict[str, str] = {} + self._models_dev_data: dict[str, Any] | None = None + self._models_dev_loaded_at: float = 0 + self._models_dev_cache_path = ( + Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json" + ) + + @staticmethod + def _provider_specs() -> tuple[ProviderSpec, ...]: + """ + 返回受支持的 provider 定义。 + + OpenAI 保留为用户自定义 OpenAI-compatible 兜底入口,因此仍要求填写 + Base URL;ChatGPT 则单独承接官方 API Key / ChatGPT 订阅鉴权。 + """ + browser_auth = ProviderAuthMethod( + id="browser_oauth", + type="oauth", + label="浏览器授权", + description="使用 ChatGPT Plus/Pro 浏览器登录并回调授权。", + ) + device_auth = ProviderAuthMethod( + id="device_code", + type="device", + label="设备码授权", + description="适合无回调环境,复制设备码到浏览器完成登录。", + ) + return ( + ProviderSpec( + id="chatgpt", + name="ChatGPT", + runtime="chatgpt", + models_dev_provider_id="openai", + default_base_url="https://api.openai.com/v1", + api_key_hint="可直接填写 OpenAI API Key,或使用 ChatGPT Plus/Pro 登录授权。", + oauth_methods=(browser_auth, device_auth), + model_list_strategy="chatgpt", + description="支持 ChatGPT Plus/Pro 鉴权或 OpenAI 官方 API Key。", + sort_order=10, + ), + ProviderSpec( + id="google", + name="Google", + runtime="google", + models_dev_provider_id="google", + supports_api_key=True, + api_key_hint="填写 Gemini / Google AI Studio API Key。", + model_list_strategy="google", + description="Gemini / Google AI Studio。", + sort_order=20, + ), + ProviderSpec( + id="deepseek", + name="DeepSeek", + runtime="deepseek", + models_dev_provider_id="deepseek", + default_base_url="https://api.deepseek.com", + api_key_hint="填写 DeepSeek API Key。", + description="DeepSeek 官方平台。", + sort_order=30, + ), + ProviderSpec( + id="openrouter", + name="OpenRouter", + runtime="openai_compatible", + models_dev_provider_id="openrouter", + default_base_url="https://openrouter.ai/api/v1", + api_key_hint="填写 OpenRouter API Key。", + description="OpenRouter 聚合模型平台。", + sort_order=40, + ), + ProviderSpec( + id="github-copilot", + name="GitHub Copilot", + runtime="github_copilot", + models_dev_provider_id="github-copilot", + supports_api_key=False, + api_key_label="GitHub Token", + oauth_methods=( + ProviderAuthMethod( + id="device_code", + type="device", + label="GitHub 设备码授权", + description="使用 GitHub Copilot 订阅登录授权。", + ), + ), + model_list_strategy="github_copilot", + description="通过 GitHub Copilot 订阅接入。", + sort_order=50, + ), + ProviderSpec( + id="nvidia", + name="Nvidia", + runtime="openai_compatible", + models_dev_provider_id="nvidia", + default_base_url="https://integrate.api.nvidia.com/v1", + api_key_hint="填写 Nvidia API Key。", + description="Nvidia 集成推理平台。", + sort_order=60, + ), + ProviderSpec( + id="siliconflow", + name="硅基流动", + runtime="openai_compatible", + models_dev_provider_id="siliconflow", + default_base_url="https://api.siliconflow.com/v1", + api_key_hint="填写硅基流动 API Key。", + description="SiliconFlow 官方兼容端点。", + sort_order=70, + ), + ProviderSpec( + id="ollama-cloud", + name="Ollama Cloud", + runtime="openai_compatible", + models_dev_provider_id="ollama-cloud", + default_base_url="https://ollama.com/v1", + api_key_hint="填写 Ollama Cloud API Key。", + description="Ollama Cloud 云端模型服务。", + sort_order=80, + ), + ProviderSpec( + id="alibaba", + name="Alibaba", + runtime="openai_compatible", + models_dev_provider_id="alibaba", + default_base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key_hint="填写 DashScope / Alibaba API Key。", + description="阿里云百炼兼容端点。", + sort_order=90, + ), + ProviderSpec( + id="minimax", + name="MiniMax", + runtime="anthropic_compatible", + models_dev_provider_id="minimax", + default_base_url="https://api.minimax.io/anthropic/v1", + api_key_hint="填写 MiniMax API Key。", + model_list_strategy="anthropic_compatible", + description="MiniMax Anthropic-compatible 端点。", + sort_order=100, + ), + ProviderSpec( + id="xiaomi", + name="Xiaomi", + runtime="openai_compatible", + models_dev_provider_id="xiaomi", + default_base_url="https://api.xiaomimimo.com/v1", + api_key_hint="填写 Xiaomi API Key。", + description="小米 Mimo 兼容端点。", + sort_order=110, + ), + ProviderSpec( + id="tencent", + name="Tencent", + runtime="openai_compatible", + models_dev_provider_id="tencent", + default_base_url="https://api.lkeap.cloud.tencent.com/coding/v3", + api_key_hint="填写 Tencent API Key。", + description="腾讯兼容端点。", + sort_order=120, + ), + ProviderSpec( + id="openai", + name="OpenAI Compatible", + runtime="openai_compatible", + default_base_url="", + base_url_editable=True, + requires_base_url=True, + supports_api_key=True, + api_key_hint="通用 OpenAI-compatible 兜底入口,需要手动填写 Base URL。", + description="通用 OpenAI-compatible 模型服务。", + sort_order=200, + ), + ) + + def list_providers(self) -> list[dict[str, Any]]: + """返回前端可渲染的 provider 目录。""" + providers = [] + for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order): + providers.append( + { + "id": spec.id, + "name": spec.name, + "runtime": spec.runtime, + "default_base_url": spec.default_base_url or "", + "base_url_editable": spec.base_url_editable, + "requires_base_url": spec.requires_base_url, + "supports_api_key": spec.supports_api_key, + "api_key_label": spec.api_key_label, + "api_key_hint": spec.api_key_hint, + "supports_model_refresh": spec.supports_model_refresh, + "oauth_methods": [ + { + "id": method.id, + "type": method.type, + "label": method.label, + "description": method.description, + } + for method in spec.oauth_methods + ], + "description": spec.description, + "auth_status": self.get_auth_status(spec.id), + } + ) + return providers + + def get_provider(self, provider_id: str) -> ProviderSpec: + """按 provider id 获取定义。""" + normalized = (provider_id or "").strip().lower() + for spec in self._provider_specs(): + if spec.id == normalized: + return spec + raise LLMProviderError(f"不支持的 LLM 提供商:{provider_id}") + + @staticmethod + def _sanitize_base_url(base_url: Optional[str]) -> Optional[str]: + if base_url is None: + return None + value = str(base_url).strip() + if not value: + return None + return value.rstrip("/") + + @staticmethod + def _httpx_proxy_key() -> str: + """兼容不同 httpx 版本的 proxy 参数名。""" + params = httpx.Client.__init__.__code__.co_varnames + return "proxy" if "proxy" in params else "proxies" + + def _build_httpx_kwargs(self) -> dict[str, Any]: + kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT} + if settings.PROXY_HOST: + kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST + return kwargs + + @staticmethod + def _read_agent_config() -> dict[str, Any]: + config = SystemConfigOper().get(SystemConfigKey.AIAgentConfig) + if isinstance(config, dict): + return config + return {} + + @staticmethod + async def _write_agent_config(value: dict[str, Any]) -> None: + """ + 使用异步持久化写回 provider 鉴权配置。 + + `SystemConfigOper().get()` 读取的是内存缓存,这里保留同步调用; + 但写入需要落库,因此统一走 `async_set()`。 + """ + await SystemConfigOper().async_set( + SystemConfigKey.AIAgentConfig, + copy.deepcopy(value) or None, + ) + + def _get_auth_store(self) -> dict[str, Any]: + config = self._read_agent_config() + auth_store = config.get("provider_auth") + if isinstance(auth_store, dict): + return auth_store + return {} + + def get_saved_auth(self, provider_id: str) -> dict[str, Any] | None: + """读取持久化 provider 鉴权信息。""" + return copy.deepcopy(self._get_auth_store().get(provider_id)) + + async def save_auth(self, provider_id: str, auth_data: dict[str, Any]) -> None: + """写入 provider 鉴权信息。""" + config = self._read_agent_config() + auth_store = config.get("provider_auth") + if not isinstance(auth_store, dict): + auth_store = {} + auth_store[provider_id] = copy.deepcopy(auth_data) + config["provider_auth"] = auth_store + await self._write_agent_config(config) + + async def clear_auth(self, provider_id: str) -> None: + """移除 provider 鉴权信息。""" + config = self._read_agent_config() + auth_store = config.get("provider_auth") + if not isinstance(auth_store, dict): + return + auth_store.pop(provider_id, None) + if auth_store: + config["provider_auth"] = auth_store + else: + config.pop("provider_auth", None) + await self._write_agent_config(config) + + def get_auth_status(self, provider_id: str) -> dict[str, Any]: + """返回前端展示用的 provider 鉴权摘要。""" + auth = self.get_saved_auth(provider_id) + if not auth: + return {"connected": False} + return { + "connected": True, + "type": auth.get("type"), + "label": auth.get("label") or auth.get("email") or auth.get("account_id") or "已授权", + "expires_at": auth.get("expires_at"), + "updated_at": auth.get("updated_at"), + } + + async def _load_models_dev_from_disk(self) -> dict[str, Any] | None: + try: + if not self._models_dev_cache_path.exists(): + return None + async with aiofiles.open( + self._models_dev_cache_path, mode="r", encoding="utf-8" + ) as stream: + return json.loads(await stream.read()) + except Exception as err: + logger.warning(f"读取 models.dev 缓存失败: {err}") + return None + + async def _write_models_dev_to_disk(self, payload: dict[str, Any]) -> None: + try: + self._models_dev_cache_path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open( + self._models_dev_cache_path, mode="w", encoding="utf-8" + ) as stream: + await stream.write(json.dumps(payload, ensure_ascii=False)) + except Exception as err: + logger.warning(f"写入 models.dev 缓存失败: {err}") + + async def _fetch_models_dev(self) -> dict[str, Any]: + headers = {"User-Agent": "MoviePilot/1.0"} + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.get(self._MODELS_DEV_URL, headers=headers) + response.raise_for_status() + return response.json() + + async def get_models_dev_data(self, force_refresh: bool = False) -> dict[str, Any]: + """ + 返回 models.dev 原始数据。 + + 这里复用 opencode 的做法,把公共模型目录缓存到本地文件中,避免每次 + 刷新模型列表都直接打到远端。 + """ + async with self._models_dev_lock: + now = time.time() + if ( + not force_refresh + and self._models_dev_data is not None + and now - self._models_dev_loaded_at < self._MODELS_DEV_CACHE_TTL + ): + return self._models_dev_data + + if not force_refresh and self._models_dev_cache_path.exists(): + mtime = self._models_dev_cache_path.stat().st_mtime + if now - mtime < self._MODELS_DEV_CACHE_TTL: + cached = await self._load_models_dev_from_disk() + if isinstance(cached, dict): + self._models_dev_data = cached + self._models_dev_loaded_at = now + return cached + + try: + payload = await self._fetch_models_dev() + self._models_dev_data = payload + self._models_dev_loaded_at = now + await self._write_models_dev_to_disk(payload) + return payload + except Exception as err: + logger.warning(f"刷新 models.dev 失败,尝试回退本地缓存: {err}") + cached = await self._load_models_dev_from_disk() + if isinstance(cached, dict): + self._models_dev_data = cached + self._models_dev_loaded_at = now + return cached + raise LLMProviderError(f"获取 models.dev 数据失败: {err}") from err + + async def _models_dev_provider_payload(self, provider_id: str) -> dict[str, Any]: + spec = self.get_provider(provider_id) + if not spec.models_dev_provider_id: + return {} + return (await self.get_models_dev_data()).get(spec.models_dev_provider_id, {}) or {} + + async def _models_dev_model( + self, provider_id: str, model_id: str + ) -> dict[str, Any] | None: + payload = await self._models_dev_provider_payload(provider_id) + models = payload.get("models") if isinstance(payload, dict) else None + if not isinstance(models, dict): + return None + + candidates = [model_id] + if model_id.startswith("models/"): + candidates.append(model_id.removeprefix("models/")) + + for candidate in candidates: + if candidate in models: + return models[candidate] + return None + + @staticmethod + def _normalize_model_record( + model_id: str, + display_name: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + transport: str = "openai", + live_context: Optional[int] = None, + live_input: Optional[int] = None, + live_output: Optional[int] = None, + live_supports_tools: Optional[bool] = None, + live_supports_reasoning: Optional[bool] = None, + live_supports_image: Optional[bool] = None, + live_supports_audio: Optional[bool] = None, + source: str = "provider", + ) -> dict[str, Any]: + """ + 统一输出模型记录格式,前端据此直接渲染和自动回填上下文等参数。 + """ + metadata = metadata or {} + limit = metadata.get("limit") or {} + modalities = metadata.get("modalities") or {} + input_modalities = set(modalities.get("input") or []) + + context_tokens = live_context or limit.get("context") + input_tokens = live_input or limit.get("input") + output_tokens = live_output or limit.get("output") + supports_image_input = ( + live_supports_image + if live_supports_image is not None + else "image" in input_modalities + ) + supports_audio_input = ( + live_supports_audio + if live_supports_audio is not None + else "audio" in input_modalities + ) + supports_tools = ( + live_supports_tools + if live_supports_tools is not None + else bool(metadata.get("tool_call")) + ) + supports_reasoning = ( + live_supports_reasoning + if live_supports_reasoning is not None + else bool(metadata.get("reasoning")) + ) + + return { + "id": model_id, + "name": display_name or metadata.get("name") or model_id, + "family": metadata.get("family"), + "context_tokens": context_tokens, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "context_tokens_k": max(1, int((int(context_tokens) + 999) / 1000)) + if context_tokens + else None, + "supports_reasoning": supports_reasoning, + "supports_tools": supports_tools, + "supports_image_input": supports_image_input, + "supports_audio_input": supports_audio_input, + "transport": transport, + "source": source, + "release_date": metadata.get("release_date"), + "status": metadata.get("status"), + } + + def _normalize_base_url_for_anthropic(self, base_url: str) -> str: + normalized = self._sanitize_base_url(base_url) or "" + if normalized.endswith("/v1"): + return normalized[:-3] + return normalized + + async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]: + from google import genai + from google.genai.types import HttpOptions + + http_options = None + if settings.PROXY_HOST: + proxy_key = self._httpx_proxy_key() + proxy_args = {proxy_key: settings.PROXY_HOST} + http_options = HttpOptions( + client_args=proxy_args, + async_client_args=proxy_args, + ) + + client = genai.Client(api_key=api_key, http_options=http_options) + response = await client.aio.models.list() + results = [] + for model in response.page: + supported = set(model.supported_actions or []) + if "generateContent" not in supported: + continue + model_id = model.name + metadata = await self._models_dev_model("google", model_id) or {} + results.append( + self._normalize_model_record( + model_id=model_id, + display_name=model.display_name or metadata.get("name") or model_id, + metadata=metadata, + source="provider", + ) + ) + return sorted(results, key=lambda item: item["name"].lower()) + + async def _list_models_from_openai_compatible( + self, + provider_id: str, + api_key: str, + base_url: str, + default_headers: Optional[dict[str, str]] = None, + ) -> list[dict[str, Any]]: + from openai import AsyncOpenAI + + client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + default_headers=default_headers, + timeout=15.0, + max_retries=2, + ) + results = [] + response = await client.models.list() + for model in response.data: + metadata = await self._models_dev_model(provider_id, model.id) or {} + results.append( + self._normalize_model_record( + model_id=model.id, + display_name=metadata.get("name") or model.id, + metadata=metadata, + source="provider", + ) + ) + return sorted(results, key=lambda item: item["name"].lower()) + + async def _list_models_from_anthropic_compatible( + self, + provider_id: str, + ) -> list[dict[str, Any]]: + """ + Anthropic-compatible 生态没有像 OpenAI 那样统一稳定的 models.list 行为, + 因此优先读取 models.dev 目录;若未来 provider 暴露标准 models 接口, + 再平滑补充实时刷新即可。 + """ + payload = await self._models_dev_provider_payload(provider_id) + models = payload.get("models") if isinstance(payload, dict) else None + if not isinstance(models, dict): + raise LLMProviderError(f"{provider_id} 暂无可用模型目录") + results = [] + for model_id, metadata in models.items(): + results.append( + self._normalize_model_record( + model_id=model_id, + display_name=metadata.get("name") or model_id, + metadata=metadata, + transport="anthropic", + source="models.dev", + ) + ) + return sorted(results, key=lambda item: item["name"].lower()) + + @staticmethod + def _copilot_headers( + token: Optional[str] = None, include_auth: bool = True + ) -> dict[str, str]: + """ + 构造 GitHub Copilot 请求头。 + + OpenAI-compatible 调用会由 SDK 自行补 Authorization,因此这里允许 + 仅补充 Copilot 必需的意图头,避免重复覆盖。 + """ + headers = { + "User-Agent": "MoviePilot/1.0", + "Openai-Intent": "conversation-edits", + "x-initiator": "user", + } + if include_auth and token: + headers["Authorization"] = f"Bearer {token}" + return headers + + async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]: + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.get( + "https://api.githubcopilot.com/models", + headers=self._copilot_headers(token), + ) + response.raise_for_status() + payload = response.json() + + raw_models = payload.get("data") if isinstance(payload, dict) else None + if not isinstance(raw_models, list): + raise LLMProviderError("GitHub Copilot 模型列表响应格式不正确") + + results = [] + for item in raw_models: + if not isinstance(item, dict): + continue + if not item.get("model_picker_enabled", True): + continue + if (item.get("policy") or {}).get("state") == "disabled": + continue + + model_id = str(item.get("id") or "").strip() + if not model_id: + continue + + endpoints = set(item.get("supported_endpoints") or []) + # 优先兼容 OpenAI 风格端点;仅在缺失时再切到 Anthropic 风格消息接口。 + transport = ( + "anthropic" + if "/v1/messages" in endpoints + and "/v1/chat/completions" not in endpoints + and "/v1/responses" not in endpoints + else "openai" + ) + + limits = ((item.get("capabilities") or {}).get("limits") or {}) + supports = ((item.get("capabilities") or {}).get("supports") or {}) + metadata = await self._models_dev_model("github-copilot", model_id) or {} + results.append( + self._normalize_model_record( + model_id=model_id, + display_name=item.get("name") or metadata.get("name") or model_id, + metadata=metadata, + transport=transport, + live_context=limits.get("max_context_window_tokens"), + live_input=limits.get("max_prompt_tokens"), + live_output=limits.get("max_output_tokens"), + live_supports_tools=supports.get("tool_calls"), + live_supports_reasoning=bool( + supports.get("adaptive_thinking") + or supports.get("reasoning_effort") + or supports.get("max_thinking_budget") is not None + or supports.get("min_thinking_budget") is not None + ), + live_supports_image=bool( + supports.get("vision") + or ((limits.get("vision") or {}).get("supported_media_types")) + ), + source="provider", + ) + ) + return sorted(results, key=lambda i: i["name"].lower()) + + async def _list_chatgpt_oauth_models(self) -> list[dict[str, Any]]: + payload = await self._models_dev_provider_payload("chatgpt") + models = payload.get("models") if isinstance(payload, dict) else None + if not isinstance(models, dict): + return [ + { + "id": model_id, + "name": model_id, + "context_tokens": None, + "input_tokens": None, + "output_tokens": None, + "context_tokens_k": settings.LLM_MAX_CONTEXT_TOKENS, + "supports_reasoning": True, + "supports_tools": True, + "supports_image_input": True, + "supports_audio_input": False, + "transport": "openai", + "source": "builtin", + "release_date": None, + "status": None, + } + for model_id in sorted(self._CHATGPT_ALLOWED_OAUTH_MODELS) + ] + + results = [] + for model_id, metadata in models.items(): + if "codex" in model_id or model_id in self._CHATGPT_ALLOWED_OAUTH_MODELS: + match = None + if model_id.startswith("gpt-"): + try: + match = float(model_id.split("-")[1].replace(".mini", "")) + except Exception as err: + print(err) + match = None + if match is not None and match > 5.4 and "codex" not in model_id: + continue + results.append( + self._normalize_model_record( + model_id=model_id, + display_name=metadata.get("name") or model_id, + metadata=metadata, + source="models.dev", + ) + ) + return sorted(results, key=lambda item: item["name"].lower()) + + async def list_models( + self, + provider_id: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + force_refresh: bool = False, + ) -> list[dict[str, Any]]: + """返回标准化后的模型目录。""" + spec = self.get_provider(provider_id) + if force_refresh and spec.models_dev_provider_id: + # 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表” + # 在使用目录型 provider 时也能拿到最新参数。 + await self.get_models_dev_data(force_refresh=True) + runtime = await self.resolve_runtime( + provider_id, + model=None, + api_key=api_key, + base_url=base_url, + ) + + if spec.model_list_strategy == "google": + return await self._list_models_from_google(runtime["api_key"]) + + if spec.model_list_strategy == "github_copilot": + return await self._list_models_from_copilot(runtime["api_key"]) + + if spec.model_list_strategy == "chatgpt": + if runtime.get("auth_mode") == "oauth": + return await self._list_chatgpt_oauth_models() + return await self._list_models_from_openai_compatible( + provider_id="chatgpt", + api_key=runtime["api_key"], + base_url=runtime["base_url"], + default_headers=runtime.get("default_headers"), + ) + + if spec.model_list_strategy == "anthropic_compatible": + return await self._list_models_from_anthropic_compatible( + provider_id=provider_id, + ) + + # openai-compatible / deepseek 默认走官方 models 端点。 + return await self._list_models_from_openai_compatible( + provider_id=provider_id, + api_key=runtime["api_key"], + base_url=runtime["base_url"], + default_headers=runtime.get("default_headers"), + ) + + async def resolve_model_metadata( + self, provider_id: str, model_id: Optional[str] + ) -> dict[str, Any] | None: + if not model_id: + return None + metadata = await self._models_dev_model(provider_id, model_id) + if metadata: + return metadata + if provider_id == "chatgpt": + return await self._models_dev_model("openai", model_id) + if provider_id == "openai": + models_dev = await self.get_models_dev_data() + return models_dev.get("openai", {}).get("models", {}).get(model_id) + return None + + @staticmethod + def _jwt_claims(token: str) -> dict[str, Any]: + try: + return jwt.decode(token, options={"verify_signature": False}) + except Exception as err: + print(err) + return {} + + @staticmethod + def _extract_chatgpt_account_id(token_payload: dict[str, Any]) -> Optional[str]: + if token_payload.get("chatgpt_account_id"): + return token_payload["chatgpt_account_id"] + auth_payload = token_payload.get("https://api.openai.com/auth") or {} + if auth_payload.get("chatgpt_account_id"): + return auth_payload["chatgpt_account_id"] + organizations = token_payload.get("organizations") or [] + if organizations and isinstance(organizations[0], dict): + return organizations[0].get("id") + return None + + def _chatgpt_authorize_url( + self, redirect_uri: str, challenge: str, state: str + ) -> str: + query = urlencode( + { + "response_type": "code", + "client_id": self._CHATGPT_CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": "openid profile email offline_access", + "code_challenge": challenge, + "code_challenge_method": "S256", + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "state": state, + "originator": "moviepilot", + } + ) + return f"{self._CHATGPT_ISSUER}/oauth/authorize?{query}" + + @staticmethod + def _pkce_pair() -> tuple[str, str]: + verifier = secrets.token_urlsafe(64).replace("=", "") + digest = hashlib.sha256(verifier.encode("utf-8")).digest() + challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=") + return verifier, challenge + + async def start_auth( + self, + provider_id: str, + method_id: str, + callback_url: Optional[str] = None, + ) -> dict[str, Any]: + """ + 启动 OAuth / device code 会话。 + + API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。 + """ + provider = self.get_provider(provider_id) + method = next( + (item for item in provider.oauth_methods if item.id == method_id), + None, + ) + if not method: + raise LLMProviderAuthError(f"{provider.name} 不支持授权方式:{method_id}") + + session = PendingAuthSession( + session_id=secrets.token_urlsafe(18), + provider_id=provider_id, + method_id=method_id, + flow_type=method.type, + expires_at=time.time() + 600, + ) + + if provider_id == "chatgpt" and method_id == "browser_oauth": + if not callback_url: + raise LLMProviderAuthError("ChatGPT 浏览器授权缺少回调地址") + verifier, challenge = self._pkce_pair() + state = secrets.token_urlsafe(24) + session.authorize_url = self._chatgpt_authorize_url( + redirect_uri=callback_url, + challenge=challenge, + state=state, + ) + session.instructions = "请在浏览器中完成 ChatGPT Plus/Pro 登录授权。" + session.context.update( + { + "code_verifier": verifier, + "state": state, + "redirect_uri": callback_url, + } + ) + with self._lock: + self._pending_sessions[session.session_id] = session + self._oauth_state_index[state] = session.session_id + return { + "session_id": session.session_id, + "flow_type": "oauth_browser", + "authorize_url": session.authorize_url, + "instructions": session.instructions, + "expires_at": session.expires_at, + } + + if provider_id == "chatgpt" and method_id == "device_code": + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/usercode", + headers={ + "Content-Type": "application/json", + "User-Agent": "MoviePilot/1.0", + }, + json={"client_id": self._CHATGPT_CLIENT_ID}, + ) + response.raise_for_status() + payload = response.json() + + session.verification_url = f"{self._CHATGPT_ISSUER}/codex/device" + session.user_code = payload.get("user_code") + session.interval_seconds = max(int(payload.get("interval") or 5), 1) + session.instructions = f"请在浏览器输入设备码:{session.user_code}" + session.context.update( + { + "device_auth_id": payload.get("device_auth_id"), + "user_code": payload.get("user_code"), + } + ) + with self._lock: + self._pending_sessions[session.session_id] = session + return { + "session_id": session.session_id, + "flow_type": "device_code", + "verification_url": session.verification_url, + "user_code": session.user_code, + "interval_seconds": session.interval_seconds, + "instructions": session.instructions, + "expires_at": session.expires_at, + } + + if provider_id == "github-copilot" and method_id == "device_code": + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + "https://github.com/login/device/code", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": "MoviePilot/1.0", + }, + json={ + "client_id": self._COPILOT_CLIENT_ID, + "scope": "read:user", + }, + ) + response.raise_for_status() + payload = response.json() + + session.verification_url = payload.get("verification_uri") + session.user_code = payload.get("user_code") + session.interval_seconds = max(int(payload.get("interval") or 5), 1) + session.instructions = f"请在 GitHub 页面输入设备码:{session.user_code}" + session.context.update( + { + "device_code": payload.get("device_code"), + } + ) + with self._lock: + self._pending_sessions[session.session_id] = session + return { + "session_id": session.session_id, + "flow_type": "device_code", + "verification_url": session.verification_url, + "user_code": session.user_code, + "interval_seconds": session.interval_seconds, + "instructions": session.instructions, + "expires_at": session.expires_at, + } + + raise LLMProviderAuthError(f"暂未实现 {provider.name} 的授权方式:{method.label}") + + def get_session_status(self, session_id: str) -> dict[str, Any]: + """读取临时授权会话状态。""" + with self._lock: + session = self._pending_sessions.get(session_id) + if not session: + raise LLMProviderAuthError("授权会话不存在或已过期") + return { + "session_id": session.session_id, + "provider_id": session.provider_id, + "status": session.status, + "message": session.message, + "user_code": session.user_code, + "verification_url": session.verification_url, + "authorize_url": session.authorize_url, + "instructions": session.instructions, + "interval_seconds": session.interval_seconds, + "expires_at": session.expires_at, + } + + async def _mark_session_success( + self, session: PendingAuthSession, auth_data: dict[str, Any] + ) -> None: + auth_data["updated_at"] = int(time.time()) + await self.save_auth(session.provider_id, auth_data) + session.status = "authorized" + session.message = "授权成功" + + @staticmethod + def _mark_session_error(session: PendingAuthSession, message: str) -> None: + session.status = "failed" + session.message = message + + async def handle_chatgpt_callback( + self, + provider_id: str, + code: Optional[str], + state: Optional[str], + error: Optional[str], + error_description: Optional[str], + ) -> tuple[bool, str]: + """处理 ChatGPT 浏览器 OAuth 回调。""" + if provider_id != "chatgpt": + return False, "当前 provider 不支持浏览器 OAuth 回调" + + if error: + message = error_description or error + with self._lock: + session_id = self._oauth_state_index.pop(state or "", None) + if session_id and session_id in self._pending_sessions: + self._mark_session_error(self._pending_sessions[session_id], message) + return False, message + + if not code or not state: + return False, "缺少授权码或 state 参数" + + with self._lock: + session_id = self._oauth_state_index.pop(state, None) + session = self._pending_sessions.get(session_id or "") + + if not session: + return False, "授权会话不存在或已失效" + + if state != session.context.get("state"): + self._mark_session_error(session, "state 校验失败") + return False, "state 校验失败" + + try: + payload = await self._exchange_chatgpt_code_for_tokens( + code=code, + redirect_uri=session.context["redirect_uri"], + code_verifier=session.context["code_verifier"], + ) + claims = self._jwt_claims(payload.get("id_token") or payload["access_token"]) + account_id = self._extract_chatgpt_account_id(claims) + auth_data = { + "type": "oauth", + "provider": "chatgpt", + "access_token": payload["access_token"], + "refresh_token": payload["refresh_token"], + "expires_at": int(time.time() + int(payload.get("expires_in") or 3600)), + "account_id": account_id, + "email": claims.get("email"), + "label": claims.get("email") or account_id or "ChatGPT Plus/Pro", + } + await self._mark_session_success(session, auth_data) + return True, "ChatGPT 授权成功" + except Exception as err: + message = f"ChatGPT 授权失败: {err}" + self._mark_session_error(session, message) + return False, message + + async def poll_auth_session(self, session_id: str) -> dict[str, Any]: + """ + 执行一次 device code 轮询,并返回最新状态。 + + 前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。 + """ + with self._lock: + session = self._pending_sessions.get(session_id) + if not session: + raise LLMProviderAuthError("授权会话不存在或已过期") + if session.status != "pending": + return self.get_session_status(session_id) + + try: + if session.provider_id == "chatgpt" and session.method_id == "device_code": + await self._poll_chatgpt_device_auth(session) + elif session.provider_id == "github-copilot" and session.method_id == "device_code": + await self._poll_copilot_device_auth(session) + else: + raise LLMProviderAuthError("当前授权会话不支持轮询") + except Exception as err: + self._mark_session_error(session, str(err)) + return self.get_session_status(session_id) + + async def _exchange_chatgpt_code_for_tokens( + self, code: str, redirect_uri: str, code_verifier: str + ) -> dict[str, Any]: + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + f"{self._CHATGPT_ISSUER}/oauth/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": self._CHATGPT_CLIENT_ID, + "code_verifier": code_verifier, + }, + ) + response.raise_for_status() + return response.json() + + async def _refresh_chatgpt_tokens(self, refresh_token: str) -> dict[str, Any]: + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + f"{self._CHATGPT_ISSUER}/oauth/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": self._CHATGPT_CLIENT_ID, + }, + ) + response.raise_for_status() + return response.json() + + async def _poll_chatgpt_device_auth(self, session: PendingAuthSession) -> None: + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token", + headers={ + "Content-Type": "application/json", + "User-Agent": "MoviePilot/1.0", + }, + json={ + "device_auth_id": session.context["device_auth_id"], + "user_code": session.context["user_code"], + }, + ) + + if response.status_code in {403, 404}: + session.message = "等待用户在浏览器完成授权" + return + + response.raise_for_status() + payload = response.json() + token_payload = await self._exchange_chatgpt_code_for_tokens( + code=payload["authorization_code"], + redirect_uri=f"{self._CHATGPT_ISSUER}/deviceauth/callback", + code_verifier=payload["code_verifier"], + ) + claims = self._jwt_claims( + token_payload.get("id_token") or token_payload["access_token"] + ) + account_id = self._extract_chatgpt_account_id(claims) + await self._mark_session_success( + session, + { + "type": "oauth", + "provider": "chatgpt", + "access_token": token_payload["access_token"], + "refresh_token": token_payload["refresh_token"], + "expires_at": int(time.time() + int(token_payload.get("expires_in") or 3600)), + "account_id": account_id, + "email": claims.get("email"), + "label": claims.get("email") or account_id or "ChatGPT Plus/Pro", + }, + ) + + async def _poll_copilot_device_auth(self, session: PendingAuthSession) -> None: + async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": "MoviePilot/1.0", + }, + json={ + "client_id": self._COPILOT_CLIENT_ID, + "device_code": session.context["device_code"], + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, + ) + response.raise_for_status() + payload = response.json() + + access_token = payload.get("access_token") + if access_token: + await self._mark_session_success( + session, + { + "type": "oauth", + "provider": "github-copilot", + "access_token": access_token, + # Copilot 设备码授权返回的是长期可复用 token,这里复用 access 字段即可。 + "refresh_token": access_token, + "expires_at": None, + "label": "GitHub Copilot", + }, + ) + return + + error = payload.get("error") + if error == "authorization_pending": + session.message = "等待用户在 GitHub 页面完成授权" + return + if error == "slow_down": + session.interval_seconds = max(session.interval_seconds + 5, 10) + session.message = "GitHub 要求降低轮询频率,稍后继续。" + return + if error: + raise LLMProviderAuthError(f"GitHub Copilot 授权失败: {error}") + + async def _resolve_chatgpt_oauth(self) -> dict[str, Any]: + auth = self.get_saved_auth("chatgpt") + if not auth or auth.get("type") != "oauth": + raise LLMProviderAuthError("尚未完成 ChatGPT Plus/Pro 授权") + + expires_at = auth.get("expires_at") + refresh_token = auth.get("refresh_token") + # 预留 60 秒刷新缓冲,避免刚发起请求就遇到过期。 + if expires_at and refresh_token and int(expires_at) <= int(time.time()) + 60: + payload = await self._refresh_chatgpt_tokens(refresh_token) + claims = self._jwt_claims(payload.get("id_token") or payload["access_token"]) + auth.update( + { + "access_token": payload["access_token"], + "refresh_token": payload.get("refresh_token") or refresh_token, + "expires_at": int(time.time() + int(payload.get("expires_in") or 3600)), + "account_id": auth.get("account_id") + or self._extract_chatgpt_account_id(claims), + "email": auth.get("email") or claims.get("email"), + "label": auth.get("label") + or claims.get("email") + or auth.get("account_id") + or "ChatGPT Plus/Pro", + } + ) + await self.save_auth("chatgpt", auth) + return auth + + async def resolve_runtime( + self, + provider_id: str, + model: Optional[str], + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ) -> dict[str, Any]: + """ + 解析 provider 运行时参数。 + + 返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。 + """ + spec = self.get_provider(provider_id) + normalized_api_key = str(api_key or "").strip() or None + normalized_base_url = self._sanitize_base_url(base_url) + model_record = None + if model: + try: + model_record = next( + ( + item + for item in await self.list_models( + provider_id, + api_key=api_key, + base_url=base_url, + ) + if item["id"] == model + ), + None, + ) + except Exception as err: + print(err) + model_record = None + + result: dict[str, Any] = { + "provider_id": provider_id, + "runtime": spec.runtime, + "model_id": model, + "model_record": model_record, + "model_metadata": await self.resolve_model_metadata(provider_id, model), + "default_headers": None, + "use_responses_api": None, + "auth_mode": "api_key", + } + + if provider_id == "chatgpt": + auth = None + try: + auth = await self._resolve_chatgpt_oauth() + except Exception: + pass + + if auth: + headers = {"originator": "moviepilot"} + if auth.get("account_id"): + headers["ChatGPT-Account-Id"] = auth["account_id"] + result.update( + { + "runtime": "chatgpt", + "api_key": auth["access_token"], + "base_url": self._CHATGPT_CODEX_BASE_URL, + "default_headers": headers, + "use_responses_api": True, + "auth_mode": "oauth", + } + ) + return result + + if normalized_api_key: + result.update( + { + "runtime": "openai_compatible", + "api_key": normalized_api_key, + "base_url": normalized_base_url or spec.default_base_url, + "auth_mode": "api_key", + } + ) + return result + + raise LLMProviderAuthError("请提供 API Key 或完成 ChatGPT 授权") + + if provider_id == "github-copilot": + auth = self.get_saved_auth("github-copilot") + if auth and auth.get("type") == "oauth": + token = auth.get("refresh_token") or auth.get("access_token") + elif normalized_api_key: + token = normalized_api_key + else: + raise LLMProviderAuthError("请先完成 GitHub Copilot 授权") + + transport = (model_record or {}).get("transport") or "openai" + result.update( + { + "runtime": "copilot_anthropic" + if transport == "anthropic" + else "github_copilot", + "api_key": token, + "base_url": "https://api.githubcopilot.com", + "default_headers": self._copilot_headers( + token, + include_auth=transport == "anthropic", + ), + "auth_mode": "oauth" if auth else "api_key", + } + ) + return result + + if spec.runtime == "google": + if not normalized_api_key: + raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") + result.update( + { + "api_key": normalized_api_key, + "base_url": None, + "auth_mode": "api_key", + } + ) + return result + + if spec.runtime == "anthropic_compatible": + effective_base_url = normalized_base_url or spec.default_base_url + if not normalized_api_key: + raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") + if not effective_base_url: + raise LLMProviderAuthError(f"{spec.name} 缺少 Base URL") + result.update( + { + "api_key": normalized_api_key, + "base_url": self._normalize_base_url_for_anthropic( + effective_base_url + ), + "auth_mode": "api_key", + } + ) + return result + + effective_base_url = normalized_base_url or spec.default_base_url + if spec.requires_base_url and not effective_base_url: + raise LLMProviderAuthError(f"{spec.name} 需要填写 Base URL") + if not normalized_api_key: + raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") + result.update( + { + "api_key": normalized_api_key, + "base_url": effective_base_url, + "auth_mode": "api_key", + } + ) + return result + + +def render_auth_result_html(success: bool, message: str) -> str: + """OAuth 回调落地页。""" + title = "授权成功" if success else "授权失败" + accent = "#3aa675" if success else "#e24b4b" + return f""" + +
+ + +{message}
+