feat(search): 添加AI推荐功能并优化相关逻辑

This commit is contained in:
PKC278
2026-01-15 02:49:29 +08:00
parent 91354295f2
commit 95f2ac3811
10 changed files with 543 additions and 72 deletions

View File

@@ -16,6 +16,7 @@ from app.agent.prompt import PromptManager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.helper.llm import LLMHelper
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
@@ -64,57 +65,7 @@ class MoviePilotAgent:
def _initialize_llm(self):
"""初始化LLM模型"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
if provider == "google":
if settings.PROXY_HOST:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
else:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler]
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True
)
else:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
def _initialize_tools(self) -> List:
"""初始化工具列表"""

View File

@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain
from app.log import logger
from app.schemas.types import MediaType
from app.utils.string import StringUtils
class SearchTorrentsInput(BaseModel):
@@ -99,7 +100,7 @@ class SearchTorrentsTool(MoviePilotTool):
if t.torrent_info:
simplified["torrent_info"] = {
"title": t.torrent_info.title,
"size": t.torrent_info.size,
"size": StringUtils.format_size(t.torrent_info.size),
"seeders": t.torrent_info.seeders,
"peers": t.torrent_info.peers,
"site_name": t.torrent_info.site_name,

View File

@@ -1,14 +1,16 @@
from typing import List, Any, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Body
from app import schemas
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.chain.ai_recommend import AIRecommendChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.log import logger
from app.schemas import MediaRecognizeConvertEventData
from app.schemas.types import MediaType, ChainEventType
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
# 取消正在运行的AI推荐会清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
if mtype:
media_type = MediaType(mtype)
else:
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
# 取消正在运行的AI推荐并清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
@@ -167,3 +175,82 @@ async def search_by_title(keyword: Optional[str] = None,
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
async def recommend_search_results(
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
AI推荐资源 - 轮询接口
前端轮询此接口,发送筛选后的索引(如果有筛选)
后端根据请求变化自动取消旧任务并启动新任务
参数:
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
- force: 强制重新推荐(清除旧结果并重新启动)
返回数据结构:
{
"success": bool,
"message": string, // 错误信息(仅在错误时存在)
"data": {
"status": string, // 状态: disabled | idle | running | completed | error
"results": array // 推荐结果仅status=completed时存在
}
}
"""
# 从缓存获取上次搜索结果
results = await SearchChain().async_last_search_results() or []
if not results:
return schemas.Response(success=False, message="没有可用的搜索结果", data={
"status": "error"
})
recommend_chain = AIRecommendChain()
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
if force:
logger.info("收到新推荐请求,清除旧结果并启动新任务")
recommend_chain.cancel_ai_recommend()
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
# 直接返回运行中状态
return schemas.Response(success=True, data={
"status": "running"
})
# 如果是仅检查模式,不传递 filtered_indices避免触发请求变化检测
if check_only:
# 返回当前运行状态,不做任何任务启动或取消操作
current_status = recommend_chain.get_current_status_only()
# 如果有错误将错误信息放到message中
if current_status.get("status") == "error":
error_msg = current_status.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=current_status)
return schemas.Response(success=True, data=current_status)
# 获取当前状态(会检测请求是否变化)
status_data = recommend_chain.get_status(filtered_indices, len(results))
# 如果功能未启用,直接返回禁用状态
if status_data.get("status") == "disabled":
return schemas.Response(success=True, data=status_data)
# 如果是空闲状态,启动新任务
if status_data["status"] == "idle":
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
# 立即返回运行中状态
return schemas.Response(success=True, data={
"status": "running"
})
# 如果有错误将错误信息放到message中
if status_data.get("status") == "error":
error_msg = status_data.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=status_data)
# 返回当前状态
return schemas.Response(success=True, data=status_data)

View File

@@ -130,18 +130,40 @@ async def cache_img(
def get_global_setting(token: str):
"""
查询非敏感系统设置(默认鉴权)
仅包含登录前UI初始化必需的字段
"""
if token != "moviepilot":
raise HTTPException(status_code=403, detail="Forbidden")
# 白名单模式,仅包含前端业务逻辑必需的字段
# 白名单模式,仅包含登录前UI初始化必需的字段
info = settings.model_dump(
include={
"TMDB_IMAGE_DOMAIN",
"GLOBAL_IMAGE_CACHE",
"ADVANCED_MODE",
}
)
# 追加版本信息(用于版本检查)
info.update({
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
"BACKEND_VERSION": APP_VERSION
})
return schemas.Response(success=True,
data=info)
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
"""
查询用户相关系统设置(登录后获取)
包含业务功能相关的配置和用户权限信息
"""
# 业务功能相关的配置字段
info = settings.model_dump(
include={
"RECOGNIZE_SOURCE",
"SEARCH_SOURCE"
"SEARCH_SOURCE",
"AI_RECOMMEND_ENABLED"
}
)
# 追加用户唯一ID和订阅分享管理权限
@@ -150,8 +172,6 @@ def get_global_setting(token: str):
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
"SUBSCRIBE_SHARE_MANAGE": share_admin,
"WORKFLOW_SHARE_MANAGE": share_admin,
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
"BACKEND_VERSION": APP_VERSION
})
return schemas.Response(success=True,
data=info)

