mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-05 19:38:40 +08:00
fix(security): scope subscriptions to owner (#6056)
This commit is contained in:
@@ -17,7 +17,7 @@ from app.db.models.subscribe import Subscribe
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user_async
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
@@ -51,14 +51,69 @@ def build_subscribe_event_payload(subscribe: Subscribe) -> dict:
|
||||
return {column.name: values.get(column.name) for column in subscribe.__table__.columns}
|
||||
|
||||
|
||||
def can_access_subscribe(
|
||||
subscribe: Subscribe | SubscribeHistory | None, current_user: User
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前用户是否可访问订阅及其历史记录。
|
||||
|
||||
超级用户拥有全局订阅管理能力;普通用户只能访问 username 精确匹配自己的订阅。
|
||||
空 username 表示无法归属的 legacy 订阅,只能由超级用户管理。
|
||||
"""
|
||||
if not subscribe:
|
||||
return False
|
||||
if current_user.is_superuser:
|
||||
return True
|
||||
username = subscribe.username
|
||||
return bool(username) and username == current_user.name
|
||||
|
||||
|
||||
async def get_accessible_subscribe(
|
||||
db: AsyncSession, subscribe_id: int, current_user: User
|
||||
) -> Subscribe | None:
|
||||
"""
|
||||
按订阅 ID 读取当前用户可访问的订阅行。
|
||||
"""
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if can_access_subscribe(subscribe, current_user):
|
||||
return subscribe
|
||||
return None
|
||||
|
||||
|
||||
def get_accessible_subscribe_sync(
|
||||
db: Session, subscribe_id: int, current_user: User
|
||||
) -> Subscribe | None:
|
||||
"""
|
||||
同步读取当前用户可访问的订阅行。
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subscribe_id)
|
||||
if can_access_subscribe(subscribe, current_user):
|
||||
return subscribe
|
||||
return None
|
||||
|
||||
|
||||
def select_accessible_subscribe(
|
||||
subscribes: List[Subscribe], current_user: User
|
||||
) -> Subscribe | None:
|
||||
"""
|
||||
从候选订阅中选择当前用户可访问的第一条记录。
|
||||
"""
|
||||
for subscribe in subscribes or []:
|
||||
if can_access_subscribe(subscribe, current_user):
|
||||
return subscribe
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
|
||||
async def read_subscribes(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
查询所有订阅
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
return await Subscribe.async_list_by_username(db, current_user.name)
|
||||
return await Subscribe.async_list(db)
|
||||
|
||||
|
||||
@@ -69,7 +124,7 @@ async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询所有订阅 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return await read_subscribes()
|
||||
return await Subscribe.async_list()
|
||||
|
||||
|
||||
@router.post("/", summary="新增订阅", response_model=schemas.Response)
|
||||
@@ -106,7 +161,11 @@ async def create_subscribe(
|
||||
# completed_episode 是响应派生字段,禁止写入持久层
|
||||
subscribe_dict.pop("completed_episode", None)
|
||||
sid, message = await SubscribeChain().async_add(
|
||||
mtype=mtype, title=title, exist_ok=True, **subscribe_dict
|
||||
mtype=mtype,
|
||||
title=title,
|
||||
exist_ok=True,
|
||||
owner_scope=not current_user.is_superuser,
|
||||
**subscribe_dict,
|
||||
)
|
||||
return schemas.Response(success=bool(sid), message=message, data={"id": sid})
|
||||
|
||||
@@ -116,17 +175,18 @@ async def update_subscribe(
|
||||
*,
|
||||
subscribe_in: schemas.Subscribe,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
更新订阅信息
|
||||
"""
|
||||
subscribe = await Subscribe.async_get(db, subscribe_in.id)
|
||||
subscribe = await get_accessible_subscribe(db, subscribe_in.id, current_user)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
# 避免更新缺失集数
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
subscribe_dict["username"] = subscribe.username
|
||||
if subscribe_in.episode_priority is None:
|
||||
subscribe_dict.pop("episode_priority", None)
|
||||
# completed_episode 是响应派生字段,禁止写入持久层
|
||||
@@ -165,12 +225,12 @@ async def update_subscribe_status(
|
||||
subid: int,
|
||||
state: str,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
更新订阅状态
|
||||
"""
|
||||
subscribe = await Subscribe.async_get(db, subid)
|
||||
subscribe = await get_accessible_subscribe(db, subid, current_user)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
valid_states = ["R", "P", "S"]
|
||||
@@ -199,7 +259,7 @@ async def subscribe_mediaid(
|
||||
season: Optional[int] = None,
|
||||
title: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
|
||||
@@ -209,23 +269,27 @@ async def subscribe_mediaid(
|
||||
tmdbid = mediaid[5:]
|
||||
if not tmdbid or not str(tmdbid).isdigit():
|
||||
return Subscribe()
|
||||
result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
|
||||
subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
|
||||
result = select_accessible_subscribe(subscribes, current_user)
|
||||
elif mediaid.startswith("douban:"):
|
||||
doubanid = mediaid[7:]
|
||||
if not doubanid:
|
||||
return Subscribe()
|
||||
result = await Subscribe.async_get_by_doubanid(db, doubanid)
|
||||
subscribes = await Subscribe.async_list_by_doubanid(db, doubanid)
|
||||
result = select_accessible_subscribe(subscribes, current_user)
|
||||
if not result and title:
|
||||
title_check = True
|
||||
elif mediaid.startswith("bangumi:"):
|
||||
bangumiid = mediaid[8:]
|
||||
if not bangumiid or not str(bangumiid).isdigit():
|
||||
return Subscribe()
|
||||
result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
|
||||
subscribes = await Subscribe.async_list_by_bangumiid(db, int(bangumiid))
|
||||
result = select_accessible_subscribe(subscribes, current_user)
|
||||
if not result and title:
|
||||
title_check = True
|
||||
else:
|
||||
result = await Subscribe.async_get_by_mediaid(db, mediaid)
|
||||
subscribes = await Subscribe.async_list_by_mediaid(db, mediaid)
|
||||
result = select_accessible_subscribe(subscribes, current_user)
|
||||
if not result and title:
|
||||
title_check = True
|
||||
# 使用名称检查订阅
|
||||
@@ -233,18 +297,23 @@ async def subscribe_mediaid(
|
||||
meta = MetaInfo(title)
|
||||
if season is not None:
|
||||
meta.begin_season = season
|
||||
result = await Subscribe.async_get_by_title(
|
||||
subscribes = await Subscribe.async_list_by_title(
|
||||
db, title=meta.name, season=meta.begin_season
|
||||
)
|
||||
result = select_accessible_subscribe(subscribes, current_user)
|
||||
|
||||
return result if result else Subscribe()
|
||||
|
||||
|
||||
@router.get("/refresh", summary="刷新订阅", response_model=schemas.Response)
|
||||
def refresh_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def refresh_subscribes(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
刷新所有订阅
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
Scheduler().start("subscribe_refresh")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -253,12 +322,12 @@ def refresh_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def reset_subscribes(
|
||||
subid: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
重置订阅
|
||||
"""
|
||||
subscribe = await Subscribe.async_get(db, subid)
|
||||
subscribe = await get_accessible_subscribe(db, subid, current_user)
|
||||
if subscribe:
|
||||
# 在更新之前获取旧数据
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
@@ -292,26 +361,43 @@ async def reset_subscribes(
|
||||
|
||||
|
||||
@router.get("/check", summary="刷新订阅 TMDB 信息", response_model=schemas.Response)
|
||||
def check_subscribes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def check_subscribes(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
刷新订阅 TMDB 信息
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
Scheduler().start("subscribe_tmdb")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
|
||||
async def search_subscribes(
|
||||
background_tasks: BackgroundTasks, _: schemas.TokenPayload = Depends(verify_token)
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
搜索所有订阅
|
||||
"""
|
||||
background_tasks.add_task(
|
||||
Scheduler().start,
|
||||
job_id="subscribe_search",
|
||||
**{"sid": None, "state": "R", "manual": True},
|
||||
)
|
||||
if current_user.is_superuser:
|
||||
background_tasks.add_task(
|
||||
Scheduler().start,
|
||||
job_id="subscribe_search",
|
||||
**{"sid": None, "state": "R", "manual": True},
|
||||
)
|
||||
else:
|
||||
subscribes = await Subscribe.async_list_by_username(
|
||||
db, current_user.name, state="R"
|
||||
)
|
||||
for subscribe in subscribes:
|
||||
background_tasks.add_task(
|
||||
Scheduler().start,
|
||||
job_id="subscribe_search",
|
||||
**{"sid": subscribe.id, "state": None, "manual": True},
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -321,11 +407,15 @@ async def search_subscribes(
|
||||
async def search_subscribe(
|
||||
subscribe_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
根据订阅编号搜索订阅
|
||||
"""
|
||||
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
background_tasks.add_task(
|
||||
Scheduler().start,
|
||||
job_id="subscribe_search",
|
||||
@@ -339,7 +429,7 @@ async def delete_subscribe_by_mediaid(
|
||||
mediaid: str,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
根据TMDBID或豆瓣ID删除订阅 tmdb:/douban:
|
||||
@@ -355,15 +445,17 @@ async def delete_subscribe_by_mediaid(
|
||||
doubanid = mediaid[7:]
|
||||
if not doubanid:
|
||||
return schemas.Response(success=False)
|
||||
subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
|
||||
if subscribe:
|
||||
delete_subscribes.append(subscribe)
|
||||
subscribes = await Subscribe.async_list_by_doubanid(db, doubanid)
|
||||
delete_subscribes.extend(subscribes)
|
||||
else:
|
||||
subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
|
||||
if subscribe:
|
||||
delete_subscribes.append(subscribe)
|
||||
subscribes = await Subscribe.async_list_by_mediaid(db, mediaid)
|
||||
delete_subscribes.extend(subscribes)
|
||||
delete_events = []
|
||||
for subscribe in delete_subscribes:
|
||||
for subscribe in [
|
||||
subscribe
|
||||
for subscribe in delete_subscribes
|
||||
if can_access_subscribe(subscribe, current_user)
|
||||
]:
|
||||
subscribe_info = build_subscribe_event_payload(subscribe)
|
||||
subscribe_id = subscribe_info.get("id")
|
||||
if not subscribe_id:
|
||||
@@ -464,14 +556,19 @@ async def subscribe_history(
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
查询电影/电视剧订阅历史
|
||||
"""
|
||||
histories = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=mtype, page=page, count=count
|
||||
)
|
||||
if current_user.is_superuser:
|
||||
histories = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=mtype, page=page, count=count
|
||||
)
|
||||
else:
|
||||
histories = await SubscribeHistory.async_list_by_type_and_username(
|
||||
db, mtype=mtype, username=current_user.name, page=page, count=count
|
||||
)
|
||||
result = []
|
||||
for history in histories:
|
||||
history_item = schemas.Subscribe.model_validate(history, from_attributes=True)
|
||||
@@ -488,12 +585,14 @@ async def subscribe_history(
|
||||
async def delete_subscribe_history(
|
||||
history_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
删除订阅历史
|
||||
"""
|
||||
await SubscribeHistory.async_delete(db, history_id)
|
||||
history = await SubscribeHistory.async_get(db, history_id)
|
||||
if can_access_subscribe(history, current_user):
|
||||
await SubscribeHistory.async_delete(db, history_id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -565,11 +664,13 @@ async def popular_subscribes(
|
||||
async def user_subscribes(
|
||||
username: str,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
查询用户订阅
|
||||
"""
|
||||
if not current_user.is_superuser and username != current_user.name:
|
||||
return []
|
||||
return await Subscribe.async_list_by_username(db, username)
|
||||
|
||||
|
||||
@@ -581,12 +682,12 @@ async def user_subscribes(
|
||||
def subscribe_files(
|
||||
subscribe_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
订阅相关文件信息
|
||||
"""
|
||||
subscribe = Subscribe.get(db, subscribe_id)
|
||||
subscribe = get_accessible_subscribe_sync(db, subscribe_id, current_user)
|
||||
if subscribe:
|
||||
return SubscribeChain().subscribe_files_info(subscribe)
|
||||
return schemas.SubscrbieInfo()
|
||||
@@ -594,11 +695,16 @@ def subscribe_files(
|
||||
|
||||
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
|
||||
async def subscribe_share(
|
||||
sub: schemas.SubscribeShare, _: schemas.TokenPayload = Depends(verify_token)
|
||||
sub: schemas.SubscribeShare,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
分享订阅
|
||||
"""
|
||||
subscribe = await get_accessible_subscribe(db, sub.subscribe_id, current_user)
|
||||
if not subscribe:
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
state, errmsg = await MoviePilotServerHelper.async_sub_share(
|
||||
subscribe_id=sub.subscribe_id,
|
||||
share_title=sub.share_title,
|
||||
@@ -728,26 +834,27 @@ async def subscribe_share_statistics(
|
||||
async def read_subscribe(
|
||||
subscribe_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
根据订阅编号查询订阅信息
|
||||
"""
|
||||
if not subscribe_id:
|
||||
return Subscribe()
|
||||
return await Subscribe.async_get(db, subscribe_id)
|
||||
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
|
||||
return subscribe if subscribe else Subscribe()
|
||||
|
||||
|
||||
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
|
||||
async def delete_subscribe(
|
||||
subscribe_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
current_user: User = Depends(get_current_active_user_async),
|
||||
) -> Any:
|
||||
"""
|
||||
删除订阅信息
|
||||
"""
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
subscribe = await get_accessible_subscribe(db, subscribe_id, current_user)
|
||||
if subscribe:
|
||||
# 在删除之前获取订阅信息
|
||||
subscribe_info = build_subscribe_event_payload(subscribe)
|
||||
|
||||
@@ -130,6 +130,46 @@ class Subscribe(Base):
|
||||
return None
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists_by_username(cls, db: Session, username: str, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, season: Optional[int] = None):
|
||||
"""
|
||||
按订阅 owner 查询同一媒体的订阅行。
|
||||
"""
|
||||
if not username:
|
||||
return None
|
||||
if tmdbid:
|
||||
query = db.query(cls).filter(cls.username == username, cls.tmdbid == tmdbid)
|
||||
if season is not None:
|
||||
query = query.filter(cls.season == season)
|
||||
return query.first()
|
||||
elif doubanid:
|
||||
return db.query(cls).filter(cls.username == username, cls.doubanid == doubanid).first()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_exists_by_username(cls, db: AsyncSession, username: str, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, season: Optional[int] = None):
|
||||
"""
|
||||
异步按订阅 owner 查询同一媒体的订阅行。
|
||||
"""
|
||||
if not username:
|
||||
return None
|
||||
if tmdbid:
|
||||
query = select(cls).filter(cls.username == username, cls.tmdbid == tmdbid)
|
||||
if season is not None:
|
||||
query = query.filter(cls.season == season)
|
||||
result = await db.execute(query)
|
||||
elif doubanid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.username == username, cls.doubanid == doubanid)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_state(cls, db: Session, state: str):
|
||||
@@ -174,6 +214,22 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
|
||||
"""
|
||||
异步按标题查询候选订阅列表。
|
||||
"""
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title, cls.season == season)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
|
||||
@@ -209,6 +265,17 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_doubanid(cls, db: AsyncSession, doubanid: str):
|
||||
"""
|
||||
异步按豆瓣 ID 查询候选订阅列表。
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_bangumiid(cls, db: Session, bangumiid: int):
|
||||
@@ -222,6 +289,17 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
|
||||
"""
|
||||
异步按 Bangumi ID 查询候选订阅列表。
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.bangumiid == bangumiid)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_mediaid(cls, db: Session, mediaid: str):
|
||||
@@ -235,6 +313,17 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_mediaid(cls, db: AsyncSession, mediaid: str):
|
||||
"""
|
||||
异步按自定义媒体 ID 查询候选订阅列表。
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.mediaid == mediaid)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by(cls, db: Session, type: str, season: Optional[str] = None,
|
||||
|
||||
@@ -102,6 +102,31 @@ class SubscribeHistory(Base):
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_type_and_username(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
mtype: str,
|
||||
username: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30
|
||||
):
|
||||
"""
|
||||
按订阅 owner 查询指定类型的历史分页。
|
||||
"""
|
||||
if not username:
|
||||
return []
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.type == mtype,
|
||||
cls.username == username
|
||||
).order_by(
|
||||
cls.date.desc()
|
||||
).offset((page - 1) * count).limit(count)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
|
||||
@@ -29,10 +29,19 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
新增订阅
|
||||
"""
|
||||
subscribe = Subscribe.exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
owner_scope = bool(kwargs.pop("owner_scope", False))
|
||||
username = kwargs.get("username") if owner_scope else None
|
||||
if username:
|
||||
subscribe = Subscribe.exists_by_username(self._db,
|
||||
username=username,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
else:
|
||||
subscribe = Subscribe.exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
kwargs.update({
|
||||
"name": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
@@ -55,10 +64,17 @@ class SubscribeOper(DbOper):
|
||||
subscribe = Subscribe(**kwargs)
|
||||
subscribe.create(self._db)
|
||||
# 查询订阅
|
||||
subscribe = Subscribe.exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
if username:
|
||||
subscribe = Subscribe.exists_by_username(self._db,
|
||||
username=username,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
else:
|
||||
subscribe = Subscribe.exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
return subscribe.id, "新增订阅成功"
|
||||
else:
|
||||
return subscribe.id, "订阅已存在"
|
||||
@@ -67,10 +83,19 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
异步新增订阅
|
||||
"""
|
||||
subscribe = await Subscribe.async_exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
owner_scope = bool(kwargs.pop("owner_scope", False))
|
||||
username = kwargs.get("username") if owner_scope else None
|
||||
if username:
|
||||
subscribe = await Subscribe.async_exists_by_username(self._db,
|
||||
username=username,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
else:
|
||||
subscribe = await Subscribe.async_exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
kwargs.update({
|
||||
"name": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
@@ -93,10 +118,17 @@ class SubscribeOper(DbOper):
|
||||
subscribe = Subscribe(**kwargs)
|
||||
await subscribe.async_create(self._db)
|
||||
# 查询订阅
|
||||
subscribe = await Subscribe.async_exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
if username:
|
||||
subscribe = await Subscribe.async_exists_by_username(self._db,
|
||||
username=username,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
else:
|
||||
subscribe = await Subscribe.async_exists(self._db,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id,
|
||||
season=kwargs.get('season'))
|
||||
return subscribe.id, "新增订阅成功"
|
||||
else:
|
||||
return subscribe.id, "订阅已存在"
|
||||
|
||||
@@ -13,6 +13,585 @@ class SubscribeEndpointTest(TestCase):
|
||||
订阅接口回归测试。
|
||||
"""
|
||||
|
||||
def test_read_subscribes_scopes_regular_user_and_keeps_superuser_global(self):
|
||||
"""
|
||||
普通用户只能看到自己创建的订阅,超级用户保留全局视图。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import list_subscribes, read_subscribes
|
||||
|
||||
own = _EndpointSubscribe(id=1, username="alice", name="自己的订阅")
|
||||
other = _EndpointSubscribe(id=2, username="bob", name="他人的订阅")
|
||||
legacy = _EndpointSubscribe(id=3, username=None, name="旧订阅")
|
||||
all_subscribes = [own, other, legacy]
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_list",
|
||||
new=AsyncMock(return_value=all_subscribes),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
|
||||
new=AsyncMock(return_value=[own]),
|
||||
):
|
||||
api_token_result = asyncio.run(list_subscribes(_="api-token"))
|
||||
self.assertEqual([sub.id for sub in api_token_result], [1, 2, 3])
|
||||
|
||||
regular_result = asyncio.run(
|
||||
read_subscribes(
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
self.assertEqual([sub.id for sub in regular_result], [1])
|
||||
|
||||
superuser_result = asyncio.run(
|
||||
read_subscribes(
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
self.assertEqual([sub.id for sub in superuser_result], [1, 2, 3])
|
||||
|
||||
def test_read_subscribe_hides_other_and_legacy_from_regular_user(self):
|
||||
"""
|
||||
订阅详情按 owner 隐藏他人和 legacy 订阅,避免泄露订阅行存在性。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import read_subscribe
|
||||
|
||||
current_user = _EndpointUser(name="alice", is_superuser=False)
|
||||
cases = [
|
||||
(_EndpointSubscribe(id=1, username="alice", name="自己的订阅"), 1),
|
||||
(_EndpointSubscribe(id=2, username="bob", name="他人的订阅"), None),
|
||||
(_EndpointSubscribe(id=3, username=None, name="旧订阅"), None),
|
||||
]
|
||||
|
||||
for subscribe, expected_id in cases:
|
||||
with self.subTest(subscribe_id=subscribe.id), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(return_value=subscribe),
|
||||
):
|
||||
result = asyncio.run(
|
||||
read_subscribe(
|
||||
subscribe_id=subscribe.id,
|
||||
db=object(),
|
||||
current_user=current_user,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(getattr(result, "id", None), expected_id)
|
||||
|
||||
def test_manage_permission_does_not_allow_cross_user_update(self):
|
||||
"""
|
||||
manage 权限不等于跨用户订阅管理权限,普通用户不能修改他人或 legacy 订阅。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe
|
||||
|
||||
manage_user = _EndpointUser(
|
||||
name="alice",
|
||||
is_superuser=False,
|
||||
permissions={"manage": True},
|
||||
)
|
||||
|
||||
for subscribe in [
|
||||
_EndpointSubscribe(
|
||||
id=2,
|
||||
username="bob",
|
||||
name="他人的订阅",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
),
|
||||
_EndpointSubscribe(
|
||||
id=3,
|
||||
username=None,
|
||||
name="旧订阅",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
),
|
||||
]:
|
||||
with self.subTest(subscribe_id=subscribe.id), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(return_value=subscribe),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(
|
||||
update_subscribe(
|
||||
subscribe_in=Subscribe(
|
||||
id=subscribe.id,
|
||||
name="改名",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
),
|
||||
db=object(),
|
||||
current_user=manage_user,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(response.success)
|
||||
self.assertEqual(response.message, "订阅不存在")
|
||||
send_event.assert_not_awaited()
|
||||
|
||||
def test_owner_can_update_own_subscribe(self):
|
||||
"""
|
||||
owner 可以继续管理自己创建的订阅。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe
|
||||
|
||||
subscribe = _EndpointSubscribe(
|
||||
id=4,
|
||||
username="alice",
|
||||
name="旧标题",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
vote=0.0,
|
||||
sites=[],
|
||||
search_imdbid=0,
|
||||
filter_groups=[],
|
||||
start_episode=0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(
|
||||
update_subscribe(
|
||||
subscribe_in=Subscribe(
|
||||
id=4,
|
||||
name="新标题",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
),
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
|
||||
def test_update_subscribe_preserves_existing_owner(self):
|
||||
"""
|
||||
普通更新不得允许请求体改写订阅 owner。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe
|
||||
|
||||
subscribe = _EndpointSubscribe(
|
||||
id=12,
|
||||
username="alice",
|
||||
name="旧标题",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
vote=0.0,
|
||||
sites=[],
|
||||
search_imdbid=0,
|
||||
filter_groups=[],
|
||||
start_episode=0,
|
||||
)
|
||||
subscribe_in = Subscribe(
|
||||
id=12,
|
||||
username="bob",
|
||||
name="新标题",
|
||||
total_episode=8,
|
||||
lack_episode=2,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(
|
||||
update_subscribe(
|
||||
subscribe_in=subscribe_in,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(subscribe.username, "alice")
|
||||
event_type, payload = send_event.await_args.args
|
||||
self.assertEqual(event_type, EventType.SubscribeModified)
|
||||
self.assertNotIn("username", payload["fields"])
|
||||
self.assertEqual(payload["subscribe_info"]["username"], "alice")
|
||||
|
||||
def test_superuser_can_update_other_and_legacy_subscribe(self):
|
||||
"""
|
||||
超级用户可以管理他人和 legacy 订阅。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import update_subscribe_status
|
||||
|
||||
current_user = _EndpointUser(name="admin", is_superuser=True)
|
||||
for subscribe in [
|
||||
_EndpointSubscribe(id=5, username="bob", state="R", name="他人的订阅"),
|
||||
_EndpointSubscribe(id=6, username=None, state="R", name="旧订阅"),
|
||||
]:
|
||||
with self.subTest(subscribe_id=subscribe.id), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(side_effect=[subscribe, subscribe]),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(
|
||||
update_subscribe_status(
|
||||
subid=subscribe.id,
|
||||
state="S",
|
||||
db=object(),
|
||||
current_user=current_user,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
self.assertEqual(subscribe.state, "S")
|
||||
|
||||
def test_share_subscribe_requires_local_owner(self):
|
||||
"""
|
||||
分享本地订阅前必须确认当前用户有权读取该订阅行。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import subscribe_share
|
||||
from app.schemas.subscribe import SubscribeShare
|
||||
|
||||
other = _EndpointSubscribe(id=7, username="bob", name="他人的订阅")
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get",
|
||||
new=AsyncMock(return_value=other),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.MoviePilotServerHelper.async_sub_share",
|
||||
new=AsyncMock(return_value=(True, "")),
|
||||
) as sub_share:
|
||||
response = asyncio.run(
|
||||
subscribe_share(
|
||||
sub=SubscribeShare(
|
||||
subscribe_id=7,
|
||||
share_title="分享",
|
||||
share_comment="",
|
||||
share_user="alice",
|
||||
),
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(response.success)
|
||||
self.assertEqual(response.message, "订阅不存在")
|
||||
sub_share.assert_not_awaited()
|
||||
|
||||
def test_subscribe_mediaid_returns_owner_when_other_candidate_matches_first(self):
|
||||
"""
|
||||
按媒体查询订阅时,他人订阅不能挡住当前用户自己的订阅。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import subscribe_mediaid
|
||||
|
||||
other = _EndpointSubscribe(id=13, username="bob", tmdbid=123, season=1)
|
||||
own = _EndpointSubscribe(id=14, username="alice", tmdbid=123, season=1)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_exists",
|
||||
new=AsyncMock(return_value=other),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get_by_tmdbid",
|
||||
new=AsyncMock(return_value=[other, own]),
|
||||
):
|
||||
result = asyncio.run(
|
||||
subscribe_mediaid(
|
||||
mediaid="tmdb:123",
|
||||
season=1,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result.id, 14)
|
||||
|
||||
def test_delete_subscribe_by_mediaid_deletes_owner_when_other_douban_match_first(self):
|
||||
"""
|
||||
按媒体删除订阅时,应在候选集合中删除当前用户自己的订阅。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import delete_subscribe_by_mediaid
|
||||
|
||||
other = _EndpointSubscribe(id=15, username="bob", doubanid="douban-1")
|
||||
own = _EndpointSubscribe(id=16, username="alice", doubanid="douban-1")
|
||||
db = _EndpointAsyncDb()
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_get_by_doubanid",
|
||||
new=AsyncMock(return_value=other),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_list_by_doubanid",
|
||||
new=AsyncMock(return_value=[other, own]),
|
||||
create=True,
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.build_subscribe_event_payload",
|
||||
return_value={"id": 16, "doubanid": "douban-1"},
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(
|
||||
delete_subscribe_by_mediaid(
|
||||
mediaid="douban:douban-1",
|
||||
db=db,
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(db.deleted, [own])
|
||||
send_event.assert_awaited_once()
|
||||
|
||||
def test_search_subscribes_regular_user_schedules_only_owned_rows(self):
|
||||
"""
|
||||
普通用户批量搜索只按自己的订阅 ID 入队。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import search_subscribes
|
||||
|
||||
background_tasks = _EndpointBackgroundTasks()
|
||||
owned = [
|
||||
_EndpointSubscribe(id=17, username="alice", state="R"),
|
||||
_EndpointSubscribe(id=18, username="alice", state="R"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
|
||||
new=AsyncMock(return_value=owned),
|
||||
):
|
||||
response = asyncio.run(
|
||||
search_subscribes(
|
||||
background_tasks=background_tasks,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(
|
||||
[task["kwargs"]["sid"] for task in background_tasks.tasks],
|
||||
[17, 18],
|
||||
)
|
||||
|
||||
def test_subscribe_files_hides_other_user_row(self):
|
||||
"""
|
||||
订阅文件接口不能向普通用户暴露他人的订阅文件信息。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import subscribe_files
|
||||
|
||||
other = _EndpointSubscribe(id=19, username="bob", name="他人的订阅")
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.get",
|
||||
return_value=other,
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.SubscribeChain"
|
||||
) as subscribe_chain:
|
||||
result = subscribe_files(
|
||||
subscribe_id=19,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
|
||||
self.assertEqual(result.episodes, {})
|
||||
subscribe_chain.return_value.subscribe_files_info.assert_not_called()
|
||||
|
||||
def test_user_subscribes_hides_other_user_list(self):
|
||||
"""
|
||||
普通用户不能通过 username 参数读取其他用户订阅列表。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import user_subscribes
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.Subscribe.async_list_by_username",
|
||||
new=AsyncMock(return_value=[_EndpointSubscribe(id=20, username="bob")]),
|
||||
) as list_by_username:
|
||||
result = asyncio.run(
|
||||
user_subscribes(
|
||||
username="bob",
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result, [])
|
||||
list_by_username.assert_not_awaited()
|
||||
|
||||
def test_subscribe_oper_async_add_scopes_duplicate_lookup_by_owner(self):
|
||||
"""
|
||||
owner-aware 创建不应把他人已有订阅当作当前用户订阅。
|
||||
"""
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
|
||||
other = _EndpointSubscribe(id=21, username="bob")
|
||||
own = _EndpointSubscribe(id=22, username="alice")
|
||||
created = SimpleNamespace(async_create=AsyncMock())
|
||||
|
||||
with patch("app.db.subscribe_oper.Subscribe") as subscribe_model:
|
||||
subscribe_model.async_exists = AsyncMock(return_value=other)
|
||||
subscribe_model.async_exists_by_username = AsyncMock(
|
||||
side_effect=[None, own]
|
||||
)
|
||||
subscribe_model.return_value = created
|
||||
|
||||
sid, message = asyncio.run(
|
||||
SubscribeOper(db=object()).async_add(
|
||||
mediainfo=_EndpointMediaInfo(),
|
||||
username="alice",
|
||||
owner_scope=True,
|
||||
season=1,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(sid, 22)
|
||||
self.assertEqual(message, "新增订阅成功")
|
||||
subscribe_model.async_exists.assert_not_awaited()
|
||||
self.assertEqual(subscribe_model.async_exists_by_username.await_count, 2)
|
||||
created.async_create.assert_awaited_once()
|
||||
|
||||
def test_subscribe_history_scopes_regular_user_and_keeps_superuser_global(self):
|
||||
"""
|
||||
订阅历史分页必须在 DB 层按 owner 收窄,避免全局页过滤后误判没有更多数据。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import subscribe_history
|
||||
|
||||
own = _EndpointSubscribe(
|
||||
id=8,
|
||||
username="alice",
|
||||
name="自己的历史",
|
||||
type=MediaType.MOVIE.value,
|
||||
)
|
||||
other = _EndpointSubscribe(
|
||||
id=9,
|
||||
username="bob",
|
||||
name="他人的历史",
|
||||
type=MediaType.MOVIE.value,
|
||||
)
|
||||
legacy = _EndpointSubscribe(
|
||||
id=10,
|
||||
username="",
|
||||
name="旧历史",
|
||||
type=MediaType.MOVIE.value,
|
||||
)
|
||||
db = object()
|
||||
owner_query = AsyncMock(return_value=[own])
|
||||
global_query = AsyncMock(return_value=[other, legacy])
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.SubscribeHistory.async_list_by_type",
|
||||
new=global_query,
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.SubscribeHistory.async_list_by_type_and_username",
|
||||
new=owner_query,
|
||||
create=True,
|
||||
):
|
||||
regular_result = asyncio.run(
|
||||
subscribe_history(
|
||||
mtype=MediaType.MOVIE.value,
|
||||
page=1,
|
||||
count=2,
|
||||
db=db,
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
self.assertEqual([history.id for history in regular_result], [8])
|
||||
owner_query.assert_awaited_once_with(
|
||||
db,
|
||||
mtype=MediaType.MOVIE.value,
|
||||
username="alice",
|
||||
page=1,
|
||||
count=2,
|
||||
)
|
||||
global_query.assert_not_awaited()
|
||||
|
||||
owner_query.reset_mock()
|
||||
global_query.reset_mock(return_value=True)
|
||||
global_query.return_value = [own, other, legacy]
|
||||
|
||||
superuser_result = asyncio.run(
|
||||
subscribe_history(
|
||||
mtype=MediaType.MOVIE.value,
|
||||
page=1,
|
||||
count=3,
|
||||
db=db,
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
self.assertEqual([history.id for history in superuser_result], [8, 9, 10])
|
||||
global_query.assert_awaited_once_with(
|
||||
db,
|
||||
mtype=MediaType.MOVIE.value,
|
||||
page=1,
|
||||
count=3,
|
||||
)
|
||||
owner_query.assert_not_awaited()
|
||||
|
||||
def test_delete_subscribe_history_hides_other_from_regular_user(self):
|
||||
"""
|
||||
普通用户删除他人订阅历史时按不存在处理。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import delete_subscribe_history
|
||||
|
||||
other = _EndpointSubscribe(
|
||||
id=11,
|
||||
username="bob",
|
||||
name="他人的历史",
|
||||
type=MediaType.MOVIE.value,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.SubscribeHistory.async_get",
|
||||
new=AsyncMock(return_value=other),
|
||||
), patch(
|
||||
"app.api.endpoints.subscribe.SubscribeHistory.async_delete",
|
||||
new=AsyncMock(),
|
||||
) as async_delete:
|
||||
response = asyncio.run(
|
||||
delete_subscribe_history(
|
||||
history_id=11,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="alice", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
async_delete.assert_not_awaited()
|
||||
|
||||
def test_global_refresh_and_check_require_superuser(self):
|
||||
"""
|
||||
没有 owner 参数的全局订阅任务只允许超级用户触发。
|
||||
"""
|
||||
from app.api.endpoints.subscribe import check_subscribes, refresh_subscribes
|
||||
|
||||
regular_user = _EndpointUser(name="alice", is_superuser=False)
|
||||
superuser = _EndpointUser(name="admin", is_superuser=True)
|
||||
|
||||
for endpoint in [refresh_subscribes, check_subscribes]:
|
||||
with self.subTest(endpoint=endpoint.__name__), patch(
|
||||
"app.api.endpoints.subscribe.Scheduler"
|
||||
) as scheduler:
|
||||
response = endpoint(current_user=regular_user)
|
||||
|
||||
self.assertFalse(response.success)
|
||||
self.assertEqual(response.message, "订阅不存在")
|
||||
scheduler.return_value.start.assert_not_called()
|
||||
|
||||
for endpoint, job_id in [
|
||||
(refresh_subscribes, "subscribe_refresh"),
|
||||
(check_subscribes, "subscribe_tmdb"),
|
||||
]:
|
||||
with self.subTest(endpoint=endpoint.__name__), patch(
|
||||
"app.api.endpoints.subscribe.Scheduler"
|
||||
) as scheduler:
|
||||
response = endpoint(current_user=superuser)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
scheduler.return_value.start.assert_called_once_with(job_id)
|
||||
|
||||
def test_create_subscribe_excludes_completed_episode_from_write_payload(self):
|
||||
"""
|
||||
新增订阅时不应把 completed_episode 派生字段传入持久化链路。
|
||||
@@ -35,13 +614,14 @@ class SubscribeEndpointTest(TestCase):
|
||||
response = asyncio.run(
|
||||
create_subscribe(
|
||||
subscribe_in=subscribe_in,
|
||||
current_user=SimpleNamespace(name="moviepilot-user"),
|
||||
current_user=_EndpointUser(name="moviepilot-user", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertNotIn("completed_episode", async_add.await_args.kwargs)
|
||||
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
|
||||
self.assertTrue(async_add.await_args.kwargs["owner_scope"])
|
||||
|
||||
def test_create_subscribe_preserves_special_season_zero_with_doubanid(self):
|
||||
"""
|
||||
@@ -67,12 +647,37 @@ class SubscribeEndpointTest(TestCase):
|
||||
response = asyncio.run(
|
||||
create_subscribe(
|
||||
subscribe_in=subscribe_in,
|
||||
current_user=SimpleNamespace(name="moviepilot-user"),
|
||||
current_user=_EndpointUser(name="moviepilot-user", is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(async_add.await_args.kwargs["season"], 0)
|
||||
self.assertTrue(async_add.await_args.kwargs["owner_scope"])
|
||||
|
||||
def test_create_subscribe_keeps_superuser_global_deduplication(self):
|
||||
"""
|
||||
超级用户新增订阅保持全局去重语义。
|
||||
"""
|
||||
subscribe_in = Subscribe(
|
||||
name="测试电影",
|
||||
year="2026",
|
||||
type=MediaType.MOVIE.value,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "订阅已存在")),
|
||||
) as async_add:
|
||||
response = asyncio.run(
|
||||
create_subscribe(
|
||||
subscribe_in=subscribe_in,
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
self.assertFalse(async_add.await_args.kwargs["owner_scope"])
|
||||
|
||||
def test_update_status_sends_modified_event_payload_with_scene_and_fields(self):
|
||||
"""
|
||||
@@ -89,7 +694,14 @@ class SubscribeEndpointTest(TestCase):
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(update_subscribe_status(subid=5, state="S", db=object()))
|
||||
response = asyncio.run(
|
||||
update_subscribe_status(
|
||||
subid=5,
|
||||
state="S",
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
@@ -125,7 +737,13 @@ class SubscribeEndpointTest(TestCase):
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(reset_subscribes(subid=6, db=object()))
|
||||
response = asyncio.run(
|
||||
reset_subscribes(
|
||||
subid=6,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
@@ -166,7 +784,13 @@ class SubscribeEndpointTest(TestCase):
|
||||
"app.api.endpoints.subscribe.eventmanager.async_send_event",
|
||||
new=AsyncMock(),
|
||||
) as send_event:
|
||||
response = asyncio.run(update_subscribe(subscribe_in=subscribe_in, db=object()))
|
||||
response = asyncio.run(
|
||||
update_subscribe(
|
||||
subscribe_in=subscribe_in,
|
||||
db=object(),
|
||||
current_user=_EndpointUser(name="admin", is_superuser=True),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(response.success)
|
||||
send_event.assert_awaited_once()
|
||||
@@ -179,6 +803,77 @@ class SubscribeEndpointTest(TestCase):
|
||||
self.assertEqual(payload["subscribe_info"]["name"], "新标题")
|
||||
|
||||
|
||||
class _EndpointUser(SimpleNamespace):
|
||||
"""
|
||||
最小用户替身,模拟订阅 endpoint 依赖的用户权限字段。
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, is_superuser: bool, permissions: dict | None = None):
|
||||
super().__init__(
|
||||
name=name,
|
||||
is_superuser=is_superuser,
|
||||
permissions=permissions or {},
|
||||
)
|
||||
|
||||
|
||||
class _EndpointAsyncDb:
|
||||
"""
|
||||
最小异步数据库替身,用于观察 endpoint 删除的订阅对象。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.deleted = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
async def delete(self, obj):
|
||||
self.deleted.append(obj)
|
||||
|
||||
async def commit(self):
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self):
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
class _EndpointBackgroundTasks:
|
||||
"""
|
||||
最小后台任务替身,记录 endpoint 入队的任务参数。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tasks = []
|
||||
|
||||
def add_task(self, func, **kwargs):
|
||||
self.tasks.append({"func": func, "kwargs": kwargs})
|
||||
|
||||
|
||||
class _EndpointMediaInfo:
|
||||
"""
|
||||
最小媒体信息替身,模拟 SubscribeOper 写订阅行所需字段。
|
||||
"""
|
||||
|
||||
title = "测试剧集"
|
||||
year = "2026"
|
||||
type = MediaType.TV
|
||||
tmdb_id = 123
|
||||
imdb_id = "tt123"
|
||||
tvdb_id = 456
|
||||
douban_id = "douban-1"
|
||||
bangumi_id = 789
|
||||
episode_group = None
|
||||
vote_average = 8.0
|
||||
overview = "测试简介"
|
||||
|
||||
@staticmethod
|
||||
def get_poster_image():
|
||||
return "poster.jpg"
|
||||
|
||||
@staticmethod
|
||||
def get_backdrop_image():
|
||||
return "backdrop.jpg"
|
||||
|
||||
|
||||
class _EndpointSubscribe:
|
||||
"""
|
||||
最小订阅替身,模拟 endpoint 依赖的 ORM 对象接口。
|
||||
@@ -186,6 +881,7 @@ class _EndpointSubscribe:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.pop("id", None)
|
||||
self.username = kwargs.pop("username", None)
|
||||
self.name = kwargs.pop("name", None)
|
||||
self.total_episode = kwargs.pop("total_episode", None)
|
||||
self.lack_episode = kwargs.pop("lack_episode", None)
|
||||
|
||||
Reference in New Issue
Block a user