mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-16 15:08:22 +08:00
fix: preserve reasoning content for compatible llms
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -142,9 +142,15 @@ def _patch_deepseek_reasoning_content_support():
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
# Resolve original messages so we can extract reasoning_content from
|
||||
# additional_kwargs. The parent's payload builder does not propagate
|
||||
# this DeepSeek-specific field.
|
||||
extra_body = (getattr(self, "model_kwargs", None) or {}).get("extra_body")
|
||||
if not _is_deepseek_thinking_enabled(
|
||||
getattr(self, "model_name", None) or getattr(self, "model", None),
|
||||
extra_body,
|
||||
):
|
||||
return payload
|
||||
|
||||
# 从原始 LangChain 消息中取回 reasoning_content。上游 payload 构造器
|
||||
# 不会自动透传这个 DeepSeek 扩展字段。
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
|
||||
for i, message in enumerate(payload["messages"]):
|
||||
@@ -152,9 +158,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
message["content"] = json.dumps(message["content"])
|
||||
elif message["role"] == "assistant":
|
||||
if isinstance(message["content"], list):
|
||||
# DeepSeek API expects assistant content to be a string,
|
||||
# not a list. Extract text blocks and join them, or use
|
||||
# empty string if none exist.
|
||||
# DeepSeek API 要求 assistant content 为字符串;工具场景下
|
||||
# LangChain 可能保留为内容块列表,这里只拼回可见文本块。
|
||||
text_parts = [
|
||||
block.get("text", "")
|
||||
for block in message["content"]
|
||||
@@ -162,10 +167,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
]
|
||||
message["content"] = "".join(text_parts) if text_parts else ""
|
||||
|
||||
# DeepSeek reasoning models require every assistant message to
|
||||
# carry a reasoning_content field (even when empty). The value
|
||||
# is stored in AIMessage.additional_kwargs by
|
||||
# _create_chat_result(); re-inject it into the API payload.
|
||||
# DeepSeek thinking mode 要求历史 assistant 消息携带
|
||||
# reasoning_content,即便本地只保存到了 additional_kwargs。
|
||||
if (
|
||||
"reasoning_content" not in message
|
||||
and i < len(messages)
|
||||
@@ -182,6 +185,103 @@ def _patch_deepseek_reasoning_content_support():
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_interleaved_reasoning_content_support():
|
||||
"""
|
||||
修补 OpenAI-compatible 模型的 interleaved reasoning 内容回传。
|
||||
|
||||
小米 MiMo、部分 Kimi/GLM 等兼容端点会把思考内容放在响应顶层
|
||||
`reasoning_content` 字段;如果下一轮请求没有把它随历史 assistant
|
||||
消息带回,工具调用后续请求会被服务端以 400 拒绝。
|
||||
|
||||
这里不按 provider 白名单判断,而是只在历史 AIMessage 真实保存过
|
||||
`reasoning_content` 时回传,避免以后每接入一个同类模型都要单独适配。
|
||||
"""
|
||||
try:
|
||||
import langchain_openai.chat_models.base as _openai_base
|
||||
from langchain_openai import ChatOpenAI
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai reasoning_content 修补:{err}")
|
||||
return
|
||||
|
||||
if not getattr(_openai_base, "_moviepilot_reasoning_response_patched", False):
|
||||
original_convert_dict = getattr(_openai_base, "_convert_dict_to_message", None)
|
||||
original_convert_delta = getattr(
|
||||
_openai_base, "_convert_delta_to_message_chunk", None
|
||||
)
|
||||
|
||||
if callable(original_convert_dict):
|
||||
@wraps(original_convert_dict)
|
||||
def _patched_convert_dict_to_message(message_dict):
|
||||
message = original_convert_dict(message_dict)
|
||||
if (
|
||||
isinstance(message, AIMessage)
|
||||
and "reasoning_content" in message_dict
|
||||
):
|
||||
message.additional_kwargs["reasoning_content"] = (
|
||||
message_dict.get("reasoning_content") or ""
|
||||
)
|
||||
return message
|
||||
|
||||
_openai_base._convert_dict_to_message = _patched_convert_dict_to_message
|
||||
|
||||
if callable(original_convert_delta):
|
||||
@wraps(original_convert_delta)
|
||||
def _patched_convert_delta_to_message_chunk(delta, default_class):
|
||||
chunk = original_convert_delta(delta, default_class)
|
||||
if (
|
||||
isinstance(chunk, AIMessageChunk)
|
||||
and "reasoning_content" in delta
|
||||
):
|
||||
chunk.additional_kwargs["reasoning_content"] = (
|
||||
delta.get("reasoning_content") or ""
|
||||
)
|
||||
return chunk
|
||||
|
||||
_openai_base._convert_delta_to_message_chunk = (
|
||||
_patched_convert_delta_to_message_chunk
|
||||
)
|
||||
|
||||
_openai_base._moviepilot_reasoning_response_patched = True
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_interleaved_reasoning_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-openai 缺少 _get_request_payload,无法修补 reasoning_content")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
if "messages" not in payload:
|
||||
return payload
|
||||
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
for index, payload_message in enumerate(payload["messages"]):
|
||||
if (
|
||||
payload_message.get("role") != "assistant"
|
||||
or index >= len(messages)
|
||||
or not isinstance(messages[index], AIMessage)
|
||||
or "reasoning_content" in payload_message
|
||||
):
|
||||
continue
|
||||
|
||||
reasoning_content = messages[index].additional_kwargs.get(
|
||||
"reasoning_content"
|
||||
)
|
||||
if reasoning_content is not None:
|
||||
# 只回传模型真实返回过的思考字段。普通模型没有该字段时,
|
||||
# payload 保持原样,不额外塞未知参数。
|
||||
payload_message["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
|
||||
ChatOpenAI._get_request_payload = _patched_get_request_payload
|
||||
ChatOpenAI._moviepilot_interleaved_reasoning_patched = True
|
||||
logger.debug("已修补 langchain-openai interleaved reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_instructions_support():
|
||||
"""
|
||||
修补 langchain-openai 在使用 use_responses_api=True 时,
|
||||
@@ -195,6 +295,8 @@ def _patch_openai_responses_instructions_support():
|
||||
logger.debug(f"跳过 langchain-openai instructions 修补:{err}")
|
||||
return
|
||||
|
||||
_patch_openai_interleaved_reasoning_content_support()
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
|
||||
return
|
||||
|
||||
|
||||
@@ -38,6 +38,9 @@ class _FakeChatDeepSeek:
|
||||
self.model_name = model_name
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
def _convert_input(self, input_):
|
||||
return type("_FakeInput", (), {"to_messages": lambda _self: input_})()
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
messages = []
|
||||
for message in input_:
|
||||
@@ -62,7 +65,7 @@ class _FakeChatDeepSeek:
|
||||
_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload
|
||||
|
||||
|
||||
sys.modules.pop("app.helper.llm", None)
|
||||
sys.modules.pop("app.agent.llm.helper", None)
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
settings=ModuleType("settings"),
|
||||
@@ -78,7 +81,7 @@ sys.modules["app.core.config"].settings.PROXY_HOST = None
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
|
||||
module_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "helper.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path)
|
||||
llm_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
|
||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
@@ -30,6 +32,92 @@ class _FakeModel:
|
||||
return SimpleNamespace(content=self._content)
|
||||
|
||||
|
||||
def _build_tool_call(name: str = "search"):
|
||||
return [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
"name": name,
|
||||
"args": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _FakeOpenAIInput:
|
||||
def __init__(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def to_messages(self):
|
||||
return self._messages
|
||||
|
||||
|
||||
class _FakeChatOpenAIForPatch:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = kwargs["model"]
|
||||
self.model_name = kwargs["model"]
|
||||
self.openai_api_base = kwargs.get("base_url")
|
||||
self.profile = None
|
||||
|
||||
def _convert_input(self, input_):
|
||||
return _FakeOpenAIInput(input_)
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
messages = []
|
||||
for message in input_:
|
||||
payload_message = {
|
||||
"role": message.type,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.type == "human":
|
||||
payload_message["role"] = "user"
|
||||
elif message.type == "ai":
|
||||
payload_message["role"] = "assistant"
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload_message["tool_calls"] = tool_calls
|
||||
elif message.type == "tool":
|
||||
payload_message["role"] = "tool"
|
||||
payload_message["tool_call_id"] = message.tool_call_id
|
||||
messages.append(payload_message)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
def _build_fake_openai_modules(chat_openai_cls=_FakeChatOpenAIForPatch):
|
||||
"""构造最小 langchain_openai stub,避免单测触发真实依赖链。"""
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
|
||||
for attr in (
|
||||
"_moviepilot_interleaved_reasoning_patched",
|
||||
"_moviepilot_responses_instructions_patched",
|
||||
):
|
||||
if hasattr(chat_openai_cls, attr):
|
||||
delattr(chat_openai_cls, attr)
|
||||
|
||||
openai_module = ModuleType("langchain_openai")
|
||||
openai_module.__path__ = []
|
||||
openai_module.ChatOpenAI = chat_openai_cls
|
||||
|
||||
chat_models_module = ModuleType("langchain_openai.chat_models")
|
||||
chat_models_module.__path__ = []
|
||||
|
||||
base_module = ModuleType("langchain_openai.chat_models.base")
|
||||
|
||||
def _convert_dict_to_message(message_dict):
|
||||
return AIMessage(content=message_dict.get("content") or "")
|
||||
|
||||
def _convert_delta_to_message_chunk(delta, default_class):
|
||||
return AIMessageChunk(content=delta.get("content") or "")
|
||||
|
||||
base_module._convert_dict_to_message = _convert_dict_to_message
|
||||
base_module._convert_delta_to_message_chunk = _convert_delta_to_message_chunk
|
||||
|
||||
return {
|
||||
"langchain_openai": openai_module,
|
||||
"langchain_openai.chat_models": chat_models_module,
|
||||
"langchain_openai.chat_models.base": base_module,
|
||||
}, base_module
|
||||
|
||||
|
||||
sys.modules.pop("app.agent.llm.helper", None)
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
@@ -144,6 +232,97 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
|
||||
def test_openai_compatible_patch_preserves_stream_reasoning_content(self):
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
|
||||
fake_modules, openai_base = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
llm_module._patch_openai_interleaved_reasoning_content_support()
|
||||
|
||||
chunk = openai_base._convert_delta_to_message_chunk(
|
||||
{"role": "assistant", "content": "", "reasoning_content": "先调用工具"},
|
||||
AIMessageChunk,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
chunk.additional_kwargs.get("reasoning_content"),
|
||||
"先调用工具",
|
||||
)
|
||||
|
||||
def test_openai_compatible_patch_injects_xiaomi_reasoning_content(self):
|
||||
fake_modules, _ = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
llm_module._patch_openai_interleaved_reasoning_content_support()
|
||||
llm = _FakeChatOpenAIForPatch(
|
||||
model="mimo-v2.5-pro",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.xiaomimimo.com/v1",
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先调用天气工具"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertEqual(
|
||||
payload["messages"][1]["reasoning_content"],
|
||||
"先调用天气工具",
|
||||
)
|
||||
|
||||
def test_openai_compatible_patch_injects_any_model_with_reasoning_content(self):
|
||||
fake_modules, _ = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
llm_module._patch_openai_interleaved_reasoning_content_support()
|
||||
llm = _FakeChatOpenAIForPatch(
|
||||
model="glm-5",
|
||||
api_key="sk-test",
|
||||
base_url="https://open.bigmodel.cn/api/paas/v4",
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先规划工具调用"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertEqual(
|
||||
payload["messages"][1]["reasoning_content"],
|
||||
"先规划工具调用",
|
||||
)
|
||||
|
||||
def test_openai_compatible_patch_skips_when_reasoning_content_missing(self):
|
||||
fake_modules, _ = _build_fake_openai_modules()
|
||||
with patch.dict(sys.modules, fake_modules):
|
||||
llm_module._patch_openai_interleaved_reasoning_content_support()
|
||||
llm = _FakeChatOpenAIForPatch(
|
||||
model="gpt-4o-mini",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertNotIn("reasoning_content", payload["messages"][1])
|
||||
|
||||
def test_get_llm_uses_deepseek_thinking_level_controls(self):
|
||||
calls = []
|
||||
patch_calls = []
|
||||
@@ -308,6 +487,50 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
"https://updated.example.com/v1",
|
||||
)
|
||||
|
||||
def test_get_llm_keeps_openai_patch_global_without_model_marker(self):
|
||||
class _FakeProviderManager:
|
||||
async def resolve_runtime(self, **kwargs):
|
||||
return {
|
||||
"provider_id": kwargs["provider_id"],
|
||||
"runtime": "openai_compatible",
|
||||
"model_id": kwargs["model"],
|
||||
"api_key": kwargs["api_key"],
|
||||
"base_url": kwargs["base_url"],
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"model_record": None,
|
||||
"model_metadata": {},
|
||||
}
|
||||
|
||||
provider_module = ModuleType("app.agent.llm.provider")
|
||||
provider_module.LLMProviderManager = _FakeProviderManager
|
||||
fake_openai_modules, _ = _build_fake_openai_modules()
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
"app.agent.llm.provider": provider_module,
|
||||
**fake_openai_modules,
|
||||
},
|
||||
):
|
||||
created = asyncio.run(
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="openai",
|
||||
model="mimo-v2.5-pro",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.xiaomimimo.com/v1",
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
getattr(
|
||||
sys.modules["langchain_openai"].ChatOpenAI,
|
||||
"_moviepilot_interleaved_reasoning_patched",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(hasattr(created, "_moviepilot_interleaved_reasoning_field"))
|
||||
|
||||
def test_get_llm_maps_unified_max_to_openai_xhigh(self):
|
||||
calls = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user