Skip to main content

synaptic_graph/
tool_node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolContext, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14/// Wraps a `SerialToolExecutor` into a `ToolCaller` for the middleware chain.
15struct BaseToolCaller {
16    executor: SerialToolExecutor,
17    #[allow(dead_code)]
18    tool_context: ToolContext,
19}
20
21#[async_trait]
22impl ToolCaller for BaseToolCaller {
23    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
24        self.executor
25            .execute(&request.call.name, request.call.arguments.clone())
26            .await
27    }
28}
29
30/// Prebuilt node that executes tool calls from the last AI message in state.
31///
32/// Supports both regular `Tool` and `RuntimeAwareTool` instances.
33/// When a runtime-aware tool is registered, it receives the current graph
34/// state, store reference, and tool call ID via [`ToolRuntime`].
35pub struct ToolNode {
36    executor: SerialToolExecutor,
37    middleware: Option<Arc<MiddlewareChain>>,
38    /// Optional store reference injected into RuntimeAwareTool calls.
39    store: Option<Arc<dyn Store>>,
40    /// Runtime-aware tools keyed by tool name.
41    runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
42}
43
44impl ToolNode {
45    pub fn new(executor: SerialToolExecutor) -> Self {
46        Self {
47            executor,
48            middleware: None,
49            store: None,
50            runtime_tools: HashMap::new(),
51        }
52    }
53
54    /// Create a ToolNode with middleware support.
55    pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
56        Self {
57            executor,
58            middleware: Some(middleware),
59            store: None,
60            runtime_tools: HashMap::new(),
61        }
62    }
63
64    /// Set the store reference for runtime-aware tool injection.
65    pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
66        self.store = Some(store);
67        self
68    }
69
70    /// Register a runtime-aware tool.
71    ///
72    /// When a tool call matches a registered runtime-aware tool by name,
73    /// it will be called with a [`ToolRuntime`] containing the current
74    /// graph state, store, and tool call ID.
75    pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
76        self.runtime_tools.insert(tool.name().to_string(), tool);
77        self
78    }
79}
80
81#[async_trait]
82impl Node<MessageState> for ToolNode {
83    async fn process(
84        &self,
85        mut state: MessageState,
86    ) -> Result<NodeOutput<MessageState>, SynapticError> {
87        let last = state
88            .last_message()
89            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
90
91        let tool_calls = last.tool_calls().to_vec();
92        if tool_calls.is_empty() {
93            return Ok(state.into());
94        }
95
96        // Serialize current state for context injection
97        let state_value = serde_json::to_value(&state).ok();
98
99        for call in &tool_calls {
100            // Check if this is a runtime-aware tool
101            let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
102                let runtime = ToolRuntime {
103                    store: self.store.clone(),
104                    stream_writer: None,
105                    state: state_value.clone(),
106                    tool_call_id: call.id.clone(),
107                    config: None,
108                };
109                rt_tool
110                    .call_with_runtime(call.arguments.clone(), runtime)
111                    .await?
112            } else {
113                // Regular tool execution
114                let tool_ctx = ToolContext {
115                    state: state_value.clone(),
116                    tool_call_id: call.id.clone(),
117                };
118
119                if let Some(ref chain) = self.middleware {
120                    let request = ToolCallRequest { call: call.clone() };
121                    let base = BaseToolCaller {
122                        executor: self.executor.clone(),
123                        tool_context: tool_ctx,
124                    };
125                    chain.call_tool(request, &base).await?
126                } else {
127                    self.executor
128                        .execute(&call.name, call.arguments.clone())
129                        .await?
130                }
131            };
132            state
133                .messages
134                .push(Message::tool(result.to_string(), &call.id));
135        }
136
137        Ok(state.into())
138    }
139}
140
141/// Standard routing function: returns "tools" if last message has tool_calls, else END.
142///
143/// This is the standard condition function used with `add_conditional_edges`
144/// to route between an agent node and a tools node.
145pub fn tools_condition(state: &MessageState) -> String {
146    if let Some(last) = state.last_message() {
147        if !last.tool_calls().is_empty() {
148            return "tools".to_string();
149        }
150    }
151    crate::END.to_string()
152}