Skip to main content

swink_agent/
fn_tool.rs

1//! Closure-based tool builder that implements [`AgentTool`] without requiring
2//! a custom struct or trait implementation.
3//!
4//! # Example
5//!
6//! ```
7//! use schemars::JsonSchema;
8//! use serde::Deserialize;
9//! use swink_agent::{AgentToolResult, FnTool};
10//!
11//! #[derive(Deserialize, JsonSchema)]
12//! struct Params { city: String }
13//!
14//! let tool = FnTool::new("get_weather", "Weather", "Get weather for a city.")
15//!     .with_execute_typed(|params: Params, _cancel| async move {
16//!         AgentToolResult::text(format!("72F in {}", params.city))
17//!     });
18//!
19//! assert_eq!(swink_agent::AgentTool::name(&tool), "get_weather");
20//! ```
21
22use std::future::Future;
23use std::sync::Arc;
24
25use serde::de::DeserializeOwned;
26use serde_json::Value;
27use tokio_util::sync::CancellationToken;
28
29use crate::tool::{
30    AgentTool, AgentToolResult, ToolFuture, debug_validated_schema, permissive_object_schema,
31    validated_schema_for,
32};
33
34// ─── Type aliases for stored closures ───────────────────────────────────────
35
36type ExecuteFn = Arc<
37    dyn Fn(
38            String,
39            Value,
40            CancellationToken,
41            Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
42        ) -> ToolFuture<'static>
43        + Send
44        + Sync,
45>;
46
47type ApprovalContextFn = Arc<dyn Fn(&Value) -> Option<Value> + Send + Sync>;
48
49// ─── FnTool ─────────────────────────────────────────────────────────────────
50
51/// A tool built entirely from closures and configuration, implementing
52/// [`AgentTool`] without requiring a custom struct.
53///
54/// Use the builder methods to configure the tool's schema, approval
55/// requirements, and execution logic.
56pub struct FnTool {
57    name: String,
58    label: String,
59    description: String,
60    schema: Value,
61    requires_approval: bool,
62    execute_fn: ExecuteFn,
63    approval_context_fn: Option<ApprovalContextFn>,
64}
65
66impl FnTool {
67    /// Create a new `FnTool` with the given name, label, and description.
68    ///
69    /// The default schema accepts any object and the default execute returns
70    /// an error indicating the tool is not implemented.
71    #[must_use]
72    pub fn new(
73        name: impl Into<String>,
74        label: impl Into<String>,
75        description: impl Into<String>,
76    ) -> Self {
77        Self {
78            name: name.into(),
79            label: label.into(),
80            description: description.into(),
81            schema: permissive_object_schema(),
82            requires_approval: false,
83            execute_fn: Arc::new(|_, _, _, _| {
84                Box::pin(async { AgentToolResult::error("not implemented") })
85            }),
86            approval_context_fn: None,
87        }
88    }
89
90    /// Set the parameters schema from a type implementing
91    /// [`JsonSchema`](schemars::JsonSchema).
92    #[must_use]
93    pub fn with_schema_for<T: schemars::JsonSchema>(mut self) -> Self {
94        self.schema = validated_schema_for::<T>();
95        self
96    }
97
98    /// Set the parameters schema from a raw JSON value.
99    #[must_use]
100    pub fn with_schema(mut self, schema: Value) -> Self {
101        self.schema = debug_validated_schema(schema);
102        self
103    }
104
105    /// Set whether this tool requires user approval before execution.
106    #[must_use]
107    pub const fn with_requires_approval(mut self, requires: bool) -> Self {
108        self.requires_approval = requires;
109        self
110    }
111
112    /// Set the execution function using the full signature.
113    ///
114    /// The closure receives `(tool_call_id, params, cancellation_token, on_update)`.
115    #[must_use]
116    pub fn with_execute<F, Fut>(mut self, f: F) -> Self
117    where
118        F: Fn(
119                String,
120                Value,
121                CancellationToken,
122                Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
123            ) -> Fut
124            + Send
125            + Sync
126            + 'static,
127        Fut: Future<Output = AgentToolResult> + Send + 'static,
128    {
129        self.execute_fn = Arc::new(move |id, params, cancel, on_update| {
130            Box::pin(f(id, params, cancel, on_update))
131        });
132        self
133    }
134
135    /// Set the execution function using a simplified signature.
136    ///
137    /// The closure receives only `(params, cancellation_token)`, ignoring the
138    /// tool call ID and update callback.
139    #[must_use]
140    pub fn with_execute_simple<F, Fut>(mut self, f: F) -> Self
141    where
142        F: Fn(Value, CancellationToken) -> Fut + Send + Sync + 'static,
143        Fut: Future<Output = AgentToolResult> + Send + 'static,
144    {
145        self.execute_fn =
146            Arc::new(move |_id, params, cancel, _on_update| Box::pin(f(params, cancel)));
147        self
148    }
149
150    /// Set the execution function using an explicit untyped async signature.
151    ///
152    /// This is equivalent to [`Self::with_execute_simple`] and exists as a
153    /// discoverability alias for callers looking for an untyped async builder.
154    #[must_use]
155    pub fn with_execute_async<F, Fut>(self, f: F) -> Self
156    where
157        F: Fn(Value, CancellationToken) -> Fut + Send + Sync + 'static,
158        Fut: Future<Output = AgentToolResult> + Send + 'static,
159    {
160        self.with_execute_simple(f)
161    }
162
163    /// Set the execution function using a typed parameter struct.
164    ///
165    /// This derives the schema from `T` and deserializes validated params into
166    /// `T` before calling the closure. On deserialization failure, execution
167    /// returns `AgentToolResult::error("invalid parameters: ...")`.
168    #[must_use]
169    pub fn with_execute_typed<T, F, Fut>(mut self, f: F) -> Self
170    where
171        T: DeserializeOwned + schemars::JsonSchema + Send + 'static,
172        F: Fn(T, CancellationToken) -> Fut + Send + Sync + 'static,
173        Fut: Future<Output = AgentToolResult> + Send + 'static,
174    {
175        self.schema = validated_schema_for::<T>();
176        self.execute_fn = Arc::new(move |_id, params, cancel, _on_update| {
177            let parsed: T = match serde_json::from_value(params) {
178                Ok(parsed) => parsed,
179                Err(err) => {
180                    return Box::pin(async move {
181                        AgentToolResult::error(format!("invalid parameters: {err}"))
182                    });
183                }
184            };
185            Box::pin(f(parsed, cancel))
186        });
187        self
188    }
189
190    /// Set a closure that provides rich context for the approval UI.
191    ///
192    /// When the tool requires approval, this closure is called to produce
193    /// context that is attached to the [`ToolApprovalRequest`](crate::ToolApprovalRequest).
194    #[must_use]
195    pub fn with_approval_context<F>(mut self, f: F) -> Self
196    where
197        F: Fn(&Value) -> Option<Value> + Send + Sync + 'static,
198    {
199        self.approval_context_fn = Some(Arc::new(f));
200        self
201    }
202}
203
204impl AgentTool for FnTool {
205    fn name(&self) -> &str {
206        &self.name
207    }
208
209    fn label(&self) -> &str {
210        &self.label
211    }
212
213    fn description(&self) -> &str {
214        &self.description
215    }
216
217    fn parameters_schema(&self) -> &Value {
218        &self.schema
219    }
220
221    fn requires_approval(&self) -> bool {
222        self.requires_approval
223    }
224
225    fn approval_context(&self, params: &Value) -> Option<Value> {
226        self.approval_context_fn.as_ref().and_then(|f| f(params))
227    }
228
229    fn execute(
230        &self,
231        tool_call_id: &str,
232        params: Value,
233        cancellation_token: CancellationToken,
234        on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
235        _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
236        _credential: Option<crate::credential::ResolvedCredential>,
237    ) -> ToolFuture<'_> {
238        let fut = (self.execute_fn)(
239            tool_call_id.to_owned(),
240            params,
241            cancellation_token,
242            on_update,
243        );
244        Box::pin(fut)
245    }
246}
247
248impl std::fmt::Debug for FnTool {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("FnTool")
251            .field("name", &self.name)
252            .field("label", &self.label)
253            .field("description", &self.description)
254            .field("requires_approval", &self.requires_approval)
255            .finish_non_exhaustive()
256    }
257}
258
259// ─── Compile-time Send + Sync assertion ─────────────────────────────────────
260
261const _: () = {
262    const fn assert_send_sync<T: Send + Sync>() {}
263    assert_send_sync::<FnTool>();
264};
265
266#[cfg(test)]
267mod tests {
268    use schemars::JsonSchema;
269    use serde::Deserialize;
270    use serde_json::json;
271    use tokio_util::sync::CancellationToken;
272
273    use super::*;
274    use crate::ContentBlock;
275
276    fn test_state() -> std::sync::Arc<std::sync::RwLock<crate::SessionState>> {
277        std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::new()))
278    }
279
280    fn sample_tool() -> FnTool {
281        FnTool::new("test", "Test", "A test tool.")
282    }
283
284    #[test]
285    fn metadata_matches_constructor() {
286        let tool = sample_tool();
287        assert_eq!(tool.name(), "test");
288        assert_eq!(tool.label(), "Test");
289        assert_eq!(tool.description(), "A test tool.");
290        assert!(!tool.requires_approval());
291    }
292
293    #[tokio::test]
294    async fn default_execute_returns_error() {
295        let tool = sample_tool();
296        let result = tool
297            .execute(
298                "{}",
299                json!({}),
300                CancellationToken::new(),
301                None,
302                test_state(),
303                None,
304            )
305            .await;
306        assert!(result.is_error);
307    }
308
309    #[tokio::test]
310    async fn simple_execute_receives_params() {
311        let tool = FnTool::new("echo", "Echo", "Echo params.").with_execute_simple(
312            |params, _cancel| async move {
313                let msg = params["msg"].as_str().unwrap_or("none").to_owned();
314                AgentToolResult::text(msg)
315            },
316        );
317
318        let result = tool
319            .execute(
320                "id",
321                json!({"msg": "hello"}),
322                CancellationToken::new(),
323                None,
324                test_state(),
325                None,
326            )
327            .await;
328        assert!(!result.is_error);
329        assert_eq!(result.content.len(), 1);
330    }
331
332    #[tokio::test]
333    async fn async_execute_receives_params() {
334        let tool = FnTool::new("echo", "Echo", "Echo params.").with_execute_async(
335            |params, _cancel| async move {
336                let msg = params["msg"].as_str().unwrap_or("none").to_owned();
337                AgentToolResult::text(msg)
338            },
339        );
340
341        let result = tool
342            .execute(
343                "id",
344                json!({"msg": "hello"}),
345                CancellationToken::new(),
346                None,
347                test_state(),
348                None,
349            )
350            .await;
351        assert!(!result.is_error);
352        assert_eq!(ContentBlock::extract_text(&result.content), "hello");
353    }
354
355    #[derive(Deserialize, JsonSchema)]
356    #[allow(dead_code)]
357    struct TestParams {
358        city: String,
359    }
360
361    #[test]
362    fn with_schema_for_sets_schema() {
363        let tool = sample_tool().with_schema_for::<TestParams>();
364        let schema = tool.parameters_schema();
365        assert_eq!(schema["type"], "object");
366        assert!(
367            schema["required"]
368                .as_array()
369                .unwrap()
370                .contains(&json!("city"))
371        );
372    }
373
374    #[test]
375    fn approval_flag_is_configurable() {
376        let tool = sample_tool().with_requires_approval(true);
377        assert!(tool.requires_approval());
378    }
379
380    #[tokio::test]
381    async fn full_execute_receives_all_args() {
382        let tool =
383            FnTool::new("full", "Full", "Full signature.").with_execute(
384                |id, _params, _cancel, _on_update| async move {
385                    AgentToolResult::text(format!("id={id}"))
386                },
387            );
388
389        let result = tool
390            .execute(
391                "call_42",
392                json!({}),
393                CancellationToken::new(),
394                None,
395                test_state(),
396                None,
397            )
398            .await;
399        assert!(!result.is_error);
400    }
401
402    #[derive(Deserialize, JsonSchema)]
403    struct TypedParams {
404        city: String,
405    }
406
407    #[tokio::test]
408    async fn typed_execute_deserializes_params_and_sets_schema() {
409        let tool = FnTool::new("typed", "Typed", "Typed params.").with_execute_typed(
410            |params: TypedParams, _cancel| async move { AgentToolResult::text(params.city) },
411        );
412
413        let schema = tool.parameters_schema();
414        assert_eq!(schema["type"], "object");
415        assert!(
416            schema["required"]
417                .as_array()
418                .unwrap()
419                .contains(&json!("city"))
420        );
421
422        let result = tool
423            .execute(
424                "id",
425                json!({"city": "Chicago"}),
426                CancellationToken::new(),
427                None,
428                test_state(),
429                None,
430            )
431            .await;
432        assert!(!result.is_error);
433        assert_eq!(ContentBlock::extract_text(&result.content), "Chicago");
434    }
435
436    #[tokio::test]
437    async fn typed_execute_reports_deserialization_errors() {
438        let tool = FnTool::new("typed", "Typed", "Typed params.").with_execute_typed(
439            |params: TypedParams, _cancel| async move { AgentToolResult::text(params.city) },
440        );
441
442        let result = tool
443            .execute(
444                "id",
445                json!({"city": 42}),
446                CancellationToken::new(),
447                None,
448                test_state(),
449                None,
450            )
451            .await;
452        assert!(result.is_error);
453        assert!(
454            ContentBlock::extract_text(&result.content).contains("invalid parameters"),
455            "expected invalid parameters error, got: {:?}",
456            result.content
457        );
458    }
459}