This commit is contained in:
jxxghp
2025-07-31 19:51:48 +08:00
parent 713d44eac3
commit 8efba30adb
17 changed files with 361 additions and 469 deletions

View File

@@ -2,8 +2,8 @@ import time
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON, select from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query from app.db import db_query, db_update, Base, async_db_query
@@ -56,25 +56,25 @@ class DownloadHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_hash(db: Session, download_hash: str): def get_by_hash(cls, db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by( return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc() DownloadHistory.date.desc()
).first() ).first()
@staticmethod @classmethod
@db_query @db_query
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str): def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
if tmdbid: if tmdbid:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all() return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
elif doubanid: elif doubanid:
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all() return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
return [] return []
@staticmethod @classmethod
@db_query @db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
@classmethod @classmethod
@@ -85,14 +85,14 @@ class DownloadHistory(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_path(db: Session, path: str): def get_by_path(cls, db: Session, path: str):
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first() return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
@staticmethod @classmethod
@db_query @db_query
def get_last_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, season: Optional[str] = None, year: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None): episode: Optional[str] = None, tmdbid: Optional[int] = None):
""" """
@@ -142,9 +142,9 @@ class DownloadHistory(Base):
return [] return []
@staticmethod @classmethod
@db_query @db_query
def list_by_user_date(db: Session, date: str, username: Optional[str] = None): def list_by_user_date(cls, db: Session, date: str, username: Optional[str] = None):
""" """
查询某用户某时间之后的下载历史 查询某用户某时间之后的下载历史
""" """
@@ -156,9 +156,9 @@ class DownloadHistory(Base):
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
@staticmethod @classmethod
@db_query @db_query
def list_by_date(db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None): def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
""" """
查询某时间之后的下载历史 查询某时间之后的下载历史
""" """
@@ -174,9 +174,9 @@ class DownloadHistory(Base):
DownloadHistory.tmdbid == tmdbid).order_by( DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
@staticmethod @classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, days: int): def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(DownloadHistory) \ return db.query(DownloadHistory) \
.filter(DownloadHistory.type == mtype, .filter(DownloadHistory.type == mtype,
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
@@ -204,35 +204,35 @@ class DownloadFiles(Base):
# 状态 0-已删除 1-正常 # 状态 0-已删除 1-正常
state = Column(Integer, nullable=False, default=1) state = Column(Integer, nullable=False, default=1)
@staticmethod @classmethod
@db_query @db_query
def get_by_hash(db: Session, download_hash: str, state: Optional[int] = None): def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state: if state:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, return db.query(cls).filter(cls.download_hash == download_hash,
DownloadFiles.state == state).all() cls.state == state).all()
else: else:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() return db.query(cls).filter(cls.download_hash == download_hash).all()
@staticmethod @classmethod
@db_query @db_query
def get_by_fullpath(db: Session, fullpath: str, all_files: bool = False): def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
if not all_files: if not all_files:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( return db.query(cls).filter(cls.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).first() cls.id.desc()).first()
else: else:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( return db.query(cls).filter(cls.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).all() cls.id.desc()).all()
@staticmethod @classmethod
@db_query @db_query
def get_by_savepath(db: Session, savepath: str): def get_by_savepath(cls, db: Session, savepath: str):
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() return db.query(cls).filter(cls.savepath == savepath).all()
@staticmethod @classmethod
@db_update @db_update
def delete_by_fullpath(db: Session, fullpath: str): def delete_by_fullpath(cls, db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, db.query(cls).filter(cls.fullpath == fullpath,
DownloadFiles.state == 1).update( cls.state == 1).update(
{ {
"state": 0 "state": 0
} }

View File

@@ -41,28 +41,28 @@ class MediaServerItem(Base):
# 同步时间 # 同步时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod @classmethod
@db_query @db_query
def get_by_itemid(db: Session, item_id: str): def get_by_itemid(cls, db: Session, item_id: str):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() return db.query(cls).filter(cls.item_id == item_id).first()
@staticmethod @classmethod
@db_update @db_update
def empty(db: Session, server: Optional[str] = None): def empty(cls, db: Session, server: Optional[str] = None):
if server is None: if server is None:
db.query(MediaServerItem).delete() db.query(cls).delete()
else: else:
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() db.query(cls).filter(cls.server == server).delete()
@staticmethod @classmethod
@db_query @db_query
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): def exist_by_tmdbid(cls, db: Session, tmdbid: int, mtype: str):
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
MediaServerItem.item_type == mtype).first() cls.item_type == mtype).first()
@staticmethod @classmethod
@db_query @db_query
def exists_by_title(db: Session, title: str, mtype: str, year: str): def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
return db.query(MediaServerItem).filter(MediaServerItem.title == title, return db.query(cls).filter(cls.title == title,
MediaServerItem.item_type == mtype, cls.item_type == mtype,
MediaServerItem.year == str(year)).first() cls.year == str(year)).first()

