refactor: adjust default and maximum limits for plugin candidates and torrent results; enhance result formatting for agents

This commit is contained in:
jxxghp
2026-05-08 14:47:20 +08:00
parent 0a0d5e6da2
commit 14b366a648
18 changed files with 297 additions and 60 deletions

View File

@@ -172,6 +172,7 @@ class LLMProviderManager(metaclass=Singleton):
}
def __init__(self):
"""初始化管理器实例及各类锁和缓存变量。"""
self._lock = threading.RLock()
self._models_dev_lock = asyncio.Lock()
self._pending_sessions: dict[str, PendingAuthSession] = {}
@@ -749,6 +750,7 @@ class LLMProviderManager(metaclass=Singleton):
return tuple(providers)
def _cached_models_dev_payload(self) -> dict[str, Any]:
"""获取缓存的 models.dev payload。"""
if isinstance(self._models_dev_data, dict):
return self._models_dev_data
@@ -772,6 +774,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]:
"""从 models.dev 数据中提取支持的环境变量名。"""
raw_env_names = payload.get("env")
if not isinstance(raw_env_names, list):
return ()
@@ -786,6 +789,7 @@ class LLMProviderManager(metaclass=Singleton):
def _models_dev_reserved_provider_ids(
cls, specs: tuple[ProviderSpec, ...]
) -> set[str]:
"""获取所有已保留的 models_dev_provider_id 集合。"""
reserved_ids: set[str] = set()
for spec in specs:
if spec.models_dev_provider_id:
@@ -797,6 +801,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str:
"""根据环境变量名动态推断 API Key 标签名称。"""
first_env = env_names[0].upper() if env_names else ""
if "TOKEN" in first_env and "KEY" not in first_env:
return "API Token"
@@ -806,6 +811,7 @@ class LLMProviderManager(metaclass=Singleton):
def _normalize_models_dev_base_url(
cls, runtime: str, base_url: Optional[str]
) -> Optional[str]:
"""规范化从 models.dev 获取的 Base URL。"""
normalized = cls._sanitize_base_url(base_url)
if not normalized:
return None
@@ -837,6 +843,7 @@ class LLMProviderManager(metaclass=Singleton):
payload: dict[str, Any],
sort_order: int,
) -> ProviderSpec | None:
"""根据 models.dev 数据动态生成 ProviderSpec 实例。"""
normalized_id = str(provider_id or "").strip().lower()
if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS:
return None
@@ -899,6 +906,7 @@ class LLMProviderManager(metaclass=Singleton):
def _dynamic_provider_specs(
self, builtin_specs: tuple[ProviderSpec, ...]
) -> tuple[ProviderSpec, ...]:
"""获取从 models.dev 动态加载的所有 ProviderSpec 实例。"""
payload = self._cached_models_dev_payload()
if not payload:
return ()
@@ -939,12 +947,14 @@ class LLMProviderManager(metaclass=Singleton):
return tuple(dynamic_specs)
def _provider_specs(self) -> tuple[ProviderSpec, ...]:
"""获取所有支持的 ProviderSpec包括内置和动态加载的。"""
builtin_specs = self._builtin_provider_specs()
return builtin_specs + self._dynamic_provider_specs(builtin_specs)
async def _get_provider_async(
self, provider_id: str, force_refresh: bool = False
) -> ProviderSpec:
"""异步获取指定 provider 的 ProviderSpec 实例。"""
normalized_provider_id = self._normalize_provider_id(provider_id)
try:
return self.get_provider(normalized_provider_id)
@@ -953,6 +963,7 @@ class LLMProviderManager(metaclass=Singleton):
return self.get_provider(normalized_provider_id)
def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]:
"""将 ProviderSpec 序列化为前端可用的字典。"""
return {
"id": spec.id,
"name": spec.name,
@@ -1014,6 +1025,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _sanitize_base_url(base_url: Optional[str]) -> Optional[str]:
"""清理 Base URL 中多余的空格和结尾斜杠。"""
if base_url is None:
return None
value = str(base_url).strip()
@@ -1023,6 +1035,7 @@ class LLMProviderManager(metaclass=Singleton):
@classmethod
def _default_base_url_for_provider(cls, spec: ProviderSpec) -> Optional[str]:
"""获取 provider 的默认 Base URL。"""
default_base_url = cls._sanitize_base_url(spec.default_base_url)
if default_base_url:
return default_base_url
@@ -1032,6 +1045,7 @@ class LLMProviderManager(metaclass=Singleton):
@classmethod
def _normalize_provider_id(cls, provider_id: str) -> str:
"""规范化 provider_id 以兼容旧版配置。"""
normalized = (provider_id or "").strip().lower()
if normalized == "minimax-coding":
return "minimax"
@@ -1043,6 +1057,7 @@ class LLMProviderManager(metaclass=Singleton):
def _normalize_base_url_preset_id(
cls, provider_id: str, base_url_preset_id: Optional[str]
) -> Optional[str]:
"""规范化 Base URL 预设 ID。"""
normalized_provider_id = cls._normalize_provider_id(provider_id)
normalized_preset_id = str(base_url_preset_id or "").strip().lower() or None
if not normalized_preset_id:
@@ -1060,6 +1075,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[ProviderUrlPreset]:
"""根据给定的参数解析出适用的 Base URL 预设。"""
normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id)
if normalized_preset_id:
for preset in spec.base_url_presets:
@@ -1089,6 +1105,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> str:
"""解析提供商最终适用的 runtime。"""
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
return preset.runtime or spec.runtime if preset else spec.runtime
@@ -1099,6 +1116,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> str:
"""解析获取模型列表的策略。"""
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
return preset.model_list_strategy or spec.model_list_strategy if preset else spec.model_list_strategy
@@ -1109,6 +1127,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[str]:
"""解析用于获取模型列表的 Base URL。"""
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
if preset:
preset_value = cls._sanitize_base_url(preset.value)
@@ -1127,6 +1146,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[str]:
"""解析对应的 models.dev provider id。"""
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
if preset:
return preset.models_dev_provider_id or spec.models_dev_provider_id
@@ -1143,6 +1163,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[str]:
"""解析对外暴露的用于获取模型列表的 Base URL。"""
spec = self.get_provider(provider_id)
return self._resolve_provider_model_list_base_url(
spec,
@@ -1157,6 +1178,7 @@ class LLMProviderManager(metaclass=Singleton):
return "proxy" if "proxy" in params else "proxies"
def _build_httpx_kwargs(self) -> dict[str, Any]:
"""构造用于 httpx 客户端的参数,如代理等。"""
kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT}
if settings.PROXY_HOST:
kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST
@@ -1164,6 +1186,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _read_agent_config() -> dict[str, Any]:
"""读取 AI Agent 配置信息。"""
config = SystemConfigOper().get(SystemConfigKey.AIAgentConfig)
if isinstance(config, dict):
return config
@@ -1183,6 +1206,7 @@ class LLMProviderManager(metaclass=Singleton):
)
def _get_auth_store(self) -> dict[str, Any]:
"""获取所有鉴权数据。"""
config = self._read_agent_config()
auth_store = config.get("provider_auth")
if isinstance(auth_store, dict):
@@ -1230,6 +1254,7 @@ class LLMProviderManager(metaclass=Singleton):
}
async def _load_models_dev_from_disk(self) -> dict[str, Any] | None:
"""从磁盘缓存加载 models.dev 数据。"""
try:
if not self._models_dev_cache_path.exists():
return None
@@ -1242,6 +1267,7 @@ class LLMProviderManager(metaclass=Singleton):
return None
def _load_bundled_models_dev_payload(self) -> dict[str, Any] | None:
"""从随代码附带的离线文件加载 models.dev 数据。"""
try:
if not self._MODELS_DEV_BUNDLED_PATH.exists():
return None
@@ -1255,6 +1281,7 @@ class LLMProviderManager(metaclass=Singleton):
return payload if isinstance(payload, dict) else None
async def _write_models_dev_to_disk(self, payload: dict[str, Any]) -> None:
"""将 models.dev 数据写入磁盘缓存。"""
try:
self._models_dev_cache_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(
@@ -1265,6 +1292,7 @@ class LLMProviderManager(metaclass=Singleton):
logger.warning(f"写入 models.dev 缓存失败: {err}")
async def _fetch_models_dev(self) -> dict[str, Any]:
"""通过网络请求获取最新 models.dev 数据。"""
headers = {"User-Agent": "MoviePilot/1.0"}
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.get(self._MODELS_DEV_URL, headers=headers)
@@ -1322,6 +1350,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> dict[str, Any]:
"""获取指定 provider 在 models.dev 中的完整负载。"""
spec = await self._get_provider_async(provider_id)
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
spec,
@@ -1339,6 +1368,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> dict[str, Any] | None:
"""获取指定模型的 models.dev 元数据。"""
payload = await self._models_dev_provider_payload(
provider_id,
base_url=base_url,
@@ -1435,12 +1465,14 @@ class LLMProviderManager(metaclass=Singleton):
}
def _normalize_base_url_for_anthropic(self, base_url: str) -> str:
"""对 Anthropic 的 Base URL 进行适配处理。"""
normalized = self._sanitize_base_url(base_url) or ""
if normalized.endswith("/v1"):
return normalized[:-3]
return normalized
async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]:
"""从 Google AI Studio 获取模型列表。"""
from google import genai
from google.genai.types import HttpOptions
@@ -1479,6 +1511,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: str,
default_headers: Optional[dict[str, str]] = None,
) -> list[dict[str, Any]]:
"""通过 OpenAI 兼容接口获取模型列表。"""
from openai import AsyncOpenAI
client = AsyncOpenAI(
@@ -1559,6 +1592,7 @@ class LLMProviderManager(metaclass=Singleton):
return headers
async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]:
"""从 GitHub Copilot 端点获取模型列表。"""
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.get(
"https://api.githubcopilot.com/models",
@@ -1628,6 +1662,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> list[dict[str, Any]]:
"""获取开启 OAuth 的 ChatGPT 模型列表。"""
# ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护
# 一份内部名单,直接跟随当前 provider 对应的 models.dev 数据。
payload = await self._models_dev_provider_payload(
@@ -1742,6 +1777,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> dict[str, Any] | None:
"""解析并返回指定模型在 models.dev 中的元数据。"""
if not model_id:
return None
metadata = await self._models_dev_model(
@@ -1761,6 +1797,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _jwt_claims(token: str) -> dict[str, Any]:
"""解析 JWT token 内容(不验证签名)。"""
try:
return jwt.decode(token, options={"verify_signature": False})
except Exception as err:
@@ -1769,6 +1806,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _extract_chatgpt_account_id(token_payload: dict[str, Any]) -> Optional[str]:
"""从 ChatGPT 的 Token payload 中提取 account id。"""
if token_payload.get("chatgpt_account_id"):
return token_payload["chatgpt_account_id"]
auth_payload = token_payload.get("https://api.openai.com/auth") or {}
@@ -1782,6 +1820,7 @@ class LLMProviderManager(metaclass=Singleton):
def _chatgpt_authorize_url(
self, redirect_uri: str, challenge: str, state: str
) -> str:
"""构建 ChatGPT OAuth 授权链接。"""
query = urlencode(
{
"response_type": "code",
@@ -1800,6 +1839,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _pkce_pair() -> tuple[str, str]:
"""生成 PKCE verifier 和 challenge。"""
verifier = secrets.token_urlsafe(64).replace("=", "")
digest = hashlib.sha256(verifier.encode("utf-8")).digest()
challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=")
@@ -1958,6 +1998,7 @@ class LLMProviderManager(metaclass=Singleton):
async def _mark_session_success(
self, session: PendingAuthSession, auth_data: dict[str, Any]
) -> None:
"""标记授权会话为成功,并保存认证信息。"""
auth_data["updated_at"] = int(time.time())
await self.save_auth(session.provider_id, auth_data)
session.status = "authorized"
@@ -1965,6 +2006,7 @@ class LLMProviderManager(metaclass=Singleton):
@staticmethod
def _mark_session_error(session: PendingAuthSession, message: str) -> None:
"""标记授权会话为失败,并记录错误信息。"""
session.status = "failed"
session.message = message
@@ -2054,6 +2096,7 @@ class LLMProviderManager(metaclass=Singleton):
async def _exchange_chatgpt_code_for_tokens(
self, code: str, redirect_uri: str, code_verifier: str
) -> dict[str, Any]:
"""使用 authorization code 交换 ChatGPT 令牌。"""
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.post(
f"{self._CHATGPT_ISSUER}/oauth/token",
@@ -2070,6 +2113,7 @@ class LLMProviderManager(metaclass=Singleton):
return response.json()
async def _refresh_chatgpt_tokens(self, refresh_token: str) -> dict[str, Any]:
"""刷新 ChatGPT 的 access_token。"""
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.post(
f"{self._CHATGPT_ISSUER}/oauth/token",
@@ -2084,6 +2128,7 @@ class LLMProviderManager(metaclass=Singleton):
return response.json()
async def _poll_chatgpt_device_auth(self, session: PendingAuthSession) -> None:
"""轮询 ChatGPT Device Auth 状态。"""
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.post(
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token",
@@ -2127,6 +2172,7 @@ class LLMProviderManager(metaclass=Singleton):
)
async def _poll_copilot_device_auth(self, session: PendingAuthSession) -> None:
"""轮询 GitHub Copilot Device Auth 状态。"""
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
response = await client.post(
"https://github.com/login/oauth/access_token",
@@ -2172,6 +2218,7 @@ class LLMProviderManager(metaclass=Singleton):
raise LLMProviderAuthError(f"GitHub Copilot 授权失败: {error}")
async def _resolve_chatgpt_oauth(self) -> dict[str, Any]:
"""解析并返回 ChatGPT OAuth 鉴权,支持自动刷新 Token。"""
auth = self.get_saved_auth("chatgpt")
if not auth or auth.get("type") != "oauth":
raise LLMProviderAuthError("尚未完成 ChatGPT Plus/Pro 授权")

View File

@@ -4,7 +4,7 @@ import threading
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Callable, ClassVar, Optional
from langchain_core.tools import BaseTool
from pydantic import PrivateAttr
@@ -23,6 +23,56 @@ class ToolChain(ChainBase):
pass
# 单个工具结果的兜底上限。各工具仍应优先在自身逻辑中分页或摘要化;
# 这里用于拦截遗漏路径,避免超大结果直接进入模型上下文。
DEFAULT_TOOL_RESULT_MAX_CHARS = 64 * 1024
MIN_TOOL_RESULT_PREVIEW_CHARS = 512
def serialize_tool_result_for_agent(result: Any) -> str:
"""将工具返回值稳定转换为 Agent 可消费的字符串。"""
if isinstance(result, str):
return result
if isinstance(result, (int, float)):
return str(result)
try:
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
except Exception as e:
logger.warning(f"工具结果转换为JSON失败: {e}, 使用字符串表示")
return str(result)
def format_tool_result_for_agent(
result: Any,
*,
tool_name: Optional[str] = None,
max_chars: Optional[int] = DEFAULT_TOOL_RESULT_MAX_CHARS,
) -> str:
"""
统一格式化工具结果,并在超长时返回结构化预览。
具体工具可以通过 `result_max_chars` 覆盖上限;传入 None 或 <=0 表示不截断。
"""
formatted_result = serialize_tool_result_for_agent(result)
if not max_chars or max_chars <= 0 or len(formatted_result) <= max_chars:
return formatted_result
preview_limit = max(MIN_TOOL_RESULT_PREVIEW_CHARS, max_chars)
preview = formatted_result[:preview_limit]
payload = {
"tool_result_truncated": True,
"tool_name": tool_name,
"total_chars": len(formatted_result),
"returned_chars": len(preview),
"content_preview": preview,
"message": (
f"工具返回内容超过 {max_chars} 字符,已截断为预览;"
"请使用更精确的筛选条件、分页参数或专用查询参数继续获取。"
),
}
return json.dumps(payload, ensure_ascii=False, indent=2)
# 将常见的阻塞调用按能力域拆分到独立线程池,避免外部慢 IO 抢占同一批 worker。
_BLOCKING_BUCKET_LIMITS = {
"default": 4,
@@ -66,6 +116,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
MoviePilot专用工具基类LangChain v1 / langchain_core
"""
result_max_chars: ClassVar[Optional[int]] = DEFAULT_TOOL_RESULT_MAX_CHARS
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
_channel: Optional[str] = PrivateAttr(default=None)
@@ -160,21 +212,16 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
# 执行具体工具逻辑
try:
result = await self.run(**kwargs)
logger.debug(f"Tool {self.name} executed with result: {result}")
result_len = len(str(result)) if result is not None else 0
logger.debug(f"Tool {self.name} executed, raw result length: {result_len}")
except Exception as e:
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
result = error_message
# 格式化结果
if isinstance(result, str):
formatted_result = result
elif isinstance(result, (int, float)):
formatted_result = str(result)
else:
formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
return formatted_result
return format_tool_result_for_agent(
result, tool_name=self.name, max_chars=self.result_max_chars
)
def get_tool_message(self, **kwargs) -> Optional[str]:
"""

View File

@@ -15,7 +15,8 @@ DEFAULT_PLUGIN_DATA_PREVIEW_CHARS = 12_000
MAX_PLUGIN_DATA_PREVIEW_CHARS = 50_000
PLUGIN_DATA_KEY_PREVIEW_LIMIT = 50
PLUGIN_DATA_TRUNCATION_SUFFIX = "\n...(插件数据内容过长,已截断)"
DEFAULT_PLUGIN_CANDIDATE_LIMIT = 500
DEFAULT_PLUGIN_CANDIDATE_LIMIT = 50
MAX_PLUGIN_CANDIDATE_LIMIT = 200
def get_plugin_snapshot(plugin_id: str) -> Optional[dict[str, Any]]:

View File

@@ -8,7 +8,7 @@ from app.utils.crypto import HashUtils
from app.utils.string import StringUtils
SEARCH_RESULT_CACHE_FILE = "__search_result__"
TORRENT_RESULT_LIMIT = 200
TORRENT_RESULT_LIMIT = 50
def build_torrent_ref(context: Optional[Context]) -> str:

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.impl._plugin_tool_utils import (
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
MAX_PLUGIN_CANDIDATE_LIMIT,
list_installed_plugins,
search_plugin_candidates,
summarize_candidates,
@@ -29,7 +30,7 @@ class QueryInstalledPluginsInput(BaseModel):
)
max_results: Optional[int] = Field(
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
description="Maximum number of plugins to return. Defaults to 10.",
description="Maximum number of plugins to return. Defaults to 50, capped at 200.",
)
@@ -53,7 +54,10 @@ class QueryInstalledPluginsTool(MoviePilotTool):
def _clamp_results(max_results: Optional[int]) -> int:
if max_results is None:
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
return max(1, min(int(max_results), 200))
try:
return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT))
except (TypeError, ValueError):
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
async def run(
self,

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.impl._plugin_tool_utils import (
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
MAX_PLUGIN_CANDIDATE_LIMIT,
load_market_plugins,
search_plugin_candidates,
summarize_candidates,
@@ -29,7 +30,7 @@ class QueryMarketPluginsInput(BaseModel):
)
max_results: Optional[int] = Field(
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
description="Maximum number of plugins to return. Defaults to 10.",
description="Maximum number of plugins to return. Defaults to 50, capped at 200.",
)
force_refresh: Optional[bool] = Field(
False,
@@ -56,7 +57,10 @@ class QueryMarketPluginsTool(MoviePilotTool):
def _clamp_results(max_results: Optional[int]) -> int:
if max_results is None:
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
return max(1, min(int(max_results), 200))
try:
return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT))
except (TypeError, ValueError):
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
async def run(
self,

View File

@@ -10,6 +10,10 @@ from app.chain.media import MediaChain
from app.log import logger
from app.schemas.types import MediaType
DIRECTOR_PREVIEW_LIMIT = 10
ACTOR_PREVIEW_LIMIT = 20
SEASON_PREVIEW_LIMIT = 100
class QueryMediaDetailInput(BaseModel):
"""查询媒体详情工具的输入参数模型"""
@@ -64,23 +68,23 @@ class QueryMediaDetailTool(MoviePilotTool):
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
# 精简 directors - 只保留姓名和职位
director_source = [d for d in (mediainfo.directors or []) if d.get("name")]
directors = [
{
"name": d.get("name"),
"job": d.get("job")
}
for d in (mediainfo.directors or [])
if d.get("name")
for d in director_source[:DIRECTOR_PREVIEW_LIMIT]
]
# 精简 actors - 只保留姓名和角色
actor_source = [a for a in (mediainfo.actors or []) if a.get("name")]
actors = [
{
"name": a.get("name"),
"character": a.get("character")
}
for a in (mediainfo.actors or [])
if a.get("name")
for a in actor_source[:ACTOR_PREVIEW_LIMIT]
]
# 构建基础媒体详情信息
@@ -88,12 +92,20 @@ class QueryMediaDetailTool(MoviePilotTool):
"status": mediainfo.status,
"genres": genres,
"directors": directors,
"actors": actors
"directors_total": len(director_source),
"directors_truncated": len(director_source) > DIRECTOR_PREVIEW_LIMIT,
"actors": actors,
"actors_total": len(actor_source),
"actors_truncated": len(actor_source) > ACTOR_PREVIEW_LIMIT,
}
# 如果是电视剧,添加电视剧特有信息
if mediainfo.type == MediaType.TV:
# 精简 season_info - 只保留基础摘要
season_source = [
s for s in (mediainfo.season_info or [])
if s.get("season_number") is not None
]
season_info = [
{
"season_number": s.get("season_number"),
@@ -101,8 +113,7 @@ class QueryMediaDetailTool(MoviePilotTool):
"episode_count": s.get("episode_count"),
"air_date": s.get("air_date")
}
for s in (mediainfo.season_info or [])
if s.get("season_number") is not None
for s in season_source[:SEASON_PREVIEW_LIMIT]
]
result.update({
@@ -110,7 +121,9 @@ class QueryMediaDetailTool(MoviePilotTool):
"number_of_episodes": mediainfo.number_of_episodes,
"first_air_date": mediainfo.first_air_date,
"last_air_date": mediainfo.last_air_date,
"season_info": season_info
"season_info": season_info,
"season_info_total": len(season_source),
"season_info_truncated": len(season_source) > SEASON_PREVIEW_LIMIT,
})
return json.dumps(result, ensure_ascii=False, indent=2)

View File

@@ -12,13 +12,15 @@ from app.helper.subscribe import SubscribeHelper
from app.log import logger
from app.schemas.types import MediaType, media_type_to_agent
MAX_PAGE_SIZE = 50
class QueryPopularSubscribesInput(BaseModel):
"""查询热门订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: str = Field(..., description="Allowed values: movie, tv")
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
@@ -69,6 +71,8 @@ class QueryPopularSubscribesTool(MoviePilotTool):
page = 1
if count is None or count < 1:
count = 30
# 外部统计接口支持传入 count这里做硬上限避免 Agent 一次拉取过多结果。
count = min(count, MAX_PAGE_SIZE)
media_type_enum = MediaType.from_agent(media_type)
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"

View File

@@ -11,6 +11,14 @@ from app.db.models.site import Site
from app.db.models.siteuserdata import SiteUserData
from app.log import logger
SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10
def _preview_list(value, limit: int = SITE_USERDATA_DETAIL_PREVIEW_LIMIT) -> tuple[list, int, bool]:
"""返回列表字段预览,避免做种明细或未读消息一次性撑大工具结果。"""
items = list(value) if isinstance(value, (list, tuple)) else []
return items[:limit], len(items), len(items) > limit
class QuerySiteUserdataInput(BaseModel):
"""查询站点用户数据工具的输入参数模型"""
@@ -110,6 +118,13 @@ class QuerySiteUserdataTool(MoviePilotTool):
else 0
)
seeding_preview, seeding_count, seeding_truncated = _preview_list(
user_data.seeding_info
)
unread_preview, unread_count, unread_truncated = _preview_list(
user_data.message_unread_contents
)
user_data_dict = {
"domain": user_data.domain,
"name": user_data.name,
@@ -131,13 +146,13 @@ class QuerySiteUserdataTool(MoviePilotTool):
"seeding_size_gb": round(seeding_size_gb, 2),
"leeching_size": user_data.leeching_size,
"leeching_size_gb": round(leeching_size_gb, 2),
"seeding_info": user_data.seeding_info
if user_data.seeding_info
else [],
"seeding_info_count": seeding_count,
"seeding_info": seeding_preview,
"seeding_info_truncated": seeding_truncated,
"message_unread": user_data.message_unread,
"message_unread_contents": user_data.message_unread_contents
if user_data.message_unread_contents
else [],
"message_unread_contents_count": unread_count,
"message_unread_contents": unread_preview,
"message_unread_contents_truncated": unread_truncated,
"err_msg": user_data.err_msg,
"updated_day": user_data.updated_day,
"updated_time": user_data.updated_time,

View File

@@ -9,13 +9,15 @@ from app.agent.tools.base import MoviePilotTool
from app.helper.subscribe import SubscribeHelper
from app.log import logger
MAX_PAGE_SIZE = 50
class QuerySubscribeSharesInput(BaseModel):
"""查询订阅分享工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)")
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
@@ -63,6 +65,8 @@ class QuerySubscribeSharesTool(MoviePilotTool):
page = 1
if count is None or count < 1:
count = 30
# 订阅分享是外部列表型结果,限制单页大小能降低工具上下文占用。
count = min(count, MAX_PAGE_SIZE)
subscribe_helper = SubscribeHelper()
shares = await subscribe_helper.async_get_shares(

View File

@@ -62,8 +62,8 @@ class QueryTransferHistoryTool(MoviePilotTool):
if page is None or page < 1:
page = 1
# 每页记录数
count = 50
# 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。
count = 30
# 获取数据库会话
async with AsyncSessionFactory() as db:

View File

@@ -115,9 +115,7 @@ class QueryWorkflowsTool(MoviePilotTool):
"last_time": wf.last_time,
"current_action": wf.current_action
}
# 如果有结果,添加结果信息
if wf.result:
simplified["result"] = wf.result
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
simplified_workflows.append(simplified)
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)

View File

@@ -73,7 +73,7 @@ class SearchMediaTool(MoviePilotTool):
filtered_results.append(result)
if filtered_results:
# 限制最多30条结果
# 搜索结果只返回前 30 条,后续可通过更精确的年份/类型条件缩小范围。
total_count = len(filtered_results)
limited_results = filtered_results[:30]
# 精简字段,只保留关键信息
@@ -96,8 +96,8 @@ class SearchMediaTool(MoviePilotTool):
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 100:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
if total_count > len(limited_results):
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_results)} 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到符合条件的媒体资源: {title}"

View File

@@ -35,7 +35,7 @@ class SearchPersonTool(MoviePilotTool):
persons = await media_chain.async_search_persons(name=name)
if persons:
# 限制最多30条结果
# 人物搜索结果只返回前 30 条,避免 biography/别名等字段挤占上下文。
total_count = len(persons)
limited_persons = persons[:30]
# 精简字段,只保留关键信息
@@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool):
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 50:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
if total_count > len(limited_persons):
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_persons)} 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到相关人物信息: {name}"

