sentinel_proxy/
validation.rs

1//! API schema validation module for Sentinel proxy
2//!
3//! This module provides JSON Schema validation for API routes,
4//! supporting both request and response validation with OpenAPI integration.
5
6use anyhow::{Context, Result};
7use bytes::Bytes;
8use http::{Request, Response, StatusCode};
9use http_body_util::{BodyExt, Full};
10use jsonschema::{Draft, JSONSchema, ValidationError};
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tracing::{debug, info, warn};
17
18use sentinel_config::ApiSchemaConfig;
19
20/// API schema validator
21pub struct SchemaValidator {
22    /// Configuration for schema validation
23    config: Arc<ApiSchemaConfig>,
24    /// Compiled request schema
25    request_schema: Option<Arc<JSONSchema>>,
26    /// Compiled response schema
27    response_schema: Option<Arc<JSONSchema>>,
28    /// OpenAPI specification (if loaded)
29    openapi_spec: Option<OpenApiSpec>,
30}
31
32/// OpenAPI specification
33#[derive(Debug, Clone, Deserialize)]
34struct OpenApiSpec {
35    openapi: String,
36    paths: HashMap<String, PathItem>,
37    components: Option<Components>,
38}
39
40/// OpenAPI path item
41#[derive(Debug, Clone, Deserialize)]
42struct PathItem {
43    #[serde(default)]
44    get: Option<Operation>,
45    #[serde(default)]
46    post: Option<Operation>,
47    #[serde(default)]
48    put: Option<Operation>,
49    #[serde(default)]
50    delete: Option<Operation>,
51    #[serde(default)]
52    patch: Option<Operation>,
53}
54
55/// OpenAPI operation
56#[derive(Debug, Clone, Deserialize)]
57struct Operation {
58    #[serde(rename = "operationId")]
59    operation_id: Option<String>,
60    #[serde(rename = "requestBody")]
61    request_body: Option<RequestBody>,
62    responses: HashMap<String, ApiResponse>,
63}
64
65/// OpenAPI request body
66#[derive(Debug, Clone, Deserialize)]
67struct RequestBody {
68    required: Option<bool>,
69    content: HashMap<String, MediaType>,
70}
71
72/// OpenAPI response
73#[derive(Debug, Clone, Deserialize)]
74struct ApiResponse {
75    description: String,
76    content: Option<HashMap<String, MediaType>>,
77}
78
79/// OpenAPI media type
80#[derive(Debug, Clone, Deserialize)]
81struct MediaType {
82    schema: Option<Value>,
83}
84
85/// OpenAPI components
86#[derive(Debug, Clone, Deserialize)]
87struct Components {
88    schemas: Option<HashMap<String, Value>>,
89}
90
91/// Validation error response
92#[derive(Debug, Serialize)]
93pub struct ValidationErrorResponse {
94    pub error: String,
95    pub status: u16,
96    pub validation_errors: Vec<ValidationErrorDetail>,
97    pub request_id: String,
98}
99
100/// Individual validation error detail
101#[derive(Debug, Serialize)]
102pub struct ValidationErrorDetail {
103    pub field: String,
104    pub message: String,
105    pub value: Option<Value>,
106}
107
108impl SchemaValidator {
109    /// Create a new schema validator
110    pub fn new(config: ApiSchemaConfig) -> Result<Self> {
111        let mut validator = Self {
112            config: Arc::new(config.clone()),
113            request_schema: None,
114            response_schema: None,
115            openapi_spec: None,
116        };
117
118        // Load OpenAPI specification if provided
119        if let Some(ref schema_file) = config.schema_file {
120            validator.load_openapi_spec(schema_file)?;
121        }
122
123        // Compile request schema if provided
124        if let Some(ref schema) = config.request_schema {
125            validator.request_schema = Some(Arc::new(Self::compile_schema(schema)?));
126        }
127
128        // Compile response schema if provided
129        if let Some(ref schema) = config.response_schema {
130            validator.response_schema = Some(Arc::new(Self::compile_schema(schema)?));
131        }
132
133        Ok(validator)
134    }
135
136    /// Load OpenAPI specification from file
137    fn load_openapi_spec(&mut self, path: &Path) -> Result<()> {
138        let content = std::fs::read_to_string(path)
139            .with_context(|| format!("Failed to read OpenAPI spec: {:?}", path))?;
140
141        let spec: OpenApiSpec = if path
142            .extension()
143            .map_or(false, |e| e == "yaml" || e == "yml")
144        {
145            serde_yaml::from_str(&content)?
146        } else {
147            serde_json::from_str(&content)?
148        };
149
150        info!("Loaded OpenAPI specification from {:?}", path);
151        self.openapi_spec = Some(spec);
152        Ok(())
153    }
154
155    /// Compile a JSON schema
156    fn compile_schema(schema: &Value) -> Result<JSONSchema> {
157        JSONSchema::options()
158            .with_draft(Draft::Draft7)
159            .compile(schema)
160            .map_err(|e| anyhow::anyhow!("Failed to compile schema: {}", e))
161    }
162
163    /// Validate a request
164    pub async fn validate_request<B>(
165        &self,
166        request: &Request<B>,
167        body: &[u8],
168        path: &str,
169        request_id: &str,
170    ) -> Result<()> {
171        if !self.config.validate_requests {
172            return Ok(());
173        }
174
175        // Parse JSON body
176        let json_body: Value = if body.is_empty() {
177            json!(null)
178        } else {
179            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
180        };
181
182        // Get the appropriate schema
183        let schema = if let Some(ref request_schema) = self.request_schema {
184            request_schema.clone()
185        } else if let Some(ref spec) = self.openapi_spec {
186            // Try to find schema from OpenAPI spec
187            match self.get_request_schema_from_spec(spec, path, request.method().as_str()) {
188                Some(s) => Arc::new(Self::compile_schema(&s)?),
189                None => {
190                    debug!("No schema found for {} {}", request.method(), path);
191                    return Ok(());
192                }
193            }
194        } else {
195            // No schema configured
196            return Ok(());
197        };
198
199        // Validate against schema
200        self.validate_against_schema(&schema, &json_body, request_id)?;
201
202        Ok(())
203    }
204
205    /// Validate a response
206    pub async fn validate_response(
207        &self,
208        status: StatusCode,
209        body: &[u8],
210        path: &str,
211        method: &str,
212        request_id: &str,
213    ) -> Result<()> {
214        if !self.config.validate_responses {
215            return Ok(());
216        }
217
218        // Parse JSON body
219        let json_body: Value = if body.is_empty() {
220            json!(null)
221        } else {
222            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
223        };
224
225        // Get the appropriate schema
226        let schema = if let Some(ref response_schema) = self.response_schema {
227            response_schema.clone()
228        } else if let Some(ref spec) = self.openapi_spec {
229            // Try to find schema from OpenAPI spec
230            match self.get_response_schema_from_spec(spec, path, method, status.as_u16()) {
231                Some(s) => Arc::new(Self::compile_schema(&s)?),
232                None => {
233                    debug!(
234                        "No schema found for {} {} response {}",
235                        method, path, status
236                    );
237                    return Ok(());
238                }
239            }
240        } else {
241            // No schema configured
242            return Ok(());
243        };
244
245        // Validate against schema
246        self.validate_against_schema(&schema, &json_body, request_id)?;
247
248        Ok(())
249    }
250
251    /// Validate JSON against a schema
252    fn validate_against_schema(
253        &self,
254        schema: &JSONSchema,
255        instance: &Value,
256        request_id: &str,
257    ) -> Result<()> {
258        let result = schema.validate(instance);
259
260        if let Err(errors) = result {
261            let validation_errors: Vec<ValidationErrorDetail> = errors
262                .map(|error| self.format_validation_error(error, instance))
263                .collect();
264
265            if !validation_errors.is_empty() {
266                return Err(self.create_validation_error(validation_errors, request_id));
267            }
268        }
269
270        // Additional strict mode checks
271        if self.config.strict_mode {
272            self.strict_mode_checks(schema, instance, request_id)?;
273        }
274
275        Ok(())
276    }
277
278    /// Format a validation error
279    fn format_validation_error(
280        &self,
281        error: ValidationError,
282        instance: &Value,
283    ) -> ValidationErrorDetail {
284        let field = error.instance_path.to_string();
285        let field = if field.is_empty() {
286            "$".to_string()
287        } else {
288            field
289        };
290
291        let value = error
292            .instance_path
293            .iter()
294            .fold(Some(instance), |acc, segment| {
295                acc.and_then(|v| match segment {
296                    jsonschema::paths::PathChunk::Property(prop) => v.get(prop.as_ref()),
297                    jsonschema::paths::PathChunk::Index(idx) => v.get(idx),
298                    _ => None,
299                })
300            })
301            .cloned();
302
303        ValidationErrorDetail {
304            field,
305            message: error.to_string(),
306            value,
307        }
308    }
309
310    /// Perform strict mode checks
311    fn strict_mode_checks(
312        &self,
313        _schema: &JSONSchema,
314        instance: &Value,
315        _request_id: &str,
316    ) -> Result<()> {
317        // Check for null values
318        if self.has_null_values(instance) {
319            warn!("Strict mode: Found null values in JSON");
320        }
321
322        // Check for empty strings
323        if self.has_empty_strings(instance) {
324            warn!("Strict mode: Found empty strings in JSON");
325        }
326
327        Ok(())
328    }
329
330    /// Check if JSON contains null values
331    fn has_null_values(&self, value: &Value) -> bool {
332        match value {
333            Value::Null => true,
334            Value::Array(arr) => arr.iter().any(|v| self.has_null_values(v)),
335            Value::Object(obj) => obj.values().any(|v| self.has_null_values(v)),
336            _ => false,
337        }
338    }
339
340    /// Check if JSON contains empty strings
341    fn has_empty_strings(&self, value: &Value) -> bool {
342        match value {
343            Value::String(s) if s.is_empty() => true,
344            Value::Array(arr) => arr.iter().any(|v| self.has_empty_strings(v)),
345            Value::Object(obj) => obj.values().any(|v| self.has_empty_strings(v)),
346            _ => false,
347        }
348    }
349
350    /// Get request schema from OpenAPI spec
351    fn get_request_schema_from_spec(
352        &self,
353        spec: &OpenApiSpec,
354        path: &str,
355        method: &str,
356    ) -> Option<Value> {
357        let path_item = spec.paths.get(path)?;
358        let operation = match method.to_lowercase().as_str() {
359            "get" => path_item.get.as_ref(),
360            "post" => path_item.post.as_ref(),
361            "put" => path_item.put.as_ref(),
362            "delete" => path_item.delete.as_ref(),
363            "patch" => path_item.patch.as_ref(),
364            _ => None,
365        }?;
366
367        let request_body = operation.request_body.as_ref()?;
368        let media_type = request_body.content.get("application/json")?;
369        media_type.schema.clone()
370    }
371
372    /// Get response schema from OpenAPI spec
373    fn get_response_schema_from_spec(
374        &self,
375        spec: &OpenApiSpec,
376        path: &str,
377        method: &str,
378        status: u16,
379    ) -> Option<Value> {
380        let path_item = spec.paths.get(path)?;
381        let operation = match method.to_lowercase().as_str() {
382            "get" => path_item.get.as_ref(),
383            "post" => path_item.post.as_ref(),
384            "put" => path_item.put.as_ref(),
385            "delete" => path_item.delete.as_ref(),
386            "patch" => path_item.patch.as_ref(),
387            _ => None,
388        }?;
389
390        // Try exact status code first, then default
391        let response = operation
392            .responses
393            .get(&status.to_string())
394            .or_else(|| operation.responses.get("default"))?;
395
396        let content = response.content.as_ref()?;
397        let media_type = content.get("application/json")?;
398        media_type.schema.clone()
399    }
400
401    /// Create a parsing error response
402    fn create_parsing_error(&self, error: serde_json::Error, request_id: &str) -> anyhow::Error {
403        let error_response = ValidationErrorResponse {
404            error: "Invalid JSON".to_string(),
405            status: 400,
406            validation_errors: vec![ValidationErrorDetail {
407                field: "$".to_string(),
408                message: error.to_string(),
409                value: None,
410            }],
411            request_id: request_id.to_string(),
412        };
413
414        anyhow::anyhow!(serde_json::to_string(&error_response)
415            .unwrap_or_else(|_| { format!("JSON parsing error: {}", error) }))
416    }
417
418    /// Create a validation error response
419    fn create_validation_error(
420        &self,
421        errors: Vec<ValidationErrorDetail>,
422        request_id: &str,
423    ) -> anyhow::Error {
424        let error_response = ValidationErrorResponse {
425            error: "Validation failed".to_string(),
426            status: 400,
427            validation_errors: errors,
428            request_id: request_id.to_string(),
429        };
430
431        anyhow::anyhow!(serde_json::to_string(&error_response)
432            .unwrap_or_else(|_| { "Validation failed".to_string() }))
433    }
434
435    /// Generate validation error response
436    pub fn generate_error_response(
437        &self,
438        errors: Vec<ValidationErrorDetail>,
439        request_id: &str,
440    ) -> Response<Full<Bytes>> {
441        let error_response = ValidationErrorResponse {
442            error: "Validation failed".to_string(),
443            status: 400,
444            validation_errors: errors,
445            request_id: request_id.to_string(),
446        };
447
448        let body = serde_json::to_vec(&error_response)
449            .unwrap_or_else(|_| br#"{"error":"Validation failed","status":400}"#.to_vec());
450
451        Response::builder()
452            .status(StatusCode::BAD_REQUEST)
453            .header("Content-Type", "application/json")
454            .header("X-Request-Id", request_id)
455            .body(Full::new(Bytes::from(body)))
456            .unwrap_or_else(|_| {
457                Response::builder()
458                    .status(StatusCode::INTERNAL_SERVER_ERROR)
459                    .body(Full::new(Bytes::new()))
460                    .unwrap()
461            })
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use serde_json::json;
469
470    #[test]
471    fn test_schema_validation() {
472        let schema = json!({
473            "type": "object",
474            "properties": {
475                "name": {
476                    "type": "string",
477                    "minLength": 1
478                },
479                "age": {
480                    "type": "integer",
481                    "minimum": 0
482                }
483            },
484            "required": ["name"]
485        });
486
487        let config = ApiSchemaConfig {
488            schema_file: None,
489            request_schema: Some(schema),
490            response_schema: None,
491            validate_requests: true,
492            validate_responses: false,
493            strict_mode: false,
494        };
495
496        let validator = SchemaValidator::new(config).unwrap();
497
498        // Valid JSON
499        let valid_json = json!({
500            "name": "John",
501            "age": 30
502        });
503
504        let schema = validator.request_schema.as_ref().unwrap();
505        let result = validator.validate_against_schema(schema, &valid_json, "test-123");
506        assert!(result.is_ok());
507
508        // Invalid JSON (missing required field)
509        let invalid_json = json!({
510            "age": 30
511        });
512
513        let result = validator.validate_against_schema(schema, &invalid_json, "test-124");
514        assert!(result.is_err());
515
516        // Invalid JSON (wrong type)
517        let invalid_json = json!({
518            "name": 123,
519            "age": "thirty"
520        });
521
522        let result = validator.validate_against_schema(schema, &invalid_json, "test-125");
523        assert!(result.is_err());
524    }
525
526    #[tokio::test]
527    async fn test_request_validation() {
528        let schema = json!({
529            "type": "object",
530            "properties": {
531                "email": {
532                    "type": "string",
533                    "format": "email"
534                },
535                "password": {
536                    "type": "string",
537                    "minLength": 8
538                }
539            },
540            "required": ["email", "password"]
541        });
542
543        let config = ApiSchemaConfig {
544            schema_file: None,
545            request_schema: Some(schema),
546            response_schema: None,
547            validate_requests: true,
548            validate_responses: false,
549            strict_mode: false,
550        };
551
552        let validator = SchemaValidator::new(config).unwrap();
553
554        let request = Request::post("/login")
555            .header("Content-Type", "application/json")
556            .body(())
557            .unwrap();
558
559        // Valid request body
560        let valid_body = json!({
561            "email": "user@example.com",
562            "password": "securepassword123"
563        });
564        let body_bytes = serde_json::to_vec(&valid_body).unwrap();
565
566        let result = validator
567            .validate_request(&request, &body_bytes, "/login", "req-001")
568            .await;
569        assert!(result.is_ok());
570
571        // Invalid request body
572        let invalid_body = json!({
573            "email": "not-an-email",
574            "password": "short"
575        });
576        let body_bytes = serde_json::to_vec(&invalid_body).unwrap();
577
578        let result = validator
579            .validate_request(&request, &body_bytes, "/login", "req-002")
580            .await;
581        assert!(result.is_err());
582    }
583}