View File

@@ -34,7 +34,7 @@ class Message(Base):
# 附件json # 附件json
note = Column(JSON) note = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(Message).order_by(Message.reg_time.desc()).offset((page - 1) * count).limit(count).all() return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()

View File

@@ -13,27 +13,27 @@ class PluginData(Base):
key = Column(String, index=True, nullable=False) key = Column(String, index=True, nullable=False)
value = Column(JSON) value = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data(db: Session, plugin_id: str): def get_plugin_data(cls, db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() return db.query(cls).filter(cls.plugin_id == plugin_id).all()
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str): def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
@staticmethod @classmethod
@db_update @db_update
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).delete()
@staticmethod @classmethod
@db_update @db_update
def del_plugin_data(db: Session, plugin_id: str): def del_plugin_data(cls, db: Session, plugin_id: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id).delete() db.query(cls).filter(cls.plugin_id == plugin_id).delete()
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() return db.query(cls).filter(cls.plugin_id == plugin_id).all()

View File

@@ -54,27 +54,27 @@ class Site(Base):
# 下载器 # 下载器
downloader = Column(String) downloader = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(Site).filter(Site.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@staticmethod @classmethod
@db_query @db_query
def get_actives(db: Session): def get_actives(cls, db: Session):
return db.query(Site).filter(Site.is_active == 1).all() return db.query(cls).filter(cls.is_active == 1).all()
@staticmethod @classmethod
@db_query @db_query
def list_order_by_pri(db: Session): def list_order_by_pri(cls, db: Session):
return db.query(Site).order_by(Site.pri).all() return db.query(cls).order_by(cls.pri).all()
@staticmethod @classmethod
@db_query @db_query
def get_domains_by_ids(db: Session, ids: list): def get_domains_by_ids(cls, db: Session, ids: list):
return [r[0] for r in db.query(Site.domain).filter(Site.id.in_(ids)).all()] return [r[0] for r in db.query(cls.domain).filter(cls.id.in_(ids)).all()]
@staticmethod @classmethod
@db_update @db_update
def reset(db: Session): def reset(cls, db: Session):
db.query(Site).delete() db.query(cls).delete()

View File

@@ -19,10 +19,10 @@ class SiteIcon(Base):
# 图标Base64 # 图标Base64
base64 = Column(String) base64 = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@classmethod @classmethod
@async_db_query @async_db_query

View File

@@ -26,12 +26,12 @@ class SiteStatistic(Base):
# 耗时记录 Json # 耗时记录 Json
note = Column(JSON) note = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(SiteStatistic).filter(SiteStatistic.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@staticmethod @classmethod
@db_update @db_update
def reset(db: Session): def reset(cls, db: Session):
db.query(SiteStatistic).delete() db.query(cls).delete()

View File

@@ -53,42 +53,42 @@ class SiteUserData(Base):
# 更新时间 # 更新时间
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S')) updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None): def get_by_domain(cls, db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
if workdate and worktime: if workdate and worktime:
return db.query(SiteUserData).filter(SiteUserData.domain == domain, return db.query(cls).filter(cls.domain == domain,
SiteUserData.updated_day == workdate, cls.updated_day == workdate,
SiteUserData.updated_time == worktime).all() cls.updated_time == worktime).all()
elif workdate: elif workdate:
return db.query(SiteUserData).filter(SiteUserData.domain == domain, return db.query(cls).filter(cls.domain == domain,
SiteUserData.updated_day == workdate).all() cls.updated_day == workdate).all()
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all() return db.query(cls).filter(cls.domain == domain).all()
@staticmethod @classmethod
@db_query @db_query
def get_by_date(db: Session, date: str): def get_by_date(cls, db: Session, date: str):
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all() return db.query(cls).filter(cls.updated_day == date).all()
@staticmethod @classmethod
@db_query @db_query
def get_latest(db: Session): def get_latest(cls, db: Session):
""" """
获取各站点最新一天的数据 获取各站点最新一天的数据
""" """
subquery = ( subquery = (
db.query( db.query(
SiteUserData.domain, cls.domain,
func.max(SiteUserData.updated_day).label('latest_update_day') func.max(cls.updated_day).label('latest_update_day')
) )
.group_by(SiteUserData.domain) .group_by(cls.domain)
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == "")) .filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery() .subquery()
) )
# 主查询:按 domain 和 updated_day 获取最新的记录 # 主查询:按 domain 和 updated_day 获取最新的记录
return db.query(SiteUserData).join( return db.query(cls).join(
subquery, subquery,
(SiteUserData.domain == subquery.c.domain) & (cls.domain == subquery.c.domain) &
(SiteUserData.updated_day == subquery.c.latest_update_day) (cls.updated_day == subquery.c.latest_update_day)
).order_by(SiteUserData.updated_time.desc()).all() ).order_by(cls.updated_time.desc()).all()

