Merge pull request #5419 from wikrin/subscribe-source-query-enhancement

This commit is contained in:
jxxghp
2026-01-25 14:04:42 +08:00
committed by GitHub
5 changed files with 122 additions and 14 deletions

View File

@@ -292,10 +292,6 @@ class DownloadChain(ChainBase):
# 登记下载记录
downloadhis = DownloadHistoryOper()
# 获取应用的识别词(如果有)
custom_words_str = None
if hasattr(_meta, 'apply_words') and _meta.apply_words:
custom_words_str = '\n'.join(_meta.apply_words)
downloadhis.add(
path=download_path.as_posix(),
type=_media.type.value,
@@ -319,7 +315,6 @@ class DownloadChain(ChainBase):
date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
media_category=_media.category,
episode_group=_media.episode_group,
custom_words=custom_words_str,
note={"source": source}
)

View File

@@ -1119,6 +1119,19 @@ class SubscribeChain(ChainBase):
})
logger.info(f'{subscribe.name} 订阅元数据更新完成')
def get_subscribe_by_source(self, source: str) -> Optional[Subscribe]:
"""
从来源获取订阅
"""
source_keyword = self.parse_subscribe_source_keyword(source)
if not source_keyword:
return None
# 只保留需要的字段动态获取订阅
valid_fields = {k: v for k, v in source_keyword.items()
if k in ["type", "season", "tmdbid", "doubanid", "bangumiid"]}
# 暂时不考虑订阅历史, 若有必要再添加
return SubscribeOper().get_by(**valid_fields)
@staticmethod
def follow():
"""
@@ -1828,8 +1841,9 @@ class SubscribeChain(ChainBase):
def get_subscribe_source_keyword(subscribe: Subscribe) -> str:
"""
构造用于订阅来源的关键字字符串
:param subscribe: Subscribe 对象
:return: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
:return str: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
"""
source_keyword = {
'id': subscribe.id,
@@ -1844,3 +1858,24 @@ class SubscribeChain(ChainBase):
'bangumiid': subscribe.bangumiid
}
return f"Subscribe|{json.dumps(source_keyword, ensure_ascii=False)}"
@staticmethod
def parse_subscribe_source_keyword(source_keyword_str: str) -> Optional[dict]:
"""
解析订阅来源关键字字符串
:param source_keyword_str: 订阅来源关键字字符串,格式为 "Subscribe|{...}"
:return Dict: 如果解析失败则返回None
"""
if not source_keyword_str or not source_keyword_str.startswith("Subscribe|"):
return None
try:
# 分割字符串获取JSON部分
json_part = source_keyword_str.split("|", 1)[1]
# 解析JSON字符串
source_keyword = json.loads(json_part)
return source_keyword
except (IndexError, json.JSONDecodeError, TypeError) as e:
logger.error(f"解析订阅来源关键字失败: {e}")
return None

View File

@@ -10,6 +10,7 @@ from app import schemas
from app.chain import ChainBase
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
from app.chain.subscribe import SubscribeChain
from app.chain.tmdb import TmdbChain
from app.core.config import settings, global_vars
from app.core.context import MediaInfo
@@ -1222,7 +1223,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
# 提前获取下载历史,以便获取自定义识别词
download_history = None
downloadhis = DownloadHistoryOper()
if bluray_dir:
if download_hash:
# 先按hash查询
download_history = downloadhis.get_by_hash(download_hash)
elif bluray_dir:
# 蓝光原盘,按目录名查询
download_history = downloadhis.get_by_path(file_path.as_posix())
else:
@@ -1231,14 +1235,14 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
if download_file:
download_history = downloadhis.get_by_hash(download_file.download_hash)
# 获取自定义识别词
custom_words_list = None
if download_history and download_history.custom_words:
custom_words_list = download_history.custom_words.split('\n')
if not meta:
# 文件元数据(传入自定义识别词)
file_meta = MetaInfoPath(file_path, custom_words=custom_words_list)
subscribe_custom_words = None
if download_history and isinstance(download_history.note, dict):
# 使用source动态获取订阅
subscribe = SubscribeChain().get_subscribe_by_source(download_history.note.get("source"))
subscribe_custom_words = subscribe.custom_words.split("\n") if subscribe and subscribe.custom_words else None
# 文件元数据(优先使用订阅识别词)
file_meta = MetaInfoPath(file_path, custom_words=subscribe_custom_words)
else:
file_meta = meta

View File

@@ -227,6 +227,66 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@db_query
def get_by(cls, db: Session, type: str, season: Optional[str] = None,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
"""
根据条件查询订阅
"""
# TMDBID
if tmdbid:
if season is not None:
result = db.query(cls).filter(
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
)
else:
result = db.query(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
# 豆瓣ID
elif doubanid:
result = db.query(cls).filter(cls.doubanid == doubanid, cls.type == type)
# BangumiID
elif bangumiid:
result = db.query(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
else:
return None
return result.first()
@classmethod
@async_db_query
async def async_get_by(cls, db: AsyncSession, type: str, season: Optional[str] = None,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
"""
根据条件查询订阅
"""
# TMDBID
if tmdbid:
if season is not None:
result = await db.execute(
select(cls).filter(
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
)
# 豆瓣ID
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid, cls.type == type)
)
# BangumiID
elif bangumiid:
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
)
else:
return None
return result.scalars().first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
subscrbies = self.get_by_tmdbid(db, tmdbid, season)

View File

@@ -111,6 +111,20 @@ class SubscribeOper(DbOper):
"""
return await Subscribe.async_get(self._db, rid=sid)
def get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
"""
根据条件查询订阅
"""
return Subscribe.get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
async def async_get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
"""
根据条件查询订阅
"""
return await Subscribe.async_get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
def list(self, state: Optional[str] = None) -> List[Subscribe]:
"""
获取订阅列表