Skip to main content

spikard_core/validation/
mod.rs

1//! Request/response validation using JSON Schema
2
3pub mod error_mapper;
4
5use jsonschema::Validator;
6use serde_json::Value;
7use std::sync::Arc;
8
9use self::error_mapper::{ErrorCondition, ErrorMapper};
10
11/// Schema validator that compiles and validates JSON Schema
12#[derive(Clone)]
13pub struct SchemaValidator {
14    compiled: Arc<Validator>,
15    schema: Value,
16}
17
18impl std::fmt::Debug for SchemaValidator {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("SchemaValidator")
21            .field("schema", &self.schema)
22            .finish_non_exhaustive()
23    }
24}
25
26impl SchemaValidator {
27    /// Create a new validator from a JSON Schema
28    ///
29    /// # Errors
30    /// Returns an error if the schema is invalid or compilation fails.
31    pub fn new(schema: Value) -> Result<Self, String> {
32        let compiled = jsonschema::options()
33            .with_draft(jsonschema::Draft::Draft202012)
34            .should_validate_formats(true)
35            .with_pattern_options(jsonschema::PatternOptions::regex())
36            .build(&schema)
37            .map_err(|e| {
38                anyhow::anyhow!("Invalid JSON Schema")
39                    .context(format!("Schema compilation failed: {e}"))
40                    .to_string()
41            })?;
42
43        Ok(Self {
44            compiled: Arc::new(compiled),
45            schema,
46        })
47    }
48
49    /// Get the underlying JSON Schema
50    #[must_use]
51    pub const fn schema(&self) -> &Value {
52        &self.schema
53    }
54
55    /// Pre-process data to convert file objects to strings for format: `binary` validation
56    ///
57    /// Files uploaded via multipart are converted to objects like:
58    /// `{"filename": "...", "size": N, "content": "...", "content_type": "..."}`
59    ///
60    /// But schemas define them as: `{"type": "string", "format": "binary"}`
61    ///
62    /// This method recursively processes the data and converts file objects to their content strings
63    /// so that validation passes, while preserving the original structure for handlers to use.
64    fn preprocess_binary_fields(&self, data: &Value) -> Value {
65        self.preprocess_value_with_schema(data, &self.schema)
66    }
67
68    // reason: &self is needed to make recursive calls; the method carries schema state
69    // and may gain direct field access in future without breaking callers.
70    #[allow(clippy::only_used_in_recursion, clippy::self_only_used_in_recursion)]
71    fn preprocess_value_with_schema(&self, data: &Value, schema: &Value) -> Value {
72        if let Some(schema_obj) = schema.as_object() {
73            let is_string_type = schema_obj.get("type").and_then(|t| t.as_str()) == Some("string");
74            let is_binary_format = schema_obj.get("format").and_then(|f| f.as_str()) == Some("binary");
75
76            if is_string_type && is_binary_format {
77                if let Some(data_obj) = data.as_object()
78                    && data_obj.contains_key("filename")
79                    && data_obj.contains_key("content")
80                    && data_obj.contains_key("size")
81                    && data_obj.contains_key("content_type")
82                {
83                    return data_obj.get("content").unwrap_or(&Value::Null).clone();
84                }
85                return data.clone();
86            }
87
88            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("array")
89                && let Some(items_schema) = schema_obj.get("items")
90                && let Some(data_array) = data.as_array()
91            {
92                let processed_array: Vec<Value> = data_array
93                    .iter()
94                    .map(|item| self.preprocess_value_with_schema(item, items_schema))
95                    .collect();
96                return Value::Array(processed_array);
97            }
98
99            if schema_obj.get("type").and_then(|t| t.as_str()) == Some("object")
100                && let Some(properties) = schema_obj.get("properties").and_then(|p| p.as_object())
101                && let Some(data_obj) = data.as_object()
102            {
103                let mut processed_obj = serde_json::Map::new();
104                for (key, value) in data_obj {
105                    if let Some(prop_schema) = properties.get(key) {
106                        processed_obj.insert(key.clone(), self.preprocess_value_with_schema(value, prop_schema));
107                    } else {
108                        processed_obj.insert(key.clone(), value.clone());
109                    }
110                }
111                return Value::Object(processed_obj);
112            }
113        }
114
115        data.clone()
116    }
117
118    /// Validate JSON data against the schema
119    ///
120    /// # Errors
121    /// Returns a `ValidationError` if the data does not conform to the schema.
122    ///
123    /// # Too Many Lines
124    /// This function is complex due to error mapping logic.
125    // reason: option_if_let_else — deeply nested closures harm readability here;
126    // uninlined_format_args — several format strings use variables that must remain
127    // separate for clarity; too_many_lines — error-mapping pipeline is inherently long.
128    #[allow(clippy::option_if_let_else, clippy::uninlined_format_args, clippy::too_many_lines)]
129    pub fn validate(&self, data: &Value) -> Result<(), ValidationError> {
130        let processed_data = self.preprocess_binary_fields(data);
131
132        let validation_errors: Vec<_> = self.compiled.iter_errors(&processed_data).collect();
133
134        if validation_errors.is_empty() {
135            return Ok(());
136        }
137
138        let errors: Vec<ValidationErrorDetail> = validation_errors
139            .into_iter()
140            .map(|err| {
141                let instance_path = err.instance_path().to_string();
142                let schema_path_str = err.schema_path().as_str();
143                let error_msg = err.to_string();
144
145                let param_name = if schema_path_str.ends_with("/required") {
146                    let field_name = if let Some(start) = error_msg.find('"') {
147                        if let Some(end) = error_msg[start + 1..].find('"') {
148                            error_msg[start + 1..start + 1 + end].to_string()
149                        } else {
150                            String::new()
151                        }
152                    } else {
153                        String::new()
154                    };
155
156                    if instance_path.starts_with('/') && instance_path.len() > 1 {
157                        let base_path = &instance_path[1..];
158                        if field_name.is_empty() {
159                            base_path.to_string()
160                        } else {
161                            format!("{base_path}/{field_name}")
162                        }
163                    } else if field_name.is_empty() {
164                        "body".to_string()
165                    } else {
166                        field_name
167                    }
168                } else if schema_path_str.contains("/additionalProperties") {
169                    if let Some(start) = error_msg.find('(') {
170                        if let Some(quote_start) = error_msg[start..].find('\'') {
171                            let abs_start = start + quote_start + 1;
172                            error_msg[abs_start..].find('\'').map_or_else(
173                                || instance_path[1..].to_string(),
174                                |quote_end| {
175                                    let property_name = error_msg[abs_start..abs_start + quote_end].to_string();
176                                    if instance_path.starts_with('/') && instance_path.len() > 1 {
177                                        format!("{}/{property_name}", &instance_path[1..])
178                                    } else {
179                                        property_name
180                                    }
181                                },
182                            )
183                        } else {
184                            instance_path[1..].to_string()
185                        }
186                    } else if instance_path.starts_with('/') && instance_path.len() > 1 {
187                        instance_path[1..].to_string()
188                    } else {
189                        "body".to_string()
190                    }
191                } else if instance_path.starts_with('/') && instance_path.len() > 1 {
192                    instance_path[1..].to_string()
193                } else if instance_path.is_empty() {
194                    "body".to_string()
195                } else {
196                    instance_path
197                };
198
199                let loc_parts: Vec<String> = if param_name.contains('/') {
200                    let mut parts = vec!["body".to_string()];
201                    parts.extend(param_name.split('/').map(ToString::to_string));
202                    parts
203                } else if param_name == "body" {
204                    vec!["body".to_string()]
205                } else {
206                    vec!["body".to_string(), param_name.clone()]
207                };
208
209                let input_value = if schema_path_str == "/required" {
210                    data.clone()
211                } else {
212                    err.instance().clone().into_owned()
213                };
214
215                let schema_prop_path = if param_name.contains('/') {
216                    format!("/properties/{}", param_name.replace('/', "/properties/"))
217                } else {
218                    format!("/properties/{param_name}")
219                };
220
221                let mut error_condition = ErrorCondition::from_schema_error(schema_path_str, &error_msg);
222
223                error_condition = match error_condition {
224                    ErrorCondition::TypeMismatch { .. } => {
225                        let expected_type = self
226                            .schema
227                            .pointer(&format!("{schema_prop_path}/type"))
228                            .and_then(|v| v.as_str())
229                            .unwrap_or("unknown")
230                            .to_string();
231                        ErrorCondition::TypeMismatch { expected_type }
232                    }
233                    ErrorCondition::AdditionalProperties { .. } => {
234                        // reason: param_name is borrowed in the if-branch via unwrap_or(&param_name);
235                        // the else-branch therefore cannot move out of param_name and must clone.
236                        #[allow(clippy::redundant_clone)]
237                        let unexpected_field = if param_name.contains('/') {
238                            param_name.split('/').next_back().unwrap_or(&param_name).to_string()
239                        } else {
240                            param_name.clone()
241                        };
242                        ErrorCondition::AdditionalProperties {
243                            field: unexpected_field,
244                        }
245                    }
246                    other => other,
247                };
248
249                let (error_type, msg, ctx) =
250                    ErrorMapper::map_error(&error_condition, &self.schema, &schema_prop_path, &error_msg);
251
252                ValidationErrorDetail {
253                    error_type,
254                    loc: loc_parts,
255                    msg,
256                    input: input_value,
257                    ctx,
258                }
259            })
260            .collect();
261
262        Err(ValidationError { errors })
263    }
264
265    /// Validate and parse JSON bytes
266    ///
267    /// # Errors
268    /// Returns a validation error if the JSON is invalid or fails validation against the schema.
269    pub fn validate_json(&self, json_bytes: &[u8]) -> Result<Value, ValidationError> {
270        let value: Value = serde_json::from_slice(json_bytes).map_err(|e| ValidationError {
271            errors: vec![ValidationErrorDetail {
272                error_type: "json_parse_error".to_string(),
273                loc: vec!["body".to_string()],
274                msg: format!("Invalid JSON: {e}"),
275                input: Value::Null,
276                ctx: None,
277            }],
278        })?;
279
280        self.validate(&value)?;
281
282        Ok(value)
283    }
284}
285
286/// Validation error containing one or more validation failures
287#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
288pub struct ValidationError {
289    pub errors: Vec<ValidationErrorDetail>,
290}
291
292/// Individual validation error detail (FastAPI-compatible format)
293#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
294pub struct ValidationErrorDetail {
295    #[serde(rename = "type")]
296    pub error_type: String,
297    pub loc: Vec<String>,
298    pub msg: String,
299    pub input: Value,
300    #[serde(skip_serializing_if = "Option::is_none")]
301    pub ctx: Option<Value>,
302}
303
304impl std::fmt::Display for ValidationError {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        write!(f, "Validation failed: {} errors", self.errors.len())
307    }
308}
309
310impl std::error::Error for ValidationError {}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use serde_json::json;
316
317    #[test]
318    fn test_validator_creation() {
319        let schema = json!({
320            "type": "object",
321            "properties": {
322                "name": {"type": "string"},
323                "age": {"type": "integer"}
324            },
325            "required": ["name"]
326        });
327
328        let validator = SchemaValidator::new(schema).unwrap();
329        assert!(validator.compiled.is_valid(&json!({"name": "Alice", "age": 30})));
330    }
331
332    #[test]
333    fn test_validation_success() {
334        let schema = json!({
335            "type": "object",
336            "properties": {
337                "email": {"type": "string", "format": "email"}
338            }
339        });
340
341        let validator = SchemaValidator::new(schema).unwrap();
342        let data = json!({"email": "test@example.com"});
343
344        assert!(validator.validate(&data).is_ok());
345    }
346
347    #[test]
348    fn test_validation_failure() {
349        let schema = json!({
350            "type": "object",
351            "properties": {
352                "age": {"type": "integer", "minimum": 0}
353            },
354            "required": ["age"]
355        });
356
357        let validator = SchemaValidator::new(schema).unwrap();
358        let data = json!({"age": -5});
359
360        assert!(validator.validate(&data).is_err());
361    }
362
363    #[test]
364    fn test_validation_error_serialization() {
365        let schema = json!({
366            "type": "object",
367            "properties": {
368                "name": {
369                    "type": "string",
370                    "maxLength": 10
371                }
372            },
373            "required": ["name"]
374        });
375
376        let validator = SchemaValidator::new(schema).unwrap();
377        let data = json!({"name": "this_is_way_too_long"});
378
379        let result = validator.validate(&data);
380        assert!(result.is_err());
381
382        let err = result.unwrap_err();
383        assert_eq!(err.errors.len(), 1);
384
385        let error_detail = &err.errors[0];
386        assert_eq!(error_detail.error_type, "string_too_long");
387        assert_eq!(error_detail.loc, vec!["body", "name"]);
388        assert_eq!(error_detail.msg, "String should have at most 10 characters");
389        assert_eq!(error_detail.input, Value::String("this_is_way_too_long".to_string()));
390        assert_eq!(error_detail.ctx, Some(json!({"max_length": 10})));
391
392        let json_output = serde_json::to_value(&err.errors).unwrap();
393        println!(
394            "Serialized JSON: {}",
395            serde_json::to_string_pretty(&json_output).unwrap()
396        );
397
398        let serialized_error = &json_output[0];
399        assert!(serialized_error.get("type").is_some());
400        assert!(serialized_error.get("loc").is_some());
401        assert!(serialized_error.get("msg").is_some());
402        assert!(
403            serialized_error.get("input").is_some(),
404            "Missing 'input' field in serialized JSON!"
405        );
406        assert!(
407            serialized_error.get("ctx").is_some(),
408            "Missing 'ctx' field in serialized JSON!"
409        );
410
411        assert_eq!(
412            serialized_error["input"],
413            Value::String("this_is_way_too_long".to_string())
414        );
415        assert_eq!(serialized_error["ctx"], json!({"max_length": 10}));
416    }
417
418    #[test]
419    fn test_exclusive_minimum() {
420        let schema = json!({
421            "$schema": "https://json-schema.org/draft/2020-12/schema",
422            "type": "object",
423            "required": ["id", "name", "price"],
424            "properties": {
425                "id": {
426                    "type": "integer"
427                },
428                "name": {
429                    "type": "string",
430                    "minLength": 3
431                },
432                "price": {
433                    "type": "number",
434                    "exclusiveMinimum": 0
435                }
436            }
437        });
438
439        let validator = SchemaValidator::new(schema).unwrap();
440
441        let data = json!({
442            "id": 1,
443            "name": "X",
444            "price": -10
445        });
446
447        let result = validator.validate(&data);
448        eprintln!("Validation result: {result:?}");
449
450        assert!(result.is_err(), "Should have validation errors");
451        let err = result.unwrap_err();
452        eprintln!("Errors: {:?}", err.errors);
453        assert_eq!(err.errors.len(), 2, "Should have 2 errors");
454    }
455}