diff --git a/backend/src/main.py b/backend/src/main.py index cef8d10b..e68c1378 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,5 +1,6 @@ import logging import os +from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, Request @@ -7,6 +8,7 @@ from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from module.api import v1 +from module.api.program import program from module.conf import VERSION, settings, setup_logger setup_logger(reset=True) @@ -26,8 +28,19 @@ uvicorn_logging_config = { } +@asynccontextmanager +async def lifespan(app: FastAPI): + import asyncio + + # Startup + asyncio.create_task(program.startup()) + yield + # Shutdown + await program.stop() + + def create_app() -> FastAPI: - app = FastAPI() + app = FastAPI(lifespan=lifespan) # mount routers app.include_router(v1, prefix="/api") @@ -61,6 +74,7 @@ if VERSION != "DEV_VERSION": context = {"request": request} return templates.TemplateResponse("index.html", context) else: + @app.get("/", status_code=302, tags=["html"]) def index(): return RedirectResponse("/docs") diff --git a/backend/src/module/api/program.py b/backend/src/module/api/program.py index 3cb43291..e3cb07b9 100644 --- a/backend/src/module/api/program.py +++ b/backend/src/module/api/program.py @@ -1,4 +1,3 @@ -import asyncio import logging import os import signal @@ -18,14 +17,7 @@ program = Program() router = APIRouter(tags=["program"]) -@router.on_event("startup") -async def startup(): - asyncio.create_task(program.startup()) - - -@router.on_event("shutdown") -async def shutdown(): - await program.stop() +# Note: Lifespan events (startup/shutdown) are now handled in main.py via lifespan context manager @router.get( diff --git a/backend/src/module/conf/config.py b/backend/src/module/conf/config.py index bc78fa27..1bb2f19e 100644 --- a/backend/src/module/conf/config.py +++ b/backend/src/module/conf/config.py @@ -39,7 +39,7 @@ class Settings(Config): with open(CONFIG_PATH, "r", encoding="utf-8") as f: config = json.load(f) config = self._migrate_old_config(config) - config_obj = Config.parse_obj(config) + config_obj = Config.model_validate(config) self.__dict__.update(config_obj.__dict__) logger.info("Config loaded") @@ -69,7 +69,7 @@ class Settings(Config): def save(self, config_dict: dict | None = None): if not config_dict: - config_dict = self.dict() + config_dict = self.model_dump() with open(CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_dict, f, indent=4, ensure_ascii=False) @@ -79,7 +79,7 @@ class Settings(Config): self.save() def __load_from_env(self): - config_dict = self.dict() + config_dict = self.model_dump() for key, section in ENV_TO_ATTR.items(): for env, attr in section.items(): if env in os.environ: @@ -92,7 +92,7 @@ class Settings(Config): else: attr_name = attr[0] if isinstance(attr, tuple) else attr config_dict[key][attr_name] = self.__val_from_env(env, attr) - config_obj = Config.parse_obj(config_dict) + config_obj = Config.model_validate(config_dict) self.__dict__.update(config_obj.__dict__) logger.info("Config loaded from env") diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 1f505b50..cf7c9999 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -283,7 +283,7 @@ class BangumiDatabase: return False if not db_data: return False - bangumi_data = data.dict(exclude_unset=True) + bangumi_data = data.model_dump(exclude_unset=True) for key, value in bangumi_data.items(): setattr(db_data, key, value) self.session.add(db_data) diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 1ebaa430..7a00426b 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -1,7 +1,7 @@ from os.path import expandvars from typing import Literal -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class Program(BaseModel): @@ -102,8 +102,9 @@ class ExperimentalOpenAI(BaseModel): "", description="Azure OpenAI deployment id, ignored when api type is openai" ) - @validator("api_base") - def validate_api_base(cls, value: str): + @field_validator("api_base") + @classmethod + def validate_api_base(cls, value: str) -> str: if value == "https://api.openai.com/": return "https://api.openai.com/v1" return value @@ -119,5 +120,9 @@ class Config(BaseModel): notification: Notification = Notification() experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI() + def model_dump(self, *args, by_alias=True, **kwargs): + return super().model_dump(*args, by_alias=by_alias, **kwargs) + + # Keep dict() for backward compatibility def dict(self, *args, by_alias=True, **kwargs): - return super().dict(*args, by_alias=by_alias, **kwargs) + return self.model_dump(*args, by_alias=by_alias, **kwargs) diff --git a/backend/src/module/models/response.py b/backend/src/module/models/response.py index a3c9a5fe..5192f322 100644 --- a/backend/src/module/models/response.py +++ b/backend/src/module/models/response.py @@ -2,14 +2,14 @@ from pydantic import BaseModel, Field class ResponseModel(BaseModel): - status: bool = Field(..., example=True) - status_code: int = Field(..., example=200) + status: bool = Field(..., json_schema_extra={"example": True}) + status_code: int = Field(..., json_schema_extra={"example": 200}) msg_en: str msg_zh: str data: dict | None = None class APIResponse(BaseModel): - status: bool = Field(..., example=True) - msg_en: str = Field(..., example="Success") - msg_zh: str = Field(..., example="成功") \ No newline at end of file + status: bool = Field(..., json_schema_extra={"example": True}) + msg_en: str = Field(..., json_schema_extra={"example": "Success"}) + msg_zh: str = Field(..., json_schema_extra={"example": "成功"}) diff --git a/backend/src/module/security/jwt.py b/backend/src/module/security/jwt.py index 35c832a9..b0914619 100644 --- a/backend/src/module/security/jwt.py +++ b/backend/src/module/security/jwt.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from jose import JWTError, jwt from passlib.context import CryptContext @@ -21,9 +21,9 @@ app_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=1440) + expire = datetime.now(timezone.utc) + timedelta(minutes=1440) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, app_pwd_key, algorithm=app_pwd_algorithm) return encoded_jwt @@ -46,7 +46,7 @@ def verify_token(token: str): if token_data is None: return None expires = token_data.get("exp") - if datetime.utcnow() >= datetime.fromtimestamp(expires): + if datetime.now(timezone.utc) >= datetime.fromtimestamp(expires, tz=timezone.utc): raise JWTError("Token expired") return token_data diff --git a/backend/src/test/factories.py b/backend/src/test/factories.py index b2e855b4..b8d2bc43 100644 --- a/backend/src/test/factories.py +++ b/backend/src/test/factories.py @@ -1,6 +1,6 @@ """Test data factories for creating model instances with sensible defaults.""" -from datetime import datetime +from datetime import datetime, timezone from module.models import Bangumi, RSSItem, Torrent from module.models.config import Config @@ -78,7 +78,7 @@ def make_passkey(**overrides) -> Passkey: sign_count=0, aaguid="00000000-0000-0000-0000-000000000000", transports='["internal"]', - created_at=datetime.utcnow(), + created_at=datetime.now(timezone.utc), last_used_at=None, backup_eligible=False, backup_state=False, diff --git a/backend/src/test/test_config.py b/backend/src/test/test_config.py index f8aabe08..057dc3cd 100644 --- a/backend/src/test/test_config.py +++ b/backend/src/test/test_config.py @@ -92,7 +92,7 @@ class TestConfigSerialization: with open(json_path, "r") as f: loaded = json.load(f) - loaded_config = Config.parse_obj(loaded) + loaded_config = Config.model_validate(loaded) assert loaded_config.program.rss_time == config.program.rss_time assert loaded_config.downloader.type == config.downloader.type