View File

@@ -28,7 +28,7 @@ class SearchWebInput(BaseModel):
)
max_results: Optional[int] = Field(
20,
description="Maximum number of search results to return (default: 5, max: 10)",
description="Maximum number of search results to return (default: 20, max: 20)",
)

View File

@@ -2,6 +2,7 @@ import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent.tools.base import format_tool_result_for_agent
from app.agent.tools.factory import MoviePilotToolFactory
from app.log import logger
@@ -237,22 +238,14 @@ class MoviePilotToolsManager:
# 规范化参数类型
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
# 调用工具的run方法
# 调用工具的run方法。HTTP/MCP 工具调用不会经过 BaseTool._arun
# 因此这里也必须复用同一套返回值格式化和兜底截断逻辑。
result = await tool_instance.run(**normalized_arguments)
# 确保返回字符串
if isinstance(result, str):
formated_result = result
elif isinstance(result, (int, float)):
formated_result = str(result)
else:
try:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
formated_result = str(result)
return formated_result
return format_tool_result_for_agent(
result,
tool_name=tool_name,
max_chars=getattr(tool_instance, "result_max_chars", None),
)
except Exception as e:
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
error_msg = json.dumps(

View File

@@ -0,0 +1,56 @@
import asyncio
import json
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
class _AsyncSessionContext:
"""为工作流查询工具提供最小异步 DB 上下文。"""
async def __aenter__(self):
return object()
async def __aexit__(self, exc_type, exc, tb):
return False
class TestQueryWorkflowsTool(unittest.TestCase):
def test_query_workflows_omits_large_result_field(self):
tool = QueryWorkflowsTool(session_id="session-1", user_id="10001")
workflow = SimpleNamespace(
id=1,
name="demo",
description="demo workflow",
state="S",
trigger_type="manual",
run_count=1,
timer=None,
event_type=None,
add_time="2026-05-08 10:00:00",
last_time="2026-05-08 10:01:00",
current_action=None,
result="x" * 10000,
)
workflow_oper = MagicMock()
workflow_oper.async_list = AsyncMock(return_value=[workflow])
with patch(
"app.agent.tools.impl.query_workflows.AsyncSessionFactory",
return_value=_AsyncSessionContext(),
), patch(
"app.agent.tools.impl.query_workflows.WorkflowOper",
return_value=workflow_oper,
):
result = asyncio.run(tool.run())
payload = json.loads(result)
self.assertEqual(len(payload), 1)
self.assertEqual(payload[0]["name"], "demo")
self.assertNotIn("result", payload[0])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,51 @@
import asyncio
import json
import unittest
from app.agent.tools.base import (
DEFAULT_TOOL_RESULT_MAX_CHARS,
MoviePilotTool,
format_tool_result_for_agent,
)
class OversizedResultTool(MoviePilotTool):
name: str = "oversized_result_tool"
description: str = "Tool used to verify result truncation."
async def run(self, **kwargs) -> str:
return "x" * (DEFAULT_TOOL_RESULT_MAX_CHARS + 100)
class TestAgentToolResultLimits(unittest.TestCase):
def test_arun_truncates_oversized_tool_result(self):
tool = OversizedResultTool(session_id="session-1", user_id="10001")
result = asyncio.run(tool._arun())
payload = json.loads(result)
self.assertTrue(payload["tool_result_truncated"])
self.assertEqual(payload["tool_name"], "oversized_result_tool")
self.assertEqual(payload["returned_chars"], DEFAULT_TOOL_RESULT_MAX_CHARS)
self.assertGreater(payload["total_chars"], payload["returned_chars"])
def test_formatter_preserves_sensitive_json_fields_for_agent_use(self):
result = format_tool_result_for_agent(
{
"cookie": "uid=abc; token=secret",
"nested": {
"api_key": "secret-key",
"plugin_author": "MoviePilot",
},
},
tool_name="sensitive_tool",
)
payload = json.loads(result)
self.assertEqual(payload["cookie"], "uid=abc; token=secret")
self.assertEqual(payload["nested"]["api_key"], "secret-key")
self.assertEqual(payload["nested"]["plugin_author"], "MoviePilot")
if __name__ == "__main__":
unittest.main()