diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 95841f85..71561d6e 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,7 +1,6 @@ """MoviePilot AI智能体实现""" import asyncio -import threading from typing import Dict, List, Any from langchain.agents import AgentExecutor, create_openai_tools_agent @@ -21,9 +20,6 @@ from app.helper.message import MessageHelper from app.log import logger from app.schemas import Notification -# 用于保护环境变量修改的线程锁 -_env_lock = threading.Lock() - class AgentChain(ChainBase): pass @@ -75,36 +71,21 @@ class MoviePilotAgent: api_key = settings.LLM_API_KEY if provider == "google": - import os - from contextlib import contextmanager - from langchain_google_genai import ChatGoogleGenerativeAI - - # 使用线程锁保护的临时环境变量配置 - @contextmanager - def _temp_proxy_env(): - """线程安全的临时设置代理环境变量的上下文管理器""" - with _env_lock: - old_http = os.environ.get("HTTP_PROXY") - old_https = os.environ.get("HTTPS_PROXY") - try: - if settings.PROXY_HOST: - os.environ["HTTP_PROXY"] = settings.PROXY_HOST - os.environ["HTTPS_PROXY"] = settings.PROXY_HOST - yield - finally: - # 恢复原始环境变量 - if old_http is not None: - os.environ["HTTP_PROXY"] = old_http - elif "HTTP_PROXY" in os.environ: - del os.environ["HTTP_PROXY"] - - if old_https is not None: - os.environ["HTTPS_PROXY"] = old_https - elif "HTTPS_PROXY" in os.environ: - del os.environ["HTTPS_PROXY"] - - # 在临时环境变量中初始化 ChatGoogleGenerativeAI - with _temp_proxy_env(): + if settings.PROXY_HOST: + from langchain_openai import ChatOpenAI + return ChatOpenAI( + model=settings.LLM_MODEL, + api_key=api_key, + max_retries=3, + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + temperature=settings.LLM_TEMPERATURE, + streaming=True, + callbacks=[self.callback_handler], + stream_usage=True, + openai_proxy=settings.PROXY_HOST + ) + else: + from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI( model=settings.LLM_MODEL, google_api_key=api_key, diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 2023ba38..d84ba45e 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -21,6 +21,8 @@ from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribes from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool from app.agent.tools.impl.search_media import SearchMediaTool +from app.agent.tools.impl.search_person import SearchPersonTool +from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool from app.agent.tools.impl.recognize_media import RecognizeMediaTool from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool @@ -53,6 +55,8 @@ class MoviePilotToolFactory: tools = [] tool_definitions = [ SearchMediaTool, + SearchPersonTool, + SearchPersonCreditsTool, RecognizeMediaTool, ScrapeMetadataTool, QueryEpisodeScheduleTool, diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index 8e0dcbd0..163e1a76 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -33,6 +33,8 @@ class AddSubscribeInput(BaseModel): description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')") filter_groups: Optional[List[str]] = Field(None, description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)") + sites: Optional[List[int]] = Field(None, + description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)") class AddSubscribeTool(MoviePilotTool): @@ -61,12 +63,13 @@ class AddSubscribeTool(MoviePilotTool): season: Optional[int] = None, tmdb_id: Optional[str] = None, start_episode: Optional[int] = None, total_episode: Optional[int] = None, quality: Optional[str] = None, resolution: Optional[str] = None, - effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, **kwargs) -> str: + effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, + sites: Optional[List[int]] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, " f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, " f"total_episode={total_episode}, quality={quality}, resolution={resolution}, " - f"effect={effect}, filter_groups={filter_groups}") + f"effect={effect}, filter_groups={filter_groups}, sites={sites}") try: subscribe_chain = SubscribeChain() @@ -92,6 +95,8 @@ class AddSubscribeTool(MoviePilotTool): subscribe_kwargs['effect'] = effect if filter_groups: subscribe_kwargs['filter_groups'] = filter_groups + if sites: + subscribe_kwargs['sites'] = sites sid, message = await subscribe_chain.async_add( mtype=MediaType(media_type), @@ -118,6 +123,8 @@ class AddSubscribeTool(MoviePilotTool): params.append(f"特效过滤: {effect}") if filter_groups: params.append(f"规则组: {', '.join(filter_groups)}") + if sites: + params.append(f"站点: {', '.join(map(str, sites))}") if params: result_msg += f"\n配置参数: {', '.join(params)}" return result_msg diff --git a/app/agent/tools/impl/query_media_library.py b/app/agent/tools/impl/query_media_library.py index 04e34369..249bd2e1 100644 --- a/app/agent/tools/impl/query_media_library.py +++ b/app/agent/tools/impl/query_media_library.py @@ -1,7 +1,7 @@ """查询媒体库工具""" import json -from typing import Optional, List, Type +from typing import Optional, Type from pydantic import BaseModel, Field @@ -9,7 +9,6 @@ from app.agent.tools.base import MoviePilotTool from app.chain.mediaserver import MediaServerChain from app.core.context import MediaInfo from app.log import logger -from app.schemas import MediaServerItem from app.schemas.types import MediaType diff --git a/app/agent/tools/impl/search_person.py b/app/agent/tools/impl/search_person.py new file mode 100644 index 00000000..44c400ed --- /dev/null +++ b/app/agent/tools/impl/search_person.py @@ -0,0 +1,83 @@ +"""搜索人物工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.chain.media import MediaChain +from app.log import logger + + +class SearchPersonInput(BaseModel): + """搜索人物工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')") + + +class SearchPersonTool(MoviePilotTool): + name: str = "search_person" + description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database." + args_schema: Type[BaseModel] = SearchPersonInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据搜索参数生成友好的提示消息""" + name = kwargs.get("name", "") + return f"正在搜索人物: {name}" + + async def run(self, name: str, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: name={name}") + + try: + media_chain = MediaChain() + # 使用 MediaChain.async_search_persons 方法搜索人物 + persons = await media_chain.async_search_persons(name=name) + + if persons: + # 限制最多30条结果 + total_count = len(persons) + limited_persons = persons[:30] + # 精简字段,只保留关键信息 + simplified_results = [] + for person in limited_persons: + simplified = { + "name": person.name, + "id": person.id, + "source": person.source, + "profile_path": person.profile_path, + "original_name": person.original_name, + "known_for_department": person.known_for_department, + "popularity": person.popularity, + "biography": person.biography[:200] + "..." if person.biography and len(person.biography) > 200 else person.biography, + "birthday": person.birthday, + "deathday": person.deathday, + "place_of_birth": person.place_of_birth, + "gender": person.gender, + "imdb_id": person.imdb_id, + "also_known_as": person.also_known_as[:5] if person.also_known_as else [], # 限制别名数量 + } + # 添加豆瓣特有字段 + if person.source == "douban": + simplified["url"] = person.url + simplified["avatar"] = person.avatar + simplified["latin_name"] = person.latin_name + simplified["roles"] = person.roles[:5] if person.roles else [] # 限制角色数量 + # 添加Bangumi特有字段 + if person.source == "bangumi": + simplified["career"] = person.career + simplified["relation"] = person.relation + + simplified_results.append(simplified) + + result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2) + # 如果结果被裁剪,添加提示信息 + if total_count > 30: + return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}" + return result_json + else: + return f"未找到相关人物信息: {name}" + except Exception as e: + error_message = f"搜索人物失败: {str(e)}" + logger.error(f"搜索人物失败: {e}", exc_info=True) + return error_message diff --git a/app/agent/tools/impl/search_person_credits.py b/app/agent/tools/impl/search_person_credits.py new file mode 100644 index 00000000..b5724f54 --- /dev/null +++ b/app/agent/tools/impl/search_person_credits.py @@ -0,0 +1,85 @@ +"""搜索演员参演作品工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.chain.douban import DoubanChain +from app.chain.tmdb import TmdbChain +from app.chain.bangumi import BangumiChain +from app.log import logger + + +class SearchPersonCreditsInput(BaseModel): + """搜索演员参演作品工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)") + source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi") + page: Optional[int] = Field(1, description="Page number for pagination (default: 1)") + + +class SearchPersonCreditsTool(MoviePilotTool): + name: str = "search_person_credits" + description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in." + args_schema: Type[BaseModel] = SearchPersonCreditsInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据搜索参数生成友好的提示消息""" + person_id = kwargs.get("person_id", "") + source = kwargs.get("source", "") + return f"正在搜索人物参演作品: {source} ID {person_id}" + + async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}") + + try: + # 根据source选择相应的chain + if source.lower() == "tmdb": + tmdb_chain = TmdbChain() + medias = await tmdb_chain.async_person_credits(person_id=person_id, page=page) + elif source.lower() == "douban": + douban_chain = DoubanChain() + medias = await douban_chain.async_person_credits(person_id=person_id, page=page) + elif source.lower() == "bangumi": + bangumi_chain = BangumiChain() + medias = await bangumi_chain.async_person_credits(person_id=person_id) + else: + return f"不支持的数据源: {source}。支持的数据源: tmdb, douban, bangumi" + + if medias: + # 限制最多30条结果 + total_count = len(medias) + limited_medias = medias[:30] + # 精简字段,只保留关键信息 + simplified_results = [] + for media in limited_medias: + simplified = { + "title": media.title, + "en_title": media.en_title, + "year": media.year, + "type": media.type.value if media.type else None, + "season": media.season, + "tmdb_id": media.tmdb_id, + "imdb_id": media.imdb_id, + "douban_id": media.douban_id, + "overview": media.overview[:200] + "..." if media.overview and len(media.overview) > 200 else media.overview, + "vote_average": media.vote_average, + "poster_path": media.poster_path, + "backdrop_path": media.backdrop_path, + "detail_link": media.detail_link + } + simplified_results.append(simplified) + + result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2) + # 如果结果被裁剪,添加提示信息 + if total_count > 30: + return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}" + return result_json + else: + return f"未找到人物 ID {person_id} ({source}) 的参演作品" + except Exception as e: + error_message = f"搜索演员参演作品失败: {str(e)}" + logger.error(f"搜索演员参演作品失败: {e}", exc_info=True) + return error_message diff --git a/app/agent/tools/impl/search_web.py b/app/agent/tools/impl/search_web.py index b65ec345..2642b6d2 100644 --- a/app/agent/tools/impl/search_web.py +++ b/app/agent/tools/impl/search_web.py @@ -64,7 +64,8 @@ class SearchWebTool(MoviePilotTool): logger.error(f"搜索网络内容失败: {e}", exc_info=True) return error_message - async def _search_duckduckgo_api(self, query: str, max_results: int) -> list: + @staticmethod + async def _search_duckduckgo_api(query: str, max_results: int) -> list: """ 使用DuckDuckGo API进行搜索 @@ -143,7 +144,8 @@ class SearchWebTool(MoviePilotTool): logger.warning(f"DuckDuckGo API搜索失败: {e}") return [] - def _format_and_truncate_results(self, results: list, max_results: int) -> dict: + @staticmethod + def _format_and_truncate_results(results: list, max_results: int) -> dict: """ 格式化并裁剪搜索结果以避免占用过多上下文 diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index 1ebf4646..81c55e42 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -85,25 +85,26 @@ async def search(title: str, return obj.get("source") return obj.source - result = [] media_chain = MediaChain() if type == "media": _, medias = await media_chain.async_search(title=title) - if medias: - result = [media.to_dict() for media in medias] + result = [media.to_dict() for media in medias] if medias else [] elif type == "collection": - result = await media_chain.async_search_collections(name=title) - else: - result = await media_chain.async_search_persons(name=title) - if result: - # 按设置的顺序对结果进行排序 - setting_order = settings.SEARCH_SOURCE.split(',') or [] - sort_order = {} - for index, source in enumerate(setting_order): - sort_order[source] = index - result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4)) - return result[(page - 1) * count:page * count] - return [] + collections = await media_chain.async_search_collections(name=title) + result = [collection.to_dict() for collection in collections] if collections else [] + else: # person + persons = await media_chain.async_search_persons(name=title) + result = [person.model_dump() for person in persons] if persons else [] + + if not result: + return [] + + # 排序和分页 + setting_order = settings.SEARCH_SOURCE.split(',') if settings.SEARCH_SOURCE else [] + sort_order = {source: index for index, source in enumerate(setting_order)} + + sorted_result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4)) + return sorted_result[(page - 1) * count:page * count] @router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response) diff --git a/app/chain/message.py b/app/chain/message.py index ab5ae807..25469417 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -164,19 +164,15 @@ class MessageChain(ChainBase): ) # 处理消息 if text.startswith('CALLBACK:'): - # 处理按钮回调(适配支持回调的渠道) + # 处理按钮回调(适配支持回调的渠),优先级最高 if ChannelCapabilityManager.supports_callbacks(channel): self._handle_callback(text=text, channel=channel, source=source, userid=userid, username=username, original_message_id=original_message_id, original_chat_id=original_chat_id) else: logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}") - elif text.startswith('/ai') or text.startswith('/AI'): - # AI智能体处理 - self._handle_ai_message(text=text, channel=channel, source=source, - userid=userid, username=username) - elif text.startswith('/'): - # 执行命令 + elif text.startswith('/') and not text.lower().startswith('/ai'): + # 执行特定命令命令(但不是/ai) self.eventmanager.send_event( EventType.CommandExcute, { @@ -186,266 +182,226 @@ class MessageChain(ChainBase): "source": source } ) - elif text.isdigit(): - # 用户选择了具体的条目 - # 缓存 - cache_data: dict = user_cache.get(userid).copy() - # 选择项目 - if not cache_data \ - or not cache_data.get('items') \ - or len(cache_data.get('items')) < int(text): - # 发送消息 - self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid)) - return - try: - # 选择的序号 - _choice = int(text) + _current_page * self._page_size - 1 - # 缓存类型 - cache_type: str = cache_data.get('type') - # 缓存列表 - cache_list: list = cache_data.get('items').copy() - # 选择 + elif text.lower().startswith('/ai'): + # 用户指定AI智能体消息响应 + self._handle_ai_message(text=text, channel=channel, source=source, + userid=userid, username=username) + elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL: + # 普通消息,全局智能体响应 + self._handle_ai_message(text=text, channel=channel, source=source, + userid=userid, username=username) + else: + # 非智能体普通消息响应 + if text.isdigit(): + # 用户选择了具体的条目 + # 缓存 + cache_data: dict = user_cache.get(userid).copy() + # 选择项目 + if not cache_data \ + or not cache_data.get('items') \ + or len(cache_data.get('items')) < int(text): + # 发送消息 + self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid)) + return try: - if cache_type in ["Search", "ReSearch"]: - # 当前媒体信息 - mediainfo: MediaInfo = cache_list[_choice] - _current_media = mediainfo - # 查询缺失的媒体信息 - exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta, - mediainfo=_current_media) - if exist_flag and cache_type == "Search": - # 媒体库中已存在 + # 选择的序号 + _choice = int(text) + _current_page * self._page_size - 1 + # 缓存类型 + cache_type: str = cache_data.get('type') + # 缓存列表 + cache_list: list = cache_data.get('items').copy() + # 选择 + try: + if cache_type in ["Search", "ReSearch"]: + # 当前媒体信息 + mediainfo: MediaInfo = cache_list[_choice] + _current_media = mediainfo + # 查询缺失的媒体信息 + exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta, + mediainfo=_current_media) + if exist_flag and cache_type == "Search": + # 媒体库中已存在 + self.post_message( + Notification(channel=channel, + source=source, + title=f"【{_current_media.title_year}" + f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", + userid=userid)) + return + elif exist_flag: + # 没有缺失,但要全量重新搜索和下载 + no_exists = self.__get_noexits_info(_current_meta, _current_media) + # 发送缺失的媒体信息 + messages = [] + if no_exists and cache_type == "Search": + # 发送缺失消息 + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + messages = [ + f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" + for sea, no_exist in no_exists.get(mediakey).items()] + elif no_exists: + # 发送总集数的消息 + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + messages = [ + f"第 {sea} 季总 {no_exist.total_episode} 集" + for sea, no_exist in no_exists.get(mediakey).items()] + if messages: + self.post_message(Notification(channel=channel, + source=source, + title=f"{mediainfo.title_year}:\n" + "\n".join(messages), + userid=userid)) + # 搜索种子,过滤掉不需要的剧集,以便选择 + logger.info(f"开始搜索 {mediainfo.title_year} ...") self.post_message( Notification(channel=channel, source=source, - title=f"【{_current_media.title_year}" - f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", + title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", userid=userid)) - return - elif exist_flag: - # 没有缺失,但要全量重新搜索和下载 - no_exists = self.__get_noexits_info(_current_meta, _current_media) - # 发送缺失的媒体信息 - messages = [] - if no_exists and cache_type == "Search": - # 发送缺失消息 - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - messages = [ - f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" - for sea, no_exist in no_exists.get(mediakey).items()] - elif no_exists: - # 发送总集数的消息 - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - messages = [ - f"第 {sea} 季总 {no_exist.total_episode} 集" - for sea, no_exist in no_exists.get(mediakey).items()] - if messages: - self.post_message(Notification(channel=channel, - source=source, - title=f"{mediainfo.title_year}:\n" + "\n".join(messages), - userid=userid)) - # 搜索种子,过滤掉不需要的剧集,以便选择 - logger.info(f"开始搜索 {mediainfo.title_year} ...") - self.post_message( - Notification(channel=channel, - source=source, - title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", - userid=userid)) - # 开始搜索 - contexts = SearchChain().process(mediainfo=mediainfo, - no_exists=no_exists) - if not contexts: - # 没有数据 - self.post_message(Notification( - channel=channel, - source=source, - title=f"{mediainfo.title}" - f"{_current_meta.sea} 未搜索到需要的资源!", - userid=userid)) - return - # 搜索结果排序 - contexts = TorrentHelper().sort_torrents(contexts) - try: - # 判断是否设置自动下载 - auto_download_user = settings.AUTO_DOWNLOAD_USER - # 匹配到自动下载用户 - if auto_download_user \ - and (auto_download_user == "all" - or any(userid == user for user in auto_download_user.split(","))): - logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...") - # 自动选择下载 - self.__auto_download(channel=channel, - source=source, - cache_list=contexts, - userid=userid, - username=username, - no_exists=no_exists) - else: - # 更新缓存 - user_cache[userid] = { - "type": "Torrent", - "items": contexts - } - _current_page = 0 - # 保存缓存 - self.save_cache(user_cache, self._cache_file) - # 删除原消息 - if (original_message_id and original_chat_id and - ChannelCapabilityManager.supports_deletion(channel)): - self.delete_message( - channel=channel, - source=source, - message_id=original_message_id, - chat_id=original_chat_id - ) - # 发送种子数据 - logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...") - self.__post_torrents_message(channel=channel, - source=source, - title=mediainfo.title, - items=contexts[:self._page_size], - userid=userid, - total=len(contexts)) - finally: - contexts.clear() - del contexts - elif cache_type in ["Subscribe", "ReSubscribe"]: - # 订阅或洗版媒体 - mediainfo: MediaInfo = cache_list[_choice] - # 洗版标识 - best_version = False - # 查询缺失的媒体信息 - if cache_type == "Subscribe": - exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta, - mediainfo=mediainfo) - if exist_flag: + # 开始搜索 + contexts = SearchChain().process(mediainfo=mediainfo, + no_exists=no_exists) + if not contexts: + # 没有数据 self.post_message(Notification( channel=channel, source=source, - title=f"【{mediainfo.title_year}" - f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", + title=f"{mediainfo.title}" + f"{_current_meta.sea} 未搜索到需要的资源!", userid=userid)) return - else: - best_version = True - # 转换用户名 - mp_name = UserOper().get_name( - **{f"{channel.name.lower()}_userid": userid}) if channel else None - # 添加订阅,状态为N - SubscribeChain().add(title=mediainfo.title, - year=mediainfo.year, - mtype=mediainfo.type, - tmdbid=mediainfo.tmdb_id, - season=_current_meta.begin_season, - channel=channel, - source=source, - userid=userid, - username=mp_name or username, - best_version=best_version) - elif cache_type == "Torrent": - if int(text) == 0: - # 自动选择下载,强制下载模式 - self.__auto_download(channel=channel, + # 搜索结果排序 + contexts = TorrentHelper().sort_torrents(contexts) + try: + # 判断是否设置自动下载 + auto_download_user = settings.AUTO_DOWNLOAD_USER + # 匹配到自动下载用户 + if auto_download_user \ + and (auto_download_user == "all" + or any(userid == user for user in auto_download_user.split(","))): + logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...") + # 自动选择下载 + self.__auto_download(channel=channel, + source=source, + cache_list=contexts, + userid=userid, + username=username, + no_exists=no_exists) + else: + # 更新缓存 + user_cache[userid] = { + "type": "Torrent", + "items": contexts + } + _current_page = 0 + # 保存缓存 + self.save_cache(user_cache, self._cache_file) + # 删除原消息 + if (original_message_id and original_chat_id and + ChannelCapabilityManager.supports_deletion(channel)): + self.delete_message( + channel=channel, + source=source, + message_id=original_message_id, + chat_id=original_chat_id + ) + # 发送种子数据 + logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...") + self.__post_torrents_message(channel=channel, + source=source, + title=mediainfo.title, + items=contexts[:self._page_size], + userid=userid, + total=len(contexts)) + finally: + contexts.clear() + del contexts + elif cache_type in ["Subscribe", "ReSubscribe"]: + # 订阅或洗版媒体 + mediainfo: MediaInfo = cache_list[_choice] + # 洗版标识 + best_version = False + # 查询缺失的媒体信息 + if cache_type == "Subscribe": + exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta, + mediainfo=mediainfo) + if exist_flag: + self.post_message(Notification( + channel=channel, + source=source, + title=f"【{mediainfo.title_year}" + f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", + userid=userid)) + return + else: + best_version = True + # 转换用户名 + mp_name = UserOper().get_name( + **{f"{channel.name.lower()}_userid": userid}) if channel else None + # 添加订阅,状态为N + SubscribeChain().add(title=mediainfo.title, + year=mediainfo.year, + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + season=_current_meta.begin_season, + channel=channel, source=source, - cache_list=cache_list, userid=userid, - username=username) - else: - # 下载种子 - context: Context = cache_list[_choice] - # 下载 - DownloadChain().download_single(context, channel=channel, source=source, - userid=userid, username=username) + username=mp_name or username, + best_version=best_version) + elif cache_type == "Torrent": + if int(text) == 0: + # 自动选择下载,强制下载模式 + self.__auto_download(channel=channel, + source=source, + cache_list=cache_list, + userid=userid, + username=username) + else: + # 下载种子 + context: Context = cache_list[_choice] + # 下载 + DownloadChain().download_single(context, channel=channel, source=source, + userid=userid, username=username) + finally: + cache_list.clear() + del cache_list finally: - cache_list.clear() - del cache_list - finally: - cache_data.clear() - del cache_data - elif text.lower() == "p": - # 上一页 - cache_data: dict = user_cache.get(userid).copy() - if not cache_data: - # 没有缓存 - self.post_message(Notification( - channel=channel, source=source, title="输入有误!", userid=userid)) - return - try: - if _current_page == 0: - # 第一页 + cache_data.clear() + del cache_data + elif text.lower() == "p": + # 上一页 + cache_data: dict = user_cache.get(userid).copy() + if not cache_data: + # 没有缓存 self.post_message(Notification( - channel=channel, source=source, title="已经是第一页了!", userid=userid)) + channel=channel, source=source, title="输入有误!", userid=userid)) return - # 减一页 - _current_page -= 1 - cache_type: str = cache_data.get('type') - # 产生副本,避免修改原值 - cache_list: list = cache_data.get('items').copy() try: if _current_page == 0: - start = 0 - end = self._page_size - else: - start = _current_page * self._page_size - end = start + self._page_size - if cache_type == "Torrent": - # 发送种子数据 - self.__post_torrents_message(channel=channel, - source=source, - title=_current_media.title, - items=cache_list[start:end], - userid=userid, - total=len(cache_list), - original_message_id=original_message_id, - original_chat_id=original_chat_id) - else: - # 发送媒体数据 - self.__post_medias_message(channel=channel, - source=source, - title=_current_meta.name, - items=cache_list[start:end], - userid=userid, - total=len(cache_list), - original_message_id=original_message_id, - original_chat_id=original_chat_id) - finally: - cache_list.clear() - del cache_list - finally: - cache_data.clear() - del cache_data - elif text.lower() == "n": - # 下一页 - cache_data: dict = user_cache.get(userid).copy() - if not cache_data: - # 没有缓存 - self.post_message(Notification( - channel=channel, source=source, title="输入有误!", userid=userid)) - return - try: - cache_type: str = cache_data.get('type') - # 产生副本,避免修改原值 - cache_list: list = cache_data.get('items').copy() - total = len(cache_list) - # 加一页 - cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size] - if not cache_list: - # 没有数据 - self.post_message(Notification( - channel=channel, source=source, title="已经是最后一页了!", userid=userid)) - return - else: + # 第一页 + self.post_message(Notification( + channel=channel, source=source, title="已经是第一页了!", userid=userid)) + return + # 减一页 + _current_page -= 1 + cache_type: str = cache_data.get('type') + # 产生副本,避免修改原值 + cache_list: list = cache_data.get('items').copy() try: - # 加一页 - _current_page += 1 + if _current_page == 0: + start = 0 + end = self._page_size + else: + start = _current_page * self._page_size + end = start + self._page_size if cache_type == "Torrent": # 发送种子数据 self.__post_torrents_message(channel=channel, source=source, title=_current_media.title, - items=cache_list, + items=cache_list[start:end], userid=userid, - total=total, + total=len(cache_list), original_message_id=original_message_id, original_chat_id=original_chat_id) else: @@ -453,94 +409,144 @@ class MessageChain(ChainBase): self.__post_medias_message(channel=channel, source=source, title=_current_meta.name, - items=cache_list, + items=cache_list[start:end], userid=userid, - total=total, + total=len(cache_list), original_message_id=original_message_id, original_chat_id=original_chat_id) finally: cache_list.clear() del cache_list - finally: - cache_data.clear() - del cache_data - else: - # 搜索或订阅 - if text.startswith("订阅"): - # 订阅 - content = re.sub(r"订阅[::\s]*", "", text) - action = "Subscribe" - elif text.startswith("洗版"): - # 洗版 - content = re.sub(r"洗版[::\s]*", "", text) - action = "ReSubscribe" - elif text.startswith("搜索") or text.startswith("下载"): - # 重新搜索/下载 - content = re.sub(r"(搜索|下载)[::\s]*", "", text) - action = "ReSearch" - elif text.startswith("#") \ - or re.search(r"^请[问帮你]", text) \ - or re.search(r"[??]$", text) \ - or StringUtils.count_words(text) > 10 \ - or text.find("继续") != -1: - # 聊天 - content = text - action = "Chat" - elif StringUtils.is_link(text): - # 链接 - content = text - action = "Link" - else: - # 搜索 - content = text - action = "Search" - - if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]: - # 搜索 - meta, medias = MediaChain().search(content) - # 识别 - if not meta.name: - self.post_message(Notification( - channel=channel, source=source, title="无法识别输入内容!", userid=userid)) - return - # 开始搜索 - if not medias: - self.post_message(Notification( - channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", - userid=userid)) - return - logger.info(f"搜索到 {len(medias)} 条相关媒体信息") - try: - # 记录当前状态 - _current_meta = meta - # 保存缓存 - user_cache[userid] = { - 'type': action, - 'items': medias - } - self.save_cache(user_cache, self._cache_file) - _current_page = 0 - _current_media = None - # 发送媒体列表 - self.__post_medias_message(channel=channel, - source=source, - title=meta.name, - items=medias[:self._page_size], - userid=userid, total=len(medias)) finally: - medias.clear() - del medias + cache_data.clear() + del cache_data + elif text.lower() == "n": + # 下一页 + cache_data: dict = user_cache.get(userid).copy() + if not cache_data: + # 没有缓存 + self.post_message(Notification( + channel=channel, source=source, title="输入有误!", userid=userid)) + return + try: + cache_type: str = cache_data.get('type') + # 产生副本,避免修改原值 + cache_list: list = cache_data.get('items').copy() + total = len(cache_list) + # 加一页 + cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size] + if not cache_list: + # 没有数据 + self.post_message(Notification( + channel=channel, source=source, title="已经是最后一页了!", userid=userid)) + return + else: + try: + # 加一页 + _current_page += 1 + if cache_type == "Torrent": + # 发送种子数据 + self.__post_torrents_message(channel=channel, + source=source, + title=_current_media.title, + items=cache_list, + userid=userid, + total=total, + original_message_id=original_message_id, + original_chat_id=original_chat_id) + else: + # 发送媒体数据 + self.__post_medias_message(channel=channel, + source=source, + title=_current_meta.name, + items=cache_list, + userid=userid, + total=total, + original_message_id=original_message_id, + original_chat_id=original_chat_id) + finally: + cache_list.clear() + del cache_list + finally: + cache_data.clear() + del cache_data else: - # 广播事件 - self.eventmanager.send_event( - EventType.UserMessage, - { - "text": content, - "userid": userid, - "channel": channel, - "source": source - } - ) + # 搜索或订阅 + if text.startswith("订阅"): + # 订阅 + content = re.sub(r"订阅[::\s]*", "", text) + action = "Subscribe" + elif text.startswith("洗版"): + # 洗版 + content = re.sub(r"洗版[::\s]*", "", text) + action = "ReSubscribe" + elif text.startswith("搜索") or text.startswith("下载"): + # 重新搜索/下载 + content = re.sub(r"(搜索|下载)[::\s]*", "", text) + action = "ReSearch" + elif text.startswith("#") \ + or re.search(r"^请[问帮你]", text) \ + or re.search(r"[??]$", text) \ + or StringUtils.count_words(text) > 10 \ + or text.find("继续") != -1: + # 聊天 + content = text + action = "Chat" + elif StringUtils.is_link(text): + # 链接 + content = text + action = "Link" + else: + # 搜索 + content = text + action = "Search" + + if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]: + # 搜索 + meta, medias = MediaChain().search(content) + # 识别 + if not meta.name: + self.post_message(Notification( + channel=channel, source=source, title="无法识别输入内容!", userid=userid)) + return + # 开始搜索 + if not medias: + self.post_message(Notification( + channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", + userid=userid)) + return + logger.info(f"搜索到 {len(medias)} 条相关媒体信息") + try: + # 记录当前状态 + _current_meta = meta + # 保存缓存 + user_cache[userid] = { + 'type': action, + 'items': medias + } + self.save_cache(user_cache, self._cache_file) + _current_page = 0 + _current_media = None + # 发送媒体列表 + self.__post_medias_message(channel=channel, + source=source, + title=meta.name, + items=medias[:self._page_size], + userid=userid, total=len(medias)) + finally: + medias.clear() + del medias + else: + # 广播事件 + self.eventmanager.send_event( + EventType.UserMessage, + { + "text": content, + "userid": userid, + "channel": channel, + "source": source + } + ) finally: user_cache.clear() del user_cache @@ -926,7 +932,10 @@ class MessageChain(ChainBase): return # 提取用户消息 - user_message = text[3:].strip() # 移除 "/ai" 前缀 + if text.lower().startswith("/ai"): + user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感) + else: + user_message = text.strip() # 按原消息处理 if not user_message: self.post_message(Notification( channel=channel, diff --git a/app/core/cache.py b/app/core/cache.py index 0982d5a0..2bb1f1e9 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -1024,13 +1024,11 @@ def fresh(fresh: bool = True): with fresh(): result = some_cached_function() """ - token = _fresh.set(fresh) - logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}") + token = _fresh.set(fresh or is_fresh()) try: yield finally: _fresh.reset(token) - logger.debug(f"Reset fresh mode. {id(token):#x}") @asynccontextmanager async def async_fresh(fresh: bool = True): @@ -1041,13 +1039,11 @@ async def async_fresh(fresh: bool = True): async with async_fresh(): result = await some_async_cached_function() """ - token = _fresh.set(fresh) - logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}") + token = _fresh.set(fresh or is_fresh()) try: yield finally: _fresh.reset(token) - logger.debug(f"Reset async_fresh mode. {id(token):#x}") def is_fresh() -> bool: """ diff --git a/app/core/config.py b/app/core/config.py index 52a49318..2f3ad3bc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -411,6 +411,8 @@ class ConfigModel(BaseModel): # ==================== AI智能体配置 ==================== # AI智能体开关 AI_AGENT_ENABLE: bool = False + # 合局AI智能体 + AI_AGENT_GLOBAL: bool = False # LLM提供商 (openai/google/deepseek) LLM_PROVIDER: str = "deepseek" # LLM模型名称 diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index a3c72898..b4ba7c6d 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -1,19 +1,21 @@ +import asyncio import re import threading -import uuid -from pathlib import Path from threading import Event from typing import Optional, List, Dict, Callable from urllib.parse import urljoin import telebot +from telegramify_markdown import standardize, telegramify +from telegramify_markdown.type import ContentTypes, SentType from telebot import apihelper -from telebot.types import InputFile, InlineKeyboardMarkup, InlineKeyboardButton +from telebot.types import InlineKeyboardMarkup, InlineKeyboardButton from telebot.types import InputMediaPhoto from app.core.config import settings from app.core.context import MediaInfo, Context from app.core.metainfo import MetaInfo +from app.helper.thread import ThreadHelper from app.log import logger from app.utils.common import retry from app.utils.http import RequestUtils @@ -52,7 +54,7 @@ class Telegram: else: apihelper.proxy = settings.PROXY # bot - _bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown") + _bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2") # 记录句柄 self._bot = _bot # 获取并存储bot用户名用于@检测 @@ -236,12 +238,14 @@ class Telegram: return False try: - if text: - # 对text进行Markdown特殊字符转义 - text = re.sub(r"([_`])", r"\\\1", text) - caption = f"*{title}*\n{text}" + if title and text: + caption = f"**{title}**\n{text}" + elif title: + caption = f"**{title}**" + elif text: + caption = text else: - caption = f"*{title}*" + caption = "" if link: caption = f"{caption}\n[查看详情]({link})" @@ -499,7 +503,7 @@ class Telegram: if image: # 如果有图片,使用edit_message_media - media = InputMediaPhoto(media=image, caption=text, parse_mode="Markdown") + media = InputMediaPhoto(media=image, caption=standardize(text), parse_mode="MarkdownV2") self._bot.edit_message_media( chat_id=chat_id, message_id=message_id, @@ -511,8 +515,8 @@ class Telegram: self._bot.edit_message_text( chat_id=chat_id, message_id=message_id, - text=text, - parse_mode="Markdown", + text=standardize(text), + parse_mode="MarkdownV2", reply_markup=reply_markup ) return True @@ -520,49 +524,120 @@ class Telegram: logger.error(f"编辑消息失败:{str(e)}") return False - @retry(RetryException, logger=logger) def __send_request(self, userid: Optional[str] = None, image="", caption="", reply_markup: Optional[InlineKeyboardMarkup] = None) -> bool: """ 向Telegram发送报文 :param reply_markup: 内联键盘 """ - if image: - res = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(image) - if res is None: - raise Exception("获取图片失败") - if res.content: - # 使用随机标识构建图片文件的完整路径,并写入图片内容到文件 - image_file = Path(settings.TEMP_PATH) / "telegram" / str(uuid.uuid4()) - if not image_file.parent.exists(): - image_file.parent.mkdir(parents=True, exist_ok=True) - image_file.write_bytes(res.content) - photo = InputFile(image_file) - # 发送图片到Telegram - ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id, - photo=photo, - caption=caption, - parse_mode="Markdown", - reply_markup=reply_markup) - if ret is None: - raise RetryException("发送图片消息失败") - return True - # 按4096分段循环发送消息 - ret = None - if len(caption) > 4095: - for i in range(0, len(caption), 4095): - ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id, - text=caption[i:i + 4095], - parse_mode="Markdown", - reply_markup=reply_markup if i == 0 else None) - else: - ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id, - text=caption, - parse_mode="Markdown", - reply_markup=reply_markup) - if ret is None: - raise RetryException("发送文本消息失败") - return True if ret else False + kwargs = { + 'chat_id': userid or self._telegram_chat_id, + 'parse_mode': "MarkdownV2", + 'reply_markup': reply_markup + } + + try: + # 处理图片 + image = self.__process_image(image) if image else None + + # 图片消息的标题长度限制为1024,文本消息为4096 + caption_limit = 1024 if image else 4096 + if len(caption) < caption_limit: + ret = self.__send_short_message(image, caption, **kwargs) + else: + sent_idx = set() + ret = self.__send_long_message(image, caption, sent_idx, **kwargs) + + return ret is not None + except Exception as e: + logger.error(f"发送Telegram消息失败: {e}") + return False + + @retry(RetryException, logger=logger) + def __process_image(self, image_url: str) -> bytes: + """ + 处理图片URL,获取图片内容 + """ + try: + res = RequestUtils( + proxies=settings.PROXY, + ua=settings.NORMAL_USER_AGENT + ).get_res(image_url) + + if not res or not res.content: + raise RetryException("获取图片失败") + + return res.content + except Exception as e: + raise + + @retry(RetryException, logger=logger) + def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs): + """ + 发送短消息 + """ + try: + if image: + return self._bot.send_photo( + photo=image, + caption=standardize(caption), + **kwargs + ) + else: + return self._bot.send_message( + text=standardize(caption), + **kwargs + ) + except Exception as e: + raise RetryException(f"发送{'图片' if image else '文本'}消息失败") + + @retry(RetryException, logger=logger) + def __send_long_message(self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs): + """ + 发送长消息 + """ + try: + reply_markup = kwargs.pop("reply_markup", None) + + boxs: SentType = ThreadHelper().submit(lambda x: asyncio.run(telegramify(x)), caption).result() + + ret = None + for i, item in enumerate(boxs): + if i in sent_idx: + # 跳过已发送消息 + continue + + current_reply_markup = reply_markup if i == 0 else None + + if item.content_type == ContentTypes.TEXT and (i != 0 or not image): + ret = self._bot.send_message(**kwargs, + text=item.content, + reply_markup=current_reply_markup + ) + + elif item.content_type == ContentTypes.PHOTO or (image and i == 0): + ret = self._bot.send_photo(**kwargs, + photo=(getattr(item, "file_name", ""), + getattr(item, "file_data", image)), + caption=getattr(item, "caption", item.content), + reply_markup=current_reply_markup + ) + + elif item.content_type == ContentTypes.FILE: + ret = self._bot.send_document(**kwargs, + document=(item.file_name, item.file_data), + caption=item.caption, + reply_markup=current_reply_markup + ) + + sent_idx.add(i) + + return ret + except Exception as e: + try: + raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e + except NameError: + raise RetryException("发送长消息失败") from e def register_commands(self, commands: Dict[str, dict]): """ diff --git a/app/schemas/message.py b/app/schemas/message.py index f3a66aec..187b9798 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -14,7 +14,7 @@ class CommingMessage(BaseModel): # 用户ID userid: Optional[Union[str, int]] = None # 用户名称 - username: Optional[str] = None + username: Optional[Union[str, int]] = None # 消息渠道 channel: Optional[MessageChannel] = None # 来源(渠道名称) diff --git a/version.py b/version.py index 36f5975d..65f64d57 100644 --- a/version.py +++ b/version.py @@ -1,2 +1,2 @@ -APP_VERSION = 'v2.8.5' -FRONTEND_VERSION = 'v2.8.5' +APP_VERSION = 'v2.8.6' +FRONTEND_VERSION = 'v2.8.6'