Merge remote-tracking branch 'origin/v2' into v2

This commit is contained in:
jxxghp
2025-11-24 21:30:07 +08:00
14 changed files with 681 additions and 437 deletions

View File

@@ -1,7 +1,6 @@
"""MoviePilot AI智能体实现"""
import asyncio
import threading
from typing import Dict, List, Any
from langchain.agents import AgentExecutor, create_openai_tools_agent
@@ -21,9 +20,6 @@ from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
# 用于保护环境变量修改的线程锁
_env_lock = threading.Lock()
class AgentChain(ChainBase):
pass
@@ -75,36 +71,21 @@ class MoviePilotAgent:
api_key = settings.LLM_API_KEY
if provider == "google":
import os
from contextlib import contextmanager
from langchain_google_genai import ChatGoogleGenerativeAI
# 使用线程锁保护的临时环境变量配置
@contextmanager
def _temp_proxy_env():
"""线程安全的临时设置代理环境变量的上下文管理器"""
with _env_lock:
old_http = os.environ.get("HTTP_PROXY")
old_https = os.environ.get("HTTPS_PROXY")
try:
if settings.PROXY_HOST:
os.environ["HTTP_PROXY"] = settings.PROXY_HOST
os.environ["HTTPS_PROXY"] = settings.PROXY_HOST
yield
finally:
# 恢复原始环境变量
if old_http is not None:
os.environ["HTTP_PROXY"] = old_http
elif "HTTP_PROXY" in os.environ:
del os.environ["HTTP_PROXY"]
if old_https is not None:
os.environ["HTTPS_PROXY"] = old_https
elif "HTTPS_PROXY" in os.environ:
del os.environ["HTTPS_PROXY"]
# 在临时环境变量中初始化 ChatGoogleGenerativeAI
with _temp_proxy_env():
if settings.PROXY_HOST:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
else:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,

View File

