Skip to main content

serdes_ai_output/
toolset.rs

1//! Output toolset implementation.
2//!
3//! This module provides `OutputToolset`, an internal toolset that captures
4//! structured output via tool calls.
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use serde_json::Value as JsonValue;
11use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
12use serdes_ai_toolsets::{AbstractToolset, ToolsetTool};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15use std::sync::Arc;
16
17use crate::schema::OutputSchema;
18use crate::structured::StructuredOutputSchema;
19
20/// Internal toolset that captures output via tool calls.
21///
22/// This toolset provides a special tool that the model calls to
23/// "return" its structured output. The output is captured and
24/// stored for retrieval.
25///
26/// # Example
27///
28/// ```ignore
29/// use serdes_ai_output::{OutputToolset, StructuredOutputSchema};
30///
31/// let schema = StructuredOutputSchema::<MyOutput>::new(my_json_schema);
32/// let toolset = OutputToolset::new(schema);
33///
34/// // Add to agent's tools...
35/// // After model calls the output tool, retrieve the captured output:
36/// if let Some(output) = toolset.take_output() {
37///     println!("Got output: {:?}", output);
38/// }
39/// ```
40pub struct OutputToolset<T, Deps = ()>
41where
42    T: DeserializeOwned + Send + Sync + 'static,
43{
44    schema: StructuredOutputSchema<T>,
45    captured: Arc<RwLock<Option<T>>>,
46    _phantom: PhantomData<Deps>,
47}
48
49impl<T, Deps> OutputToolset<T, Deps>
50where
51    T: DeserializeOwned + Send + Sync + 'static,
52{
53    /// Create a new output toolset.
54    pub fn new(schema: StructuredOutputSchema<T>) -> Self {
55        Self {
56            schema,
57            captured: Arc::new(RwLock::new(None)),
58            _phantom: PhantomData,
59        }
60    }
61
62    /// Check if output has been captured.
63    #[must_use]
64    pub fn has_output(&self) -> bool {
65        self.captured.read().is_some()
66    }
67
68    /// Take the captured output, if any.
69    pub fn take_output(&self) -> Option<T> {
70        self.captured.write().take()
71    }
72
73    /// Get a reference to the captured output.
74    pub fn get_output(&self) -> Option<T>
75    where
76        T: Clone,
77    {
78        self.captured.read().clone()
79    }
80
81    /// Clear any captured output.
82    pub fn clear(&self) {
83        *self.captured.write() = None;
84    }
85
86    /// Get the schema.
87    #[must_use]
88    pub fn schema(&self) -> &StructuredOutputSchema<T> {
89        &self.schema
90    }
91
92    /// Get the tool name.
93    #[must_use]
94    pub fn tool_name(&self) -> &str {
95        &self.schema.tool_name
96    }
97}
98
99impl<T, Deps> std::fmt::Debug for OutputToolset<T, Deps>
100where
101    T: DeserializeOwned + Send + Sync + 'static,
102{
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("OutputToolset")
105            .field("tool_name", &self.schema.tool_name)
106            .field("has_output", &self.has_output())
107            .finish()
108    }
109}
110
111#[async_trait]
112impl<T, Deps> AbstractToolset<Deps> for OutputToolset<T, Deps>
113where
114    T: DeserializeOwned + Serialize + Send + Sync + 'static,
115    Deps: Send + Sync + 'static,
116{
117    fn id(&self) -> Option<&str> {
118        Some("__output__")
119    }
120
121    fn type_name(&self) -> &'static str {
122        "OutputToolset"
123    }
124
125    async fn get_tools(
126        &self,
127        _ctx: &RunContext<Deps>,
128    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
129        let defs = self.schema.tool_definitions();
130        let mut tools = HashMap::with_capacity(defs.len());
131
132        for def in defs {
133            let name = def.name.clone();
134            tools.insert(name, ToolsetTool::new(def).with_toolset_id("__output__"));
135        }
136
137        Ok(tools)
138    }
139
140    async fn call_tool(
141        &self,
142        name: &str,
143        args: JsonValue,
144        _ctx: &RunContext<Deps>,
145        _tool: &ToolsetTool,
146    ) -> Result<ToolReturn, ToolError> {
147        // Parse and capture the output
148        let output: T = self
149            .schema
150            .parse_tool_call(name, &args)
151            .map_err(|e| ToolError::execution_failed(e.to_string()))?;
152
153        // Store the captured output
154        *self.captured.write() = Some(output);
155
156        // Return success with the args as confirmation
157        Ok(ToolReturn::json(args))
158    }
159}
160
161/// Marker type for output tool results.
162#[derive(Debug, Clone)]
163pub struct OutputCaptured<T> {
164    /// The captured value.
165    pub value: T,
166    /// The tool name that was called.
167    pub tool_name: String,
168}
169
170impl<T> OutputCaptured<T> {
171    /// Create a new output captured marker.
172    pub fn new(value: T, tool_name: impl Into<String>) -> Self {
173        Self {
174            value,
175            tool_name: tool_name.into(),
176        }
177    }
178
179    /// Unwrap the captured value.
180    pub fn into_inner(self) -> T {
181        self.value
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use serde::Deserialize;
189    use serdes_ai_tools::{ObjectJsonSchema, PropertySchema};
190
191    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
192    struct TestOutput {
193        message: String,
194        count: i32,
195    }
196
197    fn test_schema() -> StructuredOutputSchema<TestOutput> {
198        let json_schema = ObjectJsonSchema::new()
199            .with_property(
200                "message",
201                PropertySchema::string("The message").build(),
202                true,
203            )
204            .with_property("count", PropertySchema::integer("The count").build(), true);
205
206        StructuredOutputSchema::new(json_schema)
207    }
208
209    #[test]
210    fn test_output_toolset_new() {
211        let toolset: OutputToolset<TestOutput> = OutputToolset::new(test_schema());
212        assert!(!toolset.has_output());
213        assert_eq!(toolset.id(), Some("__output__"));
214    }
215
216    #[tokio::test]
217    async fn test_output_toolset_get_tools() {
218        let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
219        let ctx = RunContext::minimal("test");
220
221        let tools = toolset.get_tools(&ctx).await.unwrap();
222        assert_eq!(tools.len(), 1);
223        assert!(tools.contains_key("final_result"));
224    }
225
226    #[tokio::test]
227    async fn test_output_toolset_call_and_capture() {
228        let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
229        let ctx = RunContext::minimal("test");
230
231        let tools = toolset.get_tools(&ctx).await.unwrap();
232        let tool = tools.get("final_result").unwrap();
233
234        let args = serde_json::json!({
235            "message": "Hello, World!",
236            "count": 42
237        });
238
239        let result = toolset.call_tool("final_result", args, &ctx, tool).await;
240        assert!(result.is_ok());
241
242        // Check captured output
243        assert!(toolset.has_output());
244        let output = toolset.take_output().unwrap();
245        assert_eq!(output.message, "Hello, World!");
246        assert_eq!(output.count, 42);
247
248        // After take, should be empty
249        assert!(!toolset.has_output());
250    }
251
252    #[tokio::test]
253    async fn test_output_toolset_clear() {
254        let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
255        let ctx = RunContext::minimal("test");
256
257        let tools = toolset.get_tools(&ctx).await.unwrap();
258        let tool = tools.get("final_result").unwrap();
259
260        let args = serde_json::json!({
261            "message": "Test",
262            "count": 1
263        });
264
265        toolset
266            .call_tool("final_result", args, &ctx, tool)
267            .await
268            .unwrap();
269
270        assert!(toolset.has_output());
271        toolset.clear();
272        assert!(!toolset.has_output());
273    }
274
275    #[tokio::test]
276    async fn test_output_toolset_wrong_tool_name() {
277        let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
278        let ctx = RunContext::minimal("test");
279
280        let tools = toolset.get_tools(&ctx).await.unwrap();
281        let tool = tools.get("final_result").unwrap();
282
283        let args = serde_json::json!({"message": "Test", "count": 1});
284
285        let result = toolset.call_tool("wrong_name", args, &ctx, tool).await;
286
287        assert!(result.is_err());
288    }
289
290    #[test]
291    fn test_output_captured() {
292        let captured = OutputCaptured::new("test".to_string(), "my_tool");
293        assert_eq!(captured.tool_name, "my_tool");
294        assert_eq!(captured.into_inner(), "test");
295    }
296}