From 0cd049bfc2f33628bff4da69acddc4abc9b60155 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Tue, 23 Jun 2026 10:05:45 +0800 Subject: [PATCH] Refactor movie pilot config and test coverage --- app/agent/__init__.py | 132 ++++++++- app/agent/llm/provider.py | 81 +++-- app/agent/middleware/activity_log.py | 37 ++- app/agent/middleware/jobs.py | 5 - app/agent/middleware/memory.py | 5 - app/agent/middleware/skills.py | 7 +- app/agent/middleware/subagents.py | 48 ++- app/agent/middleware/tool_selection.py | 16 - app/agent/runtime.py | 44 ++- app/agent/tools/base.py | 5 +- app/agent/tools/impl/_command_safety.py | 88 ++++++ app/agent/tools/impl/_filter_rule_utils.py | 100 ++++--- app/agent/tools/impl/_system_setting_utils.py | 58 +++- app/agent/tools/impl/_terminal_session.py | 33 +-- app/agent/tools/impl/browse_webpage.py | 5 + .../tools/impl/delete_download_history.py | 8 +- app/agent/tools/impl/edit_file.py | 6 +- app/agent/tools/impl/execute_command.py | 69 ++--- app/agent/tools/impl/query_site_userdata.py | 213 +++++++------- .../tools/impl/query_subscribe_history.py | 142 +++++---- app/agent/tools/impl/query_system_settings.py | 42 ++- .../tools/impl/query_transfer_history.py | 118 ++++---- app/agent/tools/impl/query_workflows.py | 133 ++++----- app/agent/tools/impl/read_file.py | 2 +- app/agent/tools/impl/run_workflow.py | 32 +- app/agent/tools/impl/update_site.py | 197 ++++++------- app/agent/tools/impl/update_subscribe.py | 277 +++++++++--------- .../tools/impl/update_system_settings.py | 42 ++- app/agent/tools/impl/write_file.py | 4 +- app/db/downloadhistory_oper.py | 6 + app/db/site_oper.py | 19 ++ app/db/subscribe_oper.py | 16 + app/db/subscribehistory_oper.py | 26 ++ app/db/transferhistory_oper.py | 45 +++ tests/test_agent_activity_log.py | 16 + tests/test_agent_graph_cache.py | 92 ++++++ tests/test_agent_query_workflows_tool.py | 13 - tests/test_agent_subagents.py | 6 +- tests/test_agent_system_settings_tools.py | 96 +++++- tests/test_execute_command_tool.py | 25 +- 40 files changed, 1481 insertions(+), 828 deletions(-) create mode 100644 app/agent/tools/impl/_command_safety.py create mode 100644 app/db/subscribehistory_oper.py create mode 100644 tests/test_agent_graph_cache.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 37c03960..401fb70e 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import json import re import traceback @@ -148,6 +149,16 @@ class _SessionUsageSnapshot: } +@dataclass +class _CompiledAgentBundle: + """会话内可复用的 Agent 图及其构造签名。""" + + signature: tuple[Any, ...] + agent: Any + streaming: bool + created_at: datetime + + class _ThinkTagStripper: """ 流式剥离 ... 标签的辅助类。 @@ -280,6 +291,8 @@ class MoviePilotAgent: self._llm_runtime_config: Optional[Dict[str, Any]] = None self._llm_provider_selection: Dict[str, Any] = {} self._agent_started_at: Optional[datetime] = None + self._compiled_agent_bundle: Optional[_CompiledAgentBundle] = None + self._last_agent_cache_hit = False # 流式token管理 self.stream_handler = StreamingHandler() @@ -980,6 +993,90 @@ class MoviePilotAgent: allow_message_tools=self.allow_message_tools, ) + def _refresh_tool_context(self, values: Dict[str, object]) -> None: + """ + 刷新本轮工具共享上下文。 + + 工具对象可能随会话内 Agent 图缓存被复用,因此这里保留 dict 对象本身, + 只替换其中内容,确保缓存工具看到的是最新权限与回复状态。 + """ + self._tool_context.clear() + self._tool_context.update(values) + + @staticmethod + def _public_runtime_config_signature(runtime_config: Dict[str, Any]) -> tuple: + """生成不包含密钥明文的 LLM 运行时签名。""" + api_key = runtime_config.get("api_key") or "" + api_key_digest = ( + hashlib.sha256(str(api_key).encode("utf-8")).hexdigest()[:12] + if api_key + else "" + ) + return ( + runtime_config.get("provider"), + runtime_config.get("model"), + api_key_digest, + runtime_config.get("base_url"), + runtime_config.get("base_url_preset"), + runtime_config.get("user_agent"), + bool(runtime_config.get("use_proxy")), + runtime_config.get("thinking_level"), + ) + + async def _agent_bundle_signature(self, streaming: bool) -> tuple[Any, ...]: + """构造会话内 Agent 图缓存签名。""" + runtime_config = await self._resolve_llm_runtime_config() + return ( + streaming, + self.channel, + self.source, + self.user_id, + self.username, + self.allow_message_tools, + bool(self._tool_context.get("is_admin")), + self.has_message_context, + self.is_background, + settings.AI_AGENT_VERBOSE, + settings.LLM_MAX_TOOLS, + settings.LLM_MAX_ITERATIONS, + self._public_runtime_config_signature(runtime_config), + agent_runtime_manager.current_signature(), + ) + + def _get_cached_agent( + self, signature: tuple[Any, ...], streaming: bool + ) -> Optional[Any]: + """按签名读取当前会话已编译的 Agent 图。""" + bundle = self._compiled_agent_bundle + if ( + bundle + and bundle.streaming == streaming + and bundle.signature == signature + ): + return bundle.agent + return None + + def _cache_agent( + self, + *, + signature: tuple[Any, ...], + agent: Any, + streaming: bool, + ) -> Any: + """保存当前会话可复用的 Agent 图。""" + self._compiled_agent_bundle = _CompiledAgentBundle( + signature=signature, + agent=agent, + streaming=streaming, + created_at=datetime.now(), + ) + return agent + + @staticmethod + def _latest_turn_messages(messages: List[BaseMessage]) -> List[BaseMessage]: + """从完整历史中提取本轮新增用户消息。""" + return [messages[-1]] if messages else [] + def _initialize_subagent_tools(self) -> List: """ 初始化子代理专用静默工具列表。 @@ -1006,6 +1103,13 @@ class MoviePilotAgent: :param streaming: 是否启用流式输出 """ try: + bundle_signature = await self._agent_bundle_signature(streaming) + cached_agent = self._get_cached_agent(bundle_signature, streaming) + self._last_agent_cache_hit = bool(cached_agent) + if cached_agent: + logger.debug(f"复用会话内 Agent 图: session_id={self.session_id}") + return cached_agent + # 系统提示词 system_prompt = prompt_manager.get_agent_prompt(channel=self.channel) @@ -1113,13 +1217,18 @@ class MoviePilotAgent: ) ) - return create_agent( + agent = create_agent( model=agent_model, tools=[*tools, *skill_tools, *activity_log_tools], system_prompt=system_prompt, middleware=middlewares, checkpointer=InMemorySaver(), ) + return self._cache_agent( + signature=bundle_signature, + agent=agent, + streaming=streaming, + ) except Exception as e: logger.error(f"创建 Agent 失败: {e}") raise e @@ -1137,12 +1246,15 @@ class MoviePilotAgent: user_display_saved = False try: logger.info( - f"Agent推理: session_id={self.session_id}, input={message}, " + f"Agent推理: session_id={self.session_id}, " + f"input_chars={len(message or '')}, " f"images={len(images) if images else 0}, files={len(files) if files else 0}, " f"audio_input={has_audio_input}" ) - self._tool_context = await self._build_tool_context( - should_dispatch_reply=self.should_dispatch_reply + self._refresh_tool_context( + await self._build_tool_context( + should_dispatch_reply=self.should_dispatch_reply + ) ) self._streamed_output = "" @@ -1330,6 +1442,11 @@ class MoviePilotAgent: # 创建智能体(根据是否流式传入不同 LLM) agent = await self._create_agent(streaming=use_streaming) + input_messages = ( + self._latest_turn_messages(messages) + if self._last_agent_cache_hit + else messages + ) if use_streaming: self.stream_handler.set_dispatch_policy( @@ -1348,7 +1465,7 @@ class MoviePilotAgent: # 流式运行智能体,token 直接推送到 stream_handler await self._stream_agent_tokens( agent=agent, - messages={"messages": messages}, + messages={"messages": input_messages}, config=agent_config, on_token=self._handle_stream_text, ) @@ -1387,7 +1504,7 @@ class MoviePilotAgent: else: # 非流式模式:后台任务或渠道不支持消息编辑 await agent.ainvoke( - {"messages": messages}, + {"messages": input_messages}, config=agent_config, ) @@ -1446,9 +1563,11 @@ class MoviePilotAgent: except asyncio.CancelledError: logger.info(f"Agent执行被取消: session_id={self.session_id}") + self._compiled_agent_bundle = None execution_error = "任务已取消" return "任务已取消", {} except Exception as e: + self._compiled_agent_bundle = None execution_error = str(e) if self._messages_have_image_input(messages) and self._is_unsupported_image_input_error(e): logger.warning( @@ -1492,6 +1611,7 @@ class MoviePilotAgent: """ 清理智能体资源 """ + self._compiled_agent_bundle = None logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}") diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py index 1905bcce..70bf98b6 100644 --- a/app/agent/llm/provider.py +++ b/app/agent/llm/provider.py @@ -1628,6 +1628,37 @@ class LLMProviderManager(metaclass=Singleton): ) return None + def _resolve_cached_model_record( + self, + provider_id: str, + model_id: Optional[str], + base_url: Optional[str] = None, + base_url_preset_id: Optional[str] = None, + transport: str = "openai", + ) -> dict[str, Any] | None: + """从缓存中的模型元数据构造轻量模型记录,不触发远端模型列表刷新。""" + if not model_id: + return None + metadata = self.resolve_cached_model_metadata( + provider_id, + model_id, + base_url=base_url, + base_url_preset_id=base_url_preset_id, + ) or {} + if not metadata: + return self._normalize_model_record( + model_id=model_id, + transport=transport, + source="configured", + ) + return self._normalize_model_record( + model_id=model_id, + display_name=metadata.get("name") or model_id, + metadata=metadata, + transport=transport, + source="models.dev-cache", + ) + @staticmethod def _normalize_model_record( model_id: str, @@ -2104,7 +2135,7 @@ class LLMProviderManager(metaclass=Singleton): try: return jwt.decode(token, options={"verify_signature": False}) except Exception as err: - print(err) + logger.debug(f"解析 JWT token 内容失败: {err}") return {} @staticmethod @@ -2587,40 +2618,29 @@ class LLMProviderManager(metaclass=Singleton): ) normalized_api_key = str(api_key or "").strip() or None normalized_base_url = self._sanitize_base_url(base_url) - model_record = None - if model: - try: - model_record = next( - ( - item - for item in await self.list_models( - normalized_provider_id, - api_key=api_key, - base_url=base_url, - base_url_preset_id=normalized_base_url_preset_id, - user_agent=user_agent, - use_proxy=use_proxy, - ) - if item["id"] == model - ), - None, - ) - except Exception as err: - print(err) - model_record = None + default_transport = ( + "anthropic" if resolved_runtime == "anthropic_compatible" else "openai" + ) + model_record = self._resolve_cached_model_record( + normalized_provider_id, + model, + base_url=base_url, + base_url_preset_id=normalized_base_url_preset_id, + transport=default_transport, + ) + model_metadata = self.resolve_cached_model_metadata( + normalized_provider_id, + model, + base_url=base_url, + base_url_preset_id=normalized_base_url_preset_id, + ) result: dict[str, Any] = { "provider_id": normalized_provider_id, "runtime": resolved_runtime, "model_id": model, "model_record": model_record, - "model_metadata": await self.resolve_model_metadata( - normalized_provider_id, - model, - base_url=base_url, - base_url_preset_id=normalized_base_url_preset_id, - use_proxy=use_proxy, - ), + "model_metadata": model_metadata, "default_headers": None, "use_responses_api": None, "auth_mode": "api_key", @@ -2631,8 +2651,7 @@ class LLMProviderManager(metaclass=Singleton): try: auth = await self._resolve_chatgpt_oauth() except Exception as err: - print(err) - pass + logger.debug(f"解析 ChatGPT OAuth 鉴权失败,回退 API Key 模式: {err}") if auth: headers = {"originator": "moviepilot"} diff --git a/app/agent/middleware/activity_log.py b/app/agent/middleware/activity_log.py index 825b9dc4..01ce6c25 100644 --- a/app/agent/middleware/activity_log.py +++ b/app/agent/middleware/activity_log.py @@ -7,12 +7,14 @@ """ import json +import os import re from collections.abc import Awaitable, Callable from datetime import datetime, timedelta from pathlib import Path from typing import Annotated, Any, NotRequired, Optional, TypedDict +import anyio from anyio import Path as AsyncPath from langchain.agents.middleware.types import ( AgentMiddleware, @@ -579,14 +581,29 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response entry = f"- **{now_str}** {summary}\n" try: if await log_path.exists(): - existing = await log_path.read_text(encoding="utf-8", errors="replace") - await log_path.write_text(existing + entry, encoding="utf-8") + async with await anyio.open_file( + log_path, + mode="a", + encoding="utf-8", + ) as stream: + await stream.write(entry) else: header = f"# {today_str} 活动日志\n\n" - await log_path.write_text(header + entry, encoding="utf-8") - logger.debug("Activity logged: %s", summary[:80]) + try: + fd = os.open(log_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + except FileExistsError: + async with await anyio.open_file( + log_path, + mode="a", + encoding="utf-8", + ) as stream: + await stream.write(entry) + else: + with os.fdopen(fd, "w", encoding="utf-8") as stream: + stream.write(header + entry) + logger.debug(f"Activity logged: {summary[:80]}") except Exception as e: - logger.warning("Failed to append activity log: %s", e) + logger.warning(f"Failed to append activity log: {e}") async def _cleanup_old_logs(self) -> None: """清理超过保留天数的旧日志文件。""" @@ -608,20 +625,16 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response file_date = datetime.strptime(match.group(1), "%Y-%m-%d").date() if file_date < cutoff_date: await path.unlink() - logger.debug("Cleaned up old activity log: %s", path.name) + logger.debug(f"Cleaned up old activity log: {path.name}") except ValueError: continue except Exception as e: - logger.warning("Failed to cleanup old activity logs: %s", e) + logger.warning(f"Failed to cleanup old activity logs: {e}") async def abefore_agent( self, state: ActivityLogState, runtime: Runtime ) -> Optional[ActivityLogStateUpdate]: """在 Agent 执行前加载近期活动日志。""" - # 如果已经加载则跳过 - if "activity_log_contents" in state: - return None - contents = await self._load_recent_logs() # 趁机清理旧日志(低频操作,不影响性能) @@ -709,7 +722,7 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response if summary: await self._append_activity(summary) except Exception as e: - logger.warning("Failed to record activity: %s", e) + logger.warning(f"Failed to record activity: {e}") return None diff --git a/app/agent/middleware/jobs.py b/app/agent/middleware/jobs.py index 9dc3602e..d70190d3 100644 --- a/app/agent/middleware/jobs.py +++ b/app/agent/middleware/jobs.py @@ -283,12 +283,7 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa ) -> JobsStateUpdate | None: """在 Agent 执行前异步加载任务元数据。 - 每个会话仅加载一次。若 state 中已有则跳过。 """ - # 如果 state 中已存在元数据则跳过 - if "jobs_metadata" in state: - return None - return JobsStateUpdate( jobs_metadata=await load_jobs_metadata(self.sources) ) diff --git a/app/agent/middleware/memory.py b/app/agent/middleware/memory.py index 077d9c95..365031f7 100644 --- a/app/agent/middleware/memory.py +++ b/app/agent/middleware/memory.py @@ -302,7 +302,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no """在代理执行前扫描记忆目录并加载所有 .md 文件的内容。 自动发现目录下所有 `.md` 文件并加载其内容到状态中。 - 如果状态中尚未存在则进行加载。 同时检测记忆文件是否为空,设置 memory_empty 标志位, 以便在系统提示词中触发初始化引导流程。 @@ -314,10 +313,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no 返回: 填充了 memory_contents 和 memory_empty 的状态更新。 """ - # 如果已经加载则跳过 - if "memory_contents" in state: - return None - # 扫描目录下所有 .md 文件 md_files = await self._scan_memory_files() diff --git a/app/agent/middleware/skills.py b/app/agent/middleware/skills.py index 1c64eea9..15ec7e72 100644 --- a/app/agent/middleware/skills.py +++ b/app/agent/middleware/skills.py @@ -322,7 +322,7 @@ def _extract_version(skill_md: Path) -> int: try: content = skill_md.read_text(encoding="utf-8", errors="replace") except Exception as err: - print(err) + logger.debug(f"读取技能版本失败: {err}") return 0 match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) if not match: @@ -627,13 +627,8 @@ class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # no ) -> SkillsStateUpdate | None: # ty: ignore[invalid-method-override] """在 Agent 执行前异步加载技能元数据。 - 每个会话仅加载一次。若 state 中已有则跳过。 首次加载时,会先将内置技能同步到用户目录(如不存在)。 """ - # 如果 state 中已存在元数据则跳过 - if "skills_metadata" in state: - return None - self._sync_bundled_skills() all_skills: dict[str, SkillMetadata] = {} diff --git a/app/agent/middleware/subagents.py b/app/agent/middleware/subagents.py index 0b2e5159..ee4356e2 100644 --- a/app/agent/middleware/subagents.py +++ b/app/agent/middleware/subagents.py @@ -197,18 +197,44 @@ def is_subagent_stream_metadata(metadata: Any) -> bool: ) == SUBAGENT_STREAM_MARKER_VALUE: return True - return bool(metadata.get("lc_agent_name") in builtin_subagent_names()) + return bool( + metadata.get("lc_agent_name") + in builtin_subagent_names(agent_runtime_manager.current_signature()) + ) -@lru_cache(maxsize=1) -def builtin_subagent_names() -> frozenset[str]: +def builtin_subagent_names( + runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None, +) -> frozenset[str]: """返回内置子代理名称集合。""" - return frozenset(profile.name for profile in _builtin_subagent_profiles()) + runtime_signature = runtime_signature or agent_runtime_manager.current_signature() + return _cached_builtin_subagent_names(runtime_signature) -@lru_cache(maxsize=1) -def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]: +@lru_cache(maxsize=8) +def _cached_builtin_subagent_names( + runtime_signature: tuple[tuple[str, int, int], ...], +) -> frozenset[str]: + """按运行时签名缓存内置子代理名称集合。""" + return frozenset( + profile.name + for profile in _builtin_subagent_profiles(runtime_signature) + ) + + +def _builtin_subagent_profiles( + runtime_signature: Optional[tuple[tuple[str, int, int], ...]] = None, +) -> tuple[_SubAgentProfile, ...]: """从运行时配置目录加载 MoviePilot 子代理定义。""" + runtime_signature = runtime_signature or agent_runtime_manager.current_signature() + return _cached_builtin_subagent_profiles(runtime_signature) + + +@lru_cache(maxsize=8) +def _cached_builtin_subagent_profiles( + runtime_signature: tuple[tuple[str, int, int], ...], +) -> tuple[_SubAgentProfile, ...]: + """按运行时签名缓存 MoviePilot 子代理定义。""" definitions = agent_runtime_manager.list_subagents() profiles = tuple( _profile_from_runtime_definition(definition) @@ -237,6 +263,10 @@ def _builtin_subagent_profiles() -> tuple[_SubAgentProfile, ...]: ) +builtin_subagent_names.cache_clear = _cached_builtin_subagent_names.cache_clear +_builtin_subagent_profiles.cache_clear = _cached_builtin_subagent_profiles.cache_clear + + def _profile_from_runtime_definition( definition: SubAgentDefinition, ) -> _SubAgentProfile: @@ -1044,6 +1074,7 @@ class SubAgentTaskControlMiddleware(AgentMiddleware): if unfinished_records: logger.info(f"Agent 结束,取消未完成子代理任务: tasks={len(unfinished_records)}") await self._cancel_records(unfinished_records) + self._tasks.clear() async def awrap_tool_call( self, @@ -1083,9 +1114,8 @@ def create_subagent_middlewares( stream_handler: Any = None, ) -> tuple[list[AgentMiddleware], list[BaseTool]]: """创建子代理中间件列表和任务工具列表。""" - _builtin_subagent_profiles.cache_clear() - builtin_subagent_names.cache_clear() - profiles = _builtin_subagent_profiles() + runtime_signature = agent_runtime_manager.current_signature() + profiles = _builtin_subagent_profiles(runtime_signature) subagent_middleware = MoviePilotSubAgentMiddleware( model=model, profiles=profiles, diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 16a6b07b..82a704b8 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -592,22 +592,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): 这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果, 不会为每次模型回合重复追加一笔 selector LLM 开销。 """ - if "selected_tool_names" in state: - self._log_selection_attempt( - _ToolSelectionAttempt( - request=ModelRequest( - model=self.model, - tools=list(self.selection_tools), - messages=state["messages"], - state=state, - runtime=runtime, - ), - selected_tool_names=state.get("selected_tool_names") or [], - status="reused", - ) - ) - return None - if not self.selection_tools or self.model is None: detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型" self._log_selection_attempt( diff --git a/app/agent/runtime.py b/app/agent/runtime.py index 648a0fd0..0a57618f 100644 --- a/app/agent/runtime.py +++ b/app/agent/runtime.py @@ -5,6 +5,7 @@ from __future__ import annotations import re import shutil import threading +import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterable, Optional @@ -243,9 +244,15 @@ class AgentRuntimeManager: self._cache_lock = threading.Lock() self._cached_signature: Optional[tuple[tuple[str, int, int], ...]] = None self._cached_config: Optional[AgentRuntimeConfig] = None + self._cached_signature_checked_at = 0.0 + self._signature_check_interval = 1.0 + self._layout_ready = False def ensure_layout(self) -> None: """创建目录、同步默认文件,并清理废弃的旧版 runtime 文件。""" + with self._cache_lock: + if self._layout_ready: + return self.agent_root_dir.mkdir(parents=True, exist_ok=True) self.runtime_dir.mkdir(parents=True, exist_ok=True) self.memory_dir.mkdir(parents=True, exist_ok=True) @@ -257,11 +264,13 @@ class AgentRuntimeManager: self._remove_obsolete_runtime_files() self._sync_bundled_defaults() self._migrate_root_memory_files() + with self._cache_lock: + self._layout_ready = True def load_runtime_config(self) -> AgentRuntimeConfig: """加载配置。用户目录损坏时自动回退到内置默认配置。""" self.ensure_layout() - signature = self._build_signature() + signature = self.current_signature() with self._cache_lock: if self._cached_signature == signature and self._cached_config: return self._cached_config @@ -269,7 +278,7 @@ class AgentRuntimeManager: try: config = self._load_from_root(self.runtime_dir) except AgentRuntimeConfigError as err: - logger.warning("Agent 根层配置无效,回退到内置默认配置: %s", err) + logger.warning(f"Agent 根层配置无效,回退到内置默认配置: {err}") config = self._load_from_root(self.bundled_defaults_dir) config.used_fallback = True config.warnings.insert( @@ -285,6 +294,25 @@ class AgentRuntimeManager: with self._cache_lock: self._cached_signature = None self._cached_config = None + self._cached_signature_checked_at = 0.0 + self._layout_ready = False + + def current_signature(self) -> tuple[tuple[str, int, int], ...]: + """返回当前运行时配置文件签名,供调用方判断缓存是否仍可复用。""" + now = time.monotonic() + with self._cache_lock: + if ( + self._cached_signature is not None + and now - self._cached_signature_checked_at + < self._signature_check_interval + ): + return self._cached_signature + + signature = self._build_signature() + with self._cache_lock: + self._cached_signature = signature + self._cached_signature_checked_at = now + return signature def set_active_persona(self, persona_query: str) -> AgentRuntimeConfig: """切换当前激活人格,并立即刷新缓存。""" @@ -308,7 +336,7 @@ class AgentRuntimeManager: ) current_path.write_text(document, encoding="utf-8") self.invalidate_cache() - logger.info("已切换 Agent 人格: %s", persona.persona_id) + logger.info(f"已切换 Agent 人格: {persona.persona_id}") return self.load_runtime_config() def list_personas(self) -> list[PersonaDefinition]: @@ -439,7 +467,7 @@ class AgentRuntimeManager: continue target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(path, target) - logger.info("已同步默认 Agent 运行时文件: %s", target) + logger.info(f"已同步默认 Agent 运行时文件: {target}") @classmethod def _should_update_bundled_subagent( @@ -478,7 +506,7 @@ class AgentRuntimeManager: return target.parent.mkdir(parents=True, exist_ok=True) source.rename(target) - logger.info("已迁移旧版 Agent 根配置文件: %s -> %s", source, target) + logger.info(f"已迁移旧版 Agent 根配置文件: {source} -> {target}") def _remove_obsolete_runtime_files(self) -> None: """删除不再支持的旧版 Agent 配置文件,避免被误迁移到 memory。""" @@ -487,14 +515,14 @@ class AgentRuntimeManager: if not path.exists() or not path.is_file(): continue path.unlink() - logger.info("已删除废弃的 Agent 根配置文件: %s", path) + logger.info(f"已删除废弃的 Agent 根配置文件: {path}") for relative_path in sorted(OBSOLETE_RUNTIME_FILES): path = self.runtime_dir / relative_path if not path.exists() or not path.is_file(): continue path.unlink() - logger.info("已删除废弃的 Agent 运行时文件: %s", path) + logger.info(f"已删除废弃的 Agent 运行时文件: {path}") def _migrate_root_memory_files(self) -> None: """将旧版根目录 memory 文件移入 `config/agent/memory`。""" @@ -505,7 +533,7 @@ class AgentRuntimeManager: if target.exists(): continue path.rename(target) - logger.info("已迁移旧版 Agent memory 文件: %s -> %s", path, target) + logger.info(f"已迁移旧版 Agent memory 文件: {path} -> {target}") def _load_from_root(self, root: Path) -> AgentRuntimeConfig: current_persona_path = root / CURRENT_PERSONA_FILE diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 36eebaf0..9f715698 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -431,7 +431,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): :return: 普通用户允许读写的本地目录列表 """ roots = [ - settings.CONFIG_PATH / "agent" + settings.CONFIG_PATH / "agent", + settings.LOG_PATH, ] resolved_roots = [] for root in roots: @@ -467,7 +468,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): allowed_text = "、".join(str(root) for root in allowed_roots) return ( resolved_path, - f"抱歉,普通用户只能{operation}配置目录、Agent记忆目录和日志目录内的文件或目录:{allowed_text}", + f"抱歉,普通用户只能{operation}Agent配置目录和日志目录内的文件或目录:{allowed_text}", ) async def _check_local_storage_access( diff --git a/app/agent/tools/impl/_command_safety.py b/app/agent/tools/impl/_command_safety.py new file mode 100644 index 00000000..45e6cfc2 --- /dev/null +++ b/app/agent/tools/impl/_command_safety.py @@ -0,0 +1,88 @@ +"""Agent 命令工具的安全校验逻辑。""" + +from __future__ import annotations + +import os.path +import re +import shlex + + +COMMAND_FORBIDDEN_KEYWORDS = ( + ":(){ :|:& };:", + "dd if=/dev/zero", + "mkfs", + "reboot", + "shutdown", +) + +COMMAND_DANGEROUS_PATTERNS = ( + re.compile(r"\brm\s+[^;&|]*-[^\s;&|]*[rR][fF]?[^\s;&|]*\s+/(?:\s|$|[;&|])"), + re.compile(r"\bdd\s+[^;&|]*(?:of=/dev/(?:sd[a-z]\d*|nvme\d+n\d+p?\d*|disk\d+)|if=/dev/zero)"), + re.compile(r"\b(?:mkfs|fdisk|parted|diskutil)\b"), + re.compile(r"\b(?:chmod|chown)\s+[^;&|]*-R[^;&|]*\s+/(?:\s|$|[;&|])"), + re.compile(r"\b(?:reboot|shutdown|halt|poweroff)\b"), +) + + +def _command_tokens(command: str) -> list[str]: + """尽力解析 shell 命令 token,解析失败时退回空白分割。""" + try: + return shlex.split(command, posix=True) + except ValueError: + return re.split(r"\s+", command.strip()) + + +def _contains_recursive_root_delete(command: str) -> bool: + """识别递归删除根目录或一级目录的 rm 命令。""" + tokens = _command_tokens(command) + if not any(token == "rm" or token.endswith("/rm") for token in tokens): + return False + has_recursive = any( + token.startswith("-") and ("r" in token or "R" in token) + for token in tokens + ) + if not has_recursive: + return False + + for token in tokens: + clean_token = re.match(r"^([^;|&><]+)", token) + if not clean_token: + continue + path_value = clean_token.group(1).strip("\"'") + if not path_value.startswith("/"): + continue + norm_path = os.path.normpath(path_value) + if norm_path == "/" or re.match(r"^/[^/]+$", norm_path): + return True + return False + + +def detect_dangerous_command(command: str) -> str: + """返回危险命令原因,安全时返回空字符串。""" + normalized = str(command or "").strip() + if not normalized: + return "命令不能为空" + for keyword in COMMAND_FORBIDDEN_KEYWORDS: + if keyword in normalized: + return f"命令包含禁止使用的关键字 '{keyword}'" + if _contains_recursive_root_delete(normalized): + return "命令疑似递归删除根目录或一级目录" + for pattern in COMMAND_DANGEROUS_PATTERNS: + if pattern.search(normalized): + return "命令匹配高危系统操作模式" + return "" + + +def validate_command_safety(command: str, *, confirmed: bool = False) -> None: + """ + 校验 shell 命令安全性。 + + :param command: 待执行命令 + :param confirmed: 是否已经通过显式参数确认高危操作 + """ + reason = detect_dangerous_command(command) + if not reason: + return + if confirmed and reason != "命令不能为空": + return + raise ValueError(f"{reason}。如确认需要执行,请设置 confirm_dangerous=true") diff --git a/app/agent/tools/impl/_filter_rule_utils.py b/app/agent/tools/impl/_filter_rule_utils.py index 25dfcc49..c62e07e8 100644 --- a/app/agent/tools/impl/_filter_rule_utils.py +++ b/app/agent/tools/impl/_filter_rule_utils.py @@ -5,8 +5,7 @@ import re from typing import Any, Dict, Iterable, Optional from app.core.event import eventmanager -from app.db import AsyncSessionFactory -from app.db.models.subscribe import Subscribe +from app.db.subscribe_oper import SubscribeOper from app.db.systemconfig_oper import SystemConfigOper from app.helper.rule import RuleHelper from app.modules.filter.RuleParser import RuleParser @@ -284,23 +283,22 @@ async def collect_rule_group_usages( continue ensure_usage(name)["used_in_global_best_version"] = True - async with AsyncSessionFactory() as db: - subscribes = await Subscribe.async_list(db) - for subscribe in subscribes: - filter_groups = subscribe.filter_groups or [] - for name in filter_groups: - if target_names and name not in target_names: - continue - ensure_usage(name)["subscribes"].append( - { - "subscribe_id": subscribe.id, - "name": subscribe.name, - "season": subscribe.season, - "type": subscribe.type, - "username": subscribe.username, - "best_version": bool(subscribe.best_version), - } - ) + subscribes = await SubscribeOper().async_list() + for subscribe in subscribes: + filter_groups = subscribe.filter_groups or [] + for name in filter_groups: + if target_names and name not in target_names: + continue + ensure_usage(name)["subscribes"].append( + { + "subscribe_id": subscribe.id, + "name": subscribe.name, + "season": subscribe.season, + "type": subscribe.type, + "username": subscribe.username, + "best_version": bool(subscribe.best_version), + } + ) return usage_map @@ -482,22 +480,22 @@ async def rename_rule_group_references(old_name: str, new_name: str) -> dict: await save_system_config(config_key, updated) changed["global_settings"][config_key.value] = updated - async with AsyncSessionFactory() as db: - subscribes = await Subscribe.async_list(db) - for subscribe in subscribes: - original = subscribe.filter_groups or [] - updated = replace_group_name_in_list(original, old_name, new_name) - if updated == original: - continue - await subscribe.async_update(db, {"filter_groups": updated}) - changed["subscribes"].append( - { - "subscribe_id": subscribe.id, - "name": subscribe.name, - "season": subscribe.season, - "filter_groups": updated, - } - ) + subscribe_oper = SubscribeOper() + subscribes = await subscribe_oper.async_list() + for subscribe in subscribes: + original = subscribe.filter_groups or [] + updated = replace_group_name_in_list(original, old_name, new_name) + if updated == original: + continue + await subscribe_oper.async_update_filter_groups(subscribe.id, updated) + changed["subscribes"].append( + { + "subscribe_id": subscribe.id, + "name": subscribe.name, + "season": subscribe.season, + "filter_groups": updated, + } + ) return changed @@ -520,21 +518,21 @@ async def remove_rule_group_references(group_name: str) -> dict: await save_system_config(config_key, updated) changed["global_settings"][config_key.value] = updated - async with AsyncSessionFactory() as db: - subscribes = await Subscribe.async_list(db) - for subscribe in subscribes: - original = subscribe.filter_groups or [] - updated = [value for value in original if value != group_name] - if updated == original: - continue - await subscribe.async_update(db, {"filter_groups": updated}) - changed["subscribes"].append( - { - "subscribe_id": subscribe.id, - "name": subscribe.name, - "season": subscribe.season, - "filter_groups": updated, - } - ) + subscribe_oper = SubscribeOper() + subscribes = await subscribe_oper.async_list() + for subscribe in subscribes: + original = subscribe.filter_groups or [] + updated = [value for value in original if value != group_name] + if updated == original: + continue + await subscribe_oper.async_update_filter_groups(subscribe.id, updated) + changed["subscribes"].append( + { + "subscribe_id": subscribe.id, + "name": subscribe.name, + "season": subscribe.season, + "filter_groups": updated, + } + ) return changed diff --git a/app/agent/tools/impl/_system_setting_utils.py b/app/agent/tools/impl/_system_setting_utils.py index 774217f5..b9b131a1 100644 --- a/app/agent/tools/impl/_system_setting_utils.py +++ b/app/agent/tools/impl/_system_setting_utils.py @@ -1,7 +1,7 @@ """系统设置工具共用的键解析与分组元数据。""" from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from app.core.config import Settings from app.schemas.types import SystemConfigKey @@ -15,6 +15,7 @@ class SettingSpec: source: str group: str label: str + systemconfig_key: Optional[SystemConfigKey] = None SYSTEMCONFIG_SETTING_METADATA = { @@ -234,6 +235,7 @@ def _build_specs() -> tuple[dict[str, SettingSpec], dict[str, SettingSpec]]: source="systemconfig", group=metadata.get("group", "misc"), label=metadata.get("label", item.value), + systemconfig_key=item, ) return core_specs, system_specs @@ -333,3 +335,57 @@ def list_setting_specs( def get_default_list_match_field(setting_key: str) -> Optional[str]: return LIST_ITEM_MATCH_FIELD_DEFAULTS.get(setting_key) + + +SECRET_KEYWORDS = ( + "api_key", + "apikey", + "token", + "secret", + "password", + "passwd", + "cookie", + "authorization", + "refresh_token", + "access_token", +) + + +def is_secret_setting_key(key: str) -> bool: + """判断设置键名是否疑似敏感字段。""" + normalized = _normalize_token(key) + return any(keyword in normalized for keyword in SECRET_KEYWORDS) + + +def redact_secret_value(value: Any, *, redact_scalar: bool = False) -> Any: + """递归脱敏配置值中的密钥、Cookie、Token 等敏感字段。""" + if isinstance(value, dict): + return { + key: "***" + if is_secret_setting_key(str(key)) + else redact_secret_value(item, redact_scalar=redact_scalar) + for key, item in value.items() + } + if isinstance(value, list): + return [ + redact_secret_value(item, redact_scalar=redact_scalar) + for item in value + ] + if isinstance(value, str): + return "***" if value and redact_scalar else value + return value + + +def should_redact_setting(spec: SettingSpec, value: Any) -> bool: + """判断某项设置在默认查询响应中是否需要脱敏。""" + if is_secret_setting_key(spec.key): + return True + if isinstance(value, dict): + return any(is_secret_setting_key(str(key)) for key in value.keys()) + if isinstance(value, list): + return any( + should_redact_setting(spec, item) + for item in value + if isinstance(item, dict) + ) + return False diff --git a/app/agent/tools/impl/_terminal_session.py b/app/agent/tools/impl/_terminal_session.py index 55e0d91e..ba265057 100644 --- a/app/agent/tools/impl/_terminal_session.py +++ b/app/agent/tools/impl/_terminal_session.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional +from app.agent.tools.impl._command_safety import validate_command_safety from app.core.config import settings from app.log import logger @@ -34,14 +35,6 @@ TERMINAL_PTY_POLL_INTERVAL = 0.05 TERMINAL_WAIT_DEFAULT_MS = 1000 TERMINAL_WAIT_MAX_MS = 60 * 1000 TERMINAL_KILL_GRACE_SECONDS = 3 -TERMINAL_FORBIDDEN_KEYWORDS = ( - "rm -rf /", - ":(){ :|:& };:", - "dd if=/dev/zero", - "mkfs", - "reboot", - "shutdown", -) @dataclass @@ -176,13 +169,9 @@ class _TerminalSessionManager: return merged_env @staticmethod - def _validate_command(command: str) -> None: + def _validate_command(command: str, *, confirmed: bool = False) -> None: """拒绝明显危险或空白命令。""" - if not command or not command.strip(): - raise ValueError("命令不能为空") - for keyword in TERMINAL_FORBIDDEN_KEYWORDS: - if keyword in command: - raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'") + validate_command_safety(command, confirmed=confirmed) @staticmethod def _set_nonblocking(fd: int) -> None: @@ -213,9 +202,10 @@ class _TerminalSessionManager: cwd: Optional[str] = None, env: Optional[dict[str, Any]] = None, use_pty: Any = True, + confirm_dangerous: bool = False, ) -> dict[str, Any]: """启动后台命令并立即返回会话 ID。""" - self._validate_command(command) + self._validate_command(command, confirmed=confirm_dangerous) normalized_cwd = self._normalize_cwd(cwd) normalized_env = self._build_env(env) should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix" @@ -313,7 +303,10 @@ class _TerminalSessionManager: continue except OSError as err: if err.errno not in {errno.EIO, errno.EBADF}: - logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err) + logger.debug( + f"PTY 输出读取异常: session_id={session.session_id}, " + f"error={err}" + ) break if not data: @@ -343,7 +336,9 @@ class _TerminalSessionManager: session.mark_finished(session.exit_code) except Exception as err: session.mark_error(str(err)) - logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err) + logger.warning( + f"等待 PTY 进程失败: session_id={session.session_id}, error={err}" + ) finally: await self._finish_reader_tasks(session) session.close_pty() @@ -358,7 +353,9 @@ class _TerminalSessionManager: session.mark_finished(exit_code) except Exception as err: session.mark_error(str(err)) - logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err) + logger.warning( + f"等待管道进程失败: session_id={session.session_id}, error={err}" + ) finally: await self._finish_reader_tasks(session) diff --git a/app/agent/tools/impl/browse_webpage.py b/app/agent/tools/impl/browse_webpage.py index c1e657a6..d576f573 100644 --- a/app/agent/tools/impl/browse_webpage.py +++ b/app/agent/tools/impl/browse_webpage.py @@ -228,6 +228,11 @@ class BrowseWebpageTool(MoviePilotTool): return "错误: 'fill_ref' 操作需要提供 value 参数" if browser_action == BrowserAction.EVALUATE and not script: return "错误: 'evaluate' 操作需要提供 script 参数" + if ( + browser_action == BrowserAction.EVALUATE + and not await self.is_admin_user() + ): + return "错误: 'evaluate' 操作仅允许管理员使用" if ( browser_action in (BrowserAction.FOCUS_TAB, BrowserAction.CLOSE_TAB) and tab_index is None diff --git a/app/agent/tools/impl/delete_download_history.py b/app/agent/tools/impl/delete_download_history.py index f2c1fb26..5a3753d0 100644 --- a/app/agent/tools/impl/delete_download_history.py +++ b/app/agent/tools/impl/delete_download_history.py @@ -6,8 +6,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag -from app.db import AsyncSessionFactory -from app.db.models.downloadhistory import DownloadHistory +from app.db.downloadhistory_oper import DownloadHistoryOper from app.log import logger @@ -40,9 +39,8 @@ class DeleteDownloadHistoryTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}") try: - async with AsyncSessionFactory() as db: - await DownloadHistory.async_delete(db, history_id) - return f"下载历史记录 ID: {history_id} 已成功删除" + await DownloadHistoryOper().async_delete_history(history_id) + return f"下载历史记录 ID: {history_id} 已成功删除" except Exception as e: logger.error(f"删除下载历史记录失败: {e}", exc_info=True) return f"删除下载历史记录时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/edit_file.py b/app/agent/tools/impl/edit_file.py index 3c44b275..581d0e34 100644 --- a/app/agent/tools/impl/edit_file.py +++ b/app/agent/tools/impl/edit_file.py @@ -12,7 +12,7 @@ from app.log import logger class EditFileInput(BaseModel): - """Input parameters for edit file tool""" + """文件编辑工具的输入参数模型。""" file_path: str = Field(..., description="The absolute path of the file to edit") old_text: str = Field(..., description="The exact old text to be replaced") @@ -27,8 +27,8 @@ class EditFileTool(MoviePilotTool): ] description: str = ( "Edit a local text file by replacing specific old text with new text. " - "Non-admin users can only edit files inside the MoviePilot config, " - "Agent memory/activity, and log directories." + "Non-admin users can only edit files inside the MoviePilot Agent config " + "and log directories." ) args_schema: Type[BaseModel] = EditFileInput diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index c638a17c..18127d1b 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -13,6 +13,7 @@ from typing import Any, Literal, Optional, TextIO, Type from pydantic import BaseModel, Field +from app.agent.tools.impl._command_safety import validate_command_safety from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.agent.tools.impl._terminal_session import ( @@ -30,14 +31,6 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024 READ_CHUNK_SIZE = 4096 KILL_GRACE_SECONDS = 3 COMMAND_CONCURRENCY_LIMIT = 2 -COMMAND_FORBIDDEN_KEYWORDS = ( - ":(){ :|:& };:", - "dd if=/dev/zero", - "mkfs", - "reboot", - "shutdown", -) - _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT) @@ -195,6 +188,13 @@ class ExecuteCommandInput(BaseModel): 60, description="For action=run, max execution time in seconds.", ) + confirm_dangerous: Optional[bool] = Field( + False, + description=( + "Explicit confirmation for high-risk commands such as recursive root deletion, " + "disk formatting, shutdown/reboot, or destructive permission changes." + ), + ) class ExecuteCommandTool(MoviePilotTool): @@ -255,34 +255,9 @@ class ExecuteCommandTool(MoviePilotTool): return command @staticmethod - def _validate_command(command: str) -> None: + def _validate_command(command: str, *, confirmed: bool = False) -> None: """复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。""" - for keyword in COMMAND_FORBIDDEN_KEYWORDS: - if keyword in command: - raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'") - - # 检查是否使用了 rm -r/R 删除根目录或一级目录,防止误杀多级目录 - import re - import os.path - tokens = re.split(r'\s+', command.strip()) - if any(t == "rm" or t.endswith("/rm") for t in tokens): - has_r = False - for token in tokens: - if token.startswith("-") and ("r" in token or "R" in token): - has_r = True - break - - if has_r: - for token in tokens: - # 提取可能包含目标路径的部分(去除重定向、管道、分号等末尾干扰) - m = re.match(r'^([^;\|&><]+)', token) - if m: - clean_token = m.group(1).strip('"\'') - # 仅对绝对路径进行一级目录限制 - if clean_token.startswith('/'): - norm_path = os.path.normpath(clean_token) - if re.match(r'^/[^/]*$', norm_path) or re.match(r'^/[^/]*/$', norm_path): - raise ValueError(f"不允许使用 rm 命令删除根目录或一级目录: {clean_token}") + validate_command_safety(command, confirmed=confirmed) @staticmethod def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]: @@ -367,7 +342,7 @@ class ExecuteCommandTool(MoviePilotTool): asyncio.shield(wait_task), timeout=KILL_GRACE_SECONDS ) except asyncio.TimeoutError: - logger.warning("命令进程强制清理超时: pid=%s", process.pid) + logger.warning(f"命令进程强制清理超时: pid={process.pid}") @staticmethod async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None: @@ -382,7 +357,7 @@ class ExecuteCommandTool(MoviePilotTool): if isinstance(result, Exception) and not isinstance( result, asyncio.CancelledError ): - logger.debug("命令输出读取任务异常: %s", result) + logger.debug(f"命令输出读取任务异常: {result}") @staticmethod def _format_run_result( @@ -425,9 +400,10 @@ class ExecuteCommandTool(MoviePilotTool): command: str, timeout: Optional[int], cwd: Optional[str] = None, + confirm_dangerous: bool = False, ) -> str: """按旧模式一次性执行命令,等待完成或超时后返回文本结果。""" - self._validate_command(command) + self._validate_command(command, confirmed=confirm_dangerous) normalized_timeout, timeout_note = self._normalize_timeout(timeout) async with _command_semaphore: @@ -482,27 +458,29 @@ class ExecuteCommandTool(MoviePilotTool): max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES, timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS, timeout: Optional[int] = 60, + confirm_dangerous: Optional[bool] = False, **kwargs, ) -> str: """执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。""" normalized_action = (action or "start").strip().lower() logger.info( - "执行工具: %s, action=%s, command=%s, session_id=%s", - self.name, - normalized_action, - command, - session_id, + f"执行工具: {self.name}, action={normalized_action}, " + f"command={command}, session_id={session_id}" ) try: if normalized_action == "start": start_command = self._require_command(command) - self._validate_command(start_command) + self._validate_command( + start_command, + confirmed=bool(confirm_dangerous), + ) payload = await terminal_session_manager.start( command=start_command, cwd=cwd, env=env, use_pty=use_pty, + confirm_dangerous=bool(confirm_dangerous), ) return self._dump(payload) @@ -542,9 +520,10 @@ class ExecuteCommandTool(MoviePilotTool): command=self._require_command(command), timeout=timeout, cwd=cwd, + confirm_dangerous=bool(confirm_dangerous), ) raise ValueError(f"不支持的 action: {action}") except Exception as err: - logger.error("执行命令 action 失败: %s", err, exc_info=True) + logger.error(f"执行命令 action 失败: {err}", exc_info=True) return self._dump({"error": str(err), "status": "error", "action": normalized_action}) diff --git a/app/agent/tools/impl/query_site_userdata.py b/app/agent/tools/impl/query_site_userdata.py index 8566166f..29def856 100644 --- a/app/agent/tools/impl/query_site_userdata.py +++ b/app/agent/tools/impl/query_site_userdata.py @@ -7,9 +7,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag -from app.db import AsyncSessionFactory -from app.db.models.site import Site -from app.db.models.siteuserdata import SiteUserData +from app.db.site_oper import SiteOper from app.log import logger SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10 @@ -66,118 +64,115 @@ class QuerySiteUserdataTool(MoviePilotTool): ) try: - # 获取数据库会话 - async with AsyncSessionFactory() as db: - # 获取站点 - site = await Site.async_get(db, site_id) - if not site: - return json.dumps( - {"success": False, "message": f"站点不存在: {site_id}"}, - ensure_ascii=False, - ) - - # 获取站点用户数据 - user_data_list = await SiteUserData.async_get_by_domain( - db, domain=site.domain, workdate=workdate + site_oper = SiteOper() + site = await site_oper.async_get(site_id) + if not site: + return json.dumps( + {"success": False, "message": f"站点不存在: {site_id}"}, + ensure_ascii=False, ) - if not user_data_list: - return json.dumps( - { - "success": False, - "message": f"站点 {site.name} ({site.domain}) 暂无用户数据", - "site_id": site_id, - "site_name": site.name, - "site_domain": site.domain, - "workdate": workdate, - }, - ensure_ascii=False, - ) + user_data_list = await site_oper.async_get_userdata_by_domain( + domain=site.domain, workdate=workdate + ) - # 格式化用户数据 - result = { - "success": True, - "site_id": site_id, - "site_name": site.name, - "site_domain": site.domain, - "workdate": workdate, - "data_count": len(user_data_list), - "user_data": [], + if not user_data_list: + return json.dumps( + { + "success": False, + "message": f"站点 {site.name} ({site.domain}) 暂无用户数据", + "site_id": site_id, + "site_name": site.name, + "site_domain": site.domain, + "workdate": workdate, + }, + ensure_ascii=False, + ) + + # 格式化用户数据 + result = { + "success": True, + "site_id": site_id, + "site_name": site.name, + "site_domain": site.domain, + "workdate": workdate, + "data_count": len(user_data_list), + "user_data": [], + } + + for user_data in user_data_list: + # 格式化上传/下载量(转换为可读格式) + upload_gb = user_data.upload / (1024**3) if user_data.upload else 0 + download_gb = ( + user_data.download / (1024**3) if user_data.download else 0 + ) + seeding_size_gb = ( + user_data.seeding_size / (1024**3) + if user_data.seeding_size + else 0 + ) + leeching_size_gb = ( + user_data.leeching_size / (1024**3) + if user_data.leeching_size + else 0 + ) + + seeding_preview, seeding_count, seeding_truncated = _preview_list( + user_data.seeding_info + ) + unread_preview, unread_count, unread_truncated = _preview_list( + user_data.message_unread_contents + ) + + user_data_dict = { + "domain": user_data.domain, + "name": user_data.name, + "username": user_data.username, + "userid": user_data.userid, + "user_level": user_data.user_level, + "join_at": user_data.join_at, + "bonus": user_data.bonus, + "upload": user_data.upload, + "upload_gb": round(upload_gb, 2), + "download": user_data.download, + "download_gb": round(download_gb, 2), + "ratio": round(user_data.ratio, 2) if user_data.ratio else 0, + "seeding": int(user_data.seeding) if user_data.seeding else 0, + "leeching": int(user_data.leeching) + if user_data.leeching + else 0, + "seeding_size": user_data.seeding_size, + "seeding_size_gb": round(seeding_size_gb, 2), + "leeching_size": user_data.leeching_size, + "leeching_size_gb": round(leeching_size_gb, 2), + "seeding_info_count": seeding_count, + "seeding_info": seeding_preview, + "seeding_info_truncated": seeding_truncated, + "message_unread": user_data.message_unread, + "message_unread_contents_count": unread_count, + "message_unread_contents": unread_preview, + "message_unread_contents_truncated": unread_truncated, + "err_msg": user_data.err_msg, + "updated_day": user_data.updated_day, + "updated_time": user_data.updated_time, } + result["user_data"].append(user_data_dict) - for user_data in user_data_list: - # 格式化上传/下载量(转换为可读格式) - upload_gb = user_data.upload / (1024**3) if user_data.upload else 0 - download_gb = ( - user_data.download / (1024**3) if user_data.download else 0 - ) - seeding_size_gb = ( - user_data.seeding_size / (1024**3) - if user_data.seeding_size - else 0 - ) - leeching_size_gb = ( - user_data.leeching_size / (1024**3) - if user_data.leeching_size - else 0 - ) + # 如果有多条数据,只返回最新的(按更新时间排序) + if len(result["user_data"]) > 1: + result["user_data"].sort( + key=lambda x: ( + x.get("updated_day", ""), + x.get("updated_time", ""), + ), + reverse=True, + ) + result["message"] = ( + f"找到 {len(result['user_data'])} 条数据,显示最新的一条" + ) + result["user_data"] = [result["user_data"][0]] - seeding_preview, seeding_count, seeding_truncated = _preview_list( - user_data.seeding_info - ) - unread_preview, unread_count, unread_truncated = _preview_list( - user_data.message_unread_contents - ) - - user_data_dict = { - "domain": user_data.domain, - "name": user_data.name, - "username": user_data.username, - "userid": user_data.userid, - "user_level": user_data.user_level, - "join_at": user_data.join_at, - "bonus": user_data.bonus, - "upload": user_data.upload, - "upload_gb": round(upload_gb, 2), - "download": user_data.download, - "download_gb": round(download_gb, 2), - "ratio": round(user_data.ratio, 2) if user_data.ratio else 0, - "seeding": int(user_data.seeding) if user_data.seeding else 0, - "leeching": int(user_data.leeching) - if user_data.leeching - else 0, - "seeding_size": user_data.seeding_size, - "seeding_size_gb": round(seeding_size_gb, 2), - "leeching_size": user_data.leeching_size, - "leeching_size_gb": round(leeching_size_gb, 2), - "seeding_info_count": seeding_count, - "seeding_info": seeding_preview, - "seeding_info_truncated": seeding_truncated, - "message_unread": user_data.message_unread, - "message_unread_contents_count": unread_count, - "message_unread_contents": unread_preview, - "message_unread_contents_truncated": unread_truncated, - "err_msg": user_data.err_msg, - "updated_day": user_data.updated_day, - "updated_time": user_data.updated_time, - } - result["user_data"].append(user_data_dict) - - # 如果有多条数据,只返回最新的(按更新时间排序) - if len(result["user_data"]) > 1: - result["user_data"].sort( - key=lambda x: ( - x.get("updated_day", ""), - x.get("updated_time", ""), - ), - reverse=True, - ) - result["message"] = ( - f"找到 {len(result['user_data'])} 条数据,显示最新的一条" - ) - result["user_data"] = [result["user_data"][0]] - - return json.dumps(result, ensure_ascii=False, indent=2) + return json.dumps(result, ensure_ascii=False, indent=2) except Exception as e: error_message = f"查询站点用户数据失败: {str(e)}" diff --git a/app/agent/tools/impl/query_subscribe_history.py b/app/agent/tools/impl/query_subscribe_history.py index b9fd8187..cee03d9c 100644 --- a/app/agent/tools/impl/query_subscribe_history.py +++ b/app/agent/tools/impl/query_subscribe_history.py @@ -7,8 +7,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag -from app.db import AsyncSessionFactory -from app.db.models.subscribehistory import SubscribeHistory +from app.db.subscribehistory_oper import SubscribeHistoryOper from app.log import logger from app.schemas.types import media_type_to_agent @@ -74,88 +73,87 @@ class QuerySubscribeHistoryTool(MoviePilotTool): if media_type not in ["all", "movie", "tv"]: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" - # 获取数据库会话 - async with AsyncSessionFactory() as db: - if name: - # 有名称过滤时,获取足够多的记录在内存中过滤,不分页 - fetch_count = 500 - if media_type == "all": - movie_history = await SubscribeHistory.async_list_by_type( - db, mtype="movie", page=1, count=fetch_count - ) - tv_history = await SubscribeHistory.async_list_by_type( - db, mtype="tv", page=1, count=fetch_count - ) - all_history = list(movie_history) + list(tv_history) - all_history.sort(key=lambda x: x.date or "", reverse=True) - else: - all_history = list( - await SubscribeHistory.async_list_by_type( - db, mtype=media_type, page=1, count=fetch_count - ) - ) - - # 按名称过滤 - name_lower = name.lower() - filtered_history = [ - record - for record in all_history - if record.name and name_lower in record.name.lower() - ] - - if not filtered_history: - return "未找到相关订阅历史记录" - - # 名称过滤时直接返回所有匹配结果,不分页 - simplified_records = self._simplify_records(filtered_history) - result_json = json.dumps( - simplified_records, ensure_ascii=False, indent=2 + subscribe_history_oper = SubscribeHistoryOper() + if name: + # 有名称过滤时,获取足够多的记录在内存中过滤,不分页 + fetch_count = 500 + if media_type == "all": + movie_history = await subscribe_history_oper.async_list_by_type( + mtype="movie", page=1, count=fetch_count ) - return result_json + tv_history = await subscribe_history_oper.async_list_by_type( + mtype="tv", page=1, count=fetch_count + ) + all_history = list(movie_history) + list(tv_history) + all_history.sort(key=lambda x: x.date or "", reverse=True) else: - # 无名称过滤时,直接利用数据库分页 - if media_type == "all": - movie_history = await SubscribeHistory.async_list_by_type( - db, mtype="movie", page=1, count=page * PAGE_SIZE - ) - tv_history = await SubscribeHistory.async_list_by_type( - db, mtype="tv", page=1, count=page * PAGE_SIZE - ) - all_history = list(movie_history) + list(tv_history) - all_history.sort(key=lambda x: x.date or "", reverse=True) - filtered_history = all_history - else: - filtered_history = list( - await SubscribeHistory.async_list_by_type( - db, mtype=media_type, page=1, count=page * PAGE_SIZE - ) + all_history = list( + await subscribe_history_oper.async_list_by_type( + mtype=media_type, page=1, count=fetch_count ) + ) + + # 按名称过滤 + name_lower = name.lower() + filtered_history = [ + record + for record in all_history + if record.name and name_lower in record.name.lower() + ] if not filtered_history: return "未找到相关订阅历史记录" - # 分页切片 - total_count = len(filtered_history) - start = (page - 1) * PAGE_SIZE - end = start + PAGE_SIZE - page_records = filtered_history[start:end] - - if not page_records: - return f"第 {page} 页没有数据。" - - simplified_records = self._simplify_records(page_records) + # 名称过滤时直接返回所有匹配结果,不分页 + simplified_records = self._simplify_records(filtered_history) result_json = json.dumps( simplified_records, ensure_ascii=False, indent=2 ) - - has_more = total_count > end - payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。" - if has_more: - payload_msg += ( - f" 可能有更多数据,可使用 page={page + 1} 获取下一页。" + return result_json + else: + # 无名称过滤时,直接利用数据库分页 + if media_type == "all": + movie_history = await subscribe_history_oper.async_list_by_type( + mtype="movie", page=1, count=page * PAGE_SIZE + ) + tv_history = await subscribe_history_oper.async_list_by_type( + mtype="tv", page=1, count=page * PAGE_SIZE + ) + all_history = list(movie_history) + list(tv_history) + all_history.sort(key=lambda x: x.date or "", reverse=True) + filtered_history = all_history + else: + filtered_history = list( + await subscribe_history_oper.async_list_by_type( + mtype=media_type, page=1, count=page * PAGE_SIZE + ) ) - return f"{payload_msg}\n\n{result_json}" + if not filtered_history: + return "未找到相关订阅历史记录" + + # 分页切片 + total_count = len(filtered_history) + start = (page - 1) * PAGE_SIZE + end = start + PAGE_SIZE + page_records = filtered_history[start:end] + + if not page_records: + return f"第 {page} 页没有数据。" + + simplified_records = self._simplify_records(page_records) + result_json = json.dumps( + simplified_records, ensure_ascii=False, indent=2 + ) + + has_more = total_count > end + payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。" + if has_more: + payload_msg += ( + f" 可能有更多数据,可使用 page={page + 1} 获取下一页。" + ) + + return f"{payload_msg}\n\n{result_json}" except Exception as e: logger.error(f"查询订阅历史失败: {e}", exc_info=True) return f"查询订阅历史时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/query_system_settings.py b/app/agent/tools/impl/query_system_settings.py index 71284cc4..5b7417e7 100644 --- a/app/agent/tools/impl/query_system_settings.py +++ b/app/agent/tools/impl/query_system_settings.py @@ -9,8 +9,11 @@ from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.agent.tools.impl._system_setting_utils import ( SettingSpec, + is_secret_setting_key, list_setting_specs, + redact_secret_value, resolve_setting_spec, + should_redact_setting, ) from app.core.config import settings from app.db.systemconfig_oper import SystemConfigOper @@ -53,6 +56,13 @@ class QuerySystemSettingsInput(BaseModel): "when multiple settings are matched it returns summaries only unless this is explicitly set to true." ), ) + show_secrets: Optional[bool] = Field( + False, + description=( + "Whether to return raw secret values such as API keys, tokens, cookies, and passwords. " + "Defaults to false; secret-like fields are redacted in returned values and previews." + ), + ) class QuerySystemSettingsTool(MoviePilotTool): @@ -85,15 +95,18 @@ class QuerySystemSettingsTool(MoviePilotTool): @staticmethod def _load_setting_value(spec: SettingSpec): + """读取指定设置项的当前值。""" if spec.source == "settings": return getattr(settings, spec.key) - return SystemConfigOper().get(spec.key) + return SystemConfigOper().get(spec.systemconfig_key) @staticmethod - def _summarize_value(value) -> dict: + def _summarize_value(value, *, redacted: bool = False) -> dict: + """生成设置值摘要,避免列表和字典默认输出过长。""" summary = { "has_value": value is not None, "value_type": type(value).__name__, + "redacted": redacted, } if isinstance(value, list): summary["item_count"] = len(value) @@ -122,14 +135,12 @@ class QuerySystemSettingsTool(MoviePilotTool): group: Optional[str] = "all", keyword: Optional[str] = None, include_values: Optional[bool] = None, + show_secrets: Optional[bool] = False, **kwargs, ) -> str: logger.info( - "执行工具: %s, setting_key=%s, group=%s, keyword=%s", - self.name, - setting_key, - group, - keyword, + f"执行工具: {self.name}, setting_key={setting_key}, " + f"group={group}, keyword={keyword}" ) try: @@ -158,18 +169,30 @@ class QuerySystemSettingsTool(MoviePilotTool): should_include_values = ( include_values if include_values is not None else len(specs) == 1 ) + allow_secret_values = bool(show_secrets) and await self.is_admin_user() settings_payload = [] for spec in specs: value = self._load_setting_value(spec) + should_redact = ( + should_redact_setting(spec, value) and not allow_secret_values + ) + response_value = ( + redact_secret_value( + value, + redact_scalar=is_secret_setting_key(spec.key), + ) + if should_redact + else value + ) item = { "setting_key": spec.key, "source": spec.source, "group": spec.group, "label": spec.label, } - item.update(self._summarize_value(value)) + item.update(self._summarize_value(response_value, redacted=should_redact)) if should_include_values: - item["value"] = value + item["value"] = response_value settings_payload.append(item) return json.dumps( @@ -177,6 +200,7 @@ class QuerySystemSettingsTool(MoviePilotTool): "success": True, "matched_count": len(settings_payload), "include_values": should_include_values, + "show_secrets": allow_secret_values, "settings": settings_payload, }, ensure_ascii=False, diff --git a/app/agent/tools/impl/query_transfer_history.py b/app/agent/tools/impl/query_transfer_history.py index 821cb7e6..1c3d1461 100644 --- a/app/agent/tools/impl/query_transfer_history.py +++ b/app/agent/tools/impl/query_transfer_history.py @@ -7,8 +7,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag -from app.db import AsyncSessionFactory -from app.db.models.transferhistory import TransferHistory +from app.db.transferhistory_oper import TransferHistoryOper from app.log import logger from app.schemas.types import media_type_to_agent from app.utils.jieba import cut as jieba_cut @@ -70,70 +69,69 @@ class QueryTransferHistoryTool(MoviePilotTool): # 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。 count = 30 - # 获取数据库会话 - async with AsyncSessionFactory() as db: - # 处理标题搜索 - if title: - # 使用统一分词封装处理标题,便于替换底层实现。 - words = jieba_cut(title, HMM=False) - title_search = "%".join(words) - # 查询记录 - result = await TransferHistory.async_list_by_title( - db, title=title_search, page=page, count=count, status=status_bool - ) - total = await TransferHistory.async_count_by_title( - db, title=title_search, status=status_bool - ) - else: - # 查询所有记录 - result = await TransferHistory.async_list_by_page( - db, page=page, count=count, status=status_bool - ) - total = await TransferHistory.async_count(db, status=status_bool) + transferhis = TransferHistoryOper() + # 处理标题搜索 + if title: + # 使用统一分词封装处理标题,便于替换底层实现。 + words = jieba_cut(title, HMM=False) + title_search = "%".join(words) + # 查询记录 + result = await transferhis.async_list_by_title( + title=title_search, page=page, count=count, status=status_bool + ) + total = await transferhis.async_count_by_title( + title=title_search, status=status_bool + ) + else: + # 查询所有记录 + result = await transferhis.async_list_by_page( + page=page, count=count, status=status_bool + ) + total = await transferhis.async_count(status=status_bool) - if not result: - return "未找到相关整理历史记录" + if not result: + return "未找到相关整理历史记录" - # 转换为字典格式,只保留关键信息 - simplified_records = [] - for record in result: - simplified = { - "id": record.id, - "title": record.title, - "year": record.year, - "type": media_type_to_agent(record.type), - "category": record.category, - "seasons": record.seasons, - "episodes": record.episodes, - "src": record.src, - "dest": record.dest, - "mode": record.mode, - "status": "成功" if record.status else "失败", - "date": record.date, - "downloader": record.downloader, - "download_hash": record.download_hash - } - # 如果失败,添加错误信息 - if not record.status and record.errmsg: - simplified["errmsg"] = record.errmsg - # 添加媒体ID信息(如果有) - if record.tmdbid: - simplified["tmdbid"] = record.tmdbid - if record.imdbid: - simplified["imdbid"] = record.imdbid - if record.doubanid: - simplified["doubanid"] = record.doubanid - simplified_records.append(simplified) + # 转换为字典格式,只保留关键信息 + simplified_records = [] + for record in result: + simplified = { + "id": record.id, + "title": record.title, + "year": record.year, + "type": media_type_to_agent(record.type), + "category": record.category, + "seasons": record.seasons, + "episodes": record.episodes, + "src": record.src, + "dest": record.dest, + "mode": record.mode, + "status": "成功" if record.status else "失败", + "date": record.date, + "downloader": record.downloader, + "download_hash": record.download_hash + } + # 如果失败,添加错误信息 + if not record.status and record.errmsg: + simplified["errmsg"] = record.errmsg + # 添加媒体ID信息(如果有) + if record.tmdbid: + simplified["tmdbid"] = record.tmdbid + if record.imdbid: + simplified["imdbid"] = record.imdbid + if record.doubanid: + simplified["doubanid"] = record.doubanid + simplified_records.append(simplified) - result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2) + result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2) - # 计算总页数 - total_pages = (total + count - 1) // count if total > 0 else 1 + # 计算总页数 + total_pages = (total + count - 1) // count if total > 0 else 1 - # 构建分页信息 - pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)" + # 构建分页信息 + pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)" - return f"{pagination_info}\n\n{result_json}" + return f"{pagination_info}\n\n{result_json}" except Exception as e: logger.error(f"查询整理历史记录失败: {e}", exc_info=True) return f"查询整理历史记录时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/query_workflows.py b/app/agent/tools/impl/query_workflows.py index f6d8289a..964a58fb 100644 --- a/app/agent/tools/impl/query_workflows.py +++ b/app/agent/tools/impl/query_workflows.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag -from app.db import AsyncSessionFactory from app.db.workflow_oper import WorkflowOper from app.log import logger @@ -56,75 +55,73 @@ class QueryWorkflowsTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: state={state}, name={name}, trigger_type={trigger_type}") try: - # 获取数据库会话 - async with AsyncSessionFactory() as db: - workflow_oper = WorkflowOper(db) - workflows = await workflow_oper.async_list() - - # 过滤工作流 - filtered_workflows = [] - for wf in workflows: - # 按状态过滤 - if state != "all" and wf.state != state: + workflow_oper = WorkflowOper() + workflows = await workflow_oper.async_list() + + # 过滤工作流 + filtered_workflows = [] + for wf in workflows: + # 按状态过滤 + if state != "all" and wf.state != state: + continue + + # 按触发类型过滤 + if trigger_type != "all": + if trigger_type == "timer" and wf.trigger_type not in ["timer", None]: continue - - # 按触发类型过滤 - if trigger_type != "all": - if trigger_type == "timer" and wf.trigger_type not in ["timer", None]: - continue - elif trigger_type == "event" and wf.trigger_type != "event": - continue - elif trigger_type == "manual" and wf.trigger_type != "manual": - continue - - # 按名称过滤(部分匹配) - if name and wf.name and name.lower() not in wf.name.lower(): + elif trigger_type == "event" and wf.trigger_type != "event": continue - - filtered_workflows.append(wf) - - if not filtered_workflows: - return "未找到相关工作流" - - # 转换为字典格式,只保留关键信息 - simplified_workflows = [] - for wf in filtered_workflows: - # 状态说明 - state_map = { - "W": "等待", - "R": "运行中", - "P": "暂停", - "S": "成功", - "F": "失败" - } - state_desc = state_map.get(wf.state, wf.state) - - # 触发类型说明 - trigger_type_map = { - "timer": "定时触发", - "event": "事件触发", - "manual": "手动触发" - } - trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发") - - simplified = { - "id": wf.id, - "name": wf.name, - "description": wf.description, - "trigger_type": trigger_type_desc, - "state": state_desc, - "run_count": wf.run_count, - "timer": wf.timer, - "event_type": wf.event_type, - "add_time": wf.add_time, - "last_time": wf.last_time, - "current_action": wf.current_action - } - # wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。 - simplified_workflows.append(simplified) - - result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2) - return result_json + elif trigger_type == "manual" and wf.trigger_type != "manual": + continue + + # 按名称过滤(部分匹配) + if name and wf.name and name.lower() not in wf.name.lower(): + continue + + filtered_workflows.append(wf) + + if not filtered_workflows: + return "未找到相关工作流" + + # 转换为字典格式,只保留关键信息 + simplified_workflows = [] + for wf in filtered_workflows: + # 状态说明 + state_map = { + "W": "等待", + "R": "运行中", + "P": "暂停", + "S": "成功", + "F": "失败" + } + state_desc = state_map.get(wf.state, wf.state) + + # 触发类型说明 + trigger_type_map = { + "timer": "定时触发", + "event": "事件触发", + "manual": "手动触发" + } + trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发") + + simplified = { + "id": wf.id, + "name": wf.name, + "description": wf.description, + "trigger_type": trigger_type_desc, + "state": state_desc, + "run_count": wf.run_count, + "timer": wf.timer, + "event_type": wf.event_type, + "add_time": wf.add_time, + "last_time": wf.last_time, + "current_action": wf.current_action + } + # wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。 + simplified_workflows.append(simplified) + + result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2) + return result_json except Exception as e: logger.error(f"查询工作流失败: {e}", exc_info=True) return f"查询工作流时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/read_file.py b/app/agent/tools/impl/read_file.py index ee99eddf..a0a7ccb9 100644 --- a/app/agent/tools/impl/read_file.py +++ b/app/agent/tools/impl/read_file.py @@ -15,7 +15,7 @@ MAX_READ_SIZE = 50 * 1024 class ReadFileInput(BaseModel): - """Input parameters for read file tool""" + """文件读取工具的输入参数模型。""" file_path: str = Field(..., description="The absolute path of the file to read") start_line: Optional[int] = Field(None, description="The starting line number (1-based, inclusive). If not provided, reading starts from the beginning of the file.") end_line: Optional[int] = Field(None, description="The ending line number (1-based, inclusive). If not provided, reading goes until the end of the file.") diff --git a/app/agent/tools/impl/run_workflow.py b/app/agent/tools/impl/run_workflow.py index 688df01d..21bbc25a 100644 --- a/app/agent/tools/impl/run_workflow.py +++ b/app/agent/tools/impl/run_workflow.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.chain.workflow import WorkflowChain -from app.db import AsyncSessionFactory from app.db.workflow_oper import WorkflowOper from app.log import logger @@ -65,26 +64,23 @@ class RunWorkflowTool(MoviePilotTool): ) try: - # 获取数据库会话 - async with AsyncSessionFactory() as db: - workflow_oper = WorkflowOper(db) - workflow = await workflow_oper.async_get(workflow_id) + workflow = await WorkflowOper().async_get(workflow_id) - if not workflow: - return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流" + if not workflow: + return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流" - # 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。 - state, errmsg = await self.run_blocking( - "workflow", - self._run_workflow_sync, - workflow.id, - from_begin, - ) + # 工作流执行链路包含大量同步步骤,统一放到 workflow 线程池。 + state, errmsg = await self.run_blocking( + "workflow", + self._run_workflow_sync, + workflow.id, + from_begin, + ) - if not state: - return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}" - else: - return f"工作流执行成功:{workflow.name} (ID: {workflow.id})" + if not state: + return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}" + else: + return f"工作流执行成功:{workflow.name} (ID: {workflow.id})" except Exception as e: logger.error(f"执行工作流失败: {e}", exc_info=True) return f"执行工作流时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/update_site.py b/app/agent/tools/impl/update_site.py index 72b1f65d..b0ffeb43 100644 --- a/app/agent/tools/impl/update_site.py +++ b/app/agent/tools/impl/update_site.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.core.event import eventmanager -from app.db import AsyncSessionFactory -from app.db.models.site import Site +from app.db.site_oper import SiteOper from app.log import logger from app.schemas.types import EventType from app.utils.string import StringUtils @@ -127,108 +126,106 @@ class UpdateSiteTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}") try: - # 获取数据库会话 - async with AsyncSessionFactory() as db: - # 获取站点 - site = await Site.async_get(db, site_id) - if not site: - return json.dumps( - {"success": False, "message": f"站点不存在: {site_id}"}, - ensure_ascii=False, - ) - - # 构建更新字典 - site_dict = {} - - # 基本信息 - if name is not None: - site_dict["name"] = name - - # URL处理(需要校正格式) - if url is not None: - _scheme, _netloc = StringUtils.get_url_netloc(url) - site_dict["url"] = f"{_scheme}://{_netloc}/" - - if pri is not None: - site_dict["pri"] = pri - if rss is not None: - site_dict["rss"] = rss - - # 认证信息 - if cookie is not None: - site_dict["cookie"] = cookie - if ua is not None: - site_dict["ua"] = ua - if apikey is not None: - site_dict["apikey"] = apikey - if token is not None: - site_dict["token"] = token - - # 配置选项 - if proxy is not None: - site_dict["proxy"] = proxy - if filter is not None: - site_dict["filter"] = filter - if note is not None: - site_dict["note"] = note - if timeout is not None: - site_dict["timeout"] = timeout - - # 流控设置 - if limit_interval is not None: - site_dict["limit_interval"] = limit_interval - if limit_count is not None: - site_dict["limit_count"] = limit_count - if limit_seconds is not None: - site_dict["limit_seconds"] = limit_seconds - - # 状态和下载器 - if is_active is not None: - site_dict["is_active"] = is_active - if downloader is not None: - site_dict["downloader"] = downloader - - # 如果没有要更新的字段 - if not site_dict: - return json.dumps( - {"success": False, "message": "没有提供要更新的字段"}, - ensure_ascii=False, - ) - - # 更新站点 - await site.async_update(db, site_dict) - - # 重新获取更新后的站点数据 - updated_site = await Site.async_get(db, site_id) - - # 发送站点更新事件 - await eventmanager.async_send_event( - EventType.SiteUpdated, - {"domain": updated_site.domain if updated_site else site.domain}, + site_oper = SiteOper() + site = await site_oper.async_get(site_id) + if not site: + return json.dumps( + {"success": False, "message": f"站点不存在: {site_id}"}, + ensure_ascii=False, ) - # 构建返回结果 - result = { - "success": True, - "message": f"站点 #{site_id} 更新成功", - "site_id": site_id, - "updated_fields": list(site_dict.keys()), + # 构建更新字典 + site_dict = {} + + # 基本信息 + if name is not None: + site_dict["name"] = name + + # URL处理(需要校正格式) + if url is not None: + _scheme, _netloc = StringUtils.get_url_netloc(url) + site_dict["url"] = f"{_scheme}://{_netloc}/" + + if pri is not None: + site_dict["pri"] = pri + if rss is not None: + site_dict["rss"] = rss + + # 认证信息 + if cookie is not None: + site_dict["cookie"] = cookie + if ua is not None: + site_dict["ua"] = ua + if apikey is not None: + site_dict["apikey"] = apikey + if token is not None: + site_dict["token"] = token + + # 配置选项 + if proxy is not None: + site_dict["proxy"] = proxy + if filter is not None: + site_dict["filter"] = filter + if note is not None: + site_dict["note"] = note + if timeout is not None: + site_dict["timeout"] = timeout + + # 流控设置 + if limit_interval is not None: + site_dict["limit_interval"] = limit_interval + if limit_count is not None: + site_dict["limit_count"] = limit_count + if limit_seconds is not None: + site_dict["limit_seconds"] = limit_seconds + + # 状态和下载器 + if is_active is not None: + site_dict["is_active"] = is_active + if downloader is not None: + site_dict["downloader"] = downloader + + # 如果没有要更新的字段 + if not site_dict: + return json.dumps( + {"success": False, "message": "没有提供要更新的字段"}, + ensure_ascii=False, + ) + + # 更新站点 + await site_oper.async_update(site_id, site_dict) + + # 重新获取更新后的站点数据 + updated_site = await site_oper.async_get(site_id) + + # 发送站点更新事件 + await eventmanager.async_send_event( + EventType.SiteUpdated, + {"domain": updated_site.domain if updated_site else site.domain}, + ) + + # 构建返回结果 + result = { + "success": True, + "message": f"站点 #{site_id} 更新成功", + "site_id": site_id, + "updated_fields": list(site_dict.keys()), + } + + if updated_site: + result["site"] = { + "id": updated_site.id, + "name": updated_site.name, + "domain": updated_site.domain, + "url": updated_site.url, + "pri": updated_site.pri, + "is_active": updated_site.is_active, + "downloader": updated_site.downloader, + "proxy": updated_site.proxy, + "timeout": updated_site.timeout, } - if updated_site: - result["site"] = { - "id": updated_site.id, - "name": updated_site.name, - "domain": updated_site.domain, - "url": updated_site.url, - "pri": updated_site.pri, - "is_active": updated_site.is_active, - "downloader": updated_site.downloader, - "proxy": updated_site.proxy, - "timeout": updated_site.timeout, - } - - return json.dumps(result, ensure_ascii=False, indent=2) + return json.dumps(result, ensure_ascii=False, indent=2) except Exception as e: error_message = f"更新站点失败: {str(e)}" diff --git a/app/agent/tools/impl/update_subscribe.py b/app/agent/tools/impl/update_subscribe.py index 86edf2a3..6428d648 100644 --- a/app/agent/tools/impl/update_subscribe.py +++ b/app/agent/tools/impl/update_subscribe.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.core.event import eventmanager -from app.db import AsyncSessionFactory -from app.db.models.subscribe import Subscribe +from app.db.subscribe_oper import SubscribeOper from app.log import logger from app.schemas.types import EventType @@ -157,149 +156,147 @@ class UpdateSubscribeTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}") try: - # 获取数据库会话 - async with AsyncSessionFactory() as db: - # 获取订阅 - subscribe = await Subscribe.async_get(db, subscribe_id) - if not subscribe: - return json.dumps( - {"success": False, "message": f"订阅不存在: {subscribe_id}"}, - ensure_ascii=False, - ) - - # 保存旧数据用于事件 - old_subscribe_dict = subscribe.to_dict() - - # 构建更新字典 - subscribe_dict = {} - - # 基本信息 - if name is not None: - subscribe_dict["name"] = name - if year is not None: - subscribe_dict["year"] = year - if season is not None: - subscribe_dict["season"] = season - - # 集数相关 - if total_episode is not None: - subscribe_dict["total_episode"] = total_episode - # 如果总集数增加,缺失集数也要相应增加 - if total_episode > (subscribe.total_episode or 0): - old_lack = subscribe.lack_episode or 0 - subscribe_dict["lack_episode"] = old_lack + ( - total_episode - (subscribe.total_episode or 0) - ) - # 标记为手动修改过总集数 - subscribe_dict["manual_total_episode"] = 1 - - # 缺失集数处理(只有在没有提供总集数时才单独处理) - # 注意:如果 lack_episode 为 0,不更新(避免更新为0) - if lack_episode is not None and total_episode is None: - if lack_episode > 0: - subscribe_dict["lack_episode"] = lack_episode - # 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理) - - if start_episode is not None: - subscribe_dict["start_episode"] = start_episode - - # 过滤规则 - if quality is not None: - subscribe_dict["quality"] = quality - if resolution is not None: - subscribe_dict["resolution"] = resolution - if effect is not None: - subscribe_dict["effect"] = effect - if include is not None: - subscribe_dict["include"] = include - if exclude is not None: - subscribe_dict["exclude"] = exclude - if filter is not None: - subscribe_dict["filter"] = filter - - # 状态 - if state is not None: - valid_states = ["R", "P", "S", "N"] - if state not in valid_states: - return json.dumps( - { - "success": False, - "message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}", - }, - ensure_ascii=False, - ) - subscribe_dict["state"] = state - - # 下载配置 - if sites is not None: - subscribe_dict["sites"] = sites - if downloader is not None: - subscribe_dict["downloader"] = downloader - if save_path is not None: - subscribe_dict["save_path"] = save_path - if best_version is not None: - subscribe_dict["best_version"] = best_version - if best_version_full is not None: - subscribe_dict["best_version_full"] = best_version_full - - # 其他配置 - if custom_words is not None: - subscribe_dict["custom_words"] = custom_words - if media_category is not None: - subscribe_dict["media_category"] = media_category - if episode_group is not None: - subscribe_dict["episode_group"] = episode_group - - # 如果没有要更新的字段 - if not subscribe_dict: - return json.dumps( - {"success": False, "message": "没有提供要更新的字段"}, - ensure_ascii=False, - ) - - # 更新订阅 - await subscribe.async_update(db, subscribe_dict) - - # 重新获取更新后的订阅数据 - updated_subscribe = await Subscribe.async_get(db, subscribe_id) - - # 发送订阅调整事件 - await eventmanager.async_send_event( - EventType.SubscribeModified, - { - "subscribe_id": subscribe_id, - "old_subscribe_info": old_subscribe_dict, - "subscribe_info": updated_subscribe.to_dict() - if updated_subscribe - else {}, - }, + subscribe_oper = SubscribeOper() + subscribe = await subscribe_oper.async_get(subscribe_id) + if not subscribe: + return json.dumps( + {"success": False, "message": f"订阅不存在: {subscribe_id}"}, + ensure_ascii=False, ) - # 构建返回结果 - result = { - "success": True, - "message": f"订阅 #{subscribe_id} 更新成功", + # 保存旧数据用于事件 + old_subscribe_dict = subscribe.to_dict() + + # 构建更新字典 + subscribe_dict = {} + + # 基本信息 + if name is not None: + subscribe_dict["name"] = name + if year is not None: + subscribe_dict["year"] = year + if season is not None: + subscribe_dict["season"] = season + + # 集数相关 + if total_episode is not None: + subscribe_dict["total_episode"] = total_episode + # 如果总集数增加,缺失集数也要相应增加 + if total_episode > (subscribe.total_episode or 0): + old_lack = subscribe.lack_episode or 0 + subscribe_dict["lack_episode"] = old_lack + ( + total_episode - (subscribe.total_episode or 0) + ) + # 标记为手动修改过总集数 + subscribe_dict["manual_total_episode"] = 1 + + # 缺失集数处理(只有在没有提供总集数时才单独处理) + # 注意:如果 lack_episode 为 0,不更新(避免更新为0) + if lack_episode is not None and total_episode is None: + if lack_episode > 0: + subscribe_dict["lack_episode"] = lack_episode + # 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理) + + if start_episode is not None: + subscribe_dict["start_episode"] = start_episode + + # 过滤规则 + if quality is not None: + subscribe_dict["quality"] = quality + if resolution is not None: + subscribe_dict["resolution"] = resolution + if effect is not None: + subscribe_dict["effect"] = effect + if include is not None: + subscribe_dict["include"] = include + if exclude is not None: + subscribe_dict["exclude"] = exclude + if filter is not None: + subscribe_dict["filter"] = filter + + # 状态 + if state is not None: + valid_states = ["R", "P", "S", "N"] + if state not in valid_states: + return json.dumps( + { + "success": False, + "message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}", + }, + ensure_ascii=False, + ) + subscribe_dict["state"] = state + + # 下载配置 + if sites is not None: + subscribe_dict["sites"] = sites + if downloader is not None: + subscribe_dict["downloader"] = downloader + if save_path is not None: + subscribe_dict["save_path"] = save_path + if best_version is not None: + subscribe_dict["best_version"] = best_version + if best_version_full is not None: + subscribe_dict["best_version_full"] = best_version_full + + # 其他配置 + if custom_words is not None: + subscribe_dict["custom_words"] = custom_words + if media_category is not None: + subscribe_dict["media_category"] = media_category + if episode_group is not None: + subscribe_dict["episode_group"] = episode_group + + # 如果没有要更新的字段 + if not subscribe_dict: + return json.dumps( + {"success": False, "message": "没有提供要更新的字段"}, + ensure_ascii=False, + ) + + # 更新订阅 + await subscribe_oper.async_update(subscribe_id, subscribe_dict) + + # 重新获取更新后的订阅数据 + updated_subscribe = await subscribe_oper.async_get(subscribe_id) + + # 发送订阅调整事件 + await eventmanager.async_send_event( + EventType.SubscribeModified, + { "subscribe_id": subscribe_id, - "updated_fields": list(subscribe_dict.keys()), + "old_subscribe_info": old_subscribe_dict, + "subscribe_info": updated_subscribe.to_dict() + if updated_subscribe + else {}, + }, + ) + + # 构建返回结果 + result = { + "success": True, + "message": f"订阅 #{subscribe_id} 更新成功", + "subscribe_id": subscribe_id, + "updated_fields": list(subscribe_dict.keys()), + } + + if updated_subscribe: + result["subscribe"] = { + "id": updated_subscribe.id, + "name": updated_subscribe.name, + "year": updated_subscribe.year, + "type": updated_subscribe.type, + "season": updated_subscribe.season, + "state": updated_subscribe.state, + "total_episode": updated_subscribe.total_episode, + "lack_episode": updated_subscribe.lack_episode, + "start_episode": updated_subscribe.start_episode, + "quality": updated_subscribe.quality, + "resolution": updated_subscribe.resolution, + "effect": updated_subscribe.effect, } - if updated_subscribe: - result["subscribe"] = { - "id": updated_subscribe.id, - "name": updated_subscribe.name, - "year": updated_subscribe.year, - "type": updated_subscribe.type, - "season": updated_subscribe.season, - "state": updated_subscribe.state, - "total_episode": updated_subscribe.total_episode, - "lack_episode": updated_subscribe.lack_episode, - "start_episode": updated_subscribe.start_episode, - "quality": updated_subscribe.quality, - "resolution": updated_subscribe.resolution, - "effect": updated_subscribe.effect, - } - - return json.dumps(result, ensure_ascii=False, indent=2) + return json.dumps(result, ensure_ascii=False, indent=2) except Exception as e: error_message = f"更新订阅失败: {str(e)}" diff --git a/app/agent/tools/impl/update_system_settings.py b/app/agent/tools/impl/update_system_settings.py index 5027474c..6d2764d5 100644 --- a/app/agent/tools/impl/update_system_settings.py +++ b/app/agent/tools/impl/update_system_settings.py @@ -11,7 +11,10 @@ from app.agent.tools.tags import ToolTag from app.agent.tools.impl._system_setting_utils import ( SettingSpec, get_default_list_match_field, + is_secret_setting_key, + redact_secret_value, resolve_setting_spec, + should_redact_setting, ) from app.core.config import settings from app.core.event import eventmanager @@ -102,12 +105,14 @@ class UpdateSystemSettingsTool(MoviePilotTool): @staticmethod def _load_setting_value(spec: SettingSpec): + """读取指定设置项的当前值。""" if spec.source == "settings": return getattr(settings, spec.key) - return SystemConfigOper().get(spec.key) + return SystemConfigOper().get(spec.systemconfig_key) @staticmethod def _normalize_systemconfig_value(value: Any): + """规范化写入 SystemConfig 的空列表值。""" if isinstance(value, list): filtered = [item for item in value if item is not None] return filtered or None @@ -221,10 +226,7 @@ class UpdateSystemSettingsTool(MoviePilotTool): **kwargs, ) -> str: logger.info( - "执行工具: %s, setting_key=%s, operation=%s", - self.name, - setting_key, - operation, + f"执行工具: {self.name}, setting_key={setting_key}, operation={operation}" ) try: @@ -266,7 +268,10 @@ class UpdateSystemSettingsTool(MoviePilotTool): else: normalized_value = self._normalize_systemconfig_value(next_value) event_value = normalized_value - success = await SystemConfigOper().async_set(spec.key, normalized_value) + success = await SystemConfigOper().async_set( + spec.systemconfig_key, + normalized_value, + ) changed = success is True if changed: @@ -280,6 +285,26 @@ class UpdateSystemSettingsTool(MoviePilotTool): ) saved_value = self._load_setting_value(spec) + redact_values = ( + should_redact_setting(spec, saved_value) + or should_redact_setting(spec, current_value) + ) + response_previous_value = ( + redact_secret_value( + current_value, + redact_scalar=is_secret_setting_key(spec.key), + ) + if redact_values + else current_value + ) + response_saved_value = ( + redact_secret_value( + saved_value, + redact_scalar=is_secret_setting_key(spec.key), + ) + if redact_values + else saved_value + ) if not changed and not message: message = "配置值未发生变化" @@ -295,8 +320,9 @@ class UpdateSystemSettingsTool(MoviePilotTool): "group": spec.group, "label": spec.label, }, - "previous_value": current_value, - "saved_value": saved_value, + "values_redacted": redact_values, + "previous_value": response_previous_value, + "saved_value": response_saved_value, }, ensure_ascii=False, indent=2, diff --git a/app/agent/tools/impl/write_file.py b/app/agent/tools/impl/write_file.py index ad19cc5f..b6a73205 100644 --- a/app/agent/tools/impl/write_file.py +++ b/app/agent/tools/impl/write_file.py @@ -12,7 +12,7 @@ from app.log import logger class WriteFileInput(BaseModel): - """Input parameters for write file tool""" + """文件写入工具的输入参数模型。""" file_path: str = Field(..., description="The absolute path of the file to write") content: str = Field(..., description="The content to write into the file") @@ -26,7 +26,7 @@ class WriteFileTool(MoviePilotTool): ] description: str = ( "Write full content to a local text file. Non-admin users can only write " - "inside the MoviePilot config, Agent memory/activity, and log directories." + "inside the MoviePilot Agent config and log directories." ) args_schema: Type[BaseModel] = WriteFileInput diff --git a/app/db/downloadhistory_oper.py b/app/db/downloadhistory_oper.py index c568e12b..6f6f614b 100644 --- a/app/db/downloadhistory_oper.py +++ b/app/db/downloadhistory_oper.py @@ -114,6 +114,12 @@ class DownloadHistoryOper(DbOper): """ return DownloadHistory.list_by_page(self._db, page, count) + async def async_delete_history(self, historyid: int): + """ + 异步删除下载记录。 + """ + await DownloadHistory.async_delete(self._db, historyid) + def truncate(self): """ 清空下载记录 diff --git a/app/db/site_oper.py b/app/db/site_oper.py index c215a471..f5fdda5b 100644 --- a/app/db/site_oper.py +++ b/app/db/site_oper.py @@ -79,6 +79,15 @@ class SiteOper(DbOper): site.update(self._db, payload) return site + async def async_update(self, sid: int, payload: dict) -> Site: + """ + 异步更新站点。 + """ + site = await self.async_get(sid) + if site: + await site.async_update(self._db, payload) + return site + def get_by_domain(self, domain: str) -> Site: """ 按域名获取站点 @@ -170,6 +179,16 @@ class SiteOper(DbOper): """ return SiteUserData.get_by_domain(self._db, domain=domain, workdate=workdate) + async def async_get_userdata_by_domain( + self, domain: str, workdate: Optional[str] = None + ) -> List[SiteUserData]: + """ + 异步获取站点用户数据。 + """ + return await SiteUserData.async_get_by_domain( + self._db, domain=domain, workdate=workdate + ) + def get_userdata_by_date(self, date: str) -> List[SiteUserData]: """ 获取站点用户数据 diff --git a/app/db/subscribe_oper.py b/app/db/subscribe_oper.py index 80959db2..4167a85a 100644 --- a/app/db/subscribe_oper.py +++ b/app/db/subscribe_oper.py @@ -169,6 +169,22 @@ class SubscribeOper(DbOper): """ await Subscribe.async_delete(self._db, rid=sid) + async def async_update(self, sid: int, payload: dict) -> Subscribe: + """ + 异步更新订阅。 + """ + subscribe = await self.async_get(sid) + if subscribe: + payload = _normalize_integer_flags(payload) + await subscribe.async_update(self._db, payload) + return subscribe + + async def async_update_filter_groups(self, sid: int, filter_groups: list) -> Subscribe: + """ + 异步更新订阅使用的过滤规则组。 + """ + return await self.async_update(sid, {"filter_groups": filter_groups}) + def update(self, sid: int, payload: dict) -> Subscribe: """ 更新订阅 diff --git a/app/db/subscribehistory_oper.py b/app/db/subscribehistory_oper.py new file mode 100644 index 00000000..b8bb5703 --- /dev/null +++ b/app/db/subscribehistory_oper.py @@ -0,0 +1,26 @@ +from typing import List, Optional + +from app.db import DbOper +from app.db.models.subscribehistory import SubscribeHistory + + +class SubscribeHistoryOper(DbOper): + """ + 订阅历史管理。 + """ + + async def async_list_by_type( + self, + mtype: str, + page: Optional[int] = 1, + count: Optional[int] = 30, + ) -> List[SubscribeHistory]: + """ + 异步按媒体类型分页查询订阅历史。 + """ + return await SubscribeHistory.async_list_by_type( + self._db, + mtype=mtype, + page=page, + count=count, + ) diff --git a/app/db/transferhistory_oper.py b/app/db/transferhistory_oper.py index 2e78c81c..287c7289 100644 --- a/app/db/transferhistory_oper.py +++ b/app/db/transferhistory_oper.py @@ -26,6 +26,51 @@ class TransferHistoryOper(DbOper): """ return await TransferHistory.async_get(self._db, historyid) + async def async_list_by_title( + self, + title: str, + page: Optional[int] = 1, + count: Optional[int] = 30, + status: Optional[bool] = None, + ) -> List[TransferHistory]: + """ + 异步按标题分页查询转移记录。 + """ + return await TransferHistory.async_list_by_title( + self._db, title=title, page=page, count=count, status=status + ) + + async def async_list_by_page( + self, + page: Optional[int] = 1, + count: Optional[int] = 30, + status: Optional[bool] = None, + ) -> List[TransferHistory]: + """ + 异步分页查询转移记录。 + """ + return await TransferHistory.async_list_by_page( + self._db, page=page, count=count, status=status + ) + + async def async_count(self, status: Optional[bool] = None) -> int: + """ + 异步统计转移记录数量。 + """ + return await TransferHistory.async_count(self._db, status=status) + + async def async_count_by_title( + self, + title: str, + status: Optional[bool] = None, + ) -> int: + """ + 异步按标题统计转移记录数量。 + """ + return await TransferHistory.async_count_by_title( + self._db, title=title, status=status + ) + def get_by_title(self, title: str) -> List[TransferHistory]: """ 按标题查询转移记录 diff --git a/tests/test_agent_activity_log.py b/tests/test_agent_activity_log.py index 5296daab..3248b320 100644 --- a/tests/test_agent_activity_log.py +++ b/tests/test_agent_activity_log.py @@ -73,6 +73,22 @@ def test_activity_log_prompt_injects_index_not_full_log(tmp_path): assert "query_activity_log" in system_text +def test_activity_log_abefore_agent_refreshes_existing_state(tmp_path): + """复用 Agent 图时,活动日志索引仍应在每轮执行前刷新。""" + date_str = datetime.now().strftime("%Y-%m-%d") + middleware = ActivityLogMiddleware(activity_dir=str(tmp_path), prompt_load_days=1) + state = {"activity_log_contents": {"old": "旧索引"}} + + _write_activity_log( + tmp_path, + date_str, + ["- **10:00** 新增活动记录"], + ) + state_update = asyncio.run(middleware.abefore_agent(state, runtime=None)) + + assert state_update == {"activity_log_contents": {date_str: "1 条活动记录"}} + + def test_activity_log_skips_trivial_greeting_without_llm(tmp_path): """无实际任务的寒暄不应调用 LLM,也不应写入活动日志。""" middleware = ActivityLogMiddleware(activity_dir=str(tmp_path)) diff --git a/tests/test_agent_graph_cache.py b/tests/test_agent_graph_cache.py new file mode 100644 index 00000000..3027e7e8 --- /dev/null +++ b/tests/test_agent_graph_cache.py @@ -0,0 +1,92 @@ +"""Agent 图缓存行为测试。""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from app.agent import MoviePilotAgent, ReplyMode, _CompiledAgentBundle + + +@pytest.fixture +def anyio_backend(): + """使用 asyncio 后端运行 anyio 异步测试。""" + return "asyncio" + + +class _FakeGraphState: + """提供 LangGraph get_state 测试替身。""" + + def __init__(self, messages): + """保存测试消息状态。""" + self.values = {"messages": messages} + + +class _CapturingAgent: + """捕获传入消息的非流式 Agent 测试替身。""" + + def __init__(self): + """初始化捕获容器。""" + self.payload = None + + async def ainvoke(self, payload, config=None): + """记录 Agent 调用输入。""" + self.payload = payload + + def get_state(self, _config): + """返回包含最终 AI 回复的图状态。""" + return _FakeGraphState([AIMessage(content="ok")]) + + +@pytest.mark.anyio +async def test_create_agent_reuses_cached_graph_when_signature_matches(): + """构造签名一致时应直接复用已编译 Agent 图。""" + cached_graph = object() + agent = MoviePilotAgent(session_id="cache-hit", user_id="user-1") + agent._compiled_agent_bundle = _CompiledAgentBundle( + signature=("sig",), + agent=cached_graph, + streaming=False, + created_at=datetime.now(), + ) + + with patch.object( + agent, + "_agent_bundle_signature", + new=AsyncMock(return_value=("sig",)), + ), patch("app.agent.create_agent") as create_agent: + graph = await agent._create_agent(streaming=False) + + assert graph is cached_graph + assert agent._last_agent_cache_hit is True + create_agent.assert_not_called() + + +@pytest.mark.anyio +async def test_execute_agent_sends_only_latest_message_on_cache_hit(): + """缓存命中时只把本轮新消息交给 LangGraph,避免重复提交历史。""" + fake_graph = _CapturingAgent() + agent = MoviePilotAgent(session_id="cache-hit", user_id="user-1") + agent.reply_mode = ReplyMode.CAPTURE_ONLY + agent._tool_context = {"user_reply_sent": False} + agent._streamed_output = "" + agent._should_stream = lambda: False + agent.stream_handler = SimpleNamespace( + stop_streaming=AsyncMock(return_value=(False, "")) + ) + + async def _create_agent(streaming=False): + """模拟缓存命中后的 Agent 创建结果。""" + agent._last_agent_cache_hit = True + return fake_graph + + agent._create_agent = _create_agent + messages = [HumanMessage(content="上一轮"), HumanMessage(content="本轮")] + + with patch("app.agent.eventmanager.send_event"): + await agent._execute_agent(messages) + + assert agent._streamed_output == "ok" + assert fake_graph.payload["messages"] == [messages[-1]] diff --git a/tests/test_agent_query_workflows_tool.py b/tests/test_agent_query_workflows_tool.py index 02dce44b..58904610 100644 --- a/tests/test_agent_query_workflows_tool.py +++ b/tests/test_agent_query_workflows_tool.py @@ -7,16 +7,6 @@ from unittest.mock import AsyncMock, MagicMock, patch from app.agent.tools.impl.query_workflows import QueryWorkflowsTool -class _AsyncSessionContext: - """为工作流查询工具提供最小异步 DB 上下文。""" - - async def __aenter__(self): - return object() - - async def __aexit__(self, exc_type, exc, tb): - return False - - class TestQueryWorkflowsTool(unittest.TestCase): def test_query_workflows_omits_large_result_field(self): tool = QueryWorkflowsTool(session_id="session-1", user_id="10001") @@ -38,9 +28,6 @@ class TestQueryWorkflowsTool(unittest.TestCase): workflow_oper.async_list = AsyncMock(return_value=[workflow]) with patch( - "app.agent.tools.impl.query_workflows.AsyncSessionFactory", - return_value=_AsyncSessionContext(), - ), patch( "app.agent.tools.impl.query_workflows.WorkflowOper", return_value=workflow_oper, ): diff --git a/tests/test_agent_subagents.py b/tests/test_agent_subagents.py index f8b66e27..e9156dbc 100644 --- a/tests/test_agent_subagents.py +++ b/tests/test_agent_subagents.py @@ -466,14 +466,16 @@ def test_after_agent_cancels_unfinished_tasks(): ) ) await asyncio.wait_for(task_started.wait(), timeout=1) + task_id = start_payload["tasks"][0]["task_id"] await middleware.aafter_agent({}, None) status_payload = json.loads( await middleware._control_task( action="status", - task_ids=[start_payload["tasks"][0]["task_id"]], + task_ids=[task_id], ) ) - assert status_payload["tasks"][0]["status"] == "cancelled" + assert status_payload["tasks"] == [] + assert status_payload["missing_task_ids"] == [task_id] asyncio.run(_run_test()) diff --git a/tests/test_agent_system_settings_tools.py b/tests/test_agent_system_settings_tools.py index 4d58d4fa..e5287fb0 100644 --- a/tests/test_agent_system_settings_tools.py +++ b/tests/test_agent_system_settings_tools.py @@ -7,6 +7,7 @@ from app.agent.tools.impl.query_system_settings import QuerySystemSettingsTool from app.agent.tools.impl.update_system_settings import UpdateSystemSettingsTool from app.agent.tools.manager import MoviePilotToolsManager from app.core.config import settings +from app.schemas.types import SystemConfigKey class TestAgentSystemSettingsTools(unittest.TestCase): @@ -24,6 +25,63 @@ class TestAgentSystemSettingsTools(unittest.TestCase): self.assertEqual(payload["matched_count"], 1) self.assertEqual(payload["settings"][0]["setting_key"], "Downloaders") self.assertEqual(payload["settings"][0]["value"][0]["name"], "qb") + system_config_oper.return_value.get.assert_called_once_with( + SystemConfigKey.Downloaders + ) + + def test_query_system_settings_redacts_secret_values_by_default(self): + """查询系统设置默认应脱敏 API Key、Token、Cookie 等敏感字段。""" + tool = QuerySystemSettingsTool(session_id="session-1", user_id="10001") + + with patch( + "app.agent.tools.impl.query_system_settings.SystemConfigOper" + ) as system_config_oper: + system_config_oper.return_value.get.return_value = [ + { + "name": "site-a", + "apikey": "site-api-key", + "token": "site-token", + "cookie": "uid=1; passkey=secret", + "url": "https://example.com", + } + ] + result = asyncio.run( + tool.run(setting_key="UserSiteAuthParams", include_values=True) + ) + + payload = json.loads(result) + item = payload["settings"][0] + self.assertTrue(item["redacted"]) + self.assertFalse(payload["show_secrets"]) + self.assertEqual("***", item["value"][0]["apikey"]) + self.assertEqual("***", item["value"][0]["token"]) + self.assertEqual("***", item["value"][0]["cookie"]) + self.assertEqual("https://example.com", item["value"][0]["url"]) + + def test_query_system_settings_show_secrets_requires_admin_context(self): + """只有管理员显式请求 show_secrets 时才返回敏感配置原值。""" + tool = QuerySystemSettingsTool(session_id="session-1", user_id="admin") + tool.set_agent_context({"is_admin": True}) + + with patch( + "app.agent.tools.impl.query_system_settings.SystemConfigOper" + ) as system_config_oper: + system_config_oper.return_value.get.return_value = [ + {"name": "site-a", "apikey": "site-api-key"} + ] + result = asyncio.run( + tool.run( + setting_key="UserSiteAuthParams", + include_values=True, + show_secrets=True, + ) + ) + + payload = json.loads(result) + item = payload["settings"][0] + self.assertTrue(payload["show_secrets"]) + self.assertFalse(item["redacted"]) + self.assertEqual("site-api-key", item["value"][0]["apikey"]) def test_query_system_settings_group_defaults_to_summary_for_multiple_items(self): tool = QuerySystemSettingsTool(session_id="session-1", user_id="10001") @@ -67,7 +125,7 @@ class TestAgentSystemSettingsTools(unittest.TestCase): self.assertTrue(payload["success"]) self.assertTrue(payload["changed"]) config_oper.async_set.assert_awaited_once_with( - "AIAgentConfig", + SystemConfigKey.AIAgentConfig, {"chatgpt": {"enabled": False}, "gemini": {"enabled": True}}, ) send_event.assert_awaited_once() @@ -99,6 +157,42 @@ class TestAgentSystemSettingsTools(unittest.TestCase): payload = json.loads(result) self.assertTrue(payload["success"]) self.assertEqual(payload["saved_value"], [{"name": "qb", "enabled": True}]) + config_oper.async_set.assert_awaited_once_with( + SystemConfigKey.Downloaders, + [{"name": "qb", "enabled": True}], + ) + + def test_update_system_settings_redacts_secret_values_in_response(self): + """更新敏感系统设置后响应不应回显旧值和新值中的密钥。""" + tool = UpdateSystemSettingsTool(session_id="session-1", user_id="10001") + config_oper = MagicMock() + config_oper.get.side_effect = [ + [{"name": "site-a", "apikey": "old-key"}], + [{"name": "site-a", "apikey": "new-key"}], + ] + config_oper.async_set = AsyncMock(return_value=True) + + with patch( + "app.agent.tools.impl.update_system_settings.SystemConfigOper", + return_value=config_oper, + ), patch( + "app.agent.tools.impl.update_system_settings.eventmanager.async_send_event", + new=AsyncMock(), + ): + result = asyncio.run( + tool.run( + setting_key="UserSiteAuthParams", + operation="upsert_list_item", + value={"name": "site-a", "apikey": "new-key"}, + match_field="name", + ) + ) + + payload = json.loads(result) + self.assertTrue(payload["success"]) + self.assertTrue(payload["values_redacted"]) + self.assertEqual("***", payload["previous_value"][0]["apikey"]) + self.assertEqual("***", payload["saved_value"][0]["apikey"]) def test_update_system_settings_updates_basic_settings(self): tool = UpdateSystemSettingsTool(session_id="session-1", user_id="10001") diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py index 94baa380..1b6cd0f2 100644 --- a/tests/test_execute_command_tool.py +++ b/tests/test_execute_command_tool.py @@ -137,8 +137,29 @@ class TestExecuteCommandTool(unittest.TestCase): payload = json.loads(result) self.assertEqual(payload["status"], "error") - # rm -rf / 命中删除根目录防护;断言拒绝原因点明 rm 根目录,避免锁死单一文案 - self.assertIn("不允许使用 rm", payload["error"]) + # rm -rf / 命中高危命令防护;断言拒绝且提示需要显式确认,避免锁死单一文案。 + self.assertIn("confirm_dangerous=true", payload["error"]) + + def test_dangerous_command_requires_explicit_confirmation(self): + """高危命令只有携带显式确认参数时才允许进入执行层。""" + tool = ExecuteCommandTool(session_id="session-1", user_id="10001") + + rejected = asyncio.run( + tool.run(action="run", command="echo ok && shutdown now", timeout=1) + ) + allowed = asyncio.run( + tool.run( + action="run", + command=_python_command("print('shutdown now confirmed')"), + timeout=1, + confirm_dangerous=True, + ) + ) + + rejected_payload = json.loads(rejected) + self.assertEqual("error", rejected_payload["status"]) + self.assertIn("confirm_dangerous=true", rejected_payload["error"]) + self.assertIn("shutdown now confirmed", allowed) class TestExecuteCommandSessionTool(unittest.IsolatedAsyncioTestCase):