@@ -21,6 +21,8 @@ from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribes
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
from app.agent.tools.impl.search_media import SearchMediaTool
from app.agent.tools.impl.search_person import SearchPersonTool
from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
@@ -53,6 +55,8 @@ class MoviePilotToolFactory:
tools = []
tool_definitions = [
SearchMediaTool,
SearchPersonTool,
SearchPersonCreditsTool,
RecognizeMediaTool,
ScrapeMetadataTool,
QueryEpisodeScheduleTool,

View File

@@ -33,6 +33,8 @@ class AddSubscribeInput(BaseModel):
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
filter_groups: Optional[List[str]] = Field(None,
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
sites: Optional[List[int]] = Field(None,
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
class AddSubscribeTool(MoviePilotTool):
@@ -61,12 +63,13 @@ class AddSubscribeTool(MoviePilotTool):
season: Optional[int] = None, tmdb_id: Optional[str] = None,
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
quality: Optional[str] = None, resolution: Optional[str] = None,
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, **kwargs) -> str:
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
f"effect={effect}, filter_groups={filter_groups}")
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
try:
subscribe_chain = SubscribeChain()
@@ -92,6 +95,8 @@ class AddSubscribeTool(MoviePilotTool):
subscribe_kwargs['effect'] = effect
if filter_groups:
subscribe_kwargs['filter_groups'] = filter_groups
if sites:
subscribe_kwargs['sites'] = sites
sid, message = await subscribe_chain.async_add(
mtype=MediaType(media_type),
@@ -118,6 +123,8 @@ class AddSubscribeTool(MoviePilotTool):
params.append(f"特效过滤: {effect}")
if filter_groups:
params.append(f"规则组: {', '.join(filter_groups)}")
if sites:
params.append(f"站点: {', '.join(map(str, sites))}")
if params:
result_msg += f"\n配置参数: {', '.join(params)}"
return result_msg

View File

@@ -1,7 +1,7 @@
"""查询媒体库工具"""
import json
from typing import Optional, List, Type
from typing import Optional, Type
from pydantic import BaseModel, Field
@@ -9,7 +9,6 @@ from app.agent.tools.base import MoviePilotTool
from app.chain.mediaserver import MediaServerChain
from app.core.context import MediaInfo
from app.log import logger
from app.schemas import MediaServerItem
from app.schemas.types import MediaType

View File

@@ -0,0 +1,83 @@
"""搜索人物工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain
from app.log import logger
class SearchPersonInput(BaseModel):
"""搜索人物工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')")
class SearchPersonTool(MoviePilotTool):
name: str = "search_person"
description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database."
args_schema: Type[BaseModel] = SearchPersonInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据搜索参数生成友好的提示消息"""
name = kwargs.get("name", "")
return f"正在搜索人物: {name}"
async def run(self, name: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: name={name}")
try:
media_chain = MediaChain()
# 使用 MediaChain.async_search_persons 方法搜索人物
persons = await media_chain.async_search_persons(name=name)
if persons:
# 限制最多30条结果
total_count = len(persons)
limited_persons = persons[:30]
# 精简字段,只保留关键信息
simplified_results = []
for person in limited_persons:
simplified = {
"name": person.name,
"id": person.id,
"source": person.source,
"profile_path": person.profile_path,
"original_name": person.original_name,
"known_for_department": person.known_for_department,
"popularity": person.popularity,
"biography": person.biography[:200] + "..." if person.biography and len(person.biography) > 200 else person.biography,
"birthday": person.birthday,
"deathday": person.deathday,
"place_of_birth": person.place_of_birth,
"gender": person.gender,
"imdb_id": person.imdb_id,
"also_known_as": person.also_known_as[:5] if person.also_known_as else [], # 限制别名数量
}
# 添加豆瓣特有字段
if person.source == "douban":
simplified["url"] = person.url
simplified["avatar"] = person.avatar
simplified["latin_name"] = person.latin_name
simplified["roles"] = person.roles[:5] if person.roles else [] # 限制角色数量
# 添加Bangumi特有字段
if person.source == "bangumi":
simplified["career"] = person.career
simplified["relation"] = person.relation
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 30:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到相关人物信息: {name}"
except Exception as e:
error_message = f"搜索人物失败: {str(e)}"
logger.error(f"搜索人物失败: {e}", exc_info=True)
return error_message

View File

@@ -0,0 +1,85 @@
"""搜索演员参演作品工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.chain.bangumi import BangumiChain
from app.log import logger
class SearchPersonCreditsInput(BaseModel):
"""搜索演员参演作品工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)")
source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi")
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
class SearchPersonCreditsTool(MoviePilotTool):
name: str = "search_person_credits"
description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in."
args_schema: Type[BaseModel] = SearchPersonCreditsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据搜索参数生成友好的提示消息"""
person_id = kwargs.get("person_id", "")
source = kwargs.get("source", "")
return f"正在搜索人物参演作品: {source} ID {person_id}"
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")
try:
# 根据source选择相应的chain
if source.lower() == "tmdb":
tmdb_chain = TmdbChain()
medias = await tmdb_chain.async_person_credits(person_id=person_id, page=page)
elif source.lower() == "douban":
douban_chain = DoubanChain()
medias = await douban_chain.async_person_credits(person_id=person_id, page=page)
elif source.lower() == "bangumi":
bangumi_chain = BangumiChain()
medias = await bangumi_chain.async_person_credits(person_id=person_id)
else:
return f"不支持的数据源: {source}。支持的数据源: tmdb, douban, bangumi"
if medias:
# 限制最多30条结果
total_count = len(medias)
limited_medias = medias[:30]
# 精简字段,只保留关键信息
simplified_results = []
for media in limited_medias:
simplified = {
"title": media.title,
"en_title": media.en_title,
"year": media.year,
"type": media.type.value if media.type else None,
"season": media.season,
"tmdb_id": media.tmdb_id,
"imdb_id": media.imdb_id,
"douban_id": media.douban_id,
"overview": media.overview[:200] + "..." if media.overview and len(media.overview) > 200 else media.overview,
"vote_average": media.vote_average,
"poster_path": media.poster_path,
"backdrop_path": media.backdrop_path,
"detail_link": media.detail_link
}
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 30:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到人物 ID {person_id} ({source}) 的参演作品"
except Exception as e:
error_message = f"搜索演员参演作品失败: {str(e)}"
logger.error(f"搜索演员参演作品失败: {e}", exc_info=True)
return error_message

View File

@@ -64,7 +64,8 @@ class SearchWebTool(MoviePilotTool):
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
return error_message
async def _search_duckduckgo_api(self, query: str, max_results: int) -> list:
@staticmethod
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
"""
使用DuckDuckGo API进行搜索
@@ -143,7 +144,8 @@ class SearchWebTool(MoviePilotTool):
logger.warning(f"DuckDuckGo API搜索失败: {e}")
return []
def _format_and_truncate_results(self, results: list, max_results: int) -> dict:
@staticmethod
def _format_and_truncate_results(results: list, max_results: int) -> dict:
"""
格式化并裁剪搜索结果以避免占用过多上下文

View File

@@ -85,25 +85,26 @@ async def search(title: str,
return obj.get("source")
return obj.source
result = []
media_chain = MediaChain()
if type == "media":
_, medias = await media_chain.async_search(title=title)
if medias:
result = [media.to_dict() for media in medias]
result = [media.to_dict() for media in medias] if medias else []
elif type == "collection":
result = await media_chain.async_search_collections(name=title)
else:
result = await media_chain.async_search_persons(name=title)
if result:
# 按设置的顺序对结果进行排序
setting_order = settings.SEARCH_SOURCE.split(',') or []
sort_order = {}
for index, source in enumerate(setting_order):
sort_order[source] = index
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
return result[(page - 1) * count:page * count]
return []
collections = await media_chain.async_search_collections(name=title)
result = [collection.to_dict() for collection in collections] if collections else []
else: # person
persons = await media_chain.async_search_persons(name=title)
result = [person.model_dump() for person in persons] if persons else []
if not result:
return []
# 排序和分页
setting_order = settings.SEARCH_SOURCE.split(',') if settings.SEARCH_SOURCE else []
sort_order = {source: index for index, source in enumerate(setting_order)}
sorted_result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
return sorted_result[(page - 1) * count:page * count]
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)

View File

@@ -164,19 +164,15 @@ class MessageChain(ChainBase):
)
# 处理消息
if text.startswith('CALLBACK:'):
# 处理按钮回调(适配支持回调的渠
# 处理按钮回调(适配支持回调的渠),优先级最高
if ChannelCapabilityManager.supports_callbacks(channel):
self._handle_callback(text=text, channel=channel, source=source,
userid=userid, username=username,
original_message_id=original_message_id, original_chat_id=original_chat_id)
else:
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
elif text.startswith('/ai') or text.startswith('/AI'):
# AI智能体处理
self._handle_ai_message(text=text, channel=channel, source=source,
userid=userid, username=username)
elif text.startswith('/'):
# 执行命令
elif text.startswith('/') and not text.lower().startswith('/ai'):
# 执行特定命令命令(但不是/ai
self.eventmanager.send_event(
EventType.CommandExcute,
{
@@ -186,266 +182,226 @@ class MessageChain(ChainBase):
"source": source
}
)
elif text.isdigit():
# 用户选择了具体的条目
# 缓存
cache_data: dict = user_cache.get(userid).copy()
# 选择项目
if not cache_data \
or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
return
try:
# 选择的序号
_choice = int(text) + _current_page * self._page_size - 1
# 缓存类型
cache_type: str = cache_data.get('type')
# 缓存列表
cache_list: list = cache_data.get('items').copy()
# 选择
elif text.lower().startswith('/ai'):
# 用户指定AI智能体消息响应
self._handle_ai_message(text=text, channel=channel, source=source,
userid=userid, username=username)
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
# 普通消息,全局智能体响应
self._handle_ai_message(text=text, channel=channel, source=source,
userid=userid, username=username)
else:
# 非智能体普通消息响应
if text.isdigit():
# 用户选择了具体的条目
# 缓存
cache_data: dict = user_cache.get(userid).copy()
# 选择项目
if not cache_data \
or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
return
try:
if cache_type in ["Search", "ReSearch"]:
# 当前媒体信息
mediainfo: MediaInfo = cache_list[_choice]
_current_media = mediainfo
# 查询缺失的媒体信息
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
mediainfo=_current_media)
if exist_flag and cache_type == "Search":
# 媒体库中已存在
# 选择的序号
_choice = int(text) + _current_page * self._page_size - 1
# 缓存类型
cache_type: str = cache_data.get('type')
# 缓存列表
cache_list: list = cache_data.get('items').copy()
# 选择
try:
if cache_type in ["Search", "ReSearch"]:
# 当前媒体信息
mediainfo: MediaInfo = cache_list[_choice]
_current_media = mediainfo
# 查询缺失的媒体信息
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
mediainfo=_current_media)
if exist_flag and cache_type == "Search":
# 媒体库中已存在
self.post_message(
Notification(channel=channel,
source=source,
title=f"{_current_media.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
userid=userid))
return
elif exist_flag:
# 没有缺失,但要全量重新搜索和下载
no_exists = self.__get_noexits_info(_current_meta, _current_media)
# 发送缺失的媒体信息
messages = []
if no_exists and cache_type == "Search":
# 发送缺失消息
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
messages = [
f"{sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode}"
for sea, no_exist in no_exists.get(mediakey).items()]
elif no_exists:
# 发送总集数的消息
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
messages = [
f"{sea} 季总 {no_exist.total_episode}"
for sea, no_exist in no_exists.get(mediakey).items()]
if messages:
self.post_message(Notification(channel=channel,
source=source,
title=f"{mediainfo.title_year}\n" + "\n".join(messages),
userid=userid))
# 搜索种子,过滤掉不需要的剧集,以便选择
logger.info(f"开始搜索 {mediainfo.title_year} ...")
self.post_message(
Notification(channel=channel,
source=source,
title=f"{_current_media.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
userid=userid))
return
elif exist_flag:
# 没有缺失,但要全量重新搜索和下载
no_exists = self.__get_noexits_info(_current_meta, _current_media)
# 发送缺失的媒体信息
messages = []
if no_exists and cache_type == "Search":
# 发送缺失消息
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
messages = [
f"{sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode}"
for sea, no_exist in no_exists.get(mediakey).items()]
elif no_exists:
# 发送总集数的消息
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
messages = [
f"{sea} 季总 {no_exist.total_episode}"
for sea, no_exist in no_exists.get(mediakey).items()]
if messages:
self.post_message(Notification(channel=channel,
source=source,
title=f"{mediainfo.title_year}\n" + "\n".join(messages),
userid=userid))
# 搜索种子,过滤掉不需要的剧集,以便选择
logger.info(f"开始搜索 {mediainfo.title_year} ...")
self.post_message(
Notification(channel=channel,
source=source,
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
userid=userid))
# 开始搜索
contexts = SearchChain().process(mediainfo=mediainfo,
no_exists=no_exists)
if not contexts:
# 没有数据
self.post_message(Notification(
channel=channel,
source=source,
title=f"{mediainfo.title}"
f"{_current_meta.sea} 未搜索到需要的资源!",
userid=userid))
return
# 搜索结果排序
contexts = TorrentHelper().sort_torrents(contexts)
try:
# 判断是否设置自动下载
auto_download_user = settings.AUTO_DOWNLOAD_USER
# 匹配到自动下载用户
if auto_download_user \
and (auto_download_user == "all"
or any(userid == user for user in auto_download_user.split(","))):
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
# 自动选择下载
self.__auto_download(channel=channel,
source=source,
cache_list=contexts,
userid=userid,
username=username,
no_exists=no_exists)
else:
# 更新缓存
user_cache[userid] = {
"type": "Torrent",
"items": contexts
}
_current_page = 0
# 保存缓存
self.save_cache(user_cache, self._cache_file)
# 删除原消息
if (original_message_id and original_chat_id and
ChannelCapabilityManager.supports_deletion(channel)):
self.delete_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id
)
# 发送种子数据
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
self.__post_torrents_message(channel=channel,
source=source,
title=mediainfo.title,
items=contexts[:self._page_size],
userid=userid,
total=len(contexts))
finally:
contexts.clear()
del contexts
elif cache_type in ["Subscribe", "ReSubscribe"]:
# 订阅或洗版媒体
mediainfo: MediaInfo = cache_list[_choice]
# 洗版标识
best_version = False
# 查询缺失的媒体信息
if cache_type == "Subscribe":
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
mediainfo=mediainfo)
if exist_flag:
# 开始搜索
contexts = SearchChain().process(mediainfo=mediainfo,
no_exists=no_exists)
if not contexts:
# 没有数据
self.post_message(Notification(
channel=channel,
source=source,
title=f"{mediainfo.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
title=f"{mediainfo.title}"
f"{_current_meta.sea} 未搜索到需要的资源!",
userid=userid))
return
else:
best_version = True
# 转换用户名
mp_name = UserOper().get_name(
**{f"{channel.name.lower()}_userid": userid}) if channel else None
# 添加订阅状态为N
SubscribeChain().add(title=mediainfo.title,
year=mediainfo.year,
mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
season=_current_meta.begin_season,
channel=channel,
source=source,
userid=userid,
username=mp_name or username,
best_version=best_version)
elif cache_type == "Torrent":
if int(text) == 0:
# 自动选择下载,强制下载模式
self.__auto_download(channel=channel,
# 搜索结果排序
contexts = TorrentHelper().sort_torrents(contexts)
try:
# 判断是否设置自动下载
auto_download_user = settings.AUTO_DOWNLOAD_USER
# 匹配到自动下载用户
if auto_download_user \
and (auto_download_user == "all"
or any(userid == user for user in auto_download_user.split(","))):
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
# 自动选择下载
self.__auto_download(channel=channel,
source=source,
cache_list=contexts,
userid=userid,
username=username,
no_exists=no_exists)
else:
# 更新缓存
user_cache[userid] = {
"type": "Torrent",
"items": contexts
}
_current_page = 0
# 保存缓存
self.save_cache(user_cache, self._cache_file)
# 删除原消息
if (original_message_id and original_chat_id and
ChannelCapabilityManager.supports_deletion(channel)):
self.delete_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id
)
# 发送种子数据
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
self.__post_torrents_message(channel=channel,
source=source,
title=mediainfo.title,
items=contexts[:self._page_size],
userid=userid,
total=len(contexts))
finally:
contexts.clear()
del contexts
elif cache_type in ["Subscribe", "ReSubscribe"]:
# 订阅或洗版媒体
mediainfo: MediaInfo = cache_list[_choice]
# 洗版标识
best_version = False
# 查询缺失的媒体信息
if cache_type == "Subscribe":
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
mediainfo=mediainfo)
if exist_flag:
self.post_message(Notification(
channel=channel,
source=source,
title=f"{mediainfo.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
userid=userid))
return
else:
best_version = True
# 转换用户名
mp_name = UserOper().get_name(
**{f"{channel.name.lower()}_userid": userid}) if channel else None
# 添加订阅状态为N
SubscribeChain().add(title=mediainfo.title,
year=mediainfo.year,
mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
season=_current_meta.begin_season,
channel=channel,
source=source,
cache_list=cache_list,
userid=userid,
username=username)
else:
# 下载种子
context: Context = cache_list[_choice]
# 下载
DownloadChain().download_single(context, channel=channel, source=source,
userid=userid, username=username)
username=mp_name or username,
best_version=best_version)
elif cache_type == "Torrent":
if int(text) == 0:
# 自动选择下载,强制下载模式
self.__auto_download(channel=channel,
source=source,
cache_list=cache_list,
userid=userid,
username=username)
else:
# 下载种子
context: Context = cache_list[_choice]
# 下载
DownloadChain().download_single(context, channel=channel, source=source,
userid=userid, username=username)
finally:
cache_list.clear()
del cache_list
finally:
cache_list.clear()
del cache_list
finally:
cache_data.clear()
del cache_data
elif text.lower() == "p":
# 上一页
cache_data: dict = user_cache.get(userid).copy()
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
return
try:
if _current_page == 0:
# 第一页
cache_data.clear()
del cache_data
elif text.lower() == "p":
# 上一页
cache_data: dict = user_cache.get(userid).copy()
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="已经是第一页了", userid=userid))
channel=channel, source=source, title="输入有误", userid=userid))
return
# 减一页
_current_page -= 1
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
cache_list: list = cache_data.get('items').copy()
try:
if _current_page == 0:
start = 0
end = self._page_size
else:
start = _current_page * self._page_size
end = start + self._page_size
if cache_type == "Torrent":
# 发送种子数据
self.__post_torrents_message(channel=channel,
source=source,
title=_current_media.title,
items=cache_list[start:end],
userid=userid,
total=len(cache_list),
original_message_id=original_message_id,
original_chat_id=original_chat_id)
else:
# 发送媒体数据
self.__post_medias_message(channel=channel,
source=source,
title=_current_meta.name,
items=cache_list[start:end],
userid=userid,
total=len(cache_list),
original_message_id=original_message_id,
original_chat_id=original_chat_id)
finally:
cache_list.clear()
del cache_list
finally:
cache_data.clear()
del cache_data
elif text.lower() == "n":
# 下一页
cache_data: dict = user_cache.get(userid).copy()
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
return
try:
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
cache_list: list = cache_data.get('items').copy()
total = len(cache_list)
# 加一页
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
if not cache_list:
# 没有数据
self.post_message(Notification(
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
return
else:
# 第一页
self.post_message(Notification(
channel=channel, source=source, title="已经是第一页了!", userid=userid))
return
# 减一页
_current_page -= 1
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
cache_list: list = cache_data.get('items').copy()
try:
# 加一页
_current_page += 1
if _current_page == 0:
start = 0
end = self._page_size
else:
start = _current_page * self._page_size
end = start + self._page_size
if cache_type == "Torrent":
# 发送种子数据
self.__post_torrents_message(channel=channel,
source=source,
title=_current_media.title,
items=cache_list,
items=cache_list[start:end],
userid=userid,
total=total,
total=len(cache_list),
original_message_id=original_message_id,
original_chat_id=original_chat_id)
else:
@@ -453,94 +409,144 @@ class MessageChain(ChainBase):
self.__post_medias_message(channel=channel,
source=source,
title=_current_meta.name,
items=cache_list,
items=cache_list[start:end],
userid=userid,
total=total,
total=len(cache_list),
original_message_id=original_message_id,
original_chat_id=original_chat_id)
finally:
cache_list.clear()
del cache_list
finally:
cache_data.clear()
del cache_data
else:
# 搜索或订阅
if text.startswith("订阅"):
# 订阅
content = re.sub(r"订阅[:\s]*", "", text)
action = "Subscribe"
elif text.startswith("洗版"):
# 洗版
content = re.sub(r"洗版[:\s]*", "", text)
action = "ReSubscribe"
elif text.startswith("搜索") or text.startswith("下载"):
# 重新搜索/下载
content = re.sub(r"(搜索|下载)[:\s]*", "", text)
action = "ReSearch"
elif text.startswith("#") \
or re.search(r"^请[问帮你]", text) \
or re.search(r"[?]$", text) \
or StringUtils.count_words(text) > 10 \
or text.find("继续") != -1:
# 聊天
content = text
action = "Chat"
elif StringUtils.is_link(text):
# 链接
content = text
action = "Link"
else:
# 搜索
content = text
action = "Search"
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
# 搜索
meta, medias = MediaChain().search(content)
# 识别
if not meta.name:
self.post_message(Notification(
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
return
# 开始搜索
if not medias:
self.post_message(Notification(
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
userid=userid))
return
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
try:
# 记录当前状态
_current_meta = meta
# 保存缓存
user_cache[userid] = {
'type': action,
'items': medias
}
self.save_cache(user_cache, self._cache_file)
_current_page = 0
_current_media = None
# 发送媒体列表
self.__post_medias_message(channel=channel,
source=source,
title=meta.name,
items=medias[:self._page_size],
userid=userid, total=len(medias))
finally:
medias.clear()
del medias
cache_data.clear()
del cache_data
elif text.lower() == "n":
# 下一页
cache_data: dict = user_cache.get(userid).copy()
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
return
try:
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
cache_list: list = cache_data.get('items').copy()
total = len(cache_list)
# 加一页
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
if not cache_list:
# 没有数据
self.post_message(Notification(
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
return
else:
try:
# 加一页
_current_page += 1
if cache_type == "Torrent":
# 发送种子数据
self.__post_torrents_message(channel=channel,
source=source,
title=_current_media.title,
items=cache_list,
userid=userid,
total=total,
original_message_id=original_message_id,
original_chat_id=original_chat_id)
else:
# 发送媒体数据
self.__post_medias_message(channel=channel,
source=source,
title=_current_meta.name,
items=cache_list,
userid=userid,
total=total,
original_message_id=original_message_id,
original_chat_id=original_chat_id)
finally:
cache_list.clear()
del cache_list
finally:
cache_data.clear()
del cache_data
else:
# 广播事件
self.eventmanager.send_event(
EventType.UserMessage,
{
"text": content,
"userid": userid,
"channel": channel,
"source": source
}
)
# 搜索或订阅
if text.startswith("订阅"):
# 订阅
content = re.sub(r"订阅[:\s]*", "", text)
action = "Subscribe"
elif text.startswith("洗版"):
# 洗版
content = re.sub(r"洗版[:\s]*", "", text)
action = "ReSubscribe"
elif text.startswith("搜索") or text.startswith("下载"):
# 重新搜索/下载
content = re.sub(r"(搜索|下载)[:\s]*", "", text)
action = "ReSearch"
elif text.startswith("#") \
or re.search(r"^请[问帮你]", text) \
or re.search(r"[?]$", text) \
or StringUtils.count_words(text) > 10 \
or text.find("继续") != -1:
# 聊天
content = text
action = "Chat"
elif StringUtils.is_link(text):
# 链接
content = text
action = "Link"
else:
# 搜索
content = text
action = "Search"
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
# 搜索
meta, medias = MediaChain().search(content)
# 识别
if not meta.name:
self.post_message(Notification(
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
return
# 开始搜索
if not medias:
self.post_message(Notification(
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
userid=userid))
return
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
try:
# 记录当前状态
_current_meta = meta
# 保存缓存
user_cache[userid] = {
'type': action,
'items': medias
}
self.save_cache(user_cache, self._cache_file)
_current_page = 0
_current_media = None
# 发送媒体列表
self.__post_medias_message(channel=channel,
source=source,
title=meta.name,
items=medias[:self._page_size],
userid=userid, total=len(medias))
finally:
medias.clear()
del medias
else:
# 广播事件
self.eventmanager.send_event(
EventType.UserMessage,
{
"text": content,
"userid": userid,
"channel": channel,
"source": source
}
)
finally:
user_cache.clear()
del user_cache
@@ -926,7 +932,10 @@ class MessageChain(ChainBase):
return
# 提取用户消息
user_message = text[3:].strip() # 移除 "/ai" 前缀
if text.lower().startswith("/ai"):
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
else:
user_message = text.strip() # 按原消息处理
if not user_message:
self.post_message(Notification(
channel=channel,

View File

@@ -1024,13 +1024,11 @@ def fresh(fresh: bool = True):
with fresh():
result = some_cached_function()
"""
token = _fresh.set(fresh)
logger.debug(f"Setting fresh mode to {fresh}. {id(token):#x}")
token = _fresh.set(fresh or is_fresh())
try:
yield
finally:
_fresh.reset(token)
logger.debug(f"Reset fresh mode. {id(token):#x}")
@asynccontextmanager
async def async_fresh(fresh: bool = True):
@@ -1041,13 +1039,11 @@ async def async_fresh(fresh: bool = True):
async with async_fresh():
result = await some_async_cached_function()
"""
token = _fresh.set(fresh)
logger.debug(f"Setting async_fresh mode to {fresh}. {id(token):#x}")
token = _fresh.set(fresh or is_fresh())
try:
yield
finally:
_fresh.reset(token)
logger.debug(f"Reset async_fresh mode. {id(token):#x}")
def is_fresh() -> bool:
"""

View File

@@ -411,6 +411,8 @@ class ConfigModel(BaseModel):
# ==================== AI智能体配置 ====================
# AI智能体开关
AI_AGENT_ENABLE: bool = False
# 合局AI智能体
AI_AGENT_GLOBAL: bool = False
# LLM提供商 (openai/google/deepseek)
LLM_PROVIDER: str = "deepseek"
# LLM模型名称

View File

@@ -1,19 +1,21 @@
import asyncio
import re
import threading
import uuid
from pathlib import Path
from threading import Event
from typing import Optional, List, Dict, Callable
from urllib.parse import urljoin
import telebot
from telegramify_markdown import standardize, telegramify
from telegramify_markdown.type import ContentTypes, SentType
from telebot import apihelper
from telebot.types import InputFile, InlineKeyboardMarkup, InlineKeyboardButton
from telebot.types import InlineKeyboardMarkup, InlineKeyboardButton
from telebot.types import InputMediaPhoto
from app.core.config import settings
from app.core.context import MediaInfo, Context
from app.core.metainfo import MetaInfo
from app.helper.thread import ThreadHelper
from app.log import logger
from app.utils.common import retry
from app.utils.http import RequestUtils
@@ -52,7 +54,7 @@ class Telegram:
else:
apihelper.proxy = settings.PROXY
# bot
_bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown")
_bot = telebot.TeleBot(self._telegram_token, parse_mode="MarkdownV2")
# 记录句柄
self._bot = _bot
# 获取并存储bot用户名用于@检测
@@ -236,12 +238,14 @@ class Telegram:
return False
try:
if text:
# 对text进行Markdown特殊字符转义
text = re.sub(r"([_`])", r"\\\1", text)
caption = f"*{title}*\n{text}"
if title and text:
caption = f"**{title}**\n{text}"
elif title:
caption = f"**{title}**"
elif text:
caption = text
else:
caption = f"*{title}*"
caption = ""
if link:
caption = f"{caption}\n[查看详情]({link})"
@@ -499,7 +503,7 @@ class Telegram:
if image:
# 如果有图片使用edit_message_media
media = InputMediaPhoto(media=image, caption=text, parse_mode="Markdown")
media = InputMediaPhoto(media=image, caption=standardize(text), parse_mode="MarkdownV2")
self._bot.edit_message_media(
chat_id=chat_id,
message_id=message_id,
@@ -511,8 +515,8 @@ class Telegram:
self._bot.edit_message_text(
chat_id=chat_id,
message_id=message_id,
text=text,
parse_mode="Markdown",
text=standardize(text),
parse_mode="MarkdownV2",
reply_markup=reply_markup
)
return True
@@ -520,49 +524,120 @@ class Telegram:
logger.error(f"编辑消息失败:{str(e)}")
return False
@retry(RetryException, logger=logger)
def __send_request(self, userid: Optional[str] = None, image="", caption="",
reply_markup: Optional[InlineKeyboardMarkup] = None) -> bool:
"""
向Telegram发送报文
:param reply_markup: 内联键盘
"""
if image:
res = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(image)
if res is None:
raise Exception("获取图片失败")
if res.content:
# 使用随机标识构建图片文件的完整路径,并写入图片内容到文件
image_file = Path(settings.TEMP_PATH) / "telegram" / str(uuid.uuid4())
if not image_file.parent.exists():
image_file.parent.mkdir(parents=True, exist_ok=True)
image_file.write_bytes(res.content)
photo = InputFile(image_file)
# 发送图片到Telegram
ret = self._bot.send_photo(chat_id=userid or self._telegram_chat_id,
photo=photo,
caption=caption,
parse_mode="Markdown",
reply_markup=reply_markup)
if ret is None:
raise RetryException("发送图片消息失败")
return True
# 按4096分段循环发送消息
ret = None
if len(caption) > 4095:
for i in range(0, len(caption), 4095):
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
text=caption[i:i + 4095],
parse_mode="Markdown",
reply_markup=reply_markup if i == 0 else None)
else:
ret = self._bot.send_message(chat_id=userid or self._telegram_chat_id,
text=caption,
parse_mode="Markdown",
reply_markup=reply_markup)
if ret is None:
raise RetryException("发送文本消息失败")
return True if ret else False
kwargs = {
'chat_id': userid or self._telegram_chat_id,
'parse_mode': "MarkdownV2",
'reply_markup': reply_markup
}
try:
# 处理图片
image = self.__process_image(image) if image else None
# 图片消息的标题长度限制为1024文本消息为4096
caption_limit = 1024 if image else 4096
if len(caption) < caption_limit:
ret = self.__send_short_message(image, caption, **kwargs)
else:
sent_idx = set()
ret = self.__send_long_message(image, caption, sent_idx, **kwargs)
return ret is not None
except Exception as e:
logger.error(f"发送Telegram消息失败: {e}")
return False
@retry(RetryException, logger=logger)
def __process_image(self, image_url: str) -> bytes:
"""
处理图片URL获取图片内容
"""
try:
res = RequestUtils(
proxies=settings.PROXY,
ua=settings.NORMAL_USER_AGENT
).get_res(image_url)
if not res or not res.content:
raise RetryException("获取图片失败")
return res.content
except Exception as e:
raise
@retry(RetryException, logger=logger)
def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs):
"""
发送短消息
"""
try:
if image:
return self._bot.send_photo(
photo=image,
caption=standardize(caption),
**kwargs
)
else:
return self._bot.send_message(
text=standardize(caption),
**kwargs
)
except Exception as e:
raise RetryException(f"发送{'图片' if image else '文本'}消息失败")
@retry(RetryException, logger=logger)
def __send_long_message(self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs):
"""
发送长消息
"""
try:
reply_markup = kwargs.pop("reply_markup", None)
boxs: SentType = ThreadHelper().submit(lambda x: asyncio.run(telegramify(x)), caption).result()
ret = None
for i, item in enumerate(boxs):
if i in sent_idx:
# 跳过已发送消息
continue
current_reply_markup = reply_markup if i == 0 else None
if item.content_type == ContentTypes.TEXT and (i != 0 or not image):
ret = self._bot.send_message(**kwargs,
text=item.content,
reply_markup=current_reply_markup
)
elif item.content_type == ContentTypes.PHOTO or (image and i == 0):
ret = self._bot.send_photo(**kwargs,
photo=(getattr(item, "file_name", ""),
getattr(item, "file_data", image)),
caption=getattr(item, "caption", item.content),
reply_markup=current_reply_markup
)
elif item.content_type == ContentTypes.FILE:
ret = self._bot.send_document(**kwargs,
document=(item.file_name, item.file_data),
caption=item.caption,
reply_markup=current_reply_markup
)
sent_idx.add(i)
return ret
except Exception as e:
try:
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
except NameError:
raise RetryException("发送长消息失败") from e
def register_commands(self, commands: Dict[str, dict]):
"""

View File

@@ -14,7 +14,7 @@ class CommingMessage(BaseModel):
# 用户ID
userid: Optional[Union[str, int]] = None
# 用户名称
username: Optional[str] = None
username: Optional[Union[str, int]] = None
# 消息渠道
channel: Optional[MessageChannel] = None
# 来源(渠道名称)