fix(security): scope subscriptions to owner (#6056)

This commit is contained in:
InfinityPacer
2026-07-05 09:45:27 +08:00
committed by GitHub
parent d977e4c48a
commit 7f82a9ea4d
5 changed files with 1017 additions and 68 deletions

View File

@@ -17,7 +17,7 @@ from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user_async
from app.db.user_oper import get_current_active_user, get_current_active_user_async
from app.helper.server import MoviePilotServerHelper
from app.log import logger
from app.scheduler import Scheduler
@@ -51,14 +51,69 @@ def build_subscribe_event_payload(subscribe: Subscribe) -> dict:
return {column.name: values.get(column.name) for column in subscribe.__table__.columns}
def can_access_subscribe(
subscribe: Subscribe | SubscribeHistory | None, current_user: User
) -> bool:
"""
判断当前用户是否可访问订阅及其历史记录。
超级用户拥有全局订阅管理能力;普通用户只能访问 username 精确匹配自己的订阅。
空 username 表示无法归属的 legacy 订阅,只能由超级用户管理。
"""
if not subscribe:
return False
if current_user.is_superuser:
return True
username = subscribe.username
return bool(username) and username == current_user.name
async def get_accessible_subscribe(
db: AsyncSession, subscribe_id: int, current_user: User
) -> Subscribe | None:
"""
按订阅 ID 读取当前用户可访问的订阅行。
"""
subscribe = await Subscribe.async_get(db, subscribe_id)
if can_access_subscribe(subscribe, current_user):
return subscribe
return None
def get_accessible_subscribe_sync(
db: Session, subscribe_id: int, current_user: User
) -> Subscribe | None:
"""
同步读取当前用户可访问的订阅行。
"""
subscribe = Subscribe.get(db, subscribe_id)
if can_access_subscribe(subscribe, current_user):
return subscribe
return None
def select_accessible_subscribe(
subscribes: List[Subscribe], current_user: User
) -> Subscribe | None:
"""
从候选订阅中选择当前用户可访问的第一条记录。
"""
for subscribe in subscribes or []:
if can_access_subscribe(subscribe, current_user):
return subscribe
return None
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
async def read_subscribes(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
查询所有订阅
"""
if not current_user.is_superuser:
return await Subscribe.async_list_by_username(db, current_user.name)
return await Subscribe.async_list(db)
@@ -69,7 +124,7 @@ async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询所有订阅 API_TOKEN认证?token=xxx
"""
return await read_subscribes()
return await Subscribe.async_list()
@router.post("/", summary="新增订阅", response_model=schemas.Response)
@@ -106,7 +161,11 @@ async def create_subscribe(
# completed_episode 是响应派生字段,禁止写入持久层
subscribe_dict.pop("completed_episode", None)
sid, message = await SubscribeChain().async_add(
mtype=mtype, title=title, exist_ok=True, **subscribe_dict
mtype=mtype,
title=title,
exist_ok=True,
owner_scope=not current_user.is_superuser,
**subscribe_dict,
)
return schemas.Response(success=bool(sid), message=message, data={"id": sid})
@@ -116,17 +175,18 @@ async def update_subscribe(
*,
subscribe_in: schemas.Subscribe,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
更新订阅信息
"""
subscribe = await Subscribe.async_get(db, subscribe_in.id)
subscribe = await get_accessible_subscribe(db, subscribe_in.id, current_user)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
# 避免更新缺失集数
old_subscribe_dict = subscribe.to_dict()
subscribe_dict = subscribe_in.model_dump()
subscribe_dict["username"] = subscribe.username
if subscribe_in.episode_priority is None:
subscribe_dict.pop("episode_priority", None)
# completed_episode 是响应派生字段,禁止写入持久层
@@ -165,12 +225,12 @@ async def update_subscribe_status(
subid: int,
state: str,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
更新订阅状态
"""
subscribe = await Subscribe.async_get(db, subid)
subscribe = await get_accessible_subscribe(db, subid, current_user)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
valid_states = ["R", "P", "S"]
@@ -199,7 +259,7 @@ async def subscribe_mediaid(
season: Optional[int] = None,
title: Optional[str] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
@@ -209,23 +269,27 @@ async def subscribe_mediaid(
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return Subscribe()
result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
result = select_accessible_subscribe(subscribes, current_user)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return Subscribe()
result = await Subscribe.async_get_by_doubanid(db, doubanid)
subscribes = await Subscribe.async_list_by_doubanid(db, doubanid)
result = select_accessible_subscribe(subscribes, current_user)
if not result and title:
title_check = True
elif mediaid.startswith("bangumi:"):
bangumiid = mediaid[8:]
if not bangumiid or not str(bangumiid).isdigit():
return Subscribe()
result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
subscribes = await Subscribe.async_list_by_bangumiid(db, int(bangumiid))
result = select_accessible_subscribe(subscribes, current_user)
if not result and title:
title_check = True
else:
result = await Subscribe.async_get_by_mediaid(db, mediaid)
subscribes = await Subscribe.async_list_by_mediaid(db, mediaid)
result = select_accessible_subscribe(subscribes, current_user)
if not result and title:
title_check = True
# 使用名称检查订阅
@@ -233,18 +297,23 @@ async def subscribe_mediaid(
meta = MetaInfo(title)
if season is not None:
meta.begin_season = season
result = await Subscribe.async_get_by_title(
subscribes = await Subscribe.async_list_by_title(
db, title=meta.name, season=meta.begin_season
)
result = select_accessible_subscribe(subscribes, current_user)
return result if result else Subscribe()
@router.get("/refresh", summary="刷新订阅", response_model=schemas.Response)
def refresh_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def refresh_subscribes(
current_user: User = Depends(get_current_active_user),
) -> Any:
"""
刷新所有订阅
"""
if not current_user.is_superuser:
return schemas.Response(success=False, message="订阅不存在")
Scheduler().start("subscribe_refresh")
return schemas.Response(success=True)
@@ -253,12 +322,12 @@ def refresh_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def reset_subscribes(
subid: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
重置订阅
"""
subscribe = await Subscribe.async_get(db, subid)
subscribe = await get_accessible_subscribe(db, subid, current_user)
if subscribe:
# 在更新之前获取旧数据
old_subscribe_dict = subscribe.to_dict()
@@ -292,26 +361,43 @@ async def reset_subscribes(
@router.get("/check", summary="刷新订阅 TMDB 信息", response_model=schemas.Response)
def check_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def check_subscribes(
current_user: User = Depends(get_current_active_user),
) -> Any:
"""
刷新订阅 TMDB 信息
"""
if not current_user.is_superuser:
return schemas.Response(success=False, message="订阅不存在")
Scheduler().start("subscribe_tmdb")
return schemas.Response(success=True)
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
async def search_subscribes(
background_tasks: BackgroundTasks, _: schemas.TokenPayload = Depends(verify_token)
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
搜索所有订阅
"""
background_tasks.add_task(
Scheduler().start,
job_id="subscribe_search",
**{"sid": None, "state": "R", "manual": True},
)
if current_user.is_superuser:
background_tasks.add_task(
Scheduler().start,
job_id="subscribe_search",
**{"sid": None, "state": "R", "manual": True},
)
else:
subscribes = await Subscribe.async_list_by_username(
db, current_user.name, state="R"
)
for subscribe in subscribes:
background_tasks.add_task(
Scheduler().start,
job_id="subscribe_search",
**{"sid": subscribe.id, "state": None, "manual": True},
)
return schemas.Response(success=True)
@@ -321,11 +407,15 @@ async def search_subscribes(
async def search_subscribe(
subscribe_id: int,
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
根据订阅编号搜索订阅
"""
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
background_tasks.add_task(
Scheduler().start,
job_id="subscribe_search",
@@ -339,7 +429,7 @@ async def delete_subscribe_by_mediaid(
mediaid: str,
season: Optional[int] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
根据TMDBID或豆瓣ID删除订阅 tmdb:/douban:
@@ -355,15 +445,17 @@ async def delete_subscribe_by_mediaid(
doubanid = mediaid[7:]
if not doubanid:
return schemas.Response(success=False)
subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
if subscribe:
delete_subscribes.append(subscribe)
subscribes = await Subscribe.async_list_by_doubanid(db, doubanid)
delete_subscribes.extend(subscribes)
else:
subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
if subscribe:
delete_subscribes.append(subscribe)
subscribes = await Subscribe.async_list_by_mediaid(db, mediaid)
delete_subscribes.extend(subscribes)
delete_events = []
for subscribe in delete_subscribes:
for subscribe in [
subscribe
for subscribe in delete_subscribes
if can_access_subscribe(subscribe, current_user)
]:
subscribe_info = build_subscribe_event_payload(subscribe)
subscribe_id = subscribe_info.get("id")
if not subscribe_id:
@@ -464,14 +556,19 @@ async def subscribe_history(
page: Optional[int] = 1,
count: Optional[int] = 30,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
查询电影/电视剧订阅历史
"""
histories = await SubscribeHistory.async_list_by_type(
db, mtype=mtype, page=page, count=count
)
if current_user.is_superuser:
histories = await SubscribeHistory.async_list_by_type(
db, mtype=mtype, page=page, count=count
)
else:
histories = await SubscribeHistory.async_list_by_type_and_username(
db, mtype=mtype, username=current_user.name, page=page, count=count
)
result = []
for history in histories:
history_item = schemas.Subscribe.model_validate(history, from_attributes=True)
@@ -488,12 +585,14 @@ async def subscribe_history(
async def delete_subscribe_history(
history_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
删除订阅历史
"""
await SubscribeHistory.async_delete(db, history_id)
history = await SubscribeHistory.async_get(db, history_id)
if can_access_subscribe(history, current_user):
await SubscribeHistory.async_delete(db, history_id)
return schemas.Response(success=True)
@@ -565,11 +664,13 @@ async def popular_subscribes(
async def user_subscribes(
username: str,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
查询用户订阅
"""
if not current_user.is_superuser and username != current_user.name:
return []
return await Subscribe.async_list_by_username(db, username)
@@ -581,12 +682,12 @@ async def user_subscribes(
def subscribe_files(
subscribe_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""
订阅相关文件信息
"""
subscribe = Subscribe.get(db, subscribe_id)
subscribe = get_accessible_subscribe_sync(db, subscribe_id, current_user)
if subscribe:
return SubscribeChain().subscribe_files_info(subscribe)
return schemas.SubscrbieInfo()
@@ -594,11 +695,16 @@ def subscribe_files(
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
async def subscribe_share(
sub: schemas.SubscribeShare, _: schemas.TokenPayload = Depends(verify_token)
sub: schemas.SubscribeShare,
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
分享订阅
"""
subscribe = await get_accessible_subscribe(db, sub.subscribe_id, current_user)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
state, errmsg = await MoviePilotServerHelper.async_sub_share(
subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
@@ -728,26 +834,27 @@ async def subscribe_share_statistics(
async def read_subscribe(
subscribe_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
根据订阅编号查询订阅信息
"""
if not subscribe_id:
return Subscribe()
return await Subscribe.async_get(db, subscribe_id)
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
return subscribe if subscribe else Subscribe()
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
async def delete_subscribe(
subscribe_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
current_user: User = Depends(get_current_active_user_async),
) -> Any:
"""
删除订阅信息
"""
subscribe = await Subscribe.async_get(db, subscribe_id)
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
if subscribe:
# 在删除之前获取订阅信息
subscribe_info = build_subscribe_event_payload(subscribe)

View File

@@ -130,6 +130,46 @@ class Subscribe(Base):
return None
return result.scalars().first()
@classmethod
@db_query
def exists_by_username(cls, db: Session, username: str, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, season: Optional[int] = None):
"""
按订阅 owner 查询同一媒体的订阅行。
"""
if not username:
return None
if tmdbid:
query = db.query(cls).filter(cls.username == username, cls.tmdbid == tmdbid)
if season is not None:
query = query.filter(cls.season == season)
return query.first()
elif doubanid:
return db.query(cls).filter(cls.username == username, cls.doubanid == doubanid).first()
return None
@classmethod
@async_db_query
async def async_exists_by_username(cls, db: AsyncSession, username: str, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, season: Optional[int] = None):
"""
异步按订阅 owner 查询同一媒体的订阅行。
"""
if not username:
return None
if tmdbid:
query = select(cls).filter(cls.username == username, cls.tmdbid == tmdbid)
if season is not None:
query = query.filter(cls.season == season)
result = await db.execute(query)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.username == username, cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()
@classmethod
@db_query
def get_by_state(cls, db: Session, state: str):
@@ -174,6 +214,22 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@async_db_query
async def async_list_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
"""
异步按标题查询候选订阅列表。
"""
if season is not None:
result = await db.execute(
select(cls).filter(cls.name == title, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.name == title)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
@@ -209,6 +265,17 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@async_db_query
async def async_list_by_doubanid(cls, db: AsyncSession, doubanid: str):
"""
异步按豆瓣 ID 查询候选订阅列表。
"""
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_bangumiid(cls, db: Session, bangumiid: int):
@@ -222,6 +289,17 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@async_db_query
async def async_list_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
"""
异步按 Bangumi ID 查询候选订阅列表。
"""
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_mediaid(cls, db: Session, mediaid: str):
@@ -235,6 +313,17 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@async_db_query
async def async_list_by_mediaid(cls, db: AsyncSession, mediaid: str):
"""
异步按自定义媒体 ID 查询候选订阅列表。
"""
result = await db.execute(
select(cls).filter(cls.mediaid == mediaid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by(cls, db: Session, type: str, season: Optional[str] = None,

View File

@@ -102,6 +102,31 @@ class SubscribeHistory(Base):
)
return result.scalars().all()
@classmethod
@async_db_query
async def async_list_by_type_and_username(
cls,
db: AsyncSession,
mtype: str,
username: str,
page: Optional[int] = 1,
count: Optional[int] = 30
):
"""
按订阅 owner 查询指定类型的历史分页。
"""
if not username:
return []
result = await db.execute(
select(cls).filter(
cls.type == mtype,
cls.username == username
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,

View File

@@ -29,10 +29,19 @@ class SubscribeOper(DbOper):
"""
新增订阅
"""
subscribe = Subscribe.exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
owner_scope = bool(kwargs.pop("owner_scope", False))
username = kwargs.get("username") if owner_scope else None
if username:
subscribe = Subscribe.exists_by_username(self._db,
username=username,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
else:
subscribe = Subscribe.exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
kwargs.update({
"name": mediainfo.title,
"year": mediainfo.year,
@@ -55,10 +64,17 @@ class SubscribeOper(DbOper):
subscribe = Subscribe(**kwargs)
subscribe.create(self._db)
# 查询订阅
subscribe = Subscribe.exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
if username:
subscribe = Subscribe.exists_by_username(self._db,
username=username,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
else:
subscribe = Subscribe.exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
return subscribe.id, "新增订阅成功"
else:
return subscribe.id, "订阅已存在"
@@ -67,10 +83,19 @@ class SubscribeOper(DbOper):
"""
异步新增订阅
"""
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
owner_scope = bool(kwargs.pop("owner_scope", False))
username = kwargs.get("username") if owner_scope else None
if username:
subscribe = await Subscribe.async_exists_by_username(self._db,
username=username,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
else:
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
kwargs.update({
"name": mediainfo.title,
"year": mediainfo.year,
@@ -93,10 +118,17 @@ class SubscribeOper(DbOper):
subscribe = Subscribe(**kwargs)
await subscribe.async_create(self._db)
# 查询订阅
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
if username:
subscribe = await Subscribe.async_exists_by_username(self._db,
username=username,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
else:
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
return subscribe.id, "新增订阅成功"
else:
return subscribe.id, "订阅已存在"

View File

@@ -13,6 +13,585 @@ class SubscribeEndpointTest(TestCase):
订阅接口回归测试。
"""
def test_read_subscribes_scopes_regular_user_and_keeps_superuser_global(self):
"""
普通用户只能看到自己创建的订阅,超级用户保留全局视图。
"""
from app.api.endpoints.subscribe import list_subscribes, read_subscribes
own = _EndpointSubscribe(id=1, username="alice", name="自己的订阅")
other = _EndpointSubscribe(id=2, username="bob", name="他人的订阅")
legacy = _EndpointSubscribe(id=3, username=None, name="旧订阅")
all_subscribes = [own, other, legacy]
with patch(
"app.api.endpoints.subscribe.Subscribe.async_list",
new=AsyncMock(return_value=all_subscribes),
), patch(
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
new=AsyncMock(return_value=[own]),
):
api_token_result = asyncio.run(list_subscribes(_="api-token"))
self.assertEqual([sub.id for sub in api_token_result], [1, 2, 3])
regular_result = asyncio.run(
read_subscribes(
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertEqual([sub.id for sub in regular_result], [1])
superuser_result = asyncio.run(
read_subscribes(
db=object(),
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertEqual([sub.id for sub in superuser_result], [1, 2, 3])
def test_read_subscribe_hides_other_and_legacy_from_regular_user(self):
"""
订阅详情按 owner 隐藏他人和 legacy 订阅,避免泄露订阅行存在性。
"""
from app.api.endpoints.subscribe import read_subscribe
current_user = _EndpointUser(name="alice", is_superuser=False)
cases = [
(_EndpointSubscribe(id=1, username="alice", name="自己的订阅"), 1),
(_EndpointSubscribe(id=2, username="bob", name="他人的订阅"), None),
(_EndpointSubscribe(id=3, username=None, name="旧订阅"), None),
]
for subscribe, expected_id in cases:
with self.subTest(subscribe_id=subscribe.id), patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(return_value=subscribe),
):
result = asyncio.run(
read_subscribe(
subscribe_id=subscribe.id,
db=object(),
current_user=current_user,
)
)
self.assertEqual(getattr(result, "id", None), expected_id)
def test_manage_permission_does_not_allow_cross_user_update(self):
"""
manage 权限不等于跨用户订阅管理权限,普通用户不能修改他人或 legacy 订阅。
"""
from app.api.endpoints.subscribe import update_subscribe
manage_user = _EndpointUser(
name="alice",
is_superuser=False,
permissions={"manage": True},
)
for subscribe in [
_EndpointSubscribe(
id=2,
username="bob",
name="他人的订阅",
total_episode=8,
lack_episode=2,
),
_EndpointSubscribe(
id=3,
username=None,
name="旧订阅",
total_episode=8,
lack_episode=2,
),
]:
with self.subTest(subscribe_id=subscribe.id), patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(return_value=subscribe),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(
update_subscribe(
subscribe_in=Subscribe(
id=subscribe.id,
name="改名",
total_episode=8,
lack_episode=2,
),
db=object(),
current_user=manage_user,
)
)
self.assertFalse(response.success)
self.assertEqual(response.message, "订阅不存在")
send_event.assert_not_awaited()
def test_owner_can_update_own_subscribe(self):
"""
owner 可以继续管理自己创建的订阅。
"""
from app.api.endpoints.subscribe import update_subscribe
subscribe = _EndpointSubscribe(
id=4,
username="alice",
name="旧标题",
total_episode=8,
lack_episode=2,
vote=0.0,
sites=[],
search_imdbid=0,
filter_groups=[],
start_episode=0,
)
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(
update_subscribe(
subscribe_in=Subscribe(
id=4,
name="新标题",
total_episode=8,
lack_episode=2,
),
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertTrue(response.success)
send_event.assert_awaited_once()
def test_update_subscribe_preserves_existing_owner(self):
"""
普通更新不得允许请求体改写订阅 owner。
"""
from app.api.endpoints.subscribe import update_subscribe
subscribe = _EndpointSubscribe(
id=12,
username="alice",
name="旧标题",
total_episode=8,
lack_episode=2,
vote=0.0,
sites=[],
search_imdbid=0,
filter_groups=[],
start_episode=0,
)
subscribe_in = Subscribe(
id=12,
username="bob",
name="新标题",
total_episode=8,
lack_episode=2,
)
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(
update_subscribe(
subscribe_in=subscribe_in,
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertTrue(response.success)
self.assertEqual(subscribe.username, "alice")
event_type, payload = send_event.await_args.args
self.assertEqual(event_type, EventType.SubscribeModified)
self.assertNotIn("username", payload["fields"])
self.assertEqual(payload["subscribe_info"]["username"], "alice")
def test_superuser_can_update_other_and_legacy_subscribe(self):
"""
超级用户可以管理他人和 legacy 订阅。
"""
from app.api.endpoints.subscribe import update_subscribe_status
current_user = _EndpointUser(name="admin", is_superuser=True)
for subscribe in [
_EndpointSubscribe(id=5, username="bob", state="R", name="他人的订阅"),
_EndpointSubscribe(id=6, username=None, state="R", name="旧订阅"),
]:
with self.subTest(subscribe_id=subscribe.id), patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(
update_subscribe_status(
subid=subscribe.id,
state="S",
db=object(),
current_user=current_user,
)
)
self.assertTrue(response.success)
send_event.assert_awaited_once()
self.assertEqual(subscribe.state, "S")
def test_share_subscribe_requires_local_owner(self):
"""
分享本地订阅前必须确认当前用户有权读取该订阅行。
"""
from app.api.endpoints.subscribe import subscribe_share
from app.schemas.subscribe import SubscribeShare
other = _EndpointSubscribe(id=7, username="bob", name="他人的订阅")
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(return_value=other),
), patch(
"app.api.endpoints.subscribe.MoviePilotServerHelper.async_sub_share",
new=AsyncMock(return_value=(True, "")),
) as sub_share:
response = asyncio.run(
subscribe_share(
sub=SubscribeShare(
subscribe_id=7,
share_title="分享",
share_comment="",
share_user="alice",
),
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertFalse(response.success)
self.assertEqual(response.message, "订阅不存在")
sub_share.assert_not_awaited()
def test_subscribe_mediaid_returns_owner_when_other_candidate_matches_first(self):
"""
按媒体查询订阅时,他人订阅不能挡住当前用户自己的订阅。
"""
from app.api.endpoints.subscribe import subscribe_mediaid
other = _EndpointSubscribe(id=13, username="bob", tmdbid=123, season=1)
own = _EndpointSubscribe(id=14, username="alice", tmdbid=123, season=1)
with patch(
"app.api.endpoints.subscribe.Subscribe.async_exists",
new=AsyncMock(return_value=other),
), patch(
"app.api.endpoints.subscribe.Subscribe.async_get_by_tmdbid",
new=AsyncMock(return_value=[other, own]),
):
result = asyncio.run(
subscribe_mediaid(
mediaid="tmdb:123",
season=1,
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertEqual(result.id, 14)
def test_delete_subscribe_by_mediaid_deletes_owner_when_other_douban_match_first(self):
"""
按媒体删除订阅时,应在候选集合中删除当前用户自己的订阅。
"""
from app.api.endpoints.subscribe import delete_subscribe_by_mediaid
other = _EndpointSubscribe(id=15, username="bob", doubanid="douban-1")
own = _EndpointSubscribe(id=16, username="alice", doubanid="douban-1")
db = _EndpointAsyncDb()
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get_by_doubanid",
new=AsyncMock(return_value=other),
), patch(
"app.api.endpoints.subscribe.Subscribe.async_list_by_doubanid",
new=AsyncMock(return_value=[other, own]),
create=True,
), patch(
"app.api.endpoints.subscribe.build_subscribe_event_payload",
return_value={"id": 16, "doubanid": "douban-1"},
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(
delete_subscribe_by_mediaid(
mediaid="douban:douban-1",
db=db,
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertTrue(response.success)
self.assertEqual(db.deleted, [own])
send_event.assert_awaited_once()
def test_search_subscribes_regular_user_schedules_only_owned_rows(self):
"""
普通用户批量搜索只按自己的订阅 ID 入队。
"""
from app.api.endpoints.subscribe import search_subscribes
background_tasks = _EndpointBackgroundTasks()
owned = [
_EndpointSubscribe(id=17, username="alice", state="R"),
_EndpointSubscribe(id=18, username="alice", state="R"),
]
with patch(
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
new=AsyncMock(return_value=owned),
):
response = asyncio.run(
search_subscribes(
background_tasks=background_tasks,
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertTrue(response.success)
self.assertEqual(
[task["kwargs"]["sid"] for task in background_tasks.tasks],
[17, 18],
)
def test_subscribe_files_hides_other_user_row(self):
"""
订阅文件接口不能向普通用户暴露他人的订阅文件信息。
"""
from app.api.endpoints.subscribe import subscribe_files
other = _EndpointSubscribe(id=19, username="bob", name="他人的订阅")
with patch(
"app.api.endpoints.subscribe.Subscribe.get",
return_value=other,
), patch(
"app.api.endpoints.subscribe.SubscribeChain"
) as subscribe_chain:
result = subscribe_files(
subscribe_id=19,
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
self.assertEqual(result.episodes, {})
subscribe_chain.return_value.subscribe_files_info.assert_not_called()
def test_user_subscribes_hides_other_user_list(self):
"""
普通用户不能通过 username 参数读取其他用户订阅列表。
"""
from app.api.endpoints.subscribe import user_subscribes
with patch(
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
new=AsyncMock(return_value=[_EndpointSubscribe(id=20, username="bob")]),
) as list_by_username:
result = asyncio.run(
user_subscribes(
username="bob",
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertEqual(result, [])
list_by_username.assert_not_awaited()
def test_subscribe_oper_async_add_scopes_duplicate_lookup_by_owner(self):
"""
owner-aware 创建不应把他人已有订阅当作当前用户订阅。
"""
from app.db.subscribe_oper import SubscribeOper
other = _EndpointSubscribe(id=21, username="bob")
own = _EndpointSubscribe(id=22, username="alice")
created = SimpleNamespace(async_create=AsyncMock())
with patch("app.db.subscribe_oper.Subscribe") as subscribe_model:
subscribe_model.async_exists = AsyncMock(return_value=other)
subscribe_model.async_exists_by_username = AsyncMock(
side_effect=[None, own]
)
subscribe_model.return_value = created
sid, message = asyncio.run(
SubscribeOper(db=object()).async_add(
mediainfo=_EndpointMediaInfo(),
username="alice",
owner_scope=True,
season=1,
)
)
self.assertEqual(sid, 22)
self.assertEqual(message, "新增订阅成功")
subscribe_model.async_exists.assert_not_awaited()
self.assertEqual(subscribe_model.async_exists_by_username.await_count, 2)
created.async_create.assert_awaited_once()
def test_subscribe_history_scopes_regular_user_and_keeps_superuser_global(self):
"""
订阅历史分页必须在 DB 层按 owner 收窄,避免全局页过滤后误判没有更多数据。
"""
from app.api.endpoints.subscribe import subscribe_history
own = _EndpointSubscribe(
id=8,
username="alice",
name="自己的历史",
type=MediaType.MOVIE.value,
)
other = _EndpointSubscribe(
id=9,
username="bob",
name="他人的历史",
type=MediaType.MOVIE.value,
)
legacy = _EndpointSubscribe(
id=10,
username="",
name="旧历史",
type=MediaType.MOVIE.value,
)
db = object()
owner_query = AsyncMock(return_value=[own])
global_query = AsyncMock(return_value=[other, legacy])
with patch(
"app.api.endpoints.subscribe.SubscribeHistory.async_list_by_type",
new=global_query,
), patch(
"app.api.endpoints.subscribe.SubscribeHistory.async_list_by_type_and_username",
new=owner_query,
create=True,
):
regular_result = asyncio.run(
subscribe_history(
mtype=MediaType.MOVIE.value,
page=1,
count=2,
db=db,
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertEqual([history.id for history in regular_result], [8])
owner_query.assert_awaited_once_with(
db,
mtype=MediaType.MOVIE.value,
username="alice",
page=1,
count=2,
)
global_query.assert_not_awaited()
owner_query.reset_mock()
global_query.reset_mock(return_value=True)
global_query.return_value = [own, other, legacy]
superuser_result = asyncio.run(
subscribe_history(
mtype=MediaType.MOVIE.value,
page=1,
count=3,
db=db,
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertEqual([history.id for history in superuser_result], [8, 9, 10])
global_query.assert_awaited_once_with(
db,
mtype=MediaType.MOVIE.value,
page=1,
count=3,
)
owner_query.assert_not_awaited()
def test_delete_subscribe_history_hides_other_from_regular_user(self):
"""
普通用户删除他人订阅历史时按不存在处理。
"""
from app.api.endpoints.subscribe import delete_subscribe_history
other = _EndpointSubscribe(
id=11,
username="bob",
name="他人的历史",
type=MediaType.MOVIE.value,
)
with patch(
"app.api.endpoints.subscribe.SubscribeHistory.async_get",
new=AsyncMock(return_value=other),
), patch(
"app.api.endpoints.subscribe.SubscribeHistory.async_delete",
new=AsyncMock(),
) as async_delete:
response = asyncio.run(
delete_subscribe_history(
history_id=11,
db=object(),
current_user=_EndpointUser(name="alice", is_superuser=False),
)
)
self.assertTrue(response.success)
async_delete.assert_not_awaited()
def test_global_refresh_and_check_require_superuser(self):
"""
没有 owner 参数的全局订阅任务只允许超级用户触发。
"""
from app.api.endpoints.subscribe import check_subscribes, refresh_subscribes
regular_user = _EndpointUser(name="alice", is_superuser=False)
superuser = _EndpointUser(name="admin", is_superuser=True)
for endpoint in [refresh_subscribes, check_subscribes]:
with self.subTest(endpoint=endpoint.__name__), patch(
"app.api.endpoints.subscribe.Scheduler"
) as scheduler:
response = endpoint(current_user=regular_user)
self.assertFalse(response.success)
self.assertEqual(response.message, "订阅不存在")
scheduler.return_value.start.assert_not_called()
for endpoint, job_id in [
(refresh_subscribes, "subscribe_refresh"),
(check_subscribes, "subscribe_tmdb"),
]:
with self.subTest(endpoint=endpoint.__name__), patch(
"app.api.endpoints.subscribe.Scheduler"
) as scheduler:
response = endpoint(current_user=superuser)
self.assertTrue(response.success)
scheduler.return_value.start.assert_called_once_with(job_id)
def test_create_subscribe_excludes_completed_episode_from_write_payload(self):
"""
新增订阅时不应把 completed_episode 派生字段传入持久化链路。
@@ -35,13 +614,14 @@ class SubscribeEndpointTest(TestCase):
response = asyncio.run(
create_subscribe(
subscribe_in=subscribe_in,
current_user=SimpleNamespace(name="moviepilot-user"),
current_user=_EndpointUser(name="moviepilot-user", is_superuser=False),
)
)
self.assertTrue(response.success)
self.assertNotIn("completed_episode", async_add.await_args.kwargs)
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
self.assertTrue(async_add.await_args.kwargs["owner_scope"])
def test_create_subscribe_preserves_special_season_zero_with_doubanid(self):
"""
@@ -67,12 +647,37 @@ class SubscribeEndpointTest(TestCase):
response = asyncio.run(
create_subscribe(
subscribe_in=subscribe_in,
current_user=SimpleNamespace(name="moviepilot-user"),
current_user=_EndpointUser(name="moviepilot-user", is_superuser=False),
)
)
self.assertTrue(response.success)
self.assertEqual(async_add.await_args.kwargs["season"], 0)
self.assertTrue(async_add.await_args.kwargs["owner_scope"])
def test_create_subscribe_keeps_superuser_global_deduplication(self):
"""
超级用户新增订阅保持全局去重语义。
"""
subscribe_in = Subscribe(
name="测试电影",
year="2026",
type=MediaType.MOVIE.value,
)
with patch(
"app.api.endpoints.subscribe.SubscribeChain.async_add",
new=AsyncMock(return_value=(1, "订阅已存在")),
) as async_add:
response = asyncio.run(
create_subscribe(
subscribe_in=subscribe_in,
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertTrue(response.success)
self.assertFalse(async_add.await_args.kwargs["owner_scope"])
def test_update_status_sends_modified_event_payload_with_scene_and_fields(self):
"""
@@ -89,7 +694,14 @@ class SubscribeEndpointTest(TestCase):
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(update_subscribe_status(subid=5, state="S", db=object()))
response = asyncio.run(
update_subscribe_status(
subid=5,
state="S",
db=object(),
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertTrue(response.success)
send_event.assert_awaited_once()
@@ -125,7 +737,13 @@ class SubscribeEndpointTest(TestCase):
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(reset_subscribes(subid=6, db=object()))
response = asyncio.run(
reset_subscribes(
subid=6,
db=object(),
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertTrue(response.success)
send_event.assert_awaited_once()
@@ -166,7 +784,13 @@ class SubscribeEndpointTest(TestCase):
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(update_subscribe(subscribe_in=subscribe_in, db=object()))
response = asyncio.run(
update_subscribe(
subscribe_in=subscribe_in,
db=object(),
current_user=_EndpointUser(name="admin", is_superuser=True),
)
)
self.assertTrue(response.success)
send_event.assert_awaited_once()
@@ -179,6 +803,77 @@ class SubscribeEndpointTest(TestCase):
self.assertEqual(payload["subscribe_info"]["name"], "新标题")
class _EndpointUser(SimpleNamespace):
"""
最小用户替身,模拟订阅 endpoint 依赖的用户权限字段。
"""
def __init__(self, name: str, is_superuser: bool, permissions: dict | None = None):
super().__init__(
name=name,
is_superuser=is_superuser,
permissions=permissions or {},
)
class _EndpointAsyncDb:
"""
最小异步数据库替身,用于观察 endpoint 删除的订阅对象。
"""
def __init__(self):
self.deleted = []
self.committed = False
self.rolled_back = False
async def delete(self, obj):
self.deleted.append(obj)
async def commit(self):
self.committed = True
async def rollback(self):
self.rolled_back = True
class _EndpointBackgroundTasks:
"""
最小后台任务替身,记录 endpoint 入队的任务参数。
"""
def __init__(self):
self.tasks = []
def add_task(self, func, **kwargs):
self.tasks.append({"func": func, "kwargs": kwargs})
class _EndpointMediaInfo:
"""
最小媒体信息替身,模拟 SubscribeOper 写订阅行所需字段。
"""
title = "测试剧集"
year = "2026"
type = MediaType.TV
tmdb_id = 123
imdb_id = "tt123"
tvdb_id = 456
douban_id = "douban-1"
bangumi_id = 789
episode_group = None
vote_average = 8.0
overview = "测试简介"
@staticmethod
def get_poster_image():
return "poster.jpg"
@staticmethod
def get_backdrop_image():
return "backdrop.jpg"
class _EndpointSubscribe:
"""
最小订阅替身,模拟 endpoint 依赖的 ORM 对象接口。
@@ -186,6 +881,7 @@ class _EndpointSubscribe:
def __init__(self, **kwargs):
self.id = kwargs.pop("id", None)
self.username = kwargs.pop("username", None)
self.name = kwargs.pop("name", None)
self.total_episode = kwargs.pop("total_episode", None)
self.lack_episode = kwargs.pop("lack_episode", None)