Skip to main content

synaptic_graph/
tool_node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14/// Wraps a `SerialToolExecutor` into a `ToolCaller` for the middleware chain.
15struct BaseToolCaller {
16    executor: SerialToolExecutor,
17}
18
19#[async_trait]
20impl ToolCaller for BaseToolCaller {
21    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
22        self.executor
23            .execute(&request.call.name, request.call.arguments.clone())
24            .await
25    }
26}
27
28/// Prebuilt node that executes tool calls from the last AI message in state.
29///
30/// Supports both regular `Tool` and `RuntimeAwareTool` instances.
31/// When a runtime-aware tool is registered, it receives the current graph
32/// state, store reference, and tool call ID via [`ToolRuntime`].
33///
34/// By default, tool calls are executed serially. Call [`ToolNode::with_parallel`]
35/// to enable concurrent execution of multiple tool calls within a single step.
36pub struct ToolNode {
37    executor: SerialToolExecutor,
38    middleware: Option<Arc<MiddlewareChain>>,
39    /// Optional store reference injected into RuntimeAwareTool calls.
40    store: Option<Arc<dyn Store>>,
41    /// Runtime-aware tools keyed by tool name.
42    runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
43    /// When true and multiple tool calls exist, execute them concurrently.
44    parallel: bool,
45}
46
47impl ToolNode {
48    pub fn new(executor: SerialToolExecutor) -> Self {
49        Self {
50            executor,
51            middleware: None,
52            store: None,
53            runtime_tools: HashMap::new(),
54            parallel: false,
55        }
56    }
57
58    /// Create a ToolNode with middleware support.
59    pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
60        Self {
61            executor,
62            middleware: Some(middleware),
63            store: None,
64            runtime_tools: HashMap::new(),
65            parallel: false,
66        }
67    }
68
69    /// Enable parallel execution of tool calls.
70    ///
71    /// When enabled and multiple tool calls exist in the last AI message,
72    /// they are executed concurrently using `futures::future::join_all`.
73    /// Results are collected in the same order as the original tool calls.
74    pub fn with_parallel(mut self, parallel: bool) -> Self {
75        self.parallel = parallel;
76        self
77    }
78
79    /// Set the store reference for runtime-aware tool injection.
80    pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
81        self.store = Some(store);
82        self
83    }
84
85    /// Register a runtime-aware tool.
86    ///
87    /// When a tool call matches a registered runtime-aware tool by name,
88    /// it will be called with a [`ToolRuntime`] containing the current
89    /// graph state, store, and tool call ID.
90    pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
91        self.runtime_tools.insert(tool.name().to_string(), tool);
92        self
93    }
94}
95
96#[async_trait]
97impl Node<MessageState> for ToolNode {
98    async fn process(
99        &self,
100        mut state: MessageState,
101    ) -> Result<NodeOutput<MessageState>, SynapticError> {
102        let last = state
103            .last_message()
104            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
105
106        let tool_calls = last.tool_calls().to_vec();
107        if tool_calls.is_empty() {
108            return Ok(state.into());
109        }
110
111        // Serialize current state for context injection
112        let state_value = serde_json::to_value(&state).ok();
113
114        if self.parallel && tool_calls.len() > 1 {
115            // Parallel execution: run all tool calls concurrently
116            let futs: Vec<_> = tool_calls
117                .iter()
118                .map(|call| {
119                    let executor = self.executor.clone();
120                    let middleware = self.middleware.clone();
121                    let rt_tool = self.runtime_tools.get(&call.name).cloned();
122                    let store = self.store.clone();
123                    let sv = state_value.clone();
124                    let call = call.clone();
125                    async move {
126                        if let Some(rt) = rt_tool {
127                            let runtime = ToolRuntime {
128                                store,
129                                stream_writer: None,
130                                state: sv,
131                                tool_call_id: call.id.clone(),
132                                config: None,
133                            };
134                            rt.call_with_runtime(call.arguments.clone(), runtime).await
135                        } else if let Some(ref chain) = middleware {
136                            let request = ToolCallRequest { call: call.clone() };
137                            let base = BaseToolCaller { executor };
138                            chain.call_tool(request, &base).await
139                        } else {
140                            executor.execute(&call.name, call.arguments.clone()).await
141                        }
142                    }
143                })
144                .collect();
145            let results = futures::future::join_all(futs).await;
146            for (call, result) in tool_calls.iter().zip(results) {
147                state
148                    .messages
149                    .push(Message::tool(result?.to_string(), &call.id));
150            }
151        } else {
152            // Serial execution (default)
153            for call in &tool_calls {
154                let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
155                    let runtime = ToolRuntime {
156                        store: self.store.clone(),
157                        stream_writer: None,
158                        state: state_value.clone(),
159                        tool_call_id: call.id.clone(),
160                        config: None,
161                    };
162                    rt_tool
163                        .call_with_runtime(call.arguments.clone(), runtime)
164                        .await?
165                } else if let Some(ref chain) = self.middleware {
166                    let request = ToolCallRequest { call: call.clone() };
167                    let base = BaseToolCaller {
168                        executor: self.executor.clone(),
169                    };
170                    chain.call_tool(request, &base).await?
171                } else {
172                    self.executor
173                        .execute(&call.name, call.arguments.clone())
174                        .await?
175                };
176                state
177                    .messages
178                    .push(Message::tool(result.to_string(), &call.id));
179            }
180        }
181
182        Ok(state.into())
183    }
184}
185
186/// Standard routing function: returns "tools" if last message has tool_calls, else END.
187///
188/// This is the standard condition function used with `add_conditional_edges`
189/// to route between an agent node and a tools node.
190pub fn tools_condition(state: &MessageState) -> String {
191    if let Some(last) = state.last_message() {
192        if !last.tool_calls().is_empty() {
193            return "tools".to_string();
194        }
195    }
196    crate::END.to_string()
197}