From 14b366a648fbfc4aaba67651a0e0cdce6fd344f9 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 8 May 2026 14:47:20 +0800 Subject: [PATCH] refactor: adjust default and maximum limits for plugin candidates and torrent results; enhance result formatting for agents --- app/agent/llm/provider.py | 47 +++++++++++++ app/agent/tools/base.py | 69 ++++++++++++++++--- app/agent/tools/impl/_plugin_tool_utils.py | 3 +- app/agent/tools/impl/_torrent_search_utils.py | 2 +- .../tools/impl/query_installed_plugins.py | 8 ++- app/agent/tools/impl/query_market_plugins.py | 8 ++- app/agent/tools/impl/query_media_detail.py | 29 +++++--- .../tools/impl/query_popular_subscribes.py | 6 +- app/agent/tools/impl/query_site_userdata.py | 27 ++++++-- .../tools/impl/query_subscribe_shares.py | 6 +- .../tools/impl/query_transfer_history.py | 4 +- app/agent/tools/impl/query_workflows.py | 4 +- app/agent/tools/impl/search_media.py | 6 +- app/agent/tools/impl/search_person.py | 6 +- app/agent/tools/impl/search_web.py | 2 +- app/agent/tools/manager.py | 23 +++---- tests/test_agent_query_workflows_tool.py | 56 +++++++++++++++ tests/test_agent_tool_result_limits.py | 51 ++++++++++++++ 18 files changed, 297 insertions(+), 60 deletions(-) create mode 100644 tests/test_agent_query_workflows_tool.py create mode 100644 tests/test_agent_tool_result_limits.py diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py index 9c423435..65a4741a 100644 --- a/app/agent/llm/provider.py +++ b/app/agent/llm/provider.py @@ -172,6 +172,7 @@ class LLMProviderManager(metaclass=Singleton): } def __init__(self): + """初始化管理器实例及各类锁和缓存变量。""" self._lock = threading.RLock() self._models_dev_lock = asyncio.Lock() self._pending_sessions: dict[str, PendingAuthSession] = {} @@ -749,6 +750,7 @@ class LLMProviderManager(metaclass=Singleton): return tuple(providers) def _cached_models_dev_payload(self) -> dict[str, Any]: + """获取缓存的 models.dev payload。""" if isinstance(self._models_dev_data, dict): return self._models_dev_data @@ -772,6 +774,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]: + """从 models.dev 数据中提取支持的环境变量名。""" raw_env_names = payload.get("env") if not isinstance(raw_env_names, list): return () @@ -786,6 +789,7 @@ class LLMProviderManager(metaclass=Singleton): def _models_dev_reserved_provider_ids( cls, specs: tuple[ProviderSpec, ...] ) -> set[str]: + """获取所有已保留的 models_dev_provider_id 集合。""" reserved_ids: set[str] = set() for spec in specs: if spec.models_dev_provider_id: @@ -797,6 +801,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str: + """根据环境变量名动态推断 API Key 标签名称。""" first_env = env_names[0].upper() if env_names else "" if "TOKEN" in first_env and "KEY" not in first_env: return "API Token" @@ -806,6 +811,7 @@ class LLMProviderManager(metaclass=Singleton): def _normalize_models_dev_base_url( cls, runtime: str, base_url: Optional[str] ) -> Optional[str]: + """规范化从 models.dev 获取的 Base URL。""" normalized = cls._sanitize_base_url(base_url) if not normalized: return None @@ -837,6 +843,7 @@ class LLMProviderManager(metaclass=Singleton): payload: dict[str, Any], sort_order: int, ) -> ProviderSpec | None: + """根据 models.dev 数据动态生成 ProviderSpec 实例。""" normalized_id = str(provider_id or "").strip().lower() if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS: return None @@ -899,6 +906,7 @@ class LLMProviderManager(metaclass=Singleton): def _dynamic_provider_specs( self, builtin_specs: tuple[ProviderSpec, ...] ) -> tuple[ProviderSpec, ...]: + """获取从 models.dev 动态加载的所有 ProviderSpec 实例。""" payload = self._cached_models_dev_payload() if not payload: return () @@ -939,12 +947,14 @@ class LLMProviderManager(metaclass=Singleton): return tuple(dynamic_specs) def _provider_specs(self) -> tuple[ProviderSpec, ...]: + """获取所有支持的 ProviderSpec,包括内置和动态加载的。""" builtin_specs = self._builtin_provider_specs() return builtin_specs + self._dynamic_provider_specs(builtin_specs) async def _get_provider_async( self, provider_id: str, force_refresh: bool = False ) -> ProviderSpec: + """异步获取指定 provider 的 ProviderSpec 实例。""" normalized_provider_id = self._normalize_provider_id(provider_id) try: return self.get_provider(normalized_provider_id) @@ -953,6 +963,7 @@ class LLMProviderManager(metaclass=Singleton): return self.get_provider(normalized_provider_id) def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]: + """将 ProviderSpec 序列化为前端可用的字典。""" return { "id": spec.id, "name": spec.name, @@ -1014,6 +1025,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _sanitize_base_url(base_url: Optional[str]) -> Optional[str]: + """清理 Base URL 中多余的空格和结尾斜杠。""" if base_url is None: return None value = str(base_url).strip() @@ -1023,6 +1035,7 @@ class LLMProviderManager(metaclass=Singleton): @classmethod def _default_base_url_for_provider(cls, spec: ProviderSpec) -> Optional[str]: + """获取 provider 的默认 Base URL。""" default_base_url = cls._sanitize_base_url(spec.default_base_url) if default_base_url: return default_base_url @@ -1032,6 +1045,7 @@ class LLMProviderManager(metaclass=Singleton): @classmethod def _normalize_provider_id(cls, provider_id: str) -> str: + """规范化 provider_id 以兼容旧版配置。""" normalized = (provider_id or "").strip().lower() if normalized == "minimax-coding": return "minimax" @@ -1043,6 +1057,7 @@ class LLMProviderManager(metaclass=Singleton): def _normalize_base_url_preset_id( cls, provider_id: str, base_url_preset_id: Optional[str] ) -> Optional[str]: + """规范化 Base URL 预设 ID。""" normalized_provider_id = cls._normalize_provider_id(provider_id) normalized_preset_id = str(base_url_preset_id or "").strip().lower() or None if not normalized_preset_id: @@ -1060,6 +1075,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[ProviderUrlPreset]: + """根据给定的参数解析出适用的 Base URL 预设。""" normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id) if normalized_preset_id: for preset in spec.base_url_presets: @@ -1089,6 +1105,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> str: + """解析提供商最终适用的 runtime。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) return preset.runtime or spec.runtime if preset else spec.runtime @@ -1099,6 +1116,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> str: + """解析获取模型列表的策略。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) return preset.model_list_strategy or spec.model_list_strategy if preset else spec.model_list_strategy @@ -1109,6 +1127,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: + """解析用于获取模型列表的 Base URL。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) if preset: preset_value = cls._sanitize_base_url(preset.value) @@ -1127,6 +1146,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: + """解析对应的 models.dev provider id。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) if preset: return preset.models_dev_provider_id or spec.models_dev_provider_id @@ -1143,6 +1163,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: + """解析对外暴露的用于获取模型列表的 Base URL。""" spec = self.get_provider(provider_id) return self._resolve_provider_model_list_base_url( spec, @@ -1157,6 +1178,7 @@ class LLMProviderManager(metaclass=Singleton): return "proxy" if "proxy" in params else "proxies" def _build_httpx_kwargs(self) -> dict[str, Any]: + """构造用于 httpx 客户端的参数,如代理等。""" kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT} if settings.PROXY_HOST: kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST @@ -1164,6 +1186,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _read_agent_config() -> dict[str, Any]: + """读取 AI Agent 配置信息。""" config = SystemConfigOper().get(SystemConfigKey.AIAgentConfig) if isinstance(config, dict): return config @@ -1183,6 +1206,7 @@ class LLMProviderManager(metaclass=Singleton): ) def _get_auth_store(self) -> dict[str, Any]: + """获取所有鉴权数据。""" config = self._read_agent_config() auth_store = config.get("provider_auth") if isinstance(auth_store, dict): @@ -1230,6 +1254,7 @@ class LLMProviderManager(metaclass=Singleton): } async def _load_models_dev_from_disk(self) -> dict[str, Any] | None: + """从磁盘缓存加载 models.dev 数据。""" try: if not self._models_dev_cache_path.exists(): return None @@ -1242,6 +1267,7 @@ class LLMProviderManager(metaclass=Singleton): return None def _load_bundled_models_dev_payload(self) -> dict[str, Any] | None: + """从随代码附带的离线文件加载 models.dev 数据。""" try: if not self._MODELS_DEV_BUNDLED_PATH.exists(): return None @@ -1255,6 +1281,7 @@ class LLMProviderManager(metaclass=Singleton): return payload if isinstance(payload, dict) else None async def _write_models_dev_to_disk(self, payload: dict[str, Any]) -> None: + """将 models.dev 数据写入磁盘缓存。""" try: self._models_dev_cache_path.parent.mkdir(parents=True, exist_ok=True) async with aiofiles.open( @@ -1265,6 +1292,7 @@ class LLMProviderManager(metaclass=Singleton): logger.warning(f"写入 models.dev 缓存失败: {err}") async def _fetch_models_dev(self) -> dict[str, Any]: + """通过网络请求获取最新 models.dev 数据。""" headers = {"User-Agent": "MoviePilot/1.0"} async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.get(self._MODELS_DEV_URL, headers=headers) @@ -1322,6 +1350,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any]: + """获取指定 provider 在 models.dev 中的完整负载。""" spec = await self._get_provider_async(provider_id) models_dev_provider_id = self._resolve_provider_models_dev_provider_id( spec, @@ -1339,6 +1368,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any] | None: + """获取指定模型的 models.dev 元数据。""" payload = await self._models_dev_provider_payload( provider_id, base_url=base_url, @@ -1435,12 +1465,14 @@ class LLMProviderManager(metaclass=Singleton): } def _normalize_base_url_for_anthropic(self, base_url: str) -> str: + """对 Anthropic 的 Base URL 进行适配处理。""" normalized = self._sanitize_base_url(base_url) or "" if normalized.endswith("/v1"): return normalized[:-3] return normalized async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]: + """从 Google AI Studio 获取模型列表。""" from google import genai from google.genai.types import HttpOptions @@ -1479,6 +1511,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: str, default_headers: Optional[dict[str, str]] = None, ) -> list[dict[str, Any]]: + """通过 OpenAI 兼容接口获取模型列表。""" from openai import AsyncOpenAI client = AsyncOpenAI( @@ -1559,6 +1592,7 @@ class LLMProviderManager(metaclass=Singleton): return headers async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]: + """从 GitHub Copilot 端点获取模型列表。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.get( "https://api.githubcopilot.com/models", @@ -1628,6 +1662,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> list[dict[str, Any]]: + """获取开启 OAuth 的 ChatGPT 模型列表。""" # ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护 # 一份内部名单,直接跟随当前 provider 对应的 models.dev 数据。 payload = await self._models_dev_provider_payload( @@ -1742,6 +1777,7 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any] | None: + """解析并返回指定模型在 models.dev 中的元数据。""" if not model_id: return None metadata = await self._models_dev_model( @@ -1761,6 +1797,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _jwt_claims(token: str) -> dict[str, Any]: + """解析 JWT token 内容(不验证签名)。""" try: return jwt.decode(token, options={"verify_signature": False}) except Exception as err: @@ -1769,6 +1806,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _extract_chatgpt_account_id(token_payload: dict[str, Any]) -> Optional[str]: + """从 ChatGPT 的 Token payload 中提取 account id。""" if token_payload.get("chatgpt_account_id"): return token_payload["chatgpt_account_id"] auth_payload = token_payload.get("https://api.openai.com/auth") or {} @@ -1782,6 +1820,7 @@ class LLMProviderManager(metaclass=Singleton): def _chatgpt_authorize_url( self, redirect_uri: str, challenge: str, state: str ) -> str: + """构建 ChatGPT OAuth 授权链接。""" query = urlencode( { "response_type": "code", @@ -1800,6 +1839,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _pkce_pair() -> tuple[str, str]: + """生成 PKCE verifier 和 challenge。""" verifier = secrets.token_urlsafe(64).replace("=", "") digest = hashlib.sha256(verifier.encode("utf-8")).digest() challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=") @@ -1958,6 +1998,7 @@ class LLMProviderManager(metaclass=Singleton): async def _mark_session_success( self, session: PendingAuthSession, auth_data: dict[str, Any] ) -> None: + """标记授权会话为成功,并保存认证信息。""" auth_data["updated_at"] = int(time.time()) await self.save_auth(session.provider_id, auth_data) session.status = "authorized" @@ -1965,6 +2006,7 @@ class LLMProviderManager(metaclass=Singleton): @staticmethod def _mark_session_error(session: PendingAuthSession, message: str) -> None: + """标记授权会话为失败,并记录错误信息。""" session.status = "failed" session.message = message @@ -2054,6 +2096,7 @@ class LLMProviderManager(metaclass=Singleton): async def _exchange_chatgpt_code_for_tokens( self, code: str, redirect_uri: str, code_verifier: str ) -> dict[str, Any]: + """使用 authorization code 交换 ChatGPT 令牌。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/oauth/token", @@ -2070,6 +2113,7 @@ class LLMProviderManager(metaclass=Singleton): return response.json() async def _refresh_chatgpt_tokens(self, refresh_token: str) -> dict[str, Any]: + """刷新 ChatGPT 的 access_token。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/oauth/token", @@ -2084,6 +2128,7 @@ class LLMProviderManager(metaclass=Singleton): return response.json() async def _poll_chatgpt_device_auth(self, session: PendingAuthSession) -> None: + """轮询 ChatGPT Device Auth 状态。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token", @@ -2127,6 +2172,7 @@ class LLMProviderManager(metaclass=Singleton): ) async def _poll_copilot_device_auth(self, session: PendingAuthSession) -> None: + """轮询 GitHub Copilot Device Auth 状态。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( "https://github.com/login/oauth/access_token", @@ -2172,6 +2218,7 @@ class LLMProviderManager(metaclass=Singleton): raise LLMProviderAuthError(f"GitHub Copilot 授权失败: {error}") async def _resolve_chatgpt_oauth(self) -> dict[str, Any]: + """解析并返回 ChatGPT OAuth 鉴权,支持自动刷新 Token。""" auth = self.get_saved_auth("chatgpt") if not auth or auth.get("type") != "oauth": raise LLMProviderAuthError("尚未完成 ChatGPT Plus/Pro 授权") diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 89fb2492..448f9959 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -4,7 +4,7 @@ import threading from abc import ABCMeta, abstractmethod from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, ClassVar, Optional from langchain_core.tools import BaseTool from pydantic import PrivateAttr @@ -23,6 +23,56 @@ class ToolChain(ChainBase): pass +# 单个工具结果的兜底上限。各工具仍应优先在自身逻辑中分页或摘要化; +# 这里用于拦截遗漏路径,避免超大结果直接进入模型上下文。 +DEFAULT_TOOL_RESULT_MAX_CHARS = 64 * 1024 +MIN_TOOL_RESULT_PREVIEW_CHARS = 512 + + +def serialize_tool_result_for_agent(result: Any) -> str: + """将工具返回值稳定转换为 Agent 可消费的字符串。""" + if isinstance(result, str): + return result + if isinstance(result, (int, float)): + return str(result) + try: + return json.dumps(result, ensure_ascii=False, indent=2, default=str) + except Exception as e: + logger.warning(f"工具结果转换为JSON失败: {e}, 使用字符串表示") + return str(result) + + +def format_tool_result_for_agent( + result: Any, + *, + tool_name: Optional[str] = None, + max_chars: Optional[int] = DEFAULT_TOOL_RESULT_MAX_CHARS, +) -> str: + """ + 统一格式化工具结果,并在超长时返回结构化预览。 + + 具体工具可以通过 `result_max_chars` 覆盖上限;传入 None 或 <=0 表示不截断。 + """ + formatted_result = serialize_tool_result_for_agent(result) + if not max_chars or max_chars <= 0 or len(formatted_result) <= max_chars: + return formatted_result + + preview_limit = max(MIN_TOOL_RESULT_PREVIEW_CHARS, max_chars) + preview = formatted_result[:preview_limit] + payload = { + "tool_result_truncated": True, + "tool_name": tool_name, + "total_chars": len(formatted_result), + "returned_chars": len(preview), + "content_preview": preview, + "message": ( + f"工具返回内容超过 {max_chars} 字符,已截断为预览;" + "请使用更精确的筛选条件、分页参数或专用查询参数继续获取。" + ), + } + return json.dumps(payload, ensure_ascii=False, indent=2) + + # 将常见的阻塞调用按能力域拆分到独立线程池,避免外部慢 IO 抢占同一批 worker。 _BLOCKING_BUCKET_LIMITS = { "default": 4, @@ -66,6 +116,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): MoviePilot专用工具基类(LangChain v1 / langchain_core) """ + result_max_chars: ClassVar[Optional[int]] = DEFAULT_TOOL_RESULT_MAX_CHARS + _session_id: str = PrivateAttr() _user_id: str = PrivateAttr() _channel: Optional[str] = PrivateAttr(default=None) @@ -160,21 +212,16 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): # 执行具体工具逻辑 try: result = await self.run(**kwargs) - logger.debug(f"Tool {self.name} executed with result: {result}") + result_len = len(str(result)) if result is not None else 0 + logger.debug(f"Tool {self.name} executed, raw result length: {result_len}") except Exception as e: error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}" logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True) result = error_message - # 格式化结果 - if isinstance(result, str): - formatted_result = result - elif isinstance(result, (int, float)): - formatted_result = str(result) - else: - formatted_result = json.dumps(result, ensure_ascii=False, indent=2) - - return formatted_result + return format_tool_result_for_agent( + result, tool_name=self.name, max_chars=self.result_max_chars + ) def get_tool_message(self, **kwargs) -> Optional[str]: """ diff --git a/app/agent/tools/impl/_plugin_tool_utils.py b/app/agent/tools/impl/_plugin_tool_utils.py index db77b743..a47509d5 100644 --- a/app/agent/tools/impl/_plugin_tool_utils.py +++ b/app/agent/tools/impl/_plugin_tool_utils.py @@ -15,7 +15,8 @@ DEFAULT_PLUGIN_DATA_PREVIEW_CHARS = 12_000 MAX_PLUGIN_DATA_PREVIEW_CHARS = 50_000 PLUGIN_DATA_KEY_PREVIEW_LIMIT = 50 PLUGIN_DATA_TRUNCATION_SUFFIX = "\n...(插件数据内容过长,已截断)" -DEFAULT_PLUGIN_CANDIDATE_LIMIT = 500 +DEFAULT_PLUGIN_CANDIDATE_LIMIT = 50 +MAX_PLUGIN_CANDIDATE_LIMIT = 200 def get_plugin_snapshot(plugin_id: str) -> Optional[dict[str, Any]]: diff --git a/app/agent/tools/impl/_torrent_search_utils.py b/app/agent/tools/impl/_torrent_search_utils.py index 6aebdb9d..b5a8b3bd 100644 --- a/app/agent/tools/impl/_torrent_search_utils.py +++ b/app/agent/tools/impl/_torrent_search_utils.py @@ -8,7 +8,7 @@ from app.utils.crypto import HashUtils from app.utils.string import StringUtils SEARCH_RESULT_CACHE_FILE = "__search_result__" -TORRENT_RESULT_LIMIT = 200 +TORRENT_RESULT_LIMIT = 50 def build_torrent_ref(context: Optional[Context]) -> str: diff --git a/app/agent/tools/impl/query_installed_plugins.py b/app/agent/tools/impl/query_installed_plugins.py index c090df0d..4b51b75b 100644 --- a/app/agent/tools/impl/query_installed_plugins.py +++ b/app/agent/tools/impl/query_installed_plugins.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.impl._plugin_tool_utils import ( DEFAULT_PLUGIN_CANDIDATE_LIMIT, + MAX_PLUGIN_CANDIDATE_LIMIT, list_installed_plugins, search_plugin_candidates, summarize_candidates, @@ -29,7 +30,7 @@ class QueryInstalledPluginsInput(BaseModel): ) max_results: Optional[int] = Field( DEFAULT_PLUGIN_CANDIDATE_LIMIT, - description="Maximum number of plugins to return. Defaults to 10.", + description="Maximum number of plugins to return. Defaults to 50, capped at 200.", ) @@ -53,7 +54,10 @@ class QueryInstalledPluginsTool(MoviePilotTool): def _clamp_results(max_results: Optional[int]) -> int: if max_results is None: return DEFAULT_PLUGIN_CANDIDATE_LIMIT - return max(1, min(int(max_results), 200)) + try: + return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT)) + except (TypeError, ValueError): + return DEFAULT_PLUGIN_CANDIDATE_LIMIT async def run( self, diff --git a/app/agent/tools/impl/query_market_plugins.py b/app/agent/tools/impl/query_market_plugins.py index 46875411..950d22b1 100644 --- a/app/agent/tools/impl/query_market_plugins.py +++ b/app/agent/tools/impl/query_market_plugins.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.impl._plugin_tool_utils import ( DEFAULT_PLUGIN_CANDIDATE_LIMIT, + MAX_PLUGIN_CANDIDATE_LIMIT, load_market_plugins, search_plugin_candidates, summarize_candidates, @@ -29,7 +30,7 @@ class QueryMarketPluginsInput(BaseModel): ) max_results: Optional[int] = Field( DEFAULT_PLUGIN_CANDIDATE_LIMIT, - description="Maximum number of plugins to return. Defaults to 10.", + description="Maximum number of plugins to return. Defaults to 50, capped at 200.", ) force_refresh: Optional[bool] = Field( False, @@ -56,7 +57,10 @@ class QueryMarketPluginsTool(MoviePilotTool): def _clamp_results(max_results: Optional[int]) -> int: if max_results is None: return DEFAULT_PLUGIN_CANDIDATE_LIMIT - return max(1, min(int(max_results), 200)) + try: + return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT)) + except (TypeError, ValueError): + return DEFAULT_PLUGIN_CANDIDATE_LIMIT async def run( self, diff --git a/app/agent/tools/impl/query_media_detail.py b/app/agent/tools/impl/query_media_detail.py index ec0ba405..749624d5 100644 --- a/app/agent/tools/impl/query_media_detail.py +++ b/app/agent/tools/impl/query_media_detail.py @@ -10,6 +10,10 @@ from app.chain.media import MediaChain from app.log import logger from app.schemas.types import MediaType +DIRECTOR_PREVIEW_LIMIT = 10 +ACTOR_PREVIEW_LIMIT = 20 +SEASON_PREVIEW_LIMIT = 100 + class QueryMediaDetailInput(BaseModel): """查询媒体详情工具的输入参数模型""" @@ -64,23 +68,23 @@ class QueryMediaDetailTool(MoviePilotTool): genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")] # 精简 directors - 只保留姓名和职位 + director_source = [d for d in (mediainfo.directors or []) if d.get("name")] directors = [ { "name": d.get("name"), "job": d.get("job") } - for d in (mediainfo.directors or []) - if d.get("name") + for d in director_source[:DIRECTOR_PREVIEW_LIMIT] ] # 精简 actors - 只保留姓名和角色 + actor_source = [a for a in (mediainfo.actors or []) if a.get("name")] actors = [ { "name": a.get("name"), "character": a.get("character") } - for a in (mediainfo.actors or []) - if a.get("name") + for a in actor_source[:ACTOR_PREVIEW_LIMIT] ] # 构建基础媒体详情信息 @@ -88,12 +92,20 @@ class QueryMediaDetailTool(MoviePilotTool): "status": mediainfo.status, "genres": genres, "directors": directors, - "actors": actors + "directors_total": len(director_source), + "directors_truncated": len(director_source) > DIRECTOR_PREVIEW_LIMIT, + "actors": actors, + "actors_total": len(actor_source), + "actors_truncated": len(actor_source) > ACTOR_PREVIEW_LIMIT, } # 如果是电视剧,添加电视剧特有信息 if mediainfo.type == MediaType.TV: # 精简 season_info - 只保留基础摘要 + season_source = [ + s for s in (mediainfo.season_info or []) + if s.get("season_number") is not None + ] season_info = [ { "season_number": s.get("season_number"), @@ -101,8 +113,7 @@ class QueryMediaDetailTool(MoviePilotTool): "episode_count": s.get("episode_count"), "air_date": s.get("air_date") } - for s in (mediainfo.season_info or []) - if s.get("season_number") is not None + for s in season_source[:SEASON_PREVIEW_LIMIT] ] result.update({ @@ -110,7 +121,9 @@ class QueryMediaDetailTool(MoviePilotTool): "number_of_episodes": mediainfo.number_of_episodes, "first_air_date": mediainfo.first_air_date, "last_air_date": mediainfo.last_air_date, - "season_info": season_info + "season_info": season_info, + "season_info_total": len(season_source), + "season_info_truncated": len(season_source) > SEASON_PREVIEW_LIMIT, }) return json.dumps(result, ensure_ascii=False, indent=2) diff --git a/app/agent/tools/impl/query_popular_subscribes.py b/app/agent/tools/impl/query_popular_subscribes.py index c2185cd0..203e77b1 100644 --- a/app/agent/tools/impl/query_popular_subscribes.py +++ b/app/agent/tools/impl/query_popular_subscribes.py @@ -12,13 +12,15 @@ from app.helper.subscribe import SubscribeHelper from app.log import logger from app.schemas.types import MediaType, media_type_to_agent +MAX_PAGE_SIZE = 50 + class QueryPopularSubscribesInput(BaseModel): """查询热门订阅工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") media_type: str = Field(..., description="Allowed values: movie, tv") page: Optional[int] = Field(1, description="Page number for pagination (default: 1)") - count: Optional[int] = Field(30, description="Number of items per page (default: 30)") + count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)") min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)") genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)") min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)") @@ -69,6 +71,8 @@ class QueryPopularSubscribesTool(MoviePilotTool): page = 1 if count is None or count < 1: count = 30 + # 外部统计接口支持传入 count,这里做硬上限,避免 Agent 一次拉取过多结果。 + count = min(count, MAX_PAGE_SIZE) media_type_enum = MediaType.from_agent(media_type) if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" diff --git a/app/agent/tools/impl/query_site_userdata.py b/app/agent/tools/impl/query_site_userdata.py index 34c47170..5fa0c621 100644 --- a/app/agent/tools/impl/query_site_userdata.py +++ b/app/agent/tools/impl/query_site_userdata.py @@ -11,6 +11,14 @@ from app.db.models.site import Site from app.db.models.siteuserdata import SiteUserData from app.log import logger +SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10 + + +def _preview_list(value, limit: int = SITE_USERDATA_DETAIL_PREVIEW_LIMIT) -> tuple[list, int, bool]: + """返回列表字段预览,避免做种明细或未读消息一次性撑大工具结果。""" + items = list(value) if isinstance(value, (list, tuple)) else [] + return items[:limit], len(items), len(items) > limit + class QuerySiteUserdataInput(BaseModel): """查询站点用户数据工具的输入参数模型""" @@ -110,6 +118,13 @@ class QuerySiteUserdataTool(MoviePilotTool): else 0 ) + seeding_preview, seeding_count, seeding_truncated = _preview_list( + user_data.seeding_info + ) + unread_preview, unread_count, unread_truncated = _preview_list( + user_data.message_unread_contents + ) + user_data_dict = { "domain": user_data.domain, "name": user_data.name, @@ -131,13 +146,13 @@ class QuerySiteUserdataTool(MoviePilotTool): "seeding_size_gb": round(seeding_size_gb, 2), "leeching_size": user_data.leeching_size, "leeching_size_gb": round(leeching_size_gb, 2), - "seeding_info": user_data.seeding_info - if user_data.seeding_info - else [], + "seeding_info_count": seeding_count, + "seeding_info": seeding_preview, + "seeding_info_truncated": seeding_truncated, "message_unread": user_data.message_unread, - "message_unread_contents": user_data.message_unread_contents - if user_data.message_unread_contents - else [], + "message_unread_contents_count": unread_count, + "message_unread_contents": unread_preview, + "message_unread_contents_truncated": unread_truncated, "err_msg": user_data.err_msg, "updated_day": user_data.updated_day, "updated_time": user_data.updated_time, diff --git a/app/agent/tools/impl/query_subscribe_shares.py b/app/agent/tools/impl/query_subscribe_shares.py index c3fc9dfd..bb0e5d0a 100644 --- a/app/agent/tools/impl/query_subscribe_shares.py +++ b/app/agent/tools/impl/query_subscribe_shares.py @@ -9,13 +9,15 @@ from app.agent.tools.base import MoviePilotTool from app.helper.subscribe import SubscribeHelper from app.log import logger +MAX_PAGE_SIZE = 50 + class QuerySubscribeSharesInput(BaseModel): """查询订阅分享工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)") page: Optional[int] = Field(1, description="Page number for pagination (default: 1)") - count: Optional[int] = Field(30, description="Number of items per page (default: 30)") + count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)") genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)") min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)") max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)") @@ -63,6 +65,8 @@ class QuerySubscribeSharesTool(MoviePilotTool): page = 1 if count is None or count < 1: count = 30 + # 订阅分享是外部列表型结果,限制单页大小能降低工具上下文占用。 + count = min(count, MAX_PAGE_SIZE) subscribe_helper = SubscribeHelper() shares = await subscribe_helper.async_get_shares( diff --git a/app/agent/tools/impl/query_transfer_history.py b/app/agent/tools/impl/query_transfer_history.py index 5201052c..2163877c 100644 --- a/app/agent/tools/impl/query_transfer_history.py +++ b/app/agent/tools/impl/query_transfer_history.py @@ -62,8 +62,8 @@ class QueryTransferHistoryTool(MoviePilotTool): if page is None or page < 1: page = 1 - # 每页记录数 - count = 50 + # 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。 + count = 30 # 获取数据库会话 async with AsyncSessionFactory() as db: diff --git a/app/agent/tools/impl/query_workflows.py b/app/agent/tools/impl/query_workflows.py index 51ce3370..2003eda9 100644 --- a/app/agent/tools/impl/query_workflows.py +++ b/app/agent/tools/impl/query_workflows.py @@ -115,9 +115,7 @@ class QueryWorkflowsTool(MoviePilotTool): "last_time": wf.last_time, "current_action": wf.current_action } - # 如果有结果,添加结果信息 - if wf.result: - simplified["result"] = wf.result + # wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。 simplified_workflows.append(simplified) result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2) diff --git a/app/agent/tools/impl/search_media.py b/app/agent/tools/impl/search_media.py index dabad39a..35086623 100644 --- a/app/agent/tools/impl/search_media.py +++ b/app/agent/tools/impl/search_media.py @@ -73,7 +73,7 @@ class SearchMediaTool(MoviePilotTool): filtered_results.append(result) if filtered_results: - # 限制最多30条结果 + # 搜索结果只返回前 30 条,后续可通过更精确的年份/类型条件缩小范围。 total_count = len(filtered_results) limited_results = filtered_results[:30] # 精简字段,只保留关键信息 @@ -96,8 +96,8 @@ class SearchMediaTool(MoviePilotTool): simplified_results.append(simplified) result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2) # 如果结果被裁剪,添加提示信息 - if total_count > 100: - return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}" + if total_count > len(limited_results): + return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_results)} 条结果。\n\n{result_json}" return result_json else: return f"未找到符合条件的媒体资源: {title}" diff --git a/app/agent/tools/impl/search_person.py b/app/agent/tools/impl/search_person.py index 01acbdd8..1c3974c0 100644 --- a/app/agent/tools/impl/search_person.py +++ b/app/agent/tools/impl/search_person.py @@ -35,7 +35,7 @@ class SearchPersonTool(MoviePilotTool): persons = await media_chain.async_search_persons(name=name) if persons: - # 限制最多30条结果 + # 人物搜索结果只返回前 30 条,避免 biography/别名等字段挤占上下文。 total_count = len(persons) limited_persons = persons[:30] # 精简字段,只保留关键信息 @@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool): result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2) # 如果结果被裁剪,添加提示信息 - if total_count > 50: - return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}" + if total_count > len(limited_persons): + return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_persons)} 条结果。\n\n{result_json}" return result_json else: return f"未找到相关人物信息: {name}" diff --git a/app/agent/tools/impl/search_web.py b/app/agent/tools/impl/search_web.py index 9c83af00..fbd551a4 100644 --- a/app/agent/tools/impl/search_web.py +++ b/app/agent/tools/impl/search_web.py @@ -28,7 +28,7 @@ class SearchWebInput(BaseModel): ) max_results: Optional[int] = Field( 20, - description="Maximum number of search results to return (default: 5, max: 10)", + description="Maximum number of search results to return (default: 20, max: 20)", ) diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index 5ddf7577..e1f2981b 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -2,6 +2,7 @@ import json import uuid from typing import Any, Dict, List, Optional +from app.agent.tools.base import format_tool_result_for_agent from app.agent.tools.factory import MoviePilotToolFactory from app.log import logger @@ -237,22 +238,14 @@ class MoviePilotToolsManager: # 规范化参数类型 normalized_arguments = self._normalize_arguments(tool_instance, arguments) - # 调用工具的run方法 + # 调用工具的run方法。HTTP/MCP 工具调用不会经过 BaseTool._arun, + # 因此这里也必须复用同一套返回值格式化和兜底截断逻辑。 result = await tool_instance.run(**normalized_arguments) - - # 确保返回字符串 - if isinstance(result, str): - formated_result = result - elif isinstance(result, (int, float)): - formated_result = str(result) - else: - try: - formated_result = json.dumps(result, ensure_ascii=False, indent=2) - except Exception as e: - logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示") - formated_result = str(result) - - return formated_result + return format_tool_result_for_agent( + result, + tool_name=tool_name, + max_chars=getattr(tool_instance, "result_max_chars", None), + ) except Exception as e: logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True) error_msg = json.dumps( diff --git a/tests/test_agent_query_workflows_tool.py b/tests/test_agent_query_workflows_tool.py new file mode 100644 index 00000000..687eb084 --- /dev/null +++ b/tests/test_agent_query_workflows_tool.py @@ -0,0 +1,56 @@ +import asyncio +import json +import unittest +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from app.agent.tools.impl.query_workflows import QueryWorkflowsTool + + +class _AsyncSessionContext: + """为工作流查询工具提供最小异步 DB 上下文。""" + + async def __aenter__(self): + return object() + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class TestQueryWorkflowsTool(unittest.TestCase): + def test_query_workflows_omits_large_result_field(self): + tool = QueryWorkflowsTool(session_id="session-1", user_id="10001") + workflow = SimpleNamespace( + id=1, + name="demo", + description="demo workflow", + state="S", + trigger_type="manual", + run_count=1, + timer=None, + event_type=None, + add_time="2026-05-08 10:00:00", + last_time="2026-05-08 10:01:00", + current_action=None, + result="x" * 10000, + ) + workflow_oper = MagicMock() + workflow_oper.async_list = AsyncMock(return_value=[workflow]) + + with patch( + "app.agent.tools.impl.query_workflows.AsyncSessionFactory", + return_value=_AsyncSessionContext(), + ), patch( + "app.agent.tools.impl.query_workflows.WorkflowOper", + return_value=workflow_oper, + ): + result = asyncio.run(tool.run()) + + payload = json.loads(result) + self.assertEqual(len(payload), 1) + self.assertEqual(payload[0]["name"], "demo") + self.assertNotIn("result", payload[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_agent_tool_result_limits.py b/tests/test_agent_tool_result_limits.py new file mode 100644 index 00000000..ac576de8 --- /dev/null +++ b/tests/test_agent_tool_result_limits.py @@ -0,0 +1,51 @@ +import asyncio +import json +import unittest + +from app.agent.tools.base import ( + DEFAULT_TOOL_RESULT_MAX_CHARS, + MoviePilotTool, + format_tool_result_for_agent, +) + + +class OversizedResultTool(MoviePilotTool): + name: str = "oversized_result_tool" + description: str = "Tool used to verify result truncation." + + async def run(self, **kwargs) -> str: + return "x" * (DEFAULT_TOOL_RESULT_MAX_CHARS + 100) + + +class TestAgentToolResultLimits(unittest.TestCase): + def test_arun_truncates_oversized_tool_result(self): + tool = OversizedResultTool(session_id="session-1", user_id="10001") + + result = asyncio.run(tool._arun()) + payload = json.loads(result) + + self.assertTrue(payload["tool_result_truncated"]) + self.assertEqual(payload["tool_name"], "oversized_result_tool") + self.assertEqual(payload["returned_chars"], DEFAULT_TOOL_RESULT_MAX_CHARS) + self.assertGreater(payload["total_chars"], payload["returned_chars"]) + + def test_formatter_preserves_sensitive_json_fields_for_agent_use(self): + result = format_tool_result_for_agent( + { + "cookie": "uid=abc; token=secret", + "nested": { + "api_key": "secret-key", + "plugin_author": "MoviePilot", + }, + }, + tool_name="sensitive_tool", + ) + payload = json.loads(result) + + self.assertEqual(payload["cookie"], "uid=abc; token=secret") + self.assertEqual(payload["nested"]["api_key"], "secret-key") + self.assertEqual(payload["nested"]["plugin_author"], "MoviePilot") + + +if __name__ == "__main__": + unittest.main()