From 8109ffb445f0e7e60e347fe4e6c390cd55fb3702 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Tue, 7 Apr 2026 07:32:35 +0800 Subject: [PATCH] feat(agent): add /stop_agent command for emergency stop of agent reasoning Add /stop_agent command that cancels the currently running agent reasoning task without clearing the session or memory. Unlike /clear_session which destroys the entire session, this allows users to stop a long-running or stuck agent process and continue the conversation afterward. --- app/agent/__init__.py | 37 ++++++++ app/chain/message.py | 55 +++++++++++- app/command.py | 193 +++++++++++++++++++++++++++--------------- 3 files changed, 216 insertions(+), 69 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index ff025d7e..e2cb444e 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -604,6 +604,43 @@ class AgentManager: return await agent.process(task.message, images=task.images) + async def stop_current_task(self, session_id: str): + """ + 应急停止当前正在执行的Agent推理任务,但保留会话和记忆。 + 与 clear_session 不同,此方法不会销毁Agent实例或清除记忆, + 用户可以在停止后继续对话。 + """ + stopped = False + + # 取消该会话的worker(会触发 _execute_agent 中的 CancelledError) + if session_id in self._session_workers: + self._session_workers[session_id].cancel() + try: + await self._session_workers[session_id] + except asyncio.CancelledError: + pass + self._session_workers.pop(session_id, None) + stopped = True + + # 清空队列中待处理的消息 + if session_id in self._session_queues: + queue = self._session_queues[session_id] + while not queue.empty(): + try: + queue.get_nowait() + queue.task_done() + except asyncio.QueueEmpty: + break + self._session_queues.pop(session_id, None) + stopped = True + + if stopped: + logger.info(f"会话 {session_id} 的Agent推理已应急停止") + else: + logger.debug(f"会话 {session_id} 没有正在执行的Agent任务") + + return stopped + async def clear_session(self, session_id: str, user_id: str): """ 清空会话 diff --git a/app/chain/message.py b/app/chain/message.py index 7c755621..aeba1d25 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -580,7 +580,7 @@ class MessageChain(ChainBase): total = len(cache_list) # 加一页 cache_list = cache_list[ - (_current_page + 1) * self._page_size: (_current_page + 2) + (_current_page + 1) * self._page_size : (_current_page + 2) * self._page_size ] if not cache_list: @@ -1134,6 +1134,59 @@ class MessageChain(ChainBase): ) ) + def remote_stop_agent( + self, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, + ): + """ + 应急停止当前正在执行的Agent推理(远程命令接口)。 + 与 /clear_session 不同,此命令不会清除会话和记忆, + 停止后用户仍可继续对话。 + """ + # 查找用户的会话ID(不弹出,保留会话) + session_info = self._user_sessions.get(userid) + if session_info: + session_id, _ = session_info + try: + future = asyncio.run_coroutine_threadsafe( + agent_manager.stop_current_task(session_id=session_id), + global_vars.loop, + ) + stopped = future.result(timeout=10) + except Exception as e: + logger.warning(f"停止Agent推理失败: {e}") + stopped = False + + if stopped: + self.post_message( + Notification( + channel=channel, + source=source, + title="智能体推理已应急停止,会话记忆已保留,您可以继续对话", + userid=userid, + ) + ) + else: + self.post_message( + Notification( + channel=channel, + source=source, + title="当前没有正在执行的智能体任务", + userid=userid, + ) + ) + else: + self.post_message( + Notification( + channel=channel, + source=source, + title="您当前没有活跃的智能体会话", + userid=userid, + ) + ) + def _handle_ai_message( self, text: str, diff --git a/app/command.py b/app/command.py index 442ab630..b0cb2823 100644 --- a/app/command.py +++ b/app/command.py @@ -45,109 +45,115 @@ class Command(metaclass=Singleton): "id": "cookiecloud", "type": "scheduler", "description": "同步站点", - "category": "站点" + "category": "站点", }, "/sites": { "func": SiteChain().remote_list, "description": "查询站点", "category": "站点", - "data": {} + "data": {}, }, "/site_cookie": { "func": SiteChain().remote_cookie, "description": "更新站点Cookie", - "data": {} + "data": {}, }, "/site_statistic": { "func": SiteChain().remote_refresh_userdatas, "description": "站点数据统计", - "data": {} + "data": {}, }, "/site_enable": { "func": SiteChain().remote_enable, "description": "启用站点", - "data": {} + "data": {}, }, "/site_disable": { "func": SiteChain().remote_disable, "description": "禁用站点", - "data": {} + "data": {}, }, "/mediaserver_sync": { "id": "mediaserver_sync", "type": "scheduler", "description": "同步媒体服务器", - "category": "管理" + "category": "管理", }, "/subscribes": { "func": SubscribeChain().remote_list, "description": "查询订阅", "category": "订阅", - "data": {} + "data": {}, }, "/subscribe_refresh": { "id": "subscribe_refresh", "type": "scheduler", "description": "刷新订阅", - "category": "订阅" + "category": "订阅", }, "/subscribe_search": { "id": "subscribe_search", "type": "scheduler", "description": "搜索订阅", - "category": "订阅" + "category": "订阅", }, "/subscribe_delete": { "func": SubscribeChain().remote_delete, "description": "删除订阅", - "data": {} + "data": {}, }, "/subscribe_tmdb": { "id": "subscribe_tmdb", "type": "scheduler", - "description": "订阅元数据更新" + "description": "订阅元数据更新", }, "/downloading": { "func": DownloadChain().remote_downloading, "description": "正在下载", "category": "管理", - "data": {} + "data": {}, }, "/transfer": { "id": "transfer", "type": "scheduler", "description": "下载文件整理", - "category": "管理" + "category": "管理", }, "/redo": { "func": TransferChain().remote_transfer, "description": "手动整理", - "data": {} + "data": {}, }, "/clear_cache": { "func": SystemChain().remote_clear_cache, "description": "清理缓存", "category": "管理", - "data": {} + "data": {}, }, "/restart": { "func": SystemChain().restart, "description": "重启系统", "category": "管理", - "data": {} + "data": {}, }, "/version": { "func": SystemChain().version, "description": "当前版本", "category": "管理", - "data": {} + "data": {}, }, "/clear_session": { "func": MessageChain().remote_clear_session, "description": "清除会话", "category": "管理", - "data": {} - } + "data": {}, + }, + "/stop_agent": { + "func": MessageChain().remote_stop_agent, + "description": "停止推理", + "category": "管理", + "data": {}, + }, } # 插件命令集合 self._plugin_commands = {} @@ -182,7 +188,7 @@ class Command(metaclass=Singleton): self._commands = { **self._preset_commands, **self._plugin_commands, - **self._other_commands + **self._other_commands, } # 强制触发注册 @@ -195,32 +201,50 @@ class Command(metaclass=Singleton): event_data: CommandRegisterEventData = event.event_data # 如果事件被取消,跳过命令注册 if event_data.cancel: - logger.debug(f"Command initialization canceled by event: {event_data.source}") + logger.debug( + f"Command initialization canceled by event: {event_data.source}" + ) return # 如果拦截源与插件标识一致时,这里认为需要强制触发注册 if pid is not None and pid == event_data.source: force_register = True initial_commands = event_data.commands or {} - logger.debug(f"Registering command count from event: {len(initial_commands)}") + logger.debug( + f"Registering command count from event: {len(initial_commands)}" + ) else: - logger.debug(f"Registering initial command count: {len(initial_commands)}") + logger.debug( + f"Registering initial command count: {len(initial_commands)}" + ) # initial_commands 必须是 self._commands 的子集 - filtered_initial_commands = DictUtils.filter_keys_to_subset(initial_commands, self._commands) + filtered_initial_commands = DictUtils.filter_keys_to_subset( + initial_commands, self._commands + ) # 如果 filtered_initial_commands 为空,则跳过注册 if not filtered_initial_commands and not force_register: logger.debug("Filtered commands are empty, skipping registration.") return # 对比调整后的命令与当前命令 - if filtered_initial_commands != self._registered_commands or force_register: - logger.debug("Command set has changed or force registration is enabled.") + if ( + filtered_initial_commands != self._registered_commands + or force_register + ): + logger.debug( + "Command set has changed or force registration is enabled." + ) self._registered_commands = filtered_initial_commands CommandChain().register_commands(commands=filtered_initial_commands) else: - logger.debug("Command set unchanged, skipping broadcast registration.") + logger.debug( + "Command set unchanged, skipping broadcast registration." + ) except Exception as e: - logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True) + logger.error( + f"Error occurred during command initialization in background: {e}", + exc_info=True, + ) def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]: """ @@ -238,7 +262,7 @@ class Command(metaclass=Singleton): command_data = { "type": command_type, "description": command.get("description"), - "category": command.get("category") + "category": command.get("category"), } # 如果有 pid,则添加到命令数据中 plugin_id = command.get("pid") @@ -253,7 +277,9 @@ class Command(metaclass=Singleton): add_commands(self._other_commands, "other") # 触发事件允许可以拦截和调整命令 - event_data = CommandRegisterEventData(commands=commands, origin="CommandChain", service=None) + event_data = CommandRegisterEventData( + commands=commands, origin="CommandChain", service=None + ) event = eventmanager.send_event(ChainEventType.CommandRegister, event_data) return event, commands @@ -274,13 +300,19 @@ class Command(metaclass=Singleton): "show": command.get("show", True), "data": { "etype": command.get("event"), - "data": command.get("data") - } + "data": command.get("data"), + }, } return plugin_commands - def __run_command(self, command: Dict[str, any], data_str: Optional[str] = "", - channel: MessageChannel = None, source: Optional[str] = None, userid: Union[str, int] = None): + def __run_command( + self, + command: Dict[str, any], + data_str: Optional[str] = "", + channel: MessageChannel = None, + source: Optional[str] = None, + userid: Union[str, int] = None, + ): """ 运行定时服务 """ @@ -292,7 +324,7 @@ class Command(metaclass=Singleton): channel=channel, source=source, title=f"开始执行 {command.get('description')} ...", - userid=userid + userid=userid, ) ) @@ -305,33 +337,33 @@ class Command(metaclass=Singleton): channel=channel, source=source, title=f"{command.get('description')} 执行完成", - userid=userid + userid=userid, ) ) else: # 命令 - cmd_data = copy.deepcopy(command['data']) if command.get('data') else {} - args_num = ObjectUtils.arguments(command['func']) + cmd_data = copy.deepcopy(command["data"]) if command.get("data") else {} + args_num = ObjectUtils.arguments(command["func"]) if args_num > 0: if cmd_data: # 有内置参数直接使用内置参数 data = cmd_data.get("data") or {} - data['channel'] = channel - data['source'] = source - data['user'] = userid + data["channel"] = channel + data["source"] = source + data["user"] = userid if data_str: - data['arg_str'] = data_str - cmd_data['data'] = data - command['func'](**cmd_data) + data["arg_str"] = data_str + cmd_data["data"] = data + command["func"](**cmd_data) elif args_num == 3: # 没有输入参数,只输入渠道来源、用户ID和消息来源 - command['func'](channel, userid, source) + command["func"](channel, userid, source) elif args_num > 3: # 多个输入参数:用户输入、用户ID - command['func'](data_str, channel, userid, source) + command["func"](data_str, channel, userid, source) else: # 没有参数 - command['func']() + command["func"]() def get_commands(self): """ @@ -345,9 +377,15 @@ class Command(metaclass=Singleton): """ return self._commands.get(cmd, {}) - def register(self, cmd: str, func: Any, data: Optional[dict] = None, - desc: Optional[str] = None, category: Optional[str] = None, - show: bool = True) -> None: + def register( + self, + cmd: str, + func: Any, + data: Optional[dict] = None, + desc: Optional[str] = None, + category: Optional[str] = None, + show: bool = True, + ) -> None: """ 注册单个命令 """ @@ -357,12 +395,17 @@ class Command(metaclass=Singleton): "description": desc, "category": category, "data": data or {}, - "show": show + "show": show, } - def execute(self, cmd: str, data_str: Optional[str] = "", - channel: MessageChannel = None, source: Optional[str] = None, - userid: Union[str, int] = None) -> None: + def execute( + self, + cmd: str, + data_str: Optional[str] = "", + channel: MessageChannel = None, + source: Optional[str] = None, + userid: Union[str, int] = None, + ) -> None: """ 执行命令 """ @@ -370,23 +413,32 @@ class Command(metaclass=Singleton): if command: try: if userid: - logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...") + logger.info( + f"用户 {userid} 开始执行:{command.get('description')} ..." + ) else: logger.info(f"开始执行:{command.get('description')} ...") # 执行命令 - self.__run_command(command, data_str=data_str, - channel=channel, source=source, userid=userid) + self.__run_command( + command, + data_str=data_str, + channel=channel, + source=source, + userid=userid, + ) if userid: logger.info(f"用户 {userid} {command.get('description')} 执行完成") else: logger.info(f"{command.get('description')} 执行完成") except Exception as err: - logger.error(f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}") - self.messagehelper.put(title=f"执行命令 {cmd} 出错", - message=str(err), - role="system") + logger.error( + f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}" + ) + self.messagehelper.put( + title=f"执行命令 {cmd} 出错", message=str(err), role="system" + ) @staticmethod def send_plugin_event(etype: EventType, data: dict) -> None: @@ -404,19 +456,24 @@ class Command(metaclass=Singleton): } """ # 命令参数 - event_str = event.event_data.get('cmd') + event_str = event.event_data.get("cmd") # 消息渠道 - event_channel = event.event_data.get('channel') + event_channel = event.event_data.get("channel") # 消息来源 - event_source = event.event_data.get('source') + event_source = event.event_data.get("source") # 消息用户 - event_user = event.event_data.get('user') + event_user = event.event_data.get("user") if event_str: cmd = event_str.split()[0] args = " ".join(event_str.split()[1:]) if self.get(cmd): - self.execute(cmd=cmd, data_str=args, - channel=event_channel, source=event_source, userid=event_user) + self.execute( + cmd=cmd, + data_str=args, + channel=event_channel, + source=event_source, + userid=event_user, + ) @eventmanager.register(EventType.ModuleReload) def module_reload_event(self, _: ManagerEvent) -> None: