diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 49044fa3..1ebaa430 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -1,7 +1,7 @@ from os.path import expandvars from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator class Program(BaseModel): @@ -102,6 +102,12 @@ class ExperimentalOpenAI(BaseModel): "", description="Azure OpenAI deployment id, ignored when api type is openai" ) + @validator("api_base") + def validate_api_base(cls, value: str): + if value == "https://api.openai.com/": + return "https://api.openai.com/v1" + return value + class Config(BaseModel): program: Program = Program() diff --git a/backend/src/test/models/test_config.py b/backend/src/test/models/test_config.py new file mode 100644 index 00000000..ff4cb3f5 --- /dev/null +++ b/backend/src/test/models/test_config.py @@ -0,0 +1,14 @@ +from module.models.config import ExperimentalOpenAI + + +def test_experimental_openai_validate_api_base(): + config = ExperimentalOpenAI(api_type="openai", api_base="https://api.openai.com/") + assert config.api_base == "https://api.openai.com/v1" + + config = ExperimentalOpenAI(api_base="https://api.openai.com/") + assert config.api_base == "https://api.openai.com/v1" + + config = ExperimentalOpenAI( + api_type="azure", api_base="https://custom-api-base.com" + ) + assert config.api_base == "https://custom-api-base.com"