Skip to main content

sentinel_proxy/
validation.rs

1//! API schema validation module for Sentinel proxy
2//!
3//! This module provides JSON Schema validation for API routes,
4//! supporting both request and response validation with OpenAPI integration.
5
6use anyhow::{Context, Result};
7use bytes::Bytes;
8use http::{Request, Response, StatusCode};
9use http_body_util::{BodyExt, Full};
10use jsonschema::Validator;
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tracing::{debug, info, warn};
17
18use sentinel_config::ApiSchemaConfig;
19
20/// API schema validator
21pub struct SchemaValidator {
22    /// Configuration for schema validation
23    config: Arc<ApiSchemaConfig>,
24    /// Compiled request schema
25    request_schema: Option<Arc<Validator>>,
26    /// Compiled response schema
27    response_schema: Option<Arc<Validator>>,
28    /// OpenAPI specification (if loaded)
29    openapi_spec: Option<OpenApiSpec>,
30}
31
32/// OpenAPI specification
33#[derive(Debug, Clone, Deserialize)]
34struct OpenApiSpec {
35    openapi: String,
36    paths: HashMap<String, PathItem>,
37    components: Option<Components>,
38}
39
40/// OpenAPI path item
41#[derive(Debug, Clone, Deserialize)]
42struct PathItem {
43    #[serde(default)]
44    get: Option<Operation>,
45    #[serde(default)]
46    post: Option<Operation>,
47    #[serde(default)]
48    put: Option<Operation>,
49    #[serde(default)]
50    delete: Option<Operation>,
51    #[serde(default)]
52    patch: Option<Operation>,
53}
54
55/// OpenAPI operation
56#[derive(Debug, Clone, Deserialize)]
57struct Operation {
58    #[serde(rename = "operationId")]
59    operation_id: Option<String>,
60    #[serde(rename = "requestBody")]
61    request_body: Option<RequestBody>,
62    responses: HashMap<String, ApiResponse>,
63}
64
65/// OpenAPI request body
66#[derive(Debug, Clone, Deserialize)]
67struct RequestBody {
68    required: Option<bool>,
69    content: HashMap<String, MediaType>,
70}
71
72/// OpenAPI response
73#[derive(Debug, Clone, Deserialize)]
74struct ApiResponse {
75    description: String,
76    content: Option<HashMap<String, MediaType>>,
77}
78
79/// OpenAPI media type
80#[derive(Debug, Clone, Deserialize)]
81struct MediaType {
82    schema: Option<Value>,
83}
84
85/// OpenAPI components
86#[derive(Debug, Clone, Deserialize)]
87struct Components {
88    schemas: Option<HashMap<String, Value>>,
89}
90
91/// Validation error response
92#[derive(Debug, Serialize)]
93pub struct ValidationErrorResponse {
94    pub error: String,
95    pub status: u16,
96    pub validation_errors: Vec<ValidationErrorDetail>,
97    pub request_id: String,
98}
99
100/// Individual validation error detail
101#[derive(Debug, Serialize)]
102pub struct ValidationErrorDetail {
103    pub field: String,
104    pub message: String,
105    pub value: Option<Value>,
106}
107
108impl SchemaValidator {
109    /// Create a new schema validator
110    pub fn new(config: ApiSchemaConfig) -> Result<Self> {
111        let mut validator = Self {
112            config: Arc::new(config.clone()),
113            request_schema: None,
114            response_schema: None,
115            openapi_spec: None,
116        };
117
118        // Load OpenAPI specification if provided
119        if let Some(ref schema_file) = config.schema_file {
120            validator.load_openapi_spec(schema_file)?;
121        }
122
123        // Compile request schema if provided
124        if let Some(ref schema) = config.request_schema {
125            validator.request_schema = Some(Arc::new(Self::compile_schema(schema)?));
126        }
127
128        // Compile response schema if provided
129        if let Some(ref schema) = config.response_schema {
130            validator.response_schema = Some(Arc::new(Self::compile_schema(schema)?));
131        }
132
133        Ok(validator)
134    }
135
136    /// Load OpenAPI specification from file
137    fn load_openapi_spec(&mut self, path: &Path) -> Result<()> {
138        let content = std::fs::read_to_string(path)
139            .with_context(|| format!("Failed to read OpenAPI spec: {:?}", path))?;
140
141        let spec: OpenApiSpec = if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
142            serde_yaml::from_str(&content)?
143        } else {
144            serde_json::from_str(&content)?
145        };
146
147        info!("Loaded OpenAPI specification from {:?}", path);
148        self.openapi_spec = Some(spec);
149        Ok(())
150    }
151
152    /// Compile a JSON schema
153    fn compile_schema(schema: &Value) -> Result<Validator> {
154        jsonschema::draft7::new(schema)
155            .map_err(|e| anyhow::anyhow!("Failed to compile schema: {}", e))
156    }
157
158    /// Validate a request
159    pub async fn validate_request<B>(
160        &self,
161        request: &Request<B>,
162        body: &[u8],
163        path: &str,
164        request_id: &str,
165    ) -> Result<()> {
166        if !self.config.validate_requests {
167            return Ok(());
168        }
169
170        // Parse JSON body
171        let json_body: Value = if body.is_empty() {
172            json!(null)
173        } else {
174            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
175        };
176
177        // Get the appropriate schema
178        let schema = if let Some(ref request_schema) = self.request_schema {
179            request_schema.clone()
180        } else if let Some(ref spec) = self.openapi_spec {
181            // Try to find schema from OpenAPI spec
182            match self.get_request_schema_from_spec(spec, path, request.method().as_str()) {
183                Some(s) => Arc::new(Self::compile_schema(&s)?),
184                None => {
185                    debug!("No schema found for {} {}", request.method(), path);
186                    return Ok(());
187                }
188            }
189        } else {
190            // No schema configured
191            return Ok(());
192        };
193
194        // Validate against schema
195        self.validate_against_schema(&schema, &json_body, request_id)?;
196
197        Ok(())
198    }
199
200    /// Validate a response
201    pub async fn validate_response(
202        &self,
203        status: StatusCode,
204        body: &[u8],
205        path: &str,
206        method: &str,
207        request_id: &str,
208    ) -> Result<()> {
209        if !self.config.validate_responses {
210            return Ok(());
211        }
212
213        // Parse JSON body
214        let json_body: Value = if body.is_empty() {
215            json!(null)
216        } else {
217            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
218        };
219
220        // Get the appropriate schema
221        let schema = if let Some(ref response_schema) = self.response_schema {
222            response_schema.clone()
223        } else if let Some(ref spec) = self.openapi_spec {
224            // Try to find schema from OpenAPI spec
225            match self.get_response_schema_from_spec(spec, path, method, status.as_u16()) {
226                Some(s) => Arc::new(Self::compile_schema(&s)?),
227                None => {
228                    debug!(
229                        "No schema found for {} {} response {}",
230                        method, path, status
231                    );
232                    return Ok(());
233                }
234            }
235        } else {
236            // No schema configured
237            return Ok(());
238        };
239
240        // Validate against schema
241        self.validate_against_schema(&schema, &json_body, request_id)?;
242
243        Ok(())
244    }
245
246    /// Validate JSON against a schema
247    fn validate_against_schema(
248        &self,
249        schema: &Validator,
250        instance: &Value,
251        request_id: &str,
252    ) -> Result<()> {
253        let validation_errors: Vec<ValidationErrorDetail> = schema
254            .iter_errors(instance)
255            .map(|error| self.format_validation_error(&error, instance))
256            .collect();
257
258        if !validation_errors.is_empty() {
259            return Err(self.create_validation_error(validation_errors, request_id));
260        }
261
262        // Additional strict mode checks
263        if self.config.strict_mode {
264            self.strict_mode_checks(schema, instance, request_id)?;
265        }
266
267        Ok(())
268    }
269
270    /// Format a validation error
271    fn format_validation_error(
272        &self,
273        error: &jsonschema::ValidationError,
274        instance: &Value,
275    ) -> ValidationErrorDetail {
276        let field = error.instance_path().to_string();
277        let field = if field.is_empty() {
278            "$".to_string()
279        } else {
280            field
281        };
282
283        let value = error
284            .instance_path()
285            .iter()
286            .fold(Some(instance), |acc: Option<&Value>, segment| {
287                acc.and_then(|v| match segment {
288                    jsonschema::paths::LocationSegment::Property(prop) => v.get(prop.as_ref()),
289                    jsonschema::paths::LocationSegment::Index(idx) => v.get(idx),
290                })
291            })
292            .cloned();
293
294        ValidationErrorDetail {
295            field,
296            message: error.to_string(),
297            value,
298        }
299    }
300
301    /// Perform strict mode checks
302    fn strict_mode_checks(
303        &self,
304        _schema: &Validator,
305        instance: &Value,
306        _request_id: &str,
307    ) -> Result<()> {
308        // Check for null values
309        if self.has_null_values(instance) {
310            warn!("Strict mode: Found null values in JSON");
311        }
312
313        // Check for empty strings
314        if self.has_empty_strings(instance) {
315            warn!("Strict mode: Found empty strings in JSON");
316        }
317
318        Ok(())
319    }
320
321    /// Check if JSON contains null values
322    fn has_null_values(&self, value: &Value) -> bool {
323        match value {
324            Value::Null => true,
325            Value::Array(arr) => arr.iter().any(|v| self.has_null_values(v)),
326            Value::Object(obj) => obj.values().any(|v| self.has_null_values(v)),
327            _ => false,
328        }
329    }
330
331    /// Check if JSON contains empty strings
332    fn has_empty_strings(&self, value: &Value) -> bool {
333        match value {
334            Value::String(s) if s.is_empty() => true,
335            Value::Array(arr) => arr.iter().any(|v| self.has_empty_strings(v)),
336            Value::Object(obj) => obj.values().any(|v| self.has_empty_strings(v)),
337            _ => false,
338        }
339    }
340
341    /// Get request schema from OpenAPI spec
342    fn get_request_schema_from_spec(
343        &self,
344        spec: &OpenApiSpec,
345        path: &str,
346        method: &str,
347    ) -> Option<Value> {
348        let path_item = spec.paths.get(path)?;
349        let operation = match method.to_lowercase().as_str() {
350            "get" => path_item.get.as_ref(),
351            "post" => path_item.post.as_ref(),
352            "put" => path_item.put.as_ref(),
353            "delete" => path_item.delete.as_ref(),
354            "patch" => path_item.patch.as_ref(),
355            _ => None,
356        }?;
357
358        let request_body = operation.request_body.as_ref()?;
359        let media_type = request_body.content.get("application/json")?;
360        media_type.schema.clone()
361    }
362
363    /// Get response schema from OpenAPI spec
364    fn get_response_schema_from_spec(
365        &self,
366        spec: &OpenApiSpec,
367        path: &str,
368        method: &str,
369        status: u16,
370    ) -> Option<Value> {
371        let path_item = spec.paths.get(path)?;
372        let operation = match method.to_lowercase().as_str() {
373            "get" => path_item.get.as_ref(),
374            "post" => path_item.post.as_ref(),
375            "put" => path_item.put.as_ref(),
376            "delete" => path_item.delete.as_ref(),
377            "patch" => path_item.patch.as_ref(),
378            _ => None,
379        }?;
380
381        // Try exact status code first, then default
382        let response = operation
383            .responses
384            .get(&status.to_string())
385            .or_else(|| operation.responses.get("default"))?;
386
387        let content = response.content.as_ref()?;
388        let media_type = content.get("application/json")?;
389        media_type.schema.clone()
390    }
391
392    /// Create a parsing error response
393    fn create_parsing_error(&self, error: serde_json::Error, request_id: &str) -> anyhow::Error {
394        let error_response = ValidationErrorResponse {
395            error: "Invalid JSON".to_string(),
396            status: 400,
397            validation_errors: vec![ValidationErrorDetail {
398                field: "$".to_string(),
399                message: error.to_string(),
400                value: None,
401            }],
402            request_id: request_id.to_string(),
403        };
404
405        anyhow::anyhow!(serde_json::to_string(&error_response)
406            .unwrap_or_else(|_| { format!("JSON parsing error: {}", error) }))
407    }
408
409    /// Create a validation error response
410    fn create_validation_error(
411        &self,
412        errors: Vec<ValidationErrorDetail>,
413        request_id: &str,
414    ) -> anyhow::Error {
415        let error_response = ValidationErrorResponse {
416            error: "Validation failed".to_string(),
417            status: 400,
418            validation_errors: errors,
419            request_id: request_id.to_string(),
420        };
421
422        anyhow::anyhow!(serde_json::to_string(&error_response)
423            .unwrap_or_else(|_| { "Validation failed".to_string() }))
424    }
425
426    /// Generate validation error response
427    pub fn generate_error_response(
428        &self,
429        errors: Vec<ValidationErrorDetail>,
430        request_id: &str,
431    ) -> Response<Full<Bytes>> {
432        let error_response = ValidationErrorResponse {
433            error: "Validation failed".to_string(),
434            status: 400,
435            validation_errors: errors,
436            request_id: request_id.to_string(),
437        };
438
439        let body = serde_json::to_vec(&error_response)
440            .unwrap_or_else(|_| br#"{"error":"Validation failed","status":400}"#.to_vec());
441
442        Response::builder()
443            .status(StatusCode::BAD_REQUEST)
444            .header("Content-Type", "application/json")
445            .header("X-Request-Id", request_id)
446            .body(Full::new(Bytes::from(body)))
447            .unwrap_or_else(|_| {
448                Response::builder()
449                    .status(StatusCode::INTERNAL_SERVER_ERROR)
450                    .body(Full::new(Bytes::new()))
451                    .unwrap()
452            })
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use serde_json::json;
460
461    #[test]
462    fn test_schema_validation() {
463        let schema = json!({
464            "type": "object",
465            "properties": {
466                "name": {
467                    "type": "string",
468                    "minLength": 1
469                },
470                "age": {
471                    "type": "integer",
472                    "minimum": 0
473                }
474            },
475            "required": ["name"]
476        });
477
478        let config = ApiSchemaConfig {
479            schema_file: None,
480            schema_content: None,
481            request_schema: Some(schema),
482            response_schema: None,
483            validate_requests: true,
484            validate_responses: false,
485            strict_mode: false,
486        };
487
488        let validator = SchemaValidator::new(config).unwrap();
489
490        // Valid JSON
491        let valid_json = json!({
492            "name": "John",
493            "age": 30
494        });
495
496        let schema = validator.request_schema.as_ref().unwrap();
497        let result = validator.validate_against_schema(schema, &valid_json, "test-123");
498        assert!(result.is_ok());
499
500        // Invalid JSON (missing required field)
501        let invalid_json = json!({
502            "age": 30
503        });
504
505        let result = validator.validate_against_schema(schema, &invalid_json, "test-124");
506        assert!(result.is_err());
507
508        // Invalid JSON (wrong type)
509        let invalid_json = json!({
510            "name": 123,
511            "age": "thirty"
512        });
513
514        let result = validator.validate_against_schema(schema, &invalid_json, "test-125");
515        assert!(result.is_err());
516    }
517
518    #[tokio::test]
519    async fn test_request_validation() {
520        let schema = json!({
521            "type": "object",
522            "properties": {
523                "email": {
524                    "type": "string",
525                    "format": "email"
526                },
527                "password": {
528                    "type": "string",
529                    "minLength": 8
530                }
531            },
532            "required": ["email", "password"]
533        });
534
535        let config = ApiSchemaConfig {
536            schema_file: None,
537            schema_content: None,
538            request_schema: Some(schema),
539            response_schema: None,
540            validate_requests: true,
541            validate_responses: false,
542            strict_mode: false,
543        };
544
545        let validator = SchemaValidator::new(config).unwrap();
546
547        let request = Request::post("/login")
548            .header("Content-Type", "application/json")
549            .body(())
550            .unwrap();
551
552        // Valid request body
553        let valid_body = json!({
554            "email": "user@example.com",
555            "password": "securepassword123"
556        });
557        let body_bytes = serde_json::to_vec(&valid_body).unwrap();
558
559        let result = validator
560            .validate_request(&request, &body_bytes, "/login", "req-001")
561            .await;
562        assert!(result.is_ok());
563
564        // Invalid request body
565        let invalid_body = json!({
566            "email": "not-an-email",
567            "password": "short"
568        });
569        let body_bytes = serde_json::to_vec(&invalid_body).unwrap();
570
571        let result = validator
572            .validate_request(&request, &body_bytes, "/login", "req-002")
573            .await;
574        assert!(result.is_err());
575    }
576}