Skip to main content

serdes_ai_output/
text.rs

1//! Text output schema implementation.
2//!
3//! This module provides `TextOutputSchema` for handling plain text output
4//! with optional validation constraints like patterns, length limits, etc.
5
6use crate::error::OutputParseError;
7use crate::mode::OutputMode;
8use crate::schema::OutputSchema;
9use async_trait::async_trait;
10use regex::Regex;
11use serde_json::Value as JsonValue;
12
13/// Schema for plain text output.
14///
15/// This schema validates text output with optional constraints:
16/// - Pattern matching via regex
17/// - Minimum/maximum length
18/// - Trim whitespace
19///
20/// # Example
21///
22/// ```rust
23/// use serdes_ai_output::TextOutputSchema;
24///
25/// let schema = TextOutputSchema::new()
26///     .with_min_length(10)
27///     .with_max_length(1000)
28///     .trim();
29/// ```
30#[derive(Debug, Clone, Default)]
31pub struct TextOutputSchema {
32    /// Optional regex pattern to match.
33    pattern: Option<Regex>,
34    /// Pattern string (for error messages).
35    pattern_str: Option<String>,
36    /// Minimum length.
37    min_length: Option<usize>,
38    /// Maximum length.
39    max_length: Option<usize>,
40    /// Whether to trim whitespace.
41    trim_whitespace: bool,
42}
43
44impl TextOutputSchema {
45    /// Create a new text output schema with no constraints.
46    #[must_use]
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Set a regex pattern the output must match.
52    ///
53    /// # Errors
54    ///
55    /// Returns an error if the pattern is invalid.
56    pub fn with_pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
57        self.pattern = Some(Regex::new(pattern)?);
58        self.pattern_str = Some(pattern.to_string());
59        Ok(self)
60    }
61
62    /// Set the minimum length constraint.
63    #[must_use]
64    pub fn with_min_length(mut self, len: usize) -> Self {
65        self.min_length = Some(len);
66        self
67    }
68
69    /// Set the maximum length constraint.
70    #[must_use]
71    pub fn with_max_length(mut self, len: usize) -> Self {
72        self.max_length = Some(len);
73        self
74    }
75
76    /// Enable whitespace trimming.
77    #[must_use]
78    pub fn trim(mut self) -> Self {
79        self.trim_whitespace = true;
80        self
81    }
82
83    /// Validate the text against configured constraints.
84    fn validate_text(&self, text: &str) -> Result<String, OutputParseError> {
85        let text = if self.trim_whitespace {
86            text.trim().to_string()
87        } else {
88            text.to_string()
89        };
90
91        // Check minimum length
92        if let Some(min) = self.min_length {
93            if text.len() < min {
94                return Err(OutputParseError::too_short(text.len(), min));
95            }
96        }
97
98        // Check maximum length
99        if let Some(max) = self.max_length {
100            if text.len() > max {
101                return Err(OutputParseError::too_long(text.len(), max));
102            }
103        }
104
105        // Check pattern
106        if let Some(ref pattern) = self.pattern {
107            if !pattern.is_match(&text) {
108                return Err(OutputParseError::PatternMismatch {
109                    pattern: self.pattern_str.clone().unwrap_or_default(),
110                });
111            }
112        }
113
114        Ok(text)
115    }
116}
117
118#[async_trait]
119impl OutputSchema<String> for TextOutputSchema {
120    fn mode(&self) -> OutputMode {
121        OutputMode::Text
122    }
123
124    fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
125        self.validate_text(text)
126    }
127
128    fn parse_tool_call(&self, _name: &str, args: &JsonValue) -> Result<String, OutputParseError> {
129        // Try to extract text from tool arguments
130        if let Some(text) = args.as_str() {
131            return self.validate_text(text);
132        }
133
134        // Try common field names
135        for field in ["text", "content", "message", "result", "output"] {
136            if let Some(text) = args.get(field).and_then(|v| v.as_str()) {
137                return self.validate_text(text);
138            }
139        }
140
141        // Fall back to JSON string representation
142        let text = serde_json::to_string(args).map_err(OutputParseError::JsonParse)?;
143        self.validate_text(&text)
144    }
145
146    fn parse_native(&self, value: &JsonValue) -> Result<String, OutputParseError> {
147        if let Some(text) = value.as_str() {
148            return self.validate_text(text);
149        }
150
151        // Try common field names for objects
152        if let Some(obj) = value.as_object() {
153            for field in ["text", "content", "message", "result", "output"] {
154                if let Some(text) = obj.get(field).and_then(|v| v.as_str()) {
155                    return self.validate_text(text);
156                }
157            }
158        }
159
160        Err(OutputParseError::NotJson)
161    }
162}
163
164/// Builder for text output schema with more options.
165#[derive(Debug, Default)]
166pub struct TextOutputSchemaBuilder {
167    schema: TextOutputSchema,
168}
169
170impl TextOutputSchemaBuilder {
171    /// Create a new builder.
172    #[must_use]
173    pub fn new() -> Self {
174        Self::default()
175    }
176
177    /// Set a regex pattern.
178    pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
179        self.schema = self.schema.with_pattern(pattern)?;
180        Ok(self)
181    }
182
183    /// Set minimum length.
184    #[must_use]
185    pub fn min_length(mut self, len: usize) -> Self {
186        self.schema = self.schema.with_min_length(len);
187        self
188    }
189
190    /// Set maximum length.
191    #[must_use]
192    pub fn max_length(mut self, len: usize) -> Self {
193        self.schema = self.schema.with_max_length(len);
194        self
195    }
196
197    /// Enable trimming.
198    #[must_use]
199    pub fn trim(mut self) -> Self {
200        self.schema = self.schema.trim();
201        self
202    }
203
204    /// Build the schema.
205    #[must_use]
206    pub fn build(self) -> TextOutputSchema {
207        self.schema
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_text_schema_default() {
217        let schema = TextOutputSchema::new();
218        assert_eq!(schema.mode(), OutputMode::Text);
219    }
220
221    #[test]
222    fn test_text_schema_parse_text() {
223        let schema = TextOutputSchema::new();
224        let result = schema.parse_text("hello world").unwrap();
225        assert_eq!(result, "hello world");
226    }
227
228    #[test]
229    fn test_text_schema_trim() {
230        let schema = TextOutputSchema::new().trim();
231        let result = schema.parse_text("  hello world  ").unwrap();
232        assert_eq!(result, "hello world");
233    }
234
235    #[test]
236    fn test_text_schema_min_length() {
237        let schema = TextOutputSchema::new().with_min_length(10);
238
239        // Too short
240        let result = schema.parse_text("short");
241        assert!(result.is_err());
242
243        // Long enough
244        let result = schema.parse_text("this is long enough");
245        assert!(result.is_ok());
246    }
247
248    #[test]
249    fn test_text_schema_max_length() {
250        let schema = TextOutputSchema::new().with_max_length(10);
251
252        // Too long
253        let result = schema.parse_text("this is too long");
254        assert!(result.is_err());
255
256        // Short enough
257        let result = schema.parse_text("short");
258        assert!(result.is_ok());
259    }
260
261    #[test]
262    fn test_text_schema_pattern() {
263        let schema = TextOutputSchema::new()
264            .with_pattern(r"^\d{3}-\d{4}$")
265            .unwrap();
266
267        // Matches
268        let result = schema.parse_text("123-4567");
269        assert!(result.is_ok());
270
271        // Doesn't match
272        let result = schema.parse_text("abc-defg");
273        assert!(result.is_err());
274    }
275
276    #[test]
277    fn test_text_schema_combined_constraints() {
278        let schema = TextOutputSchema::new()
279            .with_min_length(5)
280            .with_max_length(20)
281            .trim();
282
283        // Valid
284        let result = schema.parse_text("  hello world  ");
285        assert!(result.is_ok());
286        assert_eq!(result.unwrap(), "hello world");
287
288        // Too short after trim
289        let result = schema.parse_text("  hi  ");
290        assert!(result.is_err());
291    }
292
293    #[test]
294    fn test_text_schema_parse_tool_call_string() {
295        let schema = TextOutputSchema::new();
296        let args = serde_json::json!("hello");
297        let result = schema.parse_tool_call("result", &args).unwrap();
298        assert_eq!(result, "hello");
299    }
300
301    #[test]
302    fn test_text_schema_parse_tool_call_object() {
303        let schema = TextOutputSchema::new();
304        let args = serde_json::json!({"text": "hello from tool"});
305        let result = schema.parse_tool_call("result", &args).unwrap();
306        assert_eq!(result, "hello from tool");
307    }
308
309    #[test]
310    fn test_text_schema_parse_native() {
311        let schema = TextOutputSchema::new();
312
313        // Direct string
314        let value = serde_json::json!("native output");
315        let result = schema.parse_native(&value).unwrap();
316        assert_eq!(result, "native output");
317
318        // Object with content field
319        let value = serde_json::json!({"content": "from content"});
320        let result = schema.parse_native(&value).unwrap();
321        assert_eq!(result, "from content");
322    }
323
324    #[test]
325    fn test_builder() {
326        let schema = TextOutputSchemaBuilder::new()
327            .min_length(5)
328            .max_length(100)
329            .trim()
330            .build();
331
332        let result = schema.parse_text("  hello  ").unwrap();
333        assert_eq!(result, "hello");
334    }
335}