1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use schemars::schema::RootSchema;
5use serde::Deserialize;
6use serde_json::Value;
7
8use crate::error::ToolDispatchError;
9pub use wesichain_core::{CancellationToken, Tool, ToolContext, ToolSpec, TypedTool};
10
11pub type ToolError = wesichain_core::ToolError;
12
13#[derive(Clone, Debug)]
14pub struct ToolSchema {
15 pub args_schema: RootSchema,
16 pub output_schema: RootSchema,
17}
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct ToolCallEnvelope {
21 pub name: String,
22 pub args: Value,
23 pub call_id: String,
24}
25
26#[derive(Clone)]
27pub struct ToolSet {
28 entries: Vec<ToolMetadata>,
29 schema_catalog: BTreeMap<String, ToolSchema>,
30 dispatchers: BTreeMap<String, Arc<dyn ErasedToolRunner>>,
31}
32
33impl std::fmt::Debug for ToolSet {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ToolSet")
36 .field("entries", &self.entries)
37 .field("schema_catalog_len", &self.schema_catalog.len())
38 .field("dispatchers_len", &self.dispatchers.len())
39 .finish()
40 }
41}
42
43impl ToolSet {
44 #[allow(
45 clippy::new_ret_no_self,
46 reason = "ToolSet::new intentionally starts a builder-first registration API"
47 )]
48 pub fn new() -> ToolSetBuilder {
49 ToolSetBuilder {
50 entries: Vec::new(),
51 dispatchers: Vec::new(),
52 }
53 }
54
55 pub fn names(&self) -> Vec<&str> {
56 self.entries
57 .iter()
58 .map(|entry| entry.name.as_str())
59 .collect()
60 }
61
62 pub fn schema_catalog(&self) -> &BTreeMap<String, ToolSchema> {
63 &self.schema_catalog
64 }
65
66 pub fn tool_specs(&self) -> Vec<ToolSpec> {
74 self.entries
75 .iter()
76 .map(|e| {
77 let description = e
78 .schema
79 .args_schema
80 .schema
81 .metadata
82 .as_ref()
83 .and_then(|m| m.description.clone())
84 .unwrap_or_else(|| e.name.clone());
85
86 let parameters =
87 serde_json::to_value(&e.schema.args_schema).unwrap_or(Value::Object(
88 serde_json::Map::new(),
89 ));
90
91 ToolSpec { name: e.name.clone(), description, parameters }
92 })
93 .collect()
94 }
95
96 pub async fn dispatch(
97 &self,
98 envelope: ToolCallEnvelope,
99 ctx: ToolContext,
100 ) -> Result<Value, ToolDispatchError> {
101 let Some(dispatcher) = self.dispatchers.get(&envelope.name) else {
102 return Err(ToolDispatchError::UnknownTool {
103 name: envelope.name,
104 call_id: envelope.call_id,
105 });
106 };
107
108 dispatcher
109 .dispatch(&envelope.name, envelope.args, envelope.call_id, ctx)
110 .await
111 }
112
113 pub async fn dispatch_many(
117 &self,
118 envelopes: Vec<ToolCallEnvelope>,
119 ctx: ToolContext,
120 ) -> Vec<(String, Result<Value, ToolDispatchError>)> {
121 let mut handles = Vec::with_capacity(envelopes.len());
122
123 for envelope in envelopes {
124 let call_id = envelope.call_id.clone();
125 match self.dispatchers.get(&envelope.name) {
126 None => {
127 let err = Err(ToolDispatchError::UnknownTool {
128 name: envelope.name.clone(),
129 call_id: envelope.call_id.clone(),
130 });
131 handles.push((call_id, tokio::spawn(async move { err })));
132 }
133 Some(dispatcher) => {
134 let dispatcher = dispatcher.clone();
135 let ctx = ctx.clone();
136 let name = envelope.name.clone();
137 let args = envelope.args.clone();
138 let cid = envelope.call_id.clone();
139 handles.push((
140 call_id,
141 tokio::spawn(async move {
142 dispatcher.dispatch(&name, args, cid, ctx).await
143 }),
144 ));
145 }
146 }
147 }
148
149 let mut results = Vec::with_capacity(handles.len());
150 for (call_id, handle) in handles {
151 let result = match handle.await {
152 Ok(r) => r,
153 Err(join_err) => Err(ToolDispatchError::Execution {
154 name: String::new(),
155 call_id: call_id.clone(),
156 source: crate::ToolError::ExecutionFailed(format!("task panicked: {join_err}")),
157 }),
158 };
159 results.push((call_id, result));
160 }
161 results
162 }
163}
164
165#[derive(Clone, Default)]
166pub struct ToolSetBuilder {
167 entries: Vec<ToolMetadata>,
168 dispatchers: Vec<ToolDispatchMetadata>,
169}
170
171impl ToolSetBuilder {
172 pub fn register<T>(mut self) -> Self
173 where
174 T: TypedTool,
175 {
176 self.entries.push(ToolMetadata {
177 name: T::NAME.to_string(),
178 schema: ToolSchema {
179 args_schema: schemars::schema_for!(T::Args),
180 output_schema: schemars::schema_for!(T::Output),
181 },
182 });
183 self
184 }
185
186 pub fn register_with<T>(mut self, tool: T) -> Self
187 where
188 T: TypedTool + Send + Sync + 'static,
189 {
190 self.entries.push(ToolMetadata {
191 name: T::NAME.to_string(),
192 schema: ToolSchema {
193 args_schema: schemars::schema_for!(T::Args),
194 output_schema: schemars::schema_for!(T::Output),
195 },
196 });
197 self.dispatchers.push(ToolDispatchMetadata {
198 name: T::NAME.to_string(),
199 runner: Arc::new(TypedToolRunner { tool }),
200 });
201 self
202 }
203
204 pub fn register_dynamic(mut self, tool: impl Tool + 'static) -> Self {
209 let name = tool.name().to_string();
210 let arc: Arc<dyn Tool> = Arc::new(tool);
211
212 let args_root: RootSchema = serde_json::from_value(arc.schema())
215 .unwrap_or_else(|_| schemars::schema_for!(serde_json::Value));
216 let output_root: RootSchema = schemars::schema_for!(serde_json::Value);
217
218 self.entries.push(ToolMetadata {
219 name: name.clone(),
220 schema: ToolSchema {
221 args_schema: args_root,
222 output_schema: output_root,
223 },
224 });
225 self.dispatchers.push(ToolDispatchMetadata {
226 name,
227 runner: Arc::new(DynamicToolRunner { tool: arc }),
228 });
229 self
230 }
231
232 pub fn build(self) -> Result<ToolSet, ToolSetBuildError> {
233 let mut seen = HashSet::new();
234 let mut catalog = BTreeMap::new();
235 let mut dispatchers = BTreeMap::new();
236
237 for entry in &self.entries {
238 if entry.name.trim().is_empty() {
239 return Err(ToolSetBuildError::InvalidName {
240 name: entry.name.clone(),
241 });
242 }
243
244 if !seen.insert(entry.name.clone()) {
245 return Err(ToolSetBuildError::DuplicateName {
246 name: entry.name.clone(),
247 });
248 }
249
250 catalog.insert(entry.name.clone(), entry.schema.clone());
251 }
252
253 for dispatch in self.dispatchers {
254 dispatchers.insert(dispatch.name, dispatch.runner);
255 }
256
257 Ok(ToolSet {
258 entries: self.entries,
259 schema_catalog: catalog,
260 dispatchers,
261 })
262 }
263}
264
265#[derive(Clone, Debug)]
266struct ToolMetadata {
267 name: String,
268 schema: ToolSchema,
269}
270
271#[derive(Clone)]
272struct ToolDispatchMetadata {
273 name: String,
274 runner: Arc<dyn ErasedToolRunner>,
275}
276
277#[async_trait::async_trait]
278trait ErasedToolRunner: Send + Sync {
279 async fn dispatch(
280 &self,
281 name: &str,
282 args: Value,
283 call_id: String,
284 ctx: ToolContext,
285 ) -> Result<Value, ToolDispatchError>;
286}
287
288#[derive(Clone)]
289struct TypedToolRunner<T> {
290 tool: T,
291}
292
293#[async_trait::async_trait]
294impl<T> ErasedToolRunner for TypedToolRunner<T>
295where
296 T: TypedTool + Send + Sync,
297{
298 async fn dispatch(
299 &self,
300 name: &str,
301 args: Value,
302 call_id: String,
303 ctx: ToolContext,
304 ) -> Result<Value, ToolDispatchError> {
305 let typed_args = serde_json::from_value::<T::Args>(args).map_err(|source| {
306 ToolDispatchError::InvalidArgs {
307 name: name.to_string(),
308 call_id: call_id.clone(),
309 source,
310 }
311 })?;
312
313 let output = self.tool.run(typed_args, ctx).await.map_err(|source| {
314 ToolDispatchError::Execution {
315 name: name.to_string(),
316 call_id: call_id.clone(),
317 source,
318 }
319 })?;
320
321 serde_json::to_value(output).map_err(|source| ToolDispatchError::Serialization {
322 name: name.to_string(),
323 call_id,
324 source,
325 })
326 }
327}
328
329struct DynamicToolRunner {
331 tool: Arc<dyn Tool>,
332}
333
334#[async_trait::async_trait]
335impl ErasedToolRunner for DynamicToolRunner {
336 async fn dispatch(
337 &self,
338 name: &str,
339 args: Value,
340 call_id: String,
341 _ctx: ToolContext,
342 ) -> Result<Value, ToolDispatchError> {
343 self.tool.invoke(args).await.map_err(|source| ToolDispatchError::Execution {
344 name: name.to_string(),
345 call_id,
346 source,
347 })
348 }
349}
350
351#[derive(Clone, Debug, PartialEq, Eq)]
352pub enum ToolSetBuildError {
353 InvalidName { name: String },
354 DuplicateName { name: String },
355}
356
357impl std::fmt::Display for ToolSetBuildError {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 match self {
360 ToolSetBuildError::InvalidName { name } => {
361 write!(f, "tool name must not be empty or whitespace: {name:?}")
362 }
363 ToolSetBuildError::DuplicateName { name } => {
364 write!(f, "duplicate tool name: {name}")
365 }
366 }
367 }
368}
369
370impl std::error::Error for ToolSetBuildError {}