diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b70f6a71..65177082 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,15 +1,64 @@ +[project] +name = "auto-bangumi" +version = "3.1.0" +description = "AutoBangumi - Automated anime download manager" +requires-python = ">=3.10" +dependencies = [ + "anyio>=4.0.0", + "beautifulsoup4>=4.12.0", + "certifi>=2023.5.7", + "charset-normalizer>=3.1.0", + "click>=8.1.3", + "fastapi>=0.109.0", + "h11>=0.14.0", + "idna>=3.4", + "pydantic>=2.0.0", + "sniffio>=1.3.0", + "soupsieve>=2.4.1", + "typing_extensions>=4.0.0", + "urllib3>=2.0.3", + "uvicorn>=0.27.0", + "Jinja2>=3.1.2", + "python-dotenv>=1.0.0", + "python-jose>=3.3.0", + "passlib>=1.7.4", + "bcrypt>=4.0.1,<4.1", + "python-multipart>=0.0.6", + "sqlmodel>=0.0.14", + "sse-starlette>=1.6.5", + "semver>=3.0.1", + "openai>=1.54.3", + "httpx>=0.25.0", + "httpx-socks>=0.9.0", + "aiosqlite>=0.19.0", + "sqlalchemy[asyncio]>=2.0.0", + "webauthn>=2.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.1.0", + "black>=24.0.0", +] + +[tool.pytest.ini_options] +testpaths = ["src/test"] +asyncio_mode = "auto" + [tool.ruff] select = [ # pycodestyle(E): https://beta.ruff.rs/docs/rules/#pycodestyle-e-w - "E", + "E", # Pyflakes(F): https://beta.ruff.rs/docs/rules/#pyflakes-f - "F", + "F", # isort(I): https://beta.ruff.rs/docs/rules/#isort-i "I" ] ignore = [ # E501: https://beta.ruff.rs/docs/rules/line-too-long/ - 'E501', + 'E501', # F401: https://beta.ruff.rs/docs/rules/unused-import/ # avoid unused imports lint in `__init__.py` 'F401', diff --git a/backend/requirements.txt b/backend/requirements.txt index 448b57a1..6c3b848f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,29 +1,30 @@ -anyio==3.7.0 +anyio>=4.0.0 bs4==0.0.1 -certifi==2023.5.7 -charset-normalizer==3.1.0 -click==8.1.3 -fastapi==0.97.0 -h11==0.14.0 -idna==3.4 -pydantic~=1.10 -PySocks==1.7.1 -qbittorrent-api==2023.9.53 -requests==2.31.0 -six==1.16.0 -sniffio==1.3.0 -soupsieve==2.4.1 -typing_extensions -urllib3==2.0.3 -uvicorn==0.22.0 -attrdict==2.0.1 -Jinja2==3.1.2 -python-dotenv==1.0.0 -python-jose==3.3.0 -passlib==1.7.4 -bcrypt==4.0.1 -python-multipart==0.0.6 -sqlmodel==0.0.8 -sse-starlette==1.6.5 -semver==3.0.1 -openai==1.54.3 +certifi>=2023.5.7 +charset-normalizer>=3.1.0 +click>=8.1.3 +fastapi>=0.109.0 +h11>=0.14.0 +idna>=3.4 +pydantic>=2.0.0 +six>=1.16.0 +sniffio>=1.3.0 +soupsieve>=2.4.1 +typing_extensions>=4.0.0 +urllib3>=2.0.3 +uvicorn>=0.27.0 +Jinja2>=3.1.2 +python-dotenv>=1.0.0 +python-jose>=3.3.0 +passlib>=1.7.4 +bcrypt>=4.0.1 +python-multipart>=0.0.6 +sqlmodel>=0.0.14 +sse-starlette>=1.6.5 +semver>=3.0.1 +openai>=1.54.3 +httpx>=0.25.0 +httpx-socks>=0.9.0 +aiosqlite>=0.19.0 +sqlalchemy[asyncio]>=2.0.0 +webauthn>=2.0.0 diff --git a/backend/src/dev_server.py b/backend/src/dev_server.py new file mode 100644 index 00000000..dda12bbd --- /dev/null +++ b/backend/src/dev_server.py @@ -0,0 +1,47 @@ +"""Minimal dev server that skips downloader check for UI testing.""" +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi import APIRouter + +from module.database.combine import Database +from module.database.engine import engine + +# Initialize DB + migrations + default user +with Database(engine) as db: + db.create_table() + db.user.add_default_user() + +# Build v1 router without program router (which blocks on downloader check) +from module.api.auth import router as auth_router +from module.api.bangumi import router as bangumi_router +from module.api.config import router as config_router +from module.api.log import router as log_router +from module.api.rss import router as rss_router +from module.api.search import router as search_router + +v1 = APIRouter(prefix="/v1") +v1.include_router(auth_router) +v1.include_router(bangumi_router) +v1.include_router(config_router) +v1.include_router(log_router) +v1.include_router(rss_router) +v1.include_router(search_router) + +# Stub status endpoint (real one lives in program router which blocks on downloader) +@v1.get("/status") +async def stub_status(): + return {"status": True, "version": "dev"} + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.include_router(v1, prefix="/api") + +if __name__ == "__main__": + uvicorn.run(app, host="127.0.0.1", port=7892) diff --git a/backend/src/module/api/__init__.py b/backend/src/module/api/__init__.py index 38999e4d..651e22db 100644 --- a/backend/src/module/api/__init__.py +++ b/backend/src/module/api/__init__.py @@ -4,6 +4,7 @@ from .auth import router as auth_router from .bangumi import router as bangumi_router from .config import router as config_router from .log import router as log_router +from .passkey import router as passkey_router from .program import router as program_router from .rss import router as rss_router from .search import router as search_router @@ -13,6 +14,7 @@ __all__ = "v1" # API 1.0 v1 = APIRouter(prefix="/v1") v1.include_router(auth_router) +v1.include_router(passkey_router) v1.include_router(log_router) v1.include_router(program_router) v1.include_router(bangumi_router) diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 3ddd62b0..d24f5319 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -127,6 +127,17 @@ async def refresh_poster(bangumi_id: int): return u_response(resp) +@router.get( + path="/refresh/calendar", + response_model=APIResponse, + dependencies=[Depends(get_current_user)], +) +async def refresh_calendar(): + with TorrentManager() as manager: + resp = manager.refresh_calendar() + return u_response(resp) + + @router.get( "/reset/all", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) diff --git a/backend/src/module/api/passkey.py b/backend/src/module/api/passkey.py new file mode 100644 index 00000000..0275a9a2 --- /dev/null +++ b/backend/src/module/api/passkey.py @@ -0,0 +1,281 @@ +""" +Passkey 管理 API +用于注册、列表、删除 Passkey 凭证 +""" +import logging +from datetime import timedelta + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response +from sqlmodel import select + +from module.database.engine import async_session_factory +from module.database.passkey import PasskeyDatabase +from module.models import APIResponse +from module.models.passkey import ( + PasskeyAuthFinish, + PasskeyAuthStart, + PasskeyCreate, + PasskeyDelete, + PasskeyList, +) +from module.models.user import User +from module.security.api import active_user, get_current_user +from module.security.auth_strategy import PasskeyAuthStrategy +from module.security.jwt import create_access_token +from module.security.webauthn import get_webauthn_service + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/passkey", tags=["passkey"]) + + +def _get_webauthn_from_request(request: Request): + """ + 从请求中构造 WebAuthnService + 优先使用浏览器的 Origin header(与 clientDataJSON 中的 origin 一致) + """ + from urllib.parse import urlparse + + origin = request.headers.get("origin") + if not origin: + # Fallback: 从 Referer 或 Host 推断 + referer = request.headers.get("referer", "") + if referer: + parsed = urlparse(referer) + origin = f"{parsed.scheme}://{parsed.netloc}" + else: + host = request.headers.get("host", "localhost:7892") + forwarded_proto = request.headers.get("x-forwarded-proto") + scheme = forwarded_proto if forwarded_proto else request.url.scheme + origin = f"{scheme}://{host}" + + parsed_origin = urlparse(origin) + rp_id = parsed_origin.hostname or "localhost" + + return get_webauthn_service(rp_id, "AutoBangumi", origin) + + +# ============ 注册流程 ============ + + +@router.post("/register/options", response_model=dict) +async def get_registration_options( + request: Request, + username: str = Depends(get_current_user), +): + """ + 生成 Passkey 注册选项 + 前端调用 navigator.credentials.create() 时使用 + """ + webauthn = _get_webauthn_from_request(request) + + async with async_session_factory() as session: + try: + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get existing passkeys + passkey_db = PasskeyDatabase(session) + existing_passkeys = await passkey_db.get_passkeys_by_user_id(user.id) + + options = webauthn.generate_registration_options( + username=username, + user_id=user.id, + existing_passkeys=existing_passkeys, + ) + + return options + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to generate registration options: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/register/verify", response_model=APIResponse) +async def verify_registration( + passkey_data: PasskeyCreate, + request: Request, + username: str = Depends(get_current_user), +): + """ + 验证 Passkey 注册响应并保存 + """ + webauthn = _get_webauthn_from_request(request) + + async with async_session_factory() as session: + try: + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # 验证 WebAuthn 响应 + passkey = webauthn.verify_registration( + username=username, + credential=passkey_data.attestation_response, + device_name=passkey_data.name, + ) + + # 设置 user_id 并保存 + passkey.user_id = user.id + passkey_db = PasskeyDatabase(session) + await passkey_db.create_passkey(passkey) + + return JSONResponse( + status_code=200, + content={ + "msg_en": f"Passkey '{passkey_data.name}' registered successfully", + "msg_zh": f"Passkey '{passkey_data.name}' 注册成功", + }, + ) + + except ValueError as e: + logger.warning(f"Registration verification failed for {username}: {e}") + raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to register passkey: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============ 认证流程 ============ + + +@router.post("/auth/options", response_model=dict) +async def get_passkey_login_options( + auth_data: PasskeyAuthStart, + request: Request, +): + """ + 生成 Passkey 登录选项(challenge) + 前端先调用此接口,再调用 navigator.credentials.get() + """ + webauthn = _get_webauthn_from_request(request) + + async with async_session_factory() as session: + try: + # Get user + result = await session.execute( + select(User).where(User.username == auth_data.username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + passkey_db = PasskeyDatabase(session) + passkeys = await passkey_db.get_passkeys_by_user_id(user.id) + + if not passkeys: + raise HTTPException( + status_code=400, detail="No passkeys registered for this user" + ) + + options = webauthn.generate_authentication_options( + auth_data.username, passkeys + ) + return options + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to generate login options: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/auth/verify", response_model=dict) +async def login_with_passkey( + auth_data: PasskeyAuthFinish, + response: Response, + request: Request, +): + """ + 使用 Passkey 登录(替代密码登录) + """ + webauthn = _get_webauthn_from_request(request) + + strategy = PasskeyAuthStrategy(webauthn) + resp = await strategy.authenticate(auth_data.username, auth_data.credential) + + if resp.status: + token = create_access_token( + data={"sub": auth_data.username}, expires_delta=timedelta(days=1) + ) + response.set_cookie(key="token", value=token, httponly=True, max_age=86400) + if auth_data.username not in active_user: + active_user.append(auth_data.username) + return {"access_token": token, "token_type": "bearer"} + + raise HTTPException(status_code=resp.status_code, detail=resp.msg_en) + + +# ============ Passkey 管理 ============ + + +@router.get("/list", response_model=list[PasskeyList]) +async def list_passkeys(username: str = Depends(get_current_user)): + """获取用户的所有 Passkey""" + async with async_session_factory() as session: + try: + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + passkey_db = PasskeyDatabase(session) + passkeys = await passkey_db.get_passkeys_by_user_id(user.id) + + return [passkey_db.to_list_model(pk) for pk in passkeys] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list passkeys: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/delete", response_model=APIResponse) +async def delete_passkey( + delete_data: PasskeyDelete, + username: str = Depends(get_current_user), +): + """删除 Passkey""" + async with async_session_factory() as session: + try: + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + passkey_db = PasskeyDatabase(session) + await passkey_db.delete_passkey(delete_data.passkey_id, user.id) + + return JSONResponse( + status_code=200, + content={ + "msg_en": "Passkey deleted successfully", + "msg_zh": "Passkey 删除成功", + }, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete passkey: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/src/module/api/program.py b/backend/src/module/api/program.py index fb8f5c84..4a859e4b 100644 --- a/backend/src/module/api/program.py +++ b/backend/src/module/api/program.py @@ -24,7 +24,7 @@ async def startup(): @router.on_event("shutdown") async def shutdown(): - program.stop() + await program.stop() @router.get( @@ -69,7 +69,8 @@ async def start(): "/stop", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) async def stop(): - return u_response(program.stop()) + resp = await program.stop() + return u_response(resp) @router.get("/status", response_model=dict, dependencies=[Depends(get_current_user)]) @@ -92,7 +93,7 @@ async def program_status(): "/shutdown", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) async def shutdown_program(): - program.stop() + await program.stop() logger.info("Shutting down program...") os.kill(os.getpid(), signal.SIGINT) return JSONResponse( @@ -112,4 +113,4 @@ async def shutdown_program(): dependencies=[Depends(get_current_user)], ) async def check_downloader_status(): - return program.check_downloader() + return await program.check_downloader() diff --git a/backend/src/module/checker/checker.py b/backend/src/module/checker/checker.py index 4ff6b3a0..43196586 100644 --- a/backend/src/module/checker/checker.py +++ b/backend/src/module/checker/checker.py @@ -1,10 +1,9 @@ import logging from pathlib import Path -import requests +import httpx from module.conf import VERSION, settings -from module.downloader import DownloadClient from module.models import Config from module.update import version_check @@ -49,27 +48,28 @@ class Checker: return True @staticmethod - def check_downloader() -> bool: + async def check_downloader() -> bool: + from module.downloader import DownloadClient try: url = ( f"http://{settings.downloader.host}" if "://" not in settings.downloader.host else f"{settings.downloader.host}" ) - response = requests.get(url, timeout=2) - # if settings.downloader.type in response.text.lower(): + async with httpx.AsyncClient(timeout=2.0) as client: + response = await client.get(url) if "qbittorrent" in response.text.lower() or "vuetorrent" in response.text.lower(): - with DownloadClient() as client: - if client.authed: + async with DownloadClient() as dl_client: + if dl_client.authed: return True else: return False else: return False - except requests.exceptions.ReadTimeout: + except httpx.TimeoutException: logger.error("[Checker] Downloader connect timeout.") return False - except requests.exceptions.ConnectionError: + except httpx.ConnectError: logger.error("[Checker] Downloader connect failed.") return False except Exception as e: diff --git a/backend/src/module/core/program.py b/backend/src/module/core/program.py index 7534b848..6e5a4f48 100644 --- a/backend/src/module/core/program.py +++ b/backend/src/module/core/program.py @@ -16,14 +16,14 @@ from .sub_thread import RenameThread, RSSThread logger = logging.getLogger(__name__) figlet = r""" - _ ____ _ - /\ | | | _ \ (_) - / \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _ - / /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| | - / ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | | - /_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_| - __/ | - |___/ + _ ____ _ + /\ | | | _ \ (_) + / \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _ + / /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| | + / ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | | +/_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_| + __/ | + |___/ """ @@ -61,7 +61,7 @@ class Program(RenameThread, RSSThread): async def start(self): self.stop_event.clear() settings.load() - while not self.downloader_status: + while not await self.check_downloader_status(): logger.warning("Downloader is not running.") logger.info("Waiting for downloader to start.") await asyncio.sleep(30) @@ -77,11 +77,11 @@ class Program(RenameThread, RSSThread): msg_zh="程序启动成功。", ) - def stop(self): + async def stop(self): if self.is_running: self.stop_event.set() - self.rename_stop() - self.rss_stop() + await self.rename_stop() + await self.rss_stop() return ResponseModel( status=True, status_code=200, @@ -97,7 +97,7 @@ class Program(RenameThread, RSSThread): ) async def restart(self): - self.stop() + await self.stop() await self.start() return ResponseModel( status=True, diff --git a/backend/src/module/core/status.py b/backend/src/module/core/status.py index 152e4988..86cc8672 100644 --- a/backend/src/module/core/status.py +++ b/backend/src/module/core/status.py @@ -1,5 +1,4 @@ import asyncio -import threading from module.checker import Checker from module.conf import LEGACY_DATA_PATH @@ -8,8 +7,8 @@ from module.conf import LEGACY_DATA_PATH class ProgramStatus(Checker): def __init__(self): super().__init__() - self.stop_event = threading.Event() - self.lock = threading.Lock() + self.stop_event = asyncio.Event() + self.lock = asyncio.Lock() self._downloader_status = False self._torrents_status = False self.event = asyncio.Event() @@ -27,8 +26,11 @@ class ProgramStatus(Checker): @property def downloader_status(self): + return self._downloader_status + + async def check_downloader_status(self) -> bool: if not self._downloader_status: - self._downloader_status = self.check_downloader() + self._downloader_status = await self.check_downloader() return self._downloader_status @property diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index a809b748..d91d0eda 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -1,6 +1,10 @@ +import logging + +from sqlalchemy import inspect, text from sqlmodel import Session, SQLModel from module.models import Bangumi, User +from module.models.passkey import Passkey from .bangumi import BangumiDatabase from .engine import engine as e @@ -8,6 +12,8 @@ from .rss import RSSDatabase from .torrent import TorrentDatabase from .user import UserDatabase +logger = logging.getLogger(__name__) + class Database(Session): def __init__(self, engine=e): @@ -20,6 +26,20 @@ class Database(Session): def create_table(self): SQLModel.metadata.create_all(self.engine) + self._migrate_columns() + + def _migrate_columns(self): + """Add new columns to existing tables if they don't exist.""" + inspector = inspect(self.engine) + if "bangumi" in inspector.get_table_names(): + columns = [col["name"] for col in inspector.get_columns("bangumi")] + if "air_weekday" not in columns: + with self.engine.connect() as conn: + conn.execute( + text("ALTER TABLE bangumi ADD COLUMN air_weekday INTEGER") + ) + conn.commit() + logger.info("[Database] Migrated: added air_weekday column to bangumi table.") def drop_table(self): SQLModel.metadata.drop_all(self.engine) diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index c94e792c..5bdb6f64 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,7 +1,13 @@ -from sqlmodel import Session, create_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import create_engine from module.conf import DATA_PATH +# Sync engine (used by Database which extends Session) engine = create_engine(DATA_PATH) -db_session = Session(engine) +# Async engine (for passkey operations) +ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///") +async_engine = create_async_engine(ASYNC_DATA_PATH) +async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) diff --git a/backend/src/module/database/passkey.py b/backend/src/module/database/passkey.py new file mode 100644 index 00000000..17c9fc35 --- /dev/null +++ b/backend/src/module/database/passkey.py @@ -0,0 +1,78 @@ +""" +Passkey 数据库操作层 +""" +import logging +from datetime import datetime +from typing import List, Optional + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from module.models.passkey import Passkey, PasskeyList + +logger = logging.getLogger(__name__) + + +class PasskeyDatabase: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_passkey(self, passkey: Passkey) -> Passkey: + """创建新的 Passkey 凭证""" + self.session.add(passkey) + await self.session.commit() + await self.session.refresh(passkey) + logger.info(f"Created passkey '{passkey.name}' for user_id={passkey.user_id}") + return passkey + + async def get_passkey_by_credential_id( + self, credential_id: str + ) -> Optional[Passkey]: + """通过 credential_id 查找 Passkey(用于认证)""" + statement = select(Passkey).where(Passkey.credential_id == credential_id) + result = await self.session.execute(statement) + return result.scalar_one_or_none() + + async def get_passkeys_by_user_id(self, user_id: int) -> List[Passkey]: + """获取用户的所有 Passkey""" + statement = select(Passkey).where(Passkey.user_id == user_id) + result = await self.session.execute(statement) + return list(result.scalars().all()) + + async def get_passkey_by_id(self, passkey_id: int, user_id: int) -> Passkey: + """获取特定 Passkey(带权限检查)""" + statement = select(Passkey).where( + Passkey.id == passkey_id, Passkey.user_id == user_id + ) + result = await self.session.execute(statement) + passkey = result.scalar_one_or_none() + if not passkey: + raise HTTPException(status_code=404, detail="Passkey not found") + return passkey + + async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int): + """更新 Passkey 使用记录(签名计数器 + 最后使用时间)""" + passkey.sign_count = new_sign_count + passkey.last_used_at = datetime.utcnow() + self.session.add(passkey) + await self.session.commit() + + async def delete_passkey(self, passkey_id: int, user_id: int) -> bool: + """删除 Passkey""" + passkey = await self.get_passkey_by_id(passkey_id, user_id) + await self.session.delete(passkey) + await self.session.commit() + logger.info(f"Deleted passkey id={passkey_id} for user_id={user_id}") + return True + + def to_list_model(self, passkey: Passkey) -> PasskeyList: + """转换为安全的列表展示模型""" + return PasskeyList( + id=passkey.id, + name=passkey.name, + created_at=passkey.created_at, + last_used_at=passkey.last_used_at, + backup_eligible=passkey.backup_eligible, + aaguid=passkey.aaguid, + ) diff --git a/backend/src/module/manager/torrent.py b/backend/src/module/manager/torrent.py index cfe8d4a3..fd0cd715 100644 --- a/backend/src/module/manager/torrent.py +++ b/backend/src/module/manager/torrent.py @@ -4,6 +4,7 @@ from module.database import Database from module.downloader import DownloadClient from module.models import Bangumi, BangumiUpdate, ResponseModel from module.parser import TitleParser +from module.parser.analyser.bgm_calendar import fetch_bgm_calendar, match_weekday logger = logging.getLogger(__name__) @@ -154,6 +155,37 @@ class TorrentManager(Database): msg_zh="刷新海报链接成功。", ) + def refresh_calendar(self): + """Fetch Bangumi.tv calendar and update air_weekday for all bangumi.""" + calendar_items = fetch_bgm_calendar() + if not calendar_items: + return ResponseModel( + status_code=500, + status=False, + msg_en="Failed to fetch calendar data from Bangumi.tv.", + msg_zh="从 Bangumi.tv 获取放送表失败。", + ) + bangumis = self.bangumi.search_all() + updated = 0 + for bangumi in bangumis: + if bangumi.deleted: + continue + weekday = match_weekday( + bangumi.official_title, bangumi.title_raw, calendar_items + ) + if weekday is not None and weekday != bangumi.air_weekday: + bangumi.air_weekday = weekday + updated += 1 + if updated > 0: + self.bangumi.update_all(bangumis) + logger.info(f"[Manager] Calendar refresh: updated {updated} bangumi.") + return ResponseModel( + status_code=200, + status=True, + msg_en=f"Calendar refreshed. Updated {updated} anime.", + msg_zh=f"放送表已刷新,更新了 {updated} 部番剧。", + ) + def search_all_bangumi(self): datas = self.bangumi.search_all() if not datas: diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index bd50c0ec..d610684d 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -1,5 +1,6 @@ from .bangumi import Bangumi, BangumiUpdate, Episode, Notification from .config import Config +from .passkey import Passkey, PasskeyCreate, PasskeyDelete, PasskeyList from .response import APIResponse, ResponseModel from .rss import RSSItem, RSSUpdate from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index ebdaa5ee..d9b44153 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -27,6 +27,7 @@ class Bangumi(SQLModel, table=True): rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名") save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径") deleted: bool = Field(False, alias="deleted", title="是否已删除") + air_weekday: Optional[int] = Field(default=None, alias="air_weekday", title="放送星期") class BangumiUpdate(SQLModel): @@ -50,6 +51,7 @@ class BangumiUpdate(SQLModel): rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名") save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径") deleted: bool = Field(False, alias="deleted", title="是否已删除") + air_weekday: Optional[int] = Field(default=None, alias="air_weekday", title="放送星期") class Notification(BaseModel): diff --git a/backend/src/module/models/passkey.py b/backend/src/module/models/passkey.py new file mode 100644 index 00000000..0cfe047e --- /dev/null +++ b/backend/src/module/models/passkey.py @@ -0,0 +1,75 @@ +""" +WebAuthn Passkey 数据模型 +""" +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel +from sqlmodel import Field, SQLModel + + +class Passkey(SQLModel, table=True): + """存储 WebAuthn 凭证的数据库模型""" + + __tablename__ = "passkey" + + id: int = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="user.id", index=True) + + # 用户友好的名称 (e.g., "iPhone 15", "MacBook Pro") + name: str = Field(min_length=1, max_length=64) + + # WebAuthn 核心字段 + credential_id: str = Field(unique=True, index=True) # Base64URL encoded + public_key: str # CBOR encoded public key, Base64 stored + sign_count: int = Field(default=0) # 防止克隆攻击 + + # 可选的设备信息 + aaguid: Optional[str] = None # Authenticator AAGUID + transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"] + + # 审计字段 + created_at: datetime = Field(default_factory=datetime.utcnow) + last_used_at: Optional[datetime] = None + + # 备份状态 (是否为多设备凭证,如 iCloud Keychain) + backup_eligible: bool = Field(default=False) + backup_state: bool = Field(default=False) + + +class PasskeyCreate(BaseModel): + """创建 Passkey 的请求模型""" + + name: str = Field(min_length=1, max_length=64) + # 注册完成后的 WebAuthn 响应 + attestation_response: dict + + +class PasskeyList(BaseModel): + """返回给前端的 Passkey 列表(不含敏感数据)""" + + id: int + name: str + created_at: datetime + last_used_at: Optional[datetime] + backup_eligible: bool + aaguid: Optional[str] + + +class PasskeyDelete(BaseModel): + """删除 Passkey 请求""" + + passkey_id: int + + +class PasskeyAuthStart(BaseModel): + """Passkey 认证开始请求""" + + username: str + + +class PasskeyAuthFinish(BaseModel): + """Passkey 认证完成请求""" + + username: str + credential: dict diff --git a/backend/src/module/parser/analyser/bgm_calendar.py b/backend/src/module/parser/analyser/bgm_calendar.py new file mode 100644 index 00000000..d74e4607 --- /dev/null +++ b/backend/src/module/parser/analyser/bgm_calendar.py @@ -0,0 +1,88 @@ +import logging + +from module.network import RequestContent + +logger = logging.getLogger(__name__) + +BGM_CALENDAR_URL = "https://api.bgm.tv/calendar" + + +def fetch_bgm_calendar() -> list[dict]: + """Fetch the current season's broadcast calendar from Bangumi.tv API. + + Returns a flat list of anime items with their air_weekday (0=Mon, ..., 6=Sun). + """ + with RequestContent() as req: + data = req.get_json(BGM_CALENDAR_URL) + + if not data: + logger.warning("[BGM Calendar] Failed to fetch calendar data.") + return [] + + items = [] + for day_group in data: + weekday_info = day_group.get("weekday", {}) + # Bangumi.tv uses 1=Mon, 2=Tue, ..., 7=Sun + # Convert to 0=Mon, 1=Tue, ..., 6=Sun + bgm_weekday = weekday_info.get("id") + if bgm_weekday is None: + continue + weekday = bgm_weekday - 1 # 1-7 → 0-6 + + for item in day_group.get("items", []): + items.append({ + "name": item.get("name", ""), # Japanese title + "name_cn": item.get("name_cn", ""), # Chinese title + "air_weekday": weekday, + }) + + logger.info(f"[BGM Calendar] Fetched {len(items)} airing anime from Bangumi.tv.") + return items + + +def match_weekday(official_title: str, title_raw: str, calendar_items: list[dict]) -> int | None: + """Match a bangumi against calendar items to find its air weekday. + + Matching strategy: + 1. Exact match on Chinese title (name_cn == official_title) + 2. Exact match on Japanese title (name == title_raw or official_title) + 3. Substring match (name_cn in official_title or vice versa) + 4. Substring match on Japanese title + """ + official_title_clean = official_title.strip() + title_raw_clean = title_raw.strip() + + for item in calendar_items: + name_cn = item["name_cn"].strip() + name = item["name"].strip() + + if not name_cn and not name: + continue + + # Exact match on Chinese title + if name_cn and name_cn == official_title_clean: + return item["air_weekday"] + + # Exact match on Japanese/original title + if name and (name == title_raw_clean or name == official_title_clean): + return item["air_weekday"] + + # Second pass: substring matching + for item in calendar_items: + name_cn = item["name_cn"].strip() + name = item["name"].strip() + + if not name_cn and not name: + continue + + # Chinese title substring (at least 4 chars to avoid false positives) + if name_cn and len(name_cn) >= 4: + if name_cn in official_title_clean or official_title_clean in name_cn: + return item["air_weekday"] + + # Japanese title substring + if name and len(name) >= 4: + if name in title_raw_clean or title_raw_clean in name: + return item["air_weekday"] + + return None diff --git a/backend/src/module/security/auth_strategy.py b/backend/src/module/security/auth_strategy.py new file mode 100644 index 00000000..4b17a24c --- /dev/null +++ b/backend/src/module/security/auth_strategy.py @@ -0,0 +1,104 @@ +""" +认证策略抽象层 +将密码认证和 Passkey 认证统一为策略模式 +""" +from abc import ABC, abstractmethod + +from sqlmodel import select + +from module.database.engine import async_session_factory +from module.database.passkey import PasskeyDatabase +from module.models import ResponseModel +from module.models.user import User + + +class AuthStrategy(ABC): + """认证策略基类""" + + @abstractmethod + async def authenticate(self, username: str, credential: dict) -> ResponseModel: + """ + 执行认证 + + Args: + username: 用户名 + credential: 认证凭证(密码或 WebAuthn 响应) + + Returns: + ResponseModel with status and user info + """ + pass + + +class PasskeyAuthStrategy(AuthStrategy): + """Passkey 认证策略""" + + def __init__(self, webauthn_service): + self.webauthn_service = webauthn_service + + async def authenticate(self, username: str, credential: dict) -> ResponseModel: + """使用 WebAuthn Passkey 认证""" + async with async_session_factory() as session: + # 1. 查找用户 + try: + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise ValueError("User not found") + except ValueError: + return ResponseModel( + status_code=401, + status=False, + msg_en="User not found", + msg_zh="用户不存在", + ) + + # 2. 提取 credential_id 并查找对应的 passkey + try: + raw_id = credential.get("rawId") + if not raw_id: + raise ValueError("Missing credential ID") + + # 将 rawId 从 base64url 转换为标准格式 + credential_id_str = self.webauthn_service.base64url_encode( + self.webauthn_service.base64url_decode(raw_id) + ) + + passkey_db = PasskeyDatabase(session) + passkey = await passkey_db.get_passkey_by_credential_id(credential_id_str) + if not passkey or passkey.user_id != user.id: + raise ValueError("Passkey not found or not owned by user") + + except Exception: + return ResponseModel( + status_code=401, + status=False, + msg_en="Invalid passkey credential", + msg_zh="Passkey 凭证无效", + ) + + # 3. 验证 WebAuthn 签名 + try: + new_sign_count = self.webauthn_service.verify_authentication( + username, credential, passkey + ) + + # 4. 更新使用记录 + await passkey_db.update_passkey_usage(passkey, new_sign_count) + + return ResponseModel( + status_code=200, + status=True, + msg_en="Login successfully with passkey", + msg_zh="通过 Passkey 登录成功", + ) + + except ValueError as e: + return ResponseModel( + status_code=401, + status=False, + msg_en=f"Passkey verification failed: {str(e)}", + msg_zh=f"Passkey 验证失败: {str(e)}", + ) diff --git a/backend/src/module/security/webauthn.py b/backend/src/module/security/webauthn.py new file mode 100644 index 00000000..c7228fc1 --- /dev/null +++ b/backend/src/module/security/webauthn.py @@ -0,0 +1,277 @@ +""" +WebAuthn 认证服务层 +封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口 +""" +import base64 +import json +import logging +from typing import List, Optional + +from webauthn import ( + generate_authentication_options, + generate_registration_options, + options_to_json, + verify_authentication_response, + verify_registration_response, +) +from webauthn.helpers.cose import COSEAlgorithmIdentifier +from webauthn.helpers.structs import ( + AuthenticatorSelectionCriteria, + AuthenticatorTransport, + CredentialDeviceType, + PublicKeyCredentialDescriptor, + PublicKeyCredentialType, + ResidentKeyRequirement, + UserVerificationRequirement, +) + +from module.models.passkey import Passkey + +logger = logging.getLogger(__name__) + + +class WebAuthnService: + """WebAuthn 核心业务逻辑""" + + def __init__(self, rp_id: str, rp_name: str, origin: str): + """ + Args: + rp_id: 依赖方 ID (e.g., "localhost" or "autobangumi.example.com") + rp_name: 依赖方名称 (e.g., "AutoBangumi") + origin: 前端 origin (e.g., "http://localhost:5173") + """ + self.rp_id = rp_id + self.rp_name = rp_name + self.origin = origin + + # 存储临时的 challenge(生产环境应使用 Redis) + self._challenges: dict[str, bytes] = {} + + # ============ 注册流程 ============ + + def generate_registration_options( + self, username: str, user_id: int, existing_passkeys: List[Passkey] + ) -> dict: + """ + 生成 WebAuthn 注册选项 + + Args: + username: 用户名 + user_id: 用户 ID(转为 bytes) + existing_passkeys: 用户已有的 Passkey(用于排除) + + Returns: + JSON-serializable registration options + """ + # 将已有凭证转为排除列表 + exclude_credentials = [ + PublicKeyCredentialDescriptor( + id=self.base64url_decode(pk.credential_id), + type=PublicKeyCredentialType.PUBLIC_KEY, + transports=self._parse_transports(pk.transports), + ) + for pk in existing_passkeys + ] + + options = generate_registration_options( + rp_id=self.rp_id, + rp_name=self.rp_name, + user_id=str(user_id).encode("utf-8"), + user_name=username, + user_display_name=username, + exclude_credentials=exclude_credentials if exclude_credentials else None, + authenticator_selection=AuthenticatorSelectionCriteria( + resident_key=ResidentKeyRequirement.PREFERRED, + user_verification=UserVerificationRequirement.PREFERRED, + ), + supported_pub_key_algs=[ + COSEAlgorithmIdentifier.ECDSA_SHA_256, # -7: ES256 + COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, # -257: RS256 + ], + ) + + # 存储 challenge 用于后续验证 + challenge_key = f"reg_{username}" + self._challenges[challenge_key] = options.challenge + logger.debug(f"Generated registration challenge for {username}") + + return json.loads(options_to_json(options)) + + def verify_registration( + self, username: str, credential: dict, device_name: str + ) -> Passkey: + """ + 验证注册响应并创建 Passkey 对象 + + Args: + username: 用户名 + credential: 来自前端的 credential 响应 + device_name: 用户输入的设备名称 + + Returns: + Passkey 对象(未保存到数据库) + + Raises: + ValueError: 验证失败 + """ + challenge_key = f"reg_{username}" + expected_challenge = self._challenges.get(challenge_key) + if not expected_challenge: + raise ValueError("Challenge not found or expired") + + try: + verification = verify_registration_response( + credential=credential, + expected_challenge=expected_challenge, + expected_rp_id=self.rp_id, + expected_origin=self.origin, + ) + + # 构造 Passkey 对象 + passkey = Passkey( + user_id=0, # 调用方设置 + name=device_name, + credential_id=self.base64url_encode(verification.credential_id), + public_key=base64.b64encode(verification.credential_public_key).decode( + "utf-8" + ), + sign_count=verification.sign_count, + aaguid=verification.aaguid if verification.aaguid else None, + backup_eligible=verification.credential_device_type + == CredentialDeviceType.MULTI_DEVICE, + backup_state=verification.credential_backed_up, + ) + + logger.info( + f"Successfully verified registration for {username}, device: {device_name}" + ) + return passkey + + except Exception as e: + logger.error(f"Registration verification failed: {e}") + raise ValueError(f"Invalid registration response: {str(e)}") + finally: + # 清理使用过的 challenge(无论成功或失败都清理,防止重放攻击) + self._challenges.pop(challenge_key, None) + + # ============ 认证流程 ============ + + def generate_authentication_options( + self, username: str, passkeys: List[Passkey] + ) -> dict: + """ + 生成 WebAuthn 认证选项 + + Args: + username: 用户名 + passkeys: 用户的 Passkey 列表(限定可用凭证) + + Returns: + JSON-serializable authentication options + """ + allow_credentials = [ + PublicKeyCredentialDescriptor( + id=self.base64url_decode(pk.credential_id), + type=PublicKeyCredentialType.PUBLIC_KEY, + transports=self._parse_transports(pk.transports), + ) + for pk in passkeys + ] + + options = generate_authentication_options( + rp_id=self.rp_id, + allow_credentials=allow_credentials if allow_credentials else None, + user_verification=UserVerificationRequirement.PREFERRED, + ) + + # 存储 challenge + challenge_key = f"auth_{username}" + self._challenges[challenge_key] = options.challenge + logger.debug(f"Generated authentication challenge for {username}") + + return json.loads(options_to_json(options)) + + def verify_authentication( + self, username: str, credential: dict, passkey: Passkey + ) -> int: + """ + 验证认证响应 + + Args: + username: 用户名 + credential: 来自前端的 credential 响应 + passkey: 对应的 Passkey 对象 + + Returns: + 新的 sign_count(用于更新数据库) + + Raises: + ValueError: 验证失败 + """ + challenge_key = f"auth_{username}" + expected_challenge = self._challenges.get(challenge_key) + if not expected_challenge: + raise ValueError("Challenge not found or expired") + + try: + # 解码 public key + credential_public_key = base64.b64decode(passkey.public_key) + + verification = verify_authentication_response( + credential=credential, + expected_challenge=expected_challenge, + expected_rp_id=self.rp_id, + expected_origin=self.origin, + credential_public_key=credential_public_key, + credential_current_sign_count=passkey.sign_count, + ) + + logger.info(f"Successfully verified authentication for {username}") + return verification.new_sign_count + + except Exception as e: + logger.error(f"Authentication verification failed: {e}") + raise ValueError(f"Invalid authentication response: {str(e)}") + finally: + # 清理 challenge(无论成功或失败都清理,防止重放攻击) + self._challenges.pop(challenge_key, None) + + # ============ 辅助方法 ============ + + def _parse_transports( + self, transports_json: Optional[str] + ) -> List[AuthenticatorTransport]: + """解析存储的 transports JSON""" + if not transports_json: + return [] + try: + transport_strings = json.loads(transports_json) + return [AuthenticatorTransport(t) for t in transport_strings] + except Exception: + return [] + + def base64url_encode(self, data: bytes) -> str: + """Base64URL 编码(无 padding)""" + return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") + + def base64url_decode(self, data: str) -> bytes: + """Base64URL 解码(补齐 padding)""" + padding = 4 - len(data) % 4 + if padding != 4: + data += "=" * padding + return base64.urlsafe_b64decode(data) + + +# 全局 WebAuthn 服务实例存储 +_webauthn_services: dict[str, WebAuthnService] = {} + + +def get_webauthn_service(rp_id: str, rp_name: str, origin: str) -> WebAuthnService: + """ + 获取或创建 WebAuthnService 实例 + 使用缓存以保持 challenge 状态 + """ + key = f"{rp_id}:{origin}" + if key not in _webauthn_services: + _webauthn_services[key] = WebAuthnService(rp_id, rp_name, origin) + return _webauthn_services[key] diff --git a/backend/src/test_passkey_server.py b/backend/src/test_passkey_server.py new file mode 100644 index 00000000..83a4080e --- /dev/null +++ b/backend/src/test_passkey_server.py @@ -0,0 +1,35 @@ +""" +Minimal test server for passkey development. +Uses the real auth and passkey API routes without the downloader check. +Run with: uv run python test_passkey_server.py +""" +import uvicorn +from fastapi import FastAPI + +from module.api.auth import router as auth_router +from module.api.passkey import router as passkey_router +from module.database import Database +from module.update.startup import first_run + +app = FastAPI(title="AutoBangumi Passkey Test") + +# Mount real routers +app.include_router(auth_router, prefix="/api/v1") +app.include_router(passkey_router, prefix="/api/v1") + + +@app.on_event("startup") +async def startup(): + """Create tables and default user (no downloader check)""" + with Database() as db: + db.create_table() + db.user.add_default_user() + + +@app.get("/") +def index(): + return {"status": "Passkey test server running"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=7892) diff --git a/docs/changelog/3.2.md b/docs/changelog/3.2.md new file mode 100644 index 00000000..b2fd4a7c --- /dev/null +++ b/docs/changelog/3.2.md @@ -0,0 +1,55 @@ +# [3.2] - 2025-01 + +## Backend + +### Features + +- 新增 WebAuthn Passkey 无密码登录支持 + - 支持注册、认证、管理 Passkey 凭证 + - 支持多设备凭证(iCloud Keychain 等)备份检测 + - 支持克隆攻击防护(sign_count 验证) + - 认证策略模式,统一密码登录和 Passkey 登录接口 +- 数据库层新增异步支持(aiosqlite),为 Passkey 操作提供非阻塞 I/O +- `UserDatabase` 支持同步/异步双模式,兼容新旧代码路径 +- `Database` 上下文管理器同时支持 `with`(同步)和 `async with`(异步) + +### Changes + +- 升级 WebAuthn 依赖至 py_webauthn 2.7.0 +- `_get_webauthn_from_request` 优先使用浏览器 Origin header,修复跨端口开发环境下的验证问题 +- `auth_user` 和 `update_user_info` 转为异步函数 + +### Bugfixes + +- 修复 `aaguid` 类型错误(py_webauthn 2.7.0 中为 `str`,不再是 `bytes`) +- 修复 `credential_backup_eligible` 字段不存在的问题(改用 `credential_device_type`) +- 修复 `verify_authentication_response` 传入无效参数 `credential_id` 导致 TypeError +- 修复程序启动错误 +- 修复程序重启错误 +- 修复 episode 解析支持 int 和 float 类型 +- 修复 #805、#855 +- 修复多行标题解析后处理问题 +- 修复全局 RSS 过滤器需要重启才能生效的问题 + +## Frontend + +### Features + +- 全新 UI 设计系统重构 + - 统一的设计令牌(颜色、字体、间距、阴影、动画) + - 支持亮色/暗色主题切换 + - 完善的无障碍访问支持(ARIA、键盘导航、焦点管理) + - 响应式布局适配移动端 +- 新增 Passkey 管理面板(设置页) + - WebAuthn 浏览器支持检测 + - 设备名称自动识别 + - Passkey 列表展示与删除 +- 登录页新增 Passkey 指纹登录按钮 +- 新增可调比例图片组件 +- 新增移动端搜索样式 +- 优化移动端 Bangumi 列表样式 + +### Changes + +- 重构搜索逻辑,移除 rxjs 依赖 +- 升级前端依赖 diff --git a/webui/index.html b/webui/index.html index 356c7f49..6b8f6f66 100644 --- a/webui/index.html +++ b/webui/index.html @@ -4,11 +4,24 @@ - + + + +