fix(backend): improve database migrations, parsers, and network handling

Database:
- Add error handling and per-step version tracking in migrations
- Enable SQLite foreign keys via PRAGMA on connect
- Fix SQLAlchemy .is_(None) usage, add session.merge() for detached
- Batch commit for semantic alias merges
- Quote table/field names in fill-null-defaults SQL
- Guard against empty user data in migration

Parsers:
- TMDB: bounded LRU cache (512), asyncio.gather for parallel season
  fetches, fix season regex \d -> \d+, null-safe year, fix id shadowing
- Raw parser: re.escape() for group/prefix regex, None guard on match
- OpenAI: handle Pydantic model_dump, catch ValueError

Network:
- Null-safe get_html() return
- Error handling per RSS item in mikan parser
- Progressive retry delays (5/15/45/120/300s) with specific exceptions
- Platform detection via sys.platform instead of path heuristic
- Move filter cache to instance variable

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Estrella Pan
2026-02-23 11:46:35 +01:00
parent ec4aca5aba
commit 41298f2f8e
11 changed files with 198 additions and 84 deletions

View File

@@ -2,11 +2,15 @@ import asyncio
import functools
import logging
import httpx
from .timeout import timeout
logger = logging.getLogger(__name__)
_lock = asyncio.Lock()
_RETRY_DELAYS = [5, 15, 45, 120, 300]
def qb_connect_failed_wait(func):
@functools.wraps(func)
@@ -15,11 +19,21 @@ def qb_connect_failed_wait(func):
while times < 5:
try:
return await func(*args, **kwargs)
except Exception as e:
except (
ConnectionError,
TimeoutError,
OSError,
httpx.ConnectError,
httpx.TimeoutException,
httpx.RequestError,
) as e:
delay = _RETRY_DELAYS[times]
logger.debug("URL: %s", args[0])
logger.warning(e)
logger.warning("Cannot connect to qBittorrent. Wait 5 min and retry...")
await asyncio.sleep(300)
logger.warning(
"Cannot connect to qBittorrent. Wait %ds and retry...", delay
)
await asyncio.sleep(delay)
times += 1
return wrapper

View File

@@ -1,3 +1,4 @@
import sys
from pathlib import Path
from .config import VERSION, settings
@@ -10,4 +11,4 @@ LEGACY_DATA_PATH = Path("data/data.json")
VERSION_PATH = Path("config/version.info")
POSTERS_PATH = Path("data/posters")
PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix"
PLATFORM = "Windows" if sys.platform == "win32" else "Unix"

View File

