diff --git a/app/agent/tools/impl/scrape_metadata.py b/app/agent/tools/impl/scrape_metadata.py index 085d37cd..3d547c5a 100644 --- a/app/agent/tools/impl/scrape_metadata.py +++ b/app/agent/tools/impl/scrape_metadata.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.media import MediaChain -from app.core.config import GlobalVar +from app.core.config import global_vars from app.core.metainfo import MetaInfoPath from app.log import logger from app.schemas import FileItem @@ -17,9 +17,12 @@ from app.schemas import FileItem class ScrapeMetadataInput(BaseModel): """刮削媒体元数据工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - path: str = Field(..., description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')") - storage: Optional[str] = Field("local", description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')") - overwrite: Optional[bool] = Field(False, description="Whether to overwrite existing metadata files (default: False)") + path: str = Field(..., + description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')") + storage: Optional[str] = Field("local", + description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')") + overwrite: Optional[bool] = Field(False, + description="Whether to overwrite existing metadata files (default: False)") class ScrapeMetadataTool(MoviePilotTool): @@ -32,19 +35,19 @@ class ScrapeMetadataTool(MoviePilotTool): path = kwargs.get("path", "") storage = kwargs.get("storage", "local") overwrite = kwargs.get("overwrite", False) - + message = f"正在刮削媒体元数据: {path}" if storage != "local": message += f" [存储: {storage}]" if overwrite: message += " [覆盖模式]" - + return message async def run(self, path: str, storage: Optional[str] = "local", overwrite: Optional[bool] = False, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}") - + try: # 验证路径 if not path: @@ -52,14 +55,14 @@ class ScrapeMetadataTool(MoviePilotTool): "success": False, "message": "刮削路径不能为空" }, ensure_ascii=False) - + # 创建 FileItem fileitem = FileItem( storage=storage, path=path, type="file" if Path(path).suffix else "dir" ) - + # 检查本地存储路径是否存在 if storage == "local": scrape_path = Path(path) @@ -68,22 +71,22 @@ class ScrapeMetadataTool(MoviePilotTool): "success": False, "message": f"刮削路径不存在: {path}" }, ensure_ascii=False) - + # 识别媒体信息 media_chain = MediaChain() scrape_path = Path(path) meta = MetaInfoPath(scrape_path) mediainfo = await media_chain.async_recognize_by_meta(meta) - + if not mediainfo: return json.dumps({ "success": False, "message": f"刮削失败,无法识别媒体信息: {path}", "path": path }, ensure_ascii=False) - + # 在线程池中执行同步的刮削操作 - await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor( + await global_vars.loop.run_in_executor( None, lambda: media_chain.scrape_metadata( fileitem=fileitem, @@ -92,7 +95,7 @@ class ScrapeMetadataTool(MoviePilotTool): overwrite=overwrite ) ) - + return json.dumps({ "success": True, "message": f"{path} 刮削完成", @@ -105,7 +108,7 @@ class ScrapeMetadataTool(MoviePilotTool): "season": mediainfo.season } }, ensure_ascii=False, indent=2) - + except Exception as e: error_message = f"刮削媒体元数据失败: {str(e)}" logger.error(f"刮削媒体元数据失败: {e}", exc_info=True) @@ -114,4 +117,3 @@ class ScrapeMetadataTool(MoviePilotTool): "message": error_message, "path": path }, ensure_ascii=False) - diff --git a/app/agent/tools/impl/search_subscribe.py b/app/agent/tools/impl/search_subscribe.py index 9e601c16..ebde39cb 100644 --- a/app/agent/tools/impl/search_subscribe.py +++ b/app/agent/tools/impl/search_subscribe.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.subscribe import SubscribeChain -from app.core.config import GlobalVar +from app.core.config import global_vars from app.db.subscribe_oper import SubscribeOper from app.log import logger @@ -30,28 +30,29 @@ class SearchSubscribeTool(MoviePilotTool): """根据搜索参数生成友好的提示消息""" subscribe_id = kwargs.get("subscribe_id") manual = kwargs.get("manual", False) - + message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集" if manual: message += "(手动搜索)" - + return message - async def run(self, subscribe_id: int, manual: Optional[bool] = False, + async def run(self, subscribe_id: int, manual: Optional[bool] = False, filter_groups: Optional[List[str]] = None, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}") - + logger.info( + f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}") + try: # 先验证订阅是否存在 subscribe_oper = SubscribeOper() subscribe = subscribe_oper.get(subscribe_id) - + if not subscribe: return json.dumps({ "success": False, "message": f"订阅不存在: {subscribe_id}" }, ensure_ascii=False) - + # 获取订阅信息用于返回 subscribe_info = { "id": subscribe.id, @@ -65,7 +66,7 @@ class SearchSubscribeTool(MoviePilotTool): "tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid } - + # 检查订阅状态 if subscribe.state == "S": return json.dumps({ @@ -73,19 +74,19 @@ class SearchSubscribeTool(MoviePilotTool): "message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索", "subscribe": subscribe_info }, ensure_ascii=False) - + # 如果提供了 filter_groups 参数,先更新订阅的规则组 if filter_groups is not None: subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups}) logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}") - + # 调用 SubscribeChain 的 search 方法 # search 方法是同步的,需要在异步环境中运行 subscribe_chain = SubscribeChain() - + # 在线程池中执行同步的搜索操作 # 当 sid 有值时,state 参数会被忽略,直接处理该订阅 - await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor( + await global_vars.loop.run_in_executor( None, lambda: subscribe_chain.search( sid=subscribe_id, @@ -93,7 +94,7 @@ class SearchSubscribeTool(MoviePilotTool): manual=manual ) ) - + # 重新获取订阅信息以获取更新后的状态 updated_subscribe = subscribe_oper.get(subscribe_id) if updated_subscribe: @@ -103,19 +104,19 @@ class SearchSubscribeTool(MoviePilotTool): "last_update": updated_subscribe.last_update, "filter_groups": updated_subscribe.filter_groups }) - + # 如果提供了规则组,会在返回信息中显示 result = { "success": True, "message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成", "subscribe": subscribe_info } - + if filter_groups is not None: result["message"] += f"(已应用规则组: {', '.join(filter_groups)})" - + return json.dumps(result, ensure_ascii=False, indent=2) - + except Exception as e: error_message = f"搜索订阅缺失剧集失败: {str(e)}" logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True) @@ -124,4 +125,3 @@ class SearchSubscribeTool(MoviePilotTool): "message": error_message, "subscribe_id": subscribe_id }, ensure_ascii=False) - diff --git a/app/chain/message.py b/app/chain/message.py index add67730..d138a3bf 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1,4 +1,3 @@ -import asyncio import re import time from datetime import datetime, timedelta @@ -10,7 +9,7 @@ from app.chain.download import DownloadChain from app.chain.media import MediaChain from app.chain.search import SearchChain from app.chain.subscribe import SubscribeChain -from app.core.config import settings, GlobalVar +from app.core.config import settings, global_vars from app.core.context import MediaInfo, Context from app.core.meta import MetaBase from app.db.user_oper import UserOper @@ -174,7 +173,7 @@ class MessageChain(ChainBase): elif text.startswith('/ai') or text.startswith('/AI'): # AI智能体处理 self._handle_ai_message(text=text, channel=channel, source=source, - userid=userid, username=username) + userid=userid, username=username) elif text.startswith('/'): # 执行命令 self.eventmanager.send_event( @@ -329,7 +328,8 @@ class MessageChain(ChainBase): else: best_version = True # 转换用户名 - mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None + mp_name = UserOper().get_name( + **{f"{channel.name.lower()}_userid": userid}) if channel else None # 添加订阅,状态为N SubscribeChain().add(title=mediainfo.title, year=mediainfo.year, @@ -505,7 +505,8 @@ class MessageChain(ChainBase): # 开始搜索 if not medias: self.post_message(Notification( - channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) + channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", + userid=userid)) return logger.info(f"搜索到 {len(medias)} 条相关媒体信息") try: @@ -835,21 +836,22 @@ class MessageChain(ChainBase): 如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID """ current_time = datetime.now() - + # 检查用户是否有已存在的会话 if userid in MessageChain._user_sessions: session_id, last_time = MessageChain._user_sessions[userid] - + # 计算时间差 time_diff = current_time - last_time - + # 如果时间差小于等于15分钟,复用会话ID if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes): # 更新最后使用时间 MessageChain._user_sessions[userid] = (session_id, current_time) - logger.info(f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟") + logger.info( + f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟") return session_id - + # 创建新的会话ID new_session_id = f"user_{userid}_{int(time.time())}" MessageChain._user_sessions[userid] = (new_session_id, current_time) @@ -877,11 +879,11 @@ class MessageChain(ChainBase): if userid in MessageChain._user_sessions: session_id, _ = MessageChain._user_sessions.pop(userid) logger.info(f"已清除用户 {userid} 的会话: {session_id}") - + # 如果有会话ID,同时清除智能体的会话记忆 if session_id: try: - GlobalVar.CURRENT_EVENT_LOOP.run_until_complete( + global_vars.loop.run_until_complete( agent_manager.clear_session( session_id=session_id, user_id=str(userid) @@ -889,7 +891,7 @@ class MessageChain(ChainBase): ) except Exception as e: logger.warning(f"清除智能体会话记忆失败: {e}") - + self.post_message(Notification( channel=channel, source=source, @@ -905,7 +907,7 @@ class MessageChain(ChainBase): )) def _handle_ai_message(self, text: str, channel: MessageChannel, source: str, - userid: Union[str, int], username: str) -> None: + userid: Union[str, int], username: str) -> None: """ 处理AI智能体消息 """ @@ -946,9 +948,9 @@ class MessageChain(ChainBase): # 生成或复用会话ID session_id = self._get_or_create_session_id(userid) - + # 在事件循环中处理 - GlobalVar.CURRENT_EVENT_LOOP.run_until_complete( + global_vars.loop.run_until_complete( agent_manager.process_message( session_id=session_id, user_id=str(userid), @@ -962,4 +964,3 @@ class MessageChain(ChainBase): except Exception as e: logger.error(f"处理AI智能体消息失败: {e}") self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手") - diff --git a/app/core/config.py b/app/core/config.py index 0f57536a..52a49318 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -920,6 +920,19 @@ class GlobalVar(object): return True return False + @property + def loop(self) -> AbstractEventLoop: + """ + 当前循环 + """ + return self.CURRENT_EVENT_LOOP + + def set_loop(self, loop: AbstractEventLoop): + """ + 设置循环 + """ + self.CURRENT_EVENT_LOOP = loop + # 全局标识 global_vars = GlobalVar() diff --git a/app/core/event.py b/app/core/event.py index a4daeb66..c78b5a5a 100644 --- a/app/core/event.py +++ b/app/core/event.py @@ -11,7 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any from fastapi.concurrency import run_in_threadpool -from app.core.config import GlobalVar +from app.core.config import global_vars from app.helper.thread import ThreadHelper from app.log import logger from app.schemas import ChainEventData @@ -453,7 +453,7 @@ class EventManager(metaclass=Singleton): # 对于异步函数,直接在事件循环中运行 asyncio.run_coroutine_threadsafe( self.__safe_invoke_handler_async(handler, isolated_event), - GlobalVar.CURRENT_EVENT_LOOP + global_vars.loop ) else: # 对于同步函数,在线程池中运行 diff --git a/app/scheduler.py b/app/scheduler.py index 3c6552b5..b8c58e0c 100644 --- a/app/scheduler.py +++ b/app/scheduler.py @@ -21,7 +21,7 @@ from app.chain.site import SiteChain from app.chain.subscribe import SubscribeChain from app.chain.transfer import TransferChain from app.chain.workflow import WorkflowChain -from app.core.config import settings, GlobalVar +from app.core.config import settings, global_vars from app.core.event import eventmanager, Event from app.core.plugin import PluginManager from app.db.systemconfig_oper import SystemConfigOper @@ -474,7 +474,7 @@ class Scheduler(metaclass=SingletonClass): """ 启动协程 """ - return asyncio.run_coroutine_threadsafe(coro, GlobalVar.CURRENT_EVENT_LOOP) + return asyncio.run_coroutine_threadsafe(coro, global_vars.loop) # 获取定时任务 job = self.__prepare_job(job_id) diff --git a/app/startup/lifecycle.py b/app/startup/lifecycle.py index 8609c073..59f26e8d 100644 --- a/app/startup/lifecycle.py +++ b/app/startup/lifecycle.py @@ -4,6 +4,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from app.chain.system import SystemChain +from app.core.config import global_vars from app.helper.system import SystemHelper from app.startup.command_initializer import init_command, stop_command, restart_command from app.startup.modules_initializer import init_modules, stop_modules @@ -35,6 +36,8 @@ async def lifespan(app: FastAPI): 定义应用的生命周期事件 """ print("Starting up...") + # 存储当前循环 + global_vars.set_loop(asyncio.get_event_loop()) # 初始化路由 init_routers(app) # 初始化模块 diff --git a/app/startup/plugins_initializer.py b/app/startup/plugins_initializer.py index 36ee1d61..61a45fbf 100644 --- a/app/startup/plugins_initializer.py +++ b/app/startup/plugins_initializer.py @@ -1,4 +1,4 @@ -from app.core.config import GlobalVar +from app.core.config import global_vars from app.core.plugin import PluginManager from app.log import logger @@ -8,7 +8,7 @@ async def sync_plugins() -> bool: 初始化安装插件,并动态注册后台任务及API """ try: - loop = GlobalVar.CURRENT_EVENT_LOOP + loop = global_vars.loop plugin_manager = PluginManager() sync_result = await execute_task(loop, plugin_manager.sync, "插件同步到本地")