spikard_core/
validation.rs

1//! Request/response validation using JSON Schema
2
3use crate::debug_log_module;
4use jsonschema::Validator;
5use serde_json::Value;
6use std::sync::Arc;
7
8/// Schema validator that compiles and validates JSON Schema
9#[derive(Clone)]
10pub struct SchemaValidator {
11    compiled: Arc<Validator>,
12    schema: Value,
13}
14
15impl SchemaValidator {
16    /// Create a new validator from a JSON Schema
17    pub fn new(schema: Value) -> Result<Self, String> {
18        let compiled = jsonschema::options()
19            .with_draft(jsonschema::Draft::Draft202012)
20            .should_validate_formats(true)
21            .with_pattern_options(jsonschema::PatternOptions::regex())
22            .build(&schema)
23            .map_err(|e| {
24                anyhow::anyhow!("Invalid JSON Schema")
25                    .context(format!("Schema compilation failed: {}", e))
26                    .to_string()
27            })?;
28
29        Ok(Self {
30            compiled: Arc::new(compiled),
31            schema,
32        })
33    }
34
35    /// Get the underlying JSON Schema
36    pub fn schema(&self) -> &Value {
37        &self.schema
38    }
39
40    /// Pre-process data to convert file objects to strings for format: "binary" validation
41    ///
42    /// Files uploaded via multipart are converted to objects like:
43    /// {"filename": "...", "size": N, "content": "...", "content_type": "..."}
44    ///
45    /// But schemas define them as: {"type": "string", "format": "binary"}
46    ///
47    /// This method recursively processes the data and converts file objects to their content strings
48    /// so that validation passes, while preserving the original structure for handlers to use.
49    fn preprocess_binary_fields(&self, data: &Value) -> Value {
50        self.preprocess_value_with_schema(data, &self.schema)
51    }
52
53    #[allow(clippy::only_used_in_recursion)]
54    fn preprocess_value_with_schema(&self, data: &Value, schema: &Value) -> Value {
55        if let Some(schema_obj) = schema.as_object() {
56            let is_string_type = schema_obj.get("type").and_then(|t| t.as_str()) == Some("string");
57            let is_binary_format = schema_obj.get("format").and_then(|f| f.as_str()) == Some("binary");
58
59            #[allow(clippy::collapsible_if)]
60            if is_string_type && is_binary_format {
61                if let Some(data_obj) = data.as_object() {
62                    if data_obj.contains_key("filename")
63                        && data_obj.contains_key("content")
64                        && data_obj.contains_key("size")
65                        && data_obj.contains_key("content_type")
66                    {
67                        return data_obj.get("content").unwrap_or(&Value::Null).clone();
68                    }
69                }
70                return data.clone();
71            }
72
73            #[allow(clippy::collapsible_if)]
74            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("array") {
75                if let Some(items_schema) = schema_obj.get("items") {
76                    if let Some(data_array) = data.as_array() {
77                        let processed_array: Vec<Value> = data_array
78                            .iter()
79                            .map(|item| self.preprocess_value_with_schema(item, items_schema))
80                            .collect();
81                        return Value::Array(processed_array);
82                    }
83                }
84            }
85
86            #[allow(clippy::collapsible_if)]
87            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("object") {
88                if let Some(properties) = schema_obj.get("properties").and_then(|p| p.as_object()) {
89                    if let Some(data_obj) = data.as_object() {
90                        let mut processed_obj = serde_json::Map::new();
91                        for (key, value) in data_obj {
92                            if let Some(prop_schema) = properties.get(key) {
93                                processed_obj
94                                    .insert(key.clone(), self.preprocess_value_with_schema(value, prop_schema));
95                            } else {
96                                processed_obj.insert(key.clone(), value.clone());
97                            }
98                        }
99                        return Value::Object(processed_obj);
100                    }
101                }
102            }
103        }
104
105        data.clone()
106    }
107
108    /// Validate JSON data against the schema
109    pub fn validate(&self, data: &Value) -> Result<(), ValidationError> {
110        let processed_data = self.preprocess_binary_fields(data);
111
112        let validation_errors: Vec<_> = self.compiled.iter_errors(&processed_data).collect();
113
114        if validation_errors.is_empty() {
115            return Ok(());
116        }
117
118        let errors: Vec<ValidationErrorDetail> = validation_errors
119            .into_iter()
120            .map(|err| {
121                let instance_path = err.instance_path().to_string();
122                let schema_path_str = err.schema_path().as_str();
123                let error_msg = err.to_string();
124
125                let param_name = if schema_path_str.ends_with("/required") {
126                    let field_name = if let Some(start) = error_msg.find('"') {
127                        if let Some(end) = error_msg[start + 1..].find('"') {
128                            error_msg[start + 1..start + 1 + end].to_string()
129                        } else {
130                            "".to_string()
131                        }
132                    } else {
133                        "".to_string()
134                    };
135
136                    if !instance_path.is_empty() && instance_path.starts_with('/') && instance_path.len() > 1 {
137                        let base_path = &instance_path[1..];
138                        if !field_name.is_empty() {
139                            format!("{}/{}", base_path, field_name)
140                        } else {
141                            base_path.to_string()
142                        }
143                    } else if !field_name.is_empty() {
144                        field_name
145                    } else {
146                        "body".to_string()
147                    }
148                } else if schema_path_str.contains("/additionalProperties") {
149                    if let Some(start) = error_msg.find('(') {
150                        if let Some(quote_start) = error_msg[start..].find('\'') {
151                            let abs_start = start + quote_start + 1;
152                            if let Some(quote_end) = error_msg[abs_start..].find('\'') {
153                                let property_name = error_msg[abs_start..abs_start + quote_end].to_string();
154                                if !instance_path.is_empty()
155                                    && instance_path.starts_with('/')
156                                    && instance_path.len() > 1
157                                {
158                                    format!("{}/{}", &instance_path[1..], property_name)
159                                } else {
160                                    property_name
161                                }
162                            } else {
163                                instance_path[1..].to_string()
164                            }
165                        } else {
166                            instance_path[1..].to_string()
167                        }
168                    } else if instance_path.starts_with('/') && instance_path.len() > 1 {
169                        instance_path[1..].to_string()
170                    } else {
171                        "body".to_string()
172                    }
173                } else if instance_path.starts_with('/') && instance_path.len() > 1 {
174                    instance_path[1..].to_string()
175                } else if instance_path.is_empty() {
176                    "body".to_string()
177                } else {
178                    instance_path
179                };
180
181                let loc_parts: Vec<String> = if param_name.contains('/') {
182                    let mut parts = vec!["body".to_string()];
183                    parts.extend(param_name.split('/').map(|s| s.to_string()));
184                    parts
185                } else if param_name == "body" {
186                    vec!["body".to_string()]
187                } else {
188                    vec!["body".to_string(), param_name.clone()]
189                };
190
191                let input_value = if schema_path_str == "/required" {
192                    data.clone()
193                } else {
194                    err.instance().clone().into_owned()
195                };
196
197                let schema_prop_path = if param_name.contains('/') {
198                    format!("/properties/{}", param_name.replace('/', "/properties/"))
199                } else {
200                    format!("/properties/{}", param_name)
201                };
202
203                let (error_type, msg, ctx) = if schema_path_str.contains("minLength") {
204                    if let Some(min_len) = self
205                        .schema
206                        .pointer(&format!("{}/minLength", schema_prop_path))
207                        .and_then(|v| v.as_u64())
208                    {
209                        let ctx = serde_json::json!({"min_length": min_len});
210                        (
211                            "string_too_short".to_string(),
212                            format!("String should have at least {} characters", min_len),
213                            Some(ctx),
214                        )
215                    } else {
216                        ("string_too_short".to_string(), "String is too short".to_string(), None)
217                    }
218                } else if schema_path_str.contains("maxLength") {
219                    if let Some(max_len) = self
220                        .schema
221                        .pointer(&format!("{}/maxLength", schema_prop_path))
222                        .and_then(|v| v.as_u64())
223                    {
224                        let ctx = serde_json::json!({"max_length": max_len});
225                        (
226                            "string_too_long".to_string(),
227                            format!("String should have at most {} characters", max_len),
228                            Some(ctx),
229                        )
230                    } else {
231                        ("string_too_long".to_string(), "String is too long".to_string(), None)
232                    }
233                } else if schema_path_str.contains("exclusiveMinimum")
234                    || (error_msg.contains("less than or equal to") && error_msg.contains("minimum"))
235                {
236                    if let Some(min_val) = self
237                        .schema
238                        .pointer(&format!("{}/exclusiveMinimum", schema_prop_path))
239                        .and_then(|v| v.as_i64())
240                    {
241                        let ctx = serde_json::json!({"gt": min_val});
242                        (
243                            "greater_than".to_string(),
244                            format!("Input should be greater than {}", min_val),
245                            Some(ctx),
246                        )
247                    } else {
248                        (
249                            "greater_than".to_string(),
250                            "Input should be greater than the minimum".to_string(),
251                            None,
252                        )
253                    }
254                } else if schema_path_str.contains("minimum") || error_msg.contains("less than the minimum") {
255                    if let Some(min_val) = self
256                        .schema
257                        .pointer(&format!("{}/minimum", schema_prop_path))
258                        .and_then(|v| v.as_i64())
259                    {
260                        let ctx = serde_json::json!({"ge": min_val});
261                        (
262                            "greater_than_equal".to_string(),
263                            format!("Input should be greater than or equal to {}", min_val),
264                            Some(ctx),
265                        )
266                    } else {
267                        (
268                            "greater_than_equal".to_string(),
269                            "Input should be greater than or equal to the minimum".to_string(),
270                            None,
271                        )
272                    }
273                } else if schema_path_str.contains("exclusiveMaximum")
274                    || (error_msg.contains("greater than or equal to") && error_msg.contains("maximum"))
275                {
276                    if let Some(max_val) = self
277                        .schema
278                        .pointer(&format!("{}/exclusiveMaximum", schema_prop_path))
279                        .and_then(|v| v.as_i64())
280                    {
281                        let ctx = serde_json::json!({"lt": max_val});
282                        (
283                            "less_than".to_string(),
284                            format!("Input should be less than {}", max_val),
285                            Some(ctx),
286                        )
287                    } else {
288                        (
289                            "less_than".to_string(),
290                            "Input should be less than the maximum".to_string(),
291                            None,
292                        )
293                    }
294                } else if schema_path_str.contains("maximum") || error_msg.contains("greater than the maximum") {
295                    if let Some(max_val) = self
296                        .schema
297                        .pointer(&format!("{}/maximum", schema_prop_path))
298                        .and_then(|v| v.as_i64())
299                    {
300                        let ctx = serde_json::json!({"le": max_val});
301                        (
302                            "less_than_equal".to_string(),
303                            format!("Input should be less than or equal to {}", max_val),
304                            Some(ctx),
305                        )
306                    } else {
307                        (
308                            "less_than_equal".to_string(),
309                            "Input should be less than or equal to the maximum".to_string(),
310                            None,
311                        )
312                    }
313                } else if schema_path_str.contains("enum") || error_msg.contains("is not one of") {
314                    if let Some(enum_values) = self
315                        .schema
316                        .pointer(&format!("{}/enum", schema_prop_path))
317                        .and_then(|v| v.as_array())
318                    {
319                        let values: Vec<String> = enum_values
320                            .iter()
321                            .filter_map(|v| v.as_str().map(|s| format!("'{}'", s)))
322                            .collect();
323
324                        let msg = if values.len() > 1 {
325                            let last = values.last().unwrap();
326                            let rest = &values[..values.len() - 1];
327                            format!("Input should be {} or {}", rest.join(", "), last)
328                        } else if !values.is_empty() {
329                            format!("Input should be {}", values[0])
330                        } else {
331                            "Input should be one of the allowed values".to_string()
332                        };
333
334                        let expected_str = if values.len() > 1 {
335                            let last = values.last().unwrap();
336                            let rest = &values[..values.len() - 1];
337                            format!("{} or {}", rest.join(", "), last)
338                        } else if !values.is_empty() {
339                            values[0].clone()
340                        } else {
341                            "allowed values".to_string()
342                        };
343                        let ctx = serde_json::json!({"expected": expected_str});
344                        ("enum".to_string(), msg, Some(ctx))
345                    } else {
346                        (
347                            "enum".to_string(),
348                            "Input should be one of the allowed values".to_string(),
349                            None,
350                        )
351                    }
352                } else if schema_path_str.contains("pattern") || error_msg.contains("does not match") {
353                    if let Some(pattern) = self
354                        .schema
355                        .pointer(&format!("{}/pattern", schema_prop_path))
356                        .and_then(|v| v.as_str())
357                    {
358                        let ctx = serde_json::json!({"pattern": pattern});
359                        let msg = format!("String should match pattern '{}'", pattern);
360                        ("string_pattern_mismatch".to_string(), msg, Some(ctx))
361                    } else {
362                        (
363                            "string_pattern_mismatch".to_string(),
364                            "String does not match expected pattern".to_string(),
365                            None,
366                        )
367                    }
368                } else if schema_path_str.contains("format") {
369                    if error_msg.contains("email") {
370                        let email_pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$";
371                        let ctx = serde_json::json!({"pattern": email_pattern});
372                        (
373                            "string_pattern_mismatch".to_string(),
374                            format!("String should match pattern '{}'", email_pattern),
375                            Some(ctx),
376                        )
377                    } else if error_msg.contains("uuid") {
378                        (
379                            "uuid_parsing".to_string(),
380                            "Input should be a valid UUID".to_string(),
381                            None,
382                        )
383                    } else if error_msg.contains("date-time") {
384                        (
385                            "datetime_parsing".to_string(),
386                            "Input should be a valid datetime".to_string(),
387                            None,
388                        )
389                    } else if error_msg.contains("date") {
390                        (
391                            "date_parsing".to_string(),
392                            "Input should be a valid date".to_string(),
393                            None,
394                        )
395                    } else {
396                        ("format_error".to_string(), err.to_string(), None)
397                    }
398                } else if schema_path_str.contains("/type") {
399                    let expected_type = self
400                        .schema
401                        .pointer(&format!("{}/type", schema_prop_path))
402                        .and_then(|v| v.as_str())
403                        .unwrap_or("unknown");
404
405                    let (error_type, msg) = match expected_type {
406                        "integer" => (
407                            "int_parsing".to_string(),
408                            "Input should be a valid integer, unable to parse string as an integer".to_string(),
409                        ),
410                        "number" => (
411                            "float_parsing".to_string(),
412                            "Input should be a valid number, unable to parse string as a number".to_string(),
413                        ),
414                        "boolean" => (
415                            "bool_parsing".to_string(),
416                            "Input should be a valid boolean".to_string(),
417                        ),
418                        "string" => ("string_type".to_string(), "Input should be a valid string".to_string()),
419                        _ => (
420                            "type_error".to_string(),
421                            format!("Input should be a valid {}", expected_type),
422                        ),
423                    };
424                    (error_type, msg, None)
425                } else if schema_path_str.ends_with("/required") {
426                    ("missing".to_string(), "Field required".to_string(), None)
427                } else if schema_path_str.contains("/additionalProperties")
428                    || error_msg.contains("Additional properties are not allowed")
429                {
430                    let unexpected_field = if param_name.contains('/') {
431                        param_name.split('/').next_back().unwrap_or(&param_name).to_string()
432                    } else {
433                        param_name.clone()
434                    };
435
436                    let ctx = serde_json::json!({
437                        "additional_properties": false,
438                        "unexpected_field": unexpected_field
439                    });
440                    (
441                        "validation_error".to_string(),
442                        "Additional properties are not allowed".to_string(),
443                        Some(ctx),
444                    )
445                } else if schema_path_str.contains("/minItems") {
446                    let min_items = if let Some(start) = schema_path_str.rfind('/') {
447                        if let Some(_min_idx) = schema_path_str[..start].rfind("/minItems") {
448                            1
449                        } else {
450                            1
451                        }
452                    } else {
453                        1
454                    };
455
456                    let ctx = serde_json::json!({
457                        "min_length": min_items
458                    });
459                    (
460                        "too_short".to_string(),
461                        format!("List should have at least {} item after validation", min_items),
462                        Some(ctx),
463                    )
464                } else if schema_path_str.contains("/maxItems") {
465                    let ctx = serde_json::json!({
466                        "max_length": 1
467                    });
468                    (
469                        "too_long".to_string(),
470                        "List should have at most N items after validation".to_string(),
471                        Some(ctx),
472                    )
473                } else {
474                    ("validation_error".to_string(), err.to_string(), None)
475                };
476
477                ValidationErrorDetail {
478                    error_type,
479                    loc: loc_parts,
480                    msg,
481                    input: input_value,
482                    ctx,
483                }
484            })
485            .collect();
486
487        debug_log_module!("validation", "Returning {} validation errors", errors.len());
488        for (i, error) in errors.iter().enumerate() {
489            debug_log_module!(
490                "validation",
491                "  Error {}: type={}, loc={:?}, msg={}, input={}, ctx={:?}",
492                i,
493                error.error_type,
494                error.loc,
495                error.msg,
496                error.input,
497                error.ctx
498            );
499        }
500        #[allow(clippy::collapsible_if)]
501        if crate::debug::is_enabled() {
502            if let Ok(json_errors) = serde_json::to_value(&errors) {
503                if let Ok(json_str) = serde_json::to_string_pretty(&json_errors) {
504                    debug_log_module!("validation", "Serialized errors:\n{}", json_str);
505                }
506            }
507        }
508
509        Err(ValidationError { errors })
510    }
511
512    /// Validate and parse JSON bytes
513    pub fn validate_json(&self, json_bytes: &[u8]) -> Result<Value, ValidationError> {
514        let value: Value = serde_json::from_slice(json_bytes).map_err(|e| ValidationError {
515            errors: vec![ValidationErrorDetail {
516                error_type: "json_parse_error".to_string(),
517                loc: vec!["body".to_string()],
518                msg: format!("Invalid JSON: {}", e),
519                input: Value::Null,
520                ctx: None,
521            }],
522        })?;
523
524        self.validate(&value)?;
525
526        Ok(value)
527    }
528}
529
530/// Validation error containing one or more validation failures
531#[derive(Debug, Clone)]
532pub struct ValidationError {
533    pub errors: Vec<ValidationErrorDetail>,
534}
535
536/// Individual validation error detail (FastAPI-compatible format)
537#[derive(Debug, Clone, serde::Serialize)]
538pub struct ValidationErrorDetail {
539    #[serde(rename = "type")]
540    pub error_type: String,
541    pub loc: Vec<String>,
542    pub msg: String,
543    pub input: Value,
544    #[serde(skip_serializing_if = "Option::is_none")]
545    pub ctx: Option<Value>,
546}
547
548impl std::fmt::Display for ValidationError {
549    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550        write!(f, "Validation failed: {} errors", self.errors.len())
551    }
552}
553
554impl std::error::Error for ValidationError {}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use serde_json::json;
560
561    #[test]
562    fn test_validator_creation() {
563        let schema = json!({
564            "type": "object",
565            "properties": {
566                "name": {"type": "string"},
567                "age": {"type": "integer"}
568            },
569            "required": ["name"]
570        });
571
572        let validator = SchemaValidator::new(schema).unwrap();
573        assert!(validator.compiled.is_valid(&json!({"name": "Alice", "age": 30})));
574    }
575
576    #[test]
577    fn test_validation_success() {
578        let schema = json!({
579            "type": "object",
580            "properties": {
581                "email": {"type": "string", "format": "email"}
582            }
583        });
584
585        let validator = SchemaValidator::new(schema).unwrap();
586        let data = json!({"email": "test@example.com"});
587
588        assert!(validator.validate(&data).is_ok());
589    }
590
591    #[test]
592    fn test_validation_failure() {
593        let schema = json!({
594            "type": "object",
595            "properties": {
596                "age": {"type": "integer", "minimum": 0}
597            },
598            "required": ["age"]
599        });
600
601        let validator = SchemaValidator::new(schema).unwrap();
602        let data = json!({"age": -5});
603
604        assert!(validator.validate(&data).is_err());
605    }
606
607    #[test]
608    fn test_validation_error_serialization() {
609        let schema = json!({
610            "type": "object",
611            "properties": {
612                "name": {
613                    "type": "string",
614                    "maxLength": 10
615                }
616            },
617            "required": ["name"]
618        });
619
620        let validator = SchemaValidator::new(schema).unwrap();
621        let data = json!({"name": "this_is_way_too_long"});
622
623        let result = validator.validate(&data);
624        assert!(result.is_err());
625
626        let err = result.unwrap_err();
627        assert_eq!(err.errors.len(), 1);
628
629        let error_detail = &err.errors[0];
630        assert_eq!(error_detail.error_type, "string_too_long");
631        assert_eq!(error_detail.loc, vec!["body", "name"]);
632        assert_eq!(error_detail.msg, "String should have at most 10 characters");
633        assert_eq!(error_detail.input, Value::String("this_is_way_too_long".to_string()));
634        assert_eq!(error_detail.ctx, Some(json!({"max_length": 10})));
635
636        let json_output = serde_json::to_value(&err.errors).unwrap();
637        println!(
638            "Serialized JSON: {}",
639            serde_json::to_string_pretty(&json_output).unwrap()
640        );
641
642        let serialized_error = &json_output[0];
643        assert!(serialized_error.get("type").is_some());
644        assert!(serialized_error.get("loc").is_some());
645        assert!(serialized_error.get("msg").is_some());
646        assert!(
647            serialized_error.get("input").is_some(),
648            "Missing 'input' field in serialized JSON!"
649        );
650        assert!(
651            serialized_error.get("ctx").is_some(),
652            "Missing 'ctx' field in serialized JSON!"
653        );
654
655        assert_eq!(
656            serialized_error["input"],
657            Value::String("this_is_way_too_long".to_string())
658        );
659        assert_eq!(serialized_error["ctx"], json!({"max_length": 10}));
660    }
661
662    #[test]
663    fn test_exclusive_minimum() {
664        let schema = json!({
665            "$schema": "https://json-schema.org/draft/2020-12/schema",
666            "type": "object",
667            "required": ["id", "name", "price"],
668            "properties": {
669                "id": {
670                    "type": "integer"
671                },
672                "name": {
673                    "type": "string",
674                    "minLength": 3
675                },
676                "price": {
677                    "type": "number",
678                    "exclusiveMinimum": 0
679                }
680            }
681        });
682
683        let validator = SchemaValidator::new(schema).unwrap();
684
685        let data = json!({
686            "id": 1,
687            "name": "X",
688            "price": -10
689        });
690
691        let result = validator.validate(&data);
692        eprintln!("Validation result: {:?}", result);
693
694        assert!(result.is_err(), "Should have validation errors");
695        let err = result.unwrap_err();
696        eprintln!("Errors: {:?}", err.errors);
697        assert_eq!(err.errors.len(), 2, "Should have 2 errors");
698    }
699}