From 749aaeb0030e2f28d4173905b3ba238c3dfef5d2 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 31 Jul 2025 07:07:14 +0800 Subject: [PATCH] fix async --- app/api/endpoints/subscribe.py | 1 - app/api/endpoints/user.py | 5 +- app/api/endpoints/workflow.py | 6 +- app/db/__init__.py | 2 + app/modules/themoviedb/tmdbv3api/tmdb.py | 162 ++++++++++------------- 5 files changed, 76 insertions(+), 100 deletions(-) diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index 702c4338..8646475d 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -115,7 +115,6 @@ async def update_subscribe( # 是否手动修改过总集数 if subscribe_in.total_episode != subscribe.total_episode: subscribe_dict["manual_total_episode"] = 1 - await subscribe.async_update(db, subscribe_dict) # 发送订阅调整事件 subscribe = await subscribe.async_get(db, subscribe_in.id) eventmanager.send_event(EventType.SubscribeModified, { diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index c014ae38..098cac93 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -45,9 +45,8 @@ async def create_user( if user_info.get("password"): user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info.pop("password") - user = User(**user_info) - await user.async_create(db) - return schemas.Response(success=True) + user = await User(**user_info).async_create(db) + return schemas.Response(success=True if user else False) @router.put("/", summary="更新用户", response_model=schemas.Response) diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index ef8bd11a..2dfdb6be 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -154,11 +154,7 @@ async def workflow_fork( return schemas.Response(success=False, message="已存在相同名称的工作流") # 创建新工作流 - workflow_obj = Workflow(**workflow_dict) - await workflow_obj.async_create(db) - - # 获取工作流ID(在数据库会话有效时) - workflow = await workflow_oper.async_get_by_name(workflow_dict["name"]) + workflow = await Workflow(**workflow_dict).async_create(db) # 更新复用次数 if workflow: diff --git a/app/db/__init__.py b/app/db/__init__.py index 3f07ce64..6b7905c4 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -353,6 +353,8 @@ class Base: @async_db_update async def async_create(self, db: AsyncSession): db.add(self) + await db.flush() + return self @classmethod @db_query diff --git a/app/modules/themoviedb/tmdbv3api/tmdb.py b/app/modules/themoviedb/tmdbv3api/tmdb.py index c688511e..e3ad4b6d 100644 --- a/app/modules/themoviedb/tmdbv3api/tmdb.py +++ b/app/modules/themoviedb/tmdbv3api/tmdb.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import asyncio import logging import time from datetime import datetime @@ -130,17 +131,11 @@ class TMDb(object): @cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta) def cached_request(self, method, url, data, json, _ts=datetime.strftime(datetime.now(), '%Y%m%d')): - """ - 缓存请求 - """ return self.request(method, url, data, json) @cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta) async def async_cached_request(self, method, url, data, json, _ts=datetime.strftime(datetime.now(), '%Y%m%d')): - """ - 缓存请求(异步版本) - """ return await self.async_request(method, url, data, json) def request(self, method, url, data, json): @@ -164,12 +159,12 @@ class TMDb(object): def cache_clear(self): return self.cached_request.cache_clear() - def _request_obj(self, action, params="", call_cached=True, - method="GET", data=None, json=None, key=None): + def _validate_api_key(self): if self.api_key is None or self.api_key == "": raise TMDbException("TheMovieDb API Key 未设置!") - url = "https://%s/3%s?api_key=%s&%s&language=%s" % ( + def _build_url(self, action, params=""): + return "https://%s/3%s?api_key=%s&%s&language=%s" % ( self.domain, action, self.api_key, @@ -177,6 +172,55 @@ class TMDb(object): self.language, ) + def _handle_headers(self, headers): + if "X-RateLimit-Remaining" in headers: + self._remaining = int(headers["X-RateLimit-Remaining"]) + + if "X-RateLimit-Reset" in headers: + self._reset = int(headers["X-RateLimit-Reset"]) + + def _handle_rate_limit(self): + if self._remaining < 1: + current_time = int(time.time()) + sleep_time = self._reset - current_time + + if self.wait_on_rate_limit: + logger.warning("达到请求频率限制,休眠:%d 秒..." % sleep_time) + return abs(sleep_time) + else: + raise TMDbException("达到请求频率限制,请稍后再试!") + return 0 + + def _process_json_response(self, json_data, is_async=False): + if "page" in json_data: + self._page = json_data["page"] + + if "total_results" in json_data: + self._total_results = json_data["total_results"] + + if "total_pages" in json_data: + self._total_pages = json_data["total_pages"] + + if self.debug: + logger.info(json_data) + if is_async: + logger.info(self.async_cached_request.cache_info()) + else: + logger.info(self.cached_request.cache_info()) + + @staticmethod + def _handle_errors(json_data): + if "errors" in json_data: + raise TMDbException(json_data["errors"]) + + if "success" in json_data and json_data["success"] is False: + raise TMDbException(json_data["status_message"]) + + def _request_obj(self, action, params="", call_cached=True, + method="GET", data=None, json=None, key=None): + self._validate_api_key() + url = self._build_url(action, params) + if self.cache and self.obj_cached and call_cached and method != "POST": req = self.cached_request(method, url, data, json) else: @@ -185,45 +229,17 @@ class TMDb(object): if req is None: return None - headers = req.headers + self._handle_headers(req.headers) - if "X-RateLimit-Remaining" in headers: - self._remaining = int(headers["X-RateLimit-Remaining"]) - - if "X-RateLimit-Reset" in headers: - self._reset = int(headers["X-RateLimit-Reset"]) - - if self._remaining < 1: - current_time = int(time.time()) - sleep_time = self._reset - current_time - - if self.wait_on_rate_limit: - logger.warning("达到请求频率限制,休眠:%d 秒..." % sleep_time) - time.sleep(abs(sleep_time)) - return self._request_obj(action, params, call_cached, method, data, json, key) - else: - raise TMDbException("达到请求频率限制,将在 %d 秒后重试..." % sleep_time) + rate_limit_result = self._handle_rate_limit() + if rate_limit_result: + logger.warning("达到请求频率限制,将在 %d 秒后重试..." % rate_limit_result) + time.sleep(rate_limit_result) + return self._request_obj(action, params, call_cached, method, data, json, key) json_data = req.json() - - if "page" in json_data: - self._page = json_data["page"] - - if "total_results" in json_data: - self._total_results = json_data["total_results"] - - if "total_pages" in json_data: - self._total_pages = json_data["total_pages"] - - if self.debug: - logger.info(json_data) - logger.info(self.cached_request.cache_info()) - - if "errors" in json_data: - raise TMDbException(json_data["errors"]) - - if "success" in json_data and json_data["success"] is False: - raise TMDbException(json_data["status_message"]) + self._process_json_response(json_data, is_async=False) + self._handle_errors(json_data) if key: return json_data.get(key) @@ -231,16 +247,8 @@ class TMDb(object): async def _async_request_obj(self, action, params="", call_cached=True, method="GET", data=None, json=None, key=None): - if self.api_key is None or self.api_key == "": - raise TMDbException("TheMovieDb API Key 未设置!") - - url = "https://%s/3%s?api_key=%s&%s&language=%s" % ( - self.domain, - action, - self.api_key, - params, - self.language, - ) + self._validate_api_key() + url = self._build_url(action, params) if self.cache and self.obj_cached and call_cached and method != "POST": req = await self.async_cached_request(method, url, data, json) @@ -250,45 +258,17 @@ class TMDb(object): if req is None: return None - headers = req.headers + self._handle_headers(req.headers) - if "X-RateLimit-Remaining" in headers: - self._remaining = int(headers["X-RateLimit-Remaining"]) - - if "X-RateLimit-Reset" in headers: - self._reset = int(headers["X-RateLimit-Reset"]) - - if self._remaining < 1: - current_time = int(time.time()) - sleep_time = self._reset - current_time - - if self.wait_on_rate_limit: - logger.warning("达到请求频率限制,休眠:%d 秒..." % sleep_time) - time.sleep(abs(sleep_time)) - return await self._async_request_obj(action, params, call_cached, method, data, json, key) - else: - raise TMDbException("达到请求频率限制,将在 %d 秒后重试..." % sleep_time) + rate_limit_result = self._handle_rate_limit() + if rate_limit_result: + logger.warning("达到请求频率限制,将在 %d 秒后重试..." % rate_limit_result) + await asyncio.sleep(rate_limit_result) + return await self._async_request_obj(action, params, call_cached, method, data, json, key) json_data = req.json() - - if "page" in json_data: - self._page = json_data["page"] - - if "total_results" in json_data: - self._total_results = json_data["total_results"] - - if "total_pages" in json_data: - self._total_pages = json_data["total_pages"] - - if self.debug: - logger.info(json_data) - logger.info(self.async_cached_request.cache_info()) - - if "errors" in json_data: - raise TMDbException(json_data["errors"]) - - if "success" in json_data and json_data["success"] is False: - raise TMDbException(json_data["status_message"]) + self._process_json_response(json_data, is_async=True) + self._handle_errors(json_data) if key: return json_data.get(key)