mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-09 05:39:03 +08:00
fixx loop
This commit is contained in:
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.config import GlobalVar
|
||||
from app.core.config import global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
@@ -17,9 +17,12 @@ from app.schemas import FileItem
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(..., description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False, description="Whether to overwrite existing metadata files (default: False)")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
@@ -32,19 +35,19 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
overwrite = kwargs.get("overwrite", False)
|
||||
|
||||
|
||||
message = f"正在刮削媒体元数据: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
if overwrite:
|
||||
message += " [覆盖模式]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
@@ -52,14 +55,14 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
@@ -68,22 +71,22 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
scrape_path = Path(path)
|
||||
meta = MetaInfoPath(scrape_path)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor(
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: media_chain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
@@ -92,7 +95,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
overwrite=overwrite
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
@@ -105,7 +108,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
@@ -114,4 +117,3 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import GlobalVar
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
@@ -30,28 +30,29 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
# 先验证订阅是否存在
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = subscribe_oper.get(subscribe_id)
|
||||
|
||||
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 获取订阅信息用于返回
|
||||
subscribe_info = {
|
||||
"id": subscribe.id,
|
||||
@@ -65,7 +66,7 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
}
|
||||
|
||||
|
||||
# 检查订阅状态
|
||||
if subscribe.state == "S":
|
||||
return json.dumps({
|
||||
@@ -73,19 +74,19 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
|
||||
"subscribe": subscribe_info
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
# 如果提供了 filter_groups 参数,先更新订阅的规则组
|
||||
if filter_groups is not None:
|
||||
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
|
||||
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
|
||||
|
||||
|
||||
# 调用 SubscribeChain 的 search 方法
|
||||
# search 方法是同步的,需要在异步环境中运行
|
||||
subscribe_chain = SubscribeChain()
|
||||
|
||||
|
||||
# 在线程池中执行同步的搜索操作
|
||||
# 当 sid 有值时,state 参数会被忽略,直接处理该订阅
|
||||
await GlobalVar.CURRENT_EVENT_LOOP.run_in_executor(
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscribe_chain.search(
|
||||
sid=subscribe_id,
|
||||
@@ -93,7 +94,7 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
manual=manual
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 重新获取订阅信息以获取更新后的状态
|
||||
updated_subscribe = subscribe_oper.get(subscribe_id)
|
||||
if updated_subscribe:
|
||||
@@ -103,19 +104,19 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"last_update": updated_subscribe.last_update,
|
||||
"filter_groups": updated_subscribe.filter_groups
|
||||
})
|
||||
|
||||
|
||||
# 如果提供了规则组,会在返回信息中显示
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
|
||||
"subscribe": subscribe_info
|
||||
}
|
||||
|
||||
|
||||
if filter_groups is not None:
|
||||
result["message"] += f"(已应用规则组: {', '.join(filter_groups)})"
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
|
||||
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
|
||||
@@ -124,4 +125,3 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
@@ -10,7 +9,7 @@ from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import settings, GlobalVar
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
@@ -174,7 +173,7 @@ class MessageChain(ChainBase):
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
self.eventmanager.send_event(
|
||||
@@ -329,7 +328,8 @@ class MessageChain(ChainBase):
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
mp_name = UserOper().get_name(
|
||||
**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
@@ -505,7 +505,8 @@ class MessageChain(ChainBase):
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
|
||||
userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
@@ -835,21 +836,22 @@ class MessageChain(ChainBase):
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
|
||||
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
@@ -877,11 +879,11 @@ class MessageChain(ChainBase):
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
try:
|
||||
GlobalVar.CURRENT_EVENT_LOOP.run_until_complete(
|
||||
global_vars.loop.run_until_complete(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
@@ -889,7 +891,7 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -905,7 +907,7 @@ class MessageChain(ChainBase):
|
||||
))
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
@@ -946,9 +948,9 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 生成或复用会话ID
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
|
||||
|
||||
# 在事件循环中处理
|
||||
GlobalVar.CURRENT_EVENT_LOOP.run_until_complete(
|
||||
global_vars.loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
@@ -962,4 +964,3 @@ class MessageChain(ChainBase):
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
|
||||
@@ -920,6 +920,19 @@ class GlobalVar(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def loop(self) -> AbstractEventLoop:
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
"""
|
||||
设置循环
|
||||
"""
|
||||
self.CURRENT_EVENT_LOOP = loop
|
||||
|
||||
|
||||
# 全局标识
|
||||
global_vars = GlobalVar()
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.core.config import GlobalVar
|
||||
from app.core.config import global_vars
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ChainEventData
|
||||
@@ -453,7 +453,7 @@ class EventManager(metaclass=Singleton):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
GlobalVar.CURRENT_EVENT_LOOP
|
||||
global_vars.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
|
||||
@@ -21,7 +21,7 @@ from app.chain.site import SiteChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import settings, GlobalVar
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -474,7 +474,7 @@ class Scheduler(metaclass=SingletonClass):
|
||||
"""
|
||||
启动协程
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(coro, GlobalVar.CURRENT_EVENT_LOOP)
|
||||
return asyncio.run_coroutine_threadsafe(coro, global_vars.loop)
|
||||
|
||||
# 获取定时任务
|
||||
job = self.__prepare_job(job_id)
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.config import global_vars
|
||||
from app.helper.system import SystemHelper
|
||||
from app.startup.command_initializer import init_command, stop_command, restart_command
|
||||
from app.startup.modules_initializer import init_modules, stop_modules
|
||||
@@ -35,6 +36,8 @@ async def lifespan(app: FastAPI):
|
||||
定义应用的生命周期事件
|
||||
"""
|
||||
print("Starting up...")
|
||||
# 存储当前循环
|
||||
global_vars.set_loop(asyncio.get_event_loop())
|
||||
# 初始化路由
|
||||
init_routers(app)
|
||||
# 初始化模块
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.core.config import GlobalVar
|
||||
from app.core.config import global_vars
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
@@ -8,7 +8,7 @@ async def sync_plugins() -> bool:
|
||||
初始化安装插件,并动态注册后台任务及API
|
||||
"""
|
||||
try:
|
||||
loop = GlobalVar.CURRENT_EVENT_LOOP
|
||||
loop = global_vars.loop
|
||||
plugin_manager = PluginManager()
|
||||
|
||||
sync_result = await execute_task(loop, plugin_manager.sync, "插件同步到本地")
|
||||
|
||||
Reference in New Issue
Block a user