diff --git a/app/chain/cookiecloud.py b/app/chain/cookiecloud.py index 298efc92..f4d65086 100644 --- a/app/chain/cookiecloud.py +++ b/app/chain/cookiecloud.py @@ -6,8 +6,8 @@ from lxml import etree from app.chain import ChainBase from app.core.config import settings -from app.db.siteicons import SiteIcons -from app.db.sites import Sites +from app.db.siteicon_oper import SiteIconOper +from app.db.site_oper import SiteOper from app.helper.cookiecloud import CookieCloudHelper from app.helper.sites import SitesHelper from app.log import logger @@ -21,8 +21,8 @@ class CookieCloudChain(ChainBase): def __init__(self): super().__init__() - self.sites = Sites() - self.siteicons = SiteIcons() + self.siteoper = SiteOper() + self.siteiconoper = SiteIconOper() self.siteshelper = SitesHelper() self.cookiecloud = CookieCloudHelper( server=settings.COOKIECLOUD_HOST, @@ -45,16 +45,16 @@ class CookieCloudChain(ChainBase): for domain, cookie in cookies.items(): # 获取站点信息 indexer = self.siteshelper.get_indexer(domain) - if self.sites.exists(domain): + if self.siteoper.exists(domain): # 更新站点Cookie - self.sites.update_cookie(domain=domain, cookies=cookie) + self.siteoper.update_cookie(domain=domain, cookies=cookie) _update_count += 1 elif indexer: # 新增站点 - self.sites.add(name=indexer.get("name"), - url=indexer.get("domain"), - domain=domain, - cookie=cookie) + self.siteoper.add(name=indexer.get("name"), + url=indexer.get("domain"), + domain=domain, + cookie=cookie) _add_count += 1 # 保存站点图标 if indexer: @@ -62,10 +62,10 @@ class CookieCloudChain(ChainBase): cookie=cookie, ua=settings.USER_AGENT) if icon_url: - self.siteicons.update_icon(name=indexer.get("name"), - domain=domain, - icon_url=icon_url, - icon_base64=icon_base64) + self.siteiconoper.update_icon(name=indexer.get("name"), + domain=domain, + icon_url=icon_url, + icon_base64=icon_base64) # 处理完成 ret_msg = f"更新了{_update_count}个站点,新增了{_add_count}个站点" logger.info(f"CookieCloud同步成功:{ret_msg}") diff --git a/app/chain/site_manage.py b/app/chain/site_manage.py index 3ca6e607..19c95128 100644 --- a/app/chain/site_manage.py +++ b/app/chain/site_manage.py @@ -1,5 +1,5 @@ from app.chain import ChainBase -from app.db.sites import Sites +from app.db.site_oper import SiteOper class SiteManageChain(ChainBase): @@ -7,17 +7,17 @@ class SiteManageChain(ChainBase): 站点远程管理处理链 """ - _sites: Sites = None + _sites: SiteOper = None def __init__(self): super().__init__() - self._sites = Sites() + self._siteoper = SiteOper() def process(self): """ 查询所有站点,发送消息 """ - site_list = self._sites.list() + site_list = self._siteoper.list() if not site_list: self.post_message(title="没有维护任何站点信息!") title = f"共有 {len(site_list)} 个站点,回复 `/site_disable` `[id]` 禁用站点,回复 `/site_enable` `[id]` 启用站点:" @@ -44,12 +44,12 @@ class SiteManageChain(ChainBase): if not arg_str.isdigit(): return site_id = int(arg_str) - site = self._sites.get(site_id) + site = self._siteoper.get(site_id) if not site: self.post_message(title=f"站点编号 {site_id} 不存在!") return # 禁用站点 - self._sites.update(site_id, { + self._siteoper.update(site_id, { "is_active": False }) # 重新发送消息 @@ -65,12 +65,12 @@ class SiteManageChain(ChainBase): if not arg_str.isdigit(): return site_id = int(arg_str) - site = self._sites.get(site_id) + site = self._siteoper.get(site_id) if not site: self.post_message(title=f"站点编号 {site_id} 不存在!") return # 禁用站点 - self._sites.update(site_id, { + self._siteoper.update(site_id, { "is_active": True }) # 重新发送消息 diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 8bf92c94..bc7dd4d6 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -6,7 +6,7 @@ from app.chain.search import SearchChain from app.core.metainfo import MetaInfo from app.core.context import TorrentInfo, Context, MediaInfo from app.core.config import settings -from app.db.subscribes import Subscribes +from app.db.subscribe_oper import SubscribeOper from app.helper.sites import SitesHelper from app.log import logger from app.schemas.context import NotExistMediaInfo @@ -26,7 +26,7 @@ class SubscribeChain(ChainBase): super().__init__() self.downloadchain = DownloadChain() self.searchchain = SearchChain() - self.subscribes = Subscribes() + self.subscribehelper = SubscribeOper() self.siteshelper = SitesHelper() def process(self, title: str, year: str, @@ -89,7 +89,7 @@ class SubscribeChain(ChainBase): 'lack_episode': kwargs.get('total_episode') }) # 添加订阅 - sid, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs) + sid, err_msg = self.subscribehelper.add(mediainfo, season=season, **kwargs) if not sid: logger.error(f'{mediainfo.title_year} {err_msg}') # 发回原用户 @@ -115,15 +115,15 @@ class SubscribeChain(ChainBase): :return: 更新订阅状态为R或删除订阅 """ if sid: - subscribes = [self.subscribes.get(sid)] + subscribes = [self.subscribehelper.get(sid)] else: - subscribes = self.subscribes.list(state) + subscribes = self.subscribehelper.list(state) # 遍历订阅 for subscribe in subscribes: logger.info(f'开始搜索订阅,标题:{subscribe.name} ...') # 如果状态为N则更新为R if subscribe.state == 'N': - self.subscribes.update(subscribe.id, {'state': 'R'}) + self.subscribehelper.update(subscribe.id, {'state': 'R'}) # 生成元数据 meta = MetaInfo(subscribe.name) meta.year = subscribe.year @@ -138,7 +138,7 @@ class SubscribeChain(ChainBase): exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo) if exist_flag: logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅') - self.subscribes.delete(subscribe.id) + self.subscribehelper.delete(subscribe.id) # 发送通知 self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', image=mediainfo.get_message_image()) @@ -165,7 +165,7 @@ class SubscribeChain(ChainBase): if downloads and not lefts: # 全部下载完成 logger.info(f'{mediainfo.title_year} 下载完成,完成订阅') - self.subscribes.delete(subscribe.id) + self.subscribehelper.delete(subscribe.id) # 发送通知 self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', image=mediainfo.get_message_image()) @@ -224,7 +224,7 @@ class SubscribeChain(ChainBase): 从缓存中匹配订阅,并自动下载 """ # 所有订阅 - subscribes = self.subscribes.list('R') + subscribes = self.subscribehelper.list('R') # 遍历订阅 for subscribe in subscribes: logger.info(f'开始匹配订阅,标题:{subscribe.name} ...') @@ -242,7 +242,7 @@ class SubscribeChain(ChainBase): exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo) if exist_flag: logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅') - self.subscribes.delete(subscribe.id) + self.subscribehelper.delete(subscribe.id) # 发送通知 self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', image=mediainfo.get_message_image()) @@ -278,7 +278,7 @@ class SubscribeChain(ChainBase): if downloads and not lefts: # 全部下载完成 logger.info(f'{mediainfo.title_year} 下载完成,完成订阅') - self.subscribes.delete(subscribe.id) + self.subscribehelper.delete(subscribe.id) # 发送通知 self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', image=mediainfo.get_message_image()) @@ -291,7 +291,7 @@ class SubscribeChain(ChainBase): left_episodes = season_info.get('episodes') logger.info(f'{mediainfo.title_year} 季 {season} 未下载完整,' f'更新缺失集数为{len(left_episodes)} ...') - self.subscribes.update(subscribe.id, { + self.subscribehelper.update(subscribe.id, { "lack_episode": len(left_episodes) }) @@ -299,7 +299,7 @@ class SubscribeChain(ChainBase): """ 查询订阅并发送消息 """ - subscribes = self.subscribes.list() + subscribes = self.subscribehelper.list() if not subscribes: self.post_message(title='没有任何订阅!') return @@ -328,12 +328,12 @@ class SubscribeChain(ChainBase): if not arg_str.isdigit(): return subscribe_id = int(arg_str) - subscribe = self.subscribes.get(subscribe_id) + subscribe = self.subscribehelper.get(subscribe_id) if not subscribe: self.post_message(title=f"订阅编号 {subscribe_id} 不存在!") return # 删除订阅 - self.subscribes.delete(subscribe_id) + self.subscribehelper.delete(subscribe_id) # 重新发送消息 self.list() diff --git a/app/core/plugin.py b/app/core/plugin.py index 04f7c4b1..8fea972d 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -1,7 +1,7 @@ import traceback from typing import List, Any -from app.db.systemconfigs import SystemConfigs +from app.db.systemconfig_oper import SystemConfigOper from app.helper.module import ModuleHelper from app.log import logger from app.utils.singleton import Singleton @@ -11,7 +11,7 @@ class PluginManager(metaclass=Singleton): """ 插件管理器 """ - systemconfigs: SystemConfigs = None + systemconfigs: SystemConfigOper = None # 插件列表 _plugins: dict = {} @@ -24,7 +24,7 @@ class PluginManager(metaclass=Singleton): self.init_config() def init_config(self): - self.systemconfigs = SystemConfigs() + self.systemconfigs = SystemConfigOper() # 停止已有插件 self.stop() # 启动插件 diff --git a/app/db/__init__.py b/app/db/__init__.py index 54a690cf..44b0c613 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,5 +1,5 @@ from sqlalchemy import create_engine, QueuePool -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, Session from app.core.config import settings @@ -27,3 +27,15 @@ def get_db(): finally: if db: db.close() + + +class DbOper: + + _db: Session = None + + def __init__(self, _db=SessionLocal()): + self._db = _db + + def __del__(self): + if self._db: + self._db.close() diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 953b90f9..4a6b5842 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -11,7 +11,6 @@ class Base: def create(self, db): db.add(self) db.commit() - db.refresh(self) return self @classmethod diff --git a/app/db/plugindata_oper.py b/app/db/plugindata_oper.py new file mode 100644 index 00000000..1f2da2dd --- /dev/null +++ b/app/db/plugindata_oper.py @@ -0,0 +1,35 @@ +import json +from typing import Any + +from app.db import DbOper +from app.db.models import Base +from app.db.models.plugin import PluginData +from app.utils.object import ObjectUtils + + +class PluginDataOper(DbOper): + """ + 插件数据管理 + """ + + def save(self, plugin_id: str, key: str, value: Any) -> Base: + """ + 保存插件数据 + :param plugin_id: 插件id + :param key: 数据key + :param value: 数据值 + """ + if ObjectUtils.is_obj(value): + value = json.dumps(value) + plugin = PluginData(plugin_id=plugin_id, key=key, value=value) + return plugin.create(self._db) + + def get_data(self, key: str) -> Any: + """ + 获取插件数据 + :param key: 数据key + """ + data = PluginData.get_plugin_data_by_key(self._db, self.__class__.__name__, key) + if ObjectUtils.is_obj(data): + return json.load(data) + return data diff --git a/app/db/sites.py b/app/db/site_oper.py similarity index 91% rename from app/db/sites.py rename to app/db/site_oper.py index 1621c260..67d6a122 100644 --- a/app/db/sites.py +++ b/app/db/site_oper.py @@ -1,19 +1,13 @@ from typing import Tuple, List -from sqlalchemy.orm import Session - -from app.db import SessionLocal +from app.db import DbOper from app.db.models.site import Site -class Sites: +class SiteOper(DbOper): """ 站点管理 """ - _db: Session = None - - def __init__(self, _db=SessionLocal()): - self._db = _db def add(self, **kwargs) -> Tuple[bool, str]: """ diff --git a/app/db/siteicons.py b/app/db/siteicon_oper.py similarity index 83% rename from app/db/siteicons.py rename to app/db/siteicon_oper.py index ca0ba5e3..b0569f01 100644 --- a/app/db/siteicons.py +++ b/app/db/siteicon_oper.py @@ -1,19 +1,13 @@ from typing import List -from sqlalchemy.orm import Session - -from app.db import SessionLocal +from app.db import DbOper from app.db.models.siteicon import SiteIcon -class SiteIcons: +class SiteIconOper(DbOper): """ 站点管理 """ - _db: Session = None - - def __init__(self, _db=SessionLocal()): - self._db = _db def list(self) -> List[SiteIcon]: """ diff --git a/app/db/subscribes.py b/app/db/subscribe_oper.py similarity index 90% rename from app/db/subscribes.py rename to app/db/subscribe_oper.py index 5229830c..32fb873a 100644 --- a/app/db/subscribes.py +++ b/app/db/subscribe_oper.py @@ -1,20 +1,14 @@ from typing import Tuple, List -from sqlalchemy.orm import Session - from app.core.context import MediaInfo -from app.db import SessionLocal +from app.db import DbOper from app.db.models.subscribe import Subscribe -class Subscribes: +class SubscribeOper(DbOper): """ 订阅管理 """ - _db: Session = None - - def __init__(self, _db=SessionLocal()): - self._db = _db def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]: """ diff --git a/app/db/systemconfigs.py b/app/db/systemconfig_oper.py similarity index 88% rename from app/db/systemconfigs.py rename to app/db/systemconfig_oper.py index 119c3719..8daa6ac7 100644 --- a/app/db/systemconfigs.py +++ b/app/db/systemconfig_oper.py @@ -1,25 +1,22 @@ import json from typing import Any, Union -from sqlalchemy.orm import Session - -from app.db import SessionLocal +from app.db import DbOper from app.db.models.systemconfig import SystemConfig from app.utils.object import ObjectUtils from app.utils.singleton import Singleton from app.utils.types import SystemConfigKey -class SystemConfigs(metaclass=Singleton): +class SystemConfigOper(DbOper, metaclass=Singleton): # 配置对象 __SYSTEMCONF: dict = {} - _db: Session = None - def __init__(self, _db=SessionLocal()): + def __init__(self): """ 加载配置到内存 """ - self._db = _db + super().__init__() for item in SystemConfig.list(self._db): if ObjectUtils.is_obj(item.value): self.__SYSTEMCONF[item.key] = json.loads(item.value) diff --git a/app/helper/sites.cp310-win_amd64.pyd b/app/helper/sites.cp310-win_amd64.pyd index b7237775..ceafd246 100644 Binary files a/app/helper/sites.cp310-win_amd64.pyd and b/app/helper/sites.cp310-win_amd64.pyd differ diff --git a/app/plugins/__init__.py b/app/plugins/__init__.py index d74aafa3..8e220b85 100644 --- a/app/plugins/__init__.py +++ b/app/plugins/__init__.py @@ -1,15 +1,12 @@ -import json from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Any from app.chain import ChainBase from app.core.config import settings -from app.db import SessionLocal from app.db.models import Base -from app.db.models.plugin import PluginData -from app.db.systemconfigs import SystemConfigs -from app.utils.object import ObjectUtils +from app.db.plugindata_oper import PluginDataOper +from app.db.systemconfig_oper import SystemConfigOper class PluginChian(ChainBase): @@ -39,8 +36,9 @@ class _PluginBase(metaclass=ABCMeta): plugin_desc: str = "" def __init__(self): - self.db = SessionLocal() + self.plugindata = PluginDataOper() self.chain = PluginChian() + self.systemconfig = SystemConfigOper() @abstractmethod def init_plugin(self, config: dict = None): @@ -65,7 +63,7 @@ class _PluginBase(metaclass=ABCMeta): """ if not plugin_id: plugin_id = self.__class__.__name__ - return SystemConfigs().set(f"plugin.{plugin_id}", config) + return self.systemconfig.set(f"plugin.{plugin_id}", config) def get_config(self, plugin_id: str = None) -> Any: """ @@ -74,7 +72,7 @@ class _PluginBase(metaclass=ABCMeta): """ if not plugin_id: plugin_id = self.__class__.__name__ - return SystemConfigs().get(f"plugin.{plugin_id}") + return self.systemconfig.get(f"plugin.{plugin_id}") def get_data_path(self, plugin_id: str = None) -> Path: """ @@ -93,17 +91,11 @@ class _PluginBase(metaclass=ABCMeta): :param key: 数据key :param value: 数据值 """ - if ObjectUtils.is_obj(value): - value = json.dumps(value) - plugin = PluginData(plugin_id=self.__class__.__name__, key=key, value=value) - return plugin.create(self.db) + return self.plugindata.save(self.__class__.__name__, key, value) def get_data(self, key: str) -> Any: """ 获取插件数据 :param key: 数据key """ - data = PluginData.get_plugin_data_by_key(self.db, self.__class__.__name__, key) - if ObjectUtils.is_obj(data): - return json.load(data) - return data + return self.plugindata.get_data(key)