wesichain_graph/
tool_node.rs1use std::sync::Arc;
2
3use futures::stream::{self, BoxStream, StreamExt};
4use tokio::task::JoinSet;
5use wesichain_core::Tool;
6use wesichain_core::{Runnable, StreamEvent, WesichainError};
7use wesichain_llm::{Message, Role, ToolCall};
8
9use crate::{GraphState, StateSchema, StateUpdate};
10
11pub trait HasToolCalls {
17 fn tool_calls(&self) -> &Vec<ToolCall>;
18 fn push_tool_result(&mut self, message: Message);
19}
20
21pub struct ToolNode {
29 tools: Vec<Arc<dyn Tool>>,
30}
31
32impl ToolNode {
33 pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
34 Self { tools }
35 }
36
37 pub async fn invoke<S>(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError>
38 where
39 S: StateSchema<Update = S> + HasToolCalls,
40 {
41 <Self as Runnable<GraphState<S>, StateUpdate<S>>>::invoke(self, input).await
42 }
43}
44
45#[async_trait::async_trait]
46impl<S> Runnable<GraphState<S>, StateUpdate<S>> for ToolNode
47where
48 S: StateSchema<Update = S> + HasToolCalls,
49{
50 async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
51 let calls: Vec<ToolCall> = input.data.tool_calls().clone();
52
53 let mut join_set: JoinSet<(usize, String, Result<String, WesichainError>)> =
55 JoinSet::new();
56
57 for (index, call) in calls.iter().enumerate() {
58 let tool = self
59 .tools
60 .iter()
61 .find(|t| t.name() == call.name)
62 .ok_or_else(|| WesichainError::ToolCallFailed {
63 tool_name: call.name.clone(),
64 reason: "not found".to_string(),
65 })?;
66 let tool = tool.clone();
67 let args = call.args.clone();
68 let call_id = call.id.clone();
69 let tool_name = call.name.clone();
70 join_set.spawn(async move {
71 let result = tool.invoke(args).await.map(|v| v.to_string()).map_err(|e| {
72 WesichainError::ToolCallFailed {
73 tool_name,
74 reason: e.to_string(),
75 }
76 });
77 (index, call_id, result)
78 });
79 }
80
81 let mut results: Vec<(usize, String, Result<String, WesichainError>)> =
83 Vec::with_capacity(calls.len());
84 while let Some(res) = join_set.join_next().await {
85 results.push(res.map_err(|e| WesichainError::Custom(format!("task panicked: {e}")))?);
86 }
87 results.sort_by_key(|(idx, _, _)| *idx);
88
89 let mut next = input.data.clone();
90 for (_, call_id, output) in results {
91 next.push_tool_result(Message {
92 role: Role::Tool,
93 content: output?.into(),
94 tool_call_id: Some(call_id),
95 tool_calls: Vec::new(),
96 });
97 }
98 Ok(StateUpdate::new(next))
99 }
100
101 fn stream(&self, _input: GraphState<S>) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
102 stream::empty().boxed()
103 }
104}