Files
MoviePilot/app/agent/tools/base.py
2025-11-17 10:18:05 +08:00

75 lines
2.4 KiB
Python

"""MoviePilot工具基类"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Any
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.chain import ChainBase
from app.schemas import Notification
class ToolChain(ChainBase):
pass
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
_channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message)
# 发送执行工具说明
explanation = kwargs.get("explanation")
if explanation:
if not explanation.startswith("正在"):
explanation = "正在" + explanation
await self.send_tool_message(f"{explanation} ...")
return await self.run(**kwargs)
@abstractmethod
async def run(self, **kwargs) -> str:
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
self._channel = channel
self._source = source
self._username = username
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
self._callback_handler = callback_handler
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
await ToolChain().async_post_message(
Notification(
channel=self._channel,
source=self._source,
userid=self._user_id,
username=self._username,
title=title,
text=message
)
)