diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 0920731c..e33fbbf3 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -9,6 +9,7 @@ Core Capabilities: 2. Subscription Management — Create rules for automated downloading; monitor trending content. 3. Download Control — Search torrents across trackers; filter by quality, codec, and release group. 4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup. +5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant. {verbose_spec} @@ -19,6 +20,7 @@ Core Capabilities: - Use Markdown for structured data. Use `inline code` for media titles/paths. - Include key details (year, rating, resolution) but do NOT over-explain. - Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions). +- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it. - NOT a coding assistant. Do not offer code snippets. - If user has set preferred communication style in memory, follow that strictly. diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index e3ef5b01..40f2bc3a 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -249,7 +249,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): return None - async def send_tool_message(self, message: str, title: str = ""): + async def send_tool_message( + self, message: str, title: str = "", image: Optional[str] = None + ): """ 发送工具消息 """ @@ -261,5 +263,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): username=self._username, title=title, text=message, + image=image, ) ) diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index d3ae67e5..9e2dadc7 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -2,7 +2,7 @@ from typing import Optional, Type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from app.agent.tools.base import MoviePilotTool from app.log import logger @@ -15,42 +15,64 @@ class SendMessageInput(BaseModel): ..., description="Clear explanation of why this tool is being used in the current context", ) - message: str = Field( - ..., + message: Optional[str] = Field( + None, description="The message content to send to the user (should be clear and informative)", ) message_type: Optional[str] = Field( None, description="Title of the message, a short summary of the message content", ) + image_url: Optional[str] = Field( + None, + description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)", + ) + + @model_validator(mode="after") + def validate_payload(self): + if not self.message and not self.message_type and not self.image_url: + raise ValueError("message、message_type、image_url 至少需要提供一个") + return self class SendMessageTool(MoviePilotTool): name: str = "send_message" - description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates." + description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image." args_schema: Type[BaseModel] = SendMessageInput require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据消息参数生成友好的提示消息""" - message = kwargs.get("message", "") + message = kwargs.get("message", "") or "" title = kwargs.get("message_type") or "" + image_url = kwargs.get("image_url") # 截断过长的消息 if len(message) > 50: message = message[:50] + "..." + if title and image_url: + return f"正在发送图文消息: [{title}] {message}" if title: return f"正在发送消息: [{title}] {message}" + if image_url: + return f"正在发送图片消息: {message}" return f"正在发送消息: {message}" async def run( - self, message: str, message_type: Optional[str] = None, **kwargs + self, + message: Optional[str] = None, + message_type: Optional[str] = None, + image_url: Optional[str] = None, + **kwargs, ) -> str: - title = message_type or "" - logger.info(f"执行工具: {self.name}, 参数: title={title}, message={message}") + title = message_type or ("图片" if image_url and not message else "") + text = message or "" + logger.info( + f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}" + ) try: - await self.send_tool_message(message, title=title) + await self.send_tool_message(text, title=title, image=image_url) return "消息已发送" except Exception as e: logger.error(f"发送消息失败: {e}") diff --git a/app/chain/message.py b/app/chain/message.py index aeba1d25..c12e610b 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -126,15 +126,15 @@ class MessageChain(ChainBase): logger.debug(f"未识别到用户ID:{body}{form}{args}") return # 消息内容 - text = str(info.text).strip() if info.text else None - if not text: + text = str(info.text).strip() if info.text else "" + images = info.images + if not text and not images: logger.debug(f"未识别到消息内容::{body}{form}{args}") return # 获取原消息ID信息 original_message_id = info.message_id original_chat_id = info.chat_id - images = info.images # 处理消息 self.handle_message( @@ -221,6 +221,16 @@ class MessageChain(ChainBase): username=username, images=images, ) + elif settings.AI_AGENT_ENABLE and images: + # 带图消息优先交给智能体处理,避免图片在传统消息链路中丢失 + self._handle_ai_message( + text=text, + channel=channel, + source=source, + userid=userid, + username=username, + images=images, + ) elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL: # 普通消息,全局智能体响应 self._handle_ai_message( @@ -1234,8 +1244,20 @@ class MessageChain(ChainBase): session_id = self._get_or_create_session_id(userid) # 下载图片并转为base64 + original_images = images if images: images = self._download_images_to_base64(images, channel, source) + if original_images and not images and not user_message: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="图片读取失败,请稍后重试", + ) + ) + return # 在事件循环中处理 asyncio.run_coroutine_threadsafe( @@ -1275,6 +1297,12 @@ class MessageChain(ChainBase): ) if base64_data: base64_images.append(f"data:image/jpeg;base64,{base64_data}") + elif channel == MessageChannel.Slack: + data_url = self.run_module( + "download_file_to_data_url", file_url=img, source=source + ) + if data_url: + base64_images.append(data_url) elif img.startswith("http"): resp = RequestUtils(timeout=30).get_res(img) if resp and resp.content: diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index cdcd7e4b..4f6d3cff 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -279,12 +279,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return None images = [] for file in files: - if file.get("type") in ("image", "jpg", "jpeg", "png", "gif", "webp"): + file_type = str(file.get("type", "")).lower() + file_ext = str(file.get("filetype", "")).lower() + mime_type = str(file.get("mimetype", "")).lower() + if ( + file_type == "image" + or file_ext in ("jpg", "jpeg", "png", "gif", "webp", "bmp") + or mime_type.startswith("image/") + ): url = file.get("url_private") or file.get("url_private_download") if url: images.append(url) return images if images else None + def download_file_to_data_url(self, file_url: str, source: str) -> Optional[str]: + """ + 下载Slack文件并转为data URL + :param file_url: Slack私有文件URL + :param source: 来源名称 + :return: data URL + """ + config = self.get_config(source) + if not config: + return None + client = self.get_instance(config.name) + if not client: + return None + file_data = client.download_file(file_url) + if file_data: + import base64 + + content, mime_type = file_data + return f"data:{mime_type};base64,{base64.b64encode(content).decode()}" + return None + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 diff --git a/app/modules/slack/slack.py b/app/modules/slack/slack.py index 16f890c0..982dc717 100644 --- a/app/modules/slack/slack.py +++ b/app/modules/slack/slack.py @@ -1,6 +1,6 @@ import re from threading import Lock -from typing import List, Optional +from typing import List, Optional, Tuple from urllib.parse import quote import requests @@ -12,6 +12,7 @@ from app.core.config import settings from app.core.context import MediaInfo, Context from app.core.metainfo import MetaInfo from app.log import logger +from app.utils.http import RequestUtils from app.utils.string import StringUtils lock = Lock() @@ -22,6 +23,7 @@ class Slack: _service: SocketModeHandler = None _ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}" _channel = "" + _oauth_token = "" def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None, SLACK_CHANNEL: Optional[str] = None, **kwargs): @@ -40,6 +42,7 @@ class Slack: self._client = slack_app.client self._channel = SLACK_CHANNEL + self._oauth_token = SLACK_OAUTH_TOKEN # 标记消息来源 if kwargs.get("name"): @@ -102,6 +105,28 @@ class Slack: """ return True if self._client else False + def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]: + """ + 下载Slack私有文件 + :param file_url: Slack文件URL + :return: (文件内容, MIME类型) + """ + if not self._client or not self._oauth_token or not file_url: + return None + try: + headers = { + "Authorization": f"Bearer {self._oauth_token}", + "User-Agent": settings.USER_AGENT, + "Accept": "*/*", + } + resp = RequestUtils(headers=headers, timeout=30).get_res(file_url) + if resp and resp.content: + mime_type = resp.headers.get("Content-Type", "image/jpeg") + return resp.content, mime_type.split(";")[0] + except Exception as e: + logger.error(f"下载Slack文件失败: {e}") + return None + def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None, link: Optional[str] = None, userid: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index cff0f7ed..547ebc08 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -267,14 +267,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): largest_photo = photo[-1] file_id = largest_photo.get("file_id") if file_id: - images.append(file_id) + images.append(f"tg://file_id/{file_id}") document = msg.get("document") if document: file_id = document.get("file_id") mime_type = document.get("mime_type", "") if file_id and mime_type.startswith("image/"): - images.append(file_id) + images.append(f"tg://file_id/{file_id}") return images if images else None diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py new file mode 100644 index 00000000..d52e0d08 --- /dev/null +++ b/tests/test_agent_image_support.py @@ -0,0 +1,119 @@ +import base64 +import unittest +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from app.agent.tools.impl.send_message import SendMessageInput +from app.chain.message import MessageChain +from app.core.config import settings +from app.modules.slack import SlackModule +from app.modules.telegram import TelegramModule +from app.schemas import CommingMessage +from app.schemas.types import MessageChannel + + +class AgentImageSupportTest(unittest.TestCase): + def test_telegram_extract_images_returns_prefixed_file_ids(self): + images = TelegramModule._extract_images( + { + "photo": [{"file_id": "small"}, {"file_id": "large"}], + "document": {"file_id": "doc-image", "mime_type": "image/png"}, + } + ) + + self.assertEqual( + images, + ["tg://file_id/large", "tg://file_id/doc-image"], + ) + + def test_process_allows_image_only_message(self): + chain = MessageChain() + message = CommingMessage( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + images=["tg://file_id/image-1"], + ) + + with patch.object(chain, "message_parser", return_value=message), patch.object( + chain, "handle_message" + ) as handle_message: + chain.process(body="{}", form={}, args={"source": "telegram-test"}) + + handle_kwargs = handle_message.call_args.kwargs + self.assertEqual(handle_kwargs["text"], "") + self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"]) + + def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self): + chain = MessageChain() + + with patch.object(chain, "load_cache", return_value={}), patch.object( + chain.messagehelper, "put" + ), patch.object(chain.messageoper, "add"), patch.object( + chain, "_handle_ai_message" + ) as handle_ai_message, patch.object( + settings, "AI_AGENT_ENABLE", True + ), patch.object( + settings, "AI_AGENT_GLOBAL", False + ): + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="", + images=["tg://file_id/image-1"], + ) + + handle_ai_message.assert_called_once() + + def test_slack_images_use_authenticated_data_url_download(self): + chain = MessageChain() + + with patch.object( + chain, + "run_module", + return_value="data:image/png;base64,abc123", + ) as run_module: + images = chain._download_images_to_base64( + images=["https://files.slack.com/files-pri/T1-F1/test.png"], + channel=MessageChannel.Slack, + source="slack-test", + ) + + self.assertEqual(images, ["data:image/png;base64,abc123"]) + run_module.assert_called_once_with( + "download_file_to_data_url", + file_url="https://files.slack.com/files-pri/T1-F1/test.png", + source="slack-test", + ) + + def test_slack_module_download_file_to_data_url(self): + module = SlackModule() + client = Mock() + client.download_file.return_value = (b"png-binary", "image/png") + + with patch.object( + module, "get_config", return_value=SimpleNamespace(name="slack-test") + ), patch.object(module, "get_instance", return_value=client): + data_url = module.download_file_to_data_url( + "https://files.slack.com/files-pri/T1-F1/test.png", + "slack-test", + ) + + self.assertEqual( + data_url, + f"data:image/png;base64,{base64.b64encode(b'png-binary').decode()}", + ) + + def test_send_message_input_accepts_image_only_payload(self): + payload = SendMessageInput( + explanation="send poster image", + image_url="https://example.com/poster.png", + ) + + self.assertEqual(payload.image_url, "https://example.com/poster.png") + +if __name__ == "__main__": + unittest.main()