mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-16 13:55:28 +08:00
refactor: migrate LLM helper to agent module and add unified LLM API endpoints
- Move LLMHelper and related logic from app.helper.llm to app.agent.llm.helper - Update all imports to reference new LLMHelper location - Introduce app/agent/llm/__init__.py for internal LLM adapter exports - Add llm.py API router with endpoints for model listing, provider auth, and test calls - Remove legacy LLM endpoints from system.py - Update requirements for langchain-anthropic and anthropic - Refactor test_llm_helper_testcall.py for async LLMHelper usage and new import paths
This commit is contained in:
@@ -33,7 +33,7 @@ from app.agent.runtime import agent_runtime_manager
|
|||||||
from app.agent.tools.factory import MoviePilotToolFactory
|
from app.agent.tools.factory import MoviePilotToolFactory
|
||||||
from app.chain import ChainBase
|
from app.chain import ChainBase
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.helper.llm import LLMHelper
|
from app.agent.llm import LLMHelper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.schemas import Notification, NotificationType
|
from app.schemas import Notification, NotificationType
|
||||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||||
@@ -310,12 +310,12 @@ class MoviePilotAgent:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _initialize_llm(streaming: bool = False):
|
async def _initialize_llm(streaming: bool = False):
|
||||||
"""
|
"""
|
||||||
初始化 LLM
|
初始化 LLM
|
||||||
:param streaming: 是否启用流式输出
|
:param streaming: 是否启用流式输出
|
||||||
"""
|
"""
|
||||||
return LLMHelper.get_llm(streaming=streaming)
|
return await LLMHelper.get_llm(streaming=streaming)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_text_content(content) -> str:
|
def _extract_text_content(content) -> str:
|
||||||
@@ -387,7 +387,7 @@ class MoviePilotAgent:
|
|||||||
allow_message_tools=self.allow_message_tools,
|
allow_message_tools=self.allow_message_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_agent(self, streaming: bool = False):
|
async def _create_agent(self, streaming: bool = False):
|
||||||
"""
|
"""
|
||||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||||
:param streaming: 是否启用流式输出
|
:param streaming: 是否启用流式输出
|
||||||
@@ -397,12 +397,12 @@ class MoviePilotAgent:
|
|||||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||||
|
|
||||||
# LLM 模型(用于 agent 执行)
|
# LLM 模型(用于 agent 执行)
|
||||||
llm = self._initialize_llm(streaming=streaming)
|
llm = await self._initialize_llm(streaming=streaming)
|
||||||
self._sync_model_profile(llm)
|
self._sync_model_profile(llm)
|
||||||
|
|
||||||
# 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
# 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
||||||
non_streaming_llm = (
|
non_streaming_llm = (
|
||||||
llm if not streaming else self._initialize_llm(streaming=False)
|
llm if not streaming else await self._initialize_llm(streaming=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 工具列表
|
# 工具列表
|
||||||
@@ -577,7 +577,7 @@ class MoviePilotAgent:
|
|||||||
use_streaming = self._should_stream()
|
use_streaming = self._should_stream()
|
||||||
|
|
||||||
# 创建智能体(根据是否流式传入不同 LLM)
|
# 创建智能体(根据是否流式传入不同 LLM)
|
||||||
agent = self._create_agent(streaming=use_streaming)
|
agent = await self._create_agent(streaming=use_streaming)
|
||||||
|
|
||||||
if use_streaming:
|
if use_streaming:
|
||||||
self.stream_handler.set_dispatch_policy(
|
self.stream_handler.set_dispatch_policy(
|
||||||
|
|||||||
19
app/agent/llm/__init__.py
Normal file
19
app/agent/llm/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Agent 内部使用的 LLM 适配层。"""
|
||||||
|
|
||||||
|
from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout
|
||||||
|
from app.agent.llm.provider import (
|
||||||
|
LLMProviderAuthError,
|
||||||
|
LLMProviderError,
|
||||||
|
LLMProviderManager,
|
||||||
|
render_auth_result_html,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMHelper",
|
||||||
|
"LLMProviderAuthError",
|
||||||
|
"LLMProviderError",
|
||||||
|
"LLMProviderManager",
|
||||||
|
"LLMTestError",
|
||||||
|
"LLMTestTimeout",
|
||||||
|
"render_auth_result_html",
|
||||||
|
]
|
||||||
@@ -342,7 +342,7 @@ class LLMHelper:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
||||||
if provider_name == "openai" and model_name.startswith(
|
if provider_name in {"openai", "chatgpt"} and model_name.startswith(
|
||||||
("gpt-5", "o1", "o3", "o4")
|
("gpt-5", "o1", "o3", "o4")
|
||||||
):
|
):
|
||||||
openai_effort = cls._normalize_openai_reasoning_effort(
|
openai_effort = cls._normalize_openai_reasoning_effort(
|
||||||
@@ -366,11 +366,89 @@ class LLMHelper:
|
|||||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_llm(
|
def _build_legacy_runtime(
|
||||||
|
provider_name: str,
|
||||||
|
model_name: str | None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
在 provider 目录不可用时回退到旧的直接构造逻辑。
|
||||||
|
|
||||||
|
这主要用于单测 stub 环境以及极端的最小运行环境,正常生产路径仍优先
|
||||||
|
走 `LLMProviderManager.resolve_runtime()`。
|
||||||
|
"""
|
||||||
|
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||||
|
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||||
|
if not api_key_value:
|
||||||
|
raise ValueError("未配置LLM API Key")
|
||||||
|
|
||||||
|
runtime_name = provider_name if provider_name in {"google", "deepseek"} else "openai_compatible"
|
||||||
|
return {
|
||||||
|
"provider_id": provider_name,
|
||||||
|
"runtime": runtime_name,
|
||||||
|
"model_id": model_name,
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": base_url_value,
|
||||||
|
"default_headers": None,
|
||||||
|
"use_responses_api": None,
|
||||||
|
"model_record": None,
|
||||||
|
"model_metadata": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_thinking_level(
|
||||||
|
cls,
|
||||||
|
thinking_level: str | None = None,
|
||||||
|
disable_thinking: bool | None = None,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
统一兼容新旧 thinking 参数。
|
||||||
|
|
||||||
|
新前端只会传 `thinking_level`,但测试和部分旧调用仍可能带
|
||||||
|
`disable_thinking` / `reasoning_effort`,这里集中做一次归一化。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _normalize(value: str | None) -> str | None:
|
||||||
|
normalized = str(value or "").strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
alias_map = {
|
||||||
|
"none": "off",
|
||||||
|
"disabled": "off",
|
||||||
|
"disable": "off",
|
||||||
|
"enabled": "auto",
|
||||||
|
"enable": "auto",
|
||||||
|
"default": "auto",
|
||||||
|
"dynamic": "auto",
|
||||||
|
}
|
||||||
|
normalized = alias_map.get(normalized, normalized)
|
||||||
|
if normalized in cls._SUPPORTED_THINKING_LEVELS:
|
||||||
|
return normalized
|
||||||
|
logger.warning(f"忽略不支持的思考级别: {value}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_thinking_level = _normalize(thinking_level)
|
||||||
|
if normalized_thinking_level:
|
||||||
|
return normalized_thinking_level
|
||||||
|
|
||||||
|
legacy_effort = _normalize(reasoning_effort)
|
||||||
|
if disable_thinking:
|
||||||
|
return "off"
|
||||||
|
if disable_thinking is False:
|
||||||
|
return legacy_effort or "auto"
|
||||||
|
return legacy_effort
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_llm(
|
||||||
|
cls,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
provider: str | None = None,
|
provider: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
thinking_level: str | None = None,
|
thinking_level: str | None = None,
|
||||||
|
disable_thinking: bool | None = None,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
):
|
):
|
||||||
@@ -383,28 +461,50 @@ class LLMHelper:
|
|||||||
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。
|
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。
|
||||||
不同模型对思考模式的支持和表现不同,具体映射关系请
|
不同模型对思考模式的支持和表现不同,具体映射关系请
|
||||||
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
|
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
|
||||||
|
:param disable_thinking: 兼容旧参数,若传入则会被转换为新的
|
||||||
|
`thinking_level` 语义。
|
||||||
|
:param reasoning_effort: 兼容旧参数,若传入则会被转换为新的
|
||||||
|
`thinking_level` 语义。
|
||||||
:param api_key: API Key,默认为
|
:param api_key: API Key,默认为
|
||||||
配置项LLM_API_KEY。对于某些提供商(
|
配置项LLM_API_KEY。对于某些提供商(
|
||||||
如 DeepSeek),可能需要同时提供 base_url。
|
如 DeepSeek),可能需要同时提供 base_url。
|
||||||
:param base_url: API Base URL,默认为配置项LLM_BASE_URL。
|
:param base_url: API Base URL,默认为配置项LLM_BASE_URL。
|
||||||
:return: LLM实例
|
:return: LLM实例
|
||||||
"""
|
"""
|
||||||
provider_name = str(
|
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).lower()
|
||||||
provider if provider is not None else settings.LLM_PROVIDER
|
|
||||||
).lower()
|
|
||||||
model_name = model if model is not None else settings.LLM_MODEL
|
model_name = model if model is not None else settings.LLM_MODEL
|
||||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
normalized_thinking_level = cls._resolve_thinking_level(
|
||||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
thinking_level=thinking_level,
|
||||||
thinking_kwargs = LLMHelper._build_thinking_kwargs(
|
disable_thinking=disable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免单测在最小 stub 环境下 import `llm.py` 时被 provider
|
||||||
|
# 目录依赖链拖住。
|
||||||
|
from app.agent.llm.provider import LLMProviderManager
|
||||||
|
|
||||||
|
runtime = await LLMProviderManager().resolve_runtime(
|
||||||
|
provider_id=provider_name,
|
||||||
|
model=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
logger.debug(f"LLM provider 目录不可用,回退到旧运行时逻辑: {err}")
|
||||||
|
runtime = cls._build_legacy_runtime(
|
||||||
|
provider_name=provider_name,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
model_name = runtime.get("model_id") or model_name
|
||||||
|
thinking_kwargs = cls._build_thinking_kwargs(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
thinking_level=thinking_level
|
thinking_level=normalized_thinking_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not api_key_value:
|
if runtime["runtime"] == "google":
|
||||||
raise ValueError("未配置LLM API Key")
|
|
||||||
|
|
||||||
if provider_name == "google":
|
|
||||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||||
_patch_gemini_thought_signature()
|
_patch_gemini_thought_signature()
|
||||||
|
|
||||||
@@ -420,49 +520,76 @@ class LLMHelper:
|
|||||||
|
|
||||||
model = ChatGoogleGenerativeAI(
|
model = ChatGoogleGenerativeAI(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_key=api_key_value,
|
api_key=runtime["api_key"],
|
||||||
retries=3,
|
retries=3,
|
||||||
temperature=settings.LLM_TEMPERATURE,
|
temperature=settings.LLM_TEMPERATURE,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
client_args=client_args,
|
client_args=client_args,
|
||||||
**thinking_kwargs,
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
elif provider_name == "deepseek":
|
elif runtime["runtime"] == "deepseek":
|
||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
_patch_deepseek_reasoning_content_support()
|
_patch_deepseek_reasoning_content_support()
|
||||||
model = ChatDeepSeek(
|
model = ChatDeepSeek(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_key=api_key_value,
|
api_key=runtime["api_key"],
|
||||||
api_base=base_url_value,
|
api_base=runtime["base_url"],
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
temperature=settings.LLM_TEMPERATURE,
|
temperature=settings.LLM_TEMPERATURE,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
stream_usage=True,
|
stream_usage=True,
|
||||||
**thinking_kwargs,
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
|
elif runtime["runtime"] in {"anthropic_compatible", "copilot_anthropic"}:
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
model = ChatAnthropic(
|
||||||
|
model=model_name,
|
||||||
|
api_key=runtime["api_key"],
|
||||||
|
base_url=runtime["base_url"],
|
||||||
|
max_retries=3,
|
||||||
|
temperature=settings.LLM_TEMPERATURE,
|
||||||
|
streaming=streaming,
|
||||||
|
stream_usage=True,
|
||||||
|
anthropic_proxy=settings.PROXY_HOST,
|
||||||
|
default_headers=runtime.get("default_headers"),
|
||||||
|
**thinking_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_key=api_key_value,
|
api_key=runtime["api_key"],
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
base_url=base_url_value,
|
base_url=runtime.get("base_url"),
|
||||||
temperature=settings.LLM_TEMPERATURE,
|
temperature=settings.LLM_TEMPERATURE,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
stream_usage=True,
|
stream_usage=True,
|
||||||
openai_proxy=settings.PROXY_HOST,
|
openai_proxy=settings.PROXY_HOST,
|
||||||
|
default_headers=runtime.get("default_headers"),
|
||||||
|
use_responses_api=runtime.get("use_responses_api"),
|
||||||
**thinking_kwargs,
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否有profile
|
# 优先使用 provider / models.dev 目录中的上下文上限,减少用户手填成本。
|
||||||
if hasattr(model, "profile") and model.profile:
|
model_profile = getattr(model, "profile", None)
|
||||||
|
if model_profile:
|
||||||
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||||
else:
|
else:
|
||||||
|
model_record = runtime.get("model_record") or {}
|
||||||
|
model_metadata = runtime.get("model_metadata") or {}
|
||||||
|
metadata_limit = model_metadata.get("limit") or {}
|
||||||
|
max_input_tokens = (
|
||||||
|
model_record.get("input_tokens")
|
||||||
|
or model_record.get("context_tokens")
|
||||||
|
or metadata_limit.get("input")
|
||||||
|
or metadata_limit.get("context")
|
||||||
|
or settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||||
|
)
|
||||||
model.profile = {
|
model.profile = {
|
||||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
"max_input_tokens": int(max_input_tokens),
|
||||||
* 1000, # 转换为token单位
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -514,6 +641,8 @@ class LLMHelper:
|
|||||||
provider: str | None = None,
|
provider: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
thinking_level: str | None = None,
|
thinking_level: str | None = None,
|
||||||
|
disable_thinking: bool | None = None,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -522,16 +651,16 @@ class LLMHelper:
|
|||||||
"""
|
"""
|
||||||
provider_name = provider if provider is not None else settings.LLM_PROVIDER
|
provider_name = provider if provider is not None else settings.LLM_PROVIDER
|
||||||
model_name = model if model is not None else settings.LLM_MODEL
|
model_name = model if model is not None else settings.LLM_MODEL
|
||||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
|
||||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
llm = LLMHelper.get_llm(
|
llm = await LLMHelper.get_llm(
|
||||||
streaming=False,
|
streaming=False,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
thinking_level=thinking_level,
|
thinking_level=thinking_level,
|
||||||
api_key=api_key_value,
|
disable_thinking=disable_thinking,
|
||||||
base_url=base_url_value,
|
reasoning_effort=reasoning_effort,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
||||||
@@ -556,18 +685,47 @@ class LLMHelper:
|
|||||||
data["reply_preview"] = reply_text[:120]
|
data["reply_preview"] = reply_text[:120]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def get_models(
|
async def get_models(
|
||||||
self, provider: str, api_key: str, base_url: str = None
|
self,
|
||||||
) -> List[str]:
|
provider: str,
|
||||||
"""获取模型列表"""
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> List[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取模型列表。
|
||||||
|
|
||||||
|
返回值会带上 context/supports_reasoning 等元数据,供前端直接渲染并自动
|
||||||
|
回填上下文大小。
|
||||||
|
"""
|
||||||
logger.info(f"获取 {provider} 模型列表...")
|
logger.info(f"获取 {provider} 模型列表...")
|
||||||
if provider == "google":
|
try:
|
||||||
return self._get_google_models(api_key)
|
from app.agent.llm.provider import LLMProviderManager
|
||||||
else:
|
|
||||||
return self._get_openai_compatible_models(provider, api_key, base_url)
|
return await LLMProviderManager().list_models(
|
||||||
|
provider_id=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
force_refresh=force_refresh,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
logger.debug(f"LLM provider 目录不可用,回退旧模型列表逻辑: {err}")
|
||||||
|
if provider == "google":
|
||||||
|
return [
|
||||||
|
{"id": model_id, "name": model_id}
|
||||||
|
for model_id in await self._get_google_models(api_key or "")
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{"id": model_id, "name": model_id}
|
||||||
|
for model_id in await self._get_openai_compatible_models(
|
||||||
|
provider,
|
||||||
|
api_key or "",
|
||||||
|
base_url,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_google_models(api_key: str) -> List[str]:
|
async def _get_google_models(api_key: str) -> List[str]:
|
||||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||||
try:
|
try:
|
||||||
from google import genai
|
from google import genai
|
||||||
@@ -583,29 +741,32 @@ class LLMHelper:
|
|||||||
)
|
)
|
||||||
|
|
||||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||||
models = client.models.list()
|
models = await client.aio.models.list()
|
||||||
return [
|
result = [
|
||||||
m.name
|
m.name
|
||||||
for m in models
|
for m in models.page
|
||||||
if m.supported_actions and "generateContent" in m.supported_actions
|
if m.supported_actions and "generateContent" in m.supported_actions
|
||||||
]
|
]
|
||||||
|
await client.aio.aclose()
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取Google模型列表失败:{e}")
|
logger.error(f"获取Google模型列表失败:{e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_openai_compatible_models(
|
async def _get_openai_compatible_models(
|
||||||
provider: str, api_key: str, base_url: str = None
|
provider: str, api_key: str, base_url: str = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""获取OpenAI兼容模型列表"""
|
"""获取OpenAI兼容模型列表"""
|
||||||
try:
|
try:
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
if provider == "deepseek":
|
if provider == "deepseek":
|
||||||
base_url = base_url or "https://api.deepseek.com"
|
base_url = base_url or "https://api.deepseek.com"
|
||||||
|
|
||||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
models = client.models.list()
|
models = await client.models.list()
|
||||||
|
await client.close()
|
||||||
return [model.id for model in models.data]
|
return [model.id for model in models.data]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
||||||
1507
app/agent/llm/provider.py
Normal file
1507
app/agent/llm/provider.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -158,9 +158,9 @@ async def _summarize_with_llm(conversation_text: str) -> str | None:
|
|||||||
LLM 生成的摘要字符串,失败时返回 None。
|
LLM 生成的摘要字符串,失败时返回 None。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.helper.llm import LLMHelper
|
from app.agent.llm import LLMHelper
|
||||||
|
|
||||||
llm = LLMHelper.get_llm(streaming=False)
|
llm = await LLMHelper.get_llm(streaming=False)
|
||||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||||
response = await llm.ainvoke(prompt)
|
response = await llm.ainvoke(prompt)
|
||||||
summary = response.content.strip()
|
summary = response.content.strip()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
|||||||
|
|
||||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
|
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||||
@@ -18,6 +18,7 @@ api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
|
|||||||
api_router.include_router(tmdb.router, prefix="/tmdb", tags=["tmdb"])
|
api_router.include_router(tmdb.router, prefix="/tmdb", tags=["tmdb"])
|
||||||
api_router.include_router(history.router, prefix="/history", tags=["history"])
|
api_router.include_router(history.router, prefix="/history", tags=["history"])
|
||||||
api_router.include_router(system.router, prefix="/system", tags=["system"])
|
api_router.include_router(system.router, prefix="/system", tags=["system"])
|
||||||
|
api_router.include_router(llm.router, prefix="/llm", tags=["llm"])
|
||||||
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
|
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
|
||||||
api_router.include_router(download.router, prefix="/download", tags=["download"])
|
api_router.include_router(download.router, prefix="/download", tags=["download"])
|
||||||
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
|
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
|
||||||
|
|||||||
290
app/api/endpoints/llm.py
Normal file
290
app/api/endpoints/llm.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
import re
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app import schemas
|
||||||
|
from app.agent.llm import (
|
||||||
|
LLMHelper,
|
||||||
|
LLMProviderManager,
|
||||||
|
LLMTestTimeout,
|
||||||
|
render_auth_result_html,
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.db.models import User
|
||||||
|
from app.db.user_oper import (
|
||||||
|
get_current_active_superuser_async,
|
||||||
|
get_current_active_user_async,
|
||||||
|
)
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class LlmTestRequest(BaseModel):
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
provider: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
thinking_level: Optional[str] = None
|
||||||
|
disable_thinking: Optional[bool] = None
|
||||||
|
reasoning_effort: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LlmProviderAuthStartRequest(BaseModel):
|
||||||
|
provider: str
|
||||||
|
method: str
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
清理错误信息中的敏感字段,避免回显密钥。
|
||||||
|
"""
|
||||||
|
if not message:
|
||||||
|
return "LLM 调用失败"
|
||||||
|
|
||||||
|
sanitized = message
|
||||||
|
if api_key:
|
||||||
|
sanitized = sanitized.replace(api_key, "***")
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||||
|
r"\1***",
|
||||||
|
sanitized,
|
||||||
|
)
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||||
|
"Authorization: ***",
|
||||||
|
sanitized,
|
||||||
|
)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||||
|
async def get_llm_models(
|
||||||
|
provider: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
force_refresh: Optional[bool] = False,
|
||||||
|
_: User = Depends(get_current_active_user_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定 provider 的模型目录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider_manager = LLMProviderManager()
|
||||||
|
models = await LLMHelper().get_models(
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
force_refresh=bool(force_refresh),
|
||||||
|
)
|
||||||
|
return schemas.Response(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"provider": provider,
|
||||||
|
"models": models,
|
||||||
|
"auth_status": provider_manager.get_auth_status(provider),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers", summary="获取LLM提供商目录", response_model=schemas.Response)
|
||||||
|
async def get_llm_providers(
|
||||||
|
_: User = Depends(get_current_active_user_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
返回前端可直接渲染的 provider 目录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
providers = LLMProviderManager().list_providers()
|
||||||
|
return schemas.Response(success=True, data=providers)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider-auth/start",
|
||||||
|
summary="启动LLM提供商授权",
|
||||||
|
response_model=schemas.Response,
|
||||||
|
)
|
||||||
|
async def start_llm_provider_auth(
|
||||||
|
payload: LlmProviderAuthStartRequest,
|
||||||
|
request: Request,
|
||||||
|
_: User = Depends(get_current_active_superuser_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
启动 provider 授权会话。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
callback_url = None
|
||||||
|
if payload.provider == "chatgpt" and payload.method == "browser_oauth":
|
||||||
|
callback_url = str(
|
||||||
|
request.url_for("llm_provider_auth_callback", provider_id=payload.provider)
|
||||||
|
)
|
||||||
|
result = await LLMProviderManager().start_auth(
|
||||||
|
payload.provider,
|
||||||
|
payload.method,
|
||||||
|
callback_url,
|
||||||
|
)
|
||||||
|
return schemas.Response(success=True, data=result)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/provider-auth/{session_id}",
|
||||||
|
summary="获取LLM提供商授权会话状态",
|
||||||
|
response_model=schemas.Response,
|
||||||
|
)
|
||||||
|
async def get_llm_provider_auth_session(
|
||||||
|
session_id: str,
|
||||||
|
_: User = Depends(get_current_active_superuser_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
查询授权会话状态。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = LLMProviderManager().get_session_status(session_id)
|
||||||
|
return schemas.Response(success=True, data=result)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider-auth/{session_id}/poll",
|
||||||
|
summary="轮询LLM提供商授权会话",
|
||||||
|
response_model=schemas.Response,
|
||||||
|
)
|
||||||
|
async def poll_llm_provider_auth_session(
|
||||||
|
session_id: str,
|
||||||
|
_: User = Depends(get_current_active_superuser_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
轮询 device code / OAuth 会话状态。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await LLMProviderManager().poll_auth_session(session_id)
|
||||||
|
return schemas.Response(success=True, data=result)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/provider-auth/{provider_id}",
|
||||||
|
summary="断开LLM提供商授权",
|
||||||
|
response_model=schemas.Response,
|
||||||
|
)
|
||||||
|
async def delete_llm_provider_auth(
|
||||||
|
provider_id: str,
|
||||||
|
_: User = Depends(get_current_active_superuser_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除已保存的 provider 授权信息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await LLMProviderManager().clear_auth(provider_id)
|
||||||
|
return schemas.Response(success=True)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(success=False, message=str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/provider-auth/callback/{provider_id}",
|
||||||
|
summary="LLM提供商OAuth回调",
|
||||||
|
response_class=HTMLResponse,
|
||||||
|
name="llm_provider_auth_callback",
|
||||||
|
)
|
||||||
|
async def llm_provider_auth_callback(
|
||||||
|
provider_id: str,
|
||||||
|
code: Optional[str] = None,
|
||||||
|
state: Optional[str] = None,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
error_description: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
处理需要浏览器回跳的 OAuth provider。
|
||||||
|
"""
|
||||||
|
success, message = await LLMProviderManager().handle_chatgpt_callback(
|
||||||
|
provider_id,
|
||||||
|
code,
|
||||||
|
state,
|
||||||
|
error,
|
||||||
|
error_description,
|
||||||
|
)
|
||||||
|
return HTMLResponse(content=render_auth_result_html(success, message))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test", summary="测试LLM调用", response_model=schemas.Response)
|
||||||
|
async def llm_test(
|
||||||
|
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||||
|
_: User = Depends(get_current_active_superuser_async),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
||||||
|
"""
|
||||||
|
payload = payload or LlmTestRequest(
|
||||||
|
enabled=settings.AI_AGENT_ENABLE,
|
||||||
|
provider=settings.LLM_PROVIDER,
|
||||||
|
model=settings.LLM_MODEL,
|
||||||
|
thinking_level=getattr(settings, "LLM_THINKING_LEVEL", None),
|
||||||
|
disable_thinking=getattr(settings, "LLM_DISABLE_THINKING", None),
|
||||||
|
reasoning_effort=getattr(settings, "LLM_REASONING_EFFORT", None),
|
||||||
|
api_key=settings.LLM_API_KEY,
|
||||||
|
base_url=settings.LLM_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not payload.provider:
|
||||||
|
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
||||||
|
if not payload.model or not payload.model.strip():
|
||||||
|
return schemas.Response(success=False, message="请先配置 LLM 模型")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"provider": payload.provider,
|
||||||
|
"model": payload.model,
|
||||||
|
}
|
||||||
|
if not payload.enabled:
|
||||||
|
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||||
|
|
||||||
|
if (
|
||||||
|
payload.provider not in {"chatgpt", "github-copilot"}
|
||||||
|
and (not payload.api_key or not payload.api_key.strip())
|
||||||
|
):
|
||||||
|
return schemas.Response(
|
||||||
|
success=False,
|
||||||
|
message="请先配置 LLM API Key",
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await LLMHelper.test_current_settings(
|
||||||
|
provider=payload.provider,
|
||||||
|
model=payload.model,
|
||||||
|
thinking_level=payload.thinking_level,
|
||||||
|
disable_thinking=payload.disable_thinking,
|
||||||
|
reasoning_effort=payload.reasoning_effort,
|
||||||
|
api_key=payload.api_key,
|
||||||
|
base_url=payload.base_url,
|
||||||
|
)
|
||||||
|
if not result.get("reply_preview"):
|
||||||
|
return schemas.Response(
|
||||||
|
success=False,
|
||||||
|
message="模型响应为空",
|
||||||
|
data=result,
|
||||||
|
)
|
||||||
|
return schemas.Response(success=True, data=result)
|
||||||
|
except (LLMTestTimeout, TimeoutError) as err:
|
||||||
|
logger.warning(err)
|
||||||
|
return schemas.Response(
|
||||||
|
success=False,
|
||||||
|
message="LLM 调用超时",
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return schemas.Response(
|
||||||
|
success=False,
|
||||||
|
message=_sanitize_llm_test_error(str(err), payload.api_key),
|
||||||
|
)
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Union, Annotated
|
from typing import Any, Optional, Union, Annotated
|
||||||
@@ -12,7 +11,6 @@ from anyio import Path as AsyncPath
|
|||||||
from app.helper.sites import SitesHelper # noqa # noqa
|
from app.helper.sites import SitesHelper # noqa # noqa
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app import schemas
|
from app import schemas
|
||||||
from app.chain.mediaserver import MediaServerChain
|
from app.chain.mediaserver import MediaServerChain
|
||||||
@@ -31,7 +29,6 @@ from app.db.user_oper import (
|
|||||||
get_current_active_user_async,
|
get_current_active_user_async,
|
||||||
)
|
)
|
||||||
from app.helper.image import ImageHelper
|
from app.helper.image import ImageHelper
|
||||||
from app.helper.llm import LLMHelper, LLMTestTimeout
|
|
||||||
from app.helper.mediaserver import MediaServerHelper
|
from app.helper.mediaserver import MediaServerHelper
|
||||||
from app.helper.message import MessageHelper
|
from app.helper.message import MessageHelper
|
||||||
from app.helper.progress import ProgressHelper
|
from app.helper.progress import ProgressHelper
|
||||||
@@ -53,15 +50,6 @@ router = APIRouter()
|
|||||||
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||||
|
|
||||||
|
|
||||||
class LlmTestRequest(BaseModel):
|
|
||||||
enabled: Optional[bool] = None
|
|
||||||
provider: Optional[str] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
thinking_level: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
base_url: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
||||||
"""
|
"""
|
||||||
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
||||||
@@ -268,30 +256,6 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
|
|||||||
)
|
)
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
|
||||||
"""
|
|
||||||
清理错误信息中的敏感字段,避免回显密钥。
|
|
||||||
"""
|
|
||||||
if not message:
|
|
||||||
return "LLM 调用失败"
|
|
||||||
|
|
||||||
sanitized = message
|
|
||||||
if api_key:
|
|
||||||
sanitized = sanitized.replace(api_key, "***")
|
|
||||||
sanitized = re.sub(
|
|
||||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
|
||||||
r"\1***",
|
|
||||||
sanitized,
|
|
||||||
)
|
|
||||||
sanitized = re.sub(
|
|
||||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
|
||||||
"Authorization: ***",
|
|
||||||
sanitized,
|
|
||||||
)
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
对实际请求地址做基础安全校验。
|
对实际请求地址做基础安全校验。
|
||||||
@@ -643,87 +607,6 @@ async def set_setting(
|
|||||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
|
||||||
async def get_llm_models(
|
|
||||||
provider: str,
|
|
||||||
api_key: str,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
_: User = Depends(get_current_active_user_async),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取LLM模型列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
models = await asyncio.to_thread(
|
|
||||||
LLMHelper().get_models, provider, api_key, base_url
|
|
||||||
)
|
|
||||||
return schemas.Response(success=True, data=models)
|
|
||||||
except Exception as e:
|
|
||||||
return schemas.Response(success=False, message=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
|
||||||
async def llm_test(
|
|
||||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
|
||||||
_: User = Depends(get_current_active_superuser_async),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
|
||||||
"""
|
|
||||||
if not payload:
|
|
||||||
return schemas.Response(success=False, message="请配置智能助手LLM相关参数后再进行测试")
|
|
||||||
|
|
||||||
if not payload.provider or not payload.model:
|
|
||||||
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"provider": payload.provider,
|
|
||||||
"model": payload.model,
|
|
||||||
}
|
|
||||||
if not payload.enabled:
|
|
||||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
|
||||||
|
|
||||||
if not payload.api_key or not payload.api_key.strip():
|
|
||||||
return schemas.Response(
|
|
||||||
success=False,
|
|
||||||
message="请先配置 LLM API Key",
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not payload.model or not payload.model.strip():
|
|
||||||
return schemas.Response(
|
|
||||||
success=False,
|
|
||||||
message="请先配置 LLM 模型",
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await LLMHelper.test_current_settings(
|
|
||||||
provider=payload.provider,
|
|
||||||
model=payload.model,
|
|
||||||
thinking_level=payload.thinking_level,
|
|
||||||
api_key=payload.api_key,
|
|
||||||
base_url=payload.base_url,
|
|
||||||
)
|
|
||||||
if not result.get("reply_preview"):
|
|
||||||
return schemas.Response(
|
|
||||||
success=False,
|
|
||||||
message="模型响应为空"
|
|
||||||
)
|
|
||||||
return schemas.Response(success=True, data=result)
|
|
||||||
except (LLMTestTimeout, TimeoutError) as err:
|
|
||||||
logger.warning(err)
|
|
||||||
return schemas.Response(
|
|
||||||
success=False,
|
|
||||||
message="LLM 调用超时"
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
return schemas.Response(
|
|
||||||
success=False,
|
|
||||||
message=_sanitize_llm_test_error(str(err), payload.api_key)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/message", summary="实时消息")
|
@router.get("/message", summary="实时消息")
|
||||||
async def get_message(
|
async def get_message(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from app.chain.transfer import TransferChain
|
|||||||
from app.core.config import settings, global_vars
|
from app.core.config import settings, global_vars
|
||||||
from app.db.models import TransferHistory
|
from app.db.models import TransferHistory
|
||||||
from app.db.transferhistory_oper import TransferHistoryOper
|
from app.db.transferhistory_oper import TransferHistoryOper
|
||||||
from app.helper.llm import LLMHelper
|
from app.agent.llm import LLMHelper
|
||||||
from app.helper.voice import VoiceHelper
|
from app.helper.voice import VoiceHelper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.schemas import Notification, CommingMessage
|
from app.schemas import Notification, CommingMessage
|
||||||
|
|||||||
@@ -79,10 +79,12 @@ httpx[socks]~=0.28.1
|
|||||||
langchain~=1.2.15
|
langchain~=1.2.15
|
||||||
langchain-core~=1.3.2
|
langchain-core~=1.3.2
|
||||||
langchain-community~=0.4.1
|
langchain-community~=0.4.1
|
||||||
|
langchain-anthropic~=1.1.0
|
||||||
langchain-openai~=1.2.1
|
langchain-openai~=1.2.1
|
||||||
langchain-google-genai~=4.2.2
|
langchain-google-genai~=4.2.2
|
||||||
langchain-deepseek~=1.0.1
|
langchain-deepseek~=1.0.1
|
||||||
langgraph~=1.1.9
|
langgraph~=1.1.9
|
||||||
|
anthropic>=0.57,<1
|
||||||
openai~=2.32.0
|
openai~=2.32.0
|
||||||
google-genai~=1.73.1
|
google-genai~=1.73.1
|
||||||
ddgs~=9.10.0
|
ddgs~=9.10.0
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from app.agent.tools.impl.send_local_file import SendLocalFileInput
|
|||||||
from app.agent import MoviePilotAgent, AgentChain
|
from app.agent import MoviePilotAgent, AgentChain
|
||||||
from app.chain.message import MessageChain
|
from app.chain.message import MessageChain
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.helper.llm import LLMHelper
|
from app.agent.llm import LLMHelper
|
||||||
from app.modules.discord import DiscordModule
|
from app.modules.discord import DiscordModule
|
||||||
from app.modules.qqbot import QQBotModule
|
from app.modules.qqbot import QQBotModule
|
||||||
from app.modules.slack import SlackModule
|
from app.modules.slack import SlackModule
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType, SimpleNamespace
|
from types import ModuleType, SimpleNamespace
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
def _stub_module(name: str, **attrs):
|
def _stub_module(name: str, **attrs):
|
||||||
@@ -30,7 +30,7 @@ class _FakeModel:
|
|||||||
return SimpleNamespace(content=self._content)
|
return SimpleNamespace(content=self._content)
|
||||||
|
|
||||||
|
|
||||||
sys.modules.pop("app.helper.llm", None)
|
sys.modules.pop("app.agent.llm.helper", None)
|
||||||
_stub_module(
|
_stub_module(
|
||||||
"app.core.config",
|
"app.core.config",
|
||||||
settings=SimpleNamespace(
|
settings=SimpleNamespace(
|
||||||
@@ -46,7 +46,7 @@ _stub_module(
|
|||||||
)
|
)
|
||||||
_stub_module("app.log", logger=_DummyLogger())
|
_stub_module("app.log", logger=_DummyLogger())
|
||||||
|
|
||||||
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
|
module_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "helper.py"
|
||||||
spec = importlib.util.spec_from_file_location("test_llm_module", module_path)
|
spec = importlib.util.spec_from_file_location("test_llm_module", module_path)
|
||||||
llm_module = importlib.util.module_from_spec(spec)
|
llm_module = importlib.util.module_from_spec(spec)
|
||||||
assert spec and spec.loader
|
assert spec and spec.loader
|
||||||
@@ -67,7 +67,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_test_current_settings_uses_explicit_snapshot(self):
|
def test_test_current_settings_uses_explicit_snapshot(self):
|
||||||
fake_model = _FakeModel("OK")
|
fake_model = _FakeModel("OK")
|
||||||
get_llm_mock = Mock(return_value=fake_model)
|
get_llm_mock = AsyncMock(return_value=fake_model)
|
||||||
|
|
||||||
with patch.object(llm_module.LLMHelper, "get_llm", get_llm_mock):
|
with patch.object(llm_module.LLMHelper, "get_llm", get_llm_mock):
|
||||||
result = asyncio.run(
|
result = asyncio.run(
|
||||||
@@ -79,7 +79,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
get_llm_mock.assert_called_once_with(
|
get_llm_mock.assert_awaited_once_with(
|
||||||
streaming=False,
|
streaming=False,
|
||||||
provider="deepseek",
|
provider="deepseek",
|
||||||
model="deepseek-chat",
|
model="deepseek-chat",
|
||||||
@@ -101,7 +101,9 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(llm_module.LLMHelper, "get_llm", return_value=fake_model):
|
with patch.object(
|
||||||
|
llm_module.LLMHelper, "get_llm", AsyncMock(return_value=fake_model)
|
||||||
|
):
|
||||||
result = asyncio.run(
|
result = asyncio.run(
|
||||||
llm_module.LLMHelper.test_current_settings(
|
llm_module.LLMHelper.test_current_settings(
|
||||||
provider="deepseek",
|
provider="deepseek",
|
||||||
@@ -126,12 +128,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
sys.modules,
|
sys.modules,
|
||||||
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="openai",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="kimi-k2.6",
|
provider="openai",
|
||||||
disable_thinking=True,
|
model="kimi-k2.6",
|
||||||
api_key="sk-test",
|
disable_thinking=True,
|
||||||
base_url="https://kimi.example.com/v1",
|
api_key="sk-test",
|
||||||
|
base_url="https://kimi.example.com/v1",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -158,12 +162,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
"_patch_deepseek_reasoning_content_support",
|
"_patch_deepseek_reasoning_content_support",
|
||||||
side_effect=lambda: patch_calls.append(True),
|
side_effect=lambda: patch_calls.append(True),
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="deepseek",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="deepseek-v4-pro",
|
provider="deepseek",
|
||||||
thinking_level="xhigh",
|
model="deepseek-v4-pro",
|
||||||
api_key="sk-test",
|
thinking_level="xhigh",
|
||||||
base_url="https://api.deepseek.com",
|
api_key="sk-test",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -193,12 +199,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
"_patch_deepseek_reasoning_content_support",
|
"_patch_deepseek_reasoning_content_support",
|
||||||
side_effect=lambda: patch_calls.append(True),
|
side_effect=lambda: patch_calls.append(True),
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="deepseek",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="deepseek-v4-flash",
|
provider="deepseek",
|
||||||
thinking_level="off",
|
model="deepseek-v4-flash",
|
||||||
api_key="sk-test",
|
thinking_level="off",
|
||||||
base_url="https://proxy.example.com",
|
api_key="sk-test",
|
||||||
|
base_url="https://proxy.example.com",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -223,12 +231,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
sys.modules,
|
sys.modules,
|
||||||
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="openai",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="gpt-5-mini",
|
provider="openai",
|
||||||
thinking_level="off",
|
model="gpt-5-mini",
|
||||||
api_key="sk-test",
|
thinking_level="off",
|
||||||
base_url="https://api.openai.com/v1",
|
api_key="sk-test",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -247,12 +257,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
sys.modules,
|
sys.modules,
|
||||||
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
{"langchain_openai": SimpleNamespace(ChatOpenAI=_FakeChatOpenAI)},
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="openai",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="gpt-5.4",
|
provider="openai",
|
||||||
thinking_level="max",
|
model="gpt-5.4",
|
||||||
api_key="sk-test",
|
thinking_level="max",
|
||||||
base_url="https://api.openai.com/v1",
|
api_key="sk-test",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -275,12 +287,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="google",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="gemini-2.5-flash",
|
provider="google",
|
||||||
thinking_level="off",
|
model="gemini-2.5-flash",
|
||||||
api_key="sk-test",
|
thinking_level="off",
|
||||||
base_url=None,
|
api_key="sk-test",
|
||||||
|
base_url=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
@@ -304,12 +318,14 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
llm_module.LLMHelper.get_llm(
|
asyncio.run(
|
||||||
provider="google",
|
llm_module.LLMHelper.get_llm(
|
||||||
model="gemini-3.1-flash",
|
provider="google",
|
||||||
thinking_level="xhigh",
|
model="gemini-3.1-flash",
|
||||||
api_key="sk-test",
|
thinking_level="xhigh",
|
||||||
base_url=None,
|
api_key="sk-test",
|
||||||
|
base_url=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(calls), 1)
|
self.assertEqual(len(calls), 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user