310
app/chain/ai_recommend.py Normal file
View File

@@ -0,0 +1,310 @@
import re
from typing import List, Optional, Dict, Any
import asyncio
import hashlib
import json
from app.chain import ChainBase
from app.core.config import settings
from app.log import logger
from app.utils.common import log_execution_time
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
class AIRecommendChain(ChainBase, metaclass=Singleton):
"""
AI推荐处理链单例运行
用于基于搜索结果的AI智能推荐
"""
# 缓存文件名
__ai_indices_cache_file = "__ai_recommend_indices__"
# AI推荐状态
_ai_recommend_running = False
_ai_recommend_task: Optional[asyncio.Task] = None
_current_request_hash: Optional[str] = None # 当前请求的哈希值
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存索引列表
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
@staticmethod
def _calculate_request_hash(
filtered_indices: Optional[List[int]], search_results_count: int
) -> str:
"""
计算请求的哈希值,用于判断请求是否变化
"""
request_data = {
"filtered_indices": filtered_indices or [],
"search_results_count": search_results_count,
}
return hashlib.md5(
json.dumps(request_data, sort_keys=True).encode()
).hexdigest()
def _build_status(self) -> Dict[str, Any]:
"""
构建AI推荐状态字典
:return: 状态字典
"""
if not settings.AI_RECOMMEND_ENABLED:
return {"status": "disabled"}
if self._ai_recommend_running:
return {"status": "running"}
# 尝试从数据库加载缓存
if self._ai_recommend_result is None:
cached_indices = self.load_cache(self.__ai_indices_cache_file)
if cached_indices is not None:
self._ai_recommend_result = cached_indices
# 只要有结果始终返回completed状态和数据
if self._ai_recommend_result is not None:
return {"status": "completed", "results": self._ai_recommend_result}
if self._ai_recommend_error is not None:
return {"status": "error", "error": self._ai_recommend_error}
return {"status": "idle"}
def get_current_status_only(self) -> Dict[str, Any]:
"""
获取当前状态不校验hash用于check_only模式
"""
return self._build_status()
def get_status(
self, filtered_indices: Optional[List[int]], search_results_count: int
) -> Dict[str, Any]:
"""
获取AI推荐状态并检查请求是否变化用于首次请求或force模式
如果请求变化筛选条件变化返回idle状态
"""
# 计算当前请求的hash
request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 检查请求是否变化
is_same_request = request_hash == self._current_request_hash
# 如果请求变化了筛选条件改变返回idle状态
if not is_same_request:
return (
{"status": "idle"}
if settings.AI_RECOMMEND_ENABLED
else {"status": "disabled"}
)
# 请求未变化,返回当前实际状态
return self._build_status()
@log_execution_time(logger=logger)
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
"""
AI推荐
:param items: 候选资源列表(JSON字符串格式)
:param preference: 用户偏好(可选)
:return: AI返回的推荐结果
"""
# 设置运行状态
self._ai_recommend_running = True
try:
# 导入LLMHelper
from app.helper.llm import LLMHelper
# 获取LLM实例
llm = LLMHelper.get_llm()
# 构建提示词
user_preference = (
preference
or settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
# 添加指令
instruction = """
Task: Select the best matching items from the list based on user preferences.
Each item contains:
- index: Item number
- title: Full torrent title
- size: File size
- seeders: Number of seeders
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
"""
message = (
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
+ "\n".join(items)
)
# 调用LLM
response = await llm.ainvoke(message)
return response.content
except ValueError as e:
logger.error(f"AI推荐配置错误: {e}")
raise
except Exception as e:
raise
finally:
# 清除运行状态
self._ai_recommend_running = False
self._ai_recommend_task = None
def is_ai_recommend_running(self) -> bool:
"""
检查AI推荐是否正在运行
"""
return self._ai_recommend_running
def cancel_ai_recommend(self):
"""
取消正在运行的AI推荐任务
"""
if self._ai_recommend_task and not self._ai_recommend_task.done():
self._ai_recommend_task.cancel()
self._ai_recommend_running = False
self._ai_recommend_task = None
self._current_request_hash = None
self._ai_recommend_result = None
self._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
def start_recommend_task(
self,
filtered_indices: Optional[List[int]],
search_results_count: int,
results: List[Any],
) -> None:
"""
启动AI推荐任务
:param filtered_indices: 筛选后的索引列表
:param search_results_count: 搜索结果总数
:param results: 搜索结果列表
"""
# 计算新请求的哈希值
new_request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 如果请求变化了,取消旧任务
if new_request_hash != self._current_request_hash:
self.cancel_ai_recommend()
# 更新请求哈希值
self._current_request_hash = new_request_hash
# 重置状态
self._ai_recommend_result = None
self._ai_recommend_error = None
# 启动新任务
async def run_recommend():
# 获取当前任务对象用于在finally中比对
current_task = asyncio.current_task()
try:
self._ai_recommend_running = True
# 准备数据
items = []
valid_indices = []
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
if filtered_indices is not None and len(filtered_indices) > 0:
results_to_process = [
results[i]
for i in filtered_indices
if 0 <= i < len(results)
]
else:
results_to_process = results
for i, torrent in enumerate(results_to_process):
if len(items) >= max_items:
break
if not torrent.torrent_info:
continue
valid_indices.append(i)
item_info = {
"index": i,
"title": torrent.torrent_info.title or "未知",
"size": (
StringUtils.format_size(torrent.torrent_info.size)
if torrent.torrent_info.size
else "0 B"
),
"seeders": torrent.torrent_info.seeders or 0,
}
items.append(json.dumps(item_info, ensure_ascii=False))
if not items:
self._ai_recommend_error = "没有可用于AI推荐的资源"
return
# 调用AI推荐
ai_response = await self.async_ai_recommend(items)
# 解析AI返回的索引
try:
# 使用正则提取JSON数组非贪婪模式避免匹配多个数组
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
if not json_match:
raise ValueError(ai_response)
ai_indices = json.loads(json_match.group())
if not isinstance(ai_indices, list):
raise ValueError(f"AI返回格式错误: {ai_response}")
# 映射回原始索引
if filtered_indices:
original_indices = [
filtered_indices[valid_indices[i]]
for i in ai_indices
if i < len(valid_indices)
and 0 <= filtered_indices[valid_indices[i]] < len(results)
]
else:
original_indices = [
valid_indices[i]
for i in ai_indices
if i < len(valid_indices)
and 0 <= valid_indices[i] < len(results)
]
# 只返回索引列表,不返回完整数据
self._ai_recommend_result = original_indices
# 保存到数据库
self.save_cache(original_indices, self.__ai_indices_cache_file)
logger.info(f"AI推荐完成: {len(original_indices)}")
except Exception as e:
logger.error(
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
)
self._ai_recommend_error = str(e)
except asyncio.CancelledError:
logger.info("AI推荐任务被取消")
except Exception as e:
logger.error(f"AI推荐任务失败: {e}")
self._ai_recommend_error = str(e)
finally:
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
# 如果任务被取消并启动了新任务self._ai_recommend_task 已经指向新任务,不应重置
if self._ai_recommend_task == current_task:
self._ai_recommend_running = False
self._ai_recommend_task = None
# 创建并启动任务
self._ai_recommend_task = asyncio.create_task(run_recommend())

View File

@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
"""
__result_temp_file = "__search_result__"
__ai_result_temp_file = "__ai_search_result__"
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
"""
return await self.async_load_cache(self.__result_temp_file)
async def async_last_ai_results(self) -> Optional[List[Context]]:
"""
异步获取上次AI推荐结果
"""
return await self.async_load_cache(self.__ai_result_temp_file)
async def async_save_ai_results(self, results: List[Context]):
"""
异步保存AI推荐结果
"""
await self.async_save_cache(results, self.__ai_result_temp_file)
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:

View File

@@ -489,20 +489,18 @@ class SiteChain(ChainBase):
logger.warn(f"站点 {domain} 索引器不存在!")
return
# 查询站点图标
site_icon = siteoper.get_icon_by_domain(domain)
if not site_icon or not site_icon.base64:
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
@eventmanager.register(EventType.SiteUpdated)
def clear_site_data(self, event: Event):

View File

@@ -439,6 +439,12 @@ class ConfigModel(BaseModel):
LLM_MEMORY_RETENTION_DAYS: int = 1
# Redis记忆保留天数如果使用Redis
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
# 是否启用AI推荐
AI_RECOMMEND_ENABLED: bool = False
# AI推荐用户偏好
AI_RECOMMEND_USER_PREFERENCE: str = ""
# AI推荐条目数量限制
AI_RECOMMEND_MAX_ITEMS: int = 50
class Settings(BaseSettings, ConfigModel, LogConfigModel):

View File

@@ -1,12 +1,76 @@
"""LLM模型相关辅助功能"""
from typing import List
from typing import List, Optional
from app.core.config import settings
from app.log import logger
class LLMHelper:
"""LLM模型相关辅助功能"""
@staticmethod
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
"""
获取LLM实例
:param streaming: 是否启用流式输出
:param callbacks: 回调处理器列表
:return: LLM实例
"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
if not api_key:
raise ValueError("未配置LLM API Key")
if provider == "google":
if settings.PROXY_HOST:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
else:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True
)
else:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
"""获取模型列表"""
logger.info(f"获取 {provider} 模型列表...")

View File

@@ -242,6 +242,27 @@ class StringUtils:
else:
return size + "B"
@staticmethod
def format_size(size_bytes: int) -> str:
"""
将字节转换为人类可读格式
"""
if not size_bytes or size_bytes == 0:
return "0 B"
units = ["B", "KB", "MB", "GB", "TB", "PB"]
size = float(size_bytes)
unit_index = 0
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
# 保留两位小数
if unit_index == 0:
return f"{int(size)} {units[unit_index]}"
return f"{size:.2f} {units[unit_index]}"
@staticmethod
def url_equal(url1: str, url2: str) -> bool:
"""