1use crate::reasoning::providers::slm::strip_markdown_fences;
13use serde::de::DeserializeOwned;
14
15#[derive(Debug, thiserror::Error)]
20pub enum SchemaValidationError {
21 #[error("JSON parse error at line {line}, column {column}: {message}. Raw text starts with: {raw_prefix:?}")]
23 JsonParseError {
24 message: String,
25 line: usize,
26 column: usize,
27 raw_prefix: String,
28 },
29
30 #[error("Schema validation failed: {errors:?}")]
32 SchemaViolation { errors: Vec<String> },
33
34 #[error("Deserialization error: {message}")]
37 DeserializationError { message: String },
38}
39
40impl SchemaValidationError {
41 pub fn to_llm_feedback(&self) -> String {
43 match self {
44 SchemaValidationError::JsonParseError {
45 message,
46 line,
47 column,
48 ..
49 } => {
50 format!(
51 "Your response was not valid JSON. Error at line {}, column {}: {}. Please respond with a valid JSON object.",
52 line, column, message
53 )
54 }
55 SchemaValidationError::SchemaViolation { errors } => {
56 let error_list = errors.join("; ");
57 format!(
58 "Your JSON response did not match the required schema. Issues: {}. Please fix these and try again.",
59 error_list
60 )
61 }
62 SchemaValidationError::DeserializationError { message } => {
63 format!(
64 "Your JSON had the right structure but contained invalid values: {}. Please correct the values.",
65 message
66 )
67 }
68 }
69 }
70}
71
72pub struct ValidationPipeline;
82
83impl ValidationPipeline {
84 pub fn validate_and_parse<T: DeserializeOwned>(
89 raw_text: &str,
90 schema: Option<&jsonschema::Validator>,
91 ) -> Result<T, SchemaValidationError> {
92 let json_value = Self::parse_and_validate(raw_text, schema)?;
93
94 serde_json::from_value(json_value).map_err(|e| {
96 SchemaValidationError::DeserializationError {
97 message: e.to_string(),
98 }
99 })
100 }
101
102 pub fn validate_dynamic(
108 raw_text: &str,
109 schema: Option<&jsonschema::Validator>,
110 ) -> Result<serde_json::Value, SchemaValidationError> {
111 Self::parse_and_validate(raw_text, schema)
112 }
113
114 fn parse_and_validate(
116 raw_text: &str,
117 schema: Option<&jsonschema::Validator>,
118 ) -> Result<serde_json::Value, SchemaValidationError> {
119 let cleaned = strip_markdown_fences(raw_text);
121
122 let json_value: serde_json::Value = serde_json::from_str(&cleaned).map_err(|e| {
124 let prefix = if cleaned.len() > 100 {
125 format!("{}...", &cleaned[..100])
126 } else {
127 cleaned.clone()
128 };
129 SchemaValidationError::JsonParseError {
130 message: e.to_string(),
131 line: e.line(),
132 column: e.column(),
133 raw_prefix: prefix,
134 }
135 })?;
136
137 if let Some(validator) = schema {
139 Self::check_schema_errors(&json_value, validator)?;
140 }
141
142 Ok(json_value)
143 }
144
145 fn check_schema_errors(
147 value: &serde_json::Value,
148 validator: &jsonschema::Validator,
149 ) -> Result<(), SchemaValidationError> {
150 let errors: Vec<String> = validator
151 .iter_errors(value)
152 .map(|e| {
153 let path = e.instance_path.to_string();
154 if path.is_empty() {
155 e.to_string()
156 } else {
157 format!("at '{}': {}", path, e)
158 }
159 })
160 .collect();
161
162 if errors.is_empty() {
163 Ok(())
164 } else {
165 Err(SchemaValidationError::SchemaViolation { errors })
166 }
167 }
168
169 pub fn parse_json(raw_text: &str) -> Result<serde_json::Value, SchemaValidationError> {
171 let cleaned = strip_markdown_fences(raw_text);
172 serde_json::from_str(&cleaned).map_err(|e| {
173 let prefix = if cleaned.len() > 100 {
174 format!("{}...", &cleaned[..100])
175 } else {
176 cleaned.clone()
177 };
178 SchemaValidationError::JsonParseError {
179 message: e.to_string(),
180 line: e.line(),
181 column: e.column(),
182 raw_prefix: prefix,
183 }
184 })
185 }
186
187 pub fn validate_schema(
189 value: &serde_json::Value,
190 validator: &jsonschema::Validator,
191 ) -> Result<(), SchemaValidationError> {
192 Self::check_schema_errors(value, validator)
193 }
194
195 pub fn compile_schema(
200 schema: &serde_json::Value,
201 ) -> Result<jsonschema::Validator, SchemaValidationError> {
202 jsonschema::validator_for(schema).map_err(|e| SchemaValidationError::SchemaViolation {
203 errors: vec![format!("Invalid schema: {}", e)],
204 })
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use serde::Deserialize;
212
213 #[derive(Debug, Deserialize, PartialEq)]
214 struct TestOutput {
215 answer: String,
216 confidence: f64,
217 }
218
219 fn make_validator(schema: &serde_json::Value) -> jsonschema::Validator {
220 jsonschema::validator_for(schema).expect("valid schema")
221 }
222
223 #[test]
224 fn test_validate_and_parse_valid() {
225 let schema = serde_json::json!({
226 "type": "object",
227 "properties": {
228 "answer": {"type": "string"},
229 "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
230 },
231 "required": ["answer", "confidence"]
232 });
233 let validator = make_validator(&schema);
234
235 let raw = r#"{"answer": "42", "confidence": 0.95}"#;
236 let result: TestOutput =
237 ValidationPipeline::validate_and_parse(raw, Some(&validator)).unwrap();
238 assert_eq!(result.answer, "42");
239 assert!((result.confidence - 0.95).abs() < f64::EPSILON);
240 }
241
242 #[test]
243 fn test_validate_and_parse_markdown_fenced() {
244 let schema = serde_json::json!({
245 "type": "object",
246 "properties": {
247 "answer": {"type": "string"},
248 "confidence": {"type": "number"}
249 },
250 "required": ["answer", "confidence"]
251 });
252 let validator = make_validator(&schema);
253
254 let raw = "```json\n{\"answer\": \"hello\", \"confidence\": 0.8}\n```";
255 let result: TestOutput =
256 ValidationPipeline::validate_and_parse(raw, Some(&validator)).unwrap();
257 assert_eq!(result.answer, "hello");
258 }
259
260 #[test]
261 fn test_validate_and_parse_invalid_json() {
262 let raw = "This is not JSON at all";
263 let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, None);
264 assert!(result.is_err());
265 let err = result.unwrap_err();
266 assert!(matches!(err, SchemaValidationError::JsonParseError { .. }));
267
268 let feedback = err.to_llm_feedback();
269 assert!(feedback.contains("not valid JSON"));
270 }
271
272 #[test]
273 fn test_validate_and_parse_schema_violation() {
274 let schema = serde_json::json!({
275 "type": "object",
276 "properties": {
277 "answer": {"type": "string"},
278 "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
279 },
280 "required": ["answer", "confidence"]
281 });
282 let validator = make_validator(&schema);
283
284 let raw = r#"{"answer": "42"}"#;
286 let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, Some(&validator));
287 assert!(result.is_err());
288 let err = result.unwrap_err();
289 assert!(matches!(err, SchemaValidationError::SchemaViolation { .. }));
290
291 let feedback = err.to_llm_feedback();
292 assert!(feedback.contains("did not match the required schema"));
293 }
294
295 #[test]
296 fn test_validate_and_parse_out_of_range() {
297 let schema = serde_json::json!({
298 "type": "object",
299 "properties": {
300 "answer": {"type": "string"},
301 "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
302 },
303 "required": ["answer", "confidence"]
304 });
305 let validator = make_validator(&schema);
306
307 let raw = r#"{"answer": "42", "confidence": 1.5}"#;
309 let result = ValidationPipeline::validate_and_parse::<TestOutput>(raw, Some(&validator));
310 assert!(result.is_err());
311 assert!(matches!(
312 result.unwrap_err(),
313 SchemaValidationError::SchemaViolation { .. }
314 ));
315 }
316
317 #[test]
318 fn test_validate_and_parse_no_schema() {
319 let raw = r#"{"answer": "hello", "confidence": 0.5}"#;
320 let result: TestOutput = ValidationPipeline::validate_and_parse(raw, None).unwrap();
321 assert_eq!(result.answer, "hello");
322 }
323
324 #[test]
325 fn test_parse_json_standalone() {
326 let raw = "```json\n{\"key\": \"value\"}\n```";
327 let value = ValidationPipeline::parse_json(raw).unwrap();
328 assert_eq!(value["key"], "value");
329 }
330
331 #[test]
332 fn test_validate_schema_standalone() {
333 let schema = serde_json::json!({
334 "type": "object",
335 "required": ["name"]
336 });
337 let validator = make_validator(&schema);
338
339 let valid = serde_json::json!({"name": "test"});
340 assert!(ValidationPipeline::validate_schema(&valid, &validator).is_ok());
341
342 let invalid = serde_json::json!({"other": "field"});
343 assert!(ValidationPipeline::validate_schema(&invalid, &validator).is_err());
344 }
345
346 #[test]
347 fn test_error_feedback_messages() {
348 let json_err = SchemaValidationError::JsonParseError {
349 message: "expected value".into(),
350 line: 1,
351 column: 1,
352 raw_prefix: "bad input".into(),
353 };
354 let feedback = json_err.to_llm_feedback();
355 assert!(feedback.contains("not valid JSON"));
356 assert!(feedback.contains("line 1"));
357
358 let schema_err = SchemaValidationError::SchemaViolation {
359 errors: vec!["missing field 'name'".into()],
360 };
361 let feedback = schema_err.to_llm_feedback();
362 assert!(feedback.contains("missing field 'name'"));
363
364 let deser_err = SchemaValidationError::DeserializationError {
365 message: "invalid type: string, expected f64".into(),
366 };
367 let feedback = deser_err.to_llm_feedback();
368 assert!(feedback.contains("invalid values"));
369 }
370
371 #[test]
372 fn test_validate_dynamic_valid() {
373 let schema = serde_json::json!({
374 "type": "object",
375 "properties": {
376 "result": {"type": "string"},
377 "score": {"type": "number"}
378 },
379 "required": ["result"]
380 });
381 let validator = make_validator(&schema);
382
383 let raw = r#"{"result": "success", "score": 95.5}"#;
384 let value = ValidationPipeline::validate_dynamic(raw, Some(&validator)).unwrap();
385 assert_eq!(value["result"], "success");
386 assert_eq!(value["score"], 95.5);
387 }
388
389 #[test]
390 fn test_validate_dynamic_invalid() {
391 let schema = serde_json::json!({
392 "type": "object",
393 "properties": {
394 "name": {"type": "string"}
395 },
396 "required": ["name"]
397 });
398 let validator = make_validator(&schema);
399
400 let raw = r#"{"other": "field"}"#;
401 let result = ValidationPipeline::validate_dynamic(raw, Some(&validator));
402 assert!(result.is_err());
403 }
404
405 #[test]
406 fn test_validate_dynamic_arbitrary_shape() {
407 let user_defined_schema = serde_json::json!({
409 "type": "object",
410 "properties": {
411 "tasks": {
412 "type": "array",
413 "items": {
414 "type": "object",
415 "properties": {
416 "id": {"type": "integer"},
417 "description": {"type": "string"},
418 "priority": {"type": "string", "enum": ["low", "medium", "high"]}
419 },
420 "required": ["id", "description"]
421 }
422 },
423 "summary": {"type": "string"}
424 },
425 "required": ["tasks", "summary"]
426 });
427 let validator = make_validator(&user_defined_schema);
428
429 let raw = r#"{"tasks": [{"id": 1, "description": "Do thing", "priority": "high"}], "summary": "One task"}"#;
430 let value = ValidationPipeline::validate_dynamic(raw, Some(&validator)).unwrap();
431 assert_eq!(value["tasks"][0]["priority"], "high");
432 assert_eq!(value["summary"], "One task");
433
434 let bad = r#"{"tasks": [{"id": 1, "description": "Do thing", "priority": "urgent"}], "summary": "x"}"#;
436 let result = ValidationPipeline::validate_dynamic(bad, Some(&validator));
437 assert!(result.is_err());
438 }
439
440 #[test]
441 fn test_compile_schema_valid() {
442 let schema = serde_json::json!({"type": "object"});
443 assert!(ValidationPipeline::compile_schema(&schema).is_ok());
444 }
445
446 #[test]
447 fn test_compile_schema_invalid() {
448 let schema = serde_json::json!({"type": "not_a_type"});
449 assert!(ValidationPipeline::compile_schema(&schema).is_err());
450 }
451
452 #[test]
453 fn test_validator_performance() {
454 let schema = serde_json::json!({
456 "type": "object",
457 "properties": {
458 "name": {"type": "string", "maxLength": 100},
459 "score": {"type": "number", "minimum": 0, "maximum": 100},
460 "tags": {"type": "array", "items": {"type": "string"}},
461 "metadata": {
462 "type": "object",
463 "properties": {
464 "source": {"type": "string"},
465 "timestamp": {"type": "string"}
466 }
467 }
468 },
469 "required": ["name", "score"]
470 });
471 let validator = make_validator(&schema);
472
473 let valid_input = serde_json::json!({
474 "name": "test agent output",
475 "score": 85.5,
476 "tags": ["analysis", "research"],
477 "metadata": {"source": "web", "timestamp": "2024-01-01T00:00:00Z"}
478 });
479
480 let start = std::time::Instant::now();
481 for _ in 0..1000 {
482 let _ = ValidationPipeline::validate_schema(&valid_input, &validator);
483 }
484 let elapsed = start.elapsed();
485 let per_validation = elapsed / 1000;
486
487 assert!(
489 per_validation.as_micros() < 100,
490 "Validation took {}μs, expected <100μs",
491 per_validation.as_micros()
492 );
493 }
494}