From 7586a2cd42c64a187a5462374583fec9053ec8cc Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 23:30:59 +0800 Subject: [PATCH] disable agent message tools for ui background tasks --- app/agent/__init__.py | 14 ++--- app/agent/tools/factory.py | 20 ++++++++ app/agent/tools/impl/ask_user_choice.py | 1 + app/agent/tools/impl/send_local_file.py | 1 + app/agent/tools/impl/send_message.py | 1 + app/agent/tools/impl/send_voice_message.py | 1 + app/agent_context.py | 31 ------------ app/api/endpoints/history.py | 4 +- app/chain/__init__.py | 13 ----- app/chain/search.py | 2 +- tests/test_search_ai_recommend.py | 59 ++++++---------------- 11 files changed, 48 insertions(+), 99 deletions(-) delete mode 100644 app/agent_context.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 6f7406db..536ded84 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -30,7 +30,6 @@ from app.agent.middleware.usage import UsageMiddleware from app.agent.prompt import prompt_manager from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory -from app.agent_context import agent_execution_context from app.chain import ChainBase from app.core.config import settings from app.helper.llm import LLMHelper @@ -174,7 +173,7 @@ class MoviePilotAgent: self.force_streaming = False self.suppress_user_reply = False self.persist_output_message = True - self.suppress_message_channel_dispatch = False + self.allow_message_tools = True self._streamed_output = "" self._session_usage = _SessionUsageSnapshot() @@ -368,6 +367,7 @@ class MoviePilotAgent: username=self.username, stream_handler=self.stream_handler, agent_context=self._tool_context, + allow_message_tools=self.allow_message_tools, ) def _create_agent(self, streaming: bool = False): @@ -457,7 +457,6 @@ class MoviePilotAgent: self._tool_context = { "user_reply_sent": False, "reply_mode": None, - "suppress_message_channel_dispatch": self.suppress_message_channel_dispatch, } self._streamed_output = "" @@ -486,10 +485,7 @@ class MoviePilotAgent: messages.append(HumanMessage(content=content)) # 执行推理 - with agent_execution_context( - suppress_message_channel_dispatch=self.suppress_message_channel_dispatch - ): - await self._execute_agent(messages) + await self._execute_agent(messages) except Exception as e: error_message = f"处理消息时发生错误: {str(e)}" @@ -994,7 +990,7 @@ class AgentManager: output_callback: Optional[Callable[[str], None]] = None, suppress_user_reply: bool = False, persist_output_message: bool = True, - suppress_message_channel_dispatch: bool = False, + allow_message_tools: bool = True, ) -> None: """ 以独立后台会话执行一段 prompt。 @@ -1012,7 +1008,7 @@ class AgentManager: agent.force_streaming = bool(output_callback) agent.suppress_user_reply = suppress_user_reply agent.persist_output_message = persist_output_message - agent.suppress_message_channel_dispatch = suppress_message_channel_dispatch + agent.allow_message_tools = allow_message_tools try: await agent.process(message) diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 61653f53..51f6a177 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -79,6 +79,15 @@ class MoviePilotToolFactory: MoviePilot工具工厂 """ + _MESSAGE_TOOL_CLASSES = frozenset( + { + SendMessageTool, + AskUserChoiceTool, + SendLocalFileTool, + SendVoiceMessageTool, + } + ) + @staticmethod def _should_enable_choice_tool(channel: str = None) -> bool: if not channel: @@ -100,6 +109,7 @@ class MoviePilotToolFactory: username: str = None, stream_handler: Callable = None, agent_context: dict = None, + allow_message_tools: bool = True, ) -> List[MoviePilotTool]: """ 创建MoviePilot工具列表 @@ -181,6 +191,11 @@ class MoviePilotToolFactory: ) # 创建内置工具 for ToolClass in tool_definitions: + if ( + not allow_message_tools + and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES + ): + continue tool = ToolClass(session_id=session_id, user_id=user_id) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_stream_handler(stream_handler=stream_handler) @@ -196,6 +211,11 @@ class MoviePilotToolFactory: tool_classes = plugin_info.get("tools", []) for ToolClass in tool_classes: try: + if ( + not allow_message_tools + and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES + ): + continue # 验证工具类是否继承自 MoviePilotTool if not issubclass(ToolClass, MoviePilotTool): logger.warning( diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index 28bf7c12..c219b97e 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -64,6 +64,7 @@ class AskUserChoiceInput(BaseModel): class AskUserChoiceTool(MoviePilotTool): name: str = "ask_user_choice" + sends_message: bool = True description: str = ( "Ask the user to choose from button options on channels that support interactive buttons. " "After the user clicks a button, the selected value will come back as the user's next message." diff --git a/app/agent/tools/impl/send_local_file.py b/app/agent/tools/impl/send_local_file.py index 0f3828e3..ae971c45 100644 --- a/app/agent/tools/impl/send_local_file.py +++ b/app/agent/tools/impl/send_local_file.py @@ -45,6 +45,7 @@ class SendLocalFileInput(BaseModel): class SendLocalFileTool(MoviePilotTool): name: str = "send_local_file" + sends_message: bool = True description: str = ( "Send a local image or file from the server filesystem to the current user. " "Use this when you have generated or identified a local file the user should download." diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index 3f145b22..8aa59196 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -37,6 +37,7 @@ class SendMessageInput(BaseModel): class SendMessageTool(MoviePilotTool): name: str = "send_message" + sends_message: bool = True description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image." args_schema: Type[BaseModel] = SendMessageInput require_admin: bool = True diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py index 1ceef892..dafdf65b 100644 --- a/app/agent/tools/impl/send_voice_message.py +++ b/app/agent/tools/impl/send_voice_message.py @@ -27,6 +27,7 @@ class SendVoiceMessageInput(BaseModel): class SendVoiceMessageTool(MoviePilotTool): name: str = "send_voice_message" + sends_message: bool = True description: str = ( "Send a voice reply to the current user. Use this only when the user explicitly asks for " "a voice reply or when spoken playback is clearly better than plain text. On channels " diff --git a/app/agent_context.py b/app/agent_context.py deleted file mode 100644 index 0d9c0ba1..00000000 --- a/app/agent_context.py +++ /dev/null @@ -1,31 +0,0 @@ -import contextvars -from contextlib import contextmanager -from typing import Iterator - -_suppress_message_channel_dispatch = contextvars.ContextVar( - "suppress_message_channel_dispatch", - default=False, -) - - -def is_message_channel_dispatch_suppressed() -> bool: - """ - 当前 Agent 执行上下文是否禁止向外部消息渠道派发通知。 - """ - return bool(_suppress_message_channel_dispatch.get()) - - -@contextmanager -def agent_execution_context( - *, suppress_message_channel_dispatch: bool = False -) -> Iterator[None]: - """ - 绑定当前 Agent 执行期的上下文参数。 - """ - token = _suppress_message_channel_dispatch.set( - bool(suppress_message_channel_dispatch) - ) - try: - yield - finally: - _suppress_message_channel_dispatch.reset(token) diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index 4e5f64ea..1a2abb42 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -132,7 +132,7 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str): output_callback=update_output, suppress_user_reply=True, persist_output_message=False, - suppress_message_channel_dispatch=True, + allow_message_tools=False, ) progress.update( text="智能助手整理完成", @@ -178,7 +178,7 @@ def _start_batch_ai_redo_task( output_callback=update_output, suppress_user_reply=True, persist_output_message=False, - suppress_message_channel_dispatch=True, + allow_message_tools=False, ) progress.update( text="智能助手批量整理完成", diff --git a/app/chain/__init__.py b/app/chain/__init__.py index aa74e3f2..0073a537 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -12,7 +12,6 @@ from fastapi.concurrency import run_in_threadpool from qbittorrentapi import TorrentFilesList from transmission_rpc import File -from app.agent_context import is_message_channel_dispatch_suppressed from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh from app.core.config import settings from app.core.context import Context, MediaInfo, TorrentInfo @@ -1137,9 +1136,6 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) self.messageoper.add(**message.model_dump()) - if is_message_channel_dispatch_suppressed(): - logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") - return dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 if not dispatch_message.userid and dispatch_message.mtype: @@ -1257,9 +1253,6 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) await self.messageoper.async_add(**message.model_dump()) - if is_message_channel_dispatch_suppressed(): - logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") - return dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 if not dispatch_message.userid and dispatch_message.mtype: @@ -1354,9 +1347,6 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) - if is_message_channel_dispatch_suppressed(): - logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录") - return None dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( "post_medias_message", @@ -1379,9 +1369,6 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) - if is_message_channel_dispatch_suppressed(): - logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录") - return None dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( "post_torrents_message", diff --git a/app/chain/search.py b/app/chain/search.py index 04acb577..7b4b39d5 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -228,7 +228,7 @@ class SearchChain(ChainBase): output_callback=on_output, suppress_user_reply=True, persist_output_message=False, - suppress_message_channel_dispatch=True, + allow_message_tools=False, ) return full_output[0].strip() diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index e7ea36fd..6438d637 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -5,8 +5,6 @@ from types import SimpleNamespace from types import ModuleType from unittest.mock import AsyncMock, patch -import app.chain as chain_module - def _stub_module(name: str, **attrs): module = sys.modules.get(name) @@ -21,11 +19,9 @@ def _stub_module(name: str, **attrs): _stub_module("qbittorrentapi", TorrentFilesList=list) _stub_module("transmission_rpc", File=object) +from app.agent.tools.factory import MoviePilotToolFactory from app.chain.search import SearchChain -from app.agent_context import agent_execution_context from app.core.config import settings -from app.schemas import Notification -from app.schemas.types import NotificationType def _make_result(title: str, size: int, seeders: int): @@ -131,7 +127,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertEqual("[0, 2]", result) self.assertTrue(captured["suppress_user_reply"]) self.assertFalse(captured["persist_output_message"]) - self.assertTrue(captured["suppress_message_channel_dispatch"]) + self.assertFalse(captured["allow_message_tools"]) def test_search_by_title_clears_previous_recommend_state_when_caching(self): chain = self._make_chain() @@ -156,45 +152,22 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertIsNone(SearchChain._ai_recommend_result) self.assertIsNone(SearchChain._ai_recommend_error) - def test_post_message_skips_channel_dispatch_when_suppressed(self): - chain = object.__new__(SearchChain) - queue_calls = [] - event_calls = [] - saved_messages = [] - saved_records = [] - chain.messagehelper = SimpleNamespace( - put=lambda *args, **kwargs: saved_messages.append((args, kwargs)) - ) - chain.messageoper = SimpleNamespace( - add=lambda **kwargs: saved_records.append(kwargs) - ) - chain.messagequeue = SimpleNamespace( - send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs)) - ) - chain.eventmanager = SimpleNamespace( - send_event=lambda *args, **kwargs: event_calls.append((args, kwargs)) - ) - - notification = Notification( - mtype=NotificationType.Manual, - title="Title", - text="Body", - ) - - with ( - patch.object( - chain_module.MessageTemplateHelper, - "render", - return_value=notification, - ), - agent_execution_context(suppress_message_channel_dispatch=True), + def test_tool_factory_excludes_message_tools_when_disabled(self): + with patch( + "app.agent.tools.factory.PluginManager.get_plugin_agent_tools", + return_value=[], ): - chain.post_message(message=notification) + tools = MoviePilotToolFactory.create_tools( + session_id="test-session", + user_id="test-user", + allow_message_tools=False, + ) - self.assertEqual(1, len(saved_messages)) - self.assertEqual(1, len(saved_records)) - self.assertEqual([], queue_calls) - self.assertEqual([], event_calls) + tool_names = {tool.name for tool in tools} + self.assertNotIn("send_message", tool_names) + self.assertNotIn("ask_user_choice", tool_names) + self.assertNotIn("send_local_file", tool_names) + self.assertNotIn("send_voice_message", tool_names) if __name__ == "__main__":