mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-23 18:11:37 +08:00
Merge pull request #623 from 100gle/bugfix-openai-baseurl
bugfix(experimental): add api version for openai api base automatically
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from os.path import expandvars
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class Program(BaseModel):
|
||||
@@ -102,6 +102,12 @@ class ExperimentalOpenAI(BaseModel):
|
||||
"", description="Azure OpenAI deployment id, ignored when api type is openai"
|
||||
)
|
||||
|
||||
@validator("api_base")
|
||||
def validate_api_base(cls, value: str):
|
||||
if value == "https://api.openai.com/":
|
||||
return "https://api.openai.com/v1"
|
||||
return value
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
program: Program = Program()
|
||||
|
||||
14
backend/src/test/models/test_config.py
Normal file
14
backend/src/test/models/test_config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from module.models.config import ExperimentalOpenAI
|
||||
|
||||
|
||||
def test_experimental_openai_validate_api_base():
|
||||
config = ExperimentalOpenAI(api_type="openai", api_base="https://api.openai.com/")
|
||||
assert config.api_base == "https://api.openai.com/v1"
|
||||
|
||||
config = ExperimentalOpenAI(api_base="https://api.openai.com/")
|
||||
assert config.api_base == "https://api.openai.com/v1"
|
||||
|
||||
config = ExperimentalOpenAI(
|
||||
api_type="azure", api_base="https://custom-api-base.com"
|
||||
)
|
||||
assert config.api_base == "https://custom-api-base.com"
|
||||
Reference in New Issue
Block a user