mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-02 01:49:43 +08:00
feat(subscribe): add modified event payload contract (#6012)
This commit is contained in:
@@ -10,6 +10,7 @@ from app.agent.tools.tags import ToolTag
|
||||
from app.core.event import eventmanager
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.event import SubscribeModifiedEventData
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
@@ -261,13 +262,14 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
SubscribeModifiedEventData(
|
||||
subscribe_id=subscribe_id,
|
||||
old_subscribe_info=old_subscribe_dict,
|
||||
subscribe_info=updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
scene="agent_update",
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.db.user_oper import get_current_active_user_async
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.event import SubscribeModifiedEventData
|
||||
from app.schemas.types import MediaType, EventType, SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
@@ -149,11 +150,12 @@ async def update_subscribe(
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_in.id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
},
|
||||
SubscribeModifiedEventData(
|
||||
subscribe_id=subscribe_in.id,
|
||||
old_subscribe_info=old_subscribe_dict,
|
||||
subscribe_info=updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
scene="update",
|
||||
).to_dict(),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -181,11 +183,12 @@ async def update_subscribe_status(
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subid,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
},
|
||||
SubscribeModifiedEventData(
|
||||
subscribe_id=subid,
|
||||
old_subscribe_info=old_subscribe_dict,
|
||||
subscribe_info=updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
scene="status",
|
||||
).to_dict(),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -275,13 +278,14 @@ async def reset_subscribes(
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subid,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
SubscribeModifiedEventData(
|
||||
subscribe_id=subid,
|
||||
old_subscribe_info=old_subscribe_dict,
|
||||
subscribe_info=updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
scene="reset",
|
||||
).to_dict(),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
|
||||
@@ -603,6 +603,48 @@ class SubscribeEpisodesRefreshEventData(ChainEventData):
|
||||
reason: str = Field(default="", description="覆盖原因")
|
||||
|
||||
|
||||
class SubscribeModifiedEventData(BaseEventData):
|
||||
"""
|
||||
SubscribeModified 广播事件数据。
|
||||
|
||||
主程序在订阅字段被普通更新、状态入口、重置或 Agent 更新后发出。payload
|
||||
继续保持 dict 形态,scene 用于表达操作场景,fields 表达最终快照里的真实字段差异。
|
||||
"""
|
||||
|
||||
subscribe_id: int = Field(description="订阅 ID")
|
||||
old_subscribe_info: Dict[str, Any] = Field(default_factory=dict, description="更新前订阅快照")
|
||||
subscribe_info: Dict[str, Any] = Field(default_factory=dict, description="更新后订阅快照")
|
||||
scene: str = Field(default="update", description="触发场景:update/status/reset/agent_update")
|
||||
fields: List[str] = Field(default_factory=list, description="真实变更字段")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def compute_fields(self):
|
||||
self.fields = self._diff_fields(self.old_subscribe_info, self.subscribe_info)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _diff_fields(old_info: Dict[str, Any], new_info: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
按 old/new 快照并集计算真实字段差异;缺失 key 按 None 参与比较。
|
||||
"""
|
||||
old_info = old_info or {}
|
||||
new_info = new_info or {}
|
||||
keys = set(old_info) | set(new_info)
|
||||
return sorted(key for key in keys if old_info.get(key) != new_info.get(key))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
输出公开事件 payload,避免内部属性被未来扩展意外暴露。
|
||||
"""
|
||||
return {
|
||||
"subscribe_id": self.subscribe_id,
|
||||
"old_subscribe_info": self.old_subscribe_info,
|
||||
"subscribe_info": self.subscribe_info,
|
||||
"scene": self.scene,
|
||||
"fields": list(self.fields),
|
||||
}
|
||||
|
||||
|
||||
class SubscribeCompletionCheckEventData(ChainEventData):
|
||||
"""
|
||||
SubscribeCompletionCheck 事件的数据模型
|
||||
|
||||
74
tests/test_agent_update_subscribe.py
Normal file
74
tests/test_agent_update_subscribe.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
def test_agent_update_subscribe_sends_modified_event_payload_with_agent_scene():
|
||||
"""
|
||||
Agent 更新订阅后只发送 modify 事件,并标记 agent_update 场景。
|
||||
"""
|
||||
subscribe = _AgentSubscribe(id=9, name="旧标题", state="R", total_episode=8)
|
||||
oper = _SubscribeOperStub(subscribe)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.update_subscribe.SubscribeOper",
|
||||
return_value=oper,
|
||||
), patch(
|
||||
"app.agent.tools.impl.update_subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
result = asyncio.run(
|
||||
UpdateSubscribeTool(session_id="session-1", user_id="10001").run(
|
||||
subscribe_id=9,
|
||||
name="新标题",
|
||||
state="S",
|
||||
)
|
||||
)
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload["success"] is True
|
||||
assert oper.updates == [(9, {"name": "新标题", "state": "S"})]
|
||||
send_event.assert_awaited_once()
|
||||
event_type, event_payload = send_event.await_args.args
|
||||
assert event_type == EventType.SubscribeModified
|
||||
assert event_payload["subscribe_id"] == 9
|
||||
assert event_payload["scene"] == "agent_update"
|
||||
assert event_payload["fields"] == ["name", "state"]
|
||||
assert event_payload["old_subscribe_info"]["name"] == "旧标题"
|
||||
assert event_payload["subscribe_info"]["name"] == "新标题"
|
||||
|
||||
|
||||
class _AgentSubscribe:
|
||||
"""
|
||||
最小订阅替身,模拟 Agent 工具依赖的订阅对象接口。
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
|
||||
class _SubscribeOperStub:
|
||||
"""
|
||||
内存订阅操作替身,记录工具最终提交的更新字段。
|
||||
"""
|
||||
|
||||
def __init__(self, subscribe):
|
||||
self.subscribe = subscribe
|
||||
self.updates = []
|
||||
|
||||
async def async_get(self, subscribe_id):
|
||||
return self.subscribe if subscribe_id == self.subscribe.id else None
|
||||
|
||||
async def async_update(self, subscribe_id, payload):
|
||||
self.updates.append((subscribe_id, dict(payload)))
|
||||
self.subscribe.__dict__.update(payload)
|
||||
return self.subscribe
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.api.endpoints.subscribe import create_subscribe
|
||||
from app.schemas.subscribe import Subscribe
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import EventType, MediaType
|
||||
|
||||
|
||||
class SubscribeEndpointTest(TestCase):
|
||||
@@ -73,3 +73,135 @@ class SubscribeEndpointTest(TestCase):
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(async_add.await_args.kwargs["season"], 0)
|
||||
|
||||
def test_update_status_sends_modified_event_payload_with_scene_and_fields(self):
|
||||
"""
|
||||
状态更新只负责发出订阅修改事件,并携带场景和真实变更字段。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe_status
|
||||
|
||||
subscribe = _EndpointSubscribe(id=5, state="R", name="测试订阅")
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(update_subscribe_status(subid=5, state="S", db=object()))
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
event_type, payload = send_event.await_args.args
|
||||
self.assertEqual(event_type, EventType.SubscribeModified)
|
||||
self.assertEqual(payload["subscribe_id"], 5)
|
||||
self.assertEqual(payload["scene"], "status")
|
||||
self.assertEqual(payload["fields"], ["state"])
|
||||
self.assertEqual(payload["old_subscribe_info"]["state"], "R")
|
||||
self.assertEqual(payload["subscribe_info"]["state"], "S")
|
||||
|
||||
def test_reset_sends_modified_event_payload_with_reset_scene(self):
|
||||
"""
|
||||
reset 事件需要明确 scene,消费者不需要再从字段差异猜测用户意图。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import reset_subscribes
|
||||
|
||||
subscribe = _EndpointSubscribe(
|
||||
id=6,
|
||||
state="S",
|
||||
name="测试订阅",
|
||||
total_episode=10,
|
||||
lack_episode=3,
|
||||
note=[1, 2],
|
||||
current_priority=80,
|
||||
episode_priority={"1": 80},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(reset_subscribes(subid=6, db=object()))
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
event_type, payload = send_event.await_args.args
|
||||
self.assertEqual(event_type, EventType.SubscribeModified)
|
||||
self.assertEqual(payload["subscribe_id"], 6)
|
||||
self.assertEqual(payload["scene"], "reset")
|
||||
self.assertEqual(
|
||||
payload["fields"],
|
||||
["current_priority", "episode_priority", "lack_episode", "note", "state"],
|
||||
)
|
||||
self.assertEqual(payload["subscribe_info"]["note"], [])
|
||||
self.assertEqual(payload["subscribe_info"]["lack_episode"], 10)
|
||||
|
||||
def test_update_subscribe_sends_modified_event_payload_without_progress_refresh(self):
|
||||
"""
|
||||
普通更新只发送 modify 事件;进度刷新由事件消费者或后续流程处理。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe
|
||||
|
||||
subscribe = _EndpointSubscribe(
|
||||
id=7,
|
||||
name="旧标题",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
vote=0.0,
|
||||
sites=[],
|
||||
search_imdbid=0,
|
||||
filter_groups=[],
|
||||
start_episode=0,
|
||||
)
|
||||
subscribe_in = Subscribe(id=7, name="新标题", total_episode=8, lack_episode=2)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(update_subscribe(subscribe_in=subscribe_in, db=object()))
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
event_type, payload = send_event.await_args.args
|
||||
self.assertEqual(event_type, EventType.SubscribeModified)
|
||||
self.assertEqual(payload["subscribe_id"], 7)
|
||||
self.assertEqual(payload["scene"], "update")
|
||||
self.assertEqual(payload["fields"], ["name"])
|
||||
self.assertEqual(payload["old_subscribe_info"]["name"], "旧标题")
|
||||
self.assertEqual(payload["subscribe_info"]["name"], "新标题")
|
||||
|
||||
|
||||
class _EndpointSubscribe:
|
||||
"""
|
||||
最小订阅替身,模拟 endpoint 依赖的 ORM 对象接口。
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.pop("id", None)
|
||||
self.name = kwargs.pop("name", None)
|
||||
self.total_episode = kwargs.pop("total_episode", None)
|
||||
self.lack_episode = kwargs.pop("lack_episode", None)
|
||||
self.state = kwargs.pop("state", None)
|
||||
self.note = kwargs.pop("note", None)
|
||||
self.current_priority = kwargs.pop("current_priority", None)
|
||||
self.episode_priority = kwargs.pop("episode_priority", None)
|
||||
self.manual_total_episode = kwargs.pop("manual_total_episode", None)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
async def async_update(self, _db, payload):
|
||||
self.__dict__.update(payload)
|
||||
|
||||
50
tests/test_subscribe_modified_event.py
Normal file
50
tests/test_subscribe_modified_event.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from app.schemas.event import SubscribeModifiedEventData
|
||||
|
||||
|
||||
def test_subscribe_modified_event_data_computes_sorted_fields():
|
||||
data = SubscribeModifiedEventData(
|
||||
subscribe_id=7,
|
||||
old_subscribe_info={"state": "R", "lack_episode": 3, "name": "A"},
|
||||
subscribe_info={"state": "S", "lack_episode": 3, "name": "B"},
|
||||
scene="status",
|
||||
)
|
||||
|
||||
assert data.fields == ["name", "state"]
|
||||
assert data.to_dict() == {
|
||||
"subscribe_id": 7,
|
||||
"old_subscribe_info": {"state": "R", "lack_episode": 3, "name": "A"},
|
||||
"subscribe_info": {"state": "S", "lack_episode": 3, "name": "B"},
|
||||
"scene": "status",
|
||||
"fields": ["name", "state"],
|
||||
}
|
||||
|
||||
|
||||
def test_subscribe_modified_event_data_diffs_missing_keys_as_none():
|
||||
data = SubscribeModifiedEventData(
|
||||
subscribe_id=8,
|
||||
old_subscribe_info={"state": "R", "episode_priority": {"1": 80}},
|
||||
subscribe_info={"state": "R"},
|
||||
scene="reset",
|
||||
)
|
||||
|
||||
assert data.fields == ["episode_priority"]
|
||||
assert set(data.to_dict()) == {
|
||||
"subscribe_id",
|
||||
"old_subscribe_info",
|
||||
"subscribe_info",
|
||||
"scene",
|
||||
"fields",
|
||||
}
|
||||
|
||||
|
||||
def test_subscribe_modified_event_data_ignores_caller_supplied_fields():
|
||||
data = SubscribeModifiedEventData(
|
||||
subscribe_id=9,
|
||||
old_subscribe_info={"state": "R"},
|
||||
subscribe_info={"state": "S"},
|
||||
scene="update",
|
||||
fields=["fake"],
|
||||
)
|
||||
|
||||
assert data.fields == ["state"]
|
||||
assert data.to_dict()["fields"] == ["state"]
|
||||
Reference in New Issue
Block a user