1use crate::context::RunContext;
7use crate::errors::{OutputParseError, OutputValidationError};
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde_json::Value as JsonValue;
11use std::any::TypeId;
12use std::marker::PhantomData;
13
14#[async_trait]
16pub trait OutputValidator<Output, Deps>: Send + Sync {
17 async fn validate(
21 &self,
22 output: Output,
23 ctx: &RunContext<Deps>,
24 ) -> Result<Output, OutputValidationError>;
25}
26
27pub struct AsyncValidator<F, Deps, Output, Fut>
33where
34 F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
35 Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send,
36{
37 func: F,
38 _phantom: PhantomData<(Deps, Output, Fut)>,
39}
40
41impl<F, Deps, Output, Fut> AsyncValidator<F, Deps, Output, Fut>
42where
43 F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
44 Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send,
45{
46 pub fn new(func: F) -> Self {
48 Self {
49 func,
50 _phantom: PhantomData,
51 }
52 }
53}
54
55#[async_trait]
56impl<F, Deps, Output, Fut> OutputValidator<Output, Deps> for AsyncValidator<F, Deps, Output, Fut>
57where
58 F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
59 Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send + Sync,
60 Deps: Send + Sync,
61 Output: Send + Sync,
62{
63 async fn validate(
64 &self,
65 output: Output,
66 ctx: &RunContext<Deps>,
67 ) -> Result<Output, OutputValidationError> {
68 (self.func)(output, ctx).await
69 }
70}
71
72pub struct SyncValidator<F, Deps, Output>
74where
75 F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
76{
77 func: F,
78 _phantom: PhantomData<(Deps, Output)>,
79}
80
81impl<F, Deps, Output> SyncValidator<F, Deps, Output>
82where
83 F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
84{
85 pub fn new(func: F) -> Self {
87 Self {
88 func,
89 _phantom: PhantomData,
90 }
91 }
92}
93
94#[async_trait]
95impl<F, Deps, Output> OutputValidator<Output, Deps> for SyncValidator<F, Deps, Output>
96where
97 F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
98 Deps: Send + Sync,
99 Output: Send + Sync,
100{
101 async fn validate(
102 &self,
103 output: Output,
104 ctx: &RunContext<Deps>,
105 ) -> Result<Output, OutputValidationError> {
106 (self.func)(output, ctx)
107 }
108}
109
110pub struct NonEmptyValidator;
116
117#[async_trait]
118impl<Deps: Send + Sync> OutputValidator<String, Deps> for NonEmptyValidator {
119 async fn validate(
120 &self,
121 output: String,
122 _ctx: &RunContext<Deps>,
123 ) -> Result<String, OutputValidationError> {
124 if output.trim().is_empty() {
125 Err(OutputValidationError::failed("Output cannot be empty"))
126 } else {
127 Ok(output)
128 }
129 }
130}
131
132pub struct LengthValidator {
134 min: Option<usize>,
135 max: Option<usize>,
136}
137
138impl LengthValidator {
139 pub fn new() -> Self {
141 Self {
142 min: None,
143 max: None,
144 }
145 }
146
147 pub fn min(mut self, min: usize) -> Self {
149 self.min = Some(min);
150 self
151 }
152
153 pub fn max(mut self, max: usize) -> Self {
155 self.max = Some(max);
156 self
157 }
158}
159
160impl Default for LengthValidator {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166#[async_trait]
167impl<Deps: Send + Sync> OutputValidator<String, Deps> for LengthValidator {
168 async fn validate(
169 &self,
170 output: String,
171 _ctx: &RunContext<Deps>,
172 ) -> Result<String, OutputValidationError> {
173 let len = output.len();
174
175 if let Some(min) = self.min {
176 if len < min {
177 return Err(OutputValidationError::failed(format!(
178 "Output too short: {} < {}",
179 len, min
180 )));
181 }
182 }
183
184 if let Some(max) = self.max {
185 if len > max {
186 return Err(OutputValidationError::failed(format!(
187 "Output too long: {} > {}",
188 len, max
189 )));
190 }
191 }
192
193 Ok(output)
194 }
195}
196
197#[cfg(feature = "regex")]
199pub struct RegexValidator {
200 pattern: regex::Regex,
201 message: String,
202}
203
204#[cfg(feature = "regex")]
205impl RegexValidator {
206 pub fn new(pattern: &str, message: impl Into<String>) -> Result<Self, regex::Error> {
208 Ok(Self {
209 pattern: regex::Regex::new(pattern)?,
210 message: message.into(),
211 })
212 }
213}
214
215#[cfg(feature = "regex")]
216#[async_trait]
217impl<Deps: Send + Sync> OutputValidator<String, Deps> for RegexValidator {
218 async fn validate(
219 &self,
220 output: String,
221 _ctx: &RunContext<Deps>,
222 ) -> Result<String, OutputValidationError> {
223 if self.pattern.is_match(&output) {
224 Ok(output)
225 } else {
226 Err(OutputValidationError::failed(&self.message))
227 }
228 }
229}
230
231pub struct ChainedValidator<Output, Deps> {
237 validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
238}
239
240impl<Output: Send + Sync + 'static, Deps: Send + Sync + 'static> ChainedValidator<Output, Deps> {
241 pub fn new() -> Self {
243 Self {
244 validators: Vec::new(),
245 }
246 }
247
248 #[allow(clippy::should_implement_trait)]
250 pub fn add<V: OutputValidator<Output, Deps> + 'static>(mut self, validator: V) -> Self {
251 self.validators.push(Box::new(validator));
252 self
253 }
254}
255
256impl<Output: Send + Sync + 'static, Deps: Send + Sync + 'static> Default
257 for ChainedValidator<Output, Deps>
258{
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264#[async_trait]
265impl<Output: Send + Sync, Deps: Send + Sync> OutputValidator<Output, Deps>
266 for ChainedValidator<Output, Deps>
267{
268 async fn validate(
269 &self,
270 mut output: Output,
271 ctx: &RunContext<Deps>,
272 ) -> Result<Output, OutputValidationError> {
273 for validator in &self.validators {
274 output = validator.validate(output, ctx).await?;
275 }
276 Ok(output)
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
286pub enum OutputMode {
287 #[default]
289 Text,
290 Json,
292 ToolCall,
294}
295
296pub trait OutputSchema<Output>: Send + Sync {
298 fn json_schema(&self) -> Option<JsonValue> {
300 None
301 }
302
303 fn mode(&self) -> OutputMode {
305 OutputMode::Text
306 }
307
308 fn tool_name(&self) -> Option<&str> {
310 None
311 }
312
313 fn parse_text(&self, text: &str) -> Result<Output, OutputParseError>;
315
316 fn parse_tool_call(&self, _name: &str, _args: &JsonValue) -> Result<Output, OutputParseError> {
318 Err(OutputParseError::ToolNotCalled)
319 }
320}
321
322#[derive(Debug, Clone, Default)]
324pub struct TextOutputSchema;
325
326impl OutputSchema<String> for TextOutputSchema {
327 fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
328 Ok(text.to_string())
329 }
330}
331
332#[derive(Debug, Clone, Default)]
334pub struct DefaultOutputSchema<Output> {
335 _phantom: PhantomData<Output>,
336}
337
338impl<Output> DefaultOutputSchema<Output> {
339 pub fn new() -> Self {
341 Self {
342 _phantom: PhantomData,
343 }
344 }
345}
346
347impl<Output: DeserializeOwned + Send + Sync + 'static> OutputSchema<Output>
348 for DefaultOutputSchema<Output>
349{
350 fn mode(&self) -> OutputMode {
351 if TypeId::of::<Output>() == TypeId::of::<String>() {
352 OutputMode::Text
353 } else {
354 OutputMode::Json
355 }
356 }
357
358 fn parse_text(&self, text: &str) -> Result<Output, OutputParseError> {
359 if TypeId::of::<Output>() == TypeId::of::<String>() {
360 serde_json::from_value(serde_json::Value::String(text.to_string()))
362 .map_err(OutputParseError::Json)
363 } else {
364 let json_str = extract_json(text).unwrap_or(text);
365 serde_json::from_str(json_str).map_err(OutputParseError::Json)
366 }
367 }
368}
369
370pub struct JsonOutputSchema<T> {
372 schema: Option<JsonValue>,
373 _phantom: PhantomData<T>,
374}
375
376impl<T: DeserializeOwned> JsonOutputSchema<T> {
377 pub fn new() -> Self {
379 Self {
380 schema: None,
381 _phantom: PhantomData,
382 }
383 }
384
385 pub fn with_schema(mut self, schema: JsonValue) -> Self {
387 self.schema = Some(schema);
388 self
389 }
390}
391
392impl<T: DeserializeOwned> Default for JsonOutputSchema<T> {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for JsonOutputSchema<T> {
399 fn json_schema(&self) -> Option<JsonValue> {
400 self.schema.clone()
401 }
402
403 fn mode(&self) -> OutputMode {
404 OutputMode::Json
405 }
406
407 fn parse_text(&self, text: &str) -> Result<T, OutputParseError> {
408 let json_str = extract_json(text).unwrap_or(text);
410 serde_json::from_str(json_str).map_err(OutputParseError::Json)
411 }
412}
413
414pub struct ToolOutputSchema<T> {
416 tool_name: String,
417 schema: Option<JsonValue>,
418 _phantom: PhantomData<T>,
419}
420
421impl<T: DeserializeOwned> ToolOutputSchema<T> {
422 pub fn new(tool_name: impl Into<String>) -> Self {
424 Self {
425 tool_name: tool_name.into(),
426 schema: None,
427 _phantom: PhantomData,
428 }
429 }
430
431 pub fn with_schema(mut self, schema: JsonValue) -> Self {
433 self.schema = Some(schema);
434 self
435 }
436}
437
438impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for ToolOutputSchema<T> {
439 fn json_schema(&self) -> Option<JsonValue> {
440 self.schema.clone()
441 }
442
443 fn mode(&self) -> OutputMode {
444 OutputMode::ToolCall
445 }
446
447 fn tool_name(&self) -> Option<&str> {
448 Some(&self.tool_name)
449 }
450
451 fn parse_text(&self, _text: &str) -> Result<T, OutputParseError> {
452 Err(OutputParseError::ToolNotCalled)
453 }
454
455 fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError> {
456 if name != self.tool_name {
457 return Err(OutputParseError::ToolNotCalled);
458 }
459 serde_json::from_value(args.clone()).map_err(OutputParseError::Json)
460 }
461}
462
463fn extract_json(text: &str) -> Option<&str> {
465 if let Some(start) = text.find("```json") {
467 let content_start = start + 7;
468 if let Some(end) = text[content_start..].find("```") {
469 return Some(text[content_start..content_start + end].trim());
470 }
471 }
472
473 if let Some(start) = text.find("```") {
475 let content_start = start + 3;
476 let line_end = text[content_start..].find('\n').unwrap_or(0);
478 let content_start = content_start + line_end + 1;
479 if let Some(end) = text[content_start..].find("```") {
480 let potential = &text[content_start..content_start + end].trim();
481 if potential.starts_with('{') || potential.starts_with('[') {
482 return Some(potential);
483 }
484 }
485 }
486
487 if let Some(start) = text.find('{') {
489 if let Some(end) = text.rfind('}') {
490 if end > start {
491 return Some(&text[start..=end]);
492 }
493 }
494 }
495
496 None
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use chrono::Utc;
503 use std::sync::Arc;
504
505 fn make_context() -> RunContext<()> {
506 RunContext {
507 deps: Arc::new(()),
508 run_id: "test".to_string(),
509 start_time: Utc::now(),
510 model_name: "test".to_string(),
511 model_settings: Default::default(),
512 tool_name: None,
513 tool_call_id: None,
514 retry_count: 0,
515 metadata: None,
516 }
517 }
518
519 #[tokio::test]
520 async fn test_non_empty_validator() {
521 let validator = NonEmptyValidator;
522 let ctx = make_context();
523
524 let result = validator.validate("hello".to_string(), &ctx).await;
525 assert!(result.is_ok());
526
527 let result = validator.validate("".to_string(), &ctx).await;
528 assert!(result.is_err());
529
530 let result = validator.validate(" ".to_string(), &ctx).await;
531 assert!(result.is_err());
532 }
533
534 #[tokio::test]
535 async fn test_length_validator() {
536 let validator = LengthValidator::new().min(5).max(10);
537 let ctx = make_context();
538
539 let result = validator.validate("hello".to_string(), &ctx).await;
540 assert!(result.is_ok());
541
542 let result = validator.validate("hi".to_string(), &ctx).await;
543 assert!(result.is_err());
544
545 let result = validator.validate("hello world!".to_string(), &ctx).await;
546 assert!(result.is_err());
547 }
548
549 #[tokio::test]
550 async fn test_chained_validator() {
551 let validator = ChainedValidator::<String, ()>::new()
552 .add(NonEmptyValidator)
553 .add(LengthValidator::new().min(3));
554
555 let ctx = make_context();
556
557 let result = validator.validate("hello".to_string(), &ctx).await;
558 assert!(result.is_ok());
559
560 let result = validator.validate("hi".to_string(), &ctx).await;
561 assert!(result.is_err());
562 }
563
564 #[test]
565 fn test_text_output_schema() {
566 let schema = TextOutputSchema;
567 let result = schema.parse_text("hello world");
568 assert_eq!(result.unwrap(), "hello world");
569 }
570
571 #[test]
572 fn test_json_output_schema() {
573 use serde::Deserialize;
574
575 #[derive(Debug, Deserialize, PartialEq)]
576 struct Person {
577 name: String,
578 age: u32,
579 }
580
581 let schema = JsonOutputSchema::<Person>::new();
582
583 let result = schema.parse_text(r#"{"name": "Alice", "age": 30}"#);
585 assert_eq!(
586 result.unwrap(),
587 Person {
588 name: "Alice".to_string(),
589 age: 30
590 }
591 );
592
593 let text = r#"Here's the person:
595```json
596{"name": "Bob", "age": 25}
597```"#;
598 let result = schema.parse_text(text);
599 assert_eq!(
600 result.unwrap(),
601 Person {
602 name: "Bob".to_string(),
603 age: 25
604 }
605 );
606 }
607
608 #[test]
609 fn test_extract_json() {
610 let text = "Here's some JSON: {\"a\": 1}";
611 assert_eq!(extract_json(text), Some("{\"a\": 1}"));
612
613 let text = "```json\n{\"a\": 1}\n```";
614 assert!(extract_json(text).is_some());
615 }
616}