mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-05 03:38:36 +08:00
fix async
This commit is contained in:
@@ -15,6 +15,7 @@ from app.core.cache import cache_backend, cached
|
||||
from app.core.config import settings, global_vars
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.asyncio import AsyncUtils
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
@@ -38,7 +39,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
刷新推荐数据 - 同步包装器
|
||||
"""
|
||||
try:
|
||||
asyncio.run(self.async_refresh_recommend())
|
||||
AsyncUtils.run_async(self.async_refresh_recommend())
|
||||
except Exception as e:
|
||||
logger.error(f"刷新推荐数据失败:{str(e)}")
|
||||
raise
|
||||
|
||||
@@ -92,7 +92,12 @@ class WorkflowHelper(metaclass=WeakSingleton):
|
||||
cache_backend.clear(region=self._shares_cache_region)
|
||||
return True, ""
|
||||
else:
|
||||
return False, res.json().get("message")
|
||||
try:
|
||||
error_msg = res.json().get("message", "未知错误")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"工作流响应JSON解析失败: {e}")
|
||||
error_msg = f"响应解析失败: {res.text[:100]}..."
|
||||
return False, error_msg
|
||||
|
||||
@staticmethod
|
||||
def _handle_list_response(res) -> List[dict]:
|
||||
@@ -100,7 +105,11 @@ class WorkflowHelper(metaclass=WeakSingleton):
|
||||
处理返回List的HTTP响应
|
||||
"""
|
||||
if res and res.status_code == 200:
|
||||
return res.json()
|
||||
try:
|
||||
return res.json()
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"工作流列表响应JSON解析失败: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def workflow_share(self, workflow_id: int,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.utils.asyncio import AsyncUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
|
||||
|
||||
@@ -29,7 +31,8 @@ class BangumiApi(object):
|
||||
def __init__(self):
|
||||
self._session = requests.Session()
|
||||
self._req = RequestUtils(session=self._session)
|
||||
self._async_req = AsyncRequestUtils()
|
||||
self._client = httpx.AsyncClient()
|
||||
self._async_req = AsyncRequestUtils(client=self._client)
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
|
||||
def __invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
@@ -303,3 +306,5 @@ class BangumiApi(object):
|
||||
def close(self):
|
||||
if self._session:
|
||||
self._session.close()
|
||||
if self._client:
|
||||
AsyncUtils.run_async(self._client.aclose())
|
||||
|
||||
@@ -12,6 +12,7 @@ import requests
|
||||
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.utils.asyncio import AsyncUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
|
||||
@@ -155,7 +156,7 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
|
||||
def __init__(self):
|
||||
self._session = requests.Session()
|
||||
self._async_req = AsyncRequestUtils()
|
||||
self._client = httpx.AsyncClient()
|
||||
|
||||
@classmethod
|
||||
def __sign(cls, url: str, ts: str, method='GET') -> str:
|
||||
@@ -249,7 +250,10 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
GET请求(异步版本)
|
||||
"""
|
||||
req_url, params = self._prepare_get_request(url, **kwargs)
|
||||
resp = await self._async_req.get_res(url=req_url, params=params)
|
||||
resp = await AsyncRequestUtils(
|
||||
ua=choice(self._user_agents),
|
||||
client=self._client
|
||||
).get_res(url=req_url, params=params)
|
||||
return self._handle_response(resp)
|
||||
|
||||
def _prepare_post_request(self, url: str, **kwargs) -> tuple[str, dict]:
|
||||
@@ -292,7 +296,10 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
POST请求(异步版本)
|
||||
"""
|
||||
req_url, params = self._prepare_post_request(url, **kwargs)
|
||||
resp = await self._async_req.post_res(url=req_url, data=params)
|
||||
resp = await AsyncRequestUtils(
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
client=self._client
|
||||
).post_res(url=req_url, data=params)
|
||||
return self._handle_response(resp)
|
||||
|
||||
def imdbid(self, imdbid: str,
|
||||
@@ -869,3 +876,5 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
def close(self):
|
||||
if self._session:
|
||||
self._session.close()
|
||||
if self._client:
|
||||
AsyncUtils.run_async(self._client.aclose())
|
||||
|
||||
@@ -352,7 +352,11 @@ class SiteParserBase(metaclass=ABCMeta):
|
||||
headers=req_headers).get_res(url=url)
|
||||
if res is not None and res.status_code in (200, 500, 403):
|
||||
if req_headers and "application/json" in str(req_headers.get("Accept")):
|
||||
return json.dumps(res.json())
|
||||
try:
|
||||
return json.dumps(res.json())
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"{self._site_name} API响应JSON解析失败: {e}")
|
||||
return ""
|
||||
else:
|
||||
# 如果cloudflare 有防护,尝试使用浏览器仿真
|
||||
if under_challenge(res.text):
|
||||
|
||||
@@ -5,11 +5,13 @@ import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.utils.asyncio import AsyncUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from .exceptions import TMDbException
|
||||
|
||||
@@ -17,14 +19,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TMDb(object):
|
||||
_req = None
|
||||
_async_req = None
|
||||
_session = None
|
||||
|
||||
def __init__(self, obj_cached=True, session=None, language=None):
|
||||
def __init__(self, obj_cached=True, session=None, client=None, language=None):
|
||||
self._api_key = settings.TMDB_API_KEY
|
||||
self._language = language or settings.TMDB_LOCALE or "en-US"
|
||||
self._session_id = None
|
||||
self._session = session
|
||||
self._client = client
|
||||
self._wait_on_rate_limit = True
|
||||
self._debug_enabled = False
|
||||
self._cache_enabled = obj_cached
|
||||
@@ -34,13 +35,14 @@ class TMDb(object):
|
||||
self._total_results = None
|
||||
self._total_pages = None
|
||||
|
||||
if session is not None:
|
||||
self._req = RequestUtils(session=session, proxies=self.proxies)
|
||||
else:
|
||||
if not self._session:
|
||||
self._session = requests.Session()
|
||||
self._req = RequestUtils(session=self._session, proxies=self.proxies)
|
||||
# 初始化异步请求客户端
|
||||
self._async_req = AsyncRequestUtils(proxies=self.proxies)
|
||||
self._req = RequestUtils(session=self._session, proxies=self.proxies)
|
||||
|
||||
if not self._client:
|
||||
self._client = httpx.AsyncClient()
|
||||
self._async_req = AsyncRequestUtils(client=self._client, proxies=self.proxies)
|
||||
|
||||
self._remaining = 40
|
||||
self._reset = None
|
||||
self._timeout = 15
|
||||
@@ -277,3 +279,5 @@ class TMDb(object):
|
||||
def close(self):
|
||||
if self._session:
|
||||
self._session.close()
|
||||
if self._client:
|
||||
AsyncUtils.run_async(self._client.aclose())
|
||||
|
||||
83
app/utils/asyncio.py
Normal file
83
app/utils/asyncio.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Coroutine, Any, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AsyncUtils:
|
||||
"""
|
||||
异步工具类,用于在同步环境中调用异步方法
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def run_async(coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""
|
||||
在同步环境中安全地执行异步协程
|
||||
|
||||
:param coro: 要执行的协程
|
||||
:return: 协程的返回值
|
||||
:raises: 协程执行过程中的任何异常
|
||||
"""
|
||||
try:
|
||||
# 尝试获取当前运行的事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果有运行中的事件循环,在新线程中执行
|
||||
return AsyncUtils._run_in_thread(coro)
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环,直接使用 asyncio.run
|
||||
return asyncio.run(coro)
|
||||
|
||||
@staticmethod
|
||||
def _run_in_thread(coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""
|
||||
在新线程中创建事件循环并执行协程
|
||||
|
||||
:param coro: 要执行的协程
|
||||
:return: 协程的返回值
|
||||
"""
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def _run():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
# 在新线程中创建新的事件循环
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
result = new_loop.run_until_complete(coro)
|
||||
finally:
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
# 在新线程中执行
|
||||
thread = threading.Thread(target=_run)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def run_async_in_executor(coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""
|
||||
使用线程池执行器在新线程中运行异步协程
|
||||
|
||||
:param coro: 要执行的协程
|
||||
:return: 协程的返回值
|
||||
"""
|
||||
try:
|
||||
# 检查是否有运行中的事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 有运行中的事件循环,使用线程池
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环,直接运行
|
||||
return asyncio.run(coro)
|
||||
Reference in New Issue
Block a user