mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-13 17:52:28 +08:00
fix actions
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from log import logger
|
||||
|
||||
|
||||
class FetchDownloadsParams(ActionParams):
|
||||
@@ -36,6 +37,7 @@ class FetchDownloadsAction(BaseAction):
|
||||
"""
|
||||
self._downloads = context.downloads
|
||||
for download in self._downloads:
|
||||
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
|
||||
torrents = self.chain.list_torrents(hashs=[download.download_id])
|
||||
if not torrents:
|
||||
download.completed = True
|
||||
@@ -43,6 +45,7 @@ class FetchDownloadsAction(BaseAction):
|
||||
for t in torrents:
|
||||
download.path = t.path
|
||||
if t.progress >= 100:
|
||||
logger.info(f"下载任务 {download.download_id} 已完成")
|
||||
download.completed = True
|
||||
|
||||
self.job_done()
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from core.config import settings
|
||||
from core.event import eventmanager
|
||||
from log import logger
|
||||
from schemas import RecommendSourceEventData, MediaInfo
|
||||
from schemas.types import ChainEventType
|
||||
from utils.http import RequestUtils
|
||||
|
||||
|
||||
class FetchMediasParams(ActionParams):
|
||||
"""
|
||||
获取媒体数据参数
|
||||
"""
|
||||
pass
|
||||
sources: List[str] = Field([], description="媒体数据来源")
|
||||
|
||||
|
||||
class FetchMediasAction(BaseAction):
|
||||
@@ -14,6 +24,63 @@ class FetchMediasAction(BaseAction):
|
||||
获取媒体数据
|
||||
"""
|
||||
|
||||
__inner_sources = [
|
||||
{
|
||||
"api_path": 'recommend/tmdb_trending',
|
||||
"name": '流行趋势',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_showing',
|
||||
"name": '正在热映',
|
||||
},
|
||||
{
|
||||
"api_path": 'bangumi/calendar',
|
||||
"name": 'Bangumi每日放送',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/tmdb_movies',
|
||||
"name": 'TMDB热门电影',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/tmdb_tvs?with_original_language=zh|en|ja|ko',
|
||||
"name": 'TMDB热门电视剧',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_movie_hot',
|
||||
"name": '豆瓣热门电影',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_tv_hot',
|
||||
"name": '豆瓣热门电视剧',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_tv_animation',
|
||||
"name": '豆瓣热门动漫',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_movies',
|
||||
"name": '豆瓣最新电影',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_tvs',
|
||||
"name": '豆瓣最新电视剧',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_movie_top250',
|
||||
"name": '豆瓣电影TOP250',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_tv_weekly_chinese',
|
||||
"name": '豆瓣国产剧集榜',
|
||||
},
|
||||
{
|
||||
"api_path": 'recommend/douban_tv_weekly_global',
|
||||
"name": '豆瓣全球剧集榜',
|
||||
}
|
||||
]
|
||||
|
||||
__medias = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "获取媒体数据"
|
||||
@@ -24,7 +91,47 @@ class FetchMediasAction(BaseAction):
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
return True if self.__medias else False
|
||||
|
||||
def __get_source(self, source: str):
|
||||
"""
|
||||
获取数据源
|
||||
"""
|
||||
for s in self.__inner_sources:
|
||||
if s['name'] == source:
|
||||
return s
|
||||
return None
|
||||
|
||||
async def execute(self, params: FetchMediasParams, context: ActionContext) -> ActionContext:
|
||||
pass
|
||||
"""
|
||||
获取媒体数据,填充到medias
|
||||
"""
|
||||
# 广播事件,请示额外的推荐数据源支持
|
||||
event_data = RecommendSourceEventData()
|
||||
event = eventmanager.send_event(ChainEventType.RecommendSource, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data:
|
||||
event_data: RecommendSourceEventData = event.event_data
|
||||
if event_data.extra_sources:
|
||||
self.__inner_sources.extend([s.dict() for s in event_data.extra_sources])
|
||||
|
||||
for name in params.sources:
|
||||
source = self.__get_source(name)
|
||||
if not source:
|
||||
continue
|
||||
logger.info(f"获取媒体数据 {source} ...")
|
||||
# 调用内部API获取数据
|
||||
api_url = f"http://127.0.0.1:{settings.PORT}/api/v1/{source['api_path']}?token={settings.API_TOKEN}"
|
||||
res = RequestUtils(timeout=15).post_res(api_url)
|
||||
if res:
|
||||
results = res.json()
|
||||
logger.info(f"{name} 获取到 {len(results)} 条数据")
|
||||
self.__medias.extend([MediaInfo(**r) for r in results])
|
||||
else:
|
||||
logger.error(f"{name} 获取数据失败")
|
||||
|
||||
if self.__medias:
|
||||
context.medias.extend(self.__medias)
|
||||
|
||||
self.job_done()
|
||||
return context
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from schemas import MediaType
|
||||
|
||||
|
||||
class FilterMediasParams(ActionParams):
|
||||
"""
|
||||
过滤媒体数据参数
|
||||
"""
|
||||
pass
|
||||
type: Optional[str] = Field(None, description="媒体类型 (电影/电视剧)")
|
||||
category: Optional[str] = Field(None, description="媒体类别 (二级分类)")
|
||||
vote: Optional[int] = Field(0, description="评分")
|
||||
year: Optional[str] = Field(None, description="年份")
|
||||
|
||||
|
||||
class FilterMediasAction(BaseAction):
|
||||
@@ -14,6 +22,8 @@ class FilterMediasAction(BaseAction):
|
||||
过滤媒体数据
|
||||
"""
|
||||
|
||||
__medias = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "过滤媒体数据"
|
||||
@@ -24,7 +34,26 @@ class FilterMediasAction(BaseAction):
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
return True if self.__medias else False
|
||||
|
||||
async def execute(self, params: FilterMediasParams, context: ActionContext) -> ActionContext:
|
||||
pass
|
||||
"""
|
||||
过滤medias中媒体数据
|
||||
"""
|
||||
for media in context.medias:
|
||||
if params.type and media.type != MediaType(params.type):
|
||||
continue
|
||||
if params.category and media.category != params.category:
|
||||
continue
|
||||
if params.vote and media.vote_average < params.vote:
|
||||
continue
|
||||
if params.year and media.year != params.year:
|
||||
continue
|
||||
self.__medias.append(media)
|
||||
|
||||
if self.__medias:
|
||||
context.medias = self.__medias
|
||||
|
||||
self.job_done()
|
||||
return context
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ class FilterTorrentsParams(ActionParams):
|
||||
过滤资源数据参数
|
||||
"""
|
||||
rule_groups: Optional[List[str]] = Field([], description="规则组")
|
||||
quality: Optional[str] = Field(None, description="资源质量")
|
||||
resolution: Optional[str] = Field(None, description="资源分辨率")
|
||||
effect: Optional[str] = Field(None, description="特效")
|
||||
include: Optional[str] = Field(None, description="包含规则")
|
||||
exclude: Optional[str] = Field(None, description="排除规则")
|
||||
size: Optional[str] = Field(None, description="资源大小范围(MB)")
|
||||
@@ -48,6 +51,9 @@ class FilterTorrentsAction(BaseAction):
|
||||
if self.torrenthelper.filter_torrent(
|
||||
torrent_info=torrent.torrent_info,
|
||||
filter_params={
|
||||
"quality": params.quality,
|
||||
"resolution": params.resolution,
|
||||
"effect": params.effect,
|
||||
"include": params.include,
|
||||
"exclude": params.exclude,
|
||||
"size": params.size
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from chain.media import MediaChain
|
||||
from chain.storage import StorageChain
|
||||
from core.metainfo import MetaInfoPath
|
||||
from log import logger
|
||||
|
||||
|
||||
class ScrapeFileParams(ActionParams):
|
||||
@@ -14,6 +18,13 @@ class ScrapeFileAction(BaseAction):
|
||||
刮削文件
|
||||
"""
|
||||
|
||||
__scraped_files = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.storagechain = StorageChain()
|
||||
self.mediachain = MediaChain()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "刮削文件"
|
||||
@@ -24,7 +35,24 @@ class ScrapeFileAction(BaseAction):
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
return True if self.__scraped_files else False
|
||||
|
||||
async def execute(self, params: ScrapeFileParams, context: ActionContext) -> ActionContext:
|
||||
pass
|
||||
"""
|
||||
刮削fileitems中的所有文件
|
||||
"""
|
||||
for fileitem in context.fileitems:
|
||||
if fileitem in self.__scraped_files:
|
||||
continue
|
||||
if not self.storagechain.exists(fileitem):
|
||||
continue
|
||||
meta = MetaInfoPath(fileitem.path)
|
||||
mediainfo = self.chain.recognize_media(meta)
|
||||
if not mediainfo:
|
||||
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
|
||||
continue
|
||||
self.mediachain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo)
|
||||
self.__scraped_files.append(fileitem)
|
||||
|
||||
self.job_done()
|
||||
return context
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import copy
|
||||
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from core.event import eventmanager
|
||||
|
||||
|
||||
class SendEventParams(ActionParams):
|
||||
@@ -14,6 +17,8 @@ class SendEventAction(BaseAction):
|
||||
发送事件
|
||||
"""
|
||||
|
||||
__success = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "发送事件"
|
||||
@@ -24,7 +29,19 @@ class SendEventAction(BaseAction):
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
return self.__success
|
||||
|
||||
async def execute(self, params: SendEventParams, context: ActionContext) -> ActionContext:
|
||||
pass
|
||||
"""
|
||||
发送events中的事件
|
||||
"""
|
||||
if context.events:
|
||||
# 按优先级排序,优先级高的先发送
|
||||
context.events.sort(key=lambda x: x.priority, reverse=True)
|
||||
for event in copy.deepcopy(context.events):
|
||||
eventmanager.send_event(etype=event.event_type, data=event.event_data)
|
||||
context.events.remove(event)
|
||||
self.__success = True
|
||||
|
||||
self.job_done()
|
||||
return context
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from app.actions import BaseAction
|
||||
from app.schemas import ActionParams, ActionContext
|
||||
from chain.storage import StorageChain
|
||||
from chain.transfer import TransferChain
|
||||
from log import logger
|
||||
|
||||
|
||||
class TransferFileParams(ActionParams):
|
||||
@@ -14,6 +17,13 @@ class TransferFileAction(BaseAction):
|
||||
整理文件
|
||||
"""
|
||||
|
||||
__fileitems = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transferchain = TransferChain()
|
||||
self.storagechain = StorageChain()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "整理文件"
|
||||
@@ -24,7 +34,30 @@ class TransferFileAction(BaseAction):
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
return True if self.__fileitems else False
|
||||
|
||||
async def execute(self, params: TransferFileParams, context: ActionContext) -> ActionContext:
|
||||
pass
|
||||
"""
|
||||
从downloads中整理文件,记录到fileitems
|
||||
"""
|
||||
for download in context.downloads:
|
||||
if not download.completed:
|
||||
logger.info(f"下载任务 {download.download_id} 未完成")
|
||||
continue
|
||||
fileitem = self.storagechain.get_file_item(storage="local", path=download.path)
|
||||
if not fileitem:
|
||||
logger.info(f"文件 {download.path} 不存在")
|
||||
continue
|
||||
logger.info(f"开始整理文件 {download.path} ...")
|
||||
state, errmsg = self.transferchain.do_transfer(fileitem, background=False)
|
||||
if not state:
|
||||
logger.error(f"整理文件 {download.path} 失败: {errmsg}")
|
||||
continue
|
||||
logger.info(f"整理文件 {download.path} 完成")
|
||||
self.__fileitems.append(fileitem)
|
||||
|
||||
if self.__fileitems:
|
||||
context.fileitems.extend(self.__fileitems)
|
||||
|
||||
self.job_done()
|
||||
return context
|
||||
|
||||
@@ -84,6 +84,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("rename_file", fileitem=fileitem, name=name)
|
||||
|
||||
def exists(self, fileitem: schemas.FileItem) -> Optional[bool]:
|
||||
"""
|
||||
判断文件或目录是否存在
|
||||
"""
|
||||
return True if self.get_item(fileitem) else False
|
||||
|
||||
def get_item(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
查询目录或文件
|
||||
|
||||
@@ -6,6 +6,15 @@ from pydantic import BaseModel, Field, root_validator
|
||||
from app.schemas import MessageChannel, FileItem
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""
|
||||
事件模型
|
||||
"""
|
||||
event_type: str = Field(..., description="事件类型")
|
||||
event_data: Optional[dict] = Field({}, description="事件数据")
|
||||
priority: Optional[int] = Field(0, description="事件优先级")
|
||||
|
||||
|
||||
class BaseEventData(BaseModel):
|
||||
"""
|
||||
事件数据的基类,所有具体事件数据类应继承自此类
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.schemas.download import DownloadTask
|
||||
from app.schemas.site import Site
|
||||
from app.schemas.subscribe import Subscribe
|
||||
from app.schemas.message import Notification
|
||||
from app.schemas.event import Event
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
@@ -61,3 +62,4 @@ class ActionContext(BaseModel):
|
||||
sites: Optional[List[Site]] = Field([], description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field([], description="订阅列表")
|
||||
messages: Optional[List[Notification]] = Field([], description="消息列表")
|
||||
events: Optional[List[Event]] = Field([], description="事件列表")
|
||||
|
||||
Reference in New Issue
Block a user