Skip to main content

wesichain_agent/
tooling.rs

1use std::collections::{BTreeMap, HashSet};
2use std::sync::Arc;
3
4use schemars::schema::RootSchema;
5use serde::Deserialize;
6use serde_json::Value;
7
8use crate::error::ToolDispatchError;
9pub use wesichain_core::{CancellationToken, Tool, ToolContext, ToolSpec, TypedTool};
10
11pub type ToolError = wesichain_core::ToolError;
12
13#[derive(Clone, Debug)]
14pub struct ToolSchema {
15    pub args_schema: RootSchema,
16    pub output_schema: RootSchema,
17}
18
19#[derive(Clone, Debug, Deserialize)]
20pub struct ToolCallEnvelope {
21    pub name: String,
22    pub args: Value,
23    pub call_id: String,
24}
25
26#[derive(Clone)]
27pub struct ToolSet {
28    entries: Vec<ToolMetadata>,
29    schema_catalog: BTreeMap<String, ToolSchema>,
30    dispatchers: BTreeMap<String, Arc<dyn ErasedToolRunner>>,
31}
32
33impl std::fmt::Debug for ToolSet {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ToolSet")
36            .field("entries", &self.entries)
37            .field("schema_catalog_len", &self.schema_catalog.len())
38            .field("dispatchers_len", &self.dispatchers.len())
39            .finish()
40    }
41}
42
43impl ToolSet {
44    #[allow(
45        clippy::new_ret_no_self,
46        reason = "ToolSet::new intentionally starts a builder-first registration API"
47    )]
48    pub fn new() -> ToolSetBuilder {
49        ToolSetBuilder {
50            entries: Vec::new(),
51            dispatchers: Vec::new(),
52        }
53    }
54
55    pub fn names(&self) -> Vec<&str> {
56        self.entries
57            .iter()
58            .map(|entry| entry.name.as_str())
59            .collect()
60    }
61
62    pub fn schema_catalog(&self) -> &BTreeMap<String, ToolSchema> {
63        &self.schema_catalog
64    }
65
66    /// Build a [`Vec<ToolSpec>`] suitable for inclusion in an [`LlmRequest`].
67    ///
68    /// The `parameters` field is the JSON Schema for the tool's arguments.
69    /// The `description` field is derived from the schema's metadata when
70    /// available; otherwise the tool name is used as a fallback.
71    ///
72    /// [`LlmRequest`]: wesichain_core::LlmRequest
73    pub fn tool_specs(&self) -> Vec<ToolSpec> {
74        self.entries
75            .iter()
76            .map(|e| {
77                let description = e
78                    .schema
79                    .args_schema
80                    .schema
81                    .metadata
82                    .as_ref()
83                    .and_then(|m| m.description.clone())
84                    .unwrap_or_else(|| e.name.clone());
85
86                let parameters =
87                    serde_json::to_value(&e.schema.args_schema).unwrap_or(Value::Object(
88                        serde_json::Map::new(),
89                    ));
90
91                ToolSpec { name: e.name.clone(), description, parameters }
92            })
93            .collect()
94    }
95
96    pub async fn dispatch(
97        &self,
98        envelope: ToolCallEnvelope,
99        ctx: ToolContext,
100    ) -> Result<Value, ToolDispatchError> {
101        let Some(dispatcher) = self.dispatchers.get(&envelope.name) else {
102            return Err(ToolDispatchError::UnknownTool {
103                name: envelope.name,
104                call_id: envelope.call_id,
105            });
106        };
107
108        dispatcher
109            .dispatch(&envelope.name, envelope.args, envelope.call_id, ctx)
110            .await
111    }
112
113    /// Dispatch multiple tool calls concurrently via `tokio::spawn`.
114    ///
115    /// Results are returned in the same order as `envelopes`.
116    pub async fn dispatch_many(
117        &self,
118        envelopes: Vec<ToolCallEnvelope>,
119        ctx: ToolContext,
120    ) -> Vec<(String, Result<Value, ToolDispatchError>)> {
121        let mut handles = Vec::with_capacity(envelopes.len());
122
123        for envelope in envelopes {
124            let call_id = envelope.call_id.clone();
125            match self.dispatchers.get(&envelope.name) {
126                None => {
127                    let err = Err(ToolDispatchError::UnknownTool {
128                        name: envelope.name.clone(),
129                        call_id: envelope.call_id.clone(),
130                    });
131                    handles.push((call_id, tokio::spawn(async move { err })));
132                }
133                Some(dispatcher) => {
134                    let dispatcher = dispatcher.clone();
135                    let ctx = ctx.clone();
136                    let name = envelope.name.clone();
137                    let args = envelope.args.clone();
138                    let cid = envelope.call_id.clone();
139                    handles.push((
140                        call_id,
141                        tokio::spawn(async move {
142                            dispatcher.dispatch(&name, args, cid, ctx).await
143                        }),
144                    ));
145                }
146            }
147        }
148
149        let mut results = Vec::with_capacity(handles.len());
150        for (call_id, handle) in handles {
151            let result = match handle.await {
152                Ok(r) => r,
153                Err(join_err) => Err(ToolDispatchError::Execution {
154                    name: String::new(),
155                    call_id: call_id.clone(),
156                    source: crate::ToolError::ExecutionFailed(format!("task panicked: {join_err}")),
157                }),
158            };
159            results.push((call_id, result));
160        }
161        results
162    }
163}
164
165#[derive(Clone, Default)]
166pub struct ToolSetBuilder {
167    entries: Vec<ToolMetadata>,
168    dispatchers: Vec<ToolDispatchMetadata>,
169}
170
171impl ToolSetBuilder {
172    pub fn register<T>(mut self) -> Self
173    where
174        T: TypedTool,
175    {
176        self.entries.push(ToolMetadata {
177            name: T::NAME.to_string(),
178            schema: ToolSchema {
179                args_schema: schemars::schema_for!(T::Args),
180                output_schema: schemars::schema_for!(T::Output),
181            },
182        });
183        self
184    }
185
186    pub fn register_with<T>(mut self, tool: T) -> Self
187    where
188        T: TypedTool + Send + Sync + 'static,
189    {
190        self.entries.push(ToolMetadata {
191            name: T::NAME.to_string(),
192            schema: ToolSchema {
193                args_schema: schemars::schema_for!(T::Args),
194                output_schema: schemars::schema_for!(T::Output),
195            },
196        });
197        self.dispatchers.push(ToolDispatchMetadata {
198            name: T::NAME.to_string(),
199            runner: Arc::new(TypedToolRunner { tool }),
200        });
201        self
202    }
203
204    /// Register a `Tool` implementation (dynamic dispatch) by instance.
205    ///
206    /// Use this for tools that implement `wesichain_core::Tool` directly rather
207    /// than `TypedTool` — e.g. `AgentAsTool` and MCP bridge tools.
208    pub fn register_dynamic(mut self, tool: impl Tool + 'static) -> Self {
209        let name = tool.name().to_string();
210        let arc: Arc<dyn Tool> = Arc::new(tool);
211
212        // Build a minimal schema entry using the tool's schema() Value.
213        // We store the schema as args_schema and leave output_schema empty.
214        let args_root: RootSchema = serde_json::from_value(arc.schema())
215            .unwrap_or_else(|_| schemars::schema_for!(serde_json::Value));
216        let output_root: RootSchema = schemars::schema_for!(serde_json::Value);
217
218        self.entries.push(ToolMetadata {
219            name: name.clone(),
220            schema: ToolSchema {
221                args_schema: args_root,
222                output_schema: output_root,
223            },
224        });
225        self.dispatchers.push(ToolDispatchMetadata {
226            name,
227            runner: Arc::new(DynamicToolRunner { tool: arc }),
228        });
229        self
230    }
231
232    pub fn build(self) -> Result<ToolSet, ToolSetBuildError> {
233        let mut seen = HashSet::new();
234        let mut catalog = BTreeMap::new();
235        let mut dispatchers = BTreeMap::new();
236
237        for entry in &self.entries {
238            if entry.name.trim().is_empty() {
239                return Err(ToolSetBuildError::InvalidName {
240                    name: entry.name.clone(),
241                });
242            }
243
244            if !seen.insert(entry.name.clone()) {
245                return Err(ToolSetBuildError::DuplicateName {
246                    name: entry.name.clone(),
247                });
248            }
249
250            catalog.insert(entry.name.clone(), entry.schema.clone());
251        }
252
253        for dispatch in self.dispatchers {
254            dispatchers.insert(dispatch.name, dispatch.runner);
255        }
256
257        Ok(ToolSet {
258            entries: self.entries,
259            schema_catalog: catalog,
260            dispatchers,
261        })
262    }
263}
264
265#[derive(Clone, Debug)]
266struct ToolMetadata {
267    name: String,
268    schema: ToolSchema,
269}
270
271#[derive(Clone)]
272struct ToolDispatchMetadata {
273    name: String,
274    runner: Arc<dyn ErasedToolRunner>,
275}
276
277#[async_trait::async_trait]
278trait ErasedToolRunner: Send + Sync {
279    async fn dispatch(
280        &self,
281        name: &str,
282        args: Value,
283        call_id: String,
284        ctx: ToolContext,
285    ) -> Result<Value, ToolDispatchError>;
286}
287
288#[derive(Clone)]
289struct TypedToolRunner<T> {
290    tool: T,
291}
292
293#[async_trait::async_trait]
294impl<T> ErasedToolRunner for TypedToolRunner<T>
295where
296    T: TypedTool + Send + Sync,
297{
298    async fn dispatch(
299        &self,
300        name: &str,
301        args: Value,
302        call_id: String,
303        ctx: ToolContext,
304    ) -> Result<Value, ToolDispatchError> {
305        let typed_args = serde_json::from_value::<T::Args>(args).map_err(|source| {
306            ToolDispatchError::InvalidArgs {
307                name: name.to_string(),
308                call_id: call_id.clone(),
309                source,
310            }
311        })?;
312
313        let output = self.tool.run(typed_args, ctx).await.map_err(|source| {
314            ToolDispatchError::Execution {
315                name: name.to_string(),
316                call_id: call_id.clone(),
317                source,
318            }
319        })?;
320
321        serde_json::to_value(output).map_err(|source| ToolDispatchError::Serialization {
322            name: name.to_string(),
323            call_id,
324            source,
325        })
326    }
327}
328
329/// Wraps a `Tool` (dynamic dispatch) as an `ErasedToolRunner`.
330struct DynamicToolRunner {
331    tool: Arc<dyn Tool>,
332}
333
334#[async_trait::async_trait]
335impl ErasedToolRunner for DynamicToolRunner {
336    async fn dispatch(
337        &self,
338        name: &str,
339        args: Value,
340        call_id: String,
341        _ctx: ToolContext,
342    ) -> Result<Value, ToolDispatchError> {
343        self.tool.invoke(args).await.map_err(|source| ToolDispatchError::Execution {
344            name: name.to_string(),
345            call_id,
346            source,
347        })
348    }
349}
350
351#[derive(Clone, Debug, PartialEq, Eq)]
352pub enum ToolSetBuildError {
353    InvalidName { name: String },
354    DuplicateName { name: String },
355}
356
357impl std::fmt::Display for ToolSetBuildError {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        match self {
360            ToolSetBuildError::InvalidName { name } => {
361                write!(f, "tool name must not be empty or whitespace: {name:?}")
362            }
363            ToolSetBuildError::DuplicateName { name } => {
364                write!(f, "duplicate tool name: {name}")
365            }
366        }
367    }
368}
369
370impl std::error::Error for ToolSetBuildError {}