diff --git a/backend/src/module/ab_decorator/__init__.py b/backend/src/module/ab_decorator/__init__.py index 02467804..3897a8e9 100644 --- a/backend/src/module/ab_decorator/__init__.py +++ b/backend/src/module/ab_decorator/__init__.py @@ -2,11 +2,15 @@ import asyncio import functools import logging +import httpx + from .timeout import timeout logger = logging.getLogger(__name__) _lock = asyncio.Lock() +_RETRY_DELAYS = [5, 15, 45, 120, 300] + def qb_connect_failed_wait(func): @functools.wraps(func) @@ -15,11 +19,21 @@ def qb_connect_failed_wait(func): while times < 5: try: return await func(*args, **kwargs) - except Exception as e: + except ( + ConnectionError, + TimeoutError, + OSError, + httpx.ConnectError, + httpx.TimeoutException, + httpx.RequestError, + ) as e: + delay = _RETRY_DELAYS[times] logger.debug("URL: %s", args[0]) logger.warning(e) - logger.warning("Cannot connect to qBittorrent. Wait 5 min and retry...") - await asyncio.sleep(300) + logger.warning( + "Cannot connect to qBittorrent. Wait %ds and retry...", delay + ) + await asyncio.sleep(delay) times += 1 return wrapper diff --git a/backend/src/module/conf/__init__.py b/backend/src/module/conf/__init__.py index 67acbd55..a4190581 100644 --- a/backend/src/module/conf/__init__.py +++ b/backend/src/module/conf/__init__.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path from .config import VERSION, settings @@ -10,4 +11,4 @@ LEGACY_DATA_PATH = Path("data/data.json") VERSION_PATH = Path("config/version.info") POSTERS_PATH = Path("data/posters") -PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix" +PLATFORM = "Windows" if sys.platform == "win32" else "Unix" diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 5483b3a4..fe41029d 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -126,7 +126,9 @@ class BangumiDatabase: return None - def add_title_alias(self, bangumi_id: int, new_title_raw: str) -> bool: + def add_title_alias( + self, bangumi_id: int, new_title_raw: str, auto_commit: bool = True + ) -> bool: """ Add a new title_raw alias to an existing bangumi. @@ -152,8 +154,9 @@ class BangumiDatabase: _set_aliases_list(bangumi, aliases) self.session.add(bangumi) - self.session.commit() - _invalidate_bangumi_cache() + if auto_commit: + self.session.commit() + _invalidate_bangumi_cache() logger.info( f"[Database] Added alias '{new_title_raw}' to bangumi '{bangumi.official_title}' " f"(id: {bangumi_id})" @@ -233,8 +236,8 @@ class BangumiDatabase: for d in to_add: semantic_match = self.find_semantic_duplicate(d) if semantic_match: - # Add as alias instead of creating new entry - self.add_title_alias(semantic_match.id, d.title_raw) + # Add as alias instead of creating new entry (defer commit) + self.add_title_alias(semantic_match.id, d.title_raw, auto_commit=False) semantic_merged += 1 logger.info( f"[Database] Merged '{d.title_raw}' as alias to existing " @@ -254,9 +257,10 @@ class BangumiDatabase: if not unique_to_add: if semantic_merged > 0: + self.session.commit() + _invalidate_bangumi_cache() logger.debug( - "[Database] %s bangumi merged as aliases, " - "rest were duplicates.", + "[Database] %s bangumi merged as aliases, " "rest were duplicates.", semantic_merged, ) else: @@ -330,7 +334,9 @@ class BangumiDatabase: self.session.add(bangumi) self.session.commit() _invalidate_bangumi_cache() - logger.debug("[Database] Update %s poster_link to %s.", title_raw, poster_link) + logger.debug( + "[Database] Update %s poster_link to %s.", title_raw, poster_link + ) def delete_one(self, _id: int): statement = select(Bangumi).where(Bangumi.id == _id) @@ -427,6 +433,7 @@ class BangumiDatabase: rss_link not in match_data.rss_link and match_data.title_raw not in rss_updated ): + match_data = self.session.merge(match_data) match_data.rss_link += f",{rss_link}" match_data.added = False rss_updated.add(match_data.title_raw) @@ -481,8 +488,8 @@ class BangumiDatabase: conditions = select(Bangumi).where( or_( Bangumi.added == 0, - Bangumi.rule_name is None, - Bangumi.save_path is None, + Bangumi.rule_name.is_(None), + Bangumi.save_path.is_(None), ) ) result = self.session.execute(conditions) diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index 36e68a31..b69889f5 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -199,17 +199,21 @@ class Database(Session): if "title_aliases" in columns: needs_run = False if needs_run: - with self.engine.connect() as conn: - for stmt in statements: - conn.execute(text(stmt)) - conn.commit() - logger.info(f"[Database] Migration v{version}: {description}") + try: + with self.engine.connect() as conn: + for stmt in statements: + conn.execute(text(stmt)) + conn.commit() + logger.info(f"[Database] Migration v{version}: {description}") + except Exception as e: + logger.error(f"[Database] Migration v{version} failed: {e}") + break else: logger.debug( f"[Database] Migration v{version} skipped (already applied): {description}" ) - self._set_schema_version(CURRENT_SCHEMA_VERSION) - logger.info(f"[Database] Schema version is now {CURRENT_SCHEMA_VERSION}.") + self._set_schema_version(version) + logger.info(f"[Database] Schema version is now {self._get_schema_version()}.") self._fill_null_with_defaults() def _get_field_default(self, field_info: FieldInfo) -> tuple[bool, Any]: @@ -290,8 +294,8 @@ class Database(Session): result = conn.execute( text( - f"UPDATE {table_name} SET {field_name} = :val " - f"WHERE {field_name} IS NULL" + f'UPDATE "{table_name}" SET "{field_name}" = :val ' + f'WHERE "{field_name}" IS NULL' ), {"val": sql_value}, ) @@ -309,6 +313,9 @@ class Database(Session): # Run migration online bangumi_data = self.bangumi.search_all() user_data = self.exec("SELECT * FROM user").all() + if not user_data: + logger.warning("[Database] No user data found, skipping migration.") + return readd_bangumi = [] for bangumi in bangumi_data: dict_data = bangumi.dict() @@ -317,7 +324,10 @@ class Database(Session): self.drop_table() self.create_table() self.commit() - bangumi_data = self.bangumi.search_all() - self.bangumi.add_all(readd_bangumi) - self.add(User(**user_data[0])) - self.commit() + try: + self.bangumi.add_all(readd_bangumi) + self.add(User(**user_data[0])) + self.commit() + except Exception: + self.rollback() + raise diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index 5bdb6f64..c0d84360 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,3 +1,4 @@ +from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from sqlmodel import create_engine @@ -10,4 +11,20 @@ engine = create_engine(DATA_PATH) # Async engine (for passkey operations) ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///") async_engine = create_async_engine(ASYNC_DATA_PATH) -async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) +async_session_factory = sessionmaker( + async_engine, class_=AsyncSession, expire_on_commit=False +) + + +@event.listens_for(engine, "connect") +def _set_sqlite_fk_sync(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +@event.listens_for(async_engine.sync_engine, "connect") +def _set_sqlite_fk_async(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/backend/src/module/network/request_contents.py b/backend/src/module/network/request_contents.py index 0ea39d4a..e271947e 100644 --- a/backend/src/module/network/request_contents.py +++ b/backend/src/module/network/request_contents.py @@ -65,7 +65,7 @@ class RequestContent(RequestURL): async def get_html(self, _url): resp = await self.get_url(_url) - return resp.text + return resp.text if resp else None async def get_content(self, _url): req = await self.get_url(_url) diff --git a/backend/src/module/network/site/mikan.py b/backend/src/module/network/site/mikan.py index 5490fdec..d000e5a7 100644 --- a/backend/src/module/network/site/mikan.py +++ b/backend/src/module/network/site/mikan.py @@ -1,15 +1,24 @@ +import logging + +logger = logging.getLogger(__name__) + + def rss_parser(soup): results = [] for item in soup.findall("./channel/item"): - title = item.find("title").text - enclosure = item.find("enclosure") - if enclosure is not None: - homepage = item.find("link").text - url = enclosure.attrib.get("url") - else: - url = item.find("link").text - homepage = "" - results.append((title, url, homepage)) + try: + title = item.find("title").text + enclosure = item.find("enclosure") + if enclosure is not None: + homepage = item.find("link").text + url = enclosure.attrib.get("url") + else: + url = item.find("link").text + homepage = "" + results.append((title, url, homepage)) + except Exception as e: + logger.warning("[RSS] Failed to parse RSS item: %s", e) + continue return results diff --git a/backend/src/module/parser/analyser/openai.py b/backend/src/module/parser/analyser/openai.py index c8b44a4f..5aa7a981 100644 --- a/backend/src/module/parser/analyser/openai.py +++ b/backend/src/module/parser/analyser/openai.py @@ -1,16 +1,16 @@ import json import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any -from pydantic import BaseModel -from typing import Optional +from typing import Any, Optional -from openai import OpenAI, AzureOpenAI +from openai import AzureOpenAI, OpenAI +from pydantic import BaseModel from module.models import Bangumi logger = logging.getLogger(__name__) + class Episode(BaseModel): title_en: Optional[str] title_zh: Optional[str] @@ -103,10 +103,15 @@ class OpenAIParser: result = resp.choices[0].message.parsed if asdict: - try: - result = json.loads(result[result.index("{"):result.rindex("}") + 1]) # find the first { and last } for better compatibility - except json.JSONDecodeError: - logger.warning(f"Cannot parse result {result} as python dict.") + if hasattr(result, "model_dump"): + result = result.model_dump() + else: + try: + result = json.loads( + result[result.index("{") : result.rindex("}") + 1] + ) # find the first { and last } for better compatibility + except (json.JSONDecodeError, ValueError): + logger.warning(f"Cannot parse result {result} as python dict.") logger.debug("the parsed result is: %s", result) @@ -131,7 +136,6 @@ class OpenAIParser: dict(role="user", content=text), ], response_format=Episode, - # set temperature to 0 to make results be more stable and reproducible. temperature=0, ) diff --git a/backend/src/module/parser/analyser/raw_parser.py b/backend/src/module/parser/analyser/raw_parser.py index 46b2a948..c2ab641c 100644 --- a/backend/src/module/parser/analyser/raw_parser.py +++ b/backend/src/module/parser/analyser/raw_parser.py @@ -41,7 +41,7 @@ def pre_process(raw_name: str) -> str: def prefix_process(raw: str, group: str) -> str: - raw = re.sub(f".{group}.", "", raw) + raw = re.sub(f".{re.escape(group)}.", "", raw) raw_process = PREFIX_RE.sub("/", raw) arg_group = raw_process.split("/") while "" in arg_group: @@ -50,9 +50,9 @@ def prefix_process(raw: str, group: str) -> str: arg_group = arg_group[0].split(" ") for arg in arg_group: if re.search(r"新番|月?番", arg) and len(arg) <= 5: - raw = re.sub(f".{arg}.", "", raw) + raw = re.sub(f".{re.escape(arg)}.", "", raw) elif re.search(r"港澳台地区", arg): - raw = re.sub(f".{arg}.", "", raw) + raw = re.sub(f".{re.escape(arg)}.", "", raw) return raw @@ -79,7 +79,7 @@ def season_process(season_info: str): season = int(season_pro) except ValueError: season = CHINESE_NUMBER_MAP[season_pro] - break + break return name, season_raw, season @@ -140,6 +140,8 @@ def process(raw_title: str): group = get_group(content_title) # 翻译组的名字 match_obj = TITLE_RE.match(content_title) + if match_obj is None: + return None # 处理标题 season_info, episode_info, other = list( map(lambda x: x.strip(), match_obj.groups()) diff --git a/backend/src/module/parser/analyser/tmdb_parser.py b/backend/src/module/parser/analyser/tmdb_parser.py index 04ef10f9..5aca52d8 100644 --- a/backend/src/module/parser/analyser/tmdb_parser.py +++ b/backend/src/module/parser/analyser/tmdb_parser.py @@ -1,6 +1,8 @@ +import asyncio import logging import re import time +from collections import OrderedDict from dataclasses import dataclass from module.conf import TMDB_API @@ -12,7 +14,8 @@ logger = logging.getLogger(__name__) TMDB_URL = "https://api.themoviedb.org" # In-memory cache for TMDB lookups to avoid repeated API calls -_tmdb_cache: dict[str, "TMDBInfo | None"] = {} +_TMDB_CACHE_MAX = 512 +_tmdb_cache: OrderedDict[str, "TMDBInfo | None"] = OrderedDict() @dataclass @@ -26,7 +29,9 @@ class TMDBInfo: poster_link: str = None series_status: str = None # "Ended", "Returning Series", etc. season_episode_counts: dict[int, int] = None # {1: 13, 2: 12, ...} - virtual_season_starts: dict[int, list[int]] = None # {1: [1, 29], ...} - episode numbers where virtual seasons start + virtual_season_starts: dict[int, list[int]] = ( + None # {1: [1, 29], ...} - episode numbers where virtual seasons start + ) def get_offset_for_season(self, season: int) -> int: """Calculate offset for a season (negative sum of all previous seasons' episodes). @@ -64,7 +69,9 @@ async def is_animation(tv_id, language, req: RequestContent) -> bool: return False -async def get_season_episode_air_dates(tv_id: int, season_number: int, language: str, req: RequestContent) -> list[dict]: +async def get_season_episode_air_dates( + tv_id: int, season_number: int, language: str, req: RequestContent +) -> list[dict]: """Get episode air dates for a season. Returns: @@ -122,13 +129,17 @@ def detect_virtual_seasons(episodes: list[dict], gap_months: int = 6) -> list[in logger.debug( "[TMDB] Detected virtual season break: %s days gap " "between ep%s and ep%s", - days_diff, prev_ep['episode_number'], curr_ep['episode_number'] + days_diff, + prev_ep["episode_number"], + curr_ep["episode_number"], ) return virtual_season_starts -async def get_aired_episode_count(tv_id: int, season_number: int, language: str, req: RequestContent) -> int: +async def get_aired_episode_count( + tv_id: int, season_number: int, language: str, req: RequestContent +) -> int: """Get the count of episodes that have actually aired for a season. Args: @@ -162,20 +173,27 @@ async def get_aired_episode_count(tv_id: int, season_number: int, language: str, # Invalid date format, skip this episode continue - logger.debug("[TMDB] Season %s: %s aired of %s total episodes", season_number, aired_count, len(episodes)) + logger.debug( + "[TMDB] Season %s: %s aired of %s total episodes", + season_number, + aired_count, + len(episodes), + ) return aired_count def get_season(seasons: list) -> tuple[int, str]: ss = [s for s in seasons if s["air_date"] is not None and "特别" not in s["season"]] + if not ss: + return 1, None ss = sorted(ss, key=lambda e: e.get("air_date"), reverse=True) for season in ss: - if re.search(r"第 \d 季", season.get("season")) is not None: + if re.search(r"第 \d+ 季", season.get("season")) is not None: date = season.get("air_date").split("-") [year, _, _] = date now_year = time.localtime().tm_year if int(year) <= now_year: - return int(re.findall(r"\d", season.get("season"))[0]), season.get( + return int(re.findall(r"\d+", season.get("season"))[0]), season.get( "poster_path" ) return len(ss), ss[-1].get("poster_path") @@ -200,11 +218,16 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: contents = contents_resp.get("results") # 判断动画 if contents: + matched_id = None for content in contents: - id = content["id"] - if await is_animation(id, language, req): + cid = content["id"] + if await is_animation(cid, language, req): + matched_id = cid break - url_info = info_url(id, language) + if matched_id is None: + _tmdb_cache[cache_key] = None + return None + url_info = info_url(matched_id, language) info_content = await req.get_json(url_info) season = [ { @@ -221,37 +244,58 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: # For ongoing series, we need to get actual aired episode counts season_episode_counts = {} virtual_season_starts = {} - for s in info_content.get("seasons", []): - season_num = s.get("season_number", 0) - if season_num > 0: - total_eps = s.get("episode_count", 0) - # Get episode air dates for virtual season detection - episodes = await get_season_episode_air_dates(id, season_num, language, req) - if episodes: - # Detect virtual seasons based on air date gaps - vs_starts = detect_virtual_seasons(episodes) - if len(vs_starts) > 1: - virtual_season_starts[season_num] = vs_starts - logger.debug("[TMDB] Season %s has virtual seasons starting at episodes: %s", season_num, vs_starts) - # Count only aired episodes - season_episode_counts[season_num] = len(episodes) - else: - season_episode_counts[season_num] = total_eps + season_nums = [ + (s.get("season_number", 0), s.get("episode_count", 0)) + for s in info_content.get("seasons", []) + if s.get("season_number", 0) > 0 + ] + episode_results = await asyncio.gather( + *[ + get_season_episode_air_dates(matched_id, sn, language, req) + for sn, _ in season_nums + ], + return_exceptions=True, + ) + for (season_num, total_eps), episodes in zip(season_nums, episode_results): + if isinstance(episodes, Exception): + logger.warning( + "[TMDB] Failed to get episodes for season %s: %s", + season_num, + episodes, + ) + season_episode_counts[season_num] = total_eps + continue + if episodes: + # Detect virtual seasons based on air date gaps + vs_starts = detect_virtual_seasons(episodes) + if len(vs_starts) > 1: + virtual_season_starts[season_num] = vs_starts + logger.debug( + "[TMDB] Season %s has virtual seasons starting at episodes: %s", + season_num, + vs_starts, + ) + # Count only aired episodes + season_episode_counts[season_num] = len(episodes) + else: + season_episode_counts[season_num] = total_eps if poster_path is None: poster_path = info_content.get("poster_path") original_title = info_content.get("original_name") official_title = info_content.get("name") - year_number = info_content.get("first_air_date").split("-")[0] + year_number = (info_content.get("first_air_date") or "").split("-")[0] if poster_path: if not test: - img = await req.get_content(f"https://image.tmdb.org/t/p/w780{poster_path}") + img = await req.get_content( + f"https://image.tmdb.org/t/p/w780{poster_path}" + ) poster_link = save_image(img, "jpg") else: poster_link = "https://image.tmdb.org/t/p/w780" + poster_path else: poster_link = None result = TMDBInfo( - id=id, + id=matched_id, title=official_title, original_title=original_title, season=season, @@ -260,15 +304,22 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: poster_link=poster_link, series_status=series_status, season_episode_counts=season_episode_counts, - virtual_season_starts=virtual_season_starts if virtual_season_starts else None, + virtual_season_starts=( + virtual_season_starts if virtual_season_starts else None + ), ) + if len(_tmdb_cache) >= _TMDB_CACHE_MAX: + _tmdb_cache.popitem(last=False) _tmdb_cache[cache_key] = result return result else: + if len(_tmdb_cache) >= _TMDB_CACHE_MAX: + _tmdb_cache.popitem(last=False) _tmdb_cache[cache_key] = None return None if __name__ == "__main__": import asyncio + print(asyncio.run(tmdb_parser("魔法禁书目录", "zh"))) diff --git a/backend/src/module/rss/engine.py b/backend/src/module/rss/engine.py index aa7c2083..eaea3dc2 100644 --- a/backend/src/module/rss/engine.py +++ b/backend/src/module/rss/engine.py @@ -16,6 +16,7 @@ class RSSEngine(Database): def __init__(self, _engine=engine): super().__init__(_engine) self._to_refresh = False + self._filter_cache: dict[str, re.Pattern] = {} @staticmethod async def _get_torrents(rss: RSSItem) -> list[Torrent]: @@ -109,8 +110,6 @@ class RSSEngine(Database): logger.warning(f"[Engine] Failed to fetch RSS {rss_item.name}: {e}") return [], str(e) - _filter_cache: dict[str, re.Pattern] = {} - def _get_filter_pattern(self, filter_str: str) -> re.Pattern: if filter_str not in self._filter_cache: self._filter_cache[filter_str] = re.compile(