From b1018d41feab5b2db1219c9eb44b803d9ef4fbb1 Mon Sep 17 00:00:00 2001 From: Estrella Pan Date: Sat, 24 Jan 2026 06:10:20 +0100 Subject: [PATCH] fix(test): use sync sessions in database tests to match production code The database classes (BangumiDatabase, TorrentDatabase, RSSDatabase) use synchronous Session, but tests were incorrectly using AsyncSession with await calls, causing AttributeError on coroutine objects. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude Co-Authored-By: Happy --- backend/src/test/test_database.py | 59 +++++++++++++------------------ 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 4cd48b13..83edf119 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -1,33 +1,24 @@ import pytest -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker -from sqlmodel import SQLModel +from sqlmodel import Session, SQLModel, create_engine from module.database.bangumi import BangumiDatabase from module.database.rss import RSSDatabase from module.database.torrent import TorrentDatabase from module.models import Bangumi, RSSItem, Torrent -# sqlite async mock engine -engine = create_async_engine( - "sqlite+aiosqlite://", - echo=False, -) -async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) +# sqlite sync engine for testing +engine = create_engine("sqlite://", echo=False) @pytest.fixture -async def db_session(): - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - async with async_session_factory() as session: +def db_session(): + SQLModel.metadata.create_all(engine) + with Session(engine) as session: yield session - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.drop_all) + SQLModel.metadata.drop_all(engine) -@pytest.mark.asyncio -async def test_bangumi_database(db_session): +def test_bangumi_database(db_session): test_data = Bangumi( official_title="无职转生,到了异世界就拿出真本事", year="2021", @@ -51,34 +42,33 @@ async def test_bangumi_database(db_session): db = BangumiDatabase(db_session) # insert - await db.add(test_data) - result = await db.search_id(1) + db.add(test_data) + result = db.search_id(1) assert result.official_title == test_data.official_title # update test_data.official_title = "无职转生,到了异世界就拿出真本事II" - await db.update(test_data) - result = await db.search_id(1) + db.update(test_data) + result = db.search_id(1) assert result.official_title == test_data.official_title # search poster - poster = await db.match_poster("无职转生,到了异世界就拿出真本事II (2021)") + poster = db.match_poster("无职转生,到了异世界就拿出真本事II (2021)") assert poster == "/test/test.jpg" # match torrent - result = await db.match_torrent( + result = db.match_torrent( "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]" ) assert result.official_title == "无职转生,到了异世界就拿出真本事II" # delete - await db.delete_one(1) - result = await db.search_id(1) + db.delete_one(1) + result = db.search_id(1) assert result is None -@pytest.mark.asyncio -async def test_torrent_database(db_session): +def test_torrent_database(db_session): test_data = Torrent( name="[Sub Group]test S02 01 [720p].mkv", url="https://test.com/test.mkv", @@ -86,22 +76,21 @@ async def test_torrent_database(db_session): db = TorrentDatabase(db_session) # insert - await db.add(test_data) - result = await db.search(1) + db.add(test_data) + result = db.search(1) assert result.name == test_data.name # update test_data.downloaded = True - await db.update(test_data) - result = await db.search(1) + db.update(test_data) + result = db.search(1) assert result.downloaded == True -@pytest.mark.asyncio -async def test_rss_database(db_session): +def test_rss_database(db_session): rss_url = "https://test.com/test.xml" db = RSSDatabase(db_session) - await db.add(RSSItem(url=rss_url, name="Test RSS")) - result = await db.search_id(1) + db.add(RSSItem(url=rss_url, name="Test RSS")) + result = db.search_id(1) assert result.url == rss_url