Skip to main content

serdes_ai_toolsets/
wrapper.rs

1//! Wrapper toolset implementation.
2//!
3//! This module provides `WrapperToolset`, which allows custom pre/post
4//! processing around tool calls.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15/// Type alias for before-call hooks.
16pub type BeforeCallHook<Deps> = dyn Fn(&str, &JsonValue, &RunContext<Deps>) + Send + Sync;
17
18/// Type alias for after-call hooks.
19pub type AfterCallHook<Deps> =
20    dyn Fn(&str, &Result<ToolReturn, ToolError>, &RunContext<Deps>) + Send + Sync;
21
22/// Wrapper that allows custom pre/post processing.
23///
24/// This is useful for adding logging, metrics, or other cross-cutting
25/// concerns to tool calls.
26///
27/// # Example
28///
29/// ```ignore
30/// use serdes_ai_toolsets::{WrapperToolset, FunctionToolset};
31///
32/// let toolset = FunctionToolset::new().tool(my_tool);
33///
34/// let wrapped = WrapperToolset::new(toolset)
35///     .before(|name, args, ctx| {
36///         println!("Calling tool: {} with args: {:?}", name, args);
37///     })
38///     .after(|name, result, ctx| {
39///         match result {
40///             Ok(_) => println!("Tool {} succeeded", name),
41///             Err(e) => println!("Tool {} failed: {}", name, e),
42///         }
43///     });
44/// ```
45pub struct WrapperToolset<T, Deps = ()> {
46    inner: T,
47    before_call: Option<Arc<BeforeCallHook<Deps>>>,
48    after_call: Option<Arc<AfterCallHook<Deps>>>,
49    _phantom: PhantomData<fn() -> Deps>,
50}
51
52impl<T, Deps> WrapperToolset<T, Deps>
53where
54    T: AbstractToolset<Deps>,
55{
56    /// Create a new wrapper toolset.
57    pub fn new(inner: T) -> Self {
58        Self {
59            inner,
60            before_call: None,
61            after_call: None,
62            _phantom: PhantomData,
63        }
64    }
65
66    /// Add a before-call hook.
67    #[must_use]
68    pub fn before<F>(mut self, f: F) -> Self
69    where
70        F: Fn(&str, &JsonValue, &RunContext<Deps>) + Send + Sync + 'static,
71    {
72        self.before_call = Some(Arc::new(f));
73        self
74    }
75
76    /// Add an after-call hook.
77    #[must_use]
78    pub fn after<F>(mut self, f: F) -> Self
79    where
80        F: Fn(&str, &Result<ToolReturn, ToolError>, &RunContext<Deps>) + Send + Sync + 'static,
81    {
82        self.after_call = Some(Arc::new(f));
83        self
84    }
85
86    /// Get the inner toolset.
87    #[must_use]
88    pub fn inner(&self) -> &T {
89        &self.inner
90    }
91}
92
93#[async_trait]
94impl<T, Deps> AbstractToolset<Deps> for WrapperToolset<T, Deps>
95where
96    T: AbstractToolset<Deps>,
97    Deps: Send + Sync,
98{
99    fn id(&self) -> Option<&str> {
100        self.inner.id()
101    }
102
103    fn type_name(&self) -> &'static str {
104        "WrapperToolset"
105    }
106
107    fn label(&self) -> String {
108        format!("WrapperToolset({})", self.inner.label())
109    }
110
111    async fn get_tools(
112        &self,
113        ctx: &RunContext<Deps>,
114    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
115        self.inner.get_tools(ctx).await
116    }
117
118    async fn call_tool(
119        &self,
120        name: &str,
121        args: JsonValue,
122        ctx: &RunContext<Deps>,
123        tool: &ToolsetTool,
124    ) -> Result<ToolReturn, ToolError> {
125        // Call before hook
126        if let Some(ref before) = self.before_call {
127            before(name, &args, ctx);
128        }
129
130        // Execute the tool
131        let result = self.inner.call_tool(name, args, ctx, tool).await;
132
133        // Call after hook
134        if let Some(ref after) = self.after_call {
135            after(name, &result, ctx);
136        }
137
138        result
139    }
140
141    async fn enter(&self) -> Result<(), ToolError> {
142        self.inner.enter().await
143    }
144
145    async fn exit(&self) -> Result<(), ToolError> {
146        self.inner.exit().await
147    }
148}
149
150impl<T: std::fmt::Debug, Deps> std::fmt::Debug for WrapperToolset<T, Deps> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("WrapperToolset")
153            .field("inner", &self.inner)
154            .field("has_before", &self.before_call.is_some())
155            .field("has_after", &self.after_call.is_some())
156            .finish()
157    }
158}
159
160/// Logging wrapper for tool calls.
161#[derive(Debug, Clone)]
162pub struct LoggingWrapper {
163    prefix: String,
164}
165
166impl LoggingWrapper {
167    /// Create a new logging wrapper.
168    #[must_use]
169    pub fn new(prefix: impl Into<String>) -> Self {
170        Self {
171            prefix: prefix.into(),
172        }
173    }
174
175    /// Wrap a toolset with logging.
176    pub fn wrap<T, Deps>(self, toolset: T) -> WrapperToolset<T, Deps>
177    where
178        T: AbstractToolset<Deps>,
179        Deps: Send + Sync + 'static,
180    {
181        let before_prefix = self.prefix.clone();
182        let after_prefix = self.prefix.clone();
183
184        WrapperToolset::new(toolset)
185            .before(move |name, args, _ctx| {
186                tracing::debug!(
187                    target: "tool_calls",
188                    "[{}] Calling tool '{}' with args: {}",
189                    before_prefix,
190                    name,
191                    args
192                );
193            })
194            .after(move |name, result, _ctx| match result {
195                Ok(_) => {
196                    tracing::debug!(
197                        target: "tool_calls",
198                        "[{}] Tool '{}' completed successfully",
199                        after_prefix,
200                        name
201                    );
202                }
203                Err(e) => {
204                    tracing::warn!(
205                        target: "tool_calls",
206                        "[{}] Tool '{}' failed: {}",
207                        after_prefix,
208                        name,
209                        e
210                    );
211                }
212            })
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::FunctionToolset;
220    use async_trait::async_trait;
221    use serdes_ai_tools::{Tool, ToolDefinition};
222    use std::sync::atomic::{AtomicU32, Ordering};
223
224    struct TestTool;
225
226    #[async_trait]
227    impl Tool<()> for TestTool {
228        fn definition(&self) -> ToolDefinition {
229            ToolDefinition::new("test", "Test tool")
230        }
231
232        async fn call(
233            &self,
234            _ctx: &RunContext<()>,
235            _args: JsonValue,
236        ) -> Result<ToolReturn, ToolError> {
237            Ok(ToolReturn::text("result"))
238        }
239    }
240
241    #[tokio::test]
242    async fn test_wrapper_before_hook() {
243        let before_count = Arc::new(AtomicU32::new(0));
244        let counter = before_count.clone();
245
246        let toolset = FunctionToolset::new().tool(TestTool);
247        let wrapped = WrapperToolset::new(toolset).before(move |_, _, _| {
248            counter.fetch_add(1, Ordering::SeqCst);
249        });
250
251        let ctx = RunContext::minimal("test");
252        let tools = wrapped.get_tools(&ctx).await.unwrap();
253        let tool = tools.get("test").unwrap();
254
255        wrapped
256            .call_tool("test", serde_json::json!({}), &ctx, tool)
257            .await
258            .unwrap();
259
260        assert_eq!(before_count.load(Ordering::SeqCst), 1);
261    }
262
263    #[tokio::test]
264    async fn test_wrapper_after_hook() {
265        let after_count = Arc::new(AtomicU32::new(0));
266        let counter = after_count.clone();
267
268        let toolset = FunctionToolset::new().tool(TestTool);
269        let wrapped = WrapperToolset::new(toolset).after(move |_, _, _| {
270            counter.fetch_add(1, Ordering::SeqCst);
271        });
272
273        let ctx = RunContext::minimal("test");
274        let tools = wrapped.get_tools(&ctx).await.unwrap();
275        let tool = tools.get("test").unwrap();
276
277        wrapped
278            .call_tool("test", serde_json::json!({}), &ctx, tool)
279            .await
280            .unwrap();
281
282        assert_eq!(after_count.load(Ordering::SeqCst), 1);
283    }
284
285    #[tokio::test]
286    async fn test_wrapper_both_hooks() {
287        let call_order = Arc::new(parking_lot::Mutex::new(Vec::new()));
288        let before_order = call_order.clone();
289        let after_order = call_order.clone();
290
291        let toolset = FunctionToolset::new().tool(TestTool);
292        let wrapped = WrapperToolset::new(toolset)
293            .before(move |_, _, _| {
294                before_order.lock().push("before");
295            })
296            .after(move |_, _, _| {
297                after_order.lock().push("after");
298            });
299
300        let ctx = RunContext::minimal("test");
301        let tools = wrapped.get_tools(&ctx).await.unwrap();
302        let tool = tools.get("test").unwrap();
303
304        wrapped
305            .call_tool("test", serde_json::json!({}), &ctx, tool)
306            .await
307            .unwrap();
308
309        let order = call_order.lock();
310        assert_eq!(*order, vec!["before", "after"]);
311    }
312
313    #[tokio::test]
314    async fn test_wrapper_receives_args() {
315        let received_name = Arc::new(parking_lot::Mutex::new(String::new()));
316        let received_args = Arc::new(parking_lot::Mutex::new(serde_json::Value::Null));
317
318        let name_ref = received_name.clone();
319        let args_ref = received_args.clone();
320
321        let toolset = FunctionToolset::new().tool(TestTool);
322        let wrapped = WrapperToolset::new(toolset).before(move |name, args, _| {
323            *name_ref.lock() = name.to_string();
324            *args_ref.lock() = args.clone();
325        });
326
327        let ctx = RunContext::minimal("test");
328        let tools = wrapped.get_tools(&ctx).await.unwrap();
329        let tool = tools.get("test").unwrap();
330
331        wrapped
332            .call_tool("test", serde_json::json!({"key": "value"}), &ctx, tool)
333            .await
334            .unwrap();
335
336        assert_eq!(*received_name.lock(), "test");
337        assert_eq!(received_args.lock()["key"], "value");
338    }
339}