mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 06:42:38 +08:00
refactor: unify message interactions
This commit is contained in:
@@ -1,107 +0,0 @@
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.interaction import (
|
||||
from app.chain.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
|
||||
1363
app/chain/interaction.py
Normal file
1363
app/chain/interaction.py
Normal file
File diff suppressed because it is too large
Load Diff
1074
app/chain/message.py
1074
app/chain/message.py
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.agent.interaction import (
|
||||
from app.chain.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
|
||||
158
tests/test_media_interaction.py
Normal file
158
tests/test_media_interaction.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
|
||||
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
|
||||
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
|
||||
setattr(sys.modules["transmission_rpc"], "File", object)
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
|
||||
from app.chain.interaction import MediaInteractionChain, media_interaction_manager
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestMediaInteraction(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
media_interaction_manager.clear()
|
||||
|
||||
@staticmethod
|
||||
def _build_meta(name: str) -> MetaBase:
|
||||
meta = MetaBase(name)
|
||||
meta.name = name
|
||||
meta.begin_season = 1
|
||||
return meta
|
||||
|
||||
def test_message_routes_text_reply_to_media_interaction_before_ai(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
self.assertIsNotNone(request)
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="1",
|
||||
)
|
||||
|
||||
handle_text.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
def test_callback_routes_to_media_interaction_chain(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
|
||||
return_value=True,
|
||||
) as handle_callback:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:media:{request.request_id}:page-next",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
handle_callback.assert_called_once()
|
||||
|
||||
def test_media_interaction_starts_search_and_posts_media_list(self):
|
||||
chain = MediaInteractionChain()
|
||||
meta = self._build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
MediaInfo(title="Interstellar", year="2014"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.chain.interaction.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
), patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_text_interaction(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="星际穿越",
|
||||
)
|
||||
|
||||
self.assertTrue(handled)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertTrue(notification.buttons)
|
||||
self.assertTrue(
|
||||
notification.buttons[0][0]["callback_data"].startswith("media:")
|
||||
)
|
||||
|
||||
request = media_interaction_manager.get_by_user("10001")
|
||||
self.assertIsNotNone(request)
|
||||
self.assertEqual(request.action, "Search")
|
||||
self.assertEqual(len(request.items), 2)
|
||||
|
||||
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
|
||||
chain = MediaInteractionChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[
|
||||
MediaInfo(title=f"资源 {index}", year="2024")
|
||||
for index in range(1, 11)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_callback_interaction(
|
||||
callback_data="page_n",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
|
||||
self.assertTrue(handled)
|
||||
self.assertEqual(request.page, 1)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertEqual(notification.original_message_id, 123)
|
||||
self.assertEqual(notification.original_chat_id, "456")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user