From 0c8fd5121a54c6888f655648fe0ca09ac38111c1 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 1 Aug 2025 14:19:34 +0800 Subject: [PATCH] fix async apis --- app/api/endpoints/dashboard.py | 11 ++- app/api/endpoints/download.py | 2 +- app/api/endpoints/history.py | 2 +- app/api/endpoints/mediaserver.py | 6 +- app/api/endpoints/message.py | 18 ++-- app/api/endpoints/plugin.py | 43 ++++---- app/api/endpoints/site.py | 162 +++++++++++++++---------------- app/api/endpoints/storage.py | 4 +- app/api/endpoints/subscribe.py | 14 +-- app/api/endpoints/system.py | 24 ++--- app/api/endpoints/transfer.py | 3 +- app/api/endpoints/user.py | 18 ++-- app/chain/__init__.py | 13 +++ app/chain/site.py | 10 +- app/chain/subscribe.py | 6 +- app/chain/system.py | 3 +- app/chain/torrents.py | 16 +++ app/db/models/message.py | 13 ++- app/db/models/site.py | 15 ++- app/db/models/siteuserdata.py | 41 +++++++- app/db/models/systemconfig.py | 11 ++- app/db/models/transferhistory.py | 15 +++ app/db/site_oper.py | 6 ++ app/db/systemconfig_oper.py | 27 ++++++ app/helper/browser.py | 1 - app/helper/cookie.py | 128 ++++++++++++------------ app/modules/indexer/__init__.py | 14 +++ app/scheduler.py | 57 +++++++---- 28 files changed, 427 insertions(+), 256 deletions(-) diff --git a/app/api/endpoints/dashboard.py b/app/api/endpoints/dashboard.py index 07b018a1..8fbdf6ac 100644 --- a/app/api/endpoints/dashboard.py +++ b/app/api/endpoints/dashboard.py @@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any: @router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo]) -def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询后台服务信息 """ @@ -119,7 +119,7 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any: @router.get("/schedule2", summary="后台服务(API_TOKEN)", response_model=List[schemas.ScheduleInfo]) -def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any: +async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any: """ 查询下载器信息 API_TOKEN认证(?token=xxx) """ @@ -127,12 +127,13 @@ def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any: @router.get("/transfer", summary="文件整理统计", response_model=List[int]) -def transfer(days: Optional[int] = 7, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def transfer(days: Optional[int] = 7, + db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询文件整理统计信息 """ - transfer_stat = TransferHistory.statistic(db, days) + transfer_stat = await TransferHistory.async_statistic(db, days) return [stat[1] for stat in transfer_stat] diff --git a/app/api/endpoints/download.py b/app/api/endpoints/download.py index 2c644461..0e5ea889 100644 --- a/app/api/endpoints/download.py +++ b/app/api/endpoints/download.py @@ -116,7 +116,7 @@ def stop(hashString: str, name: Optional[str] = None, @router.get("/clients", summary="查询可用下载器", response_model=List[dict]) -def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询可用下载器 """ diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index 3cc59d3f..8f969f3a 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -80,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory, deletesrc: Optional[bool] = False, deletedest: Optional[bool] = False, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 删除整理记录 """ diff --git a/app/api/endpoints/mediaserver.py b/app/api/endpoints/mediaserver.py index de90248d..92db747e 100644 --- a/app/api/endpoints/mediaserver.py +++ b/app/api/endpoints/mediaserver.py @@ -1,7 +1,7 @@ from typing import Any, List, Dict, Optional from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app import schemas from app.chain.download import DownloadChain @@ -48,7 +48,7 @@ async def exists_local(title: Optional[str] = None, mtype: Optional[str] = None, tmdbid: Optional[int] = None, season: Optional[int] = None, - db: Session = Depends(get_async_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 判断本地是否存在 @@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False, @router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict]) -def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询可用媒体服务器 """ diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index 184b3d2a..45116520 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional from fastapi import APIRouter, BackgroundTasks, Depends, Request from pywebpush import WebPushException, webpush -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import PlainTextResponse from app import schemas from app.chain.message import MessageChain from app.core.config import settings, global_vars from app.core.security import verify_token, verify_apitoken -from app.db import get_db +from app.db import get_async_db from app.db.models import User from app.db.models.message import Message from app.db.user_oper import get_current_active_superuser @@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super @router.get("/web", summary="获取WEB消息", response_model=List[dict]) -def get_web_message(_: schemas.TokenPayload = Depends(verify_token), - db: Session = Depends(get_db), - page: Optional[int] = 1, - count: Optional[int] = 20): +async def get_web_message(_: schemas.TokenPayload = Depends(verify_token), + db: AsyncSession = Depends(get_async_db), + page: Optional[int] = 1, + count: Optional[int] = 20): """ 获取WEB消息列表 """ ret_messages = [] - messages = Message.list_by_page(db, page=page, count=count) + messages = Message.async_list_by_page(db, page=page, count=count) for message in messages: try: ret_messages.append(message.to_dict()) @@ -106,7 +106,7 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], return str(err) -def vocechat_verify() -> Any: +async def vocechat_verify() -> Any: """ VoceChat验证响应 """ @@ -128,7 +128,7 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None, @router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response) -def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)): +async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)): """ 客户端webpush通知订阅 """ diff --git a/app/api/endpoints/plugin.py b/app/api/endpoints/plugin.py index dd11f9c9..078e4333 100644 --- a/app/api/endpoints/plugin.py +++ b/app/api/endpoints/plugin.py @@ -13,6 +13,7 @@ from app.command import Command from app.core.config import settings from app.core.plugin import PluginManager from app.core.security import verify_apikey, verify_token +from app.db.models import User from app.db.systemconfig_oper import SystemConfigOper from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async from app.factory import app @@ -138,7 +139,7 @@ def register_plugin(plugin_id: str): @router.get("/", summary="所有插件", response_model=List[schemas.Plugin]) -async def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser_async), +async def all_plugins(_: User = Depends(get_current_active_superuser_async), state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]: """ 查询所有插件清单,包括本地插件和在线插件,插件状态:installed, market, all @@ -187,7 +188,7 @@ async def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_super @router.get("/installed", summary="已安装插件", response_model=List[str]) -def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any: """ 查询用户已安装插件清单 """ @@ -203,7 +204,7 @@ async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any: @router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response) -def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any: """ 重新加载插件 """ @@ -218,7 +219,7 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_ def install(plugin_id: str, repo_url: Optional[str] = "", force: Optional[bool] = False, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 安装插件 """ @@ -260,7 +261,7 @@ def remotes(token: str) -> Any: @router.get("/form/{plugin_id}", summary="获取插件表单页面") def plugin_form(plugin_id: str, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: + _: User = Depends(get_current_active_superuser)) -> dict: """ 根据插件ID获取插件配置表单或Vue组件URL """ @@ -284,7 +285,7 @@ def plugin_form(plugin_id: str, @router.get("/page/{plugin_id}", summary="获取插件数据页面") -def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: +def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict: """ 根据插件ID获取插件数据页面 """ @@ -333,7 +334,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()] @router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response) def reset_plugin(plugin_id: str, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 根据插件ID重置插件配置及数据 """ @@ -394,7 +395,7 @@ async def plugin_static_file(plugin_id: str, filepath: str): @router.get("/folders", summary="获取插件文件夹配置", response_model=dict) -def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: +async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict: """ 获取插件文件夹分组配置 """ @@ -407,7 +408,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe @router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response) -def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any: """ 保存插件文件夹分组配置 """ @@ -420,7 +421,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur @router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response) -def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def create_plugin_folder(folder_name: str, + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 创建新的插件文件夹 """ @@ -434,34 +436,35 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get @router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response) -def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def delete_plugin_folder(folder_name: str, + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 删除插件文件夹 """ folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {} if folder_name in folders: del folders[folder_name] - SystemConfigOper().set(SystemConfigKey.PluginFolders, folders) + await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders) return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功") else: return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在") @router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response) -def update_folder_plugins(folder_name: str, plugin_ids: List[str], - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def update_folder_plugins(folder_name: str, plugin_ids: List[str], + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 更新指定文件夹中的插件列表 """ folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {} folders[folder_name] = plugin_ids - SystemConfigOper().set(SystemConfigKey.PluginFolders, folders) + await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders) return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新") @router.get("/{plugin_id}", summary="获取插件配置") -def plugin_config(plugin_id: str, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: +async def plugin_config(plugin_id: str, + _: User = Depends(get_current_active_superuser_async)) -> dict: """ 根据插件ID获取插件配置信息 """ @@ -470,7 +473,7 @@ def plugin_config(plugin_id: str, @router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response) def set_plugin_config(plugin_id: str, conf: dict, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 更新插件配置 """ @@ -486,7 +489,7 @@ def set_plugin_config(plugin_id: str, conf: dict, @router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response) def uninstall_plugin(plugin_id: str, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 卸载插件 """ @@ -527,7 +530,7 @@ def uninstall_plugin(plugin_id: str, @router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response) def clone_plugin(plugin_id: str, clone_data: dict, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 创建插件分身 """ diff --git a/app/api/endpoints/site.py b/app/api/endpoints/site.py index 64851ee0..3c0cf14d 100644 --- a/app/api/endpoints/site.py +++ b/app/api/endpoints/site.py @@ -10,7 +10,7 @@ from app.api.endpoints.plugin import register_plugin_api from app.chain.site import SiteChain from app.chain.torrents import TorrentsChain from app.command import Command -from app.core.event import EventManager +from app.core.event import eventmanager from app.core.plugin import PluginManager from app.core.security import verify_token from app.db import get_db, get_async_db @@ -21,7 +21,7 @@ from app.db.models.sitestatistic import SiteStatistic from app.db.models.siteuserdata import SiteUserData from app.db.site_oper import SiteOper from app.db.systemconfig_oper import SystemConfigOper -from app.db.user_oper import get_current_active_superuser +from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async from app.helper.sites import SitesHelper # noqa from app.scheduler import Scheduler from app.schemas.types import SystemConfigKey, EventType @@ -31,20 +31,20 @@ router = APIRouter() @router.get("/", summary="所有站点", response_model=List[schemas.Site]) -def read_sites(db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]: +async def read_sites(db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser)) -> List[dict]: """ 获取站点列表 """ - return Site.list_order_by_pri(db) + return Site.async_list_order_by_pri(db) @router.post("/", summary="新增站点", response_model=schemas.Response) -def add_site( +async def add_site( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), site_in: schemas.Site, - _: schemas.TokenPayload = Depends(get_current_active_superuser) + _: User = Depends(get_current_active_superuser) ) -> Any: """ 新增站点 @@ -54,10 +54,10 @@ def add_site( if SitesHelper().auth_level < 2: return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!") domain = StringUtils.get_url_domain(site_in.url) - site_info = SitesHelper().get_indexer(domain) + site_info = await SitesHelper().async_get_indexer(domain) if not site_info: return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确") - if Site.get_by_domain(db, domain): + if await Site.async_get_by_domain(db, domain): return schemas.Response(success=False, message=f"{domain} 站点己存在") # 保存站点信息 site_in.domain = domain @@ -70,39 +70,39 @@ def add_site( site = Site(**site_in.dict()) site.create(db) # 通知站点更新 - EventManager().send_event(EventType.SiteUpdated, { + await eventmanager.async_send_event(EventType.SiteUpdated, { "domain": domain }) return schemas.Response(success=True) @router.put("/", summary="更新站点", response_model=schemas.Response) -def update_site( +async def update_site( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), site_in: schemas.Site, - _: schemas.TokenPayload = Depends(get_current_active_superuser) + _: User = Depends(get_current_active_superuser) ) -> Any: """ 更新站点信息 """ - site = Site.get(db, site_in.id) + site = await Site.async_get(db, site_in.id) if not site: return schemas.Response(success=False, message="站点不存在") # 校正地址格式 _scheme, _netloc = StringUtils.get_url_netloc(site_in.url) site_in.url = f"{_scheme}://{_netloc}/" - site.update(db, site_in.dict()) + await site.async_update(db, site_in.dict()) # 通知站点更新 - EventManager().send_event(EventType.SiteUpdated, { + await eventmanager.async_send_event(EventType.SiteUpdated, { "domain": site_in.domain }) return schemas.Response(success=True) @router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response) -def cookie_cloud_sync(background_tasks: BackgroundTasks, - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def cookie_cloud_sync(background_tasks: BackgroundTasks, + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 运行CookieCloud同步站点信息 """ @@ -111,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks, @router.get("/reset", summary="重置站点", response_model=schemas.Response) -def reset(db: Session = Depends(get_db), +def reset(db: AsyncSession = Depends(get_db), _: User = Depends(get_current_active_superuser)) -> Any: """ 清空所有站点数据并重新同步CookieCloud站点信息 @@ -122,25 +122,25 @@ def reset(db: Session = Depends(get_db), # 启动定时服务 Scheduler().start("cookiecloud", manual=True) # 插件站点删除 - EventManager().send_event(EventType.SiteDeleted, - { - "site_id": "*" - }) + eventmanager.send_event(EventType.SiteDeleted, + { + "site_id": "*" + }) return schemas.Response(success=True, message="站点已重置!") @router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response) -def update_sites_priority( +async def update_sites_priority( priorities: List[dict], - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 批量更新站点优先级 """ for priority in priorities: - site = Site.get(db, priority.get("id")) + site = await Site.async_get(db, priority.get("id")) if site: - site.update(db, {"pri": priority.get("pri")}) + await site.async_update(db, {"pri": priority.get("pri")}) return schemas.Response(success=True) @@ -151,7 +151,7 @@ def update_cookie( password: str, code: Optional[str] = None, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 使用用户密码更新站点Cookie """ @@ -174,7 +174,7 @@ def update_cookie( def refresh_userdata( site_id: int, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 刷新站点用户数据 """ @@ -192,34 +192,34 @@ def refresh_userdata( @router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData]) -def read_userdata_latest( - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def read_userdata_latest( + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 查询所有站点最新用户数据 """ - user_datas = SiteUserData.get_latest(db) + user_datas = await SiteUserData.async_get_latest(db) if not user_datas: return [] return [user_data.to_dict() for user_data in user_datas] @router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response) -def read_userdata( +async def read_userdata( site_id: int, workdate: Optional[str] = None, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 查询站点用户数据 """ - site = Site.get(db, site_id) + site = await Site.async_get(db, site_id) if not site: raise HTTPException( status_code=404, detail=f"站点 {site_id} 不存在", ) - user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate) + user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate) if not user_data: return schemas.Response(success=False, data=[]) return schemas.Response(success=True, data=user_data) @@ -264,19 +264,19 @@ async def site_icon(site_id: int, @router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory]) -def site_category(site_id: int, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def site_category(site_id: int, + db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 获取站点分类 """ - site = Site.get(db, site_id) + site = await Site.async_get(db, site_id) if not site: raise HTTPException( status_code=404, detail=f"站点 {site_id} 不存在", ) - indexer = SitesHelper().get_indexer(site.domain) + indexer = await SitesHelper().async_get_indexer(site.domain) if not indexer: raise HTTPException( status_code=404, @@ -294,38 +294,38 @@ def site_category(site_id: int, @router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo]) -def site_resource(site_id: int, - keyword: Optional[str] = None, - cat: Optional[str] = None, - page: Optional[int] = 0, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: +async def site_resource(site_id: int, + keyword: Optional[str] = None, + cat: Optional[str] = None, + page: Optional[int] = 0, + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 浏览站点资源 """ - site = Site.get(db, site_id) + site = await Site.async_get(db, site_id) if not site: raise HTTPException( status_code=404, detail=f"站点 {site_id} 不存在", ) - torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page) + torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page) if not torrents: return [] return [torrent.to_dict() for torrent in torrents] @router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site) -def read_site_by_domain( +async def read_site_by_domain( site_url: str, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 通过域名获取站点信息 """ domain = StringUtils.get_url_domain(site_url) - site = Site.get_by_domain(db, domain) + site = await Site.async_get_by_domain(db, domain) if not site: raise HTTPException( status_code=404, @@ -335,35 +335,35 @@ def read_site_by_domain( @router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic) -def read_statistic_by_domain( +async def read_statistic_by_domain( site_url: str, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 通过域名获取站点统计信息 """ domain = StringUtils.get_url_domain(site_url) - sitestatistic = SiteStatistic.get_by_domain(db, domain) + sitestatistic = await SiteStatistic.async_get_by_domain(db, domain) if sitestatistic: return sitestatistic return schemas.SiteStatistic(domain=domain) @router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic]) -def read_statistics( - db: Session = Depends(get_db), +async def read_statistics( + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 获取所有站点统计信息 """ - return SiteStatistic.list(db) + return await SiteStatistic.async_list(db) @router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site]) -def read_rss_sites(db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> List[dict]: +async def read_rss_sites(db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> List[dict]: """ 获取站点列表 """ @@ -371,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db), selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or [] # 所有站点 - all_site = Site.list_order_by_pri(db) + all_site = await Site.async_list_order_by_pri(db) if not selected_sites: return all_site @@ -381,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db), @router.get("/auth", summary="查询认证站点", response_model=dict) -def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict: +async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict: """ 获取可认证站点列表 """ @@ -409,12 +409,12 @@ def auth_site( @router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response) -def site_mapping(_: User = Depends(get_current_active_superuser)): +async def site_mapping(_: User = Depends(get_current_active_superuser_async)): """ 获取站点域名到名称的映射关系 """ try: - sites = SiteOper().list() + sites = await SiteOper().async_list() mapping = {} for site in sites: mapping[site.domain] = site.name @@ -424,7 +424,7 @@ def site_mapping(_: User = Depends(get_current_active_superuser)): @router.get("/supporting", summary="获取支持的站点列表", response_model=dict) -def support_sites(_: User = Depends(get_current_active_superuser)): +async def support_sites(_: User = Depends(get_current_active_superuser_async)): """ 获取支持的站点列表 """ @@ -432,15 +432,15 @@ def support_sites(_: User = Depends(get_current_active_superuser)): @router.get("/{site_id}", summary="站点详情", response_model=schemas.Site) -def read_site( +async def read_site( site_id: int, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser) + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async) ) -> Any: """ 通过ID获取站点信息 """ - site = Site.get(db, site_id) + site = await Site.async_get(db, site_id) if not site: raise HTTPException( status_code=404, @@ -450,18 +450,18 @@ def read_site( @router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response) -def delete_site( +async def delete_site( site_id: int, - db: Session = Depends(get_db), - _: User = Depends(get_current_active_superuser) + db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async) ) -> Any: """ 删除站点 """ - Site.delete(db, site_id) + await Site.async_delete(db, site_id) # 插件站点删除 - EventManager().send_event(EventType.SiteDeleted, - { - "site_id": site_id - }) + await eventmanager.async_send_event(EventType.SiteDeleted, + { + "site_id": site_id + }) return schemas.Response(success=True) diff --git a/app/api/endpoints/storage.py b/app/api/endpoints/storage.py index b1e22f46..6e500e16 100644 --- a/app/api/endpoints/storage.py +++ b/app/api/endpoints/storage.py @@ -12,7 +12,7 @@ from app.core.config import settings from app.core.metainfo import MetaInfoPath from app.core.security import verify_token from app.db.models import User -from app.db.user_oper import get_current_active_superuser +from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async from app.helper.progress import ProgressHelper from app.schemas.types import ProgressKey @@ -222,7 +222,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any: @router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType) -def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any: +async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any: """ 查询支持的整理方式 """ diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index 8646475d..fdaac97a 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -117,7 +117,7 @@ async def update_subscribe( subscribe_dict["manual_total_episode"] = 1 # 发送订阅调整事件 subscribe = await subscribe.async_get(db, subscribe_in.id) - eventmanager.send_event(EventType.SubscribeModified, { + await eventmanager.async_send_event(EventType.SubscribeModified, { "subscribe_id": subscribe_in.id, "old_subscribe_info": old_subscribe_dict, "subscribe_info": subscribe.to_dict(), @@ -145,7 +145,7 @@ async def update_subscribe_status( "state": state }) # 发送订阅调整事件 - eventmanager.send_event(EventType.SubscribeModified, { + eventmanager.async_send_event(EventType.SubscribeModified, { "subscribe_id": subscribe.id, "old_subscribe_info": old_subscribe_dict, "subscribe_info": subscribe.to_dict(), @@ -224,7 +224,7 @@ async def reset_subscribes( "state": "R" }) # 发送订阅调整事件 - eventmanager.send_event(EventType.SubscribeModified, { + eventmanager.async_send_event(EventType.SubscribeModified, { "subscribe_id": subscribe.id, "old_subscribe_info": old_subscribe_dict, "subscribe_info": subscribe.to_dict(), @@ -313,7 +313,7 @@ async def delete_subscribe_by_mediaid( for subscribe in delete_subscribes: await Subscribe.async_delete(db, subscribe.id) # 发送事件 - eventmanager.send_event(EventType.SubscribeDeleted, { + eventmanager.async_send_event(EventType.SubscribeDeleted, { "subscribe_id": subscribe.id, "subscribe_info": subscribe.to_dict() }) @@ -531,7 +531,7 @@ async def follow_subscriber( subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or [] if share_uid and share_uid not in subscribers: subscribers.append(share_uid) - SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers) + await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers) return schemas.Response(success=True) @@ -545,7 +545,7 @@ async def unfollow_subscriber( subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or [] if share_uid and share_uid in subscribers: subscribers.remove(share_uid) - SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers) + await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers) return schemas.Response(success=True) @@ -596,7 +596,7 @@ async def delete_subscribe( if subscribe: await Subscribe.async_delete(db, subscribe_id) # 发送事件 - eventmanager.send_event(EventType.SubscribeDeleted, { + eventmanager.async_send_event(EventType.SubscribeDeleted, { "subscribe_id": subscribe_id, "subscribe_info": subscribe.to_dict() }) diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 06aa501e..d957b615 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -23,7 +23,7 @@ from app.core.module import ModuleManager from app.core.security import verify_apitoken, verify_resource_token, verify_token from app.db.models import User from app.db.systemconfig_oper import SystemConfigOper -from app.db.user_oper import get_current_active_superuser +from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async from app.helper.mediaserver import MediaServerHelper from app.helper.message import MessageHelper from app.helper.progress import ProgressHelper @@ -202,7 +202,7 @@ def get_global_setting(token: str): @router.get("/env", summary="查询系统配置", response_model=schemas.Response) -def get_env_setting(_: User = Depends(get_current_active_superuser)): +async def get_env_setting(_: User = Depends(get_current_active_superuser_async)): """ 查询系统环境变量,包括当前版本号(仅管理员) """ @@ -220,8 +220,8 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)): @router.post("/env", summary="更新系统配置", response_model=schemas.Response) -def set_env_setting(env: dict, - _: User = Depends(get_current_active_superuser)): +async def set_env_setting(env: dict, + _: User = Depends(get_current_active_superuser_async)): """ 更新系统环境变量(仅管理员) """ @@ -243,7 +243,7 @@ def set_env_setting(env: dict, if success_updates: for key in success_updates.keys(): # 发送配置变更事件 - eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( + await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( key=key, value=getattr(settings, key, None), change_type="update" @@ -280,8 +280,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl @router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response) -def get_setting(key: str, - _: User = Depends(get_current_active_superuser)): +async def get_setting(key: str, + _: User = Depends(get_current_active_superuser_async)): """ 查询系统设置(仅管理员) """ @@ -295,10 +295,10 @@ def get_setting(key: str, @router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response) -def set_setting( +async def set_setting( key: str, value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None, - _: User = Depends(get_current_active_superuser), + _: User = Depends(get_current_active_superuser_async), ): """ 更新系统设置(仅管理员) @@ -307,7 +307,7 @@ def set_setting( success, message = settings.update_setting(key=key, value=value) if success: # 发送配置变更事件 - eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( + await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( key=key, value=value, change_type="update" @@ -319,10 +319,10 @@ def set_setting( if isinstance(value, list): value = list(filter(None, value)) value = value if value else None - success = SystemConfigOper().set(key, value) + success = await SystemConfigOper().async_set(key, value) if success: # 发送配置变更事件 - eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( + await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( key=key, value=value, change_type="update" diff --git a/app/api/endpoints/transfer.py b/app/api/endpoints/transfer.py index 70ec280e..698dc7fb 100644 --- a/app/api/endpoints/transfer.py +++ b/app/api/endpoints/transfer.py @@ -12,6 +12,7 @@ from app.core.config import settings from app.core.metainfo import MetaInfoPath from app.core.security import verify_token, verify_apitoken from app.db import get_db +from app.db.models import User from app.db.models.transferhistory import TransferHistory from app.db.user_oper import get_current_active_superuser from app.helper.directory import DirectoryHelper @@ -81,7 +82,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v def manual_transfer(transer_item: ManualTransferItem, background: Optional[bool] = False, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: + _: User = Depends(get_current_active_superuser)) -> Any: """ 手动转移,文件或历史记录,支持自定义剧集识别格式 :param transer_item: 手工整理项 diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 098cac93..a4133aba 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -25,7 +25,7 @@ async def list_users( """ 查询用户列表 """ - return await User.async_list(db) + return await current_user.async_list(db) @router.post("/", summary="新增用户", response_model=schemas.Response) @@ -38,7 +38,7 @@ async def create_user( """ 新增用户 """ - user = await User.async_get_by_name(db, name=user_in.name) + user = await current_user.async_get_by_name(db, name=user_in.name) if user: return schemas.Response(success=False, message="用户已存在") user_info = user_in.dict() @@ -68,12 +68,12 @@ async def update_user( message="密码需要同时包含字母、数字、特殊字符中的至少两项,且长度大于6位") user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info.pop("password") - user = await User.async_get_by_id(db, user_id=user_info["id"]) + user = await current_user.async_get_by_id(db, user_id=user_info["id"]) user_name = user_info.get("name") if not user_name: return schemas.Response(success=False, message="用户名不能为空") # 新用户名去重 - users = await User.async_list(db) + users = await current_user.async_list(db) for u in users: if u.name == user_name and u.id != user_info["id"]: return schemas.Response(success=False, message="用户名已被使用") @@ -185,10 +185,10 @@ async def delete_user_by_id( """ 通过唯一ID删除用户 """ - user = await User.async_get_by_id(db, user_id=user_id) + user = await current_user.async_get_by_id(db, user_id=user_id) if not user: return schemas.Response(success=False, message="用户不存在") - await User.async_delete(db, user_id) + await current_user.async_delete(db, user_id) return schemas.Response(success=True) @@ -202,10 +202,10 @@ async def delete_user_by_name( """ 通过用户名删除用户 """ - user = await User.async_get_by_name(db, name=user_name) + user = await current_user.async_get_by_name(db, name=user_name) if not user: return schemas.Response(success=False, message="用户不存在") - await User.async_delete(db, user.id) + await current_user.async_delete(db, user.id) return schemas.Response(success=True) @@ -218,7 +218,7 @@ async def read_user_by_name( """ 查询用户详情 """ - user = await User.async_get_by_name(db, name=username) + user = await current_user.async_get_by_name(db, name=username) if not user: raise HTTPException( status_code=404, diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 53e959b2..66e5a93a 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -657,6 +657,19 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page) + async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None, + cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]: + """ + 异步获取站点最新一页的种子,多个站点需要多线程处理 + :param site: 站点 + :param keyword: 标题 + :param cat: 分类 + :param page: 页码 + :reutrn: 种子资源列表 + """ + return await self.async_run_module("async_refresh_torrents", + site=site, keyword=keyword, cat=cat, page=page) + def filter_torrents(self, rule_groups: List[str], torrent_list: List[TorrentInfo], mediainfo: MediaInfo = None) -> List[TorrentInfo]: diff --git a/app/chain/site.py b/app/chain/site.py index 257564f7..0ce99ab8 100644 --- a/app/chain/site.py +++ b/app/chain/site.py @@ -4,12 +4,11 @@ from datetime import datetime from typing import Optional, Tuple, Union, Dict from urllib.parse import urljoin -from app.helper.sites import SitesHelper # noqa from lxml import etree from app.chain import ChainBase from app.core.config import global_vars, settings -from app.core.event import Event, EventManager, eventmanager +from app.core.event import Event, eventmanager from app.db.models.site import Site from app.db.site_oper import SiteOper from app.db.systemconfig_oper import SystemConfigOper @@ -18,6 +17,7 @@ from app.helper.cloudflare import under_challenge from app.helper.cookie import CookieHelper from app.helper.cookiecloud import CookieCloudHelper from app.helper.rss import RssHelper +from app.helper.sites import SitesHelper # noqa from app.log import logger from app.schemas import MessageChannel, Notification, SiteUserData from app.schemas.types import EventType, NotificationType @@ -58,7 +58,7 @@ class SiteChain(ChainBase): name=site.get("name"), payload=userdata.dict()) # 发送事件 - EventManager().send_event(EventType.SiteRefreshed, { + eventmanager.send_event(EventType.SiteRefreshed, { "site_id": site.get("id") }) # 发送站点消息 @@ -103,7 +103,7 @@ class SiteChain(ChainBase): any_site_updated = True result[site.get("name")] = userdata if any_site_updated: - EventManager().send_event(EventType.SiteRefreshed, { + eventmanager.send_event(EventType.SiteRefreshed, { "site_id": "*" }) @@ -415,7 +415,7 @@ class SiteChain(ChainBase): # 通知站点更新 if indexer: - EventManager().send_event(EventType.SiteUpdated, { + eventmanager.send_event(EventType.SiteUpdated, { "domain": domain, }) # 处理完成 diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 6528dd80..53ce22c8 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -15,7 +15,7 @@ from app.chain.tmdb import TmdbChain from app.chain.torrents import TorrentsChain from app.core.config import settings, global_vars from app.core.context import TorrentInfo, Context, MediaInfo -from app.core.event import eventmanager, Event, EventManager +from app.core.event import eventmanager, Event from app.core.meta import MetaBase from app.core.meta.words import WordsMatcher from app.core.metainfo import MetaInfo @@ -237,7 +237,7 @@ class SubscribeChain(ChainBase): username=username ) # 发送事件 - EventManager().send_event(EventType.SubscribeAdded, { + eventmanager.send_event(EventType.SubscribeAdded, { "subscribe_id": sid, "username": username, "mediainfo": mediainfo.to_dict(), @@ -1090,7 +1090,7 @@ class SubscribeChain(ChainBase): username=subscribe.username ) # 发送事件 - EventManager().send_event(EventType.SubscribeComplete, { + eventmanager.send_event(EventType.SubscribeComplete, { "subscribe_id": subscribe.id, "subscribe_info": subscribe.to_dict(), "mediainfo": mediainfo.to_dict(), diff --git a/app/chain/system.py b/app/chain/system.py index 9685decc..874e3012 100644 --- a/app/chain/system.py +++ b/app/chain/system.py @@ -7,12 +7,11 @@ from typing import Union, Optional from app.chain import ChainBase from app.core.config import settings from app.core.plugin import PluginManager +from app.helper.system import SystemHelper from app.log import logger from app.schemas import Notification, MessageChannel from app.utils.http import RequestUtils from app.utils.system import SystemUtils -from app.helper.system import SystemHelper -from app.helper.plugin import PluginHelper from version import FRONTEND_VERSION, APP_VERSION diff --git a/app/chain/torrents.py b/app/chain/torrents.py index 01188a89..4b89b6a6 100644 --- a/app/chain/torrents.py +++ b/app/chain/torrents.py @@ -85,6 +85,22 @@ class TorrentsChain(ChainBase): return [] return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page) + async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None, + page: Optional[int] = 0) -> List[TorrentInfo]: + """ + 异步浏览站点首页内容,返回种子清单,TTL缓存5分钟 + :param domain: 站点域名 + :param keyword: 搜索标题 + :param cat: 搜索分类 + :param page: 页码 + """ + logger.info(f'开始获取站点 {domain} 最新种子 ...') + site = await SitesHelper().async_get_indexer(domain) + if not site: + logger.error(f'站点 {domain} 不存在!') + return [] + return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page) + def rss(self, domain: str) -> List[TorrentInfo]: """ 获取站点RSS内容,返回种子清单,TTL缓存3分钟 diff --git a/app/db/models/message.py b/app/db/models/message.py index 55841272..33678672 100644 --- a/app/db/models/message.py +++ b/app/db/models/message.py @@ -1,9 +1,10 @@ from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, JSON +from sqlalchemy import Column, Integer, String, Sequence, JSON, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, Base +from app.db import db_query, Base, async_db_query class Message(Base): @@ -38,3 +39,11 @@ class Message(Base): @db_query def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30): return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all() + + @classmethod + @async_db_query + async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30): + result = await db.execute( + select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count) + ) + return result.scalars().all() diff --git a/app/db/models/site.py b/app/db/models/site.py index edef9604..d6d856a1 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -1,10 +1,10 @@ from datetime import datetime -from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select +from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, db_update, Base, async_db_query +from app.db import db_query, db_update, Base, async_db_query, async_db_update class Site(Base): @@ -82,6 +82,12 @@ class Site(Base): def list_order_by_pri(cls, db: Session): return db.query(cls).order_by(cls.pri).all() + @classmethod + @async_db_query + async def async_list_order_by_pri(cls, db: AsyncSession): + result = await db.execute(select(cls).order_by(cls.pri)) + return result.scalars().all() + @classmethod @db_query def get_domains_by_ids(cls, db: Session, ids: list): @@ -91,3 +97,8 @@ class Site(Base): @db_update def reset(cls, db: Session): db.query(cls).delete() + + @classmethod + @async_db_update + async def async_reset(cls, db: AsyncSession): + await db.execute(delete(cls)) diff --git a/app/db/models/siteuserdata.py b/app/db/models/siteuserdata.py index b1f5b46b..a18d356a 100644 --- a/app/db/models/siteuserdata.py +++ b/app/db/models/siteuserdata.py @@ -1,10 +1,11 @@ from datetime import datetime from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_ +from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, Base +from app.db import db_query, Base, async_db_query class SiteUserData(Base): @@ -65,6 +66,17 @@ class SiteUserData(Base): cls.updated_day == workdate).all() return db.query(cls).filter(cls.domain == domain).all() + @classmethod + @async_db_query + async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None): + query = select(cls).filter(cls.domain == domain) + if workdate and worktime: + query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime) + elif workdate: + query = query.filter(cls.updated_day == workdate) + result = await db.execute(query) + return result.scalars().all() + @classmethod @db_query def get_by_date(cls, db: Session, date: str): @@ -92,3 +104,28 @@ class SiteUserData(Base): (cls.domain == subquery.c.domain) & (cls.updated_day == subquery.c.latest_update_day) ).order_by(cls.updated_time.desc()).all() + + @classmethod + @async_db_query + async def async_get_latest(cls, db: AsyncSession): + """ + 异步获取各站点最新一天的数据 + """ + subquery = ( + select( + cls.domain, + func.max(cls.updated_day).label('latest_update_day') + ) + .group_by(cls.domain) + .filter(or_(cls.err_msg.is_(None), cls.err_msg == "")) + .subquery() + ) + + # 主查询:按 domain 和 updated_day 获取最新的记录 + result = await db.execute( + select(cls).join( + subquery, + (cls.domain == subquery.c.domain) & + (cls.updated_day == subquery.c.latest_update_day) + ).order_by(cls.updated_time.desc())) + return result.scalars().all() diff --git a/app/db/models/systemconfig.py b/app/db/models/systemconfig.py index 254c725d..5ecaa039 100644 --- a/app/db/models/systemconfig.py +++ b/app/db/models/systemconfig.py @@ -1,7 +1,8 @@ -from sqlalchemy import Column, Integer, String, Sequence, JSON +from sqlalchemy import Column, Integer, String, Sequence, JSON, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, db_update, Base +from app.db import db_query, db_update, Base, async_db_query class SystemConfig(Base): @@ -19,6 +20,12 @@ class SystemConfig(Base): def get_by_key(cls, db: Session, key: str): return db.query(cls).filter(cls.key == key).first() + @classmethod + @async_db_query + async def async_get_by_key(cls, db: AsyncSession, key: str): + result = await db.execute(select(cls).where(cls.key == key)) + return result.scalar_one_or_none() + @db_update def delete_by_key(self, db: Session, key: str): systemconfig = self.get_by_key(db, key) diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index fdba6267..7f8c9588 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -173,6 +173,21 @@ class TransferHistory(Base): time.localtime(time.time() - 86400 * days))).subquery() return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() + @classmethod + @async_db_query + async def async_statistic(cls, db: AsyncSession, days: Optional[int] = 7): + """ + 统计最近days天的下载历史数量,按日期分组返回每日数量 + """ + sub_query = select(func.substr(cls.date, 1, 10).label('date'), + cls.id.label('id')).filter( + cls.date >= time.strftime("%Y-%m-%d %H:%M:%S", + time.localtime(time.time() - 86400 * days))).subquery() + result = await db.execute( + select(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date) + ) + return result.scalars().all() + @classmethod @db_query def count(cls, db: Session, status: bool = None): diff --git a/app/db/site_oper.py b/app/db/site_oper.py index 723c0789..9149b733 100644 --- a/app/db/site_oper.py +++ b/app/db/site_oper.py @@ -35,6 +35,12 @@ class SiteOper(DbOper): """ return Site.list(self._db) + async def async_list(self) -> List[Site]: + """ + 异步获取站点列表 + """ + return await Site.async_list(self._db) + def list_order_by_pri(self) -> List[Site]: """ 获取站点列表 diff --git a/app/db/systemconfig_oper.py b/app/db/systemconfig_oper.py index 4fdc92b5..47d27ae2 100644 --- a/app/db/systemconfig_oper.py +++ b/app/db/systemconfig_oper.py @@ -47,6 +47,33 @@ class SystemConfigOper(DbOper, metaclass=Singleton): conf.create(self._db) return True + async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]: + """ + 异步设置系统设置 + :param key: 配置键 + :param value: 配置值 + :return: 是否设置成功(True 成功/False 失败/None 无需更新) + """ + if isinstance(key, SystemConfigKey): + key = key.value + # 旧值 + old_value = self.__SYSTEMCONF.get(key) + # 更新内存(deepcopy避免内存共享) + self.__SYSTEMCONF[key] = copy.deepcopy(value) + conf = await SystemConfig.async_get_by_key(self._db, key) + if conf: + if old_value != value: + if value: + conf.update(self._db, {"value": value}) + else: + conf.delete(self._db, conf.id) + return True + return None + else: + conf = SystemConfig(key=key, value=value) + await conf.async_create(self._db) + return True + def get(self, key: Union[str, SystemConfigKey] = None) -> Any: """ 获取系统设置 diff --git a/app/helper/browser.py b/app/helper/browser.py index 7c36ce5a..dc598849 100644 --- a/app/helper/browser.py +++ b/app/helper/browser.py @@ -60,7 +60,6 @@ class PlaywrightHelper: except Exception as e: logger.error(f"网页操作失败: {str(e)}") finally: - # 确保资源被正确清理 if page: page.close() if context: diff --git a/app/helper/cookie.py b/app/helper/cookie.py index 07bc536f..f57104d1 100644 --- a/app/helper/cookie.py +++ b/app/helper/cookie.py @@ -144,53 +144,50 @@ class CookieHelper: break if not submit_xpath: return None, None, "未找到登录按钮" - finally: - if html is not None: - del html - # 点击登录按钮 - try: - # 等待登录按钮准备好 - page.wait_for_selector(submit_xpath) - # 输入用户名 - page.fill(username_xpath, username) - # 输入密码 - page.fill(password_xpath, password) - # 输入二步验证码 - if twostep_xpath: - page.fill(twostep_xpath, otp_code) - # 识别验证码 - if captcha_xpath and captcha_img_url: - captcha_element = page.query_selector(captcha_xpath) - if captcha_element.is_visible(): - # 验证码图片地址 - code_url = self.__get_captcha_url(url, captcha_img_url) - # 获取当前的cookie和ua - cookie = self.parse_cookies(page.context.cookies()) - ua = page.evaluate("() => window.navigator.userAgent") - # 自动OCR识别验证码 - captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url) - if captcha: - logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha)) - else: - return None, None, "验证码识别失败" - # 输入验证码 - captcha_element.fill(captcha) - else: - # 不可见元素不处理 - pass # 点击登录按钮 - page.click(submit_xpath) - page.wait_for_load_state("networkidle", timeout=30 * 1000) - except Exception as e: - logger.error(f"仿真登录失败:{str(e)}") - return None, None, f"仿真登录失败:{str(e)}" - # 对于某二次验证码为单页面的站点,输入二次验证码 - if "verify" in page.url: - if not otp_code: - return None, None, "需要二次验证码" - html = etree.HTML(page.content()) try: + # 等待登录按钮准备好 + page.wait_for_selector(submit_xpath) + # 输入用户名 + page.fill(username_xpath, username) + # 输入密码 + page.fill(password_xpath, password) + # 输入二步验证码 + if twostep_xpath: + page.fill(twostep_xpath, otp_code) + # 识别验证码 + if captcha_xpath and captcha_img_url: + captcha_element = page.query_selector(captcha_xpath) + if captcha_element.is_visible(): + # 验证码图片地址 + code_url = self.__get_captcha_url(url, captcha_img_url) + # 获取当前的cookie和ua + cookie = self.parse_cookies(page.context.cookies()) + ua = page.evaluate("() => window.navigator.userAgent") + # 自动OCR识别验证码 + captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url) + if captcha: + logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha)) + else: + return None, None, "验证码识别失败" + # 输入验证码 + captcha_element.fill(captcha) + else: + # 不可见元素不处理 + pass + # 点击登录按钮 + page.click(submit_xpath) + page.wait_for_load_state("networkidle", timeout=30 * 1000) + except Exception as e: + logger.error(f"仿真登录失败:{str(e)}") + return None, None, f"仿真登录失败:{str(e)}" + + # 对于某二次验证码为单页面的站点,输入二次验证码 + if "verify" in page.url: + if not otp_code: + return None, None, "需要二次验证码" + html = etree.HTML(page.content()) for xpath in self._SITE_LOGIN_XPATH.get("twostep"): if html.xpath(xpath): try: @@ -204,28 +201,29 @@ class CookieHelper: logger.error(f"二次验证码输入失败:{str(e)}") return None, None, f"二次验证码输入失败:{str(e)}" break - finally: - if html is not None: - del html - # 登录后的源码 - html_text = page.content() - if not html_text: - return None, None, "获取网页源码失败" - if SiteUtils.is_logged_in(html_text): - return self.parse_cookies(page.context.cookies()), \ - page.evaluate("() => window.navigator.userAgent"), "" - else: - # 读取错误信息 - error_xpath = None - for xpath in self._SITE_LOGIN_XPATH.get("error"): - if html.xpath(xpath): - error_xpath = xpath - break - if not error_xpath: - return None, None, "登录失败" + + # 登录后的源码 + html_text = page.content() + if not html_text: + return None, None, "获取网页源码失败" + if SiteUtils.is_logged_in(html_text): + return self.parse_cookies(page.context.cookies()), \ + page.evaluate("() => window.navigator.userAgent"), "" else: - error_msg = html.xpath(error_xpath)[0] - return None, None, error_msg + # 读取错误信息 + error_xpath = None + for xpath in self._SITE_LOGIN_XPATH.get("error"): + if html.xpath(xpath): + error_xpath = xpath + break + if not error_xpath: + return None, None, "登录失败" + else: + error_msg = html.xpath(error_xpath)[0] + return None, None, error_msg + finally: + if html: + del html if not url or not username or not password: return None, None, "参数错误" diff --git a/app/modules/indexer/__init__.py b/app/modules/indexer/__init__.py index a25ea5da..5a528b25 100644 --- a/app/modules/indexer/__init__.py +++ b/app/modules/indexer/__init__.py @@ -457,6 +457,20 @@ class IndexerModule(_ModuleBase): """ return self.search_torrents(site=site, keywords=[keyword], cat=cat, page=page) + async def async_refresh_torrents(self, site: dict, + keyword: Optional[str] = None, + cat: Optional[str] = None, + page: Optional[int] = 0) -> Optional[List[TorrentInfo]]: + """ + 异步获取站点最新一页的种子,多个站点需要多线程处理 + :param site: 站点 + :param keyword: 关键字 + :param cat: 分类 + :param page: 页码 + :reutrn: 种子资源列表 + """ + return await self.async_search_torrents(site=site, keywords=[keyword], cat=cat, page=page) + def refresh_userdata(self, site: dict) -> Optional[SiteUserData]: """ 刷新站点的用户数据 diff --git a/app/scheduler.py b/app/scheduler.py index 478f36d7..1b8e6ce4 100644 --- a/app/scheduler.py +++ b/app/scheduler.py @@ -21,6 +21,7 @@ from app.core.config import settings from app.core.event import EventManager, eventmanager, Event from app.core.plugin import PluginManager from app.db.systemconfig_oper import SystemConfigOper +from app.helper.message import MessageHelper from app.helper.sites import SitesHelper # noqa from app.helper.wallpaper import WallpaperHelper from app.log import logger @@ -380,46 +381,60 @@ class Scheduler(metaclass=Singleton): # 启动定时服务 self._scheduler.start() + def __prepare_job(self, job_id: str) -> Optional[dict]: + """ + 准备定时任务 + """ + with self._lock: + job = self._jobs.get(job_id) + if not job: + return None + if job.get("running"): + logger.warning(f"定时任务 {job_id} - {job.get("name")} 正在运行 ...") + return None + self._jobs[job_id]["running"] = True + return job + + def __finish_job(self, job_id: str): + """ + 完成定时任务 + """ + with self._lock: + try: + self._jobs[job_id]["running"] = False + except KeyError: + pass + def start(self, job_id: str, *args, **kwargs): """ 启动定时服务 """ - # 处理job_id格式 - with self._lock: - job = self._jobs.get(job_id) - if not job: - return - job_name = job.get("name") - if job.get("running"): - logger.warning(f"定时任务 {job_id} - {job_name} 正在运行 ...") - return - self._jobs[job_id]["running"] = True + # 获取定时任务 + job = self.__prepare_job(job_id) + if not job: + return # 开始运行 try: if not kwargs: kwargs = job.get("kwargs") or {} job["func"](*args, **kwargs) except Exception as e: - logger.error(f"定时任务 {job_name} 执行失败:{str(e)} - {traceback.format_exc()}") - SchedulerChain().messagehelper.put(title=f"{job_name} 执行失败", - message=str(e), - role="system") - EventManager().send_event( + logger.error(f"定时任务 {job.get('name')} 执行失败:{str(e)} - {traceback.format_exc()}") + MessageHelper().put(title=f"{job.get('name')} 执行失败", + message=str(e), + role="system") + eventmanager.send_event( EventType.SystemError, { "type": "scheduler", "scheduler_id": job_id, - "scheduler_name": job_name, + "scheduler_name": job.get('name'), "error": str(e), "traceback": traceback.format_exc() } ) # 运行结束 - with self._lock: - try: - self._jobs[job_id]["running"] = False - except KeyError: - pass + self.__finish_job(job_id) def init_plugin_jobs(self): """