fix: preserve reasoning content for compatible llms

This commit is contained in:
jxxghp
2026-05-14 14:01:53 +08:00
parent 0f3a4e4c15
commit 4322f8a3c1
3 changed files with 341 additions and 13 deletions

View File

@@ -7,7 +7,7 @@ import time
from functools import wraps
from typing import Any, List
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, AIMessageChunk
from app.core.config import settings
from app.log import logger
@@ -142,9 +142,15 @@ def _patch_deepseek_reasoning_content_support():
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
# Resolve original messages so we can extract reasoning_content from
# additional_kwargs. The parent's payload builder does not propagate
# this DeepSeek-specific field.
extra_body = (getattr(self, "model_kwargs", None) or {}).get("extra_body")
if not _is_deepseek_thinking_enabled(
getattr(self, "model_name", None) or getattr(self, "model", None),
extra_body,
):
return payload
# 从原始 LangChain 消息中取回 reasoning_content。上游 payload 构造器
# 不会自动透传这个 DeepSeek 扩展字段。
messages = self._convert_input(input_).to_messages()
for i, message in enumerate(payload["messages"]):
@@ -152,9 +158,8 @@ def _patch_deepseek_reasoning_content_support():
message["content"] = json.dumps(message["content"])
elif message["role"] == "assistant":
if isinstance(message["content"], list):
# DeepSeek API expects assistant content to be a string,
# not a list. Extract text blocks and join them, or use
# empty string if none exist.
# DeepSeek API 要求 assistant content 为字符串;工具场景下
# LangChain 可能保留为内容块列表,这里只拼回可见文本块。
text_parts = [
block.get("text", "")
for block in message["content"]
@@ -162,10 +167,8 @@ def _patch_deepseek_reasoning_content_support():
]
message["content"] = "".join(text_parts) if text_parts else ""
# DeepSeek reasoning models require every assistant message to
# carry a reasoning_content field (even when empty). The value
# is stored in AIMessage.additional_kwargs by
# _create_chat_result(); re-inject it into the API payload.
# DeepSeek thinking mode 要求历史 assistant 消息携带
# reasoning_content,即便本地只保存到了 additional_kwargs。
if (
"reasoning_content" not in message
and i < len(messages)
@@ -182,6 +185,103 @@ def _patch_deepseek_reasoning_content_support():
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
def _patch_openai_interleaved_reasoning_content_support():
"""
修补 OpenAI-compatible 模型的 interleaved reasoning 内容回传。
小米 MiMo、部分 Kimi/GLM 等兼容端点会把思考内容放在响应顶层
`reasoning_content` 字段;如果下一轮请求没有把它随历史 assistant
消息带回,工具调用后续请求会被服务端以 400 拒绝。
这里不按 provider 白名单判断,而是只在历史 AIMessage 真实保存过
`reasoning_content` 时回传,避免以后每接入一个同类模型都要单独适配。
"""
try:
import langchain_openai.chat_models.base as _openai_base
from langchain_openai import ChatOpenAI
except Exception as err:
logger.debug(f"跳过 langchain-openai reasoning_content 修补:{err}")
return
if not getattr(_openai_base, "_moviepilot_reasoning_response_patched", False):
original_convert_dict = getattr(_openai_base, "_convert_dict_to_message", None)
original_convert_delta = getattr(
_openai_base, "_convert_delta_to_message_chunk", None
)
if callable(original_convert_dict):
@wraps(original_convert_dict)
def _patched_convert_dict_to_message(message_dict):
message = original_convert_dict(message_dict)
if (
isinstance(message, AIMessage)
and "reasoning_content" in message_dict
):
message.additional_kwargs["reasoning_content"] = (
message_dict.get("reasoning_content") or ""
)
return message
_openai_base._convert_dict_to_message = _patched_convert_dict_to_message
if callable(original_convert_delta):
@wraps(original_convert_delta)
def _patched_convert_delta_to_message_chunk(delta, default_class):
chunk = original_convert_delta(delta, default_class)
if (
isinstance(chunk, AIMessageChunk)
and "reasoning_content" in delta
):
chunk.additional_kwargs["reasoning_content"] = (
delta.get("reasoning_content") or ""
)
return chunk
_openai_base._convert_delta_to_message_chunk = (
_patched_convert_delta_to_message_chunk
)
_openai_base._moviepilot_reasoning_response_patched = True
if getattr(ChatOpenAI, "_moviepilot_interleaved_reasoning_patched", False):
return
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
if not callable(original_get_request_payload):
logger.warning("langchain-openai 缺少 _get_request_payload无法修补 reasoning_content")
return
@wraps(original_get_request_payload)
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
if "messages" not in payload:
return payload
messages = self._convert_input(input_).to_messages()
for index, payload_message in enumerate(payload["messages"]):
if (
payload_message.get("role") != "assistant"
or index >= len(messages)
or not isinstance(messages[index], AIMessage)
or "reasoning_content" in payload_message
):
continue
reasoning_content = messages[index].additional_kwargs.get(
"reasoning_content"
)
if reasoning_content is not None:
# 只回传模型真实返回过的思考字段。普通模型没有该字段时,
# payload 保持原样,不额外塞未知参数。
payload_message["reasoning_content"] = reasoning_content
return payload
ChatOpenAI._get_request_payload = _patched_get_request_payload
ChatOpenAI._moviepilot_interleaved_reasoning_patched = True
logger.debug("已修补 langchain-openai interleaved reasoning_content 回传兼容性")
def _patch_openai_responses_instructions_support():
"""
修补 langchain-openai 在使用 use_responses_api=True 时,
@@ -195,6 +295,8 @@ def _patch_openai_responses_instructions_support():
logger.debug(f"跳过 langchain-openai instructions 修补:{err}")
return
_patch_openai_interleaved_reasoning_content_support()
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
return

View File

@@ -38,6 +38,9 @@ class _FakeChatDeepSeek:
self.model_name = model_name
self.model_kwargs = model_kwargs or {}
def _convert_input(self, input_):
return type("_FakeInput", (), {"to_messages": lambda _self: input_})()
def _get_request_payload(self, input_, *, stop=None, **kwargs):
messages = []
for message in input_:
@@ -62,7 +65,7 @@ class _FakeChatDeepSeek:
_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload
sys.modules.pop("app.helper.llm", None)
sys.modules.pop("app.agent.llm.helper", None)
_stub_module(
"app.core.config",
settings=ModuleType("settings"),
@@ -78,7 +81,7 @@ sys.modules["app.core.config"].settings.PROXY_HOST = None
_stub_module("app.log", logger=_DummyLogger())
_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek)
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
module_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "helper.py"
spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path)
llm_module = importlib.util.module_from_spec(spec)
assert spec and spec.loader

View File

@@ -6,6 +6,8 @@ from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import AsyncMock, patch
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
@@ -30,6 +32,92 @@ class _FakeModel:
return SimpleNamespace(content=self._content)
def _build_tool_call(name: str = "search"):
return [
{
"id": "call_1",
"type": "tool_call",
"name": name,
"args": {},
}
]
class _FakeOpenAIInput:
def __init__(self, messages):
self._messages = messages
def to_messages(self):
return self._messages
class _FakeChatOpenAIForPatch:
def __init__(self, **kwargs):
self.model = kwargs["model"]
self.model_name = kwargs["model"]
self.openai_api_base = kwargs.get("base_url")
self.profile = None
def _convert_input(self, input_):
return _FakeOpenAIInput(input_)
def _get_request_payload(self, input_, *, stop=None, **kwargs):
messages = []
for message in input_:
payload_message = {
"role": message.type,
"content": message.content,
}
if message.type == "human":
payload_message["role"] = "user"
elif message.type == "ai":
payload_message["role"] = "assistant"
tool_calls = getattr(message, "tool_calls", None)
if tool_calls:
payload_message["tool_calls"] = tool_calls
elif message.type == "tool":
payload_message["role"] = "tool"
payload_message["tool_call_id"] = message.tool_call_id
messages.append(payload_message)
return {"messages": messages}
def _build_fake_openai_modules(chat_openai_cls=_FakeChatOpenAIForPatch):
"""构造最小 langchain_openai stub避免单测触发真实依赖链。"""
from langchain_core.messages import AIMessageChunk
for attr in (
"_moviepilot_interleaved_reasoning_patched",
"_moviepilot_responses_instructions_patched",
):
if hasattr(chat_openai_cls, attr):
delattr(chat_openai_cls, attr)
openai_module = ModuleType("langchain_openai")
openai_module.__path__ = []
openai_module.ChatOpenAI = chat_openai_cls
chat_models_module = ModuleType("langchain_openai.chat_models")
chat_models_module.__path__ = []
base_module = ModuleType("langchain_openai.chat_models.base")
def _convert_dict_to_message(message_dict):
return AIMessage(content=message_dict.get("content") or "")
def _convert_delta_to_message_chunk(delta, default_class):
return AIMessageChunk(content=delta.get("content") or "")
base_module._convert_dict_to_message = _convert_dict_to_message
base_module._convert_delta_to_message_chunk = _convert_delta_to_message_chunk
return {
"langchain_openai": openai_module,
"langchain_openai.chat_models": chat_models_module,
"langchain_openai.chat_models.base": base_module,
}, base_module
sys.modules.pop("app.agent.llm.helper", None)
_stub_module(
"app.core.config",
@@ -144,6 +232,97 @@ class LlmHelperTestCallTest(unittest.TestCase):
{"thinking": {"type": "disabled"}},
)
def test_openai_compatible_patch_preserves_stream_reasoning_content(self):
from langchain_core.messages import AIMessageChunk
fake_modules, openai_base = _build_fake_openai_modules()
with patch.dict(sys.modules, fake_modules):
llm_module._patch_openai_interleaved_reasoning_content_support()
chunk = openai_base._convert_delta_to_message_chunk(
{"role": "assistant", "content": "", "reasoning_content": "先调用工具"},
AIMessageChunk,
)
self.assertEqual(
chunk.additional_kwargs.get("reasoning_content"),
"先调用工具",
)
def test_openai_compatible_patch_injects_xiaomi_reasoning_content(self):
fake_modules, _ = _build_fake_openai_modules()
with patch.dict(sys.modules, fake_modules):
llm_module._patch_openai_interleaved_reasoning_content_support()
llm = _FakeChatOpenAIForPatch(
model="mimo-v2.5-pro",
api_key="sk-test",
base_url="https://api.xiaomimimo.com/v1",
)
messages = [
HumanMessage(content="天气如何?"),
AIMessage(
content="",
tool_calls=_build_tool_call(),
additional_kwargs={"reasoning_content": "先调用天气工具"},
),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertEqual(
payload["messages"][1]["reasoning_content"],
"先调用天气工具",
)
def test_openai_compatible_patch_injects_any_model_with_reasoning_content(self):
fake_modules, _ = _build_fake_openai_modules()
with patch.dict(sys.modules, fake_modules):
llm_module._patch_openai_interleaved_reasoning_content_support()
llm = _FakeChatOpenAIForPatch(
model="glm-5",
api_key="sk-test",
base_url="https://open.bigmodel.cn/api/paas/v4",
)
messages = [
HumanMessage(content="天气如何?"),
AIMessage(
content="",
tool_calls=_build_tool_call(),
additional_kwargs={"reasoning_content": "先规划工具调用"},
),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertEqual(
payload["messages"][1]["reasoning_content"],
"先规划工具调用",
)
def test_openai_compatible_patch_skips_when_reasoning_content_missing(self):
fake_modules, _ = _build_fake_openai_modules()
with patch.dict(sys.modules, fake_modules):
llm_module._patch_openai_interleaved_reasoning_content_support()
llm = _FakeChatOpenAIForPatch(
model="gpt-4o-mini",
api_key="sk-test",
base_url="https://api.openai.com/v1",
)
messages = [
HumanMessage(content="天气如何?"),
AIMessage(
content="",
tool_calls=_build_tool_call(),
),
ToolMessage(content="晴天", tool_call_id="call_1"),
]
payload = llm._get_request_payload(messages)
self.assertNotIn("reasoning_content", payload["messages"][1])
def test_get_llm_uses_deepseek_thinking_level_controls(self):
calls = []
patch_calls = []
@@ -308,6 +487,50 @@ class LlmHelperTestCallTest(unittest.TestCase):
"https://updated.example.com/v1",
)
def test_get_llm_keeps_openai_patch_global_without_model_marker(self):
class _FakeProviderManager:
async def resolve_runtime(self, **kwargs):
return {
"provider_id": kwargs["provider_id"],
"runtime": "openai_compatible",
"model_id": kwargs["model"],
"api_key": kwargs["api_key"],
"base_url": kwargs["base_url"],
"default_headers": None,
"use_responses_api": None,
"model_record": None,
"model_metadata": {},
}
provider_module = ModuleType("app.agent.llm.provider")
provider_module.LLMProviderManager = _FakeProviderManager
fake_openai_modules, _ = _build_fake_openai_modules()
with patch.dict(
sys.modules,
{
"app.agent.llm.provider": provider_module,
**fake_openai_modules,
},
):
created = asyncio.run(
llm_module.LLMHelper.get_llm(
provider="openai",
model="mimo-v2.5-pro",
api_key="sk-test",
base_url="https://api.xiaomimimo.com/v1",
)
)
self.assertTrue(
getattr(
sys.modules["langchain_openai"].ChatOpenAI,
"_moviepilot_interleaved_reasoning_patched",
False,
)
)
self.assertFalse(hasattr(created, "_moviepilot_interleaved_reasoning_field"))
def test_get_llm_maps_unified_max_to_openai_xhigh(self):
calls = []