mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 23:02:09 +08:00
disable agent message tools for ui background tasks
This commit is contained in:
@@ -30,7 +30,6 @@ from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent_context import agent_execution_context
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
@@ -174,7 +173,7 @@ class MoviePilotAgent:
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self.persist_output_message = True
|
||||
self.suppress_message_channel_dispatch = False
|
||||
self.allow_message_tools = True
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
@@ -368,6 +367,7 @@ class MoviePilotAgent:
|
||||
username=self.username,
|
||||
stream_handler=self.stream_handler,
|
||||
agent_context=self._tool_context,
|
||||
allow_message_tools=self.allow_message_tools,
|
||||
)
|
||||
|
||||
def _create_agent(self, streaming: bool = False):
|
||||
@@ -457,7 +457,6 @@ class MoviePilotAgent:
|
||||
self._tool_context = {
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"suppress_message_channel_dispatch": self.suppress_message_channel_dispatch,
|
||||
}
|
||||
self._streamed_output = ""
|
||||
|
||||
@@ -486,10 +485,7 @@ class MoviePilotAgent:
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
with agent_execution_context(
|
||||
suppress_message_channel_dispatch=self.suppress_message_channel_dispatch
|
||||
):
|
||||
await self._execute_agent(messages)
|
||||
await self._execute_agent(messages)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
@@ -994,7 +990,7 @@ class AgentManager:
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
suppress_user_reply: bool = False,
|
||||
persist_output_message: bool = True,
|
||||
suppress_message_channel_dispatch: bool = False,
|
||||
allow_message_tools: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
以独立后台会话执行一段 prompt。
|
||||
@@ -1012,7 +1008,7 @@ class AgentManager:
|
||||
agent.force_streaming = bool(output_callback)
|
||||
agent.suppress_user_reply = suppress_user_reply
|
||||
agent.persist_output_message = persist_output_message
|
||||
agent.suppress_message_channel_dispatch = suppress_message_channel_dispatch
|
||||
agent.allow_message_tools = allow_message_tools
|
||||
|
||||
try:
|
||||
await agent.process(message)
|
||||
|
||||
@@ -79,6 +79,15 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
_MESSAGE_TOOL_CLASSES = frozenset(
|
||||
{
|
||||
SendMessageTool,
|
||||
AskUserChoiceTool,
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
@@ -100,6 +109,7 @@ class MoviePilotToolFactory:
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
allow_message_tools: bool = True,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -181,6 +191,11 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
if (
|
||||
not allow_message_tools
|
||||
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
|
||||
):
|
||||
continue
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
@@ -196,6 +211,11 @@ class MoviePilotToolFactory:
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
if (
|
||||
not allow_message_tools
|
||||
and ToolClass in MoviePilotToolFactory._MESSAGE_TOOL_CLASSES
|
||||
):
|
||||
continue
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(
|
||||
|
||||
@@ -64,6 +64,7 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
|
||||
@@ -45,6 +45,7 @@ class SendLocalFileInput(BaseModel):
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
|
||||
@@ -37,6 +37,7 @@ class SendMessageInput(BaseModel):
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
sends_message: bool = True
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -27,6 +27,7 @@ class SendVoiceMessageInput(BaseModel):
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Use this only when the user explicitly asks for "
|
||||
"a voice reply or when spoken playback is clearly better than plain text. On channels "
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
_suppress_message_channel_dispatch = contextvars.ContextVar(
|
||||
"suppress_message_channel_dispatch",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def is_message_channel_dispatch_suppressed() -> bool:
|
||||
"""
|
||||
当前 Agent 执行上下文是否禁止向外部消息渠道派发通知。
|
||||
"""
|
||||
return bool(_suppress_message_channel_dispatch.get())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_execution_context(
|
||||
*, suppress_message_channel_dispatch: bool = False
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
绑定当前 Agent 执行期的上下文参数。
|
||||
"""
|
||||
token = _suppress_message_channel_dispatch.set(
|
||||
bool(suppress_message_channel_dispatch)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_suppress_message_channel_dispatch.reset(token)
|
||||
@@ -132,7 +132,7 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手整理完成",
|
||||
@@ -178,7 +178,7 @@ def _start_batch_ai_redo_task(
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手批量整理完成",
|
||||
|
||||
@@ -12,7 +12,6 @@ from fastapi.concurrency import run_in_threadpool
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.agent_context import is_message_channel_dispatch_suppressed
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
@@ -1137,9 +1136,6 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.model_dump())
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
# 发送消息按设置隔离
|
||||
if not dispatch_message.userid and dispatch_message.mtype:
|
||||
@@ -1257,9 +1253,6 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
# 发送消息按设置隔离
|
||||
if not dispatch_message.userid and dispatch_message.mtype:
|
||||
@@ -1354,9 +1347,6 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
return self.messagequeue.send_message(
|
||||
"post_medias_message",
|
||||
@@ -1379,9 +1369,6 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
return self.messagequeue.send_message(
|
||||
"post_torrents_message",
|
||||
|
||||
@@ -228,7 +228,7 @@ class SearchChain(ChainBase):
|
||||
output_callback=on_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
return full_output[0].strip()
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ from types import SimpleNamespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import app.chain as chain_module
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
@@ -21,11 +19,9 @@ def _stub_module(name: str, **attrs):
|
||||
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain.search import SearchChain
|
||||
from app.agent_context import agent_execution_context
|
||||
from app.core.config import settings
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import NotificationType
|
||||
|
||||
|
||||
def _make_result(title: str, size: int, seeders: int):
|
||||
@@ -131,7 +127,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual("[0, 2]", result)
|
||||
self.assertTrue(captured["suppress_user_reply"])
|
||||
self.assertFalse(captured["persist_output_message"])
|
||||
self.assertTrue(captured["suppress_message_channel_dispatch"])
|
||||
self.assertFalse(captured["allow_message_tools"])
|
||||
|
||||
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
||||
chain = self._make_chain()
|
||||
@@ -156,45 +152,22 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIsNone(SearchChain._ai_recommend_result)
|
||||
self.assertIsNone(SearchChain._ai_recommend_error)
|
||||
|
||||
def test_post_message_skips_channel_dispatch_when_suppressed(self):
|
||||
chain = object.__new__(SearchChain)
|
||||
queue_calls = []
|
||||
event_calls = []
|
||||
saved_messages = []
|
||||
saved_records = []
|
||||
chain.messagehelper = SimpleNamespace(
|
||||
put=lambda *args, **kwargs: saved_messages.append((args, kwargs))
|
||||
)
|
||||
chain.messageoper = SimpleNamespace(
|
||||
add=lambda **kwargs: saved_records.append(kwargs)
|
||||
)
|
||||
chain.messagequeue = SimpleNamespace(
|
||||
send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs))
|
||||
)
|
||||
chain.eventmanager = SimpleNamespace(
|
||||
send_event=lambda *args, **kwargs: event_calls.append((args, kwargs))
|
||||
)
|
||||
|
||||
notification = Notification(
|
||||
mtype=NotificationType.Manual,
|
||||
title="Title",
|
||||
text="Body",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
chain_module.MessageTemplateHelper,
|
||||
"render",
|
||||
return_value=notification,
|
||||
),
|
||||
agent_execution_context(suppress_message_channel_dispatch=True),
|
||||
def test_tool_factory_excludes_message_tools_when_disabled(self):
|
||||
with patch(
|
||||
"app.agent.tools.factory.PluginManager.get_plugin_agent_tools",
|
||||
return_value=[],
|
||||
):
|
||||
chain.post_message(message=notification)
|
||||
tools = MoviePilotToolFactory.create_tools(
|
||||
session_id="test-session",
|
||||
user_id="test-user",
|
||||
allow_message_tools=False,
|
||||
)
|
||||
|
||||
self.assertEqual(1, len(saved_messages))
|
||||
self.assertEqual(1, len(saved_records))
|
||||
self.assertEqual([], queue_calls)
|
||||
self.assertEqual([], event_calls)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
self.assertNotIn("send_message", tool_names)
|
||||
self.assertNotIn("ask_user_choice", tool_names)
|
||||
self.assertNotIn("send_local_file", tool_names)
|
||||
self.assertNotIn("send_voice_message", tool_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user