mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-28 04:21:24 +08:00
feat: add WebAuthn passkey authentication support
- Add passkey login as alternative authentication method - Support multiple passkeys per user with custom names - Backend: WebAuthn service, auth strategy pattern, API endpoints - Frontend: passkey management UI in settings, login option - Fix: convert downloader check from sync requests to async httpx to prevent blocking the event loop when downloader unavailable Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from .auth import router as auth_router
|
||||
from .bangumi import router as bangumi_router
|
||||
from .config import router as config_router
|
||||
from .log import router as log_router
|
||||
from .passkey import router as passkey_router
|
||||
from .program import router as program_router
|
||||
from .rss import router as rss_router
|
||||
from .search import router as search_router
|
||||
@@ -13,6 +14,7 @@ __all__ = "v1"
|
||||
# API 1.0
|
||||
v1 = APIRouter(prefix="/v1")
|
||||
v1.include_router(auth_router)
|
||||
v1.include_router(passkey_router)
|
||||
v1.include_router(log_router)
|
||||
v1.include_router(program_router)
|
||||
v1.include_router(bangumi_router)
|
||||
|
||||
229
backend/src/module/api/passkey.py
Normal file
229
backend/src/module/api/passkey.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Passkey 管理 API
|
||||
用于注册、列表、删除 Passkey 凭证
|
||||
"""
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from module.database import Database
|
||||
from module.models import APIResponse
|
||||
from module.models.passkey import (
|
||||
PasskeyAuthFinish,
|
||||
PasskeyAuthStart,
|
||||
PasskeyCreate,
|
||||
PasskeyDelete,
|
||||
PasskeyList,
|
||||
)
|
||||
from module.security.api import active_user, get_current_user
|
||||
from module.security.auth_strategy import PasskeyAuthStrategy
|
||||
from module.security.jwt import create_access_token
|
||||
from module.security.webauthn import get_webauthn_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/passkey", tags=["passkey"])
|
||||
|
||||
|
||||
def _get_webauthn_from_request(request: Request):
|
||||
"""
|
||||
从请求中构造 WebAuthnService
|
||||
根据 Host header 动态确定 RP ID 和 origin
|
||||
"""
|
||||
host = request.headers.get("host", "localhost:7892")
|
||||
rp_id = host.split(":")[0] # 去掉端口
|
||||
|
||||
# 判断协议
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto")
|
||||
if forwarded_proto:
|
||||
scheme = forwarded_proto
|
||||
else:
|
||||
scheme = request.url.scheme
|
||||
|
||||
if scheme == "https":
|
||||
origin = f"https://{host}"
|
||||
else:
|
||||
origin = f"http://{host}"
|
||||
|
||||
return get_webauthn_service(rp_id, "AutoBangumi", origin)
|
||||
|
||||
|
||||
# ============ 注册流程 ============
|
||||
|
||||
|
||||
@router.post("/register/options", response_model=dict)
|
||||
async def get_registration_options(
|
||||
request: Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生成 Passkey 注册选项
|
||||
前端调用 navigator.credentials.create() 时使用
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
existing_passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
|
||||
options = webauthn.generate_registration_options(
|
||||
username=username,
|
||||
user_id=user.id,
|
||||
existing_passkeys=existing_passkeys,
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate registration options: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/register/verify", response_model=APIResponse)
|
||||
async def verify_registration(
|
||||
passkey_data: PasskeyCreate,
|
||||
request: Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
验证 Passkey 注册响应并保存
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
|
||||
# 验证 WebAuthn 响应
|
||||
passkey = webauthn.verify_registration(
|
||||
username=username,
|
||||
credential=passkey_data.attestation_response,
|
||||
device_name=passkey_data.name,
|
||||
)
|
||||
|
||||
# 设置 user_id 并保存
|
||||
passkey.user_id = user.id
|
||||
await db.passkey.create_passkey(passkey)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg_en": f"Passkey '{passkey_data.name}' registered successfully",
|
||||
"msg_zh": f"Passkey '{passkey_data.name}' 注册成功",
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Registration verification failed for {username}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register passkey: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============ 认证流程 ============
|
||||
|
||||
|
||||
@router.post("/auth/options", response_model=dict)
|
||||
async def get_passkey_login_options(
|
||||
auth_data: PasskeyAuthStart,
|
||||
request: Request,
|
||||
):
|
||||
"""
|
||||
生成 Passkey 登录选项(challenge)
|
||||
前端先调用此接口,再调用 navigator.credentials.get()
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
try:
|
||||
user = await db.user.get_user(auth_data.username)
|
||||
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
|
||||
if not passkeys:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No passkeys registered for this user"
|
||||
)
|
||||
|
||||
options = webauthn.generate_authentication_options(
|
||||
auth_data.username, passkeys
|
||||
)
|
||||
return options
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate login options: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=dict)
|
||||
async def login_with_passkey(
|
||||
auth_data: PasskeyAuthFinish,
|
||||
response: Response,
|
||||
request: Request,
|
||||
):
|
||||
"""
|
||||
使用 Passkey 登录(替代密码登录)
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
strategy = PasskeyAuthStrategy(webauthn)
|
||||
resp = await strategy.authenticate(auth_data.username, auth_data.credential)
|
||||
|
||||
if resp.status:
|
||||
token = create_access_token(
|
||||
data={"sub": auth_data.username}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
|
||||
if auth_data.username not in active_user:
|
||||
active_user.append(auth_data.username)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
raise HTTPException(status_code=resp.status_code, detail=resp.msg_en)
|
||||
|
||||
|
||||
# ============ Passkey 管理 ============
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[PasskeyList])
|
||||
async def list_passkeys(username: str = Depends(get_current_user)):
|
||||
"""获取用户的所有 Passkey"""
|
||||
async with Database() as db:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
|
||||
return [db.passkey.to_list_model(pk) for pk in passkeys]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list passkeys: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/delete", response_model=APIResponse)
|
||||
async def delete_passkey(
|
||||
delete_data: PasskeyDelete,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
"""删除 Passkey"""
|
||||
async with Database() as db:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
await db.passkey.delete_passkey(delete_data.passkey_id, user.id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg_en": "Passkey deleted successfully",
|
||||
"msg_zh": "Passkey 删除成功",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete passkey: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -24,7 +24,7 @@ async def startup():
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown():
|
||||
program.stop()
|
||||
await program.stop()
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -69,7 +69,8 @@ async def start():
|
||||
"/stop", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def stop():
|
||||
return u_response(program.stop())
|
||||
resp = await program.stop()
|
||||
return u_response(resp)
|
||||
|
||||
|
||||
@router.get("/status", response_model=dict, dependencies=[Depends(get_current_user)])
|
||||
@@ -92,7 +93,7 @@ async def program_status():
|
||||
"/shutdown", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def shutdown_program():
|
||||
program.stop()
|
||||
await program.stop()
|
||||
logger.info("Shutting down program...")
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
return JSONResponse(
|
||||
@@ -112,4 +113,4 @@ async def shutdown_program():
|
||||
dependencies=[Depends(get_current_user)],
|
||||
)
|
||||
async def check_downloader_status():
|
||||
return program.check_downloader()
|
||||
return await program.check_downloader()
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from module.conf import VERSION, settings
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import Config
|
||||
from module.update import version_check
|
||||
|
||||
@@ -49,27 +48,28 @@ class Checker:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_downloader() -> bool:
|
||||
async def check_downloader() -> bool:
|
||||
from module.downloader import DownloadClient
|
||||
try:
|
||||
url = (
|
||||
f"http://{settings.downloader.host}"
|
||||
if "://" not in settings.downloader.host
|
||||
else f"{settings.downloader.host}"
|
||||
)
|
||||
response = requests.get(url, timeout=2)
|
||||
# if settings.downloader.type in response.text.lower():
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(url)
|
||||
if "qbittorrent" in response.text.lower() or "vuetorrent" in response.text.lower():
|
||||
with DownloadClient() as client:
|
||||
if client.authed:
|
||||
async with DownloadClient() as dl_client:
|
||||
if dl_client.authed:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
except requests.exceptions.ReadTimeout:
|
||||
except httpx.TimeoutException:
|
||||
logger.error("[Checker] Downloader connect timeout.")
|
||||
return False
|
||||
except requests.exceptions.ConnectionError:
|
||||
except httpx.ConnectError:
|
||||
logger.error("[Checker] Downloader connect failed.")
|
||||
return False
|
||||
except Exception as e:
|
||||
|
||||
@@ -16,14 +16,14 @@ from .sub_thread import RenameThread, RSSThread
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
figlet = r"""
|
||||
_ ____ _
|
||||
/\ | | | _ \ (_)
|
||||
/ \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _
|
||||
/ /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| |
|
||||
/ ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | |
|
||||
/_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_|
|
||||
__/ |
|
||||
|___/
|
||||
_ ____ _
|
||||
/\ | | | _ \ (_)
|
||||
/ \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _
|
||||
/ /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| |
|
||||
/ ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | |
|
||||
/_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_|
|
||||
__/ |
|
||||
|___/
|
||||
"""
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class Program(RenameThread, RSSThread):
|
||||
async def start(self):
|
||||
self.stop_event.clear()
|
||||
settings.load()
|
||||
while not self.downloader_status:
|
||||
while not await self.check_downloader_status():
|
||||
logger.warning("Downloader is not running.")
|
||||
logger.info("Waiting for downloader to start.")
|
||||
await asyncio.sleep(30)
|
||||
@@ -77,11 +77,11 @@ class Program(RenameThread, RSSThread):
|
||||
msg_zh="程序启动成功。",
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
if self.is_running:
|
||||
self.stop_event.set()
|
||||
self.rename_stop()
|
||||
self.rss_stop()
|
||||
await self.rename_stop()
|
||||
await self.rss_stop()
|
||||
return ResponseModel(
|
||||
status=True,
|
||||
status_code=200,
|
||||
@@ -97,7 +97,7 @@ class Program(RenameThread, RSSThread):
|
||||
)
|
||||
|
||||
async def restart(self):
|
||||
self.stop()
|
||||
await self.stop()
|
||||
await self.start()
|
||||
return ResponseModel(
|
||||
status=True,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
from module.checker import Checker
|
||||
from module.conf import LEGACY_DATA_PATH
|
||||
@@ -8,8 +7,8 @@ from module.conf import LEGACY_DATA_PATH
|
||||
class ProgramStatus(Checker):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stop_event = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
self.stop_event = asyncio.Event()
|
||||
self.lock = asyncio.Lock()
|
||||
self._downloader_status = False
|
||||
self._torrents_status = False
|
||||
self.event = asyncio.Event()
|
||||
@@ -27,8 +26,11 @@ class ProgramStatus(Checker):
|
||||
|
||||
@property
|
||||
def downloader_status(self):
|
||||
return self._downloader_status
|
||||
|
||||
async def check_downloader_status(self) -> bool:
|
||||
if not self._downloader_status:
|
||||
self._downloader_status = self.check_downloader()
|
||||
self._downloader_status = await self.check_downloader()
|
||||
return self._downloader_status
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,42 +1,64 @@
|
||||
from sqlmodel import Session, SQLModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from module.models import Bangumi, User
|
||||
from module.models import Bangumi, Passkey, User
|
||||
|
||||
from .bangumi import BangumiDatabase
|
||||
from .engine import engine as e
|
||||
from .engine import async_session_factory, engine as e
|
||||
from .passkey import PasskeyDatabase
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
from .user import UserDatabase
|
||||
|
||||
|
||||
class Database(Session):
|
||||
def __init__(self, engine=e):
|
||||
self.engine = engine
|
||||
super().__init__(engine)
|
||||
self.rss = RSSDatabase(self)
|
||||
self.torrent = TorrentDatabase(self)
|
||||
self.bangumi = BangumiDatabase(self)
|
||||
self.user = UserDatabase(self)
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._session: AsyncSession | None = None
|
||||
self.rss: RSSDatabase | None = None
|
||||
self.torrent: TorrentDatabase | None = None
|
||||
self.bangumi: BangumiDatabase | None = None
|
||||
self.user: UserDatabase | None = None
|
||||
self.passkey: PasskeyDatabase | None = None
|
||||
|
||||
def create_table(self):
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
async def __aenter__(self):
|
||||
self._session = async_session_factory()
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
self.passkey = PasskeyDatabase(self._session)
|
||||
return self
|
||||
|
||||
def drop_table(self):
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
def migrate(self):
|
||||
async def create_table(self):
|
||||
async with e.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async def drop_table(self):
|
||||
async with e.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
async def commit(self):
|
||||
if self._session:
|
||||
await self._session.commit()
|
||||
|
||||
async def add(self, obj):
|
||||
if self._session:
|
||||
self._session.add(obj)
|
||||
await self._session.commit()
|
||||
|
||||
async def migrate(self):
|
||||
# Run migration online
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
user_data = self.exec("SELECT * FROM user").all()
|
||||
bangumi_data = await self.bangumi.search_all()
|
||||
readd_bangumi = []
|
||||
for bangumi in bangumi_data:
|
||||
dict_data = bangumi.dict()
|
||||
del dict_data["id"]
|
||||
readd_bangumi.append(Bangumi(**dict_data))
|
||||
self.drop_table()
|
||||
self.create_table()
|
||||
self.commit()
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
self.bangumi.add_all(readd_bangumi)
|
||||
self.add(User(**user_data[0]))
|
||||
self.commit()
|
||||
await self.drop_table()
|
||||
await self.create_table()
|
||||
await self.commit()
|
||||
await self.bangumi.add_all(readd_bangumi)
|
||||
|
||||
78
backend/src/module/database/passkey.py
Normal file
78
backend/src/module/database/passkey.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Passkey 数据库操作层
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from module.models.passkey import Passkey, PasskeyList
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasskeyDatabase:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_passkey(self, passkey: Passkey) -> Passkey:
|
||||
"""创建新的 Passkey 凭证"""
|
||||
self.session.add(passkey)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(passkey)
|
||||
logger.info(f"Created passkey '{passkey.name}' for user_id={passkey.user_id}")
|
||||
return passkey
|
||||
|
||||
async def get_passkey_by_credential_id(
|
||||
self, credential_id: str
|
||||
) -> Optional[Passkey]:
|
||||
"""通过 credential_id 查找 Passkey(用于认证)"""
|
||||
statement = select(Passkey).where(Passkey.credential_id == credential_id)
|
||||
result = await self.session.execute(statement)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_passkeys_by_user_id(self, user_id: int) -> List[Passkey]:
|
||||
"""获取用户的所有 Passkey"""
|
||||
statement = select(Passkey).where(Passkey.user_id == user_id)
|
||||
result = await self.session.execute(statement)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_passkey_by_id(self, passkey_id: int, user_id: int) -> Passkey:
|
||||
"""获取特定 Passkey(带权限检查)"""
|
||||
statement = select(Passkey).where(
|
||||
Passkey.id == passkey_id, Passkey.user_id == user_id
|
||||
)
|
||||
result = await self.session.execute(statement)
|
||||
passkey = result.scalar_one_or_none()
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=404, detail="Passkey not found")
|
||||
return passkey
|
||||
|
||||
async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int):
|
||||
"""更新 Passkey 使用记录(签名计数器 + 最后使用时间)"""
|
||||
passkey.sign_count = new_sign_count
|
||||
passkey.last_used_at = datetime.utcnow()
|
||||
self.session.add(passkey)
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_passkey(self, passkey_id: int, user_id: int) -> bool:
|
||||
"""删除 Passkey"""
|
||||
passkey = await self.get_passkey_by_id(passkey_id, user_id)
|
||||
await self.session.delete(passkey)
|
||||
await self.session.commit()
|
||||
logger.info(f"Deleted passkey id={passkey_id} for user_id={user_id}")
|
||||
return True
|
||||
|
||||
def to_list_model(self, passkey: Passkey) -> PasskeyList:
|
||||
"""转换为安全的列表展示模型"""
|
||||
return PasskeyList(
|
||||
id=passkey.id,
|
||||
name=passkey.name,
|
||||
created_at=passkey.created_at,
|
||||
last_used_at=passkey.last_used_at,
|
||||
backup_eligible=passkey.backup_eligible,
|
||||
aaguid=passkey.aaguid,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
from .bangumi import Bangumi, BangumiUpdate, Episode, Notification
|
||||
from .config import Config
|
||||
from .passkey import Passkey, PasskeyCreate, PasskeyDelete, PasskeyList
|
||||
from .response import APIResponse, ResponseModel
|
||||
from .rss import RSSItem, RSSUpdate
|
||||
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
|
||||
|
||||
75
backend/src/module/models/passkey.py
Normal file
75
backend/src/module/models/passkey.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
WebAuthn Passkey 数据模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Passkey(SQLModel, table=True):
|
||||
"""存储 WebAuthn 凭证的数据库模型"""
|
||||
|
||||
__tablename__ = "passkey"
|
||||
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id", index=True)
|
||||
|
||||
# 用户友好的名称 (e.g., "iPhone 15", "MacBook Pro")
|
||||
name: str = Field(min_length=1, max_length=64)
|
||||
|
||||
# WebAuthn 核心字段
|
||||
credential_id: str = Field(unique=True, index=True) # Base64URL encoded
|
||||
public_key: str # CBOR encoded public key, Base64 stored
|
||||
sign_count: int = Field(default=0) # 防止克隆攻击
|
||||
|
||||
# 可选的设备信息
|
||||
aaguid: Optional[str] = None # Authenticator AAGUID
|
||||
transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"]
|
||||
|
||||
# 审计字段
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_used_at: Optional[datetime] = None
|
||||
|
||||
# 备份状态 (是否为多设备凭证,如 iCloud Keychain)
|
||||
backup_eligible: bool = Field(default=False)
|
||||
backup_state: bool = Field(default=False)
|
||||
|
||||
|
||||
class PasskeyCreate(BaseModel):
|
||||
"""创建 Passkey 的请求模型"""
|
||||
|
||||
name: str = Field(min_length=1, max_length=64)
|
||||
# 注册完成后的 WebAuthn 响应
|
||||
attestation_response: dict
|
||||
|
||||
|
||||
class PasskeyList(BaseModel):
|
||||
"""返回给前端的 Passkey 列表(不含敏感数据)"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
backup_eligible: bool
|
||||
aaguid: Optional[str]
|
||||
|
||||
|
||||
class PasskeyDelete(BaseModel):
|
||||
"""删除 Passkey 请求"""
|
||||
|
||||
passkey_id: int
|
||||
|
||||
|
||||
class PasskeyAuthStart(BaseModel):
|
||||
"""Passkey 认证开始请求"""
|
||||
|
||||
username: str
|
||||
|
||||
|
||||
class PasskeyAuthFinish(BaseModel):
|
||||
"""Passkey 认证完成请求"""
|
||||
|
||||
username: str
|
||||
credential: dict
|
||||
115
backend/src/module/security/auth_strategy.py
Normal file
115
backend/src/module/security/auth_strategy.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
认证策略抽象层
|
||||
将密码认证和 Passkey 认证统一为策略模式
|
||||
"""
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from module.database import Database
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User
|
||||
|
||||
|
||||
class AuthStrategy(ABC):
|
||||
"""认证策略基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
|
||||
"""
|
||||
执行认证
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
credential: 认证凭证(密码或 WebAuthn 响应)
|
||||
|
||||
Returns:
|
||||
ResponseModel with status and user info
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PasswordAuthStrategy(AuthStrategy):
|
||||
"""密码认证策略(保持向后兼容)"""
|
||||
|
||||
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
|
||||
"""使用密码认证"""
|
||||
password = credential.get("password")
|
||||
if not password:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="Password is required",
|
||||
msg_zh="密码不能为空",
|
||||
)
|
||||
|
||||
user = User(username=username, password=password)
|
||||
async with Database() as db:
|
||||
return await db.user.auth_user(user)
|
||||
|
||||
|
||||
class PasskeyAuthStrategy(AuthStrategy):
|
||||
"""Passkey 认证策略"""
|
||||
|
||||
def __init__(self, webauthn_service):
|
||||
self.webauthn_service = webauthn_service
|
||||
|
||||
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
|
||||
"""使用 WebAuthn Passkey 认证"""
|
||||
async with Database() as db:
|
||||
# 1. 查找用户
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
except Exception:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="User not found",
|
||||
msg_zh="用户不存在",
|
||||
)
|
||||
|
||||
# 2. 提取 credential_id 并查找对应的 passkey
|
||||
try:
|
||||
raw_id = credential.get("rawId")
|
||||
if not raw_id:
|
||||
raise ValueError("Missing credential ID")
|
||||
|
||||
# 将 rawId 从 base64url 转换为标准格式
|
||||
credential_id_str = self.webauthn_service.base64url_encode(
|
||||
self.webauthn_service.base64url_decode(raw_id)
|
||||
)
|
||||
|
||||
passkey = await db.passkey.get_passkey_by_credential_id(credential_id_str)
|
||||
if not passkey or passkey.user_id != user.id:
|
||||
raise ValueError("Passkey not found or not owned by user")
|
||||
|
||||
except Exception as e:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="Invalid passkey credential",
|
||||
msg_zh="Passkey 凭证无效",
|
||||
)
|
||||
|
||||
# 3. 验证 WebAuthn 签名
|
||||
try:
|
||||
new_sign_count = self.webauthn_service.verify_authentication(
|
||||
username, credential, passkey
|
||||
)
|
||||
|
||||
# 4. 更新使用记录
|
||||
await db.passkey.update_passkey_usage(passkey, new_sign_count)
|
||||
|
||||
return ResponseModel(
|
||||
status_code=200,
|
||||
status=True,
|
||||
msg_en="Login successfully with passkey",
|
||||
msg_zh="通过 Passkey 登录成功",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en=f"Passkey verification failed: {str(e)}",
|
||||
msg_zh=f"Passkey 验证失败: {str(e)}",
|
||||
)
|
||||
277
backend/src/module/security/webauthn.py
Normal file
277
backend/src/module/security/webauthn.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
WebAuthn 认证服务层
|
||||
封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from webauthn import (
|
||||
generate_authentication_options,
|
||||
generate_registration_options,
|
||||
options_to_json,
|
||||
verify_authentication_response,
|
||||
verify_registration_response,
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
AuthenticatorTransport,
|
||||
PublicKeyCredentialDescriptor,
|
||||
PublicKeyCredentialType,
|
||||
ResidentKeyRequirement,
|
||||
UserVerificationRequirement,
|
||||
)
|
||||
|
||||
from module.models.passkey import Passkey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebAuthnService:
|
||||
"""WebAuthn 核心业务逻辑"""
|
||||
|
||||
def __init__(self, rp_id: str, rp_name: str, origin: str):
|
||||
"""
|
||||
Args:
|
||||
rp_id: 依赖方 ID (e.g., "localhost" or "autobangumi.example.com")
|
||||
rp_name: 依赖方名称 (e.g., "AutoBangumi")
|
||||
origin: 前端 origin (e.g., "http://localhost:5173")
|
||||
"""
|
||||
self.rp_id = rp_id
|
||||
self.rp_name = rp_name
|
||||
self.origin = origin
|
||||
|
||||
# 存储临时的 challenge(生产环境应使用 Redis)
|
||||
self._challenges: dict[str, bytes] = {}
|
||||
|
||||
# ============ 注册流程 ============
|
||||
|
||||
def generate_registration_options(
|
||||
self, username: str, user_id: int, existing_passkeys: List[Passkey]
|
||||
) -> dict:
|
||||
"""
|
||||
生成 WebAuthn 注册选项
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
user_id: 用户 ID(转为 bytes)
|
||||
existing_passkeys: 用户已有的 Passkey(用于排除)
|
||||
|
||||
Returns:
|
||||
JSON-serializable registration options
|
||||
"""
|
||||
# 将已有凭证转为排除列表
|
||||
exclude_credentials = [
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=self.base64url_decode(pk.credential_id),
|
||||
type=PublicKeyCredentialType.PUBLIC_KEY,
|
||||
transports=self._parse_transports(pk.transports),
|
||||
)
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=self.rp_id,
|
||||
rp_name=self.rp_name,
|
||||
user_id=str(user_id).encode("utf-8"),
|
||||
user_name=username,
|
||||
user_display_name=username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.PREFERRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
),
|
||||
supported_pub_key_algs=[
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256, # -7: ES256
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, # -257: RS256
|
||||
],
|
||||
)
|
||||
|
||||
# 存储 challenge 用于后续验证
|
||||
challenge_key = f"reg_{username}"
|
||||
self._challenges[challenge_key] = options.challenge
|
||||
logger.debug(f"Generated registration challenge for {username}")
|
||||
|
||||
return json.loads(options_to_json(options))
|
||||
|
||||
def verify_registration(
|
||||
self, username: str, credential: dict, device_name: str
|
||||
) -> Passkey:
|
||||
"""
|
||||
验证注册响应并创建 Passkey 对象
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
credential: 来自前端的 credential 响应
|
||||
device_name: 用户输入的设备名称
|
||||
|
||||
Returns:
|
||||
Passkey 对象(未保存到数据库)
|
||||
|
||||
Raises:
|
||||
ValueError: 验证失败
|
||||
"""
|
||||
challenge_key = f"reg_{username}"
|
||||
expected_challenge = self._challenges.get(challenge_key)
|
||||
if not expected_challenge:
|
||||
raise ValueError("Challenge not found or expired")
|
||||
|
||||
try:
|
||||
verification = verify_registration_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=self.rp_id,
|
||||
expected_origin=self.origin,
|
||||
)
|
||||
|
||||
# 构造 Passkey 对象
|
||||
passkey = Passkey(
|
||||
user_id=0, # 调用方设置
|
||||
name=device_name,
|
||||
credential_id=self.base64url_encode(verification.credential_id),
|
||||
public_key=base64.b64encode(verification.credential_public_key).decode(
|
||||
"utf-8"
|
||||
),
|
||||
sign_count=verification.sign_count,
|
||||
aaguid=verification.aaguid.hex() if verification.aaguid else None,
|
||||
backup_eligible=verification.credential_backup_eligible,
|
||||
backup_state=verification.credential_backed_up,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully verified registration for {username}, device: {device_name}"
|
||||
)
|
||||
return passkey
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Registration verification failed: {e}")
|
||||
raise ValueError(f"Invalid registration response: {str(e)}")
|
||||
finally:
|
||||
# 清理使用过的 challenge(无论成功或失败都清理,防止重放攻击)
|
||||
self._challenges.pop(challenge_key, None)
|
||||
|
||||
# ============ 认证流程 ============
|
||||
|
||||
def generate_authentication_options(
|
||||
self, username: str, passkeys: List[Passkey]
|
||||
) -> dict:
|
||||
"""
|
||||
生成 WebAuthn 认证选项
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
passkeys: 用户的 Passkey 列表(限定可用凭证)
|
||||
|
||||
Returns:
|
||||
JSON-serializable authentication options
|
||||
"""
|
||||
allow_credentials = [
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=self.base64url_decode(pk.credential_id),
|
||||
type=PublicKeyCredentialType.PUBLIC_KEY,
|
||||
transports=self._parse_transports(pk.transports),
|
||||
)
|
||||
for pk in passkeys
|
||||
]
|
||||
|
||||
options = generate_authentication_options(
|
||||
rp_id=self.rp_id,
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
# 存储 challenge
|
||||
challenge_key = f"auth_{username}"
|
||||
self._challenges[challenge_key] = options.challenge
|
||||
logger.debug(f"Generated authentication challenge for {username}")
|
||||
|
||||
return json.loads(options_to_json(options))
|
||||
|
||||
def verify_authentication(
|
||||
self, username: str, credential: dict, passkey: Passkey
|
||||
) -> int:
|
||||
"""
|
||||
验证认证响应
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
credential: 来自前端的 credential 响应
|
||||
passkey: 对应的 Passkey 对象
|
||||
|
||||
Returns:
|
||||
新的 sign_count(用于更新数据库)
|
||||
|
||||
Raises:
|
||||
ValueError: 验证失败
|
||||
"""
|
||||
challenge_key = f"auth_{username}"
|
||||
expected_challenge = self._challenges.get(challenge_key)
|
||||
if not expected_challenge:
|
||||
raise ValueError("Challenge not found or expired")
|
||||
|
||||
try:
|
||||
# 解码 public key
|
||||
credential_public_key = base64.b64decode(passkey.public_key)
|
||||
credential_id = self.base64url_decode(passkey.credential_id)
|
||||
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=self.rp_id,
|
||||
expected_origin=self.origin,
|
||||
credential_public_key=credential_public_key,
|
||||
credential_current_sign_count=passkey.sign_count,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully verified authentication for {username}")
|
||||
return verification.new_sign_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication verification failed: {e}")
|
||||
raise ValueError(f"Invalid authentication response: {str(e)}")
|
||||
finally:
|
||||
# 清理 challenge(无论成功或失败都清理,防止重放攻击)
|
||||
self._challenges.pop(challenge_key, None)
|
||||
|
||||
# ============ 辅助方法 ============
|
||||
|
||||
def _parse_transports(
|
||||
self, transports_json: Optional[str]
|
||||
) -> List[AuthenticatorTransport]:
|
||||
"""解析存储的 transports JSON"""
|
||||
if not transports_json:
|
||||
return []
|
||||
try:
|
||||
transport_strings = json.loads(transports_json)
|
||||
return [AuthenticatorTransport(t) for t in transport_strings]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def base64url_encode(self, data: bytes) -> str:
|
||||
"""Base64URL 编码(无 padding)"""
|
||||
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
|
||||
|
||||
def base64url_decode(self, data: str) -> bytes:
|
||||
"""Base64URL 解码(补齐 padding)"""
|
||||
padding = 4 - len(data) % 4
|
||||
if padding != 4:
|
||||
data += "=" * padding
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
|
||||
# 全局 WebAuthn 服务实例存储
|
||||
_webauthn_services: dict[str, WebAuthnService] = {}
|
||||
|
||||
|
||||
def get_webauthn_service(rp_id: str, rp_name: str, origin: str) -> WebAuthnService:
|
||||
"""
|
||||
获取或创建 WebAuthnService 实例
|
||||
使用缓存以保持 challenge 状态
|
||||
"""
|
||||
key = f"{rp_id}:{origin}"
|
||||
if key not in _webauthn_services:
|
||||
_webauthn_services[key] = WebAuthnService(rp_id, rp_name, origin)
|
||||
return _webauthn_services[key]
|
||||
Reference in New Issue
Block a user