diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py
index 22d1c6eb..49044fa3 100644
--- a/backend/src/module/models/config.py
+++ b/backend/src/module/models/config.py
@@ -1,4 +1,5 @@
from os.path import expandvars
+from typing import Literal
from pydantic import BaseModel, Field
@@ -88,7 +89,18 @@ class ExperimentalOpenAI(BaseModel):
api_base: str = Field(
"https://api.openai.com/v1", description="OpenAI api base url"
)
- model: str = Field("gpt-3.5-turbo", description="OpenAI model")
+ api_type: Literal["azure", "openai"] = Field(
+ "openai", description="OpenAI api type, usually for azure"
+ )
+ api_version: str = Field(
+ "2023-05-15", description="OpenAI api version, only for Azure"
+ )
+ model: str = Field(
+ "gpt-3.5-turbo", description="OpenAI model, ignored when api type is azure"
+ )
+ deployment_id: str = Field(
+ "", description="Azure OpenAI deployment id, ignored when api type is openai"
+ )
class Config(BaseModel):
diff --git a/backend/src/module/parser/analyser/openai.py b/backend/src/module/parser/analyser/openai.py
index 6da6d232..0d9fe946 100644
--- a/backend/src/module/parser/analyser/openai.py
+++ b/backend/src/module/parser/analyser/openai.py
@@ -1,6 +1,7 @@
import json
import logging
from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict
import openai
@@ -98,21 +99,10 @@ class OpenAIParser:
if not prompt:
prompt = DEFAULT_PROMPT
- with ThreadPoolExecutor(max_workers=1) as worker:
- future = worker.submit(
- openai.ChatCompletion.create,
- api_key=self._api_key,
- api_base=self.api_base,
- model=self.model,
- messages=[
- dict(role="system", content=prompt),
- dict(role="user", content=text),
- ],
- # set temperature to 0 to make results be more stable and reproducible.
- temperature=0,
- **self.openai_kwargs,
- )
+ params = self._prepare_params(text, prompt)
+ with ThreadPoolExecutor(max_workers=1) as worker:
+ future = worker.submit(openai.ChatCompletion.create, **params)
resp = future.result()
result = resp["choices"][0]["message"]["content"]
@@ -126,3 +116,36 @@ class OpenAIParser:
logger.debug(f"the parsed result is: {result}")
return result
+
+ def _prepare_params(self, text: str, prompt: str) -> Dict[str, Any]:
+ """_prepare_params is a helper function to prepare params for openai library.
+ There are some differences between openai and azure openai api, so we need to
+ prepare params for them.
+
+ Args:
+ text (str): the text to be parsed
+ prompt (str): the custom prompt
+
+ Returns:
+ Dict[str, Any]: the prepared key value pairs.
+ """
+ params = dict(
+ api_key=self._api_key,
+ api_base=self.api_base,
+ messages=[
+ dict(role="system", content=prompt),
+ dict(role="user", content=text),
+ ],
+ # set temperature to 0 to make results be more stable and reproducible.
+ temperature=0,
+ )
+
+ api_type = self.openai_kwargs.get("api_type", "openai")
+ if api_type == "azure":
+ params["deployment_id"] = self.openai_kwargs.get("deployment_id", "")
+ params["api_version"] = self.openai_kwargs.get("api_version", "2023-05-15")
+ params["api_type"] = "azure"
+ else:
+ params["model"] = self.model
+
+ return params
diff --git a/backend/src/test/test_openai.py b/backend/src/test/test_openai.py
index 0df1efdb..4709c579 100644
--- a/backend/src/test/test_openai.py
+++ b/backend/src/test/test_openai.py
@@ -1,19 +1,57 @@
import json
-import os
from unittest import mock
-from dotenv import load_dotenv
-from module.parser.analyser.openai import OpenAIParser
-
-load_dotenv()
+from module.parser.analyser.openai import DEFAULT_PROMPT, OpenAIParser
class TestOpenAIParser:
@classmethod
def setup_class(cls):
- api_key = os.getenv("OPENAI_API_KEY") or "testing!"
+ api_key = "testing!"
cls.parser = OpenAIParser(api_key=api_key)
+ def test__prepare_params_with_openai(self):
+ text = "hello world"
+ expected = dict(
+ api_key=self.parser._api_key,
+ api_base=self.parser.api_base,
+ messages=[
+ dict(role="system", content=DEFAULT_PROMPT),
+ dict(role="user", content=text),
+ ],
+ temperature=0,
+ model=self.parser.model,
+ )
+
+ params = self.parser._prepare_params(text, DEFAULT_PROMPT)
+ assert expected == params
+
+ def test__prepare_params_with_azure(self):
+ azure_parser = OpenAIParser(
+ api_key="aaabbbcc",
+ api_base="https://test.openai.azure.com/",
+ api_type="azure",
+ api_version="2023-05-15",
+ deployment_id="gpt-35-turbo",
+ )
+
+ text = "hello world"
+ expected = dict(
+ api_key=azure_parser._api_key,
+ api_base=azure_parser.api_base,
+ messages=[
+ dict(role="system", content=DEFAULT_PROMPT),
+ dict(role="user", content=text),
+ ],
+ temperature=0,
+ deployment_id="gpt-35-turbo",
+ api_version="2023-05-15",
+ api_type="azure",
+ )
+
+ params = azure_parser._prepare_params(text, DEFAULT_PROMPT)
+ assert expected == params
+
def test_parse(self):
text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]"
expected = {
diff --git a/backend/src/test/test_title_parser.py b/backend/src/test/test_title_parser.py
index 36199bc2..b23ab4a2 100644
--- a/backend/src/test/test_title_parser.py
+++ b/backend/src/test/test_title_parser.py
@@ -1,6 +1,3 @@
-import json
-import os
-
import pytest
from module.conf import settings
from module.parser.title_parser import TitleParser
diff --git a/webui/src/components/setting/config-openai.vue b/webui/src/components/setting/config-openai.vue
index 2b34ad91..a19f8177 100644
--- a/webui/src/components/setting/config-openai.vue
+++ b/webui/src/components/setting/config-openai.vue
@@ -1,20 +1,29 @@
@@ -53,7 +87,7 @@ const items: SettingItem[] = [
;
export type NotificationType = UnionToTuple;
/** OpenAI Model List */
export type OpenAIModel = UnionToTuple;
+/** OpenAI API Type */
+export type OpenAIType = UnionToTuple;