diff --git a/backend/src/module/parser/openai.py b/backend/src/module/parser/openai.py index 2ac051ca..e14e7cc8 100644 --- a/backend/src/module/parser/openai.py +++ b/backend/src/module/parser/openai.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import openai @@ -76,7 +77,9 @@ class OpenAIParser: self.model = model self.openai_kwargs = kwargs - def parse(self, text: str, prompt: str | None = None) -> str: + def parse( + self, text: str, prompt: str | None = None, asdict: bool = True + ) -> dict | str: """parse text with openai Args: @@ -84,9 +87,12 @@ class OpenAIParser: prompt (str | None, optional): the custom prompt. Built-in prompt will be used if no prompt is provided. \ Defaults to None. + asdict (bool, optional): + whether to return the result as dict or not. \ + Defaults to True. Returns: - str: the parsed text. + dict | str: the parsed result. """ if not prompt: prompt = DEFAULT_PROMPT @@ -111,6 +117,12 @@ class OpenAIParser: loop = asyncio.get_event_loop() result = loop.run_until_complete(complete()) + if asdict: + try: + result = json.loads(result) + except json.JSONDecodeError: + logger.warning(f"Cannot parse result {result} as python dict.") + logger.debug(f"the parsed result is: {result}") return result diff --git a/backend/src/test/test_openai.py b/backend/src/test/test_openai.py index ea40f865..60408641 100644 --- a/backend/src/test/test_openai.py +++ b/backend/src/test/test_openai.py @@ -15,7 +15,7 @@ class TestOpenAIParser: def test_parse(self): text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]" - result = self.parser.parse(text=text) + result = self.parser.parse(text=text, asdict=False) assert json.loads(result) == { "group": "梦蓝字幕组", "title_en": "New Doraemon",