1use crate::error::OutputParseError;
7use crate::mode::OutputMode;
8use crate::schema::OutputSchema;
9use async_trait::async_trait;
10use regex::Regex;
11use serde_json::Value as JsonValue;
12
13#[derive(Debug, Clone, Default)]
31pub struct TextOutputSchema {
32 pattern: Option<Regex>,
34 pattern_str: Option<String>,
36 min_length: Option<usize>,
38 max_length: Option<usize>,
40 trim_whitespace: bool,
42}
43
44impl TextOutputSchema {
45 #[must_use]
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn with_pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
57 self.pattern = Some(Regex::new(pattern)?);
58 self.pattern_str = Some(pattern.to_string());
59 Ok(self)
60 }
61
62 #[must_use]
64 pub fn with_min_length(mut self, len: usize) -> Self {
65 self.min_length = Some(len);
66 self
67 }
68
69 #[must_use]
71 pub fn with_max_length(mut self, len: usize) -> Self {
72 self.max_length = Some(len);
73 self
74 }
75
76 #[must_use]
78 pub fn trim(mut self) -> Self {
79 self.trim_whitespace = true;
80 self
81 }
82
83 fn validate_text(&self, text: &str) -> Result<String, OutputParseError> {
85 let text = if self.trim_whitespace {
86 text.trim().to_string()
87 } else {
88 text.to_string()
89 };
90
91 if let Some(min) = self.min_length {
93 if text.len() < min {
94 return Err(OutputParseError::too_short(text.len(), min));
95 }
96 }
97
98 if let Some(max) = self.max_length {
100 if text.len() > max {
101 return Err(OutputParseError::too_long(text.len(), max));
102 }
103 }
104
105 if let Some(ref pattern) = self.pattern {
107 if !pattern.is_match(&text) {
108 return Err(OutputParseError::PatternMismatch {
109 pattern: self.pattern_str.clone().unwrap_or_default(),
110 });
111 }
112 }
113
114 Ok(text)
115 }
116}
117
118#[async_trait]
119impl OutputSchema<String> for TextOutputSchema {
120 fn mode(&self) -> OutputMode {
121 OutputMode::Text
122 }
123
124 fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
125 self.validate_text(text)
126 }
127
128 fn parse_tool_call(&self, _name: &str, args: &JsonValue) -> Result<String, OutputParseError> {
129 if let Some(text) = args.as_str() {
131 return self.validate_text(text);
132 }
133
134 for field in ["text", "content", "message", "result", "output"] {
136 if let Some(text) = args.get(field).and_then(|v| v.as_str()) {
137 return self.validate_text(text);
138 }
139 }
140
141 let text = serde_json::to_string(args).map_err(OutputParseError::JsonParse)?;
143 self.validate_text(&text)
144 }
145
146 fn parse_native(&self, value: &JsonValue) -> Result<String, OutputParseError> {
147 if let Some(text) = value.as_str() {
148 return self.validate_text(text);
149 }
150
151 if let Some(obj) = value.as_object() {
153 for field in ["text", "content", "message", "result", "output"] {
154 if let Some(text) = obj.get(field).and_then(|v| v.as_str()) {
155 return self.validate_text(text);
156 }
157 }
158 }
159
160 Err(OutputParseError::NotJson)
161 }
162}
163
164#[derive(Debug, Default)]
166pub struct TextOutputSchemaBuilder {
167 schema: TextOutputSchema,
168}
169
170impl TextOutputSchemaBuilder {
171 #[must_use]
173 pub fn new() -> Self {
174 Self::default()
175 }
176
177 pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
179 self.schema = self.schema.with_pattern(pattern)?;
180 Ok(self)
181 }
182
183 #[must_use]
185 pub fn min_length(mut self, len: usize) -> Self {
186 self.schema = self.schema.with_min_length(len);
187 self
188 }
189
190 #[must_use]
192 pub fn max_length(mut self, len: usize) -> Self {
193 self.schema = self.schema.with_max_length(len);
194 self
195 }
196
197 #[must_use]
199 pub fn trim(mut self) -> Self {
200 self.schema = self.schema.trim();
201 self
202 }
203
204 #[must_use]
206 pub fn build(self) -> TextOutputSchema {
207 self.schema
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_text_schema_default() {
217 let schema = TextOutputSchema::new();
218 assert_eq!(schema.mode(), OutputMode::Text);
219 }
220
221 #[test]
222 fn test_text_schema_parse_text() {
223 let schema = TextOutputSchema::new();
224 let result = schema.parse_text("hello world").unwrap();
225 assert_eq!(result, "hello world");
226 }
227
228 #[test]
229 fn test_text_schema_trim() {
230 let schema = TextOutputSchema::new().trim();
231 let result = schema.parse_text(" hello world ").unwrap();
232 assert_eq!(result, "hello world");
233 }
234
235 #[test]
236 fn test_text_schema_min_length() {
237 let schema = TextOutputSchema::new().with_min_length(10);
238
239 let result = schema.parse_text("short");
241 assert!(result.is_err());
242
243 let result = schema.parse_text("this is long enough");
245 assert!(result.is_ok());
246 }
247
248 #[test]
249 fn test_text_schema_max_length() {
250 let schema = TextOutputSchema::new().with_max_length(10);
251
252 let result = schema.parse_text("this is too long");
254 assert!(result.is_err());
255
256 let result = schema.parse_text("short");
258 assert!(result.is_ok());
259 }
260
261 #[test]
262 fn test_text_schema_pattern() {
263 let schema = TextOutputSchema::new()
264 .with_pattern(r"^\d{3}-\d{4}$")
265 .unwrap();
266
267 let result = schema.parse_text("123-4567");
269 assert!(result.is_ok());
270
271 let result = schema.parse_text("abc-defg");
273 assert!(result.is_err());
274 }
275
276 #[test]
277 fn test_text_schema_combined_constraints() {
278 let schema = TextOutputSchema::new()
279 .with_min_length(5)
280 .with_max_length(20)
281 .trim();
282
283 let result = schema.parse_text(" hello world ");
285 assert!(result.is_ok());
286 assert_eq!(result.unwrap(), "hello world");
287
288 let result = schema.parse_text(" hi ");
290 assert!(result.is_err());
291 }
292
293 #[test]
294 fn test_text_schema_parse_tool_call_string() {
295 let schema = TextOutputSchema::new();
296 let args = serde_json::json!("hello");
297 let result = schema.parse_tool_call("result", &args).unwrap();
298 assert_eq!(result, "hello");
299 }
300
301 #[test]
302 fn test_text_schema_parse_tool_call_object() {
303 let schema = TextOutputSchema::new();
304 let args = serde_json::json!({"text": "hello from tool"});
305 let result = schema.parse_tool_call("result", &args).unwrap();
306 assert_eq!(result, "hello from tool");
307 }
308
309 #[test]
310 fn test_text_schema_parse_native() {
311 let schema = TextOutputSchema::new();
312
313 let value = serde_json::json!("native output");
315 let result = schema.parse_native(&value).unwrap();
316 assert_eq!(result, "native output");
317
318 let value = serde_json::json!({"content": "from content"});
320 let result = schema.parse_native(&value).unwrap();
321 assert_eq!(result, "from content");
322 }
323
324 #[test]
325 fn test_builder() {
326 let schema = TextOutputSchemaBuilder::new()
327 .min_length(5)
328 .max_length(100)
329 .trim()
330 .build();
331
332 let result = schema.parse_text(" hello ").unwrap();
333 assert_eq!(result, "hello");
334 }
335}