synaptic_graph/
tool_node.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use synaptic_core::{Message, RuntimeAwareTool, Store, SynapticError, ToolRuntime};
7use synaptic_middleware::{MiddlewareChain, ToolCallRequest, ToolCaller};
8use synaptic_tools::SerialToolExecutor;
9
10use crate::command::NodeOutput;
11use crate::node::Node;
12use crate::state::MessageState;
13
14struct BaseToolCaller {
16 executor: SerialToolExecutor,
17}
18
19#[async_trait]
20impl ToolCaller for BaseToolCaller {
21 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
22 self.executor
23 .execute(&request.call.name, request.call.arguments.clone())
24 .await
25 }
26}
27
28pub struct ToolNode {
37 executor: SerialToolExecutor,
38 middleware: Option<Arc<MiddlewareChain>>,
39 store: Option<Arc<dyn Store>>,
41 runtime_tools: HashMap<String, Arc<dyn RuntimeAwareTool>>,
43 parallel: bool,
45}
46
47impl ToolNode {
48 pub fn new(executor: SerialToolExecutor) -> Self {
49 Self {
50 executor,
51 middleware: None,
52 store: None,
53 runtime_tools: HashMap::new(),
54 parallel: false,
55 }
56 }
57
58 pub fn with_middleware(executor: SerialToolExecutor, middleware: Arc<MiddlewareChain>) -> Self {
60 Self {
61 executor,
62 middleware: Some(middleware),
63 store: None,
64 runtime_tools: HashMap::new(),
65 parallel: false,
66 }
67 }
68
69 pub fn with_parallel(mut self, parallel: bool) -> Self {
75 self.parallel = parallel;
76 self
77 }
78
79 pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
81 self.store = Some(store);
82 self
83 }
84
85 pub fn with_runtime_tool(mut self, tool: Arc<dyn RuntimeAwareTool>) -> Self {
91 self.runtime_tools.insert(tool.name().to_string(), tool);
92 self
93 }
94}
95
96#[async_trait]
97impl Node<MessageState> for ToolNode {
98 async fn process(
99 &self,
100 mut state: MessageState,
101 ) -> Result<NodeOutput<MessageState>, SynapticError> {
102 let last = state
103 .last_message()
104 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
105
106 let tool_calls = last.tool_calls().to_vec();
107 if tool_calls.is_empty() {
108 return Ok(state.into());
109 }
110
111 let state_value = serde_json::to_value(&state).ok();
113
114 if self.parallel && tool_calls.len() > 1 {
115 let futs: Vec<_> = tool_calls
117 .iter()
118 .map(|call| {
119 let executor = self.executor.clone();
120 let middleware = self.middleware.clone();
121 let rt_tool = self.runtime_tools.get(&call.name).cloned();
122 let store = self.store.clone();
123 let sv = state_value.clone();
124 let call = call.clone();
125 async move {
126 if let Some(rt) = rt_tool {
127 let runtime = ToolRuntime {
128 store,
129 stream_writer: None,
130 state: sv,
131 tool_call_id: call.id.clone(),
132 config: None,
133 };
134 rt.call_with_runtime(call.arguments.clone(), runtime).await
135 } else if let Some(ref chain) = middleware {
136 let request = ToolCallRequest { call: call.clone() };
137 let base = BaseToolCaller { executor };
138 chain.call_tool(request, &base).await
139 } else {
140 executor.execute(&call.name, call.arguments.clone()).await
141 }
142 }
143 })
144 .collect();
145 let results = futures::future::join_all(futs).await;
146 for (call, result) in tool_calls.iter().zip(results) {
147 state
148 .messages
149 .push(Message::tool(result?.to_string(), &call.id));
150 }
151 } else {
152 for call in &tool_calls {
154 let result = if let Some(rt_tool) = self.runtime_tools.get(&call.name) {
155 let runtime = ToolRuntime {
156 store: self.store.clone(),
157 stream_writer: None,
158 state: state_value.clone(),
159 tool_call_id: call.id.clone(),
160 config: None,
161 };
162 rt_tool
163 .call_with_runtime(call.arguments.clone(), runtime)
164 .await?
165 } else if let Some(ref chain) = self.middleware {
166 let request = ToolCallRequest { call: call.clone() };
167 let base = BaseToolCaller {
168 executor: self.executor.clone(),
169 };
170 chain.call_tool(request, &base).await?
171 } else {
172 self.executor
173 .execute(&call.name, call.arguments.clone())
174 .await?
175 };
176 state
177 .messages
178 .push(Message::tool(result.to_string(), &call.id));
179 }
180 }
181
182 Ok(state.into())
183 }
184}
185
186pub fn tools_condition(state: &MessageState) -> String {
191 if let Some(last) = state.last_message() {
192 if !last.tool_calls().is_empty() {
193 return "tools".to_string();
194 }
195 }
196 crate::END.to_string()
197}