Skip to main content

serdes_ai_output/
types.rs

1//! Output type wrappers and markers.
2//!
3//! This module provides marker types and wrappers for different output modes.
4
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value as JsonValue};
7use std::marker::PhantomData;
8use std::ops::{Deref, DerefMut};
9
10use crate::structured::{DEFAULT_OUTPUT_TOOL_DESCRIPTION, DEFAULT_OUTPUT_TOOL_NAME};
11
12/// Marker for native structured output mode.
13///
14/// Use this to indicate that output should use the model's native
15/// structured output feature (like OpenAI's response_format).
16#[derive(Debug, Clone)]
17pub struct NativeOutput<T> {
18    _phantom: PhantomData<T>,
19}
20
21impl<T> NativeOutput<T> {
22    /// Create a new native output marker.
23    #[must_use]
24    pub fn new() -> Self {
25        Self {
26            _phantom: PhantomData,
27        }
28    }
29}
30
31impl<T> Default for NativeOutput<T> {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37/// Marker for prompted output mode (JSON in text).
38///
39/// Use this to indicate that output should be prompted as JSON
40/// but without native structured output enforcement.
41#[derive(Debug, Clone)]
42pub struct PromptedOutput<T> {
43    _phantom: PhantomData<T>,
44}
45
46impl<T> PromptedOutput<T> {
47    /// Create a new prompted output marker.
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            _phantom: PhantomData,
52        }
53    }
54}
55
56impl<T> Default for PromptedOutput<T> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62/// Marker for tool-based output mode.
63///
64/// Use this to indicate that output should be captured via a tool call.
65/// This is the most reliable method for structured output.
66#[derive(Debug, Clone)]
67pub struct ToolOutput<T> {
68    /// The tool name.
69    pub tool_name: String,
70    /// The tool description.
71    pub tool_description: String,
72    _phantom: PhantomData<T>,
73}
74
75impl<T> ToolOutput<T> {
76    /// Create a new tool output marker with default settings.
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            tool_name: DEFAULT_OUTPUT_TOOL_NAME.to_string(),
81            tool_description: DEFAULT_OUTPUT_TOOL_DESCRIPTION.to_string(),
82            _phantom: PhantomData,
83        }
84    }
85
86    /// Set the tool name.
87    #[must_use]
88    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
89        self.tool_name = name.into();
90        self
91    }
92
93    /// Set the tool description.
94    #[must_use]
95    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
96        self.tool_description = desc.into();
97        self
98    }
99}
100
101impl<T> Default for ToolOutput<T> {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Plain text output (no structure).
108///
109/// Use this when you just want the model's text response without
110/// any structured parsing.
111#[derive(Debug, Clone, Default)]
112pub struct TextOutput;
113
114impl TextOutput {
115    /// Create a new text output marker.
116    #[must_use]
117    pub fn new() -> Self {
118        Self
119    }
120}
121
122/// Structured dictionary output (like TypedDict in Python).
123///
124/// This provides a flexible key-value structure for when you
125/// don't want to define a specific struct type.
126#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
127pub struct StructuredDict(pub Map<String, JsonValue>);
128
129impl StructuredDict {
130    /// Create a new empty structured dict.
131    #[must_use]
132    pub fn new() -> Self {
133        Self(Map::new())
134    }
135
136    /// Create from a JSON map.
137    #[must_use]
138    pub fn from_map(map: Map<String, JsonValue>) -> Self {
139        Self(map)
140    }
141
142    /// Get a value by key.
143    #[must_use]
144    pub fn get(&self, key: &str) -> Option<&JsonValue> {
145        self.0.get(key)
146    }
147
148    /// Get a mutable value by key.
149    pub fn get_mut(&mut self, key: &str) -> Option<&mut JsonValue> {
150        self.0.get_mut(key)
151    }
152
153    /// Insert a value.
154    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<JsonValue>) {
155        self.0.insert(key.into(), value.into());
156    }
157
158    /// Remove a value.
159    pub fn remove(&mut self, key: &str) -> Option<JsonValue> {
160        self.0.remove(key)
161    }
162
163    /// Check if the dict contains a key.
164    #[must_use]
165    pub fn contains_key(&self, key: &str) -> bool {
166        self.0.contains_key(key)
167    }
168
169    /// Get the number of entries.
170    #[must_use]
171    pub fn len(&self) -> usize {
172        self.0.len()
173    }
174
175    /// Check if empty.
176    #[must_use]
177    pub fn is_empty(&self) -> bool {
178        self.0.is_empty()
179    }
180
181    /// Get an iterator over keys.
182    pub fn keys(&self) -> impl Iterator<Item = &String> {
183        self.0.keys()
184    }
185
186    /// Get an iterator over values.
187    pub fn values(&self) -> impl Iterator<Item = &JsonValue> {
188        self.0.values()
189    }
190
191    /// Get an iterator over entries.
192    pub fn iter(&self) -> impl Iterator<Item = (&String, &JsonValue)> {
193        self.0.iter()
194    }
195
196    /// Convert to a JSON value.
197    #[must_use]
198    pub fn to_json(&self) -> JsonValue {
199        JsonValue::Object(self.0.clone())
200    }
201
202    /// Try to get a string value.
203    #[must_use]
204    pub fn get_str(&self, key: &str) -> Option<&str> {
205        self.0.get(key).and_then(|v| v.as_str())
206    }
207
208    /// Try to get an i64 value.
209    #[must_use]
210    pub fn get_i64(&self, key: &str) -> Option<i64> {
211        self.0.get(key).and_then(|v| v.as_i64())
212    }
213
214    /// Try to get a f64 value.
215    #[must_use]
216    pub fn get_f64(&self, key: &str) -> Option<f64> {
217        self.0.get(key).and_then(|v| v.as_f64())
218    }
219
220    /// Try to get a bool value.
221    #[must_use]
222    pub fn get_bool(&self, key: &str) -> Option<bool> {
223        self.0.get(key).and_then(|v| v.as_bool())
224    }
225
226    /// Try to get an array value.
227    #[must_use]
228    pub fn get_array(&self, key: &str) -> Option<&Vec<JsonValue>> {
229        self.0.get(key).and_then(|v| v.as_array())
230    }
231
232    /// Try to get an object value.
233    #[must_use]
234    pub fn get_object(&self, key: &str) -> Option<&Map<String, JsonValue>> {
235        self.0.get(key).and_then(|v| v.as_object())
236    }
237}
238
239impl Deref for StructuredDict {
240    type Target = Map<String, JsonValue>;
241
242    fn deref(&self) -> &Self::Target {
243        &self.0
244    }
245}
246
247impl DerefMut for StructuredDict {
248    fn deref_mut(&mut self) -> &mut Self::Target {
249        &mut self.0
250    }
251}
252
253impl From<Map<String, JsonValue>> for StructuredDict {
254    fn from(map: Map<String, JsonValue>) -> Self {
255        Self(map)
256    }
257}
258
259impl From<StructuredDict> for JsonValue {
260    fn from(dict: StructuredDict) -> Self {
261        JsonValue::Object(dict.0)
262    }
263}
264
265impl TryFrom<JsonValue> for StructuredDict {
266    type Error = &'static str;
267
268    fn try_from(value: JsonValue) -> Result<Self, Self::Error> {
269        match value {
270            JsonValue::Object(map) => Ok(Self(map)),
271            _ => Err("Expected JSON object"),
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_native_output() {
282        let _output: NativeOutput<String> = NativeOutput::new();
283    }
284
285    #[test]
286    fn test_prompted_output() {
287        let _output: PromptedOutput<String> = PromptedOutput::new();
288    }
289
290    #[test]
291    fn test_tool_output() {
292        let output: ToolOutput<String> = ToolOutput::new()
293            .with_tool_name("result")
294            .with_description("The result");
295
296        assert_eq!(output.tool_name, "result");
297        assert_eq!(output.tool_description, "The result");
298    }
299
300    #[test]
301    fn test_text_output() {
302        let _output = TextOutput::new();
303    }
304
305    #[test]
306    fn test_structured_dict_new() {
307        let dict = StructuredDict::new();
308        assert!(dict.is_empty());
309    }
310
311    #[test]
312    fn test_structured_dict_insert_get() {
313        let mut dict = StructuredDict::new();
314        dict.insert("name", "Alice");
315        dict.insert("age", 30);
316
317        assert_eq!(dict.get_str("name"), Some("Alice"));
318        assert_eq!(dict.get_i64("age"), Some(30));
319    }
320
321    #[test]
322    fn test_structured_dict_remove() {
323        let mut dict = StructuredDict::new();
324        dict.insert("key", "value");
325        assert!(dict.contains_key("key"));
326
327        dict.remove("key");
328        assert!(!dict.contains_key("key"));
329    }
330
331    #[test]
332    fn test_structured_dict_serde() {
333        let mut dict = StructuredDict::new();
334        dict.insert("name", "Bob");
335        dict.insert("score", 95.5);
336
337        let json = serde_json::to_string(&dict).unwrap();
338        let parsed: StructuredDict = serde_json::from_str(&json).unwrap();
339
340        assert_eq!(parsed.get_str("name"), Some("Bob"));
341        assert_eq!(parsed.get_f64("score"), Some(95.5));
342    }
343
344    #[test]
345    fn test_structured_dict_from_json() {
346        let json = serde_json::json!({"a": 1, "b": "two"});
347        let dict: StructuredDict = json.try_into().unwrap();
348
349        assert_eq!(dict.get_i64("a"), Some(1));
350        assert_eq!(dict.get_str("b"), Some("two"));
351    }
352
353    #[test]
354    fn test_structured_dict_to_json() {
355        let mut dict = StructuredDict::new();
356        dict.insert("key", "value");
357
358        let json = dict.to_json();
359        assert_eq!(json["key"], "value");
360    }
361
362    #[test]
363    fn test_structured_dict_iter() {
364        let mut dict = StructuredDict::new();
365        dict.insert("a", 1);
366        dict.insert("b", 2);
367
368        let keys: Vec<_> = dict.keys().collect();
369        assert_eq!(keys.len(), 2);
370    }
371}