diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 6f8c7c81..b7904750 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -23,8 +23,11 @@ from app.core.module import ModuleManager from app.core.security import verify_apitoken, verify_resource_token, verify_token from app.db.models import User from app.db.systemconfig_oper import SystemConfigOper -from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \ - get_current_active_user_async +from app.db.user_oper import ( + get_current_active_superuser, + get_current_active_superuser_async, + get_current_active_user_async, +) from app.helper.llm import LLMHelper from app.helper.mediaserver import MediaServerHelper from app.helper.message import MessageHelper @@ -47,12 +50,13 @@ router = APIRouter() async def fetch_image( - url: str, - proxy: Optional[bool] = None, - use_cache: bool = False, - if_none_match: Optional[str] = None, - cookies: Optional[str | dict] = None, - allowed_domains: Optional[set[str]] = None) -> Optional[Response]: + url: str, + proxy: Optional[bool] = None, + use_cache: bool = False, + if_none_match: Optional[str] = None, + cookies: Optional[str | dict] = None, + allowed_domains: Optional[set[str]] = None, +) -> Optional[Response]: """ 处理图片缓存逻辑,支持HTTP缓存和磁盘缓存 """ @@ -83,47 +87,57 @@ async def fetch_image( return Response( content=content, media_type=UrlUtils.get_mime_type(url, "image/jpeg"), - headers=headers + headers=headers, ) @router.get("/img/{proxy}", summary="图片代理") async def proxy_img( - imgurl: str, - proxy: bool = False, - cache: bool = False, - use_cookies: bool = False, - if_none_match: Annotated[str | None, Header()] = None, - _: schemas.TokenPayload = Depends(verify_resource_token) + imgurl: str, + proxy: bool = False, + cache: bool = False, + use_cookies: bool = False, + if_none_match: Annotated[str | None, Header()] = None, + _: schemas.TokenPayload = Depends(verify_resource_token), ) -> Response: """ 图片代理,可选是否使用代理服务器,支持 HTTP 缓存 """ # 媒体服务器添加图片代理支持 - hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if - config and config.config and config.config.get("host")] + hosts = [ + config.config.get("host") + for config in MediaServerHelper().get_configs().values() + if config and config.config and config.config.get("host") + ] allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts) cookies = ( MediaServerChain().get_image_cookies(server=None, image_url=imgurl) if use_cookies else None ) - return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies, - if_none_match=if_none_match, allowed_domains=allowed_domains) + return await fetch_image( + url=imgurl, + proxy=proxy, + use_cache=cache, + cookies=cookies, + if_none_match=if_none_match, + allowed_domains=allowed_domains, + ) @router.get("/cache/image", summary="图片缓存") async def cache_img( - url: str, - if_none_match: Annotated[str | None, Header()] = None, - _: schemas.TokenPayload = Depends(verify_resource_token) + url: str, + if_none_match: Annotated[str | None, Header()] = None, + _: schemas.TokenPayload = Depends(verify_resource_token), ) -> Response: """ 本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存 """ # 如果没有启用全局图片缓存,则不使用磁盘缓存 - return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE, - if_none_match=if_none_match) + return await fetch_image( + url=url, use_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match + ) @router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response) @@ -144,15 +158,18 @@ def get_global_setting(token: str): } ) # 追加版本信息(用于版本检查) - info.update({ - "FRONTEND_VERSION": SystemChain.get_frontend_version(), - "BACKEND_VERSION": APP_VERSION - }) - return schemas.Response(success=True, - data=info) + info.update( + { + "FRONTEND_VERSION": SystemChain.get_frontend_version(), + "BACKEND_VERSION": APP_VERSION, + } + ) + return schemas.Response(success=True, data=info) -@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response) +@router.get( + "/global/user", summary="查询用户相关系统设置", response_model=schemas.Response +) async def get_user_global_setting(_: User = Depends(get_current_active_user_async)): """ 查询用户相关系统设置(登录后获取) @@ -164,7 +181,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn "RECOGNIZE_SOURCE", "SEARCH_SOURCE", "AI_RECOMMEND_ENABLED", - "PASSKEY_ALLOW_REGISTER_WITHOUT_OTP" + "PASSKEY_ALLOW_REGISTER_WITHOUT_OTP", } ) # 智能助手总开关未开启,智能推荐状态强制返回False @@ -173,13 +190,14 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn # 追加用户唯一ID和订阅分享管理权限 share_admin = SubscribeHelper().is_admin_user() - info.update({ - "USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(), - "SUBSCRIBE_SHARE_MANAGE": share_admin, - "WORKFLOW_SHARE_MANAGE": share_admin, - }) - return schemas.Response(success=True, - data=info) + info.update( + { + "USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(), + "SUBSCRIBE_SHARE_MANAGE": share_admin, + "WORKFLOW_SHARE_MANAGE": share_admin, + } + ) + return schemas.Response(success=True, data=info) @router.get("/env", summary="查询系统配置", response_model=schemas.Response) @@ -187,22 +205,22 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)): """ 查询系统环境变量,包括当前版本号(仅管理员) """ - info = settings.model_dump( - exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"} + info = settings.model_dump(exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}) + info.update( + { + "VERSION": APP_VERSION, + "AUTH_VERSION": SitesHelper().auth_version, + "INDEXER_VERSION": SitesHelper().indexer_version, + "FRONTEND_VERSION": SystemChain().get_frontend_version(), + } ) - info.update({ - "VERSION": APP_VERSION, - "AUTH_VERSION": SitesHelper().auth_version, - "INDEXER_VERSION": SitesHelper().indexer_version, - "FRONTEND_VERSION": SystemChain().get_frontend_version() - }) - return schemas.Response(success=True, - data=info) + return schemas.Response(success=True, data=info) @router.post("/env", summary="更新系统配置", response_model=schemas.Response) -async def set_env_setting(env: dict, - _: User = Depends(get_current_active_superuser_async)): +async def set_env_setting( + env: dict, _: User = Depends(get_current_active_superuser_async) +): """ 更新系统环境变量(仅管理员) """ @@ -215,30 +233,31 @@ async def set_env_setting(env: dict, return schemas.Response( success=False, message=f"{', '.join([v[1] for v in failed_updates.values()])}", - data={ - "success_updates": success_updates, - "failed_updates": failed_updates - } + data={"success_updates": success_updates, "failed_updates": failed_updates}, ) if success_updates: # 发送配置变更事件 - await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( - key=success_updates.keys(), - change_type="update" - )) + await eventmanager.async_send_event( + etype=EventType.ConfigChanged, + data=ConfigChangeEventData( + key=success_updates.keys(), change_type="update" + ), + ) return schemas.Response( success=True, message="所有配置项更新成功", - data={ - "success_updates": success_updates - } + data={"success_updates": success_updates}, ) @router.get("/progress/{process_type}", summary="实时进度") -async def get_progress(request: Request, process_type: str, _: schemas.TokenPayload = Depends(verify_resource_token)): +async def get_progress( + request: Request, + process_type: str, + _: schemas.TokenPayload = Depends(verify_resource_token), +): """ 实时获取处理进度,返回格式为SSE """ @@ -259,8 +278,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl @router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response) -async def get_setting(key: str, - _: User = Depends(get_current_active_user_async)): +async def get_setting(key: str, _: User = Depends(get_current_active_user_async)): """ 查询系统设置(仅管理员) """ @@ -268,16 +286,14 @@ async def get_setting(key: str, value = getattr(settings, key) else: value = SystemConfigOper().get(key) - return schemas.Response(success=True, data={ - "value": value - }) + return schemas.Response(success=True, data={"value": value}) @router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response) async def set_setting( - key: str, - value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None, - _: User = Depends(get_current_active_superuser_async), + key: str, + value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None, + _: User = Depends(get_current_active_superuser_async), ): """ 更新系统设置(仅管理员) @@ -286,11 +302,10 @@ async def set_setting( success, message = settings.update_setting(key=key, value=value) if success: # 发送配置变更事件 - await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( - key=key, - value=value, - change_type="update" - )) + await eventmanager.async_send_event( + etype=EventType.ConfigChanged, + data=ConfigChangeEventData(key=key, value=value, change_type="update"), + ) elif success is None: success = True return schemas.Response(success=success, message=message) @@ -301,31 +316,40 @@ async def set_setting( success = await SystemConfigOper().async_set(key, value) if success: # 发送配置变更事件 - await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( - key=key, - value=value, - change_type="update" - )) + await eventmanager.async_send_event( + etype=EventType.ConfigChanged, + data=ConfigChangeEventData(key=key, value=value, change_type="update"), + ) return schemas.Response(success=True) else: return schemas.Response(success=False, message=f"配置项 '{key}' 不存在") @router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response) -async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)): +async def get_llm_models( + provider: str, + api_key: str, + base_url: Optional[str] = None, + _: User = Depends(get_current_active_user_async), +): """ 获取LLM模型列表 """ try: - models = LLMHelper().get_models(provider, api_key, base_url) + models = await asyncio.to_thread( + LLMHelper().get_models, provider, api_key, base_url + ) return schemas.Response(success=True, data=models) except Exception as e: return schemas.Response(success=False, message=str(e)) @router.get("/message", summary="实时消息") -async def get_message(request: Request, role: Optional[str] = "system", - _: schemas.TokenPayload = Depends(verify_resource_token)): +async def get_message( + request: Request, + role: Optional[str] = "system", + _: schemas.TokenPayload = Depends(verify_resource_token), +): """ 实时获取系统消息,返回格式为SSE """ @@ -346,8 +370,12 @@ async def get_message(request: Request, role: Optional[str] = "system", @router.get("/logging", summary="实时日志") -async def get_logging(request: Request, length: Optional[int] = 50, logfile: Optional[str] = "moviepilot.log", - _: schemas.TokenPayload = Depends(verify_resource_token)): +async def get_logging( + request: Request, + length: Optional[int] = 50, + logfile: Optional[str] = "moviepilot.log", + _: schemas.TokenPayload = Depends(verify_resource_token), +): """ 实时获取系统日志 length = -1 时, 返回text/plain @@ -356,7 +384,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt base_path = AsyncPath(settings.LOG_PATH) log_path = base_path / logfile - if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}): + if not await SecurityUtils.async_is_safe_path( + base_path=base_path, user_path=log_path, allowed_suffixes={".log"} + ): raise HTTPException(status_code=404, detail="Not Found") if not await log_path.exists() or not await log_path.is_file(): @@ -371,7 +401,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt file_size = file_stat.st_size # 读取历史日志 - async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f: + async with aiofiles.open( + log_path, mode="r", encoding="utf-8", errors="ignore" + ) as f: # 优化大文件读取策略 if file_size > 100 * 1024: # 只读取最后100KB的内容 @@ -380,9 +412,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt await f.seek(position) content = await f.read() # 找到第一个完整的行 - first_newline = content.find('\n') + first_newline = content.find("\n") if first_newline != -1: - content = content[first_newline + 1:] + content = content[first_newline + 1 :] else: # 小文件直接读取全部内容 content = await f.read() @@ -390,7 +422,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt # 按行分割并添加到队列,只保留非空行 lines = [line.strip() for line in content.splitlines() if line.strip()] # 只取最后N行 - for line in lines[-max(length, 50):]: + for line in lines[-max(length, 50) :]: lines_queue.append(line) # 输出历史日志 @@ -398,7 +430,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt yield f"data: {line}\n\n" # 实时监听新日志 - async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f: + async with aiofiles.open( + log_path, mode="r", encoding="utf-8", errors="ignore" + ) as f: # 移动文件指针到文件末尾,继续监听新增内容 await f.seek(0, 2) # 记录初始文件大小 @@ -435,7 +469,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt return Response(content="日志文件不存在!", media_type="text/plain") try: # 使用 aiofiles 异步读取文件 - async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file: + async with aiofiles.open( + log_path, mode="r", encoding="utf-8", errors="ignore" + ) as file: text = await file.read() # 倒序输出 text = "\n".join(text.split("\n")[::-1]) @@ -447,13 +483,16 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt return StreamingResponse(log_generator(), media_type="text/event-stream") -@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response) +@router.get( + "/versions", summary="查询Github所有Release版本", response_model=schemas.Response +) async def latest_version(_: schemas.TokenPayload = Depends(verify_token)): """ 查询Github所有Release版本 """ - version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res( - f"https://api.github.com/repos/jxxghp/MoviePilot/releases") + version_res = await AsyncRequestUtils( + proxies=settings.PROXY, headers=settings.GITHUB_HEADERS + ).get_res(f"https://api.github.com/repos/jxxghp/MoviePilot/releases") if version_res: ver_json = version_res.json() if ver_json: @@ -462,10 +501,12 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)): @router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response) -def ruletest(title: str, - rulegroup_name: str, - subtitle: Optional[str] = None, - _: schemas.TokenPayload = Depends(verify_token)): +def ruletest( + title: str, + rulegroup_name: str, + subtitle: Optional[str] = None, + _: schemas.TokenPayload = Depends(verify_token), +): """ 过滤规则测试,规则类型 1-订阅,2-洗版,3-搜索 """ @@ -476,7 +517,9 @@ def ruletest(title: str, # 查询规则组详情 rulegroup = RuleHelper().get_rule_group(rulegroup_name) if not rulegroup: - return schemas.Response(success=False, message=f"过滤规则组 {rulegroup_name} 不存在!") + return schemas.Response( + success=False, message=f"过滤规则组 {rulegroup_name} 不存在!" + ) # 根据标题查询媒体信息 media_info = SearchChain().recognize_media(MetaInfo(title=title, subtitle=subtitle)) @@ -484,21 +527,22 @@ def ruletest(title: str, return schemas.Response(success=False, message="未识别到媒体信息!") # 过滤 - result = SearchChain().filter_torrents(rule_groups=[rulegroup.name], - torrent_list=[torrent], mediainfo=media_info) + result = SearchChain().filter_torrents( + rule_groups=[rulegroup.name], torrent_list=[torrent], mediainfo=media_info + ) if not result: return schemas.Response(success=False, message="不符合过滤规则!") - return schemas.Response(success=True, data={ - "priority": 100 - result[0].pri_order + 1 - }) + return schemas.Response( + success=True, data={"priority": 100 - result[0].pri_order + 1} + ) @router.get("/nettest", summary="测试网络连通性") async def nettest( - url: str, - proxy: bool, - include: Optional[str] = None, - _: schemas.TokenPayload = Depends(verify_token), + url: str, + proxy: bool, + include: Optional[str] = None, + _: schemas.TokenPayload = Depends(verify_token), ): """ 测试网络连通性 @@ -570,21 +614,26 @@ async def nettest( return schemas.Response(success=False, message=message, data={"time": time}) -@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response) +@router.get( + "/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response +) def modulelist(_: schemas.TokenPayload = Depends(verify_token)): """ 查询已加载的模块ID列表 """ - modules = [{ - "id": k, - "name": v.get_name(), - } for k, v in ModuleManager().get_modules().items()] - return schemas.Response(success=True, data={ - "modules": modules - }) + modules = [ + { + "id": k, + "name": v.get_name(), + } + for k, v in ModuleManager().get_modules().items() + ] + return schemas.Response(success=True, data={"modules": modules}) -@router.get("/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response) +@router.get( + "/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response +) def moduletest(moduleid: str, _: schemas.TokenPayload = Depends(verify_token)): """ 模块可用性测试接口 @@ -608,8 +657,7 @@ def restart_system(_: User = Depends(get_current_active_superuser)): @router.get("/runscheduler", summary="运行服务", response_model=schemas.Response) -def run_scheduler(jobid: str, - _: User = Depends(get_current_active_superuser)): +def run_scheduler(jobid: str, _: User = Depends(get_current_active_superuser)): """ 执行命令(仅管理员) """ @@ -622,9 +670,10 @@ def run_scheduler(jobid: str, return schemas.Response(success=True) -@router.get("/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response) -def run_scheduler2(jobid: str, - _: Annotated[str, Depends(verify_apitoken)]): +@router.get( + "/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response +) +def run_scheduler2(jobid: str, _: Annotated[str, Depends(verify_apitoken)]): """ 执行命令(API_TOKEN认证) """ diff --git a/app/helper/llm.py b/app/helper/llm.py index 46e2ca09..5526d020 100644 --- a/app/helper/llm.py +++ b/app/helper/llm.py @@ -46,7 +46,7 @@ class LLMHelper: api_key=api_key, retries=3, temperature=settings.LLM_TEMPERATURE, - streaming=streaming + streaming=streaming, ) elif provider == "deepseek": from langchain_deepseek import ChatDeepSeek @@ -78,13 +78,14 @@ class LLMHelper: logger.info(f"使用LLM模型: {model.model},Profile: {model.profile}") else: model.profile = { - "max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000, # 转换为token单位 + "max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS + * 1000, # 转换为token单位 } return model def get_models( - self, provider: str, api_key: str, base_url: str = None + self, provider: str, api_key: str, base_url: str = None ) -> List[str]: """获取模型列表""" logger.info(f"获取 {provider} 模型列表...") @@ -98,8 +99,16 @@ class LLMHelper: """获取Google模型列表(使用 google-genai SDK v1)""" try: from google import genai + from google.genai.types import HttpOptions - client = genai.Client(api_key=api_key) + http_options = None + if settings.PROXY_HOST: + http_options = HttpOptions( + client_args={"proxy": settings.PROXY_HOST}, + async_client_args={"proxy": settings.PROXY_HOST}, + ) + + client = genai.Client(api_key=api_key, http_options=http_options) models = client.models.list() return [ m.name @@ -112,7 +121,7 @@ class LLMHelper: @staticmethod def _get_openai_compatible_models( - provider: str, api_key: str, base_url: str = None + provider: str, api_key: str, base_url: str = None ) -> List[str]: """获取OpenAI兼容模型列表""" try: