1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use crate::{
14 definition::ToolDefinition, return_types::ToolResult, schema::SchemaBuilder, RunContext,
15};
16
17#[async_trait]
47pub trait Tool<Deps = ()>: Send + Sync {
48 fn definition(&self) -> ToolDefinition;
53
54 async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult;
65
66 fn max_retries(&self) -> Option<u32> {
70 None
71 }
72
73 async fn prepare(
80 &self,
81 _ctx: &RunContext<Deps>,
82 def: ToolDefinition,
83 ) -> Option<ToolDefinition> {
84 Some(def)
85 }
86
87 fn name(&self) -> String {
89 self.definition().name.clone()
90 }
91
92 fn description(&self) -> String {
94 self.definition().description.clone()
95 }
96}
97
98pub type BoxedTool<Deps> = Arc<dyn Tool<Deps>>;
100
101pub struct FunctionTool<F, Deps = ()> {
127 name: String,
128 description: String,
129 parameters: JsonValue,
130 function: F,
131 max_retries: Option<u32>,
132 strict: Option<bool>,
133 _phantom: PhantomData<fn() -> Deps>,
134}
135
136impl<F, Deps> FunctionTool<F, Deps> {
137 pub fn new(
139 name: impl Into<String>,
140 description: impl Into<String>,
141 parameters: impl Into<JsonValue>,
142 function: F,
143 ) -> Self {
144 Self {
145 name: name.into(),
146 description: description.into(),
147 parameters: parameters.into(),
148 function,
149 max_retries: None,
150 strict: None,
151 _phantom: PhantomData,
152 }
153 }
154
155 #[must_use]
157 pub fn with_max_retries(mut self, retries: u32) -> Self {
158 self.max_retries = Some(retries);
159 self
160 }
161
162 #[must_use]
164 pub fn with_strict(mut self, strict: bool) -> Self {
165 self.strict = Some(strict);
166 self
167 }
168}
169
170type PinnedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
172
173#[async_trait]
174impl<F, Deps> Tool<Deps> for FunctionTool<F, Deps>
175where
176 F: for<'a> Fn(&'a RunContext<Deps>, JsonValue) -> PinnedFuture<ToolResult> + Send + Sync,
177 Deps: Send + Sync,
178{
179 fn definition(&self) -> ToolDefinition {
180 let mut def = ToolDefinition::new(&self.name, &self.description)
181 .with_parameters(self.parameters.clone());
182 if let Some(strict) = self.strict {
183 def = def.with_strict(strict);
184 }
185 def
186 }
187
188 async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
189 (self.function)(ctx, args).await
190 }
191
192 fn max_retries(&self) -> Option<u32> {
193 self.max_retries
194 }
195}
196
197impl<F, Deps> std::fmt::Debug for FunctionTool<F, Deps> {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("FunctionTool")
200 .field("name", &self.name)
201 .field("description", &self.description)
202 .field("max_retries", &self.max_retries)
203 .finish()
204 }
205}
206
207pub struct SyncFunctionTool<F, Deps = ()> {
211 name: String,
212 description: String,
213 parameters: JsonValue,
214 function: F,
215 max_retries: Option<u32>,
216 _phantom: PhantomData<fn() -> Deps>,
217}
218
219impl<F, Deps> SyncFunctionTool<F, Deps>
220where
221 F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
222{
223 pub fn new(
225 name: impl Into<String>,
226 description: impl Into<String>,
227 parameters: impl Into<JsonValue>,
228 function: F,
229 ) -> Self {
230 Self {
231 name: name.into(),
232 description: description.into(),
233 parameters: parameters.into(),
234 function,
235 max_retries: None,
236 _phantom: PhantomData,
237 }
238 }
239
240 #[must_use]
242 pub fn with_max_retries(mut self, retries: u32) -> Self {
243 self.max_retries = Some(retries);
244 self
245 }
246}
247
248#[async_trait]
249impl<F, Deps> Tool<Deps> for SyncFunctionTool<F, Deps>
250where
251 F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
252 Deps: Send + Sync,
253{
254 fn definition(&self) -> ToolDefinition {
255 ToolDefinition::new(&self.name, &self.description).with_parameters(self.parameters.clone())
256 }
257
258 async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
259 (self.function)(ctx, args)
260 }
261
262 fn max_retries(&self) -> Option<u32> {
263 self.max_retries
264 }
265}
266
267impl<F, Deps> std::fmt::Debug for SyncFunctionTool<F, Deps> {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 f.debug_struct("SyncFunctionTool")
270 .field("name", &self.name)
271 .field("description", &self.description)
272 .field("max_retries", &self.max_retries)
273 .finish()
274 }
275}
276
277pub fn sync_tool<F, Deps>(
292 name: impl Into<String>,
293 description: impl Into<String>,
294 function: F,
295) -> SyncFunctionTool<F, Deps>
296where
297 F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
298{
299 SyncFunctionTool::new(
300 name,
301 description,
302 SchemaBuilder::new()
303 .build()
304 .expect("SchemaBuilder JSON serialization failed"),
305 function,
306 )
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::ToolReturn;
313
314 #[derive(Debug, Clone, Default)]
315 struct TestDeps;
316
317 struct TestTool;
318
319 #[async_trait]
320 impl Tool<TestDeps> for TestTool {
321 fn definition(&self) -> ToolDefinition {
322 ToolDefinition::new("test", "Test tool").with_parameters(
323 SchemaBuilder::new()
324 .integer("x", "A number", true)
325 .build()
326 .expect("SchemaBuilder JSON serialization failed"),
327 )
328 }
329
330 async fn call(&self, _ctx: &RunContext<TestDeps>, args: JsonValue) -> ToolResult {
331 let x = args["x"].as_i64().unwrap_or(0);
332 Ok(ToolReturn::text(format!("x = {x}")))
333 }
334
335 fn max_retries(&self) -> Option<u32> {
336 Some(5)
337 }
338 }
339
340 #[tokio::test]
341 async fn test_tool_trait() {
342 let tool = TestTool;
343 let ctx = RunContext::new(TestDeps, "test-model");
344
345 assert_eq!(tool.name(), "test");
346 assert_eq!(tool.description(), "Test tool");
347 assert_eq!(tool.max_retries(), Some(5));
348
349 let result = tool.call(&ctx, serde_json::json!({"x": 42})).await.unwrap();
350 assert_eq!(result.as_text(), Some("x = 42"));
351 }
352
353 #[tokio::test]
354 async fn test_sync_function_tool() {
355 let tool = SyncFunctionTool::new(
356 "add",
357 "Add numbers",
358 SchemaBuilder::new()
359 .number("a", "First", true)
360 .number("b", "Second", true)
361 .build()
362 .expect("SchemaBuilder JSON serialization failed"),
363 |_ctx: &RunContext<()>, args: JsonValue| {
364 let a = args["a"].as_f64().unwrap_or(0.0);
365 let b = args["b"].as_f64().unwrap_or(0.0);
366 Ok(ToolReturn::text(format!("{}", a + b)))
367 },
368 );
369
370 let ctx = RunContext::minimal("test");
371 let result = tool
372 .call(&ctx, serde_json::json!({"a": 1.5, "b": 2.5}))
373 .await
374 .unwrap();
375 assert_eq!(result.as_text(), Some("4"));
376 }
377
378 #[tokio::test]
379 async fn test_tool_prepare() {
380 let tool = TestTool;
381 let ctx = RunContext::new(TestDeps, "test");
382 let def = tool.definition();
383 let prepared = tool.prepare(&ctx, def.clone()).await;
384 assert!(prepared.is_some());
385 assert_eq!(prepared.unwrap().name, def.name);
386 }
387
388 #[test]
389 fn test_sync_tool_helper() {
390 let tool = sync_tool::<_, ()>("echo", "Echo", |_ctx, args| {
391 let msg = args["message"].as_str().unwrap_or("default");
392 Ok(ToolReturn::text(msg))
393 });
394 assert_eq!(tool.name, "echo");
395 }
396}