From e6f029ba728edff2494c84e4cadc95d3654ad19a Mon Sep 17 00:00:00 2001 From: 100gle Date: Tue, 21 Nov 2023 21:40:16 +0800 Subject: [PATCH] bugfix(experimental): add api version for openai api base automatically --- backend/src/module/models/config.py | 8 +++++++- backend/src/test/models/test_config.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 backend/src/test/models/test_config.py diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 49044fa3..1ebaa430 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -1,7 +1,7 @@ from os.path import expandvars from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator class Program(BaseModel): @@ -102,6 +102,12 @@ class ExperimentalOpenAI(BaseModel): "", description="Azure OpenAI deployment id, ignored when api type is openai" ) + @validator("api_base") + def validate_api_base(cls, value: str): + if value == "https://api.openai.com/": + return "https://api.openai.com/v1" + return value + class Config(BaseModel): program: Program = Program() diff --git a/backend/src/test/models/test_config.py b/backend/src/test/models/test_config.py new file mode 100644 index 00000000..ff4cb3f5 --- /dev/null +++ b/backend/src/test/models/test_config.py @@ -0,0 +1,14 @@ +from module.models.config import ExperimentalOpenAI + + +def test_experimental_openai_validate_api_base(): + config = ExperimentalOpenAI(api_type="openai", api_base="https://api.openai.com/") + assert config.api_base == "https://api.openai.com/v1" + + config = ExperimentalOpenAI(api_base="https://api.openai.com/") + assert config.api_base == "https://api.openai.com/v1" + + config = ExperimentalOpenAI( + api_type="azure", api_base="https://custom-api-base.com" + ) + assert config.api_base == "https://custom-api-base.com"