@@ -126,7 +126,9 @@ class BangumiDatabase:
return None
def add_title_alias(self, bangumi_id: int, new_title_raw: str) -> bool:
def add_title_alias(
self, bangumi_id: int, new_title_raw: str, auto_commit: bool = True
) -> bool:
"""
Add a new title_raw alias to an existing bangumi.
@@ -152,8 +154,9 @@ class BangumiDatabase:
_set_aliases_list(bangumi, aliases)
self.session.add(bangumi)
self.session.commit()
_invalidate_bangumi_cache()
if auto_commit:
self.session.commit()
_invalidate_bangumi_cache()
logger.info(
f"[Database] Added alias '{new_title_raw}' to bangumi '{bangumi.official_title}' "
f"(id: {bangumi_id})"
@@ -233,8 +236,8 @@ class BangumiDatabase:
for d in to_add:
semantic_match = self.find_semantic_duplicate(d)
if semantic_match:
# Add as alias instead of creating new entry
self.add_title_alias(semantic_match.id, d.title_raw)
# Add as alias instead of creating new entry (defer commit)
self.add_title_alias(semantic_match.id, d.title_raw, auto_commit=False)
semantic_merged += 1
logger.info(
f"[Database] Merged '{d.title_raw}' as alias to existing "
@@ -254,9 +257,10 @@ class BangumiDatabase:
if not unique_to_add:
if semantic_merged > 0:
self.session.commit()
_invalidate_bangumi_cache()
logger.debug(
"[Database] %s bangumi merged as aliases, "
"rest were duplicates.",
"[Database] %s bangumi merged as aliases, " "rest were duplicates.",
semantic_merged,
)
else:
@@ -330,7 +334,9 @@ class BangumiDatabase:
self.session.add(bangumi)
self.session.commit()
_invalidate_bangumi_cache()
logger.debug("[Database] Update %s poster_link to %s.", title_raw, poster_link)
logger.debug(
"[Database] Update %s poster_link to %s.", title_raw, poster_link
)
def delete_one(self, _id: int):
statement = select(Bangumi).where(Bangumi.id == _id)
@@ -427,6 +433,7 @@ class BangumiDatabase:
rss_link not in match_data.rss_link
and match_data.title_raw not in rss_updated
):
match_data = self.session.merge(match_data)
match_data.rss_link += f",{rss_link}"
match_data.added = False
rss_updated.add(match_data.title_raw)
@@ -481,8 +488,8 @@ class BangumiDatabase:
conditions = select(Bangumi).where(
or_(
Bangumi.added == 0,
Bangumi.rule_name is None,
Bangumi.save_path is None,
Bangumi.rule_name.is_(None),
Bangumi.save_path.is_(None),
)
)
result = self.session.execute(conditions)

View File

@@ -199,17 +199,21 @@ class Database(Session):
if "title_aliases" in columns:
needs_run = False
if needs_run:
with self.engine.connect() as conn:
for stmt in statements:
conn.execute(text(stmt))
conn.commit()
logger.info(f"[Database] Migration v{version}: {description}")
try:
with self.engine.connect() as conn:
for stmt in statements:
conn.execute(text(stmt))
conn.commit()
logger.info(f"[Database] Migration v{version}: {description}")
except Exception as e:
logger.error(f"[Database] Migration v{version} failed: {e}")
break
else:
logger.debug(
f"[Database] Migration v{version} skipped (already applied): {description}"
)
self._set_schema_version(CURRENT_SCHEMA_VERSION)
logger.info(f"[Database] Schema version is now {CURRENT_SCHEMA_VERSION}.")
self._set_schema_version(version)
logger.info(f"[Database] Schema version is now {self._get_schema_version()}.")
self._fill_null_with_defaults()
def _get_field_default(self, field_info: FieldInfo) -> tuple[bool, Any]:
@@ -290,8 +294,8 @@ class Database(Session):
result = conn.execute(
text(
f"UPDATE {table_name} SET {field_name} = :val "
f"WHERE {field_name} IS NULL"
f'UPDATE "{table_name}" SET "{field_name}" = :val '
f'WHERE "{field_name}" IS NULL'
),
{"val": sql_value},
)
@@ -309,6 +313,9 @@ class Database(Session):
# Run migration online
bangumi_data = self.bangumi.search_all()
user_data = self.exec("SELECT * FROM user").all()
if not user_data:
logger.warning("[Database] No user data found, skipping migration.")
return
readd_bangumi = []
for bangumi in bangumi_data:
dict_data = bangumi.dict()
@@ -317,7 +324,10 @@ class Database(Session):
self.drop_table()
self.create_table()
self.commit()
bangumi_data = self.bangumi.search_all()
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()
try:
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()
except Exception:
self.rollback()
raise

View File

@@ -1,3 +1,4 @@
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import create_engine
@@ -10,4 +11,20 @@ engine = create_engine(DATA_PATH)
# Async engine (for passkey operations)
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
async_engine = create_async_engine(ASYNC_DATA_PATH)
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
async_session_factory = sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
@event.listens_for(engine, "connect")
def _set_sqlite_fk_sync(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
@event.listens_for(async_engine.sync_engine, "connect")
def _set_sqlite_fk_async(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()

View File

@@ -65,7 +65,7 @@ class RequestContent(RequestURL):
async def get_html(self, _url):
resp = await self.get_url(_url)
return resp.text
return resp.text if resp else None
async def get_content(self, _url):
req = await self.get_url(_url)

View File

@@ -1,15 +1,24 @@
import logging
logger = logging.getLogger(__name__)
def rss_parser(soup):
results = []
for item in soup.findall("./channel/item"):
title = item.find("title").text
enclosure = item.find("enclosure")
if enclosure is not None:
homepage = item.find("link").text
url = enclosure.attrib.get("url")
else:
url = item.find("link").text
homepage = ""
results.append((title, url, homepage))
try:
title = item.find("title").text
enclosure = item.find("enclosure")
if enclosure is not None:
homepage = item.find("link").text
url = enclosure.attrib.get("url")
else:
url = item.find("link").text
homepage = ""
results.append((title, url, homepage))
except Exception as e:
logger.warning("[RSS] Failed to parse RSS item: %s", e)
continue
return results

View File

@@ -1,16 +1,16 @@
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import BaseModel
from typing import Optional
from typing import Any, Optional
from openai import OpenAI, AzureOpenAI
from openai import AzureOpenAI, OpenAI
from pydantic import BaseModel
from module.models import Bangumi
logger = logging.getLogger(__name__)
class Episode(BaseModel):
title_en: Optional[str]
title_zh: Optional[str]
@@ -103,10 +103,15 @@ class OpenAIParser:
result = resp.choices[0].message.parsed
if asdict:
try:
result = json.loads(result[result.index("{"):result.rindex("}") + 1]) # find the first { and last } for better compatibility
except json.JSONDecodeError:
logger.warning(f"Cannot parse result {result} as python dict.")
if hasattr(result, "model_dump"):
result = result.model_dump()
else:
try:
result = json.loads(
result[result.index("{") : result.rindex("}") + 1]
) # find the first { and last } for better compatibility
except (json.JSONDecodeError, ValueError):
logger.warning(f"Cannot parse result {result} as python dict.")
logger.debug("the parsed result is: %s", result)
@@ -131,7 +136,6 @@ class OpenAIParser:
dict(role="user", content=text),
],
response_format=Episode,
# set temperature to 0 to make results be more stable and reproducible.
temperature=0,
)

View File

@@ -41,7 +41,7 @@ def pre_process(raw_name: str) -> str:
def prefix_process(raw: str, group: str) -> str:
raw = re.sub(f".{group}.", "", raw)
raw = re.sub(f".{re.escape(group)}.", "", raw)
raw_process = PREFIX_RE.sub("/", raw)
arg_group = raw_process.split("/")
while "" in arg_group:
@@ -50,9 +50,9 @@ def prefix_process(raw: str, group: str) -> str:
arg_group = arg_group[0].split(" ")
for arg in arg_group:
if re.search(r"新番|月?番", arg) and len(arg) <= 5:
raw = re.sub(f".{arg}.", "", raw)
raw = re.sub(f".{re.escape(arg)}.", "", raw)
elif re.search(r"港澳台地区", arg):
raw = re.sub(f".{arg}.", "", raw)
raw = re.sub(f".{re.escape(arg)}.", "", raw)
return raw
@@ -79,7 +79,7 @@ def season_process(season_info: str):
season = int(season_pro)
except ValueError:
season = CHINESE_NUMBER_MAP[season_pro]
break
break
return name, season_raw, season
@@ -140,6 +140,8 @@ def process(raw_title: str):
group = get_group(content_title)
# 翻译组的名字
match_obj = TITLE_RE.match(content_title)
if match_obj is None:
return None
# 处理标题
season_info, episode_info, other = list(
map(lambda x: x.strip(), match_obj.groups())

View File

@@ -1,6 +1,8 @@
import asyncio
import logging
import re
import time
from collections import OrderedDict
from dataclasses import dataclass
from module.conf import TMDB_API
@@ -12,7 +14,8 @@ logger = logging.getLogger(__name__)
TMDB_URL = "https://api.themoviedb.org"
# In-memory cache for TMDB lookups to avoid repeated API calls
_tmdb_cache: dict[str, "TMDBInfo | None"] = {}
_TMDB_CACHE_MAX = 512
_tmdb_cache: OrderedDict[str, "TMDBInfo | None"] = OrderedDict()
@dataclass
@@ -26,7 +29,9 @@ class TMDBInfo:
poster_link: str = None
series_status: str = None # "Ended", "Returning Series", etc.
season_episode_counts: dict[int, int] = None # {1: 13, 2: 12, ...}
virtual_season_starts: dict[int, list[int]] = None # {1: [1, 29], ...} - episode numbers where virtual seasons start
virtual_season_starts: dict[int, list[int]] = (
None # {1: [1, 29], ...} - episode numbers where virtual seasons start
)
def get_offset_for_season(self, season: int) -> int:
"""Calculate offset for a season (negative sum of all previous seasons' episodes).
@@ -64,7 +69,9 @@ async def is_animation(tv_id, language, req: RequestContent) -> bool:
return False
async def get_season_episode_air_dates(tv_id: int, season_number: int, language: str, req: RequestContent) -> list[dict]:
async def get_season_episode_air_dates(
tv_id: int, season_number: int, language: str, req: RequestContent
) -> list[dict]:
"""Get episode air dates for a season.
Returns:
@@ -122,13 +129,17 @@ def detect_virtual_seasons(episodes: list[dict], gap_months: int = 6) -> list[in
logger.debug(
"[TMDB] Detected virtual season break: %s days gap "
"between ep%s and ep%s",
days_diff, prev_ep['episode_number'], curr_ep['episode_number']
days_diff,
prev_ep["episode_number"],
curr_ep["episode_number"],
)
return virtual_season_starts
async def get_aired_episode_count(tv_id: int, season_number: int, language: str, req: RequestContent) -> int:
async def get_aired_episode_count(
tv_id: int, season_number: int, language: str, req: RequestContent
) -> int:
"""Get the count of episodes that have actually aired for a season.
Args:
@@ -162,20 +173,27 @@ async def get_aired_episode_count(tv_id: int, season_number: int, language: str,
# Invalid date format, skip this episode
continue
logger.debug("[TMDB] Season %s: %s aired of %s total episodes", season_number, aired_count, len(episodes))
logger.debug(
"[TMDB] Season %s: %s aired of %s total episodes",
season_number,
aired_count,
len(episodes),
)
return aired_count
def get_season(seasons: list) -> tuple[int, str]:
ss = [s for s in seasons if s["air_date"] is not None and "特别" not in s["season"]]
if not ss:
return 1, None
ss = sorted(ss, key=lambda e: e.get("air_date"), reverse=True)
for season in ss:
if re.search(r"\d 季", season.get("season")) is not None:
if re.search(r"\d+", season.get("season")) is not None:
date = season.get("air_date").split("-")
[year, _, _] = date
now_year = time.localtime().tm_year
if int(year) <= now_year:
return int(re.findall(r"\d", season.get("season"))[0]), season.get(
return int(re.findall(r"\d+", season.get("season"))[0]), season.get(
"poster_path"
)
return len(ss), ss[-1].get("poster_path")
@@ -200,11 +218,16 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None:
contents = contents_resp.get("results")
# 判断动画
if contents:
matched_id = None
for content in contents:
id = content["id"]
if await is_animation(id, language, req):
cid = content["id"]
if await is_animation(cid, language, req):
matched_id = cid
break
url_info = info_url(id, language)
if matched_id is None:
_tmdb_cache[cache_key] = None
return None
url_info = info_url(matched_id, language)
info_content = await req.get_json(url_info)
season = [
{
@@ -221,37 +244,58 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None:
# For ongoing series, we need to get actual aired episode counts
season_episode_counts = {}
virtual_season_starts = {}
for s in info_content.get("seasons", []):
season_num = s.get("season_number", 0)
if season_num > 0:
total_eps = s.get("episode_count", 0)
# Get episode air dates for virtual season detection
episodes = await get_season_episode_air_dates(id, season_num, language, req)
if episodes:
# Detect virtual seasons based on air date gaps
vs_starts = detect_virtual_seasons(episodes)
if len(vs_starts) > 1:
virtual_season_starts[season_num] = vs_starts
logger.debug("[TMDB] Season %s has virtual seasons starting at episodes: %s", season_num, vs_starts)
# Count only aired episodes
season_episode_counts[season_num] = len(episodes)
else:
season_episode_counts[season_num] = total_eps
season_nums = [
(s.get("season_number", 0), s.get("episode_count", 0))
for s in info_content.get("seasons", [])
if s.get("season_number", 0) > 0
]
episode_results = await asyncio.gather(
*[
get_season_episode_air_dates(matched_id, sn, language, req)
for sn, _ in season_nums
],
return_exceptions=True,
)
for (season_num, total_eps), episodes in zip(season_nums, episode_results):
if isinstance(episodes, Exception):
logger.warning(
"[TMDB] Failed to get episodes for season %s: %s",
season_num,
episodes,
)
season_episode_counts[season_num] = total_eps
continue
if episodes:
# Detect virtual seasons based on air date gaps
vs_starts = detect_virtual_seasons(episodes)
if len(vs_starts) > 1:
virtual_season_starts[season_num] = vs_starts
logger.debug(
"[TMDB] Season %s has virtual seasons starting at episodes: %s",
season_num,
vs_starts,
)
# Count only aired episodes
season_episode_counts[season_num] = len(episodes)
else:
season_episode_counts[season_num] = total_eps
if poster_path is None:
poster_path = info_content.get("poster_path")
original_title = info_content.get("original_name")
official_title = info_content.get("name")
year_number = info_content.get("first_air_date").split("-")[0]
year_number = (info_content.get("first_air_date") or "").split("-")[0]
if poster_path:
if not test:
img = await req.get_content(f"https://image.tmdb.org/t/p/w780{poster_path}")
img = await req.get_content(
f"https://image.tmdb.org/t/p/w780{poster_path}"
)
poster_link = save_image(img, "jpg")
else:
poster_link = "https://image.tmdb.org/t/p/w780" + poster_path
else:
poster_link = None
result = TMDBInfo(
id=id,
id=matched_id,
title=official_title,
original_title=original_title,
season=season,
@@ -260,15 +304,22 @@ async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None:
poster_link=poster_link,
series_status=series_status,
season_episode_counts=season_episode_counts,
virtual_season_starts=virtual_season_starts if virtual_season_starts else None,
virtual_season_starts=(
virtual_season_starts if virtual_season_starts else None
),
)
if len(_tmdb_cache) >= _TMDB_CACHE_MAX:
_tmdb_cache.popitem(last=False)
_tmdb_cache[cache_key] = result
return result
else:
if len(_tmdb_cache) >= _TMDB_CACHE_MAX:
_tmdb_cache.popitem(last=False)
_tmdb_cache[cache_key] = None
return None
if __name__ == "__main__":
import asyncio
print(asyncio.run(tmdb_parser("魔法禁书目录", "zh")))

View File

@@ -16,6 +16,7 @@ class RSSEngine(Database):
def __init__(self, _engine=engine):
super().__init__(_engine)
self._to_refresh = False
self._filter_cache: dict[str, re.Pattern] = {}
@staticmethod
async def _get_torrents(rss: RSSItem) -> list[Torrent]:
@@ -109,8 +110,6 @@ class RSSEngine(Database):
logger.warning(f"[Engine] Failed to fetch RSS {rss_item.name}: {e}")
return [], str(e)
_filter_cache: dict[str, re.Pattern] = {}
def _get_filter_pattern(self, filter_str: str) -> re.Pattern:
if filter_str not in self._filter_cache:
self._filter_cache[filter_str] = re.compile(