Skip to main content

spikard_core/validation/
mod.rs

1//! Request/response validation using JSON Schema
2
3pub mod error_mapper;
4
5use crate::debug_log_module;
6use jsonschema::Validator;
7use serde_json::Value;
8use std::sync::Arc;
9
10use self::error_mapper::{ErrorCondition, ErrorMapper};
11
12/// Schema validator that compiles and validates JSON Schema
13#[derive(Clone)]
14pub struct SchemaValidator {
15    compiled: Arc<Validator>,
16    schema: Value,
17}
18
19impl SchemaValidator {
20    /// Create a new validator from a JSON Schema
21    ///
22    /// # Errors
23    /// Returns an error if the schema is invalid or compilation fails.
24    pub fn new(schema: Value) -> Result<Self, String> {
25        let compiled = jsonschema::options()
26            .with_draft(jsonschema::Draft::Draft202012)
27            .should_validate_formats(true)
28            .with_pattern_options(jsonschema::PatternOptions::regex())
29            .build(&schema)
30            .map_err(|e| {
31                anyhow::anyhow!("Invalid JSON Schema")
32                    .context(format!("Schema compilation failed: {e}"))
33                    .to_string()
34            })?;
35
36        Ok(Self {
37            compiled: Arc::new(compiled),
38            schema,
39        })
40    }
41
42    /// Get the underlying JSON Schema
43    #[must_use]
44    pub const fn schema(&self) -> &Value {
45        &self.schema
46    }
47
48    /// Pre-process data to convert file objects to strings for format: `binary` validation
49    ///
50    /// Files uploaded via multipart are converted to objects like:
51    /// `{"filename": "...", "size": N, "content": "...", "content_type": "..."}`
52    ///
53    /// But schemas define them as: `{"type": "string", "format": "binary"}`
54    ///
55    /// This method recursively processes the data and converts file objects to their content strings
56    /// so that validation passes, while preserving the original structure for handlers to use.
57    fn preprocess_binary_fields(&self, data: &Value) -> Value {
58        self.preprocess_value_with_schema(data, &self.schema)
59    }
60
61    #[allow(clippy::only_used_in_recursion, clippy::self_only_used_in_recursion)]
62    fn preprocess_value_with_schema(&self, data: &Value, schema: &Value) -> Value {
63        if let Some(schema_obj) = schema.as_object() {
64            let is_string_type = schema_obj.get("type").and_then(|t| t.as_str()) == Some("string");
65            let is_binary_format = schema_obj.get("format").and_then(|f| f.as_str()) == Some("binary");
66
67            #[allow(clippy::collapsible_if)]
68            if is_string_type && is_binary_format {
69                if let Some(data_obj) = data.as_object() {
70                    if data_obj.contains_key("filename")
71                        && data_obj.contains_key("content")
72                        && data_obj.contains_key("size")
73                        && data_obj.contains_key("content_type")
74                    {
75                        return data_obj.get("content").unwrap_or(&Value::Null).clone();
76                    }
77                }
78                return data.clone();
79            }
80
81            #[allow(clippy::collapsible_if)]
82            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("array") {
83                if let Some(items_schema) = schema_obj.get("items") {
84                    if let Some(data_array) = data.as_array() {
85                        let processed_array: Vec<Value> = data_array
86                            .iter()
87                            .map(|item| self.preprocess_value_with_schema(item, items_schema))
88                            .collect();
89                        return Value::Array(processed_array);
90                    }
91                }
92            }
93
94            #[allow(clippy::collapsible_if)]
95            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("object") {
96                if let Some(properties) = schema_obj.get("properties").and_then(|p| p.as_object()) {
97                    if let Some(data_obj) = data.as_object() {
98                        let mut processed_obj = serde_json::Map::new();
99                        for (key, value) in data_obj {
100                            if let Some(prop_schema) = properties.get(key) {
101                                processed_obj
102                                    .insert(key.clone(), self.preprocess_value_with_schema(value, prop_schema));
103                            } else {
104                                processed_obj.insert(key.clone(), value.clone());
105                            }
106                        }
107                        return Value::Object(processed_obj);
108                    }
109                }
110            }
111        }
112
113        data.clone()
114    }
115
116    /// Validate JSON data against the schema
117    ///
118    /// # Errors
119    /// Returns a `ValidationError` if the data does not conform to the schema.
120    ///
121    /// # Too Many Lines
122    /// This function is complex due to error mapping logic.
123    #[allow(clippy::option_if_let_else, clippy::uninlined_format_args, clippy::too_many_lines)]
124    pub fn validate(&self, data: &Value) -> Result<(), ValidationError> {
125        let processed_data = self.preprocess_binary_fields(data);
126
127        let validation_errors: Vec<_> = self.compiled.iter_errors(&processed_data).collect();
128
129        if validation_errors.is_empty() {
130            return Ok(());
131        }
132
133        let errors: Vec<ValidationErrorDetail> = validation_errors
134            .into_iter()
135            .map(|err| {
136                let instance_path = err.instance_path().to_string();
137                let schema_path_str = err.schema_path().as_str();
138                let error_msg = err.to_string();
139
140                let param_name = if schema_path_str.ends_with("/required") {
141                    let field_name = if let Some(start) = error_msg.find('"') {
142                        if let Some(end) = error_msg[start + 1..].find('"') {
143                            error_msg[start + 1..start + 1 + end].to_string()
144                        } else {
145                            String::new()
146                        }
147                    } else {
148                        String::new()
149                    };
150
151                    if instance_path.starts_with('/') && instance_path.len() > 1 {
152                        let base_path = &instance_path[1..];
153                        if field_name.is_empty() {
154                            base_path.to_string()
155                        } else {
156                            format!("{base_path}/{field_name}")
157                        }
158                    } else if field_name.is_empty() {
159                        "body".to_string()
160                    } else {
161                        field_name
162                    }
163                } else if schema_path_str.contains("/additionalProperties") {
164                    if let Some(start) = error_msg.find('(') {
165                        if let Some(quote_start) = error_msg[start..].find('\'') {
166                            let abs_start = start + quote_start + 1;
167                            error_msg[abs_start..].find('\'').map_or_else(
168                                || instance_path[1..].to_string(),
169                                |quote_end| {
170                                    let property_name = error_msg[abs_start..abs_start + quote_end].to_string();
171                                    if instance_path.starts_with('/') && instance_path.len() > 1 {
172                                        format!("{}/{property_name}", &instance_path[1..])
173                                    } else {
174                                        property_name
175                                    }
176                                },
177                            )
178                        } else {
179                            instance_path[1..].to_string()
180                        }
181                    } else if instance_path.starts_with('/') && instance_path.len() > 1 {
182                        instance_path[1..].to_string()
183                    } else {
184                        "body".to_string()
185                    }
186                } else if instance_path.starts_with('/') && instance_path.len() > 1 {
187                    instance_path[1..].to_string()
188                } else if instance_path.is_empty() {
189                    "body".to_string()
190                } else {
191                    instance_path
192                };
193
194                let loc_parts: Vec<String> = if param_name.contains('/') {
195                    let mut parts = vec!["body".to_string()];
196                    parts.extend(param_name.split('/').map(ToString::to_string));
197                    parts
198                } else if param_name == "body" {
199                    vec!["body".to_string()]
200                } else {
201                    vec!["body".to_string(), param_name.clone()]
202                };
203
204                let input_value = if schema_path_str == "/required" {
205                    data.clone()
206                } else {
207                    err.instance().clone().into_owned()
208                };
209
210                let schema_prop_path = if param_name.contains('/') {
211                    format!("/properties/{}", param_name.replace('/', "/properties/"))
212                } else {
213                    format!("/properties/{param_name}")
214                };
215
216                let mut error_condition = ErrorCondition::from_schema_error(schema_path_str, &error_msg);
217
218                error_condition = match error_condition {
219                    ErrorCondition::TypeMismatch { .. } => {
220                        let expected_type = self
221                            .schema
222                            .pointer(&format!("{schema_prop_path}/type"))
223                            .and_then(|v| v.as_str())
224                            .unwrap_or("unknown")
225                            .to_string();
226                        ErrorCondition::TypeMismatch { expected_type }
227                    }
228                    ErrorCondition::AdditionalProperties { .. } => {
229                        #[allow(clippy::redundant_clone)]
230                        let unexpected_field = if param_name.contains('/') {
231                            param_name.split('/').next_back().unwrap_or(&param_name).to_string()
232                        } else {
233                            param_name.clone()
234                        };
235                        ErrorCondition::AdditionalProperties {
236                            field: unexpected_field,
237                        }
238                    }
239                    other => other,
240                };
241
242                let (error_type, msg, ctx) =
243                    ErrorMapper::map_error(&error_condition, &self.schema, &schema_prop_path, &error_msg);
244
245                ValidationErrorDetail {
246                    error_type,
247                    loc: loc_parts,
248                    msg,
249                    input: input_value,
250                    ctx,
251                }
252            })
253            .collect();
254
255        debug_log_module!("validation", "Returning {} validation errors", errors.len());
256        for (i, error) in errors.iter().enumerate() {
257            debug_log_module!(
258                "validation",
259                "  Error {}: type={}, loc={:?}, msg={}, input={}, ctx={:?}",
260                i,
261                error.error_type,
262                error.loc,
263                error.msg,
264                error.input,
265                error.ctx
266            );
267        }
268        #[allow(clippy::collapsible_if)]
269        if crate::debug::is_enabled() {
270            if let Ok(json_errors) = serde_json::to_value(&errors) {
271                if let Ok(json_str) = serde_json::to_string_pretty(&json_errors) {
272                    debug_log_module!("validation", "Serialized errors:\n{}", json_str);
273                }
274            }
275        }
276
277        Err(ValidationError { errors })
278    }
279
280    /// Validate and parse JSON bytes
281    ///
282    /// # Errors
283    /// Returns a validation error if the JSON is invalid or fails validation against the schema.
284    pub fn validate_json(&self, json_bytes: &[u8]) -> Result<Value, ValidationError> {
285        let value: Value = serde_json::from_slice(json_bytes).map_err(|e| ValidationError {
286            errors: vec![ValidationErrorDetail {
287                error_type: "json_parse_error".to_string(),
288                loc: vec!["body".to_string()],
289                msg: format!("Invalid JSON: {e}"),
290                input: Value::Null,
291                ctx: None,
292            }],
293        })?;
294
295        self.validate(&value)?;
296
297        Ok(value)
298    }
299}
300
301/// Validation error containing one or more validation failures
302#[derive(Debug, Clone)]
303pub struct ValidationError {
304    pub errors: Vec<ValidationErrorDetail>,
305}
306
307/// Individual validation error detail (FastAPI-compatible format)
308#[derive(Debug, Clone, serde::Serialize)]
309pub struct ValidationErrorDetail {
310    #[serde(rename = "type")]
311    pub error_type: String,
312    pub loc: Vec<String>,
313    pub msg: String,
314    pub input: Value,
315    #[serde(skip_serializing_if = "Option::is_none")]
316    pub ctx: Option<Value>,
317}
318
319impl std::fmt::Display for ValidationError {
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        write!(f, "Validation failed: {} errors", self.errors.len())
322    }
323}
324
325impl std::error::Error for ValidationError {}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use serde_json::json;
331
332    #[test]
333    fn test_validator_creation() {
334        let schema = json!({
335            "type": "object",
336            "properties": {
337                "name": {"type": "string"},
338                "age": {"type": "integer"}
339            },
340            "required": ["name"]
341        });
342
343        let validator = SchemaValidator::new(schema).unwrap();
344        assert!(validator.compiled.is_valid(&json!({"name": "Alice", "age": 30})));
345    }
346
347    #[test]
348    fn test_validation_success() {
349        let schema = json!({
350            "type": "object",
351            "properties": {
352                "email": {"type": "string", "format": "email"}
353            }
354        });
355
356        let validator = SchemaValidator::new(schema).unwrap();
357        let data = json!({"email": "test@example.com"});
358
359        assert!(validator.validate(&data).is_ok());
360    }
361
362    #[test]
363    fn test_validation_failure() {
364        let schema = json!({
365            "type": "object",
366            "properties": {
367                "age": {"type": "integer", "minimum": 0}
368            },
369            "required": ["age"]
370        });
371
372        let validator = SchemaValidator::new(schema).unwrap();
373        let data = json!({"age": -5});
374
375        assert!(validator.validate(&data).is_err());
376    }
377
378    #[test]
379    fn test_validation_error_serialization() {
380        let schema = json!({
381            "type": "object",
382            "properties": {
383                "name": {
384                    "type": "string",
385                    "maxLength": 10
386                }
387            },
388            "required": ["name"]
389        });
390
391        let validator = SchemaValidator::new(schema).unwrap();
392        let data = json!({"name": "this_is_way_too_long"});
393
394        let result = validator.validate(&data);
395        assert!(result.is_err());
396
397        let err = result.unwrap_err();
398        assert_eq!(err.errors.len(), 1);
399
400        let error_detail = &err.errors[0];
401        assert_eq!(error_detail.error_type, "string_too_long");
402        assert_eq!(error_detail.loc, vec!["body", "name"]);
403        assert_eq!(error_detail.msg, "String should have at most 10 characters");
404        assert_eq!(error_detail.input, Value::String("this_is_way_too_long".to_string()));
405        assert_eq!(error_detail.ctx, Some(json!({"max_length": 10})));
406
407        let json_output = serde_json::to_value(&err.errors).unwrap();
408        println!(
409            "Serialized JSON: {}",
410            serde_json::to_string_pretty(&json_output).unwrap()
411        );
412
413        let serialized_error = &json_output[0];
414        assert!(serialized_error.get("type").is_some());
415        assert!(serialized_error.get("loc").is_some());
416        assert!(serialized_error.get("msg").is_some());
417        assert!(
418            serialized_error.get("input").is_some(),
419            "Missing 'input' field in serialized JSON!"
420        );
421        assert!(
422            serialized_error.get("ctx").is_some(),
423            "Missing 'ctx' field in serialized JSON!"
424        );
425
426        assert_eq!(
427            serialized_error["input"],
428            Value::String("this_is_way_too_long".to_string())
429        );
430        assert_eq!(serialized_error["ctx"], json!({"max_length": 10}));
431    }
432
433    #[test]
434    fn test_exclusive_minimum() {
435        let schema = json!({
436            "$schema": "https://json-schema.org/draft/2020-12/schema",
437            "type": "object",
438            "required": ["id", "name", "price"],
439            "properties": {
440                "id": {
441                    "type": "integer"
442                },
443                "name": {
444                    "type": "string",
445                    "minLength": 3
446                },
447                "price": {
448                    "type": "number",
449                    "exclusiveMinimum": 0
450                }
451            }
452        });
453
454        let validator = SchemaValidator::new(schema).unwrap();
455
456        let data = json!({
457            "id": 1,
458            "name": "X",
459            "price": -10
460        });
461
462        let result = validator.validate(&data);
463        eprintln!("Validation result: {result:?}");
464
465        assert!(result.is_err(), "Should have validation errors");
466        let err = result.unwrap_err();
467        eprintln!("Errors: {:?}", err.errors);
468        assert_eq!(err.errors.len(), 2, "Should have 2 errors");
469    }
470}