-VLGK.png)
LangGraph
LangGraph 框架
什么是 LangGraph?
LangGraph 是 LangChain 发布的一个多智能体框架。通过建立在 LangChain 之上,LangGraph 使开发人员可以轻松创建强大的智能体。
LangGraph 的核心功能
- 支持循环流: LangGraph 允许定义包含循环的流程,这对于大多数代理架构至关重要。这使得 LangGraph 更适合构建需要记忆和上下文推理的应用程序。
- 状态管理: LangGraph 提供了状态管理功能,允许代理在多个步骤之间存储和检索信息。这对于构建需要跟踪对话状态或游戏状态的应用程序至关重要。
- 多参与者支持: LangGraph 支持多个代理相互交互,以实现更复杂的工作流程。这使得 LangGraph 非常适合构建需要协作或竞争的代理应用程序。
- 可扩展性: LangGraph 可以扩展到生产环境,以支持大规模应用程序。
LangGraph 和 LangChain 的区别
LangGraph 和 LangChain 是两个相关但不同的工具,都来自 LangChain 生态系统。
LangChain
LangChain 是一个用于构建大语言模型应用程序的框架
- 线性工作流:主要支持顺序执行的链式操作
- 组件库:提供丰富的预构建组件,如提示模板、向量存储、检索器等
- 简单集成:易于快速原型开发和简单的 LLM 应用
- 抽象层:为不同的 LLM 提供统一接口
LangGraph
LangGraph 是 LangChain 团队开发的更高级工具,专门用于构建复杂的多智能体系统:
- 图状工作流:支持复杂的分支、循环和条件逻辑
- 状态管理:内置强大的状态管理机制
- 多智能体协作:原生支持多个 AI 智能体之间的交互
- 复杂决策流:可以根据条件动态选择执行路径
- 持久化:支持长时间运行的工作流和状态持久化
主要区别
复杂性处理:
- LangChain 适合简单到中等复杂度的应用
- LangGraph 专为复杂的多步骤、多智能体场景设计
工作流结构:
- LangChain 主要是链式(Chain)结构
- LangGraph 是图状(Graph)结构,支持任意的节点连接
LangGraph 安装和使用
pip install -U "langgraph>=0.6.1"
简单 Agent
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
from langchain_community.tools import TavilySearchResults
from langchain_core.tools import tool
from datetime import datetime
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 定义工具函数
@tool
def search_web(query: str) -> str:
"""搜索网络信息的工具"""
t_search = TavilySearchResults()
return t_search.run(query)
@tool
def get_data_tool():
"""获取目前日期的工具"""
return datetime.now().date()
tools = [search_web, get_data_tool]
system_prompt = """你是一个智能助手。你有以下工具可以使用:
1. search_web: 用于搜索互联网获取最新信息,特别是产品价格、新闻、实时数据等
2. get_current_date: 获取今天的日期
3. get_current_time: 获取当前的日期和时间
重要规则:
- 当用户询问产品价格、最新信息、新闻等需要实时数据的问题时,必须使用search_web工具
- 当用户询问时间或日期时,使用相应的时间工具
- 如果你的知识库中没有准确或最新的信息,应该使用搜索工具
- 优先使用工具获取准确信息,而不是依赖可能过时的训练数据
请根据用户问题选择合适的工具来获取准确答案。"""
agent = create_react_agent(model=llm,
tools=tools,
prompt=system_prompt
)
response = agent.invoke({"messages": [{"role": "user", "content": "小米yu7价格"}]})
print(response["messages"][3].content)
LangGraph 基础知识(核心概念)
Graph(流程图)
LangGraph 的核心是将代理工作流程建模为图表。您可以使用三个关键组件来定义代理的行为:
State
:表示应用程序当前快照的共享数据结构。它可以是任何 Python 类型,但通常是TypedDict
或 PydanticBaseModel
。Nodes
:用于编码代理逻辑的 Python 函数。它们接收当前值State
作为输入,执行一些计算或副作用,并返回更新后的State
。Edges
:根据当前条件确定下一步执行哪个操作的 Python 函数State
。它们可以是条件分支或固定转换。
通过组合 Nodes
和 Edges
,您可以创建复杂的循环工作流,使其 State
随时间推移而演化。然而,真正的强大之处在于 LangGraph 对 的管理方式 State
。需要强调的是:Nodes
和 Edges
只不过是 Python 函数而已——它们可以包含 LLM 代码,也可以只是经典的 Python 代码。
简而言之:节点负责工作,边负责告诉下一步做什么。
状态图
StateGraph
类是主要使用的图形类。这是由用户定义的 State
对象参数化的。
通俗来说,它是一张流程图 + 状态管理系统,定义了:
- 哪些步骤(节点)要执行?
- 每一步之间怎么跳转(边)?
- 整个流程中数据状态如何流动和更新(状态)?
为什么叫“状态图”而不是“流程图”?
LangGraph 不只是流程控制,还强调:
- 每个节点执行前、执行后都可以访问和修改状态(state)
- 状态是图的“血液”,在节点之间流动
- 节点的跳转可以依据状态来判断(如条件跳转)
所以叫做 State Graph(有状态的流程图),而不是“静态流程图”。
from typing import TypedDict**
**from langgraph.graph import StateGraph**
**# 定义状态结构**
**class MyState(TypedDict):**
** question: str**
** answer: str**
**# 定义节点函数**
**def search_node(state):**
** return {"answer": "这是答案"}**
**# 创建状态图**
**builder = StateGraph(state_schema=MyState)**
**# 添加一个节点**
**builder.add_node("search", search_node)**
**# 第一个要调用的节点**
**builder.set_entry_point("search")**
**# 要构建图,首先要定义状态,然后添加节点和边,最后进行编译,会进行基本的检查**
**graph = builder.compile()**
**# 执行图**
**result = graph.invoke({"question": "什么是状态图?"})**
**print(result["answer"]) # 输出:这是答案
State(状态)
在使用 LangGraph 构建流程图之前,第一件事就是定义图的状态 State
。这是整个图运行中用于共享和传递信息的核心机制。
什么是 State?
LangGraph 中的 State 是图中所有节点(Node)之间传递数据的模式结构,可以类比为一个共享的上下文字典,它包含输入、输出、中间变量等。
定义 State 时,需要包含两个部分:
- Schema(模式):指定 State 的字段结构(可以用
TypedDict
或Pydantic
)
# langgraph推荐使用TypedDict
"""
1. TypedDict 是标准库的一部分(来自 typing 模块),零依赖,零性能开销而 Pydantic 会在每一步创建模型实例,会增加运行时负担
2. LangGraph 中的 State 实质就是一个字典(dict),而 TypedDict 就是“有类型注解的 dict”,与 LangGraph 的执行机制无缝对接,而 Pydantic 是类结构,需要 .dict() 转换,略显多余
"""
from typing import TypedDict
class State1(TypedDict):
user_input: str
# 使用 pydantic 可以进行参数校验和提供默认值
from pydantic import BaseModel
class State2(BaseModel):
question: str
result: str = ""
- **多个模式(Multiple Schemas):**在大多数情况下,LangGraph 使用一个统一的 State 模式。但你也可以设置“输入模式”和“输出模式”分开
- 输入模式:接收用户输入的字段(如
question
) - 输出模式:只保留最终输出的字段(如
final_answer
)
- 输入模式:接收用户输入的字段(如
from typing import TypedDict
from langgraph.graph import StateGraph
# 1. 定义输入、输出、图内部的状态结构
# 输入字段:用户的问题
class InputState(TypedDict):
question: str
# 中间状态:包括中间结果
class InternalState(TypedDict):
question: str
search_result: str
final_answer: str
# 输出字段:只想返回最终答案
class OutputState(TypedDict):
final_answer: str
# 2. 定义节点函数(中间节点用中间字段)
def search_node(state: InternalState) -> dict:
return {"search_result": f"搜索了:{state['question']}"}
def answer_node(state: InternalState) -> dict:
return {"final_answer": f"根据搜索结果:{state['search_result']},这是答案"}
# 3. 创建 StateGraph,显式指定输入/输出 Schema
builder = StateGraph(state_schema=InternalState,
input_schema=InputState,
output_schema=OutputState)
# 4. 添加节点
builder.add_node("search", search_node)
builder.add_node("answer", answer_node)
# 5. 配置流程
builder.set_entry_point("search")
builder.add_edge("search", "answer")
# 6. 编译并执行图
app = builder.compile()
result = app.invoke({"question": "什么是LangGraph?"})
print(result) # {'final_answer': '根据搜索结果:搜索了:什么是LangGraph?,这是答案'}
- Reducer(归并函数):在 LangGraph 中,所有节点返回的都是“局部更新结果”,Reducer 是用于合并多个节点输出更新的机制。 将每个节点返回的“局部状态更新”统一合并进全局的 State。
from typing import Annotated
from typing_extensions import TypedDict
from operator import add
class State(TypedDict):
foo: int
bar: Annotated[list[str], add] # 每条消息是 {role, content},会自动追加到列表末尾
使用图形状态中的消息
为什么要使用消息?
大多数现代 LLM 提供商都提供聊天模型接口,接受消息列表作为输入。LangChain ChatModel
尤其接受对象列表 Message
作为输入。这些消息有多种形式,例如 HumanMessage
(用户输入)或 AIMessage
(LLM 响应)。
在图表中使用消息
在许多情况下,将之前的对话历史记录以消息列表的形式存储在图状态中会很有帮助。为此,我们可以向图状态添加一个键(通道),该键存储 Message
对象列表,并使用 Reducer 函数对其进行注释。Reducer 函数对于指示图如何 Message
在每次状态更新(例如,当节点发送更新时)时更新状态中的对象列表至关重要。如果您未指定 Reducer,则每次状态更新都会用最新提供的值覆盖消息列表。如果您只想将消息附加到现有列表中,可以使用 operator.add
。
operator 是 Python 的一个内置模块,**把常见的运算符(如 +、-、==、getitem 等)变成了函数**,方便函数式编程和高阶函数使用。
from typing import TypedDict
from langgraph.graph import StateGraph
from typing import Annotated
import operator
# 定义状态结构 如果定义的是list[dict],会覆盖之前的数据
class ChatState(TypedDict):
messages: Annotated[list, operator.add] # 每条消息是 {role, content},会自动追加到列表末尾
# 节点函数:添加用户问题
def user_input_node(state: ChatState) -> dict:
user_msg = {"role": "user", "content": "什么是LangGraph?"}
return {"messages": [user_msg]}
# 节点函数:添加助手回复
def assistant_node(state: ChatState) -> dict:
reply = {"role": "assistant", "content": "LangGraph 是一个有状态的图编排框架。"}
return {"messages": [reply]}
# 构建状态图
builder = StateGraph(state_schema=ChatState)
builder.add_node("user_input", user_input_node)
builder.add_node("assistant_reply", assistant_node)
builder.set_entry_point("user_input")
builder.add_edge("user_input", "assistant_reply")
graph = builder.compile()
result = graph.invoke({"messages": []})
print(result["messages"])
有场景可能还需要手动更新图状态中的消息(例如,人机交互)。
如果您使用 operator.add
,您发送到图的手动状态更新将被附加到现有消息列表中,而不是更新现有消息。
为了避免这种情况,您需要一个能够跟踪消息 ID 并在更新时覆盖现有消息的 Reducer。
为此,您可以使用预构建 add_messages
函数。
对于新消息,它只会附加到现有列表中,但它也会正确处理现有消息的更新。
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict
class GraphState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
MessagesState
由于在状态中包含消息列表非常常见,因此存在一个名为 MessagesState
的预建状态,它使使用消息变得非常简单。该状态 MessagesState
使用单个键定义 messages
,该键是对象列表 AnyMessage
并使用 add_messages
。通常,需要跟踪的状态不仅仅是消息,因此我们看到人们将这个状态子类化并添加更多字段,例如:
from langgraph.graph import MessagesState
# 和上述代码不同会在State类中自动维护一个messages 字段,不需要显示创建
class State(MessagesState):
documents: list[str]
Node(节点)
节点(Nodes)是图中执行逻辑的基本单位。每个节点表示一个函数步骤、处理阶段或子逻辑流程,多个节点通过边连接成有向图,组成一个完整的有状态计算流程。
LangGraph 中的节点就是你定义的一个函数(或 Runnable 对象),用于接收状态、执行逻辑,并返回更新后的状态
def my_node(state: dict) -> dict:
# 处理输入状态,并返回更新字段
return {"new_key": "new_value"}
# LangGraph 会自动用 reducer 把这些更新合并进全局状态。
START
节点
Node START
是一个特殊节点,表示将用户输入发送到图的节点。引用此节点的主要目的是确定应首先调用哪些节点。
(API 参考:START)
from langgraph.graph import START
graph.add_edge(START, "node_a")
END
节点
Node END
是一个特殊节点,表示终端节点。当需要指示哪些边在完成后没有操作时,可以引用此节点。
from langgraph.graph import END
graph.add_edge("node_a", END)
并行运行节点
import operator
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
class State(TypedDict):
# The operator.add reducer fn makes this append-only
aggregate: Annotated[list, operator.add]
def a(state: State):
print(f'Adding "A" to {state["aggregate"]}')
return {"aggregate": ["A"]}
def b(state: State):
print(f'Adding "B" to {state["aggregate"]}')
return {"aggregate": ["B"]}
def c(state: State):
print(f'Adding "C" to {state["aggregate"]}')
return {"aggregate": ["C"]}
def d(state: State):
print(f'Adding "D" to {state["aggregate"]}')
return {"aggregate": ["D"]}
builder = StateGraph(State)
builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()
print(graph.invoke({"aggregate": ["start"]}))
Edge(边/跳转)
Edge(边) 是连接节点的通道,表示图中节点之间的执行跳转关系。你可以把它理解为「节点执行完之后,下一步去哪,是构成 LangGraph 流程图的核心。
[!TIP]
Edge 是 LangGraph 中连接两个节点的“执行路径”,控制流程的走向。
- 普通边:直接从一个节点到下一个节点。
graph.add_edge("节点A", "节点B")
- 条件边:调用一个函数来确定下一步要去哪个节点。
from typing import TypedDict
from langgraph.graph import StateGraph, END
class MyState(TypedDict):
type: str
result: str
def judge_node(state: MyState):
"""节点函数:可以做一些预处理"""
return state # 保持状态不变,只是路由
def route_condition(state: MyState):
"""条件函数:只负责路由决策"""
if state["type"] == "a":
return "a"
elif state["type"] == "b":
return "b"
else:
return "default"
def node_a(state):
return {"result": "走了 A 分支"}
def node_b(state):
return {"result": "走了 B 分支"}
def node_default(state):
return {"result": "走了默认分支"}
# 构建图
graph = StateGraph(state_schema=MyState)
# 定义节点
graph.add_node("judge_node", judge_node)
graph.add_node("a", node_a)
graph.add_node("b", node_b)
graph.add_node("default", node_default)
# 定义开始节点
graph.set_entry_point("judge_node")
# 使用不同的函数作为条件函数
graph.add_conditional_edges("judge_node", route_condition, {
"a": "a",
"b": "b",
"default": "default"
})
# 添加结束边
graph.add_edge("a", END)
graph.add_edge("b", END)
graph.add_edge("default", END)
app = graph.compile()
# 测试
print("测试 A:", app.invoke({"type": "a", "result": ""}))
- 入口点:当图开始运行时首先运行的第一个(些)节点。
from langgraph.graph import START
graph.add_edge(START, "node_a")
- 条件入口点:调用一个函数来确定当用户输入到达时首先调用哪个节点。
from typing import TypedDict
from langgraph.graph import StateGraph, END
class MyState(TypedDict):
user_type: str # "vip", "normal", "guest"
message: str
result: str
# 定义不同的处理节点
def vip_service(state):
"""VIP 用户服务"""
return {"result": f"VIP专享服务: {state['message']}"}
def normal_service(state):
"""普通用户服务"""
return {"result": f"标准服务: {state['message']}"}
def guest_service(state):
"""游客服务"""
return {"result": f"游客服务(功能受限): {state['message']}"}
# 条件入口点函数
def route_by_user_type(state):
"""根据用户类型路由到不同的服务"""
user_type = state["user_type"]
if user_type == "vip":
return "vip_service"
elif user_type == "normal":
return "normal_service"
else:
return "guest_service"
# 构建图
workflow = StateGraph(state_schema=MyState)
# 添加节点
workflow.add_node("vip_service", vip_service)
workflow.add_node("normal_service", normal_service)
workflow.add_node("guest_service", guest_service)
# 设置条件入口点 - 关键部分!
workflow.set_conditional_entry_point(
route_by_user_type, # 条件函数
{
"vip_service": "vip_service",
"normal_service": "normal_service",
"guest_service": "guest_service"
}
)
# 添加结束边
workflow.add_edge("vip_service", END)
workflow.add_edge("normal_service", END)
workflow.add_edge("guest_service", END)
# 编译图
app = workflow.compile()
# 测试 VIP 用户
result1 = app.invoke({
"user_type": "vip",
"message": "我要退款",
"result": ""
})
print("VIP用户:", result1)
Send 发送
默认情况下,Nodes
和 Edges
是提前定义的,并在相同的共享状态下运行。但是,在某些情况下,确切的边无法提前知道,并且可能希望同时存在 State
的不同版本。
send("节点名", 更新数据)
的作用是:告诉 LangGraph 把这部分更新数据发送给指定节点,让它继续执行。
主要用途
- 条件路由:根据某些条件将消息发送到不同的节点
- 并行处理:同时向多个节点发送消息
- 动态工作流:根据运行时状态决定消息的发送目标
from langgraph.graph import StateGraph, END
from langgraph.constants import Send
from typing import TypedDict
class MyState(TypedDict):
messages: list
result: str
def process_input(state: MyState) -> Send:
"""处理输入并决定发送到哪个节点"""
if len(state["messages"]) > 5:
return Send("complex_processor", state)
else:
return Send("simple_processor", state)
def simple_processor(state: MyState) -> MyState:
"""简单处理器"""
return {"messages": state["messages"], "result": "simple"}
def complex_processor(state: MyState) -> MyState:
"""复杂处理器"""
return {"messages": state["messages"], "result": "complex"}
# 构建图
graph = StateGraph(state_schema=MyState)
graph.add_node("input", process_input)
graph.add_node("simple_processor", simple_processor)
graph.add_node("complex_processor", complex_processor)
graph.set_entry_point("input")
graph.add_edge("simple_processor", END)
graph.add_edge("complex_processor", END)
app = graph.compile()
massages = []
for i in range(3):
massages.append(i)
print(app.invoke({"messages": massages}))
和 add_conditional_edges 有什么区别呢?
- add_conditional_edges:外部决策。路由逻辑在一个单独的函数中,它在节点运行之后被调用,根据图的当前状态来决定下一步去哪里。
- Send:内部决策。路由逻辑在节点自身的函数体内,节点在运行时直接、显式地决定将其结果发送到哪个特定的节点
两者通常在需要并发处理的时候配合使用
Map-Reduce 模式
**Map-Reduce **是一种经典的并行计算模式,特别适合处理大规模数据。
Map-Reduce 将复杂的数据处理任务分解为两个阶段:
- Map 阶段:将大任务分解为多个小任务,并行处理
- Reduce 阶段:将所有小任务的结果合并成最终结果
"""_
_LangGraph Map-Reduce 简单案例:数字求和_
_把一堆数字分给多个worker算平方,然后把结果加起来_
_"""_
_from typing import Annotated_
_import operator_
_from langgraph.graph import StateGraph, START, END_
_from langgraph.constants import Send_
_from typing import TypedDict, List_
_# 状态定义_
_class State(TypedDict):_
_ numbers: List[int] # 输入的数字_
_ results: Annotated[list[int], operator.add] # worker的结果_
_ final_sum: int # 最终求和_
_# 1. Map阶段:分发数字_
_def split_numbers(state: State):_
_ """把数字分发给不同的worker"""_
_ numbers = state["numbers"]_
_ print(f"📦 分发数字: {numbers}")_
_ # 每个数字发给一个worker_
_ return [Send("worker", {"number": num}) for num in numbers]_
_# 2. Worker阶段:计算平方_
_def calculate_square(state: State):_
_ """每个worker计算一个数字的平方"""_
_ number = state["number"]_
_ square = number * number_
_ print(f"⚡ Worker: {number}² = {square}")_
_ return {"results": [square]}_
_# 3. Reduce阶段:求和_
_def sum_results(state: State):_
_ """把所有结果加起来"""_
_ results = state.get("results", [])_
_ total = sum(results)_
_ print(f"📊 求和: {results} = {total}")_
_ return {"final_sum": total}_
_# 构建图_
_def create_simple_graph():_
_ graph = StateGraph(State)_
_ # 添加节点_
_ graph.add_node("splitter", lambda s: s) # 分发器_
_ graph.add_node("worker", calculate_square) # 工作节点_
_ graph.add_node("summer", sum_results) # 求和器_
_ # 连接节点_
_ graph.add_edge(START, "splitter")_
_ graph.add_conditional_edges("splitter", split_numbers, ["worker"]) # Map阶段_
_ graph.add_edge("worker", "summer") # Worker完成后求和_
_ graph.add_edge("summer", END)_
_ return graph.compile()_
_# 运行例子_
_def run_example():_
_ app = create_simple_graph()_
_ # 测试数据_
_ initial_state = {_
_ "numbers": [1, 2, 3, 4, 5],_
_ "results": [],_
_ "final_sum": 0_
_ }_
_ print("🚀 开始计算...")_
_ print("任务:计算每个数字的平方,然后求和")_
_ print()_
_ # 运行_
_ result = app.invoke(initial_state)_
_ print(result)_
_if __name__ == "__main__":_
_ run_example()
Command
命令
将控制流(边)和状态更新(节点)结合在一起可能非常有用。例如,您可能希望在同一个节点中既执行状态更新,又决定接下来要去哪个节点。
[!TIP]
在节点函数中返回时Command
,必须添加返回类型注释,其中包含节点路由到的节点名称列表,例如Command[Literal["my_other_node"]]
。这对于图形渲染是必需的,它告诉 LangGraphmy_node
可以导航到my_other_node
。
from typing import TypedDict
from langgraph.graph import StateGraph, END
from langgraph.types import Command, Literal, Send
class MyState(TypedDict):
type: str
text: str
result: str
def judge_node(state: MyState) -> Command[Literal["a", "b", "default"]]:
"""条件函数:使用Command进行路由和状态更新"""
if state["type"] == "a":
return Command(update={"text": "走了 A 分支"}, goto="a")
elif state["type"] == "b":
# Command只是更新当前节点结束的状态
# return Command(update={"text": "走了 B 分支"}, goto="b")
return Send("b", state)
else:
return Command(update={"text": "走了默认分支"}, goto="default")
def node_a(state):
return {"result": f"A节点处理: {state['text']}"}
def node_b(state):
return {"result": f"B节点处理: {state['text']}"}
def node_default(state):
return {"result": f"默认节点处理: {state['text']}"}
# 构建图
graph = StateGraph(state_schema=MyState)
graph.add_node("judge_node", judge_node)
graph.add_node("a", node_a)
graph.add_node("b", node_b)
graph.add_node("default", node_default)
graph.set_entry_point("judge_node")
# 添加结束边
graph.add_edge("a", END)
graph.add_edge("b", END)
graph.add_edge("default", END)
app = graph.compile()
# 测试
print("测试 A:", app.invoke({"type": "a", "text": "", "result": ""}))
print("测试 B:", app.invoke({"type": "b", "text": "", "result": ""}))
print("测试其他:", app.invoke({"type": "default", "text": "", "result": ""}))
什么时候应该使用命令而不是条件边?
Command
当需要同时更新图形状态和路由到其他节点时使用。例如,在实现多代理切换时,需要路由到其他代理并向该代理传递一些信息。在进行 command 更新状态的时候,更新的属性必须符合初始化状态的内容
使用条件边在节点之间有条件地路由而不更新状态。
配置 Runtime
创建图时,还可以标记图的某些部分是可配置的。这样做通常是为了方便在模型或系统提示之间切换。这允许创建单个“认知架构”(图),但拥有多个不同的实例。
在运行图时提供额外的“配置参数”而不是“状态参数”,并且通过类型约束这些参数。
from langgraph.graph import StateGraph
from langgraph.runtime import Runtime
from langchain_community.chat_models import ChatZhipuAI
from typing import TypedDict
# 定义状态结构
class MyState(TypedDict):
question: str
answer: str
# 定义配置结构
class MyContext(TypedDict):
language: str # 配置中包含语言选项,比如 "en" 或 "zh"
# 节点函数可以访问 runtime 参数 runtime 可以访问上下文和内存存储
def step1(state: MyState, runtime: Runtime[MyContext]):
if runtime.context["language"] == "zh":
answer = "你好!"
else:
answer = "Hello!"
return {"answer": answer}
# 构建图
graph = StateGraph(state_schema=MyState, context_schema=MyContext)
graph.add_node("step1", step1)
graph.set_entry_point("step1")
# 编译
app = graph.compile()
# 执行时传入 config 参数(区分于 state)
result = app.invoke({"question": "Hi"}, context={"language": "zh"})
print(result) # => {"question": "Hi", "answer": "你好!"}
在运行时指定 llm
from langgraph.graph import MessagesState
from langgraph.runtime import Runtime
from langgraph.graph import END, StateGraph, START
from typing_extensions import TypedDict
class MyContext(TypedDict):
model: str
MODELS = {
"anthropic": "anthropic:claude-3-5-haiku-latest",
"openai": "openai:gpt-4.1-mini",
}
def call_model(state: MessagesState, runtime: Runtime[MyContext]):
model = ""
if runtime.context:
model = runtime.context["model"]
model = MODELS[model]
return {"messages": {"role": "assistant", "content": model}}
builder = StateGraph(MessagesState, context_schema=MyContext)
builder.add_node("model", call_model)
builder.add_edge(START, "model")
builder.add_edge("model", END)
graph = builder.compile()
# Usage
input_message = {"role": "user", "content": "hi"}
# With no configuration, uses default (Anthropic)
response_1 = graph.invoke({"messages": [input_message]})
# Or, can set OpenAI
context = {"model": "openai"}
response_2 = graph.invoke({"messages": [input_message]}, context=context)
print(response_1)
print(response_2)
递归限制
递归限制设置图在单次执行中可以执行的最大超步数。一旦达到限制,LangGraph 将出现 GraphRecursionError
。默认情况下,此值设置为 25 步。可以在运行时在任何图上设置递归限制,并将其传递给 .invoke
/.stream
通过配置字典。重要的是,recursion_limit
是一个独立的 config
键,不应 configurable
像所有其他用户定义的配置一样在键内传递。
graph.invoke(inputs, config={"recursion_limit": 5, "configurable":{"llm": "anthropic"}})
可视化图表
"""_
_LangGraph Map-Reduce 简单案例:数字求和_
_把一堆数字分给多个worker算平方,然后把结果加起来_
_"""_
_from typing import Annotated_
_import operator_
_from langgraph.graph import StateGraph, START, END_
_from langgraph.types import Send_
_from typing import TypedDict, List_
_# 状态定义_
_class State(TypedDict):_
_ numbers: List[int] # 输入的数字_
_ results: Annotated[list[int], operator.add] # worker的结果_
_ final_sum: int # 最终求和_
_# 1. Map阶段:分发数字_
_def split_numbers(state: State):_
_ """把数字分发给不同的worker"""_
_ numbers = state["numbers"]_
_ print(f"📦 分发数字: {numbers}")_
_ # 每个数字发给一个worker_
_ return [Send("worker", {"number": num}) for num in numbers]_
_# 2. Worker阶段:计算平方_
_def calculate_square(state: State):_
_ """每个worker计算一个数字的平方"""_
_ number = state["number"]_
_ square = number * number_
_ print(f"⚡ Worker: {number}² = {square}")_
_ return {"results": [square]}_
_# 3. Reduce阶段:求和_
_def sum_results(state: State):_
_ """把所有结果加起来"""_
_ results = state.get("results", [])_
_ total = sum(results)_
_ print(f"📊 求和: {results} = {total}")_
_ return {"final_sum": total}_
_# 构建图_
_def create_simple_graph():_
_ graph = StateGraph(State)_
_ # 添加节点_
_ graph.add_node("splitter", lambda s: s) # 分发器_
_ graph.add_node("worker", calculate_square) # 工作节点_
_ graph.add_node("summer", sum_results) # 求和器_
_ # 连接节点_
_ graph.add_edge(START, "splitter")_
_ graph.add_conditional_edges("splitter", split_numbers, ["worker"]) # Map阶段_
_ graph.add_edge("worker", "summer") # Worker完成后求和_
_ graph.add_edge("summer", END)_
_ return graph.compile()_
_# 运行例子_
_def run_example():_
_ app = create_simple_graph()_
_ # 测试数据_
_ initial_state = {_
_ "numbers": [1, 2, 3, 4, 5],_
_ "results": [],_
_ "final_sum": 0_
_ }_
_ print("🚀 开始计算...")_
_ print("任务:计算每个数字的平方,然后求和")_
_ print()_
_ # 运行_
_ app.invoke(initial_state)_
_ from IPython.display import Image, display_
_ from langchain_core.runnables.graph import MermaidDrawMethod_
_ display(_
_ Image(_
_ app.get_graph().draw_mermaid_png(_
_ draw_method=MermaidDrawMethod.API,_
_ output_file_path='./可视化图.png'_
_ )_
_ )_
_ )_
_if __name__ == "__main__":_
_ run_example()
子图
LangGraph 子图(Subgraph)是一种模块化的图结构,允许您将复杂的工作流分解为更小的、可重用的组件。就像函数在编程中的作用一样,子图提供了封装和复用的能力。
子图的优势
- 代码复用:避免重复编写相同的逻辑
- 清晰的架构:将复杂流程分解为清晰的模块
- 易于维护:修改子图只需在一个地方进行
- 团队协作:不同团队可以独立开发不同的子图
- 测试友好:可以单独测试子图的功能
两种状态通讯
1.共享状态键(Shared State Keys)
父图和子图在其状态模式中有共享的状态键。在这种情况下,您可以将子图作为节点包含在父图中。
from langgraph.graph import StateGraph, MessagesState, START
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建子图
def subplot(state: MessagesState) -> MessagesState:
# 获取大模型回答的内容进行摘要总结
answer = state["messages"][-1].content
summary_prompt = f"请用一句话总结下面这句话:\n\n答:{answer}"
response = llm.invoke(summary_prompt)
return {"messages": state["messages"] + [response]}
summary_subgraph = (
StateGraph(state_schema=MessagesState)
.add_node("subplot", subplot)
.add_edge(START, "subplot")
.compile()
)
# 创建父图
def llm_answer_node(state: MessagesState) -> MessagesState:
# 使用大模型进行回答
answer = llm.invoke(state["messages"])
print("父图输出", answer)
return {"messages": state["messages"] + [answer]}
parent_graph = (
StateGraph(MessagesState)
.add_node("llm_answer", llm_answer_node)
.add_node("summarize_subgraph", summary_subgraph)
.add_edge(START, "llm_answer")
.add_edge("llm_answer", "summarize_subgraph")
.compile()
)
# 测试
input_state = {
"messages": [{"role": "user", "content": "langgraph是什么?"}],
}
result = parent_graph.invoke(input_state)
print(result)
2.不同状态模式(Different State Schemas)常用
父图和子图有不同的模式(状态模式中没有共享的状态键)。在这种情况下,您必须在父图的节点内部调用子图:这在父图和子图有不同状态模式且需要在调用子图前后转换状态时很有用。
from langgraph.graph import StateGraph, MessagesState, START
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建子图
class SubgraphMessagesState(TypedDict):
subgraph_messages: Annotated[list[AnyMessage], add_messages]
def subplot(state: SubgraphMessagesState) -> SubgraphMessagesState:
# 获取大模型回答的内容进行摘要总结
answer = state["subgraph_messages"][-1].content
summary_prompt = f"请用一句话总结下面这句话:\n\n答:{answer}"
response = llm.invoke(summary_prompt)
print("\n\n")
print("子图中问题和输出:", state["subgraph_messages"] + [response])
return {"subgraph_messages": [response]}
summary_subgraph = (
StateGraph(state_schema=SubgraphMessagesState)
.add_node("subplot", subplot)
.add_edge(START, "subplot")
.compile()
)
# 创建父图
def llm_answer_node(state: MessagesState) -> MessagesState:
# 使用大模型进行回答
answer = llm.invoke(state["messages"])
print("父图中问题和输出:", state["messages"] + [answer])
# 转换状态格式
summary_result = summary_subgraph.invoke({"subgraph_messages": state["messages"] + [answer]})
return {"messages": state["messages"] + [answer]+ [summary_result["subgraph_messages"][2]]}
parent_graph = (
StateGraph(state_schema=MessagesState)
.add_node("llm_answer", llm_answer_node)
.add_edge(START, "llm_answer")
.compile()
)
# 测试输入
input_state = {
"messages": [{"role": "user", "content": "langgraph是什么?"}],
}
result = parent_graph.invoke(input_state)
print("最终结果:", result)
流式输出
LangGraph 实施了流媒体系统来显示实时更新,从而实现响应迅速且透明的用户体验。
LangGraph 的流式传输系统可将图形运行的实时反馈显示到您的应用中。
流式输出在 LangGraph 中的重要性:
⚡ 用户立即看到反馈
🎯 减少等待时间
💾 节省内存使用
😊 提升用户体验
将以下一个或多个流模式作为列表传递给 stream()
或 astream()
方法:
五种模式代码详解
from typing import Annotated
import operator
from langgraph.graph import StateGraph, START, END
from langgraph.constants import Send
from typing import TypedDict, List
"""values、updates、debug模式"""
# 状态定义
class State(TypedDict):
numbers: List[int] # 输入的数字
results: Annotated[list[int], operator.add] # worker的结果
final_sum: int # 最终求和
# 1. Map阶段:分发数字
def split_numbers(state: State):
_"""把数字分发给不同的worker"""_
_ _numbers = state["numbers"]
# 每个数字发给一个worker
return [Send("worker", {"number": num}) for num in numbers]
# 2. Worker阶段:计算平方
def calculate_square(state: State):
_"""每个worker计算一个数字的平方"""_
_ _number = state["number"]
square = number * number
return {"results": [square]}
# 3. Reduce阶段:求和
def sum_results(state: State):
_"""把所有结果加起来"""_
_ _results = state.get("results", [])
total = sum(results)
return {"final_sum": total}
# 构建图
def create_simple_graph():
graph = StateGraph(state_schema=State)
# 添加节点
graph.add_node("splitter", lambda s: s) # 分发器
graph.add_node("worker", calculate_square) # 工作节点
graph.add_node("summer", sum_results) # 求和器
# 连接节点
graph.add_edge(START, "splitter")
graph.add_conditional_edges("splitter", split_numbers, ["worker"]) # Map阶段
graph.add_edge("worker", "summer") # Worker完成后求和
graph.add_edge("summer", END)
return graph.compile()
"""messages模式"""
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
class MyState(TypedDict):
question: str
results: str
def generate_answer(state: MyState):
question = state["question"]
answer = llm.invoke([
{"role": "user", "content": f"{question}"}
])
return {"answer": answer.content}
# 构建图
def create_llm_graph():
graph = StateGraph(state_schema=MyState)
# 添加节点
graph.add_node("generate_answer", generate_answer)
# 连接节点
graph.add_edge(START, "generate_answer")
graph.add_edge("generate_answer", END)
return graph.compile()
"""自定义模式"""
from langgraph.config import get_stream_writer
import time
# 定义状态
class FileState(TypedDict):
filename: str # 文件名称
content: str # 文件内容
word_count: int # 内容数量
processed: bool # 是否处理
def read_file(state: FileState):
_"""步骤1:读取文件"""_
_ _writer = get_stream_writer()
# 发送开始信息
writer({"step": "读取文件", "status": "开始", "progress": 0})
time.sleep(1)
# 发送进度信息
writer({"step": "读取文件", "status": "正在读取...", "progress": 50})
time.sleep(1)
# 模拟文件内容
content = "这是一个示例文件,包含一些文本内容。"
# 发送完成信息
writer({
"step": "读取文件",
"status": "完成",
"progress": 100,
"data": {"size": len(content)}
})
return {"content": content}
def count_words(state: FileState):
_"""步骤2:统计字数"""_
_ _writer = get_stream_writer()
writer({"step": "统计字数", "status": "开始", "progress": 0})
time.sleep(0.5)
writer({"step": "统计字数", "status": "正在分析...", "progress": 30})
time.sleep(1)
writer({"step": "统计字数", "status": "计算中...", "progress": 70})
time.sleep(0.5)
# 计算字数
word_count = len(state["content"])
writer({
"step": "统计字数",
"status": "完成",
"progress": 100,
"data": {"word_count": word_count}
})
return {"word_count": word_count}
def finalize_processing(state: FileState):
_"""步骤3:完成处理"""_
_ _writer = get_stream_writer()
writer({"step": "完成处理", "status": "生成报告", "progress": 50})
time.sleep(1)
writer({
"step": "完成处理",
"status": "全部完成",
"progress": 100,
"data": {
"filename": state["filename"],
"total_chars": state["word_count"],
"summary": f"文件 {state['filename']} 处理完成,共 {state['word_count']} 个字符"
}
})
return {"processed": True}
# 构建图
def create_custom_graph():
graph = (
StateGraph(state_schema=FileState)
.add_node("read_file", read_file)
.add_node("count_words", count_words)
.add_node("finalize", finalize_processing)
.add_edge(START, "read_file")
.add_edge("read_file", "count_words")
.add_edge("count_words", "finalize")
.compile()
)
return graph
# 运行例子
def run_example():
app = create_simple_graph()
app1 = create_llm_graph()
app2 = create_custom_graph()
# 测试数据
initial_state = {
"numbers": [1, 2, 3, 4, 5],
"results": [],
"final_sum": 0
}
print("====================VALUES模式=====================")
for result in app.stream(initial_state, stream_mode="values"):
print(result)
print("====================UPDATES模式=====================")
for result in app.stream(initial_state, stream_mode="updates"):
print(result)
print("====================DEBUG模式=====================")
for result in app.stream(initial_state, stream_mode="debug"):
print(result)
print("====================MESSAGES模式=====================")
for result in app1.stream({"question": "什么是状态图?"}, stream_mode="messages"):
print(result[0].content)
print("====================CUSTOM模式=====================")
# 初始状态
initial_state1 = {
"filename": "example.txt",
"content": "",
"word_count": 0,
"processed": False
}
# 使用Custom模式运行
for chunk in app2.stream(initial_state1, stream_mode="custom"):
step = chunk.get("step", "")
status = chunk.get("status", "")
progress = chunk.get("progress", 0)
data = chunk.get("data", {})
# 显示进度
progress_bar = "█" * (progress // 10) + "░" * (10 - progress // 10)
print(f"\n[{step}] {status}")
print(f"进度: [{progress_bar}] {progress}%")
# 显示额外数据
if data:
for key, value in data.items():
print(f"📊 {key}: {value}")
print("====================融合多种模式=====================")
for mode, result in app.stream(initial_state, stream_mode=["values", "updates"]):
print("模式:", mode)
print(result)
if __name__ == "__main__":
run_example()
融合多种模式
可以传递一个列表作为 stream_mode
参数来同时传输多种模式
输出将是流模式名称和 (mode, chunk)
print("====================融合多种模式=====================")
for mode, result in app.stream(initial_state, stream_mode=["values", "updates"]):
print("模式:", mode)
print(result)
从工具中流式传输数据
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
from langchain_tavily import TavilySearch
from langgraph.config import get_stream_writer
from langchain_core.tools import tool
from datetime import datetime
import time
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 定义工具函数
@tool
def search_web(query: str):
_"""搜索网络信息的工具"""_
_ _writer = get_stream_writer()
writer({"step": "搜索", "status": "请等待...", "progress": 0})
time.sleep(0.5)
writer({"step": "搜索", "status": "请等待...", "progress": 50})
time.sleep(0.5)
writer({
"step": "搜索",
"status": "完成",
"progress": 100,
"data": {"结果": TavilySearch().invoke(query)}
})
@tool
def get_data_tool():
_"""获取目前日期的工具"""_
_ _return datetime.now().date()
tools = [search_web, get_data_tool]
system_prompt = """你是一个智能助手。你有以下工具可以使用:
1. search_web: 用于搜索互联网获取最新信息,特别是产品价格、新闻、实时数据等
2. get_current_date: 获取今天的日期
重要规则:
- 当用户询问产品价格、最新信息、新闻等需要实时数据的问题时,必须使用search_web工具
- 当用户询问时间或日期时,使用相应的时间工具
- 如果你的知识库中没有准确或最新的信息,应该使用搜索工具
- 优先使用工具获取准确信息,而不是依赖可能过时的训练数据
请根据用户问题选择合适的工具来获取准确答案。"""
agent = create_react_agent(model=llm,
tools=tools,
prompt=system_prompt
)
for chunk in agent.stream({"messages": [{"role": "user", "content": "小米yu7价格"}]}, stream_mode="custom"):
step = chunk.get("step", "")
status = chunk.get("status", "")
progress = chunk.get("progress", 0)
data = chunk.get("data", {})
# 显示进度
progress_bar = "█" * (progress // 10) + "░" * (10 - progress // 10)
print(f"\n[{step}] {status}")
print(f"进度: [{progress_bar}] {progress}%")
# 显示额外数据
if data:
for key, value in data.items():
print(f"📊 {key}: {value}")
从子图中进行流式传输
只需要在父图的方法中 stream()
设置 subgraphs=True
出将作为元组进行流式传输 (namespace, data)
,其中 namespace
是包含调用子图的节点路径的元组,例如 ("parent_node:<task_id>", "child_node:<task_id>")
。
from langgraph.graph import StateGraph, MessagesState, START
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建子图
class SubgraphMessagesState(TypedDict):
subgraph_messages: Annotated[list[AnyMessage], add_messages]
def subplot(state: SubgraphMessagesState) -> SubgraphMessagesState:
# 获取大模型回答的内容进行摘要总结
answer = state["subgraph_messages"][-1].content
summary_prompt = f"请用一句话总结下面这句话:\n\n答:{answer}"
response = llm.invoke(summary_prompt)
# print("\n\n")
# print("子图中问题和输出:", state["subgraph_messages"] + [response])
return {"subgraph_messages": [response]}
summary_subgraph = (
StateGraph(state_schema=SubgraphMessagesState)
.add_node("subplot", subplot)
.add_edge(START, "subplot")
.compile()
)
# 创建父图
def llm_answer_node(state: MessagesState) -> MessagesState:
# 使用大模型进行回答
answer = llm.invoke(state["messages"])
# print("父图中问题和输出:", state["messages"] + [answer])
# 转换状态格式
summary_result = summary_subgraph.invoke({"subgraph_messages": state["messages"] + [answer]})
return {"messages": state["messages"] + [answer, summary_result["subgraph_messages"][2]]}
parent_graph = (
StateGraph(state_schema=MessagesState)
.add_node("llm_answer", llm_answer_node)
.add_edge(START, "llm_answer")
.compile()
)
# 测试输入
input_state = {
"messages": [{"role": "user", "content": "langgraph是什么?"}],
}
for chunk in parent_graph.stream(
input_state,
stream_mode="updates",
subgraphs=True,
):
print(chunk)
禁用特定聊天模型的流式传输
如果您的应用程序将支持流式传输的模型与不支持流式传输的模型混合使用,则可能需要明确禁用不支持流式传输的模型。
model = init_chat_model(
"anthropic:claude-3-7-sonnet-latest",
disable_streaming=True
)
持久性
检查点(Checkpointing)是 LangGraph 持久性的核心机制。它允许你在图执行过程中的任何点保存状态,并在需要时恢复。
核心概念
- 检查点(Checkpoint): 图状态的快照
- 线程(Thread): 用于访问检查点的唯一标识
- 检查点保存器(Checkpointer): 负责保存和恢复状态的组件
线程(Threads)
线程是检查点保存器保存的每个检查点分配的唯一 ID 或线程标识符
当使用检查点调用图表时,必须指定 thread_id
作为 configurable
配置部分的一部分:
# 调用图时必须指定 thread_id
config = {"configurable": {"thread_id": "unique_thread_id"}}
result = graph.invoke(input_data, config=config)
特点
- 每个线程代表一个独立的对话或执行上下文
- 线程允许在图执行后访问图的状态
- 支持多个并发线程
from langchain.chat_models import init_chat_model
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, MessagesState, START
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
def process_message(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": response}
builder = StateGraph(state_schema=MessagesState)
builder.add_node("process_message", process_message)
builder.add_edge(START, "process_message")
# # 没有使用持久性
# graph = builder.compile()
#
# input_message = {"role": "user", "content": "你好呀!我的名字叫初见"}
# for chunk in graph.stream({"messages": [input_message]}, stream_mode="values"):
# chunk["messages"][-1].pretty_print()
#
# input_message = {"role": "user", "content": "我的名字叫什么?"}
# for chunk in graph.stream({"messages": [input_message]}, stream_mode="values"):
# chunk["messages"][-1].pretty_print()
# 使用持久性
checkpointer = InMemorySaver()
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "user_123"}}
input_message = {"role": "user", "content": "你好呀!我的名字叫初见"}
for chunk in graph.stream({"messages": [input_message]}, config, stream_mode="values"):
chunk["messages"][-1].pretty_print()
input_message = {"role": "user", "content": "我的名字叫什么?"}
for chunk in graph.stream({"messages": [input_message]}, config, stream_mode="values"):
chunk["messages"][-1].pretty_print()
检查点
检查点是在每个超级步骤中保存的图状态的快照,由 StateSnapshot
具有以下关键属性的对象表示:
config
:与此检查点相关的配置。metadata
:与此检查点相关的元数据。values
:当前State
的值。也就是图执行到目前为止,所有变量的状态值(如"messages"
,"steps"
,"results"
等字段的值)。next
图中接下来要执行的节点名称的元组。tasks
:包含具体要执行的任务的详细信息,用PregelTask
类型表示。比next
更详细
LangGraph 中检查点的作用
本质理解
[!TIP]
LangGraph 中的图是围绕 State** 状态对象** 构建的:
每一步(Node)执行时会读取State
,返回一个新的State
。
所谓的“检查点”就是:
在某个节点运行后,把当时的 State 存起来(比如存到数据库或磁盘)
然后如果下次因为任何原因中断或重新运行,只需:
加载上次的检查点状态 State,重新进入图流程
获取状态内容
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.memory import InMemorySaver
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建子图
def subplot(state: MessagesState) -> MessagesState:
# 获取大模型回答的内容进行摘要总结
answer = state["messages"][-1].content
summary_prompt = f"请用一句话总结下面这句话:\n\n答:{answer}"
response = llm.invoke(summary_prompt)
return {"messages": state["messages"] + [response]}
summary_subgraph = (
StateGraph(state_schema=MessagesState)
.add_node("subplot", subplot)
.add_edge(START, "subplot")
.compile()
)
# 创建父图
def llm_answer_node(state: MessagesState) -> MessagesState:
# 使用大模型进行回答
answer = llm.invoke(state["messages"])
print("父图输出", answer)
return {"messages": state["messages"] + [answer]}
checkpointer = InMemorySaver()
parent_graph = (
StateGraph(MessagesState)
.add_node("llm_answer", llm_answer_node)
.add_node("summarize_subgraph", summary_subgraph)
.add_edge(START, "llm_answer")
.add_edge("llm_answer", "summarize_subgraph")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "1"}}
# 测试输入
input_state = {
"messages": [{"role": "user", "content": "langgraph是什么?请用100字介绍"}],
}
for chunk in parent_graph.stream(
input_state,
config,
stream_mode="updates",
subgraphs=True
):
# print(chunk)
pass
print("================获取状态=================")
"""
与已保存的图表状态交互时,必须指定线程标识符。您可以通过调用来查看图表的最新graph.get_state(config)状态。这将返回一个StateSnapshot对象
"""
print(parent_graph.get_state(config).values)
print("================状态历史记录=================")
"""
可以通过调用 获取给定线程的图形执行的完整历史记录graph.get_state_history(config)。
这将返回与配置中提供的线程 ID 关联的对象列表StateSnapshot。
重要的是,检查点将按时间顺序排序,最新的检查点 /StateSnapshot将位于列表中的第一个。
注意:这里采用的是共享状态的子图,可以将子图的内容持久化,如果使用的是不同状态的就需要分别存储
"""
history = list(parent_graph.get_state_history(config))
for idx, snapshot in enumerate(history):
print(f"Step {idx}:")
print(f" Checkpoint ID: {snapshot.config['configurable']['checkpoint_id']}")
print(f" Node: {snapshot.metadata.get('source')}")
print(f" Messages: {[m.content for m in snapshot.values['messages']]}")
print("")
print("================重放机制=================")
"""
可以重放先前的图执行。如果使用 thread_id 和 checkpoint_id 调用图,LangGraph 会:
1.重放检查点之前已执行的步骤(不重新执行)
2.执行检查点之后的步骤(即使之前已执行)
注意:必须传递这些内容thread_id, checkpoint_id
重要特性:
LangGraph 知道特定步骤是否已执行
检查点前的步骤会被重放(不重新执行)
检查点后的步骤会被重新执行,不包括当前检查点(创建新分支)
"""
# 获取Step2检查点,开始进行重放
step2_level_checkpoint = None
if history:
step2_level_checkpoint = list(history)[2].config['configurable']['checkpoint_id']
print("第二步检查点:", step2_level_checkpoint)
config = {"configurable": {"thread_id": "1", "checkpoint_id": step2_level_checkpoint}}
# 开始执行
result = parent_graph.invoke(None, config=config)
print(result)
更新对应状态
使用 graph.update_state()
方法编辑图状态。当其中某个节点需要人为进行控制的时候,需要用更新状态来确定是否执行剩下流程。
方法参数
- config
- 必须包含
thread_id
指定要更新的线程 - 可选包含
checkpoint_id
来分叉选定的检查点
- values
- 用于更新状态的值
- 更新会传递给 reducer 函数(如果定义了)
- 没有 reducer 的通道会被覆盖
- as_node
- 可选参数,指定更新来自哪个节点
- 影响下一步执行的节点
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
from typing_extensions import TypedDict
from typing import Annotated
from operator import add
class TaskState(TypedDict):
task_id: str # 任务id
title: str # 标题
assignee: str # 接收任务的人
priority: int # 优先级
comments: Annotated[list, add] # 评论
status: str # 状态
def create_task(state: TaskState):
"""创建任务"""
return {
"status": "进行中",
"comments": [f"任务 '{state['title']}' 已创建,分配给 {state['assignee']}"]
}
def update_task(state: TaskState):
"""更新任务"""
return {
"status": "已更新",
"comments": ["任务状态已更新"]
}
def plan_task(state: TaskState):
"""准备任务完成"""
return {
"status": "已准备",
"comments": [f"任务 '{state['title']}' 已准备"]
}
# 创建任务管理流程
workflow = StateGraph(state_schema=TaskState)
workflow.add_node("create", create_task)
workflow.add_node("update", update_task)
workflow.add_node("complete", plan_task)
workflow.add_edge(START, "create")
workflow.add_edge("update", "complete")
workflow.add_edge("complete", END)
checkpointer = InMemorySaver()
app = workflow.compile(checkpointer=checkpointer)
# 创建任务
config = {"configurable": {"thread_id": "task_001"}}
result = app.invoke({
"task_id": "T001",
"title": "开发注册功能",
"assignee": "张三",
"priority": 1,
"comments": [],
"status": "待分配"
}, config)
print("=== 任务创建完成 ===")
print(f"状态: {result['status']}")
print("评论:", result['comments'])
# 演示 as_node 参数, update节点需要手动使用update_state进行触发, 才能执行后续操作
print("\n=== 使用 as_node 参数 ===")
app.update_state(config, {
"status": "已更新",
"comments": ["自动通知:任务优先级提升"],
"priority": 3
}, as_node="update")
final_result = app.invoke(None, config)
# 查看最终状态
final_state = app.get_state(config)
print("最终状态:", final_result["status"])
print("完整评论:", final_result["comments"])
# 手动更新状态 - 添加评论
print("\n=== 手动添加评论 ===")
app.update_state(config, {
"comments": ["项目经理:请在周五前完成"],
"priority": 2
})
# 查看更新后的状态
updated_state = app.get_state(config)
print(f"优先级: {updated_state.values['priority']}")
print("所有评论:", updated_state.values['comments'])
# 继续更新 - 添加更多评论
print("\n=== 添加更多评论 ===")
app.update_state(config, {
"comments": ["张三:已完成开发"],
"status": "开发完成"
})
final_state = app.get_state(config)
print(f"最终状态: {final_state.values['status']}")
print("完整评论历史:", final_state.values['comments'])
记忆存储
状态模式指定在图执行时填充的键集合。但如果我们想在线程之间保留信息怎么办?
Store 接口
- 检查点保存器单独无法跨线程共享信息
- Store 接口解决了这个问题
- 可以在所有聊天对话中保留用户特定信息
基础用法
每种内存类型都是一个具有特定属性的 Python 类(Item
)。我们可以通过上述转换将其作为字典访问 .dict
。它具有以下属性:
value
:此内存的值(本身就是一个字典)key
:此命名空间中此内存的唯一键namespace
:字符串列表,此内存类型的命名空间created_at
:此内存创建的时间戳updated_at
:此内存更新的时间戳
print("-" * 8, "基础用法", "-" * 8)_
_from langgraph.store.memory import InMemoryStore_
_import uuid_
_# 创建存储_
_# in_memory_store = InMemoryStore()_
_#_
_# # 定义命名空间_
_# namespace_for_memory = ("user_id", "memories")_
_#_
_# # 存储记忆_
_# memory_id = str(uuid.uuid4())_
_# memory = {"hobby": "篮球、音乐、美食、编程..."}_
_# in_memory_store.put(namespace_for_memory, memory_id, memory)_
_#_
_# # 搜索记忆_
_# memories = in_memory_store.search(namespace_for_memory)_
_# # 打印数据_
_# print(memories[-1].dict())_
_print("-" * 8, "语义搜索", "-" * 8)_
_from langchain_huggingface import HuggingFaceEmbeddings_
_namespace_for_memory = ("user_id", "memories")_
_store = InMemoryStore(_
_ index={_
_ "embed": HuggingFaceEmbeddings(model_name=r"D:\llm\Local_model\BAAI\bge-large-zh-v1___5"),_
_ "dims": 1024,_
_ "fields": ["hobby", "food_preference"]_
_ }_
_)_
_# 3. 存储数据并检查_
_memory_id_1 = str(uuid.uuid4())_
_memory_1 = {"hobby": "我的爱好是:篮球、音乐、美食、编程..."}_
_store.put(namespace_for_memory, memory_id_1, memory_1)_
_print(f"✓ 存储 hobby 记忆: {memory_id_1}")_
_memory_id_2 = str(uuid.uuid4())_
_memory_2 = {"food_preference": "我最喜欢的美食是:臭豆腐、小龙虾、红烧肉..."}_
_store.put(namespace_for_memory, memory_id_2, memory_2)_
_print(f"✓ 存储 food_preference 记忆: {memory_id_2}")_
_# 4. 检查存储的数据_
_print("\n=== 调试信息 ===")_
_print(f"Namespace: {namespace_for_memory}")_
_print(f"存储的记忆数量: {len(store.search(namespace_for_memory))}")_
_# 5. 搜索测试_
_print("\n=== 搜索测试 ===")_
_# 测试 1: 搜索食物偏好_
_print("搜索: 用户喜欢吃什么?")_
_memories = store.search(_
_ namespace_for_memory,_
_ query="用户喜欢吃什么?",_
_ limit=3_
_)_
_print(f"搜索结果数量: {len(memories)}")_
_if memories:_
_ print(f"最相关结果: {memories[0].dict()}")_
_else:_
_ print("没有找到结果")_
_# 测试 2: 搜索爱好_
_print("\n搜索: 用户的爱好有哪些?")_
_memories = store.search(_
_ namespace_for_memory,_
_ query="用户的爱好有哪些?",_
_ limit=3_
_)_
_print(f"搜索结果数量: {len(memories)}")_
_if memories:_
_ print(f"最相关结果: {memories[0].dict()}")_
_else:_
_ print("没有找到结果")
langgraph 中使用存储功能
from typing import Annotated, List
from typing_extensions import TypedDict
from operator import add
import uuid
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
# 初始化大模型
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 定义状态结构
class MessagesState(TypedDict):
messages: Annotated[List[BaseMessage], add]
# 创建检查点保存器和内存存储
checkpointer = InMemorySaver()
in_memory_store = InMemoryStore()
# 聊天机器人节点 *代表后面的参数必须使用显示写出参数名称 store=in_memory_store
def chatbot(state: MessagesState, config: RunnableConfig, *, store: BaseStore):
_"""主聊天机器人节点,处理用户消息并生成回复"""_
_ _# 获取用户ID和最新消息
user_id = config["configurable"]["user_id"]
last_message = state["messages"][-1]
# 定义内存命名空间
namespace = (user_id, "memories")
# 简单的聊天逻辑
user_input = last_message.content.lower()
# 将聊天历史获取并组装提示词
memories = store.search(namespace)
memory_text = "\n".join(m.value["memory"] for m in memories)
prompt = f"请参考聊天记录:{memory_text}\n\nHuman: {user_input}\nAI:"
response = llm.invoke(prompt).content
# 存储对应的问题和答案
memory = f"问题:{user_input} --- 答案:{response}"
memory_id = str(uuid.uuid4())
store.put(namespace, memory_id, {"memory": memory})
# 返回AI消息
return {"messages": [AIMessage(content=response)]}
# 创建图
def create_persistent_graph():
_"""创建持久化的聊天机器人图"""_
_ _# 创建状态图
workflow = StateGraph(MessagesState)
# 添加节点
workflow.add_node("chatbot", chatbot)
# 添加边
workflow.add_edge(START, "chatbot")
workflow.add_edge("chatbot", END)
# 编译图,使用检查点保存器和存储
graph = workflow.compile(checkpointer=checkpointer, store=in_memory_store)
return graph
# 工具函数:显示状态历史
def show_state_history(graph, config):
_"""显示状态历史"""_
_ _print("\n=== 状态历史 ===")
history = graph.get_state_history(config)
for i, snapshot in enumerate(history):
print(f"\n步骤 {i}:")
print(f" 配置: {snapshot.config}")
print(f" 值: {snapshot.values}")
print(f" 下一步: {snapshot.next}")
print(f" 元数据: {snapshot.metadata}")
# 工具函数:显示存储的记忆
def show_memories(store, user_id):
_"""显示用户的所有记忆"""_
_ _print(f"\n=== 用户 {user_id} 的记忆 ===")
namespace = (user_id, "memories")
memories = store.search(namespace)
if memories:
for memory in memories:
print(f"记忆ID: {memory.key}")
print(f"内容: {memory.value}")
print(f"创建时间: {memory.created_at}")
print(f"更新时间: {memory.updated_at}")
print("---")
else:
print("没有找到记忆")
# 主程序
def main():
# 创建图
graph = create_persistent_graph()
# 用户配置
user_id = "user_123"
thread_id = "conversation_1"
config = {
"configurable": {
"thread_id": thread_id,
"user_id": user_id
}
}
print("=== LangGraph 持久化聊天机器人 ===")
print("输入 'quit' 退出,'history' 查看状态历史,'memories' 查看记忆")
while True:
user_input = input("\n用户: ").strip()
if user_input.lower() == 'quit':
break
elif user_input.lower() == 'history':
show_state_history(graph, config)
continue
elif user_input.lower() == 'memories':
show_memories(in_memory_store, user_id)
continue
# 创建用户消息
initial_state = {
"messages": [HumanMessage(content=user_input)]
}
# 运行图
try:
result = graph.invoke(initial_state, config)
# 显示AI回复
ai_message = result["messages"][-1]
print(f"AI: {ai_message.content}")
except Exception as e:
print(f"错误: {e}")
print("\n=== 最终状态 ===")
final_state = graph.get_state(config)
print(f"最终状态: {final_state.values}")
print("\n=== 所有记忆 ===")
show_memories(in_memory_store, user_id)
# 演示不同线程间的记忆共享
def demo_cross_thread_memory():
_"""演示跨线程记忆共享"""_
_ _print("\n=== 跨线程记忆共享演示 ===")
graph = create_persistent_graph()
user_id = "user_456"
# 第一个对话线程
config1 = {
"configurable": {
"thread_id": "thread_1",
"user_id": user_id
}
}
print("线程1 - 建立记忆:")
result1 = graph.invoke({
"messages": [HumanMessage(content="我叫Alice,我喜欢音乐")]
}, config1)
print(f"AI: {result1['messages'][-1].content}")
# 第二个对话线程(相同用户)
config2 = {
"configurable": {
"thread_id": "thread_2",
"user_id": user_id
}
}
print("\n线程2 - 访问记忆:")
result2 = graph.invoke({
"messages": [HumanMessage(content="你还记得我吗?")]
}, config2)
print(f"AI: {result2['messages'][-1].content}")
# 显示共享的记忆
show_memories(in_memory_store, user_id)
if __name__ == "__main__":
# 运行主程序
main()
# 演示跨线程记忆共享
demo_cross_thread_memory()
记忆
对于人工智能代理来说,记忆至关重要,因为它能让它们记住之前的交互,从反馈中学习,并适应用户的偏好。随着代理需要处理更复杂的任务,并进行大量的用户交互,这种能力对于效率和用户满意度都至关重要。
- 短期记忆(或线程范围的记忆)通过维护会话中的消息历史记录来跟踪正在进行的对话。LangGraph 将短期记忆作为代理状态的一部分进行管理。状态使用检查点持久化到数据库中,以便线程可以随时恢复。短期记忆会在图被调用或某个步骤完成时更新,并且在每个步骤开始时读取状态。
- 长期记忆跨会话存储用户特定或应用程序级别的数据,并在对话线程之间共享。它可以在任何时间、任何线程中调用。记忆的作用域是任何自定义命名空间,而不仅仅是单个线程 ID。LangGraph 提供存储,方便您保存和调用长期记忆。
短期记忆
短期记忆(线程级持久性)使代理能够跟踪多轮对话。在持久性中已经讲过简单的线程级的短期记忆。
1.使用 redis 作为存储
pip install -U langgraph-checkpoint-redis
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.redis import RedisSaver
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
DB_URI = "redis://localhost:6379"
with RedisSaver.from_conn_string(DB_URI) as checkpointer:
# 第一次使用 Redis 检查点时需要调用
checkpointer.setup()
# 执行模型对话的函数图
def call_llm(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": response}
# 创建图对象
builder = StateGraph(MessagesState)
builder.add_node(call_llm)
builder.add_edge(START, "call_llm")
# 将检查点传入
graph = builder.compile(checkpointer=checkpointer)
config = {
"configurable": {
"thread_id": "1"
}
}
# 第一次对话
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "你好呀,我是初见"}]},
config,
stream_mode="values"
):
chunk["messages"][-1].pretty_print()
# 第二次对话
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "我的名字叫什么?"}]},
config,
stream_mode="values"
):
chunk["messages"][-1].pretty_print()
2.使用Postgres存储
下载 postgresSql,数据库存储是的加密数据。
# 1.使用docker下载对应镜像
docker pull postgres:alpine# 这边使用的是体积更小的镜像
# 2.运行对应镜像
docker run -id --name=postgresql -v postgre-data:/var/lib/postgresql/data -p 5432:5432 -e POSTGRES_PASSWORD=123456 -e LANG=C.UTF-8 postgres:alpine
pip install -U "psycopg[binary,pool]" langgraph-checkpoint-postgres
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.postgres import PostgresSaver
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
DB_URI = "postgresql://postgres:123456@localhost:5432/postgres?sslmode=disable"
with PostgresSaver.from_conn_string(DB_URI) as checkpointer:
# checkpointer.setup()
def call_llm(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": [response]}
builder = StateGraph(MessagesState)
builder.add_node(call_llm)
builder.add_edge(START, "call_llm")
graph = builder.compile(checkpointer=checkpointer)
config = {
"configurable": {
"thread_id": "1"
}
}
chunk = graph.invoke(
{"messages": [{"role": "user", "content": "你好,我叫法外狂徒张三"}]},
config,
stream_mode="values"
)
chunk["messages"][-1].pretty_print()
chunk1 = graph.invoke(
{"messages": [{"role": "user", "content": "我的名字叫什么?"}]},
config,
stream_mode="values"
)
chunk1["messages"][-1].pretty_print()
# 最终状态检查
final_state = graph.get_state(config)
print(f"\n=== 最终状态 ===")
print(f"总消息数: {len(final_state.values.get('messages', []))}")
for i, msg in enumerate(final_state.values.get('messages', [])):
role = getattr(msg, 'type', 'unknown')
content = getattr(msg, 'content', str(msg))
print(f" {i + 1}. [{role}] {content[:100]}...")
长期记忆
使用长期记忆来存储对话中特定于用户或特定于应用程序的数据。
1.使用Postgres存储
from langchain_core.runnables import RunnableConfig
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.store.postgres import PostgresStore
from langgraph.store.base import BaseStore
import uuid
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
DB_URI = "postgresql://postgres:123456@localhost:5432/postgres?sslmode=disable"
with (
PostgresStore.from_conn_string(DB_URI) as store,
PostgresSaver.from_conn_string(DB_URI) as checkpointer,
):
# 第一次使用postgres存储是需要调用初始化
store.setup()
checkpointer.setup()
def call_model(
state: MessagesState,
config: RunnableConfig,
*,
store: BaseStore,
):
user_id = config["configurable"]["user_id"]
namespace = ("memories", user_id)
# 获取记忆
memories = store.search(namespace, query=str(state["messages"][-1].content))
info = "\n".join([d.value["data"] for d in memories])
system_msg = f"你是一个与用户交谈的有帮助的助手. 用户信息: {info}"
response = llm.invoke(
[{"role": "system", "content": system_msg}] + state["messages"]
)
# 存储新的记忆,如果用户要求模型记住
last_message = state["messages"][-1]
if "记住" in last_message.content.lower():
memory = "用户名是初见"
# 存储记忆
store.put(namespace, str(uuid.uuid4()), {"data": memory})
return {"messages": response}
# 创建图
builder = StateGraph(MessagesState)
builder.add_node(call_model)
builder.add_edge(START, "call_model")
graph = builder.compile(
checkpointer=checkpointer,
store=store,
)
config = {
"configurable": {
"thread_id": "1",
"user_id": "1",
}
}
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "嗨!记住:我的名字是初见"}]},
config,
stream_mode="values",
):
chunk["messages"][-1].pretty_print()
# 使用第二个线程去访问用户信息
config = {
"configurable": {
"thread_id": "2",
"user_id": "1",
}
}
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "我的名字是什么?"}]},
config,
stream_mode="values",
):
chunk["messages"][-1].pretty_print()
2.使用Redis存储
from langchain_core.runnables import RunnableConfig
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.redis import RedisSaver
from langgraph.store.redis import RedisStore
from langgraph.store.base import BaseStore
import uuid
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
DB_URI = "redis://localhost:6379"
with (
RedisStore.from_conn_string(DB_URI) as store,
RedisSaver.from_conn_string(DB_URI) as checkpointer,
):
# 第一次使用postgres存储是需要调用初始化
store.setup()
checkpointer.setup()
def call_model(
state: MessagesState,
config: RunnableConfig,
*,
store: BaseStore,
):
user_id = config["configurable"]["user_id"]
namespace = ("memories", user_id)
# 获取记忆
memories = store.search(namespace, query=str(state["messages"][-1].content))
info = "\n".join([d.value["data"] for d in memories])
system_msg = f"你是一个与用户交谈的有帮助的助手. 用户信息: {info}"
response = llm.invoke(
[{"role": "system", "content": system_msg}] + state["messages"]
)
# 存储新的记忆,如果用户要求模型记住
last_message = state["messages"][-1]
if "记住" in last_message.content.lower():
memory = "用户名是初见"
# 存储记忆
store.put(namespace, str(uuid.uuid4()), {"data": memory})
return {"messages": response}
# 创建图
builder = StateGraph(MessagesState)
builder.add_node(call_model)
builder.add_edge(START, "call_model")
graph = builder.compile(
checkpointer=checkpointer,
store=store,
)
config = {
"configurable": {
"thread_id": "1",
"user_id": "1",
}
}
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "嗨!记住:我的名字是初见"}]},
config,
stream_mode="values",
):
chunk["messages"][-1].pretty_print()
# 使用第二个线程去访问用户信息
config = {
"configurable": {
"thread_id": "2",
"user_id": "1",
}
}
for chunk in graph.stream(
{"messages": [{"role": "user", "content": "我的名字是什么?"}]},
config,
stream_mode="values",
):
chunk["messages"][-1].pretty_print()
具有语义搜索的记忆
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chat_models import init_chat_model
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.graph import START, MessagesState, StateGraph
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 加载嵌入模型
embeddings = HuggingFaceEmbeddings(model_name=r"D:\llm\Local_model\BAAI\bge-large-zh-v1___5")
store = InMemoryStore(
index={
"embed": embeddings,
"dims": 1024,
}
)
store.put(("memories", "user_123"), "1", {"text": "我喜欢吃披萨"})
store.put(("memories", "user_123"), "2", {"text": "我喜欢吃红烧肉"})
store.put(("memories", "user_123"), "3", {"text": "我的职业是程序员"})
def chat(state, *, store: BaseStore):
# 根据用户的最后一条消息进行搜索
items = store.search(
("memories","user_123"), query=state["messages"][-1].content, limit=2
)
print(items)
memories = "\n".join(item.value["text"] for item in items)
memories = f"## 用户记忆\n{memories}" if memories else ""
response = llm.invoke(
[
{"role": "system", "content": f"你是一个乐于助人的助手.\n{memories}"},
] + state["messages"]
)
return {"messages": [response]}
builder = StateGraph(MessagesState)
builder.add_node(chat)
builder.add_edge(START, "chat")
graph = builder.compile(store=store)
for message, metadata in graph.stream(
input={"messages": [{"role": "user", "content": "我饿了?"}]},
stream_mode="messages",
):
print(message.content, end="")
管理短期记忆
启用短期记忆后,长对话可能会超出 LLM 的上下文窗口。常见的解决方案如下:
- 修剪消息:删除前 N 条或后 N 条消息(在调用 LLM 之前)
- 从 LangGraph 状态中永久删除消息
- 总结消息:总结历史记录中较早的消息,并用摘要替换它们
- 管理检查点以存储和检索消息历史记录
- 自定义策略(例如,消息过滤等
修剪消息
大多数 LLM 都有一个最大支持的上下文窗口(以 token 为单位)。决定何时截断消息的一种方法是计算消息历史记录中的 token 数量,并在接近该限制时进行截断。
from langchain_core.messages.utils import (
trim_messages,
count_tokens_approximately
)
from langgraph.checkpoint.memory import InMemorySaver
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, START, MessagesState
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
def call_llm(state: MessagesState):
messages = trim_messages(
state["messages"],
strategy="last", # 修剪策略(last从末尾,first从开头, middle从中间)
token_counter=count_tokens_approximately, # 用来估算token数量
max_tokens=50, # 修剪后的消息总 token 不超过 200
start_on="human", # _控制修剪的起始消息类型(确保修剪后以human消息开始)_
end_on=("human", "tool"), # 允许哪些角色作为修剪终点
)
print(f"修剪后消息数量: {len(messages)}")
print(f"修剪后总tokens: {count_tokens_approximately(messages)}")
print("修剪后的消息:")
for i, msg in enumerate(messages):
print(f" {i}: {msg.type} - {msg.content[:50]}...")
print("\n" + "=" * 50 + "\n")
response = llm.invoke(messages)
return {"messages": [response]}
checkpointer = InMemorySaver()
builder = StateGraph(MessagesState)
builder.add_node(call_llm)
builder.add_edge(START, "call_llm")
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "我的名字叫初见"}, config)
graph.invoke({"messages": "帮我家的猫写一首诗"}, config)
graph.invoke({"messages": "现在对狗做一样的事情"}, config)
final_response = graph.invoke({"messages": "我的名字叫什么?"}, config)
# print(final_response)
# final_response["messages"][-1].pretty_print()
删除消息
可以从图表状态中删除消息,以管理消息历史记录。当您想要移除特定消息或清除整个消息历史记录时,此功能非常有用。
from langchain_core.messages import RemoveMessage
from langgraph.graph import StateGraph, START, MessagesState
from langchain.chat_models import init_chat_model
from langgraph.checkpoint.memory import InMemorySaver
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
def delete_messages(state):
messages = state["messages"]
if len(messages) > 2:
# 删除最早的两条消息
return {"messages": [RemoveMessage(id=m.id) for m in messages[:2]]}
return None
def call_llm(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": response}
builder = StateGraph(MessagesState)
builder.add_sequence([call_llm, delete_messages])
builder.add_edge(START, "call_llm")
checkpointer = InMemorySaver()
app = builder.compile(checkpointer=checkpointer)
config = {
"configurable": {
"thread_id": "1"
}
}
for event in app.stream(
{"messages": [{"role": "user", "content": "你好呀,我是初见哦"}]},
config,
stream_mode="values"
):
print([(message.type, message.content) for message in event["messages"]])
for event in app.stream(
{"messages": [{"role": "user", "content": "我的名字是什么?"}]},
config,
stream_mode="values"
):
# 最终回复会把最开始的两条消息删除
print([(message.type, message.content) for message in event["messages"]])
总结消息
修剪或删除消息的问题在于,可能会因剔除消息队列而丢失信息。因此,一些应用程序受益于一种更复杂的方法,即使用聊天模型来汇总消息历史记录。
下载模块
pip install langmem
from typing import Any, TypedDict
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain.chat_models import init_chat_model
from langchain_core.messages import AnyMessage
from langchain_core.messages.utils import count_tokens_approximately
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import SummarizationNode, RunningSummary
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建一个专用于摘要的模型实例,限制输出最多 128 tokens
summarization_model = llm.bind(max_tokens=128)
# 定义状态结构,包含对话历史和摘要上下文
class State(MessagesState):
context: dict[str, RunningSummary] # 用于存储用户摘要记忆(running_summary)
# 定义输入格式,传给 call_model 函数使用
class LLMInputState(TypedDict):
summarized_messages: list[AnyMessage] # 已被压缩/摘要过的消息
context: dict[str, RunningSummary]
# 首次生成摘要
initial_summary_prompt = ChatPromptTemplate.from_template(
"""请阅读以下对话内容,并生成一个简洁的摘要,用于帮助理解对话的主要内容:
对话内容:
{messages}
摘要:"""
)
# 在已有摘要基础上追加新的对话内容,更新摘要
existing_summary_prompt = ChatPromptTemplate.from_template(
"""你之前已经生成了如下摘要:
{existing_summary}
现在,对话继续发展了,请根据新增的对话内容,更新这个摘要,使其覆盖所有关键内容。
新增对话内容:
{messages}
更新后的摘要:"""
)
# 用于最终调用模型之前,将摘要和剩余消息一起传入
final_prompt = ChatPromptTemplate.from_template(
"""你是一位智能助理。以下是用户和你的对话摘要,可帮助你快速理解上下文:
摘要:
{summary}
这是对话中未被总结的新消息,请继续处理这些信息:
{messages}
"""
)
# 创建摘要节点:超过一定 token 数时会对历史消息自动进行摘要
summarization_node = SummarizationNode(
token_counter=count_tokens_approximately, # 使用近似 token 计算
model=summarization_model, # 使用绑定了 max_tokens 的模型
max_tokens=200, # 在进行摘要之前,传给模型的输入上下文的最大token长度限制
max_tokens_before_summary=150, # 超过这个数就会触发摘要
max_summary_tokens=128, # 每次摘要最多保留 128 tokens
initial_summary_prompt=initial_summary_prompt, # 首次生成摘要的提示词
existing_summary_prompt=existing_summary_prompt, # 更新摘要的提示词
final_prompt=final_prompt # 模型回答问题之前参考的摘要上下文的提示词
)
# 模型调用节点:对压缩过的历史消息进行问答
def call_llm(state: LLMInputState):
response = llm.invoke(state["summarized_messages"])
return {
"messages": [response],
"context": state.get("context", {}) # 把上下文原样返回,里面就有摘要
}
# 使用内存存储器(可换成 Redis/Postgres)
checkpointer = InMemorySaver()
# 构建 LangGraph 的流程图
builder = StateGraph(State)
# 添加两个节点:摘要节点 和 模型调用节点
builder.add_node(call_llm)
builder.add_node("summarize", summarization_node)
# 定义边:从 START 开始 → 先摘要 → 再模型调用
builder.add_edge(START, "summarize")
builder.add_edge("summarize", "call_llm")
# 编译图
graph = builder.compile(checkpointer=checkpointer)
# ========== 流程调用 ==========
config = {"configurable": {"thread_id": "1"}} # 每个线程维护一个上下文
# 第1轮:告诉模型「我叫小明」
graph.invoke({"messages": "你好,我叫初见"}, config)
# 第2轮:要求写一首猫的诗
graph.invoke({"messages": "请写一首关于猫的诗"}, config)
# 第3轮:让它对狗做一样的事
graph.invoke({"messages": "现在也请为狗写一首诗"}, config)
# 第4轮:问它「我叫什么名字?」
final_response = graph.invoke({"messages": "你还记得我叫什么名字吗?"}, config)
# 输出最终回复
final_response["messages"][-1].pretty_print()
# 输出摘要内容(短期记忆)
print("\n摘要记忆内容(summary):", final_response)
工具
工具封装了可调用函数及其输入模式。这些可以传递给兼容的聊天模型,让模型决定是否调用工具以及使用哪些参数。
预建工具
LangChain 为常见的外部系统(包括 API、数据库、文件系统和 Web 数据)提供预构建的工具集成。
浏览集成目录以查找可用的工具。
常见类别:
- 搜索:Bing、SerpAPI、Tavily
- 代码执行:Python REPL、Node.js REPL
- 数据库:SQL、MongoDB、Redis
- Web 数据:抓取和浏览
- API:OpenWeatherMap、NewsAPI 等。
自定义工具
使用 @tool 装饰器来定义工具
from langchain_core.tools import tool
@tool
def multiply(a: int, b: int) -> int:
"""将两个数目相乘."""
return a * b
# 运行工具
print(multiply.invoke({"a": 6, "b": 7})) # returns 42
tool_call = {
"type": "tool_call",
"id": "1",
"args": {"a": 42, "b": 7}
}
print(multiply.invoke(tool_call))
print("=" * 8, "在代理中使用", "=" * 8)
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
agent = create_react_agent(
model=llm,
tools=[multiply]
)
print(agent.invoke({"messages": [{"role": "user", "content": "100乘以37等于多少?"}]}))
print("=" * 8, "在工作流中使用", "=" * 8)
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_tavily import TavilySearch
@tool
def tavily_search_tool(query: str) -> str:
"""这是一个搜索工具"""
tool_instance = TavilySearch()
return tool_instance.run(query)
# 执行工具的节点
tool_node = ToolNode([tavily_search_tool])
# 绑定工具到模型
model_with_tools = llm.bind_tools([tavily_search_tool])
def should_continue(state: MessagesState):
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return END
def call_model(state: MessagesState):
messages = state["messages"]
response = model_with_tools.invoke(messages)
return {"messages": [response]}
builder = StateGraph(MessagesState)
# 定义节点和边
builder.add_node("call_model", call_model)
builder.add_node("tools", tool_node)
builder.add_edge(START, "call_model")
builder.add_conditional_edges("call_model", should_continue, ["tools", END])
builder.add_edge("tools", "call_model")
graph = builder.compile()
print(graph.invoke({"messages": [{"role": "user", "content": "上海的天气?"}]}))
[!TIP]
用户输入:“上海的天气?”
↓
🧠 call_model 节点
- 使用绑定了工具的 LLM(model_with_tools)调用模型
- 模型判断是否需要工具(如搜索)
- 如果有 tool_calls,走向 tools 节点
- 否则直接结束
↓
🔧 tools 节点(ToolNode) - 检查模型返回是否有 tool_calls(如 tavily_search_tool)
- 执行对应工具(如 TavilySearch.run(query))
- 工具执行结果构建成 AI Message 返回
↓
🧠 再次回到 call_model(循环处理新消息) - 把 tool_result 消息输入给模型,继续推理
↓
✅ END(当模型不再调用工具,流程结束)
工具定制
参数说明
from langchain_core.tools import tool
@tool("multiply_tool", parse_docstring=True)
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: First operand
b: Second operand
"""
return a * b
显式输入模式
from pydantic import BaseModel, Field
from langchain_core.tools import tool
class MultiplyInputSchema(BaseModel):
"""Multiply two numbers"""
# 通过创建一个类继承BaseModel去作为参数验证和工具描述
a: int = Field(description="First operand")
b: int = Field(description="Second operand")
@tool("multiply_tool", args_schema=MultiplyInputSchema)
def multiply(a: int, b: int) -> int:
return a * b
上下文管理
LangGraph 中的工具有时需要上下文数据,例如仅在运行时使用的参数(例如,用户 ID 或会话详细信息),这些数据不应由模型控制。LangGraph 提供了三种方法来管理此类上下文:
配置
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
@tool
def get_user_info(
config: RunnableConfig,
) -> str:
"""查找用户信息."""
user_id = config["configurable"].get("user_id")
return "用户的名字是初见" if user_id == "user_123" else "Unknown user"
# 创建Agent
agent = create_react_agent(
model=llm,
tools=[get_user_info],
)
response = agent.invoke(
{"messages": [{"role": "user", "content": "查询用户信息"}]},
config={"configurable": {"user_id": "user_123"}}
)
print(response)
短期记忆
短期记忆保持在单次执行期间发生变化的动态状态
from typing import Annotated, NotRequired
from langgraph.prebuilt import InjectedState, create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool, InjectedToolCallId
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
class CustomState(AgentState):
# user_name字段处于短期状态
user_name: NotRequired[str]
# 只有通过 InjectedState,LangGraph 才会自动把当前运行状态(CustomState)注入到工具函数中。
@tool
def get_user_name(
state: Annotated[CustomState, InjectedState]
) -> str:
"""从state中检索当前用户名。"""
# 返回存储的名称,如果未设置则返回默认值
return state.get("user_name", "Unknown user")
@tool
def update_user_name(
new_name: str,
# 注入当前工具调用ID(LLM会忽略此参数,用于工具消息追踪)
tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
"""更新短期记忆中的用户名"""
return Command(update={
"user_name": new_name,
"messages": [
ToolMessage(content=f"名字已更新为:{new_name}", tool_call_id=tool_call_id)
]
})
# 创建内存持久化器
checkpointer = InMemorySaver()
# 创建代理
agent = create_react_agent(
model=llm,
tools=[get_user_name, update_user_name],
state_schema=CustomState,
checkpointer=checkpointer,
)
config = {"configurable": {"thread_id": "1"}}
# 第一次:更新用户名(模型会调用update_user_name工具)
response = agent.invoke({"messages": [{"role": "user", "content": "我的名字是初见"}]}, config)
response["messages"][-1].pretty_print()
# 第二次:获取用户名(调用get_user_name工具)
response = agent.invoke({"messages": [{"role": "user", "content": "我的名字是什么?"}]}, config)
response["messages"][-1].pretty_print()
长期记忆
使用长期记忆来存储对话中特定于用户或应用程序的数据。这对于像聊天机器人这样的应用程序非常有用。
要使用长期记忆,需要:
- 配置存储以在调用之间保留数据。
- 使用该
get_store
功能从工具或提示中访问 store。
from langchain_core.runnables import RunnableConfig
from typing_extensions import TypedDict
from langchain_core.tools import tool
from langgraph.config import get_store
from langgraph.prebuilt import create_react_agent
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.memory import InMemorySaver
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# 创建内存存储对象
store = InMemoryStore()
# 创建内存持久化器
checkpointer = InMemorySaver()
# 存储初始化用户信息
store.put(
("users",),
"user_123",
{
"name": "初见",
"language": "中文",
}
)
class UserInfo(TypedDict):
name: str
language: str
@tool
def update_user_info(user_info: UserInfo, config: RunnableConfig) -> str:
"""更新用户信息"""
print("工具被调用,接收到的 user_info:", user_info)
store = get_store()
user_id = config["configurable"].get("user_id")
store.put(("users",), user_id, user_info)
return "用户信息更新成功"
@tool
def get_user_info(config: RunnableConfig) -> str:
"""查找用户信息."""
store = get_store()
user_id = config["configurable"].get("user_id")
user_info = store.get(("users",), user_id)
return str(user_info.value) if user_info else "Unknown user"
# 创建代理
agent = create_react_agent(
model=llm,
tools=[update_user_info, get_user_info],
checkpointer=checkpointer,
store=store
)
config = {"configurable": {"thread_id": "1", "user_id": "user_123"}}
# 运行代理
response = agent.invoke(
{"messages": [{"role": "user", "content": "查询用户信息"}]},
config=config
)
response["messages"][-1].pretty_print()
config = {"configurable": {"thread_id": "2", "user_id": "user_123"}}
# 运行代理
response = agent.invoke(
{"messages": [{"role": "user", "content": "查询用户信息"}]},
config=config
)
response["messages"][-1].pretty_print()
# 运行代理-更新用户信息
agent.invoke(
{"messages": [{"role": "user", "content": "我的名字叫李铭,使用的语言是西班牙语"}]},
config=config
)
# 运行代理
response = agent.invoke(
{"messages": [{"role": "user", "content": "查询用户姓名和使用语言"}]},
config=config
)
response["messages"][-1].pretty_print()
人机交互
要在代理或工作流中审核、编辑和批准工具调用,请使用中断来暂停图表并等待人工输入。中断使用 LangGraph 的持久层(该层会保存图表状态)无限期暂停图表执行,直到恢复为止。
暂停使用 interrupt
动态中断(也称为动态断点)根据图表的当前状态触发。您可以通过在适当的位置调用 interrupt
函数来设置动态中断。图表将暂停,以便人工干预,然后根据人工输入恢复图表。这对于审批、编辑或收集其他上下文等任务非常有用。
from typing import TypedDict
import uuid
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import START
from langgraph.graph import StateGraph
from langgraph.types import interrupt, Command
# 1. 定义状态结构,LangGraph 的每个节点输入输出都以它为标准
class State(TypedDict):
some_text: str # 保存一段文本
# 2. 定义一个“人工参与节点”
def human_node(state: State):
# 使用 interrupt 让图在这里暂停,并把 current state 输出出去, value是用户选择的内容
value = interrupt(
{
"是否同意进行下一步": "1.同意,2.拒绝" # 提示要修改的文本
}
)
if not value:
return {
"some_text": "拒绝请假" # 更新状态为人工输入后的新文本
}
# interrupt 实际上在 resume 之后会把 resume 的内容作为 value
return {
"some_text": "成功请假" # 更新状态为人工输入后的新文本
}
# 3. 构建图
graph_builder = StateGraph(State) # 指定状态类型
graph_builder.add_node("human_node", human_node) # 添加节点
graph_builder.add_edge(START, "human_node") # 从起点进入这个节点
# 4. 启用内存检查点(用于保存每个线程的执行状态)
checkpointer = InMemorySaver()
graph = graph_builder.compile(checkpointer=checkpointer) # 编译成可执行图
# 5. 设置一个线程ID,便于在中断后恢复同一个会话
config = {"configurable": {"thread_id": uuid.uuid4()}}
# 6. 启动流程,并传入初始文本。运行到中断节点时会暂停
result = graph.invoke({"some_text": "original text", "name": "张三"}, config=config)
# 7. 输出中断返回内容,通常是交给用户编辑或确认
print(result["__interrupt__"])
# 例子输出:{'name': 'human_node', 'args': {'text_to_revise': 'original text'}}
# 8. 模拟用户修改后,通过 resume 恢复图的执行(传入 Command)
print(graph.invoke(Command(resume=False), config=config))
# 输出:{'some_text': 'Edited text'} 表示人工干预后的新状态
[!TIP]
**注意:**就开发者体验而言,中断类似于 Python 的 input() 函数,但它们不会自动从中断点恢复执行。相反,它们会重新运行发生中断的整个节点。因此,中断通常最好放置在节点的起始位置或专用节点中。
多智能体系统
代理是一种使用 LLM 来决定应用程序控制流的系统。随着这些系统的开发,它们可能会随着时间的推移变得更加复杂,从而更难以管理和扩展。例如,您可能会遇到以下问题:
- 代理可以使用的工具太多,无法决定下一步调用哪个工具
- 环境变得过于复杂,单个代理无法跟踪
- 系统中需要多个专业领域(例如规划师、研究员、数学专家等)
为了解决这些问题,您可以考虑将应用程序拆分成多个较小的独立代理,并将它们组合成一个多代理系统。这些独立代理可以像提示符和 LLM 调用一样简单,也可以像 ReAct 代理一样复杂(甚至更多!)。
使用多代理系统的主要好处是:
- 模块化:独立的代理使得代理系统的开发、测试和维护变得更加容易。
- 专业化:您可以创建专注于特定领域的专家代理,这有助于提高整体系统性能。
- 控制:您可以明确控制代理如何通信。
在多代理系统中,有几种连接代理的方法:
- 网络:每个代理都可以与其他代理通信。任何代理都可以决定接下来要呼叫哪个代理。
- 主管代理:每个代理只与一个主管代理进行通信。主管代理负责决定接下来应该调用哪个代理。
- 主管(工具调用):这是主管架构的一个特例。单个代理可以表示为工具。在这种情况下,主管代理使用工具调用 LLM 来决定调用哪些代理工具,以及传递给这些代理的参数。
- 分层结构:你可以定义一个多智能体系统,其中包含多个主管的主管。这是主管架构的泛化,允许更复杂的控制流。
- 自定义多代理工作流:每个代理仅与一部分代理进行通信。流程的某些部分是确定性的,只有部分代理可以决定接下来要调用哪些其他代理。
交接(Handoffs)
交接概念
在多智能体架构中,智能体可以表示为图节点。每个智能体节点执行其步骤,并决定是完成执行还是路由至其他智能体,包括可能路由至自身(例如,循环运行)。多智能体交互中一种常见的模式是切换,即一个智能体将控制权移交给另一个智能体
基础交接示例
在线订餐系统
场景描述:顾客下单 → 订单处理 → 支付处理 → 配送安排
import os
import json
import random
from typing import Literal, TypedDict
from langgraph.types import Command
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()
# 初始化大模型
llm = ChatOpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model='qwen-plus-2025-01-25',
temperature=0.7
)
def call_llm(system_prompt: str, user_message: str, temperature: float = 0.7) -> str:
"""
真实的LLM调用函数
"""
try:
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message)
]
# 设置温度参数
llm.temperature = temperature
response = llm.invoke(messages)
print(f"LLM调用成功")
print(f"系统提示: {system_prompt[:50]}...")
print(f"用户输入: {user_message[:50]}...")
print(f"AI回复: {response.content[:100]}...")
return response.content
except Exception as e:
print(f"LLM调用失败: {str(e)}")
return f"LLM调用失败: {str(e)}"
class OrderState(TypedDict):
customer_name: str # 顾客名称
order_items: list # 订单内容
raw_order_text: str # 原始订单文本(用户自然语言输入)
total_amount: float # 订单总价格
payment_status: str # 支付状态
delivery_address: str # 配送地址
order_status: str # 订单状态
messages: list # 消息
llm_analysis: dict # LLM分析结果
def order_receiver(state: OrderState) -> Command[Literal["payment_processor", "order_validator"]]:
"""订单接收智能体 - 使用真实大模型理解自然语言订单"""
raw_text = state.get("raw_order_text", "")
customer_name = state.get("customer_name", "")
order_items = state.get("order_items", [])
print(f"订单接收:{customer_name} 的订单")
llm_analysis = {}
# 使用大模型理解和解析自然语言订单
if raw_text:
system_prompt = """你是一个专业的餐厅订单理解助手。请仔细分析用户的订单文本,提取以下信息:
1. 识别用户想要的具体食物和饮料
2. 判断订单信息是否清晰完整
3. 识别任何特殊要求或备注
4. 评估订单的明确程度
请严格按照以下JSON格式返回结果:
{
"items_identified": ["商品1", "商品2"],
"items_confirmed": true/false,
"confidence_score": 0.0-1.0,
"special_requests": "特殊要求描述",
"clarity_assessment": "订单清晰度评估",
"suggested_clarification": "如果不清晰,建议询问的问题"
}"""
llm_response = call_llm(system_prompt, f"用户订单:{raw_text}", temperature=0.3)
try:
llm_analysis = json.loads(llm_response)
except json.JSONDecodeError:
print("LLM返回的不是有效JSON格式,尝试解析...")
# 如果不是JSON,创建默认分析结果
llm_analysis = {
"items_identified": [],
"items_confirmed": False,
"confidence_score": 0.0,
"clarity_assessment": "解析失败",
"suggested_clarification": "请重新描述您的订单"
}
print(f"LLM分析结果: {llm_analysis}")
# 基于LLM分析决定流程
confidence = llm_analysis.get("confidence_score", 0.0)
items_confirmed = llm_analysis.get("items_confirmed", False)
if confidence < 0.7 or not items_confirmed:
return Command(
goto="order_validator",
update={
"order_status": "需要验证",
"llm_analysis": llm_analysis,
"messages": [
f"订单信息需要确认:{llm_analysis.get('suggested_clarification', '请提供更详细的订单信息')}"]
}
)
# 计算总金额
total = sum(item.get("price", 0) for item in order_items)
# 检查订单有效性
if not order_items or total <= 0:
return Command(
goto="order_validator",
update={
"order_status": "需要验证",
"llm_analysis": llm_analysis,
"messages": ["订单信息不完整,需要验证"]
}
)
else:
return Command(
goto="payment_processor",
update={
"total_amount": total,
"order_status": "待支付",
"llm_analysis": llm_analysis,
"messages": [f"订单确认成功!总金额:{total}元,准备处理支付"]
}
)
def order_validator(state: OrderState) -> Command[Literal["payment_processor", END]]:
"""订单验证智能体 - 使用大模型进行智能验证和建议"""
print("订单验证:使用AI检查订单完整性")
order_items = state.get("order_items", [])
customer_name = state.get("customer_name", "")
llm_analysis = state.get("llm_analysis", {})
raw_order_text = state.get("raw_order_text", "")
# 使用大模型进行订单验证和修正建议
system_prompt = """你是一个专业的订单验证专家。请分析订单信息,判断:
1. 订单是否有效和完整
2. 是否需要修正或补充信息
3. 给出具体的处理建议
4. 如果订单无效,说明原因
请根据分析结果返回JSON格式:
{
"is_valid": true/false,
"validation_score": 0.0-1.0,
"issues_found": ["问题1", "问题2"],
"corrections_needed": ["修正建议1", "修正建议2"],
"processing_recommendation": "继续处理/需要客户确认/取消订单",
"customer_message": "给客户的反馈消息"
}"""
order_info = f"""
客户姓名:{customer_name}
原始订单文本:{raw_order_text}
当前订单项:{order_items}
之前的LLM分析:{llm_analysis}
"""
llm_response = call_llm(system_prompt, order_info, temperature=0.2)
try:
validation_result = json.loads(llm_response)
except json.JSONDecodeError:
print("验证结果解析失败")
validation_result = {
"is_valid": False,
"customer_message": "订单验证过程中出现错误,请重新提交订单"
}
print(f"验证结果: {validation_result}")
# 根据验证结果决定流程
if not validation_result.get("is_valid", False) or validation_result.get("processing_recommendation") == "取消订单":
return Command(
goto=END,
update={
"order_status": "订单取消",
"messages": [validation_result.get("customer_message", "订单验证失败,已取消")]
}
)
# 如果需要客户确认但这里简化处理,假设已确认
corrected_total = sum(item.get("price", 0) for item in order_items) or 50.0
return Command(
goto="payment_processor",
update={
"total_amount": corrected_total,
"order_status": "验证通过,待支付",
"messages": [validation_result.get("customer_message", f"订单验证通过,总金额:{corrected_total}元")]
}
)
def payment_processor(state: OrderState) -> Command[Literal["delivery_scheduler", "order_receiver", END]]:
"""支付处理智能体 - 使用大模型生成个性化支付处理消息"""
total_amount = state.get("total_amount", 0)
customer_name = state.get("customer_name", "")
print(f"支付处理:处理 {customer_name} 的 {total_amount} 元支付")
# 模拟支付处理逻辑
payment_success = random.choice([True, False, "retry"])
if payment_success:
# 使用大模型生成支付成功消息
system_prompt = f"""请生成一条专业、友好的支付成功确认消息。要求:
1. 感谢客户
2. 确认支付金额
3. 告知下一步流程
4. 保持简洁友好的语调
客户姓名:{customer_name}
支付金额:{total_amount}元"""
success_message = call_llm(system_prompt, "生成支付成功消息", temperature=0.5)
return Command(
goto="delivery_scheduler",
update={
"payment_status": "已支付",
"order_status": "待配送",
"messages": [success_message]
}
)
elif payment_success == "retry":
return Command(
goto="order_receiver",
update={
"payment_status": "支付重试",
"messages": ["支付遇到临时问题,正在重新处理您的订单..."]
}
)
else:
# 使用大模型生成个性化的支付失败处理消息
system_prompt = f"""请生成一条专业、同理心的支付失败处理消息。要求:
1. 表达歉意和理解
2. 提供具体的解决方案
3. 给出客服联系方式
4. 保持积极正面的语调
5. 不要让客户感到沮丧
客户姓名:{customer_name}
订单金额:{total_amount}元"""
failure_message = call_llm(system_prompt, "支付失败,需要友好专业的客服回复", temperature=0.6)
return Command(
goto=END,
update={
"payment_status": "支付失败",
"order_status": "订单暂停",
"messages": [failure_message]
}
)
def delivery_scheduler(state: OrderState) -> Command[Literal[END]]:
"""配送安排智能体 - 使用大模型生成个性化配送通知"""
customer_name = state.get("customer_name", "")
delivery_address = state.get("delivery_address", "")
order_items = state.get("order_items", [])
total_amount = state.get("total_amount", 0)
print(f"配送安排:为 {customer_name} 安排配送")
# 🤖 使用大模型生成个性化配送通知
system_prompt = f"""请生成一条专业、详细的配送通知消息。要求:
1. 个性化问候客户
2. 确认订单详情
3. 提供准确的配送信息
4. 包含联系方式和注意事项
5. 语气友好专业
订单信息:
- 客户:{customer_name}
- 配送地址:{delivery_address}
- 订单商品:{[item.get('name', '未知商品') for item in order_items]}
- 订单金额:{total_amount}元"""
delivery_notification = call_llm(system_prompt, "生成专业的配送通知消息", temperature=0.4)
return Command(
goto=END,
update={
"order_status": "配送中",
"messages": [delivery_notification]
}
)
# 构建集成真实大模型的订餐系统
order_builder = StateGraph(OrderState)
order_builder.add_node("order_receiver", order_receiver)
order_builder.add_node("order_validator", order_validator)
order_builder.add_node("payment_processor", payment_processor)
order_builder.add_node("delivery_scheduler", delivery_scheduler)
order_builder.add_edge(START, "order_receiver")
order_system = order_builder.compile()
# 测试运行 - 包含自然语言订单输入
test_order_1 = {
"customer_name": "张三",
"raw_order_text": "我想要一个汉堡,再来一杯可乐,谢谢!",
"order_items": [
{"name": "汉堡", "price": 25.0},
{"name": "可乐", "price": 8.0}
],
"delivery_address": "北京市朝阳区xxx街道",
"messages": [],
"llm_analysis": {}
}
try:
result = order_system.invoke(test_order_1)
print(f"\n最终状态:{result['order_status']}")
print(f"处理消息:")
for msg in result['messages']:
print(f" - {msg}")
if result.get('llm_analysis'):
print(f"LLM分析结果:{result['llm_analysis']}")
except Exception as e:
print(f"系统执行出错: {str(e)}")
需要从子图中回到父图
def some_node_inside_alice(state):
return Command(
goto="bob",
update={"my_state_key": "my_state_value"},
# 需要在子图中添加下面代码,代表我要回到父图
graph=Command.PARENT,
)
交接的关键要点
- 交接时机
- 任务超出当前智能体能力范围
- 需要专业化处理
- 错误处理和重试机制
- 工作流程的自然转换点
2.交接信息
# 完整的交接信息示例
return Command(
goto="target_agent", # 目标智能体
update={ # 状态更新
"handoff_reason": "专业需求", # 交接原因
"context": "相关上下文", # 上下文信息
"priority": "high", # 优先级
"deadline": "2024-01-01" # 截止时间
},
graph=Command.PARENT # 图级别(可选)
)
工具进行交接
from typing import Annotated
from typing_extensions import Literal
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langchain_core.messages import ToolMessage, convert_to_messages
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from langgraph.graph import MessagesState, StateGraph, START, END
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()
# 初始化大模型
llm = ChatOpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model='qwen-plus-2025-01-25',
temperature=0.7
)
def make_handoff_tool(*, agent_name: str):
"""
创建一个工具交接函数,用于在代理之间进行转接
Args:
agent_name (str): 目标代理的名称
Returns:
tool: 返回一个可以执行代理转接的工具函数
"""
# 根据目标代理名称动态生成工具名称
tool_name = f"transfer_to_{agent_name}"
@tool(tool_name)
def handoff_to_agent(
# 注入当前图状态(LLM会忽略此参数,但工具内部可以使用)
state: Annotated[dict, InjectedState],
# 注入当前工具调用ID(LLM会忽略此参数,用于工具消息追踪)
tool_call_id: Annotated[str, InjectedToolCallId],
):
"""请求另一个代理的帮助进行任务交接"""
# 创建工具响应消息,表示成功转接到目标代理
tool_message = {
"role": "tool",
"content": f"成功转接到 {agent_name} 代理",
"name": tool_name,
"tool_call_id": tool_call_id,
}
# 返回Command对象,用于导航到父图中的另一个代理节点
return Command(
# 导航到目标代理节点
goto=agent_name,
# 在父图中执行导航
graph=Command.PARENT,
# 更新状态:将完整的消息历史传递给目标代理,并添加工具消息
# 这确保了聊天历史的完整性和有效性
update={"messages": state["messages"] + [tool_message]},
)
return handoff_to_agent
def make_agent(model, tools, system_prompt=None):
"""
创建一个智能代理,能够使用工具并在需要时进行代理转接
Args:
model: 语言模型实例
tools: 代理可用的工具列表
system_prompt: 系统提示词,定义代理的角色和行为
Returns:
compiled_graph: 编译后的代理图
"""
# 将工具绑定到模型上
model_with_tools = model.bind_tools(tools)
# 创建工具名称到工具对象的映射,便于快速查找
tools_by_name = {tool.name: tool for tool in tools}
def call_model(state: MessagesState) -> Command[Literal["call_tools", END]]:
"""
调用语言模型生成响应
Args:
state: 当前消息状态
Returns:
Command: 如果需要调用工具则转到call_tools,否则结束
"""
messages = state["messages"]
# 如果有系统提示词,将其添加到消息开头
if system_prompt:
messages = [{"role": "system", "content": system_prompt}] + messages
# 调用绑定了工具的模型
response = model_with_tools.invoke(messages)
# 检查模型是否决定使用工具
if len(response.tool_calls) > 0:
# 如果有工具调用,转到工具执行节点
return Command(goto="call_tools", update={"messages": [response]})
# 如果没有工具调用,直接返回响应消息
return {"messages": [response]}
def call_tools(state: MessagesState) -> Command[Literal["call_model"]]:
"""
执行工具调用
Args:
state: 当前消息状态
Returns:
list[Command]: 工具执行结果的命令列表
"""
# 获取最后一条消息中的所有工具调用
tool_calls = state["messages"][-1].tool_calls
results = []
# 逐个执行工具调用
for tool_call in tool_calls:
# 根据工具名称获取对应的工具对象
tool_ = tools_by_name[tool_call["name"]]
# 获取工具的输入参数架构
tool_input_fields = tool_.get_input_schema().model_json_schema()[
"properties"
]
# 检查工具是否需要状态注入(简化版实现)
if "state" in tool_input_fields:
# 如果工具需要状态,将当前状态注入到工具参数中
tool_call = {**tool_call, "args": {**tool_call["args"], "state": state}}
# 执行工具调用
tool_response = tool_.invoke(tool_call)
# 处理不同类型的工具响应
if isinstance(tool_response, ToolMessage):
# 标准工具消息响应
results.append(Command(update={"messages": [tool_response]}))
elif isinstance(tool_response, Command):
# 直接返回Command对象的工具(如转接工具)
results.append(tool_response)
else:
# 普通响应,转换为工具消息
tool_message = ToolMessage(
content=str(tool_response),
tool_call_id=tool_call["id"]
)
results.append(Command(update={"messages": [tool_message]}))
# 返回所有工具执行结果
return results
# 构建代理的内部图结构
graph = StateGraph(MessagesState)
# 添加模型调用节点和工具调用节点
graph.add_node("call_model", call_model)
graph.add_node("call_tools", call_tools)
# 设置图的边:从开始到模型调用,从工具调用回到模型调用
graph.add_edge(START, "call_model")
graph.add_edge("call_tools", "call_model")
# 编译并返回图
return graph.compile()
def pretty_print_messages(update):
"""
美化打印消息更新,用于调试和展示
Args:
update: 图更新信息,可能是元组或字典
"""
# 检查是否是来自子图的更新
if isinstance(update, tuple):
ns, update = update
# 跳过父图更新的打印
if len(ns) == 0:
return
# 提取子图ID并打印
graph_id = ns[-1].split(":")[0]
print(f"来自子图 {graph_id} 的更新:")
print()
# 遍历所有节点更新
for node_name, node_update in update.items():
print(f"来自节点 {node_name} 的更新:")
print()
# 美化打印所有消息
if "messages" in node_update:
for m in convert_to_messages(node_update["messages"]):
m.pretty_print()
print()
# ============= 定义数学工具 =============
@tool
def add(a: int, b: int) -> int:
"""执行两个数字的加法运算"""
result = a + b
print(f"执行加法: {a} + {b} = {result}")
return result
@tool
def multiply(a: int, b: int) -> int:
"""执行两个数字的乘法运算"""
result = a * b
print(f"执行乘法: {a} × {b} = {result}")
return result
@tool
def subtract(a: int, b: int) -> int:
"""执行两个数字的减法运算"""
result = a - b
print(f"执行减法: {a} - {b} = {result}")
return result
@tool
def divide(a: int, b: int) -> float:
"""执行两个数字的除法运算"""
if b == 0:
return "错误:不能除以零"
result = a / b
print(f"执行除法: {a} ÷ {b} = {result}")
return result
# ============= 演示单个代理 =============
def demo_single_agent():
"""演示单个具有所有数学工具的代理"""
print("=" * 60)
print("演示:单个数学代理")
print("=" * 60)
# 创建一个拥有所有数学工具的代理
math_agent = make_agent(
llm,
[add, multiply, subtract, divide],
system_prompt="你是一个数学专家,可以执行各种数学运算。请一步步解决问题。"
)
print("问题: 计算 (3 + 5) × 12")
print()
# 运行代理并显示结果
for chunk in math_agent.stream({"messages": [("user", "计算 (3 + 5) × 12")]}):
pretty_print_messages(chunk)
# ============= 演示多代理协作 =============
def demo_multi_agent_collaboration():
"""演示多个专业代理之间的协作"""
print("=" * 60)
print("演示:多代理协作系统")
print("=" * 60)
# 创建加法专家代理
addition_expert = make_agent(
llm,
[add, subtract, make_handoff_tool(agent_name="multiplication_expert")],
system_prompt="""你是加法和减法专家。你精通加法和减法运算。
当你完成加法或减法运算后,如果后续还需要乘法或除法运算,
请立即使用 transfer_to_multiplication_expert 工具转接给乘法专家。
不要尝试自己完成乘法运算。"""
)
# 创建乘法专家代理
multiplication_expert = make_agent(
llm,
[multiply, divide, make_handoff_tool(agent_name="addition_expert")],
system_prompt="""你是乘法和除法专家。你精通乘法和除法运算。
当你接收到需要乘法运算的任务时,请立即执行乘法运算。
如果后续还需要加法或减法运算,请转接给加法专家。
当前任务:执行乘法运算并给出最终答案。"""
)
# 构建多代理协作图
builder = StateGraph(MessagesState)
# 添加两个专家代理节点
builder.add_node("addition_expert", addition_expert)
builder.add_node("multiplication_expert", multiplication_expert)
# 设置入口点为加法专家
builder.add_edge(START, "addition_expert")
# 编译协作图
collaboration_graph = builder.compile()
print("问题: 计算 (3 + 5) × 12")
print("加法专家将处理加法,然后转接给乘法专家处理乘法")
print()
# 运行协作图并显示子图中的所有更新
for chunk in collaboration_graph.stream(
{"messages": [("user", "请计算 (3 + 5) × 12")]},
subgraphs=True # 包含子图更新
):
pretty_print_messages(chunk)
# ============= 更复杂的协作示例 =============
def demo_complex_collaboration():
"""演示更复杂的多步骤协作"""
print("=" * 60)
print("演示:复杂多步协作")
print("=" * 60)
# 创建基础运算专家
basic_math_expert = make_agent(
llm,
[add, subtract, make_handoff_tool(agent_name="advanced_math_expert")],
system_prompt="""你是基础数学专家,专门处理加法和减法。
对于乘法、除法等高级运算,请转接给高级数学专家。
你需要先处理括号内的基础运算。"""
)
# 创建高级运算专家
advanced_math_expert = make_agent(
llm,
[multiply, divide, make_handoff_tool(agent_name="basic_math_expert")],
system_prompt="""你是高级数学专家,专门处理乘法和除法。
对于加法、减法等基础运算,请转接给基础数学专家。
你负责处理复杂的乘除运算。"""
)
# 构建协作图
builder = StateGraph(MessagesState)
builder.add_node("basic_math_expert", basic_math_expert)
builder.add_node("advanced_math_expert", advanced_math_expert)
builder.add_edge(START, "basic_math_expert")
complex_graph = builder.compile()
print("复杂问题: 计算 ((10 + 5) × 3 - 8) ÷ 2")
print("将需要多次代理转接来完成计算")
print()
for chunk in complex_graph.stream(
{"messages": [("user", "请逐步计算 ((10 + 5) × 3 - 8) ÷ 2")]},
subgraphs=True
):
pretty_print_messages(chunk)
# ============= 主程序入口 =============
def main():
"""主程序,运行所有演示"""
print("LangGraph工具交接案例演示")
print("展示单代理和多代理协作的数学计算系统")
print()
try:
# 演示1:单个代理
# demo_single_agent()
#
# print("\n" + "-" * 20 + "\n")
# 演示2:多代理协作
demo_multi_agent_collaboration()
# print("\n" + "-" * 20 + "\n")
#
# # 演示3:复杂协作
# demo_complex_collaboration()
except Exception as e:
print(f"运行出错: {e}")
if __name__ == "__main__":
main()
如何构建多智能体应用
自定义主管架构
from typing import Literal, Annotated, Dict, Any, List
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.types import Command
from langchain_core.tools import tool, InjectedToolCallId
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import InjectedState, create_react_agent
import json
import os
from dotenv import load_dotenv
load_dotenv()
# 模拟数据
KNOWLEDGE_BASE = {
"login": "请清除浏览器缓存并重新登录,或重置密码。",
"payment": "请检查银行卡余额,确认交易状态,或联系银行。",
"bug": "我们已记录此问题,技术团队将在24小时内处理。",
"network": "请检查网络连接,或尝试切换网络环境。",
"performance": "建议清理缓存、重启应用或检查系统资源使用情况。"
}
PRODUCTS = {
"basic": {"name": "基础版", "price": 99, "features": ["基础功能", "邮件支持"]},
"pro": {"name": "专业版", "price": 299, "features": ["高级功能", "优先支持", "API访问"]},
"enterprise": {"name": "企业版", "price": 999, "features": ["企业功能", "专属客服", "定制开发"]}
}
USER_DATABASE = {
"user123": {"plan": "pro", "status": "active", "support_level": "premium", "balance": 500},
"user456": {"plan": "basic", "status": "active", "support_level": "standard", "balance": 100}
}
TICKET_SYSTEM = [] # 模拟工单系统
# 扩展状态定义
class CustomerServiceState(MessagesState):
current_agent: str # 目前智能体
customer_id: str # 用户id
issue_type: str # 问题类别
priority: str # 优先级
resolution_status: str # 解决状态
routing_reason: str # 路由原因
ticket_id: str # 工单id
remaining_steps: int #
# ================================
# 为ReAct Agent定义工具函数
# ================================
@tool
def search_knowledge_base(query: str) -> str:
_"""搜索技术知识库,查找解决方案"""_
_ _query_lower = query.lower()
results = []
for issue, solution in KNOWLEDGE_BASE.items():
if issue in query_lower:
results.append(f"{issue}: {solution}")
if results:
return f"找到以下解决方案:\n" + "\n".join(results)
else:
return "未在知识库中找到相关解决方案,建议创建技术工单进行人工处理。"
@tool
def get_product_info(product_query: str) -> str:
_"""获取产品信息和价格"""_
_ _if not product_query:
# 返回所有产品信息
result = "我们的产品线包括:\n\n"
for key, product in PRODUCTS.items():
result += f"**{product['name']}** - ¥{product['price']}/月\n"
result += f"功能: {', '.join(product['features'])}\n\n"
return result
# 搜索特定产品
query_lower = product_query.lower()
for key, product in PRODUCTS.items():
if key in query_lower or product['name'] in query_lower:
return f"**{product['name']}**\n价格: ¥{product['price']}/月\n功能: {', '.join(product['features'])}"
return f"未找到关于'{product_query}'的产品信息。请查看我们的完整产品列表。"
@tool
def get_user_account_info(user_id: str) -> str:
_"""查询用户账户信息"""_
_ _if not user_id:
return "请提供您的用户ID以查询账户信息。"
if user_id in USER_DATABASE:
user_info = USER_DATABASE[user_id]
return f"""账户信息:
用户ID: {user_id}
当前套餐: {user_info['plan']}
账户状态: {user_info['status']}
支持级别: {user_info['support_level']}
账户余额: ¥{user_info['balance']}"""
else:
return f"未找到用户ID '{user_id}' 的账户信息。请检查用户ID是否正确。"
@tool
def create_support_ticket(issue_description: str, priority: str,
state: Annotated[Dict, InjectedState]) -> str:
_"""创建技术支持工单"""_
_ _import uuid
import datetime
ticket_id = f"TICKET-{str(uuid.uuid4())[:8].upper()}"
customer_id = state.get("customer_id", "anonymous")
ticket = {
"id": ticket_id,
"customer_id": customer_id,
"issue": issue_description,
"priority": priority,
"status": "open",
"created_at": datetime.datetime.now() # 模拟时间
}
TICKET_SYSTEM.append(ticket)
# 更新状态
state["ticket_id"] = ticket_id
return f"已创建支持工单: {ticket_id}\n问题描述: {issue_description}\n优先级: {priority}\n我们的技术团队将在24小时内处理您的问题。"
@tool
def calculate_upgrade_cost(current_plan: str, target_plan: str) -> str:
_"""计算升级费用"""_
_ _if current_plan not in PRODUCTS or target_plan not in PRODUCTS:
return "无效的套餐类型。请检查套餐名称。"
current_price = PRODUCTS[current_plan]["price"]
target_price = PRODUCTS[target_plan]["price"]
if target_price <= current_price:
return f"目标套餐 ({PRODUCTS[target_plan]['name']}) 价格不高于当前套餐 ({PRODUCTS[current_plan]['name']}),无需升级费用。"
upgrade_cost = target_price - current_price
return f"""升级费用计算:
当前套餐: {PRODUCTS[current_plan]['name']} (¥{current_price}/月)
目标套餐: {PRODUCTS[target_plan]['name']} (¥{target_price}/月)
升级费用: ¥{upgrade_cost}/月
新增功能: {', '.join(set(PRODUCTS[target_plan]['features']) - set(PRODUCTS[current_plan]['features']))}"""
@tool
def process_refund_request(user_id: str, reason: str) -> str:
_"""处理退款请求"""_
_ _if not user_id or user_id not in USER_DATABASE:
return "请提供有效的用户ID以处理退款请求。"
user_info = USER_DATABASE[user_id]
if user_info["status"] != "active":
return "只有活跃账户才能申请退款。"
# 模拟退款处理
refund_amount = PRODUCTS[user_info["plan"]]["price"]
return f"""退款申请已提交:
用户ID: {user_id}
退款原因: {reason}
退款金额: ¥{refund_amount}
处理时间: 3-5个工作日
退款将原路返回到您的支付账户。"""
# ================================
# 使用create_react_agent创建专业Agent
# ================================
def create_llm_router_with_react_agents():
_"""创建使用ReAct Agent的智能路由系统"""_
_ _llm = ChatOpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model='qwen-plus-2025-01-25',
temperature=0.1
)
# 1. 创建技术支持ReAct Agent
tech_tools = [search_knowledge_base, create_support_ticket]
tech_system_prompt = """你是一个专业的技术支持工程师。你的任务是:
1. 使用search_knowledge_base工具搜索已知问题的解决方案
2. 如果知识库中没有解决方案,使用create_support_ticket创建工单
3. 提供清晰的步骤指导
4. 对于复杂问题,建议用户联系高级技术支持
请始终保持专业、耐心的态度,并确保用户理解每个步骤。"""
tech_agent = create_react_agent(
llm,
tech_tools,
state_schema=CustomerServiceState,
prompt=tech_system_prompt
)
# 2. 创建销售ReAct Agent
sales_tools = [get_product_info, calculate_upgrade_cost]
sales_system_prompt = """你是一个专业的销售顾问。你的任务是:
1. 使用get_product_info工具提供详细的产品信息
2. 使用calculate_upgrade_cost工具帮助客户计算升级费用
3. 根据客户需求推荐合适的产品
4. 解答价格和功能相关问题
请保持热情、专业的态度,关注客户真实需求,提供有价值的建议。"""
sales_agent = create_react_agent(
llm,
sales_tools,
state_schema=CustomerServiceState,
prompt=sales_system_prompt
)
# 3. 创建客户管理ReAct Agent
admin_tools = [get_user_account_info, process_refund_request]
admin_system_prompt = """你是一个客户管理专员。你的任务是:
1. 使用get_user_account_info工具查询客户账户信息
2. 使用process_refund_request工具处理退款申请
3. 处理账户相关问题和权限调整
4. 确保客户数据的安全和隐私
请保持严谨、负责的态度,严格遵守数据保护规定。"""
admin_agent = create_react_agent(
llm,
admin_tools,
state_schema=CustomerServiceState,
prompt=admin_system_prompt
)
# 4. 创建智能路由器
def llm_router(state: CustomerServiceState) -> Command[Literal["tech_agent", "sales_agent", "admin_agent", END]]:
_"""LLM驱动的智能路由器"""_
_ _messages = state["messages"]
if not messages:
return Command(goto=END)
last_message = messages[-1].content
routing_prompt = f"""分析用户问题并决定路由到哪个专业部门。
用户问题: {last_message}
部门说明:
- tech_agent: 技术问题、Bug、登录异常、系统错误、性能问题
- sales_agent: 产品咨询、价格询问、功能对比、升级服务、购买流程
- admin_agent: 账户查询、权限管理、退款申请、账单问题、用户资料
返回JSON格式:
{{
"route_to": "部门代码",
"confidence": "置信度(0-1)",
"reason": "路由原因",
"priority": "优先级(low/normal/high/urgent)"
}}"""
try:
response = llm.invoke([SystemMessage(content=routing_prompt)])
decision = json.loads(response.content.strip())
route_to = decision.get("route_to", "tech_agent")
confidence = decision.get("confidence", 0.8)
reason = decision.get("reason", "智能路由分析")
priority = decision.get("priority", "normal")
print(f"🧠 路由决策: {route_to} | 置信度: {confidence} | 原因: {reason}")
return Command(
goto=route_to,
update={
"current_agent": route_to,
"issue_type": route_to.replace("_agent", ""),
"priority": priority,
"routing_reason": reason
}
)
except Exception as e:
print(f"⚠️ 路由失败,使用默认路由: {e}")
return Command(
goto="tech_agent",
update={
"current_agent": "tech_agent",
"routing_reason": "路由失败,默认技术支持"
}
)
# 5. 构建图结构
builder = StateGraph(CustomerServiceState)
# 添加节点
builder.add_node("llm_router", llm_router)
builder.add_node("tech_agent", tech_agent)
builder.add_node("sales_agent", sales_agent)
builder.add_node("admin_agent", admin_agent)
# 设置入口
builder.add_edge(START, "llm_router")
return builder.compile()
# ================================
# 纯ReAct Agent架构 - 监督者也是ReAct Agent-工具进行交接
# ================================
def create_pure_react_system():
_"""创建纯ReAct Agent架构 - 连监督者都是ReAct Agent"""_
_ _llm = ChatOpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model='qwen-plus-2025-01-25',
temperature=0.2
)
# 为监督者创建路由工具
@tool
def route_to_technical_support(issue_description: str, state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
_"""将问题路由到技术支持部门"""_
_ _tool_message = {
"role": "tool",
"content": f"Successfully transferred to tech_agent",
"name": "tech_agent",
"tool_call_id": tool_call_id,
}
return Command(
goto="tech_agent",
# 在父图中执行导航
graph=Command.PARENT,
update={
"current_agent": "tech_agent",
"issue_type": "technical",
"routing_reason": f"技术问题: {issue_description}",
"messages":
state["messages"] + [tool_message]
}
)
@tool
def route_to_sales_department(inquiry_description: str, state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
_"""将咨询路由到销售部门"""_
_ _tool_message = {
"role": "tool",
"content": f"Successfully transferred to sales_agent",
"name": "sales_agent",
"tool_call_id": tool_call_id,
}
return Command(
goto="sales_agent",
# 在父图中执行导航
graph=Command.PARENT,
update={
"current_agent": "sales_agent",
"issue_type": "sales",
"routing_reason": f"销售咨询: {inquiry_description}",
"messages":
state["messages"] + [tool_message]
}
)
@tool
def route_to_customer_management(request_description: str, state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
_"""将请求路由到客户管理部门"""_
_ _tool_message = {
"role": "tool",
"content": f"Successfully transferred to admin_agent",
"name": "admin_agent",
"tool_call_id": tool_call_id,
}
return Command(
goto="admin_agent",
# 在父图中执行导航
graph=Command.PARENT,
update={
"current_agent": "admin_agent",
"issue_type": "administration",
"routing_reason": f"销售咨询: {request_description}",
"messages":
state["messages"] + [tool_message]
}
)
# 监督者ReAct Agent
supervisor_tools = [route_to_technical_support, route_to_sales_department, route_to_customer_management]
supervisor_prompt = """你是一个智能客服路由监督者。根据用户问题,使用相应的路由工具将问题分配给专业部门:
- route_to_technical_support: 技术问题、系统故障、登录问题、Bug报告
- route_to_sales_department: 产品咨询、价格询问、购买升级、功能对比
- route_to_customer_management: 账户管理、退款申请、权限问题、用户资料
分析用户问题的核心内容,选择最合适的部门处理。一旦路由完成,解释你的决策理由。"""
supervisor_agent = create_react_agent(
llm,
supervisor_tools,
state_schema=CustomerServiceState,
prompt=supervisor_prompt
)
# 专业部门的ReAct Agent
tech_tools = [search_knowledge_base, create_support_ticket]
tech_agent = create_react_agent(
llm,
tech_tools,
state_schema=CustomerServiceState,
prompt="你是技术支持工程师,使用工具解决技术问题。"
)
sales_tools = [get_product_info, calculate_upgrade_cost]
sales_agent = create_react_agent(
llm,
sales_tools,
state_schema=CustomerServiceState,
prompt="你是销售顾问,使用工具提供产品信息和升级建议。"
)
admin_tools = [get_user_account_info, process_refund_request]
admin_agent = create_react_agent(
llm,
admin_tools,
state_schema=CustomerServiceState,
prompt="你是客户管理专员,使用工具处理账户和管理事务。"
)
# 构建纯ReAct架构
builder = StateGraph(CustomerServiceState)
builder.add_node("supervisor_agent", supervisor_agent)
builder.add_node("tech_agent", tech_agent)
builder.add_node("sales_agent", sales_agent)
builder.add_node("admin_agent", admin_agent)
builder.add_edge(START, "supervisor_agent")
return builder.compile()
# ================================
# 测试运行函数
# ================================
def run_react_agent_examples():
_"""运行ReAct Agent示例"""_
_ _print("🤖 基于create_react_agent的多智能体客户服务系统")
print("=" * 70)
# 测试用例
test_cases = [
# {
# "message": "我的应用一直崩溃,点击登录按钮就闪退",
# "customer_id": "user123"
# },
{
"message": "我想了解专业版和企业版的区别,我们公司大概50人,另外我想查询我的余额,用户ID是user123",
"customer_id": ""
},
# {
# "message": "我想查看我的账户余额,用户ID是user123,另外我想申请退款",
# "customer_id": "user123"
# },
# {
# "message": "系统报错500,数据库连接失败,这是生产环境的紧急问题",
# "customer_id": "user456"
# }
]
systems = [
# ("混合架构 (LLM路由 + ReAct Agent)", create_llm_router_with_react_agents),
("纯ReAct架构 (监督者也是ReAct Agent)", create_pure_react_system)
]
for system_name, create_system in systems:
print(f"\n🔥 {system_name}")
print("=" * 70)
try:
graph = create_system()
for i, test_case in enumerate(test_cases, 1):
print(f"\n--- 测试案例 {i} ---")
print(f"用户: {test_case['message']}")
print("-" * 50)
initial_state = {
"messages": [HumanMessage(content=test_case['message'])],
"current_agent": "start",
"customer_id": test_case['customer_id'],
"issue_type": "",
"priority": "normal",
"resolution_status": "pending",
"routing_reason": "",
"ticket_id": "",
"remaining_steps": 10
}
try:
result = graph.invoke(initial_state)
# 显示AI回复
print("🤖 AI回复:")
for msg in result["messages"]:
if isinstance(msg, AIMessage):
# 截取长回复以保持可读性
content = msg.content
if len(content) > 300:
content = content[:300] + "..."
print(f" {content}")
# 显示处理信息
print(f"\n📊 处理状态:")
print(f" 处理部门: {result.get('current_agent', 'N/A')}")
print(f" 问题类型: {result.get('issue_type', 'N/A')}")
print(f" 优先级: {result.get('priority', 'N/A')}")
if result.get('ticket_id'):
print(f" 工单ID: {result.get('ticket_id')}")
except Exception as e:
print(f"❌ 处理失败: {e}")
except Exception as e:
print(f"❌ 系统创建失败: {e}")
if __name__ == "__main__":
# 检查API配置
if not os.getenv("DASHSCOPE_API_KEY"):
print("⚠️ 请设置DASHSCOPE_API_KEY环境变量")
print("export DASHSCOPE_API_KEY='your-api-key-here'")
else:
run_react_agent_examples()
# 显示工单系统状态
if TICKET_SYSTEM:
print(f"\n📋 创建的工单 ({len(TICKET_SYSTEM)} 个):")
for ticket in TICKET_SYSTEM:
print(f" {ticket['id']}: {ticket['issue'][:50]}...")
预构建主管架构
pip install langgraph-supervisor
_"""_
_正确使用 LangGraph Supervisor 调用子代理的多智能体系统_
_系统包含:_
_- ResearchAgent: 研究专家,使用网络搜索工具_
_- AnalysisAgent: 数据分析专家,处理计算和分析_
_- WriteAgent: 写作专家,生成报告和总结_
_- Supervisor: 协调器,负责任务分派和流程控制_
_业务场景:市场研究和报告生成_
_"""_
from typing import Annotated, Sequence, TypedDict
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model
import os
from dotenv import load_dotenv
load_dotenv()
llm = init_chat_model(api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_provider="openai",
model='qwen-plus-2025-01-25')
# ============================================================================
# 定义系统状态
# ============================================================================
class MarketResearchState(TypedDict):
_"""_
_ 市场研究系统的状态定义_
_ messages: 对话历史消息_
_ research_topic: 研究主题_
_ research_data: 收集到的研究数据_
_ analysis_results: 分析结果_
_ final_report: 最终报告_
_ """_
_ _messages: Annotated[Sequence[BaseMessage], "对话历史"]
research_topic: str
research_data: dict
analysis_results: dict
final_report: str
# ============================================================================
# 创建专业工具
# ============================================================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import yfinance as yf # 获取股票信息
import re
@tool
def web_search(query: str) -> str:
_"""_
_ 真实的网络搜索工具 - 使用 Tavily API_
_ Args:_
_ query: 搜索关键词_
_ Returns:_
_ 搜索结果_
_ """_
_ _try:
from langchain_tavily import TavilySearch
tavily_search = TavilySearch()
results = tavily_search.invoke(query)
if results:
formatted_results = f"🔍 搜索关键词: {query}\n\n"
for i, result in enumerate(results["results"], 1):
title = result.get('title', '无标题')
content = result.get('content', '无内容')
url = result.get('url', '无链接')
# 清理和截取内容
content = re.sub(r'\s+', ' ', content).strip()
if len(content) > 300:
content = content[:300] + "..."
formatted_results += f"{i}. **{title}**\n"
formatted_results += f" 📖 {content}\n"
formatted_results += f" 🔗 {url}\n\n"
return formatted_results
except Exception as e:
return f"搜索过程中出现错误: {str(e)}\n请检查网络连接或API配置。"
@tool
def calculate_market_metrics(data_input: str) -> str:
_"""_
_ 真实的市场指标计算工具_
_ Args:_
_ data_input: 数据输入,可以是股票代码、数值或描述_
_ Returns:_
_ 计算结果和分析_
_ """_
_ _try:
result = "📊 市场指标分析结果:\n\n"
# 尝试解析输入中的数值
numbers = re.findall(r'[\d,]+\.?\d*', data_input.replace(',', ''))
numbers = [float(x) for x in numbers if x]
if numbers:
# 基础统计计算
if len(numbers) >= 2:
mean_val = np.mean(numbers)
std_val = np.std(numbers)
growth_rate = ((numbers[-1] - numbers[0]) / numbers[0]) * 100 if numbers[0] != 0 else 0
result += f"📈 平均值: {mean_val:,.2f}\n"
result += f"📊 标准差: {std_val:,.2f}\n"
result += f"📈 增长率: {growth_rate:+.2f}%\n"
# 复合年增长率计算 (假设数据跨度为多年)
if len(numbers) > 2:
years = len(numbers) - 1
cagr = (pow(numbers[-1] / numbers[0], 1 / years) - 1) * 100
result += f"📊 复合年增长率 (CAGR): {cagr:.2f}%\n"
# 波动率计算
if len(numbers) > 1:
returns = np.diff(numbers) / numbers[:-1]
volatility = np.std(returns) * 100
result += f"⚡ 波动率: {volatility:.2f}%\n"
# 检查是否包含股票代码
stock_symbols = re.findall(r'\b[A-Z]{1,5}\b', data_input.upper())
if stock_symbols:
result += f"\n🔍 检测到股票代码: {', '.join(stock_symbols)}\n"
for symbol in stock_symbols[:3]: # 限制最多3个
try:
ticker = yf.Ticker(symbol)
info = ticker.info
if info:
pe = info.get('trailingPE', 'N/A')
pb = info.get('priceToBook', 'N/A')
roe = info.get('returnOnEquity', 'N/A')
result += f"\n📊 {symbol} 财务指标:\n"
result += f" • 市盈率 (PE): {pe if pe != 'N/A' else 'N/A'}\n"
result += f" • 市净率 (PB): {pb if pb != 'N/A' else 'N/A'}\n"
result += f" • 净资产收益率 (ROE): {roe if roe != 'N/A' else 'N/A'}\n"
except:
continue
# 市场风险评估
result += f"\n🎯 风险评估:\n"
if numbers and len(numbers) > 1:
cv = (std_val / mean_val) * 100 if mean_val != 0 else 0
if cv < 10:
risk_level = "低风险"
elif cv < 25:
risk_level = "中等风险"
else:
risk_level = "高风险"
result += f" • 风险等级: {risk_level} (变异系数: {cv:.2f}%)\n"
# 投资建议
if numbers and len(numbers) >= 2:
if growth_rate > 15:
suggestion = "强烈看好,建议适当增加投资"
elif growth_rate > 5:
suggestion = "谨慎乐观,可考虑投资"
elif growth_rate > -5:
suggestion = "保持观望,注意风险控制"
else:
suggestion = "谨慎投资,建议降低仓位"
result += f" • 投资建议: {suggestion}\n"
return result
except Exception as e:
return f"计算市场指标时出错: {str(e)}"
@tool
def generate_real_charts(data_description: str, chart_type: str = "auto") -> str:
_"""_
_ 生成真实的数据图表_
_ Args:_
_ data_description: 数据描述或股票代码_
_ chart_type: 图表类型 (line, bar, pie, scatter, auto)_
_ Returns:_
_ 图表生成结果和base64编码的图片_
_ """_
_ _try:
# 设置中文字体,例如 'SimHei' (黑体) 或 'Microsoft YaHei' (微软雅黑)
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号 '-' 显示为方块的问题
plt.style.use('seaborn-v0_8')
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('市场数据分析图表', fontsize=16, fontweight='bold')
# 检查是否包含股票代码
stock_symbols = re.findall(r'\b[A-Z]{2,5}\b', data_description.upper())
if stock_symbols:
# 获取股票数据并生成图表
symbol = stock_symbols[0]
try:
ticker = yf.Ticker(symbol)
hist = ticker.history(period="6mo")
if not hist.empty:
# 价格走势图
axes[0, 0].plot(hist.index, hist['Close'], linewidth=2, color='#1f77b4')
axes[0, 0].set_title(f'{symbol} 股价走势 (6个月)', fontweight='bold')
axes[0, 0].set_ylabel('价格 ($)')
axes[0, 0].grid(True, alpha=0.3)
# 成交量柱状图
axes[0, 1].bar(hist.index, hist['Volume'], alpha=0.7, color='#ff7f0e')
axes[0, 1].set_title(f'{symbol} 成交量', fontweight='bold')
axes[0, 1].set_ylabel('成交量')
# 价格分布直方图
axes[1, 0].hist(hist['Close'], bins=20, alpha=0.7, color='#2ca02c', edgecolor='black')
axes[1, 0].set_title('价格分布', fontweight='bold')
axes[1, 0].set_xlabel('价格 ($)')
axes[1, 0].set_ylabel('频次')
# 移动平均线
hist['MA20'] = hist['Close'].rolling(window=20).mean()
hist['MA50'] = hist['Close'].rolling(window=50).mean()
axes[1, 1].plot(hist.index, hist['Close'], label='收盘价', linewidth=1)
axes[1, 1].plot(hist.index, hist['MA20'], label='20日均线', linewidth=2)
axes[1, 1].plot(hist.index, hist['MA50'], label='50日均线', linewidth=2)
axes[1, 1].set_title('移动平均线分析', fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
except Exception as e:
# 如果股票数据获取失败,生成示例图表
generate_charts(axes)
else:
# 生成通用示例图表
generate_charts(axes)
# 调整布局
plt.tight_layout()
# 保存图表为base64
buffer = BytesIO()
plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight')
buffer.seek(0)
# 转换为base64
chart_base64 = base64.b64encode(buffer.getvalue()).decode()
plt.close()
result = "📊 已成功生成市场分析图表!\n\n"
result += "图表包含:\n"
result += "1. 📈 价格/趋势走势图\n"
result += "2. 📊 成交量/数据柱状图\n"
result += "3. 📊 数据分布直方图\n"
result += "4. 📈 移动平均线分析\n\n"
# 在实际应用中,你可以保存图片文件或返回base64数据
result += f"图表已生成 (图片大小: {len(chart_base64)} 字符)\n"
result += "💡 提示: 图表已保存,可用于报告展示\n"
return result
except Exception as e:
return f"生成图表时出错: {str(e)}"
def generate_charts(axes):
_"""生成示例图表"""_
_ _# 示例数据
dates = pd.date_range(start='2024-01-01', end='2024-06-30', freq='D')
np.random.seed(42)
prices = 100 + np.cumsum(np.random.randn(len(dates)) * 0.5)
volumes = np.random.randint(1000000, 5000000, len(dates))
# 趋势图
axes[0, 0].plot(dates, prices, linewidth=2, color='#1f77b4')
axes[0, 0].set_title('市场趋势分析', fontweight='bold')
axes[0, 0].set_ylabel('指数/价格')
axes[0, 0].grid(True, alpha=0.3)
# 成交量
axes[0, 1].bar(dates[::10], volumes[::10], alpha=0.7, color='#ff7f0e')
axes[0, 1].set_title('成交量分析', fontweight='bold')
axes[0, 1].set_ylabel('成交量')
# 分布图
axes[1, 0].hist(prices, bins=20, alpha=0.7, color='#2ca02c', edgecolor='black')
axes[1, 0].set_title('价格分布', fontweight='bold')
axes[1, 0].set_xlabel('价格')
axes[1, 0].set_ylabel('频次')
# 市场份额饼图
labels = ['公司A', '公司B', '公司C', '公司D', '其他']
sizes = [30, 25, 20, 15, 10]
colors = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#ff99cc']
axes[1, 1].pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
axes[1, 1].set_title('市场份额分析', fontweight='bold')
# ============================================================================
# 创建专业智能体
# ============================================================================
def create_research_agent():
_"""_
_ 创建研究智能体_
_ 专门负责:信息收集、数据搜索、行业调研_
_ """_
_ _research_agent = create_react_agent(
model=llm,
tools=[web_search],
prompt=(
"你是一名专业的市场研究员。你的职责包括:\n\n"
"1. 收集准确、全面的市场信息\n"
"2. 搜索行业数据和趋势\n"
"3. 识别关键市场参与者\n"
"4. 分析竞争格局\n\n"
"工作原则:\n"
"- 使用可靠的数据源\n"
"- 提供具体的数字和事实\n"
"- 关注最新的市场动态\n"
"- 只专注于研究工作,不做分析或写作\n\n"
"- 全程使用中文"
"完成研究后,将结果直接报告给supervisor。"
),
name="research_agent",
)
return research_agent
def create_analysis_agent():
_"""_
_ 创建分析智能体_
_ 专门负责:数据分析、指标计算、趋势预测_
_ """_
_ _analysis_agent = create_react_agent(
model=llm,
tools=[calculate_market_metrics, generate_real_charts],
prompt=(
"你是一名专业的数据分析师。你的职责包括:\n\n"
"1. 分析市场数据和趋势\n"
"2. 计算关键业务指标\n"
"3. 进行竞争分析和预测\n"
"4. 生成数据可视化图表\n\n"
"工作原则:\n"
"- 基于数据做客观分析\n"
"- 使用统计方法和模型\n"
"- 提供清晰的分析结论\n"
"- 只专注于分析工作,不做研究或写作\n\n"
"完成分析后,将结果直接报告给supervisor。"
),
name="analysis_agent",
)
return analysis_agent
def create_writing_agent():
_"""_
_ 创建写作智能体_
_ 专门负责:报告撰写、内容整理、总结归纳_
_ """_
_ _writing_agent = create_react_agent(
model=llm,
tools=[], # 写作智能体不需要特殊工具
prompt=(
"你是一名专业的商业报告写作专家。你的职责包括:\n\n"
"1. 整理和综合研究数据\n"
"2. 撰写结构清晰的报告\n"
"3. 提供actionable的建议\n"
"4. 确保内容专业且易懂\n\n"
"工作原则:\n"
"- 结构化组织信息\n"
"- 使用专业商业语言\n"
"- 突出关键发现和洞察\n"
"- 提供实用的建议和下一步行动\n\n"
"完成写作后,将最终报告提交给supervisor。"
),
name="writing_agent",
)
return writing_agent
# ============================================================================
# 使用官方方式创建 Supervisor
# ============================================================================
def create_market_research_supervisor():
_"""_
_ 使用官方 create_supervisor 创建市场研究协调器_
_ 这个supervisor会真正调用子代理来完成不同的任务_
_ """_
_ _# 创建各个专业智能体
research_agent = create_research_agent()
analysis_agent = create_analysis_agent()
writing_agent = create_writing_agent()
# 使用官方语法创建 supervisor
supervisor = create_supervisor(
model=llm,
agents=[research_agent, analysis_agent, writing_agent],
prompt=(
"你是一个市场研究项目的项目经理,负责协调三个专业团队:\n\n"
"1. Research Agent - 市场研究员,负责收集市场信息、行业数据和竞争情报\n"
"2. Analysis Agent - 数据分析师,负责分析数据、计算指标和预测趋势\n"
"3. Writing Agent - 报告专家,负责整理信息并撰写最终报告\n\n"
"你的工作流程:\n"
"1. 首先分派研究任务给Research Agent收集基础数据\n"
"2. 然后让Analysis Agent分析数据并计算关键指标\n"
"3. 最后让Writing Agent基于研究和分析结果撰写完整报告\n"
"4. 如需补充信息,可以重复调用相应的agent\n\n"
"重要原则:\n"
"- 一次只分配给一个agent,不要并行调用\n"
"- 确保每个agent都有明确的任务说明\n"
"- 按照逻辑顺序推进工作:研究 → 分析 → 写作\n"
"- 自己不要直接做具体工作,而是分派给专业agent"
),
add_handoff_back_messages=True, # 启用代理回传消息
output_mode="full_history", # 保留完整历史
).compile()
return supervisor
# ============================================================================
# 主要功能函数
# ============================================================================
def conduct_market_research(research_request: str):
_"""_
_ 执行市场研究任务_
_ Args:_
_ research_request: 研究需求描述_
_ Returns:_
_ 研究结果_
_ """_
_ _# 创建supervisor
supervisor = create_market_research_supervisor()
from IPython.display import display, Image
from langchain_core.runnables.graph import MermaidDrawMethod
display(
Image(supervisor.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API, output_file_path="./可视化图.png")))
print(f"🔍 开始市场研究任务: {research_request}")
print("=" * 60)
# 构建初始消息
initial_messages = [
HumanMessage(content=research_request)
]
final_result = None # 用于存储最终结果
try:
# 流式执行研究任务
print("📊 Supervisor开始协调各个专业团队...")
print()
for chunk in supervisor.stream({"messages": initial_messages}):
final_result = chunk # 保存每次更新的结果
# 处理不同类型的更新
if isinstance(chunk, dict):
for node_name, node_update in chunk.items():
if "messages" in node_update and node_update["messages"]:
messages = node_update["messages"]
latest_message = messages[-1]
# 显示每个阶段的进展
if hasattr(latest_message, 'name') and latest_message.name:
agent_name = latest_message.name
print(f"📝 {agent_name} 工作更新:")
# 显示消息内容(截取前200字符)
if hasattr(latest_message, 'content') and latest_message.content:
content = latest_message.content
if len(content) > 200:
content = content[:200] + "...[继续工作中]"
print(f" {content}")
print()
# 检查是否有工具调用
if hasattr(latest_message, 'tool_calls') and latest_message.tool_calls:
for tool_call in latest_message.tool_calls:
tool_name = "unknown"
if isinstance(tool_call, dict):
tool_name = tool_call.get('name', 'unknown')
elif hasattr(tool_call, 'name'):
tool_name = tool_call.name
print(f"🔧 正在使用工具: {tool_name}")
print()
# 检查是否有转移消息
if hasattr(latest_message, 'content') and latest_message.content:
if "transfer" in latest_message.content.lower():
print(f"🔄 任务转移: {latest_message.content}")
print()
print("✅ 市场研究任务完成!")
# 返回最终结果,如果没有结果则返回空字典
return final_result if final_result is not None else {}
except Exception as e:
print(f"❌ 执行过程中出错: {str(e)}")
import traceback
traceback.print_exc() # 打印详细错误信息用于调试
return {"error": str(e)}
# ============================================================================
# 测试和演示
# ============================================================================
def run_market_research_demo():
_"""_
_ 运行市场研究系统演示_
_ """_
_ _print("🚀 LangGraph Supervisor 市场研究系统演示")
print("本系统展示真正的子代理调用和协作")
print("=" * 60)
# 测试用例
test_cases = [
{
"title": "电动汽车市场研究",
"request": "请帮我做一份关于2024年电动汽车市场的综合研究报告,包括市场规模、增长趋势、主要竞争者和未来预测。",
},
# {
# "title": "人工智能行业分析",
# "request": "分析当前人工智能行业的发展状况,重点关注大模型技术的商业化前景和投资机会。",
# },
# {
# "title": "可再生能源投资研究",
# "request": "研究可再生能源领域的投资机会,分析太阳能和风能的市场潜力以及政策影响。",
# }
]
for i, test_case in enumerate(test_cases, 1):
print(f"\n📋 测试案例 {i}: {test_case['title']}")
print(f"研究需求: {test_case['request']}")
print("-" * 50)
# 执行研究任务
result = conduct_market_research(test_case['request'])
print("任务执行结果:", result)
print("\n" + "=" * 60)
# ============================================================================
# 8. 主程序入口
# ============================================================================
def main():
_"""_
_ 主程序入口_
_ """_
_ _try:
# 运行演示
run_market_research_demo()
except Exception as e:
print(f"❌ 程序执行错误: {e}")
if __name__ == "__main__":
main()