diff --git a/app/helper/llm.py b/app/helper/llm.py index d5edf0e6..b78369a8 100644 --- a/app/helper/llm.py +++ b/app/helper/llm.py @@ -3,6 +3,7 @@ import asyncio import inspect import time +from functools import wraps from typing import Any, List from app.core.config import settings @@ -74,6 +75,132 @@ def _get_httpx_proxy_key() -> str: return "proxies" +def _deepseek_thinking_toggle(extra_body: Any) -> bool | None: + """ + 解析 DeepSeek extra_body 中显式传入的 thinking 开关。 + """ + if not isinstance(extra_body, dict): + return None + + thinking = extra_body.get("thinking") + if not isinstance(thinking, dict): + return None + + thinking_type = str(thinking.get("type") or "").strip().lower() + if thinking_type == "enabled": + return True + if thinking_type == "disabled": + return False + return None + + +def _is_deepseek_thinking_enabled(model_name: str | None, extra_body: Any) -> bool: + """ + 判断本次 DeepSeek 调用是否处于 thinking mode。 + """ + explicit_toggle = _deepseek_thinking_toggle(extra_body) + if explicit_toggle is not None: + return explicit_toggle + + normalized_model_name = str(model_name or "").strip().lower() + if normalized_model_name == "deepseek-reasoner": + return True + if normalized_model_name.startswith("deepseek-v4-"): + # DeepSeek V4 默认启用 thinking mode,除非显式关闭。 + return True + return False + + +def _extract_input_messages(input_: Any) -> list[Any]: + """ + 将 chat model 输入还原为原始 BaseMessage 序列。 + """ + try: + from langchain_core.messages import convert_to_messages + + return list(convert_to_messages(input_)) + except Exception: + if isinstance(input_, list): + return list(input_) + return [] + + +def _patch_deepseek_reasoning_content_support(): + """ + 修补 langchain-deepseek 在 tool-call 场景下遗漏 reasoning_content 回传的问题。 + + DeepSeek thinking mode 要求:若 assistant 历史消息包含 tool_calls, + 后续请求中必须带回该条消息的顶层 reasoning_content。 + 某些 langchain-deepseek 版本虽然能从响应中拿到 reasoning_content, + 但不会在重放消息历史时写回请求载荷,导致 400。 + """ + try: + from langchain_deepseek import ChatDeepSeek + except Exception as err: + logger.debug(f"跳过 langchain-deepseek reasoning_content 修补:{err}") + return + + if getattr(ChatDeepSeek, "_moviepilot_reasoning_content_patched", False): + return + + original_get_request_payload = getattr(ChatDeepSeek, "_get_request_payload", None) + if not callable(original_get_request_payload): + logger.warning("langchain-deepseek 缺少 _get_request_payload,无法修补 reasoning_content") + return + + @wraps(original_get_request_payload) + def _patched_get_request_payload(self, input_, *, stop=None, **kwargs): + payload = original_get_request_payload(self, input_, stop=stop, **kwargs) + + try: + original_messages = _extract_input_messages(input_) + payload_messages = payload.get("messages") or [] + model_name = getattr(self, "model_name", None) or getattr( + self, "model", None + ) + extra_body = kwargs.get("extra_body") + if extra_body is None: + extra_body = getattr(self, "extra_body", None) + if extra_body is None: + extra_body = getattr(self, "model_kwargs", {}).get("extra_body") + + if not _is_deepseek_thinking_enabled(model_name, extra_body): + return payload + + for index, message in enumerate(payload_messages): + if not isinstance(message, dict): + continue + if message.get("role") != "assistant": + continue + if not message.get("tool_calls"): + continue + if message.get("reasoning_content") is not None: + continue + + reasoning_content = "" + if index < len(original_messages): + additional_kwargs = ( + getattr(original_messages[index], "additional_kwargs", None) + or {} + ) + if isinstance(additional_kwargs, dict): + captured_reasoning = additional_kwargs.get("reasoning_content") + if isinstance(captured_reasoning, str): + reasoning_content = captured_reasoning + + message["reasoning_content"] = reasoning_content + except Exception as err: + logger.warning( + f"修补 langchain-deepseek reasoning_content 请求载荷时失败,将继续使用原始载荷: {err}" + ) + + return payload + + ChatDeepSeek._get_request_payload = _patched_get_request_payload + ChatDeepSeek._moviepilot_reasoning_content_patched = True + logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性") + + class LLMHelper: """LLM模型相关辅助功能""" @@ -437,6 +564,7 @@ class LLMHelper: elif provider_name == "deepseek": from langchain_deepseek import ChatDeepSeek + _patch_deepseek_reasoning_content_support() model = ChatDeepSeek( model=model_name, api_key=api_key_value, diff --git a/tests/test_langchain_deepseek_compat.py b/tests/test_langchain_deepseek_compat.py new file mode 100644 index 00000000..b22d4d8a --- /dev/null +++ b/tests/test_langchain_deepseek_compat.py @@ -0,0 +1,146 @@ +import importlib.util +import sys +import unittest +from pathlib import Path +from types import ModuleType + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + + +def _stub_module(name: str, **attrs): + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + sys.modules[name] = module + for key, value in attrs.items(): + setattr(module, key, value) + return module + + +class _DummyLogger: + def __getattr__(self, _name): + return lambda *args, **kwargs: None + + +def _build_tool_call(name: str = "search", arguments: str = "{}"): + return [ + { + "id": "call_1", + "type": "tool_call", + "name": name, + "args": {}, + } + ] + + +class _FakeChatDeepSeek: + def __init__(self, model_name: str, model_kwargs: dict | None = None): + self.model_name = model_name + self.model_kwargs = model_kwargs or {} + + def _get_request_payload(self, input_, *, stop=None, **kwargs): + messages = [] + for message in input_: + payload_message = { + "role": message.type, + "content": message.content, + } + if message.type == "human": + payload_message["role"] = "user" + elif message.type == "ai": + payload_message["role"] = "assistant" + tool_calls = getattr(message, "tool_calls", None) + if tool_calls: + payload_message["tool_calls"] = tool_calls + elif message.type == "tool": + payload_message["role"] = "tool" + payload_message["tool_call_id"] = message.tool_call_id + messages.append(payload_message) + return {"messages": messages} + + +_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload + + +sys.modules.pop("app.helper.llm", None) +_stub_module( + "app.core.config", + settings=ModuleType("settings"), +) +sys.modules["app.core.config"].settings.LLM_PROVIDER = "deepseek" +sys.modules["app.core.config"].settings.LLM_MODEL = "deepseek-v4-pro" +sys.modules["app.core.config"].settings.LLM_API_KEY = "sk-test" +sys.modules["app.core.config"].settings.LLM_BASE_URL = "https://api.deepseek.com" +sys.modules["app.core.config"].settings.LLM_THINKING_LEVEL = None +sys.modules["app.core.config"].settings.LLM_DISABLE_THINKING = False +sys.modules["app.core.config"].settings.LLM_REASONING_EFFORT = None +sys.modules["app.core.config"].settings.LLM_TEMPERATURE = 0.1 +sys.modules["app.core.config"].settings.LLM_MAX_CONTEXT_TOKENS = 64 +sys.modules["app.core.config"].settings.PROXY_HOST = None +_stub_module("app.log", logger=_DummyLogger()) +_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek) + +module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py" +spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path) +llm_module = importlib.util.module_from_spec(spec) +assert spec and spec.loader +spec.loader.exec_module(llm_module) + + +class DeepSeekCompatPatchTest(unittest.TestCase): + def setUp(self): + _FakeChatDeepSeek._get_request_payload = _ORIGINAL_GET_REQUEST_PAYLOAD + if hasattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched"): + delattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched") + llm_module._patch_deepseek_reasoning_content_support() + + def test_injects_reasoning_content_for_assistant_tool_calls(self): + llm = _FakeChatDeepSeek("deepseek-v4-pro") + messages = [ + HumanMessage(content="天气如何?"), + AIMessage( + content="", + tool_calls=_build_tool_call(), + additional_kwargs={"reasoning_content": "先调用天气工具"}, + ), + ToolMessage(content="晴天", tool_call_id="call_1"), + ] + + payload = llm._get_request_payload(messages) + + self.assertEqual( + payload["messages"][1]["reasoning_content"], + "先调用天气工具", + ) + + def test_falls_back_to_empty_reasoning_content_when_missing(self): + llm = _FakeChatDeepSeek("deepseek-v4-flash") + messages = [ + HumanMessage(content="天气如何?"), + AIMessage(content="", tool_calls=_build_tool_call()), + ToolMessage(content="晴天", tool_call_id="call_1"), + ] + + payload = llm._get_request_payload(messages) + + self.assertIn("reasoning_content", payload["messages"][1]) + self.assertEqual(payload["messages"][1]["reasoning_content"], "") + + def test_skips_injection_when_thinking_is_disabled(self): + llm = _FakeChatDeepSeek( + "deepseek-v4-pro", + model_kwargs={"extra_body": {"thinking": {"type": "disabled"}}}, + ) + messages = [ + HumanMessage(content="天气如何?"), + AIMessage( + content="", + tool_calls=_build_tool_call(), + additional_kwargs={"reasoning_content": "先调用天气工具"}, + ), + ToolMessage(content="晴天", tool_call_id="call_1"), + ] + + payload = llm._get_request_payload(messages) + + self.assertNotIn("reasoning_content", payload["messages"][1]) diff --git a/tests/test_llm_helper_testcall.py b/tests/test_llm_helper_testcall.py index 0873b0c9..7b6000ca 100644 --- a/tests/test_llm_helper_testcall.py +++ b/tests/test_llm_helper_testcall.py @@ -144,6 +144,7 @@ class LlmHelperTestCallTest(unittest.TestCase): def test_get_llm_uses_deepseek_thinking_level_controls(self): calls = [] + patch_calls = [] class _FakeChatDeepSeek: def __init__(self, **kwargs): @@ -154,6 +155,10 @@ class LlmHelperTestCallTest(unittest.TestCase): with patch.dict( sys.modules, {"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)}, + ), patch.object( + llm_module, + "_patch_deepseek_reasoning_content_support", + side_effect=lambda: patch_calls.append(True), ): llm_module.LLMHelper.get_llm( provider="deepseek", @@ -168,11 +173,13 @@ class LlmHelperTestCallTest(unittest.TestCase): calls[0].get("extra_body"), {"thinking": {"type": "enabled"}}, ) + self.assertEqual(patch_calls, [True]) self.assertEqual(calls[0].get("reasoning_effort"), "max") self.assertEqual(calls[0].get("api_base"), "https://api.deepseek.com") def test_get_llm_disables_deepseek_thinking_via_thinking_level(self): calls = [] + patch_calls = [] class _FakeChatDeepSeek: def __init__(self, **kwargs): @@ -183,6 +190,10 @@ class LlmHelperTestCallTest(unittest.TestCase): with patch.dict( sys.modules, {"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)}, + ), patch.object( + llm_module, + "_patch_deepseek_reasoning_content_support", + side_effect=lambda: patch_calls.append(True), ): llm_module.LLMHelper.get_llm( provider="deepseek", @@ -197,6 +208,7 @@ class LlmHelperTestCallTest(unittest.TestCase): calls[0].get("extra_body"), {"thinking": {"type": "disabled"}}, ) + self.assertEqual(patch_calls, [True]) self.assertIsNone(calls[0].get("reasoning_effort")) self.assertEqual(calls[0].get("api_base"), "https://proxy.example.com")