Files
MoviePilot/app/agent/tools/impl/search_web.py
2026-01-24 07:39:07 +08:00

183 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import json
import re
from typing import Optional, Type, List, Dict
import httpx
from ddgs import DDGS
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.core.config import settings
from app.log import logger
# 搜索超时时间(秒)
SEARCH_TIMEOUT = 20
class SearchWebInput(BaseModel):
"""搜索网络内容工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
query: str = Field(..., description="The search query string to search for on the web")
max_results: Optional[int] = Field(5,
description="Maximum number of search results to return (default: 5, max: 10)")
class SearchWebTool(MoviePilotTool):
name: str = "search_web"
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
args_schema: Type[BaseModel] = SearchWebInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据搜索参数生成友好的提示消息"""
query = kwargs.get("query", "")
max_results = kwargs.get("max_results", 5)
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
"""
执行网络搜索
"""
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
try:
# 限制最大结果数
max_results = min(max(1, max_results or 5), 10)
results = []
# 1. 优先使用 Tavily (如果配置了 API Key)
if settings.TAVILY_API_KEY:
logger.info("使用 Tavily 进行搜索...")
results = await self._search_tavily(query, max_results)
# 2. 如果没有结果或未配置 Tavily使用 DuckDuckGo
if not results:
logger.info("使用 DuckDuckGo 进行搜索...")
results = await self._search_duckduckgo(query, max_results)
if not results:
return f"未找到与 '{query}' 相关的搜索结果"
# 格式化并裁剪结果
formatted_results = self._format_and_truncate_results(results, max_results)
return json.dumps(formatted_results, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"搜索网络内容失败: {str(e)}"
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
return error_message
@staticmethod
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
"""使用 Tavily API 进行搜索"""
try:
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
response = await client.post(
"https://api.tavily.com/search",
json={
"api_key": settings.TAVILY_API_KEY,
"query": query,
"search_depth": "basic",
"max_results": max_results,
"include_answer": False,
"include_images": False,
"include_raw_content": False,
}
)
response.raise_for_status()
data = response.json()
results = []
for result in data.get("results", []):
results.append({
'title': result.get('title', ''),
'snippet': result.get('content', ''),
'url': result.get('url', ''),
'source': 'Tavily'
})
return results
except Exception as e:
logger.warning(f"Tavily 搜索失败: {e}")
return []
@staticmethod
def _get_proxy_url(proxy_setting) -> Optional[str]:
"""从代理设置中提取代理URL"""
if not proxy_setting:
return None
if isinstance(proxy_setting, dict):
return proxy_setting.get('http') or proxy_setting.get('https')
return proxy_setting
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
"""使用 duckduckgo-search (DDGS) 进行搜索"""
try:
def sync_search():
results = []
ddgs_kwargs = {
'timeout': SEARCH_TIMEOUT
}
proxy_url = self._get_proxy_url(settings.PROXY)
if proxy_url:
ddgs_kwargs['proxy'] = proxy_url
try:
with DDGS(**ddgs_kwargs) as ddgs:
ddgs_gen = ddgs.text(
query,
max_results=max_results
)
if ddgs_gen:
for result in ddgs_gen:
results.append({
'title': result.get('title', ''),
'snippet': result.get('body', ''),
'url': result.get('href', ''),
'source': 'DuckDuckGo'
})
except Exception as err:
logger.warning(f"DuckDuckGo search process failed: {err}")
return results
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, sync_search)
except Exception as e:
logger.warning(f"DuckDuckGo 搜索失败: {e}")
return []
@staticmethod
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
"""格式化并裁剪搜索结果"""
formatted = {
"total_results": len(results),
"results": []
}
for idx, result in enumerate(results[:max_results], 1):
title = result.get("title", "")[:200]
snippet = result.get("snippet", "")
url = result.get("url", "")
source = result.get("source", "Unknown")
# 裁剪摘要
max_snippet_length = 500 # 增加到500字符提供更多上下文
if len(snippet) > max_snippet_length:
snippet = snippet[:max_snippet_length] + "..."
# 清理文本
snippet = re.sub(r'\s+', ' ', snippet).strip()
formatted["results"].append({
"rank": idx,
"title": title,
"snippet": snippet,
"url": url,
"source": source
})
if len(results) > max_results:
formatted["note"] = f"仅显示前 {max_results} 条结果。"
return formatted