Skip to main content

wesichain_graph/
tool_node.rs

1use std::sync::Arc;
2
3use futures::stream::{self, BoxStream, StreamExt};
4use tokio::task::JoinSet;
5use wesichain_core::Tool;
6use wesichain_core::{Runnable, StreamEvent, WesichainError};
7use wesichain_llm::{Message, Role, ToolCall};
8
9use crate::{GraphState, StateSchema, StateUpdate};
10
11/// Trait for states that contain pending tool calls.
12///
13/// Implement this on your state to use [`ToolNode`] for generic tool execution.
14/// For ReAct-style agents using scratchpad-based state, use
15/// [`ReActToolNode`](crate::react_subgraph::ReActToolNode) with `ScratchpadState` instead.
16pub trait HasToolCalls {
17    fn tool_calls(&self) -> &Vec<ToolCall>;
18    fn push_tool_result(&mut self, message: Message);
19}
20
21/// Generic tool execution node for graph-based workflows.
22///
23/// `ToolNode` executes all pending tool calls from state via the [`HasToolCalls`] trait.
24/// This is the general-purpose tool executor suitable for any workflow.
25///
26/// For ReAct-style agents, prefer [`ReActToolNode`](crate::react_subgraph::ReActToolNode)
27/// which integrates with the scratchpad pattern (`ScratchpadState`).
28pub struct ToolNode {
29    tools: Vec<Arc<dyn Tool>>,
30}
31
32impl ToolNode {
33    pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
34        Self { tools }
35    }
36
37    pub async fn invoke<S>(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError>
38    where
39        S: StateSchema<Update = S> + HasToolCalls,
40    {
41        <Self as Runnable<GraphState<S>, StateUpdate<S>>>::invoke(self, input).await
42    }
43}
44
45#[async_trait::async_trait]
46impl<S> Runnable<GraphState<S>, StateUpdate<S>> for ToolNode
47where
48    S: StateSchema<Update = S> + HasToolCalls,
49{
50    async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
51        let calls: Vec<ToolCall> = input.data.tool_calls().clone();
52
53        // Dispatch all tool calls concurrently, preserving original order.
54        let mut join_set: JoinSet<(usize, String, Result<String, WesichainError>)> =
55            JoinSet::new();
56
57        for (index, call) in calls.iter().enumerate() {
58            let tool = self
59                .tools
60                .iter()
61                .find(|t| t.name() == call.name)
62                .ok_or_else(|| WesichainError::ToolCallFailed {
63                    tool_name: call.name.clone(),
64                    reason: "not found".to_string(),
65                })?;
66            let tool = tool.clone();
67            let args = call.args.clone();
68            let call_id = call.id.clone();
69            let tool_name = call.name.clone();
70            join_set.spawn(async move {
71                let result = tool.invoke(args).await.map(|v| v.to_string()).map_err(|e| {
72                    WesichainError::ToolCallFailed {
73                        tool_name,
74                        reason: e.to_string(),
75                    }
76                });
77                (index, call_id, result)
78            });
79        }
80
81        // Collect results and sort by original index so message history is deterministic.
82        let mut results: Vec<(usize, String, Result<String, WesichainError>)> =
83            Vec::with_capacity(calls.len());
84        while let Some(res) = join_set.join_next().await {
85            results.push(res.map_err(|e| WesichainError::Custom(format!("task panicked: {e}")))?);
86        }
87        results.sort_by_key(|(idx, _, _)| *idx);
88
89        let mut next = input.data.clone();
90        for (_, call_id, output) in results {
91            next.push_tool_result(Message {
92                role: Role::Tool,
93                content: output?.into(),
94                tool_call_id: Some(call_id),
95                tool_calls: Vec::new(),
96            });
97        }
98        Ok(StateUpdate::new(next))
99    }
100
101    fn stream(&self, _input: GraphState<S>) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
102        stream::empty().boxed()
103    }
104}