diff --git a/app/chain/mediaserver.py b/app/chain/mediaserver.py index e6ae098e..6f239817 100644 --- a/app/chain/mediaserver.py +++ b/app/chain/mediaserver.py @@ -1,4 +1,5 @@ import threading +from datetime import datetime from typing import List, Union, Optional, Generator, Any from app.chain import ChainBase @@ -134,9 +135,10 @@ class MediaServerChain(ChainBase): with lock: # 汇总统计 total_count = 0 - # 清空登记薄 dboper = MediaServerOper() - dboper.empty() + enabled_servers = [mediaserver.name for mediaserver in mediaservers + if mediaserver and mediaserver.enabled and mediaserver.name] + dboper.delete_excluded_servers(enabled_servers) # 遍历媒体服务器 for mediaserver in mediaservers: if not mediaserver: @@ -152,6 +154,7 @@ class MediaServerChain(ChainBase): if not libraries: logger.info(f"没有获取到媒体服务器 {server_name} 的媒体库,跳过") continue + sync_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") for library in libraries: if sync_libraries \ and "all" not in sync_libraries \ @@ -180,8 +183,11 @@ class MediaServerChain(ChainBase): item_dict = item.model_dump() item_dict["seasoninfo"] = seasoninfo item_dict["item_type"] = item_type - dboper.add(**item_dict) + item_dict["lst_mod_date"] = sync_time + dboper.upsert(**item_dict) logger.info(f"{server_name} 媒体库 {library.name} 同步完成,共同步数量:{library_count}") # 总数累加 total_count += library_count + stale_count = dboper.delete_stale(server=server_name, sync_time=sync_time) + logger.info(f"媒体服务器 {server_name} 清理陈旧数据完成,删除数量:{stale_count}") logger.info(f"媒体服务器 {server_name} 数据同步完成,总同步数量:{total_count}") diff --git a/app/db/mediaserver_oper.py b/app/db/mediaserver_oper.py index 57ef2ebf..1c5d818c 100644 --- a/app/db/mediaserver_oper.py +++ b/app/db/mediaserver_oper.py @@ -14,24 +14,67 @@ class MediaServerOper(DbOper): def __init__(self, db: Session = None): super().__init__(db) + @staticmethod + def __prepare_payload(kwargs: dict) -> dict: + """ + 过滤数据库模型不存在或不应由远端覆盖的字段 + """ + return { + k: v for k, v in kwargs.items() + if hasattr(MediaServerItem, k) and k != "id" + } + def add(self, **kwargs) -> bool: """ 新增媒体服务器数据 """ - # MediaServerItem中没有的属性剔除 - kwargs = {k: v for k, v in kwargs.items() if hasattr(MediaServerItem, k)} + kwargs = self.__prepare_payload(kwargs) + server = kwargs.get("server") + item_id = kwargs.get("item_id") + if not server or not item_id: + return False item = MediaServerItem(**kwargs) - if not item.get_by_itemid(self._db, kwargs.get("item_id")): + if not item.get_by_server_itemid(self._db, server, item_id): item.create(self._db) return True return False + def upsert(self, **kwargs) -> bool: + """ + 按媒体服务器和条目ID新增或更新数据 + """ + kwargs = self.__prepare_payload(kwargs) + server = kwargs.get("server") + item_id = kwargs.get("item_id") + if not server or not item_id: + return False + + item = MediaServerItem.get_by_server_itemid(self._db, server, item_id) + if item: + item.update(self._db, kwargs) + return False + + MediaServerItem(**kwargs).create(self._db) + return True + def empty(self, server: Optional[str] = None): """ 清空媒体服务器数据 """ MediaServerItem.empty(self._db, server) + def delete_stale(self, server: str, sync_time: str) -> int: + """ + 删除本轮同步未更新的旧数据 + """ + return MediaServerItem.delete_stale(self._db, server, sync_time) + + def delete_excluded_servers(self, servers: list[str]) -> int: + """ + 删除未启用或已移除媒体服务器的数据 + """ + return MediaServerItem.delete_excluded_servers(self._db, servers) + def exists(self, **kwargs) -> Optional[MediaServerItem]: """ 判断媒体服务器数据是否存在 diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index 69c4db84..6c4d826d 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Optional +from typing import Optional, List -from sqlalchemy import Column, Integer, String, JSON, Index +from sqlalchemy import Column, Integer, String, JSON, Index, or_ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -44,6 +44,7 @@ class MediaServerItem(Base): lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) __table_args__ = ( + Index('ux_mediaserveritem_server_item_id', 'server', 'item_id', unique=True), Index('ix_mediaserveritem_tmdbid_item_type', 'tmdbid', 'item_type'), ) @@ -52,13 +53,34 @@ class MediaServerItem(Base): def get_by_itemid(cls, db: Session, item_id: str): return db.query(cls).filter(cls.item_id == item_id).first() + @classmethod + @db_query + def get_by_server_itemid(cls, db: Session, server: str, item_id: str): + return db.query(cls).filter(cls.server == server, + cls.item_id == item_id).first() + @classmethod @db_update def empty(cls, db: Session, server: Optional[str] = None): if server is None: - db.query(cls).delete() + db.query(cls).delete(synchronize_session=False) else: - db.query(cls).filter(cls.server == server).delete() + db.query(cls).filter(cls.server == server).delete(synchronize_session=False) + + @classmethod + @db_update + def delete_stale(cls, db: Session, server: str, sync_time: str): + return db.query(cls).filter(cls.server == server, + or_(cls.lst_mod_date.is_(None), + cls.lst_mod_date != sync_time)).delete(synchronize_session=False) + + @classmethod + @db_update + def delete_excluded_servers(cls, db: Session, servers: List[str]): + if not servers: + return db.query(cls).delete(synchronize_session=False) + return db.query(cls).filter(or_(cls.server.is_(None), + ~cls.server.in_(servers))).delete(synchronize_session=False) @classmethod @db_query diff --git a/database/versions/b8f6e3a1c2d4_2_2_5.py b/database/versions/b8f6e3a1c2d4_2_2_5.py new file mode 100644 index 00000000..b5115c55 --- /dev/null +++ b/database/versions/b8f6e3a1c2d4_2_2_5.py @@ -0,0 +1,87 @@ +"""2.2.5 +mediaserveritem 改为按 server + item_id 唯一 + +Revision ID: b8f6e3a1c2d4 +Revises: 93f8cb6a4d1e +Create Date: 2026-05-09 +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "b8f6e3a1c2d4" +down_revision = "93f8cb6a4d1e" +branch_labels = None +depends_on = None + +TABLE_NAME = "mediaserveritem" +INDEX_NAME = "ux_mediaserveritem_server_item_id" +INDEX_COLUMNS = ["server", "item_id"] + +mediaserveritem = sa.table( + TABLE_NAME, + sa.column("id", sa.Integer), + sa.column("server", sa.String), + sa.column("item_id", sa.String), +) + + +def _table_exists(inspector: sa.Inspector) -> bool: + return TABLE_NAME in inspector.get_table_names() + + +def _has_index_signature(inspector: sa.Inspector, unique: bool) -> bool: + target_columns = tuple(INDEX_COLUMNS) + for index in inspector.get_indexes(TABLE_NAME): + if tuple(index.get("column_names") or []) == target_columns and bool(index.get("unique")) == unique: + return True + return False + + +def _drop_index_if_exists(inspector: sa.Inspector) -> None: + for index in inspector.get_indexes(TABLE_NAME): + if index.get("name") == INDEX_NAME: + op.drop_index(INDEX_NAME, table_name=TABLE_NAME) + return + + +def _deduplicate_rows() -> None: + bind = op.get_bind() + keep_ids = ( + sa.select(sa.func.max(mediaserveritem.c.id)) + .where( + mediaserveritem.c.server.is_not(None), + mediaserveritem.c.item_id.is_not(None), + ) + .group_by(mediaserveritem.c.server, mediaserveritem.c.item_id) + ) + bind.execute( + mediaserveritem.delete().where( + sa.and_( + mediaserveritem.c.server.is_not(None), + mediaserveritem.c.item_id.is_not(None), + mediaserveritem.c.id.not_in(keep_ids), + ) + ) + ) + + +def upgrade() -> None: + inspector = sa.inspect(op.get_bind()) + if not _table_exists(inspector): + return + + _deduplicate_rows() + + inspector = sa.inspect(op.get_bind()) + if not _has_index_signature(inspector, unique=True): + op.create_index(INDEX_NAME, TABLE_NAME, INDEX_COLUMNS, unique=True) + + +def downgrade() -> None: + inspector = sa.inspect(op.get_bind()) + if not _table_exists(inspector): + return + + _drop_index_if_exists(inspector) diff --git a/tests/test_mediaserver_sync_incremental.py b/tests/test_mediaserver_sync_incremental.py new file mode 100644 index 00000000..5cece8d0 --- /dev/null +++ b/tests/test_mediaserver_sync_incremental.py @@ -0,0 +1,243 @@ +import importlib.util +import sqlite3 +import sys +import tempfile +import types +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +if "psutil" not in sys.modules: + sys.modules["psutil"] = types.ModuleType("psutil") + +if "aiosqlite" not in sys.modules: + aiosqlite_module = types.ModuleType("aiosqlite") + for attr in ( + "DatabaseError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "sqlite_version", + "sqlite_version_info", + ): + setattr(aiosqlite_module, attr, getattr(sqlite3, attr)) + aiosqlite_module.connect = sqlite3.connect + aiosqlite_module.paramstyle = "qmark" + aiosqlite_module.threadsafety = sqlite3.threadsafety + sys.modules["aiosqlite"] = aiosqlite_module + +if "app.log" not in sys.modules: + log_module = types.ModuleType("app.log") + + class _Logger: + def info(self, *_args, **_kwargs): + return None + + def debug(self, *_args, **_kwargs): + return None + + def warning(self, *_args, **_kwargs): + return None + + def error(self, *_args, **_kwargs): + return None + + log_module.logger = _Logger() + log_module.log_settings = SimpleNamespace() + log_module.LogConfigModel = type("LogConfigModel", (), {}) + sys.modules["app.log"] = log_module + +from app import schemas +from app.db import Base +from app.db.mediaserver_oper import MediaServerOper +from app.db.models.mediaserver import MediaServerItem + + +def _load_mediaserver_chain_class(): + """隔离加载 MediaServerChain,避免测试依赖完整运行时环境。""" + module_name = "_test_mediaserver_chain" + if module_name in sys.modules: + module = sys.modules[module_name] + return module, module.MediaServerChain + + if "app.chain" not in sys.modules: + chain_module = types.ModuleType("app.chain") + chain_module.ChainBase = type("ChainBase", (), {}) + sys.modules["app.chain"] = chain_module + + if "app.core.config" not in sys.modules: + config_module = types.ModuleType("app.core.config") + config_module.global_vars = SimpleNamespace(is_system_stopped=False) + sys.modules["app.core.config"] = config_module + + if "app.helper.service" not in sys.modules: + service_module = types.ModuleType("app.helper.service") + + class _ServiceConfigHelper: + @staticmethod + def get_mediaserver_configs(): + return [] + + service_module.ServiceConfigHelper = _ServiceConfigHelper + sys.modules["app.helper.service"] = service_module + + mediaserver_path = Path(__file__).resolve().parents[1] / "app" / "chain" / "mediaserver.py" + spec = importlib.util.spec_from_file_location(module_name, mediaserver_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec and spec.loader + spec.loader.exec_module(module) + return module, module.MediaServerChain + + +MEDIA_SERVER_CHAIN_MODULE, MediaServerChain = _load_mediaserver_chain_class() + + +class MediaServerIncrementalSyncTest(unittest.TestCase): + """验证媒体库同步改为按条目增量更新。""" + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + db_path = Path(self.temp_dir.name) / "mediaserver.db" + self.engine = create_engine(f"sqlite:///{db_path}") + self.SessionFactory = sessionmaker(bind=self.engine) + Base.metadata.create_all(bind=self.engine) + + def tearDown(self): + self.engine.dispose() + self.temp_dir.cleanup() + + def test_add_allows_same_item_id_across_servers(self): + """不同媒体服务器允许复用相同 item_id。""" + with self.SessionFactory() as db: + oper = MediaServerOper(db) + + self.assertTrue( + oper.add( + server="plex", + library="movies", + item_id="same-item-id", + item_type="电影", + title="Movie A", + ) + ) + self.assertTrue( + oper.add( + server="jellyfin", + library="movies", + item_id="same-item-id", + item_type="电影", + title="Movie B", + ) + ) + + items = ( + db.query(MediaServerItem) + .order_by(MediaServerItem.server.asc()) + .all() + ) + + self.assertEqual(len(items), 2) + self.assertEqual([item.server for item in items], ["jellyfin", "plex"]) + + def test_sync_updates_rows_and_removes_stale_entries(self): + """同步应更新已存在条目,并清理未再出现或已移除服务的数据。""" + old_sync_time = "2026-05-01 00:00:00" + + with self.SessionFactory() as db: + db.add_all( + [ + MediaServerItem( + server="plex", + library="movies", + item_id="/library/metadata/1", + item_type="电影", + title="Old Title", + year="2024", + path="/media/old.mkv", + lst_mod_date=old_sync_time, + ), + MediaServerItem( + server="plex", + library="movies", + item_id="/library/metadata/2", + item_type="电影", + title="Stale Title", + year="2020", + path="/media/stale.mkv", + lst_mod_date=old_sync_time, + ), + MediaServerItem( + server="jellyfin", + library="movies", + item_id="/library/metadata/1", + item_type="电影", + title="Removed Server Title", + year="2024", + path="/media/removed.mkv", + lst_mod_date=old_sync_time, + ), + ] + ) + db.commit() + existing_id = ( + db.query(MediaServerItem.id) + .filter( + MediaServerItem.server == "plex", + MediaServerItem.item_id == "/library/metadata/1", + ) + .scalar() + ) + + chain = object.__new__(MediaServerChain) + chain.librarys = lambda _server: [SimpleNamespace(id="movies", name="电影库")] + chain.items = lambda **_kwargs: iter( + [ + schemas.MediaServerItem( + server="plex", + library="movies", + item_id="/library/metadata/1", + item_type="Movie", + title="New Title", + year="2024", + tmdbid=100, + path="/media/new.mkv", + ) + ] + ) + chain.episodes = lambda *_args, **_kwargs: [] + + with patch("app.db.ScopedSession", self.SessionFactory), patch.object( + MEDIA_SERVER_CHAIN_MODULE.ServiceConfigHelper, + "get_mediaserver_configs", + return_value=[SimpleNamespace(name="plex", enabled=True, sync_libraries=["all"])], + ): + chain.sync() + + with self.SessionFactory() as db: + items = ( + db.query(MediaServerItem) + .order_by(MediaServerItem.server.asc(), MediaServerItem.item_id.asc()) + .all() + ) + + self.assertEqual(len(items), 1) + self.assertEqual(items[0].id, existing_id) + self.assertEqual(items[0].server, "plex") + self.assertEqual(items[0].item_id, "/library/metadata/1") + self.assertEqual(items[0].item_type, "电影") + self.assertEqual(items[0].title, "New Title") + self.assertEqual(items[0].path, "/media/new.mkv") + self.assertNotEqual(items[0].lst_mod_date, old_sync_time) + + +if __name__ == "__main__": + unittest.main()