fixx loop

This commit is contained in:
jxxghp
2025-11-20 08:15:37 +08:00
parent 5c983b64bc
commit 48da5c976c
8 changed files with 77 additions and 58 deletions

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain
from app.core.config import GlobalVar
from app.core.config import global_vars
from app.core.metainfo import MetaInfoPath
from app.log import logger
from app.schemas import FileItem
@@ -17,9 +17,12 @@ from app.schemas import FileItem
class ScrapeMetadataInput(BaseModel):
"""刮削媒体元数据工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
path: str = Field(..., description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
storage: Optional[str] = Field("local", description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
overwrite: Optional[bool] = Field(False, description="Whether to overwrite existing metadata files (default: False)")
path: str = Field(...,
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
storage: Optional[str] = Field("local",
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
overwrite: Optional[bool] = Field(False,
description="Whether to overwrite existing metadata files (default: False)")
class ScrapeMetadataTool(MoviePilotTool):
@@ -32,19 +35,19 @@ class ScrapeMetadataTool(MoviePilotTool):
path = kwargs.get("path", "")
storage = kwargs.get("storage", "local")
overwrite = kwargs.get("overwrite", False)
message = f"正在刮削媒体元数据: {path}"
if storage != "local":
message += f" [存储: {storage}]"
if overwrite:
message += " [覆盖模式]"
return message
async def run(self, path: str, storage: Optional[str] = "local",
overwrite: Optional[bool] = False, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
try:
# 验证路径
if not path:
@@ -52,14 +55,14 @@ class ScrapeMetadataTool(MoviePilotTool):
"success": False,
"message": "刮削路径不能为空"
}, ensure_ascii=False)
# 创建 FileItem
fileitem = FileItem(
storage=storage,
path=path,
type="file" if Path(path).suffix else "dir"
)
# 检查本地存储路径是否存在
if storage == "local":
scrape_path = Path(path)
@@ -68,22 +71,22 @@ class ScrapeMetadataTool(MoviePilotTool):
"success": False,
"message": f"刮削路径不存在: {path}"
}, ensure_ascii=False)
# 识别媒体信息
media_chain = MediaChain()
scrape_path = Path(path)
meta = MetaInfoPath(scrape_path)
mediainfo = await media_chain.async_recognize_by_meta(meta)
if not mediainfo:
return json.dumps({
"success": False,
"message": f"刮削失败,无法识别媒体信息: {path}",
"path": path
}, ensure_ascii=False)
# 在线程池中执行同步的刮削操作
await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor(
await global_vars.loop.run_in_executor(
None,
lambda: media_chain.scrape_metadata(
fileitem=fileitem,
@@ -92,7 +95,7 @@ class ScrapeMetadataTool(MoviePilotTool):
overwrite=overwrite
)
)
return json.dumps({
"success": True,
"message": f"{path} 刮削完成",
@@ -105,7 +108,7 @@ class ScrapeMetadataTool(MoviePilotTool):
"season": mediainfo.season
}
}, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"刮削媒体元数据失败: {str(e)}"
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
@@ -114,4 +117,3 @@ class ScrapeMetadataTool(MoviePilotTool):
"message": error_message,
"path": path
}, ensure_ascii=False)

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.subscribe import SubscribeChain
from app.core.config import GlobalVar
from app.core.config import global_vars
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
@@ -30,28 +30,29 @@ class SearchSubscribeTool(MoviePilotTool):
"""根据搜索参数生成友好的提示消息"""
subscribe_id = kwargs.get("subscribe_id")
manual = kwargs.get("manual", False)
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
if manual:
message += "(手动搜索)"
return message
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
logger.info(
f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
try:
# 先验证订阅是否存在
subscribe_oper = SubscribeOper()
subscribe = subscribe_oper.get(subscribe_id)
if not subscribe:
return json.dumps({
"success": False,
"message": f"订阅不存在: {subscribe_id}"
}, ensure_ascii=False)
# 获取订阅信息用于返回
subscribe_info = {
"id": subscribe.id,
@@ -65,7 +66,7 @@ class SearchSubscribeTool(MoviePilotTool):
"tmdbid": subscribe.tmdbid,
"doubanid": subscribe.doubanid
}
# 检查订阅状态
if subscribe.state == "S":
return json.dumps({
@@ -73,19 +74,19 @@ class SearchSubscribeTool(MoviePilotTool):
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
"subscribe": subscribe_info
}, ensure_ascii=False)
# 如果提供了 filter_groups 参数,先更新订阅的规则组
if filter_groups is not None:
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
# 调用 SubscribeChain 的 search 方法
# search 方法是同步的,需要在异步环境中运行
subscribe_chain = SubscribeChain()
# 在线程池中执行同步的搜索操作
# 当 sid 有值时state 参数会被忽略,直接处理该订阅
await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor(
await global_vars.loop.run_in_executor(
None,
lambda: subscribe_chain.search(
sid=subscribe_id,
@@ -93,7 +94,7 @@ class SearchSubscribeTool(MoviePilotTool):
manual=manual
)
)
# 重新获取订阅信息以获取更新后的状态
updated_subscribe = subscribe_oper.get(subscribe_id)
if updated_subscribe:
@@ -103,19 +104,19 @@ class SearchSubscribeTool(MoviePilotTool):
"last_update": updated_subscribe.last_update,
"filter_groups": updated_subscribe.filter_groups
})
# 如果提供了规则组,会在返回信息中显示
result = {
"success": True,
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
"subscribe": subscribe_info
}
if filter_groups is not None:
result["message"] += f"(已应用规则组: {', '.join(filter_groups)}"
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
@@ -124,4 +125,3 @@ class SearchSubscribeTool(MoviePilotTool):
"message": error_message,
"subscribe_id": subscribe_id
}, ensure_ascii=False)

View File

@@ -1,4 +1,3 @@
import asyncio
import re
import time
from datetime import datetime, timedelta
@@ -10,7 +9,7 @@ from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.config import settings, GlobalVar
from app.core.config import settings, global_vars
from app.core.context import MediaInfo, Context
from app.core.meta import MetaBase
from app.db.user_oper import UserOper
@@ -174,7 +173,7 @@ class MessageChain(ChainBase):
elif text.startswith('/ai') or text.startswith('/AI'):
# AI智能体处理
self._handle_ai_message(text=text, channel=channel, source=source,
userid=userid, username=username)
userid=userid, username=username)
elif text.startswith('/'):
# 执行命令
self.eventmanager.send_event(
@@ -329,7 +328,8 @@ class MessageChain(ChainBase):
else:
best_version = True
# 转换用户名
mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
mp_name = UserOper().get_name(
**{f"{channel.name.lower()}_userid": userid}) if channel else None
# 添加订阅状态为N
SubscribeChain().add(title=mediainfo.title,
year=mediainfo.year,
@@ -505,7 +505,8 @@ class MessageChain(ChainBase):
# 开始搜索
if not medias:
self.post_message(Notification(
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
userid=userid))
return
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
try:
@@ -835,21 +836,22 @@ class MessageChain(ChainBase):
如果用户上次会话在15分钟内则复用相同的会话ID否则创建新的会话ID
"""
current_time = datetime.now()
# 检查用户是否有已存在的会话
if userid in MessageChain._user_sessions:
session_id, last_time = MessageChain._user_sessions[userid]
# 计算时间差
time_diff = current_time - last_time
# 如果时间差小于等于15分钟复用会话ID
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
# 更新最后使用时间
MessageChain._user_sessions[userid] = (session_id, current_time)
logger.info(f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
logger.info(
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
return session_id
# 创建新的会话ID
new_session_id = f"user_{userid}_{int(time.time())}"
MessageChain._user_sessions[userid] = (new_session_id, current_time)
@@ -877,11 +879,11 @@ class MessageChain(ChainBase):
if userid in MessageChain._user_sessions:
session_id, _ = MessageChain._user_sessions.pop(userid)
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
# 如果有会话ID同时清除智能体的会话记忆
if session_id:
try:
GlobalVar.CURRENT_EVENT_LOOP.run_until_complete(
global_vars.loop.run_until_complete(
agent_manager.clear_session(
session_id=session_id,
user_id=str(userid)
@@ -889,7 +891,7 @@ class MessageChain(ChainBase):
)
except Exception as e:
logger.warning(f"清除智能体会话记忆失败: {e}")
self.post_message(Notification(
channel=channel,
source=source,
@@ -905,7 +907,7 @@ class MessageChain(ChainBase):
))
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
userid: Union[str, int], username: str) -> None:
userid: Union[str, int], username: str) -> None:
"""
处理AI智能体消息
"""
@@ -946,9 +948,9 @@ class MessageChain(ChainBase):
# 生成或复用会话ID
session_id = self._get_or_create_session_id(userid)
# 在事件循环中处理
GlobalVar.CURRENT_EVENT_LOOP.run_until_complete(
global_vars.loop.run_until_complete(
agent_manager.process_message(
session_id=session_id,
user_id=str(userid),
@@ -962,4 +964,3 @@ class MessageChain(ChainBase):
except Exception as e:
logger.error(f"处理AI智能体消息失败: {e}")
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")

View File

@@ -920,6 +920,19 @@ class GlobalVar(object):
return True
return False
@property
def loop(self) -> AbstractEventLoop:
"""
当前循环
"""
return self.CURRENT_EVENT_LOOP
def set_loop(self, loop: AbstractEventLoop):
"""
设置循环
"""
self.CURRENT_EVENT_LOOP = loop
# 全局标识
global_vars = GlobalVar()

View File

@@ -11,7 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from fastapi.concurrency import run_in_threadpool
from app.core.config import GlobalVar
from app.core.config import global_vars
from app.helper.thread import ThreadHelper
from app.log import logger
from app.schemas import ChainEventData
@@ -453,7 +453,7 @@ class EventManager(metaclass=Singleton):
# 对于异步函数,直接在事件循环中运行
asyncio.run_coroutine_threadsafe(
self.__safe_invoke_handler_async(handler, isolated_event),
GlobalVar.CURRENT_EVENT_LOOP
global_vars.loop
)
else:
# 对于同步函数,在线程池中运行

View File

@@ -21,7 +21,7 @@ from app.chain.site import SiteChain
from app.chain.subscribe import SubscribeChain
from app.chain.transfer import TransferChain
from app.chain.workflow import WorkflowChain
from app.core.config import settings, GlobalVar
from app.core.config import settings, global_vars
from app.core.event import eventmanager, Event
from app.core.plugin import PluginManager
from app.db.systemconfig_oper import SystemConfigOper
@@ -474,7 +474,7 @@ class Scheduler(metaclass=SingletonClass):
"""
启动协程
"""
return asyncio.run_coroutine_threadsafe(coro, GlobalVar.CURRENT_EVENT_LOOP)
return asyncio.run_coroutine_threadsafe(coro, global_vars.loop)
# 获取定时任务
job = self.__prepare_job(job_id)

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.chain.system import SystemChain
from app.core.config import global_vars
from app.helper.system import SystemHelper
from app.startup.command_initializer import init_command, stop_command, restart_command
from app.startup.modules_initializer import init_modules, stop_modules
@@ -35,6 +36,8 @@ async def lifespan(app: FastAPI):
定义应用的生命周期事件
"""
print("Starting up...")
# 存储当前循环
global_vars.set_loop(asyncio.get_event_loop())
# 初始化路由
init_routers(app)
# 初始化模块

View File

@@ -1,4 +1,4 @@
from app.core.config import GlobalVar
from app.core.config import global_vars
from app.core.plugin import PluginManager
from app.log import logger
@@ -8,7 +8,7 @@ async def sync_plugins() -> bool:
初始化安装插件并动态注册后台任务及API
"""
try:
loop = GlobalVar.CURRENT_EVENT_LOOP
loop = global_vars.loop
plugin_manager = PluginManager()
sync_result = await execute_task(loop, plugin_manager.sync, "插件同步到本地")