Skip to main content

symbi_runtime/reasoning/
schema_validation.rs

1//! Schema-first validation pipeline
2//!
3//! Provides a layered validation pipeline for LLM output:
4//! 1. Strip markdown fences
5//! 2. Parse as JSON
6//! 3. Validate against JSON Schema
7//! 4. Deserialize into target Rust type
8//!
9//! Each layer produces actionable error messages that can be fed back to
10//! the LLM as observations for self-correction.
11
12use crate::reasoning::providers::slm::strip_markdown_fences;
13use serde::de::DeserializeOwned;
14
15/// Errors from the validation pipeline, ordered by severity.
16///
17/// Each variant contains an actionable message suitable for feeding
18/// back to an LLM as an observation.
19#[derive(Debug, thiserror::Error)]
20pub enum SchemaValidationError {
21    /// The raw text couldn't be parsed as JSON.
22    #[error("JSON parse error at line {line}, column {column}: {message}. Raw text starts with: {raw_prefix:?}")]
23    JsonParseError {
24        message: String,
25        line: usize,
26        column: usize,
27        raw_prefix: String,
28    },
29
30    /// The JSON is valid but doesn't conform to the expected schema.
31    #[error("Schema validation failed: {errors:?}")]
32    SchemaViolation { errors: Vec<String> },
33
34    /// The JSON conforms to the schema but couldn't be deserialized into
35    /// the target Rust type (usually a serde issue).
36    #[error("Deserialization error: {message}")]
37    DeserializationError { message: String },
38}
39
40impl SchemaValidationError {
41    /// Format as a concise feedback message for the LLM.
42    pub fn to_llm_feedback(&self) -> String {
43        match self {
44            SchemaValidationError::JsonParseError {
45                message,
46                line,
47                column,
48                ..
49            } => {
50                format!(
51                    "Your response was not valid JSON. Error at line {}, column {}: {}. Please respond with a valid JSON object.",
52                    line, column, message
53                )
54            }
55            SchemaValidationError::SchemaViolation { errors } => {
56                let error_list = errors.join("; ");
57                format!(
58                    "Your JSON response did not match the required schema. Issues: {}. Please fix these and try again.",
59                    error_list
60                )
61            }
62            SchemaValidationError::DeserializationError { message } => {
63                format!(
64                    "Your JSON had the right structure but contained invalid values: {}. Please correct the values.",
65                    message
66                )
67            }
68        }
69    }
70}
71
72/// The validation pipeline: parses, validates, and deserializes LLM output.
73///
74/// Supports two modes:
75/// - **Static (typed)**: `validate_and_parse::<T>()` for compile-time Rust types
76/// - **Dynamic**: `validate_dynamic()` for runtime-defined schemas from the DSL
77///
78/// The dynamic path validates `serde_json::Value` against a JSON Schema without
79/// requiring a Rust type, which is essential for user-defined output shapes
80/// composed at runtime via the DSL.
81pub struct ValidationPipeline;
82
83impl ValidationPipeline {
84    /// Run the full validation pipeline with static typing:
85    /// strip fences → parse JSON → validate → deserialize into `T`.
86    ///
87    /// Use this when you have a compile-time Rust type for the output.
88    pub fn validate_and_parse<T: DeserializeOwned>(
89        raw_text: &str,
90        schema: Option<&jsonschema::Validator>,
91    ) -> Result<T, SchemaValidationError> {
92        let json_value = Self::parse_and_validate(raw_text, schema)?;
93
94        // Deserialize into target type
95        serde_json::from_value(json_value).map_err(|e| {
96            SchemaValidationError::DeserializationError {
97                message: e.to_string(),
98            }
99        })
100    }
101
102    /// Run the validation pipeline for dynamic schemas:
103    /// strip fences → parse JSON → validate against schema → return Value.
104    ///
105    /// Use this when output shapes are defined at runtime (e.g., from the DSL).
106    /// The returned `serde_json::Value` is guaranteed to conform to the schema.
107    pub fn validate_dynamic(
108        raw_text: &str,
109        schema: Option<&jsonschema::Validator>,
110    ) -> Result<serde_json::Value, SchemaValidationError> {
111        Self::parse_and_validate(raw_text, schema)
112    }
113
114    /// Common pipeline: strip fences → parse JSON → validate against schema.
115    fn parse_and_validate(
116        raw_text: &str,
117        schema: Option<&jsonschema::Validator>,
118    ) -> Result<serde_json::Value, SchemaValidationError> {
119        // Step 1: Strip markdown fences
120        let cleaned = strip_markdown_fences(raw_text);
121
122        // Step 2: Parse as JSON
123        let json_value: serde_json::Value = serde_json::from_str(&cleaned).map_err(|e| {
124            let prefix = if cleaned.len() > 100 {
125                format!("{}...", &cleaned[..100])
126            } else {
127                cleaned.clone()
128            };
129            SchemaValidationError::JsonParseError {
130                message: e.to_string(),
131                line: e.line(),
132                column: e.column(),
133                raw_prefix: prefix,
134            }
135        })?;
136
137        // Step 3: Validate against schema if provided
138        if let Some(validator) = schema {
139            Self::check_schema_errors(&json_value, validator)?;
140        }
141
142        Ok(json_value)
143    }
144
145    /// Validate a JSON value against a pre-compiled schema and collect errors.
146    fn check_schema_errors(
147        value: &serde_json::Value,
148        validator: &jsonschema::Validator,
149    ) -> Result<(), SchemaValidationError> {
150        let errors: Vec<String> = validator
151            .iter_errors(value)
152            .map(|e| {
153                let path = e.instance_path.to_string();
154                if path.is_empty() {
155                    e.to_string()
156                } else {
157                    format!("at '{}': {}", path, e)
158                }
159            })
160            .collect();
161
162        if errors.is_empty() {
163            Ok(())
164        } else {
165            Err(SchemaValidationError::SchemaViolation { errors })
166        }
167    }
168
169    /// Parse raw text as JSON without schema validation.
170    pub fn parse_json(raw_text: &str) -> Result<serde_json::Value, SchemaValidationError> {
171        let cleaned = strip_markdown_fences(raw_text);
172        serde_json::from_str(&cleaned).map_err(|e| {
173            let prefix = if cleaned.len() > 100 {
174                format!("{}...", &cleaned[..100])
175            } else {
176                cleaned.clone()
177            };
178            SchemaValidationError::JsonParseError {
179                message: e.to_string(),
180                line: e.line(),
181                column: e.column(),
182                raw_prefix: prefix,
183            }
184        })
185    }
186
187    /// Validate a JSON value against a pre-compiled schema.
188    pub fn validate_schema(
189        value: &serde_json::Value,
190        validator: &jsonschema::Validator,
191    ) -> Result<(), SchemaValidationError> {
192        Self::check_schema_errors(value, validator)
193    }
194
195    /// Create a validator from a raw JSON Schema value.
196    ///
197    /// This is the primary way to create validators for dynamic schemas
198    /// defined at runtime (e.g., from DSL configurations or TOML pipelines).
199    pub fn compile_schema(
200        schema: &serde_json::Value,
201    ) -> Result<jsonschema::Validator, SchemaValidationError> {
202        jsonschema::validator_for(schema).map_err(|e| SchemaValidationError::SchemaViolation {
203            errors: vec![format!("Invalid schema: {}", e)],
204        })
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use serde::Deserialize;
212
213    #[derive(Debug, Deserialize, PartialEq)]
214    struct TestOutput {
215        answer: String,
216        confidence: f64,
217    }
218
219    fn make_validator(schema: &serde_json::Value) -> jsonschema::Validator {
220        jsonschema::validator_for(schema).expect("valid schema")
221    }
222
223    #[test]
224    fn test_validate_and_parse_valid() {
225        let schema = serde_json::json!({
226            "type": "object",
227            "properties": {
228                "answer": {"type": "string"},
229                "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
230            },
231            "required": ["answer", "confidence"]
232        });
233        let validator = make_validator(&schema);
234
235        let raw = r#"{"answer": "42", "confidence": 0.95}"#;
236        let result: TestOutput =
237            ValidationPipeline::validate_and_parse(raw, Some(&validator)).unwrap();
238        assert_eq!(result.answer, "42");
239        assert!((result.confidence - 0.95).abs() < f64::EPSILON);
240    }
241
242    #[test]
243    fn test_validate_and_parse_markdown_fenced() {
244        let schema = serde_json::json!({
245            "type": "object",
246            "properties": {
247                "answer": {"type": "string"},
248                "confidence": {"type": "number"}
249            },
250            "required": ["answer", "confidence"]
251        });
252        let validator = make_validator(&schema);
253
254        let raw = "```json\n{\"answer\": \"hello\", \"confidence\": 0.8}\n```";
255        let result: TestOutput =
256            ValidationPipeline::validate_and_parse(raw, Some(&validator)).unwrap();
257        assert_eq!(result.answer, "hello");
258    }
259
260    #[test]
261    fn test_validate_and_parse_invalid_json() {
262        let raw = "This is not JSON at all";
263        let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, None);
264        assert!(result.is_err());
265        let err = result.unwrap_err();
266        assert!(matches!(err, SchemaValidationError::JsonParseError { .. }));
267
268        let feedback = err.to_llm_feedback();
269        assert!(feedback.contains("not valid JSON"));
270    }
271
272    #[test]
273    fn test_validate_and_parse_schema_violation() {
274        let schema = serde_json::json!({
275            "type": "object",
276            "properties": {
277                "answer": {"type": "string"},
278                "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
279            },
280            "required": ["answer", "confidence"]
281        });
282        let validator = make_validator(&schema);
283
284        // Missing required field "confidence"
285        let raw = r#"{"answer": "42"}"#;
286        let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, Some(&validator));
287        assert!(result.is_err());
288        let err = result.unwrap_err();
289        assert!(matches!(err, SchemaValidationError::SchemaViolation { .. }));
290
291        let feedback = err.to_llm_feedback();
292        assert!(feedback.contains("did not match the required schema"));
293    }
294
295    #[test]
296    fn test_validate_and_parse_out_of_range() {
297        let schema = serde_json::json!({
298            "type": "object",
299            "properties": {
300                "answer": {"type": "string"},
301                "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
302            },
303            "required": ["answer", "confidence"]
304        });
305        let validator = make_validator(&schema);
306
307        // confidence out of range
308        let raw = r#"{"answer": "42", "confidence": 1.5}"#;
309        let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, Some(&validator));
310        assert!(result.is_err());
311        assert!(matches!(
312            result.unwrap_err(),
313            SchemaValidationError::SchemaViolation { .. }
314        ));
315    }
316
317    #[test]
318    fn test_validate_and_parse_no_schema() {
319        let raw = r#"{"answer": "hello", "confidence": 0.5}"#;
320        let result: TestOutput = ValidationPipeline::validate_and_parse(raw, None).unwrap();
321        assert_eq!(result.answer, "hello");
322    }
323
324    #[test]
325    fn test_parse_json_standalone() {
326        let raw = "```json\n{\"key\": \"value\"}\n```";
327        let value = ValidationPipeline::parse_json(raw).unwrap();
328        assert_eq!(value["key"], "value");
329    }
330
331    #[test]
332    fn test_validate_schema_standalone() {
333        let schema = serde_json::json!({
334            "type": "object",
335            "required": ["name"]
336        });
337        let validator = make_validator(&schema);
338
339        let valid = serde_json::json!({"name": "test"});
340        assert!(ValidationPipeline::validate_schema(&valid, &validator).is_ok());
341
342        let invalid = serde_json::json!({"other": "field"});
343        assert!(ValidationPipeline::validate_schema(&invalid, &validator).is_err());
344    }
345
346    #[test]
347    fn test_error_feedback_messages() {
348        let json_err = SchemaValidationError::JsonParseError {
349            message: "expected value".into(),
350            line: 1,
351            column: 1,
352            raw_prefix: "bad input".into(),
353        };
354        let feedback = json_err.to_llm_feedback();
355        assert!(feedback.contains("not valid JSON"));
356        assert!(feedback.contains("line 1"));
357
358        let schema_err = SchemaValidationError::SchemaViolation {
359            errors: vec!["missing field 'name'".into()],
360        };
361        let feedback = schema_err.to_llm_feedback();
362        assert!(feedback.contains("missing field 'name'"));
363
364        let deser_err = SchemaValidationError::DeserializationError {
365            message: "invalid type: string, expected f64".into(),
366        };
367        let feedback = deser_err.to_llm_feedback();
368        assert!(feedback.contains("invalid values"));
369    }
370
371    #[test]
372    fn test_validate_dynamic_valid() {
373        let schema = serde_json::json!({
374            "type": "object",
375            "properties": {
376                "result": {"type": "string"},
377                "score": {"type": "number"}
378            },
379            "required": ["result"]
380        });
381        let validator = make_validator(&schema);
382
383        let raw = r#"{"result": "success", "score": 95.5}"#;
384        let value = ValidationPipeline::validate_dynamic(raw, Some(&validator)).unwrap();
385        assert_eq!(value["result"], "success");
386        assert_eq!(value["score"], 95.5);
387    }
388
389    #[test]
390    fn test_validate_dynamic_invalid() {
391        let schema = serde_json::json!({
392            "type": "object",
393            "properties": {
394                "name": {"type": "string"}
395            },
396            "required": ["name"]
397        });
398        let validator = make_validator(&schema);
399
400        let raw = r#"{"other": "field"}"#;
401        let result = ValidationPipeline::validate_dynamic(raw, Some(&validator));
402        assert!(result.is_err());
403    }
404
405    #[test]
406    fn test_validate_dynamic_arbitrary_shape() {
407        // Simulate a DSL-defined output schema at runtime
408        let user_defined_schema = serde_json::json!({
409            "type": "object",
410            "properties": {
411                "tasks": {
412                    "type": "array",
413                    "items": {
414                        "type": "object",
415                        "properties": {
416                            "id": {"type": "integer"},
417                            "description": {"type": "string"},
418                            "priority": {"type": "string", "enum": ["low", "medium", "high"]}
419                        },
420                        "required": ["id", "description"]
421                    }
422                },
423                "summary": {"type": "string"}
424            },
425            "required": ["tasks", "summary"]
426        });
427        let validator = make_validator(&user_defined_schema);
428
429        let raw = r#"{"tasks": [{"id": 1, "description": "Do thing", "priority": "high"}], "summary": "One task"}"#;
430        let value = ValidationPipeline::validate_dynamic(raw, Some(&validator)).unwrap();
431        assert_eq!(value["tasks"][0]["priority"], "high");
432        assert_eq!(value["summary"], "One task");
433
434        // Invalid: wrong priority enum value
435        let bad = r#"{"tasks": [{"id": 1, "description": "Do thing", "priority": "urgent"}], "summary": "x"}"#;
436        let result = ValidationPipeline::validate_dynamic(bad, Some(&validator));
437        assert!(result.is_err());
438    }
439
440    #[test]
441    fn test_compile_schema_valid() {
442        let schema = serde_json::json!({"type": "object"});
443        assert!(ValidationPipeline::compile_schema(&schema).is_ok());
444    }
445
446    #[test]
447    fn test_compile_schema_invalid() {
448        let schema = serde_json::json!({"type": "not_a_type"});
449        assert!(ValidationPipeline::compile_schema(&schema).is_err());
450    }
451
452    #[test]
453    fn test_validator_performance() {
454        // Verify that pre-compiled validators are fast (<100μs for typical schemas)
455        let schema = serde_json::json!({
456            "type": "object",
457            "properties": {
458                "name": {"type": "string", "maxLength": 100},
459                "score": {"type": "number", "minimum": 0, "maximum": 100},
460                "tags": {"type": "array", "items": {"type": "string"}},
461                "metadata": {
462                    "type": "object",
463                    "properties": {
464                        "source": {"type": "string"},
465                        "timestamp": {"type": "string"}
466                    }
467                }
468            },
469            "required": ["name", "score"]
470        });
471        let validator = make_validator(&schema);
472
473        let valid_input = serde_json::json!({
474            "name": "test agent output",
475            "score": 85.5,
476            "tags": ["analysis", "research"],
477            "metadata": {"source": "web", "timestamp": "2024-01-01T00:00:00Z"}
478        });
479
480        let start = std::time::Instant::now();
481        for _ in 0..1000 {
482            let _ = ValidationPipeline::validate_schema(&valid_input, &validator);
483        }
484        let elapsed = start.elapsed();
485        let per_validation = elapsed / 1000;
486
487        // Pre-compiled validator should be well under 100μs per validation
488        assert!(
489            per_validation.as_micros() < 100,
490            "Validation took {}μs, expected <100μs",
491            per_validation.as_micros()
492        );
493    }
494}