View File

@@ -88,16 +88,17 @@ class Subscribe(Base):
# 选择的剧集组 # 选择的剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None): def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid: if tmdbid:
if season: if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
Subscribe.season == season).first() cls.season == season).first()
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first() return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid: elif doubanid:
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() return db.query(cls).filter(cls.doubanid == doubanid).first()
return None return None
@classmethod @classmethod
@@ -121,15 +122,15 @@ class Subscribe(Base):
return None return None
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_state(db: Session, state: str): def get_by_state(cls, db: Session, state: str):
# 如果 state 为空或 None返回所有订阅 # 如果 state 为空或 None返回所有订阅
if not state: if not state:
return db.query(Subscribe).all() return db.query(cls).all()
else: else:
# 如果传入的状态不为空,拆分成多个状态 # 如果传入的状态不为空,拆分成多个状态
return db.query(Subscribe).filter(Subscribe.state.in_(state.split(','))).all() return db.query(cls).filter(cls.state.in_(state.split(','))).all()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -144,13 +145,13 @@ class Subscribe(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_title(db: Session, title: str, season: Optional[int] = None): def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
if season: if season:
return db.query(Subscribe).filter(Subscribe.name == title, return db.query(cls).filter(cls.name == title,
Subscribe.season == season).first() cls.season == season).first()
return db.query(Subscribe).filter(Subscribe.name == title).first() return db.query(cls).filter(cls.name == title).first()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -165,14 +166,14 @@ class Subscribe(Base):
) )
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None): def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
if season: if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
Subscribe.season == season).all() cls.season == season).all()
else: else:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() return db.query(cls).filter(cls.tmdbid == tmdbid).all()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -187,10 +188,10 @@ class Subscribe(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_doubanid(db: Session, doubanid: str): def get_by_doubanid(cls, db: Session, doubanid: str):
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() return db.query(cls).filter(cls.doubanid == doubanid).first()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -200,10 +201,10 @@ class Subscribe(Base):
) )
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_bangumiid(db: Session, bangumiid: int): def get_by_bangumiid(cls, db: Session, bangumiid: int):
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first() return db.query(cls).filter(cls.bangumiid == bangumiid).first()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -213,10 +214,10 @@ class Subscribe(Base):
) )
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_mediaid(db: Session, mediaid: str): def get_by_mediaid(cls, db: Session, mediaid: str):
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first() return db.query(cls).filter(cls.mediaid == mediaid).first()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -268,23 +269,23 @@ class Subscribe(Base):
await subscribe.async_delete(db, subscribe.id) await subscribe.async_delete(db, subscribe.id)
return True return True
@staticmethod @classmethod
@db_query @db_query
def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None): def list_by_username(cls, db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
if mtype: if mtype:
if state: if state:
return db.query(Subscribe).filter(Subscribe.state == state, return db.query(cls).filter(cls.state == state,
Subscribe.username == username, cls.username == username,
Subscribe.type == mtype).all() cls.type == mtype).all()
else: else:
return db.query(Subscribe).filter(Subscribe.username == username, return db.query(cls).filter(cls.username == username,
Subscribe.type == mtype).all() cls.type == mtype).all()
else: else:
if state: if state:
return db.query(Subscribe).filter(Subscribe.state == state, return db.query(cls).filter(cls.state == state,
Subscribe.username == username).all() cls.username == username).all()
else: else:
return db.query(Subscribe).filter(Subscribe.username == username).all() return db.query(cls).filter(cls.username == username).all()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -310,13 +311,13 @@ class Subscribe(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, days: int): def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(Subscribe) \ return db.query(cls) \
.filter(Subscribe.type == mtype, .filter(cls.type == mtype,
Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S", cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days))) time.localtime(time.time() - 86400 * int(days)))
).all() ).all()
@classmethod @classmethod

View File

@@ -73,13 +73,13 @@ class SubscribeHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_type(cls, db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(SubscribeHistory).filter( return db.query(cls).filter(
SubscribeHistory.type == mtype cls.type == mtype
).order_by( ).order_by(
SubscribeHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
@classmethod @classmethod
@@ -94,16 +94,17 @@ class SubscribeHistory(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None): def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid: if tmdbid:
if season: if season:
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
SubscribeHistory.season == season).first() cls.season == season).first()
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first() return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid: elif doubanid:
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first() return db.query(cls).filter(cls.doubanid == doubanid).first()
return None return None
@classmethod @classmethod

View File

@@ -14,10 +14,10 @@ class SystemConfig(Base):
# 值 # 值
value = Column(JSON) value = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_by_key(db: Session, key: str): def get_by_key(cls, db: Session, key: str):
return db.query(SystemConfig).filter(SystemConfig.key == key).first() return db.query(cls).filter(cls.key == key).first()
@db_update @db_update
def delete_by_key(self, db: Session, key: str): def delete_by_key(self, db: Session, key: str):

View File

@@ -60,22 +60,23 @@ class TransferHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def list_by_title(db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None): def list_by_title(cls, db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None: if status is not None:
return db.query(TransferHistory).filter( return db.query(cls).filter(
TransferHistory.status == status cls.status == status
).order_by( ).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
else: else:
return db.query(TransferHistory).filter(or_( return db.query(cls).filter(or_(
TransferHistory.title.like(f'%{title}%'), cls.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'), cls.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%'), cls.dest.like(f'%{title}%'),
)).order_by( )).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
@classmethod @classmethod
@@ -102,18 +103,18 @@ class TransferHistory(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None): def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
if status is not None: if status is not None:
return db.query(TransferHistory).filter( return db.query(cls).filter(
TransferHistory.status == status cls.status == status
).order_by( ).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
else: else:
return db.query(TransferHistory).order_by( return db.query(cls).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
@classmethod @classmethod
@@ -136,49 +137,49 @@ class TransferHistory(Base):
) )
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_hash(db: Session, download_hash: str): def get_by_hash(cls, db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first() return db.query(cls).filter(cls.download_hash == download_hash).first()
@staticmethod @classmethod
@db_query @db_query
def get_by_src(db: Session, src: str, storage: Optional[str] = None): def get_by_src(cls, db: Session, src: str, storage: Optional[str] = None):
if storage: if storage:
return db.query(TransferHistory).filter(TransferHistory.src == src, return db.query(cls).filter(cls.src == src,
TransferHistory.src_storage == storage).first() cls.src_storage == storage).first()
else: else:
return db.query(TransferHistory).filter(TransferHistory.src == src).first() return db.query(cls).filter(cls.src == src).first()
@staticmethod @classmethod
@db_query @db_query
def get_by_dest(db: Session, dest: str): def get_by_dest(cls, db: Session, dest: str):
return db.query(TransferHistory).filter(TransferHistory.dest == dest).first() return db.query(cls).filter(cls.dest == dest).first()
@staticmethod @classmethod
@db_query @db_query
def list_by_hash(db: Session, download_hash: str): def list_by_hash(cls, db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() return db.query(cls).filter(cls.download_hash == download_hash).all()
@staticmethod @classmethod
@db_query @db_query
def statistic(db: Session, days: Optional[int] = 7): def statistic(cls, db: Session, days: Optional[int] = 7):
""" """
统计最近days天的下载历史数量按日期分组返回每日数量 统计最近days天的下载历史数量按日期分组返回每日数量
""" """
sub_query = db.query(func.substr(TransferHistory.date, 1, 10).label('date'), sub_query = db.query(func.substr(cls.date, 1, 10).label('date'),
TransferHistory.id.label('id')).filter( cls.id.label('id')).filter(
TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery() time.localtime(time.time() - 86400 * days))).subquery()
return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
@staticmethod @classmethod
@db_query @db_query
def count(db: Session, status: bool = None): def count(cls, db: Session, status: bool = None):
if status is not None: if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0] return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else: else:
return db.query(func.count(TransferHistory.id)).first()[0] return db.query(func.count(cls.id)).first()[0]
@classmethod @classmethod
@async_db_query @async_db_query
@@ -193,16 +194,16 @@ class TransferHistory(Base):
) )
return result.scalar() return result.scalar()
@staticmethod @classmethod
@db_query @db_query
def count_by_title(db: Session, title: str, status: bool = None): def count_by_title(cls, db: Session, title: str, status: bool = None):
if status is not None: if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0] return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else: else:
return db.query(func.count(TransferHistory.id)).filter(or_( return db.query(func.count(cls.id)).filter(or_(
TransferHistory.title.like(f'%{title}%'), cls.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'), cls.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%') cls.dest.like(f'%{title}%')
)).first()[0] )).first()[0]
@classmethod @classmethod
@@ -222,9 +223,9 @@ class TransferHistory(Base):
) )
return result.scalar() return result.scalar()
@staticmethod @classmethod
@db_query @db_query
def list_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None, def list_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None,
season: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None): episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None):
""" """
@@ -235,80 +236,80 @@ class TransferHistory(Base):
if tmdbid and mtype: if tmdbid and mtype:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.episodes == episode, cls.episodes == episode,
TransferHistory.dest == dest).all() cls.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.seasons == season).all() cls.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.dest == dest).all() cls.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype).all() cls.type == mtype).all()
# 标题 + 年份 # 标题 + 年份
elif title and year: elif title and year:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
return db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.episodes == episode, cls.episodes == episode,
TransferHistory.dest == dest).all() cls.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
return db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.seasons == season).all() cls.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
return db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.dest == dest).all() cls.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year).all() cls.year == year).all()
# 类型 + 转移路径emby webhook season无tmdbid场景 # 类型 + 转移路径emby webhook season无tmdbid场景
elif mtype and season and dest: elif mtype and season and dest:
# 电视剧某季 # 电视剧某季
return db.query(TransferHistory).filter(TransferHistory.type == mtype, return db.query(cls).filter(cls.type == mtype,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.dest.like(f"{dest}%")).all() cls.dest.like(f"{dest}%")).all()
return [] return []
@staticmethod @classmethod
@db_query @db_query
def get_by_type_tmdbid(db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None): def get_by_type_tmdbid(cls, db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
""" """
据tmdbid、type查询转移记录 据tmdbid、type查询转移记录
""" """
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype).first() cls.type == mtype).first()
@staticmethod @classmethod
@db_update @db_update
def update_download_hash(db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None): def update_download_hash(cls, db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update( db.query(cls).filter(cls.id == historyid).update(
{ {
"download_hash": download_hash "download_hash": download_hash
} }
) )
@staticmethod @classmethod
@db_query @db_query
def list_by_date(db: Session, date: str): def list_by_date(cls, db: Session, date: str):
""" """
查询某时间之后的转移历史 查询某时间之后的转移历史
""" """
return db.query(TransferHistory).filter(TransferHistory.date > date).order_by(TransferHistory.id.desc()).all() return db.query(cls).filter(cls.date > date).order_by(cls.id.desc()).all()

View File

@@ -32,10 +32,10 @@ class User(Base):
# 用户个性化设置 json # 用户个性化设置 json
settings = Column(JSON, default=dict) settings = Column(JSON, default=dict)
@staticmethod @classmethod
@db_query @db_query
def get_by_name(db: Session, name: str): def get_by_name(cls, db: Session, name: str):
return db.query(User).filter(User.name == name).first() return db.query(cls).filter(cls.name == name).first()
@classmethod @classmethod
@async_db_query @async_db_query
@@ -45,10 +45,10 @@ class User(Base):
) )
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_id(db: Session, user_id: int): def get_by_id(cls, db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first() return db.query(cls).filter(cls.id == user_id).first()
@classmethod @classmethod
@async_db_query @async_db_query

View File

@@ -22,12 +22,12 @@ class UserConfig(Base):
Index('ix_userconfig_username_key', 'username', 'key'), Index('ix_userconfig_username_key', 'username', 'key'),
) )
@staticmethod @classmethod
@db_query @db_query
def get_by_key(db: Session, username: str, key: str): def get_by_key(cls, db: Session, username: str, key: str):
return db.query(UserConfig) \ return db.query(cls) \
.filter(UserConfig.username == username) \ .filter(cls.username == username) \
.filter(UserConfig.key == key) \ .filter(cls.key == key) \
.first() .first()
@db_update @db_update

