chore: add deserializing argument for parse method

This commit is contained in:
100gle
2023-09-30 19:31:34 +08:00
parent f540beec9f
commit d093fdba0e
2 changed files with 15 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import logging
import openai
@@ -76,7 +77,9 @@ class OpenAIParser:
self.model = model
self.openai_kwargs = kwargs
def parse(self, text: str, prompt: str | None = None) -> str:
def parse(
self, text: str, prompt: str | None = None, asdict: bool = True
) -> dict | str:
"""parse text with openai
Args:
@@ -84,9 +87,12 @@ class OpenAIParser:
prompt (str | None, optional):
the custom prompt. Built-in prompt will be used if no prompt is provided. \
Defaults to None.
asdict (bool, optional):
whether to return the result as dict or not. \
Defaults to True.
Returns:
str: the parsed text.
dict | str: the parsed result.
"""
if not prompt:
prompt = DEFAULT_PROMPT
@@ -111,6 +117,12 @@ class OpenAIParser:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(complete())
if asdict:
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(f"Cannot parse result {result} as python dict.")
logger.debug(f"the parsed result is: {result}")
return result

View File

@@ -15,7 +15,7 @@ class TestOpenAIParser:
def test_parse(self):
text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]"
result = self.parser.parse(text=text)
result = self.parser.parse(text=text, asdict=False)
assert json.loads(result) == {
"group": "梦蓝字幕组",
"title_en": "New Doraemon",