1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15pub type BeforeCallHook<Deps> = dyn Fn(&str, &JsonValue, &RunContext<Deps>) + Send + Sync;
17
18pub type AfterCallHook<Deps> =
20 dyn Fn(&str, &Result<ToolReturn, ToolError>, &RunContext<Deps>) + Send + Sync;
21
22pub struct WrapperToolset<T, Deps = ()> {
46 inner: T,
47 before_call: Option<Arc<BeforeCallHook<Deps>>>,
48 after_call: Option<Arc<AfterCallHook<Deps>>>,
49 _phantom: PhantomData<fn() -> Deps>,
50}
51
52impl<T, Deps> WrapperToolset<T, Deps>
53where
54 T: AbstractToolset<Deps>,
55{
56 pub fn new(inner: T) -> Self {
58 Self {
59 inner,
60 before_call: None,
61 after_call: None,
62 _phantom: PhantomData,
63 }
64 }
65
66 #[must_use]
68 pub fn before<F>(mut self, f: F) -> Self
69 where
70 F: Fn(&str, &JsonValue, &RunContext<Deps>) + Send + Sync + 'static,
71 {
72 self.before_call = Some(Arc::new(f));
73 self
74 }
75
76 #[must_use]
78 pub fn after<F>(mut self, f: F) -> Self
79 where
80 F: Fn(&str, &Result<ToolReturn, ToolError>, &RunContext<Deps>) + Send + Sync + 'static,
81 {
82 self.after_call = Some(Arc::new(f));
83 self
84 }
85
86 #[must_use]
88 pub fn inner(&self) -> &T {
89 &self.inner
90 }
91}
92
93#[async_trait]
94impl<T, Deps> AbstractToolset<Deps> for WrapperToolset<T, Deps>
95where
96 T: AbstractToolset<Deps>,
97 Deps: Send + Sync,
98{
99 fn id(&self) -> Option<&str> {
100 self.inner.id()
101 }
102
103 fn type_name(&self) -> &'static str {
104 "WrapperToolset"
105 }
106
107 fn label(&self) -> String {
108 format!("WrapperToolset({})", self.inner.label())
109 }
110
111 async fn get_tools(
112 &self,
113 ctx: &RunContext<Deps>,
114 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
115 self.inner.get_tools(ctx).await
116 }
117
118 async fn call_tool(
119 &self,
120 name: &str,
121 args: JsonValue,
122 ctx: &RunContext<Deps>,
123 tool: &ToolsetTool,
124 ) -> Result<ToolReturn, ToolError> {
125 if let Some(ref before) = self.before_call {
127 before(name, &args, ctx);
128 }
129
130 let result = self.inner.call_tool(name, args, ctx, tool).await;
132
133 if let Some(ref after) = self.after_call {
135 after(name, &result, ctx);
136 }
137
138 result
139 }
140
141 async fn enter(&self) -> Result<(), ToolError> {
142 self.inner.enter().await
143 }
144
145 async fn exit(&self) -> Result<(), ToolError> {
146 self.inner.exit().await
147 }
148}
149
150impl<T: std::fmt::Debug, Deps> std::fmt::Debug for WrapperToolset<T, Deps> {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("WrapperToolset")
153 .field("inner", &self.inner)
154 .field("has_before", &self.before_call.is_some())
155 .field("has_after", &self.after_call.is_some())
156 .finish()
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct LoggingWrapper {
163 prefix: String,
164}
165
166impl LoggingWrapper {
167 #[must_use]
169 pub fn new(prefix: impl Into<String>) -> Self {
170 Self {
171 prefix: prefix.into(),
172 }
173 }
174
175 pub fn wrap<T, Deps>(self, toolset: T) -> WrapperToolset<T, Deps>
177 where
178 T: AbstractToolset<Deps>,
179 Deps: Send + Sync + 'static,
180 {
181 let before_prefix = self.prefix.clone();
182 let after_prefix = self.prefix.clone();
183
184 WrapperToolset::new(toolset)
185 .before(move |name, args, _ctx| {
186 tracing::debug!(
187 target: "tool_calls",
188 "[{}] Calling tool '{}' with args: {}",
189 before_prefix,
190 name,
191 args
192 );
193 })
194 .after(move |name, result, _ctx| match result {
195 Ok(_) => {
196 tracing::debug!(
197 target: "tool_calls",
198 "[{}] Tool '{}' completed successfully",
199 after_prefix,
200 name
201 );
202 }
203 Err(e) => {
204 tracing::warn!(
205 target: "tool_calls",
206 "[{}] Tool '{}' failed: {}",
207 after_prefix,
208 name,
209 e
210 );
211 }
212 })
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::FunctionToolset;
220 use async_trait::async_trait;
221 use serdes_ai_tools::{Tool, ToolDefinition};
222 use std::sync::atomic::{AtomicU32, Ordering};
223
224 struct TestTool;
225
226 #[async_trait]
227 impl Tool<()> for TestTool {
228 fn definition(&self) -> ToolDefinition {
229 ToolDefinition::new("test", "Test tool")
230 }
231
232 async fn call(
233 &self,
234 _ctx: &RunContext<()>,
235 _args: JsonValue,
236 ) -> Result<ToolReturn, ToolError> {
237 Ok(ToolReturn::text("result"))
238 }
239 }
240
241 #[tokio::test]
242 async fn test_wrapper_before_hook() {
243 let before_count = Arc::new(AtomicU32::new(0));
244 let counter = before_count.clone();
245
246 let toolset = FunctionToolset::new().tool(TestTool);
247 let wrapped = WrapperToolset::new(toolset).before(move |_, _, _| {
248 counter.fetch_add(1, Ordering::SeqCst);
249 });
250
251 let ctx = RunContext::minimal("test");
252 let tools = wrapped.get_tools(&ctx).await.unwrap();
253 let tool = tools.get("test").unwrap();
254
255 wrapped
256 .call_tool("test", serde_json::json!({}), &ctx, tool)
257 .await
258 .unwrap();
259
260 assert_eq!(before_count.load(Ordering::SeqCst), 1);
261 }
262
263 #[tokio::test]
264 async fn test_wrapper_after_hook() {
265 let after_count = Arc::new(AtomicU32::new(0));
266 let counter = after_count.clone();
267
268 let toolset = FunctionToolset::new().tool(TestTool);
269 let wrapped = WrapperToolset::new(toolset).after(move |_, _, _| {
270 counter.fetch_add(1, Ordering::SeqCst);
271 });
272
273 let ctx = RunContext::minimal("test");
274 let tools = wrapped.get_tools(&ctx).await.unwrap();
275 let tool = tools.get("test").unwrap();
276
277 wrapped
278 .call_tool("test", serde_json::json!({}), &ctx, tool)
279 .await
280 .unwrap();
281
282 assert_eq!(after_count.load(Ordering::SeqCst), 1);
283 }
284
285 #[tokio::test]
286 async fn test_wrapper_both_hooks() {
287 let call_order = Arc::new(parking_lot::Mutex::new(Vec::new()));
288 let before_order = call_order.clone();
289 let after_order = call_order.clone();
290
291 let toolset = FunctionToolset::new().tool(TestTool);
292 let wrapped = WrapperToolset::new(toolset)
293 .before(move |_, _, _| {
294 before_order.lock().push("before");
295 })
296 .after(move |_, _, _| {
297 after_order.lock().push("after");
298 });
299
300 let ctx = RunContext::minimal("test");
301 let tools = wrapped.get_tools(&ctx).await.unwrap();
302 let tool = tools.get("test").unwrap();
303
304 wrapped
305 .call_tool("test", serde_json::json!({}), &ctx, tool)
306 .await
307 .unwrap();
308
309 let order = call_order.lock();
310 assert_eq!(*order, vec!["before", "after"]);
311 }
312
313 #[tokio::test]
314 async fn test_wrapper_receives_args() {
315 let received_name = Arc::new(parking_lot::Mutex::new(String::new()));
316 let received_args = Arc::new(parking_lot::Mutex::new(serde_json::Value::Null));
317
318 let name_ref = received_name.clone();
319 let args_ref = received_args.clone();
320
321 let toolset = FunctionToolset::new().tool(TestTool);
322 let wrapped = WrapperToolset::new(toolset).before(move |name, args, _| {
323 *name_ref.lock() = name.to_string();
324 *args_ref.lock() = args.clone();
325 });
326
327 let ctx = RunContext::minimal("test");
328 let tools = wrapped.get_tools(&ctx).await.unwrap();
329 let tool = tools.get("test").unwrap();
330
331 wrapped
332 .call_tool("test", serde_json::json!({"key": "value"}), &ctx, tool)
333 .await
334 .unwrap();
335
336 assert_eq!(*received_name.lock(), "test");
337 assert_eq!(received_args.lock()["key"], "value");
338 }
339}