View File

@@ -1,69 +0,0 @@
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, Base
class UserRequest(Base):
"""
用户请求表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 申请用户
req_user = Column(String, index=True, nullable=False)
# 申请时间
req_time = Column(String)
# 申请备注
req_remark = Column(String)
# 审批用户
app_user = Column(String, index=True, nullable=False)
# 审批时间
app_time = Column(String)
# 审批状态 0-待审批 1-通过 2-拒绝
app_status = Column(Integer, default=0)
# 类型
type = Column(String)
# 标题
title = Column(String)
# 年份
year = Column(String)
# 媒体ID
tmdbid = Column(Integer)
imdbid = Column(String)
tvdbid = Column(Integer)
doubanid = Column(String)
bangumiid = Column(Integer)
# 季号
season = Column(Integer)
# 海报
poster = Column(String)
# 背景图
backdrop = Column(String)
# 评分float
vote = Column(Float)
# 简介
description = Column(String)
@staticmethod
@db_query
def get_by_req_user(db: Session, req_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.req_user == req_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.req_user == req_user).all()
@staticmethod
@db_query
def get_by_app_user(db: Session, app_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.app_user == app_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.app_user == app_user).all()
@staticmethod
@db_query
def get_by_status(db: Session, status: int):
return db.query(UserRequest).filter(UserRequest.app_status == status).all()

View File

@@ -44,135 +44,135 @@ class Workflow(Base):
# 最后执行时间 # 最后执行时间
last_time = Column(String) last_time = Column(String)
@staticmethod @classmethod
@db_query @db_query
def list(db): def list(cls, db):
return db.query(Workflow).all() return db.query(cls).all()
@staticmethod @classmethod
@async_db_query @async_db_query
async def async_list(db: AsyncSession): async def async_list(cls, db: AsyncSession):
result = await db.execute(select(Workflow)) result = await db.execute(select(cls))
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_enabled_workflows(db): def get_enabled_workflows(cls, db):
return db.query(Workflow).filter(Workflow.state != 'P').all() return db.query(cls).filter(cls.state != 'P').all()
@staticmethod @classmethod
@async_db_query @async_db_query
async def async_get_enabled_workflows(db: AsyncSession): async def async_get_enabled_workflows(cls, db: AsyncSession):
result = await db.execute(select(Workflow).where(Workflow.state != 'P')) result = await db.execute(select(cls).where(cls.state != 'P'))
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_timer_triggered_workflows(db): def get_timer_triggered_workflows(cls, db):
"""获取定时触发的工作流""" """获取定时触发的工作流"""
return db.query(Workflow).filter( return db.query(cls).filter(
and_( and_(
or_( or_(
Workflow.trigger_type == 'timer', cls.trigger_type == 'timer',
not Workflow.trigger_type not cls.trigger_type
), ),
Workflow.state != 'P' cls.state != 'P'
) )
).all() ).all()
@staticmethod @classmethod
@async_db_query @async_db_query
async def async_get_timer_triggered_workflows(db: AsyncSession): async def async_get_timer_triggered_workflows(cls, db: AsyncSession):
"""异步获取定时触发的工作流""" """异步获取定时触发的工作流"""
result = await db.execute(select(Workflow).where( result = await db.execute(select(cls).where(
and_( and_(
or_( or_(
Workflow.trigger_type == 'timer', cls.trigger_type == 'timer',
not Workflow.trigger_type not cls.trigger_type
), ),
Workflow.state != 'P' cls.state != 'P'
) )
)) ))
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_event_triggered_workflows(db): def get_event_triggered_workflows(cls, db):
"""获取事件触发的工作流""" """获取事件触发的工作流"""
return db.query(Workflow).filter( return db.query(cls).filter(
and_( and_(
Workflow.trigger_type == 'event', cls.trigger_type == 'event',
Workflow.state != 'P' cls.state != 'P'
) )
).all() ).all()
@staticmethod @classmethod
@async_db_query @async_db_query
async def async_get_event_triggered_workflows(db: AsyncSession): async def async_get_event_triggered_workflows(cls, db: AsyncSession):
"""异步获取事件触发的工作流""" """异步获取事件触发的工作流"""
result = await db.execute(select(Workflow).where( result = await db.execute(select(cls).where(
and_( and_(
Workflow.trigger_type == 'event', cls.trigger_type == 'event',
Workflow.state != 'P' cls.state != 'P'
) )
)) ))
return result.scalars().all() return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_name(db, name: str): def get_by_name(cls, db, name: str):
return db.query(Workflow).filter(Workflow.name == name).first() return db.query(cls).filter(cls.name == name).first()
@staticmethod @classmethod
@async_db_query @async_db_query
async def async_get_by_name(db: AsyncSession, name: str): async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(select(Workflow).where(Workflow.name == name)) result = await db.execute(select(cls).where(cls.name == name))
return result.scalars().first() return result.scalars().first()
@staticmethod @classmethod
@db_update @db_update
def update_state(db, wid: int, state: str): def update_state(cls, db, wid: int, state: str):
db.query(Workflow).filter(Workflow.id == wid).update({"state": state}) db.query(cls).filter(cls.id == wid).update({"state": state})
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_update_state(db: AsyncSession, wid: int, state: str): async def async_update_state(cls, db: AsyncSession, wid: int, state: str):
from sqlalchemy import update from sqlalchemy import update
await db.execute(update(Workflow).where(Workflow.id == wid).values(state=state)) await db.execute(update(cls).where(cls.id == wid).values(state=state))
return True return True
@staticmethod @classmethod
@db_update @db_update
def start(db, wid: int): def start(cls, db, wid: int):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"state": 'R' "state": 'R'
}) })
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_start(db: AsyncSession, wid: int): async def async_start(cls, db: AsyncSession, wid: int):
from sqlalchemy import update from sqlalchemy import update
await db.execute(update(Workflow).where(Workflow.id == wid).values(state='R')) await db.execute(update(cls).where(cls.id == wid).values(state='R'))
return True return True
@staticmethod @classmethod
@db_update @db_update
def fail(db, wid: int, result: str): def fail(cls, db, wid: int, result: str):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'F', "state": 'F',
"result": result, "result": result,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S') "last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}) })
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_fail(db: AsyncSession, wid: int, result: str): async def async_fail(cls, db: AsyncSession, wid: int, result: str):
from sqlalchemy import update from sqlalchemy import update
await db.execute(update(Workflow).where( await db.execute(update(cls).where(
and_(Workflow.id == wid, Workflow.state != "P") and_(cls.id == wid, cls.state != "P")
).values( ).values(
state='F', state='F',
result=result, result=result,
@@ -180,73 +180,73 @@ class Workflow(Base):
)) ))
return True return True
@staticmethod @classmethod
@db_update @db_update
def success(db, wid: int, result: Optional[str] = None): def success(cls, db, wid: int, result: Optional[str] = None):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'S', "state": 'S',
"result": result, "result": result,
"run_count": Workflow.run_count + 1, "run_count": cls.run_count + 1,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S') "last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}) })
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_success(db: AsyncSession, wid: int, result: Optional[str] = None): async def async_success(cls, db: AsyncSession, wid: int, result: Optional[str] = None):
from sqlalchemy import update from sqlalchemy import update
await db.execute(update(Workflow).where( await db.execute(update(cls).where(
and_(Workflow.id == wid, Workflow.state != "P") and_(cls.id == wid, cls.state != "P")
).values( ).values(
state='S', state='S',
result=result, result=result,
run_count=Workflow.run_count + 1, run_count=cls.run_count + 1,
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S') last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
)) ))
return True return True
@staticmethod @classmethod
@db_update @db_update
def reset(db, wid: int, reset_count: Optional[bool] = False): def reset(cls, db, wid: int, reset_count: Optional[bool] = False):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"state": 'W', "state": 'W',
"result": None, "result": None,
"current_action": None, "current_action": None,
"run_count": 0 if reset_count else Workflow.run_count, "run_count": 0 if reset_count else cls.run_count,
}) })
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_reset(db: AsyncSession, wid: int, reset_count: Optional[bool] = False): async def async_reset(cls, db: AsyncSession, wid: int, reset_count: Optional[bool] = False):
from sqlalchemy import update from sqlalchemy import update
await db.execute(update(Workflow).where(Workflow.id == wid).values( await db.execute(update(cls).where(cls.id == wid).values(
state='W', state='W',
result=None, result=None,
current_action=None, current_action=None,
run_count=0 if reset_count else Workflow.run_count, run_count=0 if reset_count else cls.run_count,
)) ))
return True return True
@staticmethod @classmethod
@db_update @db_update
def update_current_action(db, wid: int, action_id: str, context: dict): def update_current_action(cls, db, wid: int, action_id: str, context: dict):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"current_action": Workflow.current_action + f",{action_id}" if Workflow.current_action else action_id, "current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
"context": context "context": context
}) })
return True return True
@staticmethod @classmethod
@async_db_update @async_db_update
async def async_update_current_action(db: AsyncSession, wid: int, action_id: str, context: dict): async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
from sqlalchemy import update from sqlalchemy import update
# 先获取当前current_action # 先获取当前current_action
result = await db.execute(select(Workflow.current_action).where(Workflow.id == wid)) result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar() current_action = result.scalar()
new_current_action = current_action + f",{action_id}" if current_action else action_id new_current_action = current_action + f",{action_id}" if current_action else action_id
await db.execute(update(Workflow).where(Workflow.id == wid).values( await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action, current_action=new_current_action,
context=context context=context
)) ))

View File

@@ -1,42 +0,0 @@
from typing import Optional
from app.db import DbOper
from app.db.models.userrequest import UserRequest
class UserRequestOper(DbOper):
"""
用户请求管理
"""
def get_need_approve(self) -> Optional[UserRequest]:
"""
获取待审批申请
"""
return UserRequest.get_by_status(self._db, 0)
def get_my_requests(self, username: str) -> Optional[UserRequest]:
"""
获取我的申请
"""
return UserRequest.get_by_req_user(self._db, username)
def approve(self, rid: int) -> bool:
"""
审批申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 1})
return True
return False
def deny(self, rid: int) -> bool:
"""
拒绝申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 2})
return True
return False