sentinel_proxy/
validation.rs

1//! API schema validation module for Sentinel proxy
2//!
3//! This module provides JSON Schema validation for API routes,
4//! supporting both request and response validation with OpenAPI integration.
5
6use anyhow::{Context, Result};
7use bytes::Bytes;
8use http::{Request, Response, StatusCode};
9use http_body_util::{BodyExt, Full};
10use jsonschema::{Draft, JSONSchema, ValidationError};
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tracing::{debug, info, warn};
17
18use sentinel_config::ApiSchemaConfig;
19
20/// API schema validator
21pub struct SchemaValidator {
22    /// Configuration for schema validation
23    config: Arc<ApiSchemaConfig>,
24    /// Compiled request schema
25    request_schema: Option<Arc<JSONSchema>>,
26    /// Compiled response schema
27    response_schema: Option<Arc<JSONSchema>>,
28    /// OpenAPI specification (if loaded)
29    openapi_spec: Option<OpenApiSpec>,
30}
31
32/// OpenAPI specification
33#[derive(Debug, Clone, Deserialize)]
34struct OpenApiSpec {
35    openapi: String,
36    paths: HashMap<String, PathItem>,
37    components: Option<Components>,
38}
39
40/// OpenAPI path item
41#[derive(Debug, Clone, Deserialize)]
42struct PathItem {
43    #[serde(default)]
44    get: Option<Operation>,
45    #[serde(default)]
46    post: Option<Operation>,
47    #[serde(default)]
48    put: Option<Operation>,
49    #[serde(default)]
50    delete: Option<Operation>,
51    #[serde(default)]
52    patch: Option<Operation>,
53}
54
55/// OpenAPI operation
56#[derive(Debug, Clone, Deserialize)]
57struct Operation {
58    #[serde(rename = "operationId")]
59    operation_id: Option<String>,
60    #[serde(rename = "requestBody")]
61    request_body: Option<RequestBody>,
62    responses: HashMap<String, ApiResponse>,
63}
64
65/// OpenAPI request body
66#[derive(Debug, Clone, Deserialize)]
67struct RequestBody {
68    required: Option<bool>,
69    content: HashMap<String, MediaType>,
70}
71
72/// OpenAPI response
73#[derive(Debug, Clone, Deserialize)]
74struct ApiResponse {
75    description: String,
76    content: Option<HashMap<String, MediaType>>,
77}
78
79/// OpenAPI media type
80#[derive(Debug, Clone, Deserialize)]
81struct MediaType {
82    schema: Option<Value>,
83}
84
85/// OpenAPI components
86#[derive(Debug, Clone, Deserialize)]
87struct Components {
88    schemas: Option<HashMap<String, Value>>,
89}
90
91/// Validation error response
92#[derive(Debug, Serialize)]
93pub struct ValidationErrorResponse {
94    pub error: String,
95    pub status: u16,
96    pub validation_errors: Vec<ValidationErrorDetail>,
97    pub request_id: String,
98}
99
100/// Individual validation error detail
101#[derive(Debug, Serialize)]
102pub struct ValidationErrorDetail {
103    pub field: String,
104    pub message: String,
105    pub value: Option<Value>,
106}
107
108impl SchemaValidator {
109    /// Create a new schema validator
110    pub fn new(config: ApiSchemaConfig) -> Result<Self> {
111        let mut validator = Self {
112            config: Arc::new(config.clone()),
113            request_schema: None,
114            response_schema: None,
115            openapi_spec: None,
116        };
117
118        // Load OpenAPI specification if provided
119        if let Some(ref schema_file) = config.schema_file {
120            validator.load_openapi_spec(schema_file)?;
121        }
122
123        // Compile request schema if provided
124        if let Some(ref schema) = config.request_schema {
125            validator.request_schema = Some(Arc::new(Self::compile_schema(schema)?));
126        }
127
128        // Compile response schema if provided
129        if let Some(ref schema) = config.response_schema {
130            validator.response_schema = Some(Arc::new(Self::compile_schema(schema)?));
131        }
132
133        Ok(validator)
134    }
135
136    /// Load OpenAPI specification from file
137    fn load_openapi_spec(&mut self, path: &Path) -> Result<()> {
138        let content = std::fs::read_to_string(path)
139            .with_context(|| format!("Failed to read OpenAPI spec: {:?}", path))?;
140
141        let spec: OpenApiSpec = if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
142            serde_yaml::from_str(&content)?
143        } else {
144            serde_json::from_str(&content)?
145        };
146
147        info!("Loaded OpenAPI specification from {:?}", path);
148        self.openapi_spec = Some(spec);
149        Ok(())
150    }
151
152    /// Compile a JSON schema
153    fn compile_schema(schema: &Value) -> Result<JSONSchema> {
154        JSONSchema::options()
155            .with_draft(Draft::Draft7)
156            .compile(schema)
157            .map_err(|e| anyhow::anyhow!("Failed to compile schema: {}", e))
158    }
159
160    /// Validate a request
161    pub async fn validate_request<B>(
162        &self,
163        request: &Request<B>,
164        body: &[u8],
165        path: &str,
166        request_id: &str,
167    ) -> Result<()> {
168        if !self.config.validate_requests {
169            return Ok(());
170        }
171
172        // Parse JSON body
173        let json_body: Value = if body.is_empty() {
174            json!(null)
175        } else {
176            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
177        };
178
179        // Get the appropriate schema
180        let schema = if let Some(ref request_schema) = self.request_schema {
181            request_schema.clone()
182        } else if let Some(ref spec) = self.openapi_spec {
183            // Try to find schema from OpenAPI spec
184            match self.get_request_schema_from_spec(spec, path, request.method().as_str()) {
185                Some(s) => Arc::new(Self::compile_schema(&s)?),
186                None => {
187                    debug!("No schema found for {} {}", request.method(), path);
188                    return Ok(());
189                }
190            }
191        } else {
192            // No schema configured
193            return Ok(());
194        };
195
196        // Validate against schema
197        self.validate_against_schema(&schema, &json_body, request_id)?;
198
199        Ok(())
200    }
201
202    /// Validate a response
203    pub async fn validate_response(
204        &self,
205        status: StatusCode,
206        body: &[u8],
207        path: &str,
208        method: &str,
209        request_id: &str,
210    ) -> Result<()> {
211        if !self.config.validate_responses {
212            return Ok(());
213        }
214
215        // Parse JSON body
216        let json_body: Value = if body.is_empty() {
217            json!(null)
218        } else {
219            serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
220        };
221
222        // Get the appropriate schema
223        let schema = if let Some(ref response_schema) = self.response_schema {
224            response_schema.clone()
225        } else if let Some(ref spec) = self.openapi_spec {
226            // Try to find schema from OpenAPI spec
227            match self.get_response_schema_from_spec(spec, path, method, status.as_u16()) {
228                Some(s) => Arc::new(Self::compile_schema(&s)?),
229                None => {
230                    debug!(
231                        "No schema found for {} {} response {}",
232                        method, path, status
233                    );
234                    return Ok(());
235                }
236            }
237        } else {
238            // No schema configured
239            return Ok(());
240        };
241
242        // Validate against schema
243        self.validate_against_schema(&schema, &json_body, request_id)?;
244
245        Ok(())
246    }
247
248    /// Validate JSON against a schema
249    fn validate_against_schema(
250        &self,
251        schema: &JSONSchema,
252        instance: &Value,
253        request_id: &str,
254    ) -> Result<()> {
255        let result = schema.validate(instance);
256
257        if let Err(errors) = result {
258            let validation_errors: Vec<ValidationErrorDetail> = errors
259                .map(|error| self.format_validation_error(error, instance))
260                .collect();
261
262            if !validation_errors.is_empty() {
263                return Err(self.create_validation_error(validation_errors, request_id));
264            }
265        }
266
267        // Additional strict mode checks
268        if self.config.strict_mode {
269            self.strict_mode_checks(schema, instance, request_id)?;
270        }
271
272        Ok(())
273    }
274
275    /// Format a validation error
276    fn format_validation_error(
277        &self,
278        error: ValidationError,
279        instance: &Value,
280    ) -> ValidationErrorDetail {
281        let field = error.instance_path.to_string();
282        let field = if field.is_empty() {
283            "$".to_string()
284        } else {
285            field
286        };
287
288        let value = error
289            .instance_path
290            .iter()
291            .fold(Some(instance), |acc, segment| {
292                acc.and_then(|v| match segment {
293                    jsonschema::paths::PathChunk::Property(prop) => v.get(prop.as_ref()),
294                    jsonschema::paths::PathChunk::Index(idx) => v.get(idx),
295                    _ => None,
296                })
297            })
298            .cloned();
299
300        ValidationErrorDetail {
301            field,
302            message: error.to_string(),
303            value,
304        }
305    }
306
307    /// Perform strict mode checks
308    fn strict_mode_checks(
309        &self,
310        _schema: &JSONSchema,
311        instance: &Value,
312        _request_id: &str,
313    ) -> Result<()> {
314        // Check for null values
315        if self.has_null_values(instance) {
316            warn!("Strict mode: Found null values in JSON");
317        }
318
319        // Check for empty strings
320        if self.has_empty_strings(instance) {
321            warn!("Strict mode: Found empty strings in JSON");
322        }
323
324        Ok(())
325    }
326
327    /// Check if JSON contains null values
328    fn has_null_values(&self, value: &Value) -> bool {
329        match value {
330            Value::Null => true,
331            Value::Array(arr) => arr.iter().any(|v| self.has_null_values(v)),
332            Value::Object(obj) => obj.values().any(|v| self.has_null_values(v)),
333            _ => false,
334        }
335    }
336
337    /// Check if JSON contains empty strings
338    fn has_empty_strings(&self, value: &Value) -> bool {
339        match value {
340            Value::String(s) if s.is_empty() => true,
341            Value::Array(arr) => arr.iter().any(|v| self.has_empty_strings(v)),
342            Value::Object(obj) => obj.values().any(|v| self.has_empty_strings(v)),
343            _ => false,
344        }
345    }
346
347    /// Get request schema from OpenAPI spec
348    fn get_request_schema_from_spec(
349        &self,
350        spec: &OpenApiSpec,
351        path: &str,
352        method: &str,
353    ) -> Option<Value> {
354        let path_item = spec.paths.get(path)?;
355        let operation = match method.to_lowercase().as_str() {
356            "get" => path_item.get.as_ref(),
357            "post" => path_item.post.as_ref(),
358            "put" => path_item.put.as_ref(),
359            "delete" => path_item.delete.as_ref(),
360            "patch" => path_item.patch.as_ref(),
361            _ => None,
362        }?;
363
364        let request_body = operation.request_body.as_ref()?;
365        let media_type = request_body.content.get("application/json")?;
366        media_type.schema.clone()
367    }
368
369    /// Get response schema from OpenAPI spec
370    fn get_response_schema_from_spec(
371        &self,
372        spec: &OpenApiSpec,
373        path: &str,
374        method: &str,
375        status: u16,
376    ) -> Option<Value> {
377        let path_item = spec.paths.get(path)?;
378        let operation = match method.to_lowercase().as_str() {
379            "get" => path_item.get.as_ref(),
380            "post" => path_item.post.as_ref(),
381            "put" => path_item.put.as_ref(),
382            "delete" => path_item.delete.as_ref(),
383            "patch" => path_item.patch.as_ref(),
384            _ => None,
385        }?;
386
387        // Try exact status code first, then default
388        let response = operation
389            .responses
390            .get(&status.to_string())
391            .or_else(|| operation.responses.get("default"))?;
392
393        let content = response.content.as_ref()?;
394        let media_type = content.get("application/json")?;
395        media_type.schema.clone()
396    }
397
398    /// Create a parsing error response
399    fn create_parsing_error(&self, error: serde_json::Error, request_id: &str) -> anyhow::Error {
400        let error_response = ValidationErrorResponse {
401            error: "Invalid JSON".to_string(),
402            status: 400,
403            validation_errors: vec![ValidationErrorDetail {
404                field: "$".to_string(),
405                message: error.to_string(),
406                value: None,
407            }],
408            request_id: request_id.to_string(),
409        };
410
411        anyhow::anyhow!(serde_json::to_string(&error_response)
412            .unwrap_or_else(|_| { format!("JSON parsing error: {}", error) }))
413    }
414
415    /// Create a validation error response
416    fn create_validation_error(
417        &self,
418        errors: Vec<ValidationErrorDetail>,
419        request_id: &str,
420    ) -> anyhow::Error {
421        let error_response = ValidationErrorResponse {
422            error: "Validation failed".to_string(),
423            status: 400,
424            validation_errors: errors,
425            request_id: request_id.to_string(),
426        };
427
428        anyhow::anyhow!(serde_json::to_string(&error_response)
429            .unwrap_or_else(|_| { "Validation failed".to_string() }))
430    }
431
432    /// Generate validation error response
433    pub fn generate_error_response(
434        &self,
435        errors: Vec<ValidationErrorDetail>,
436        request_id: &str,
437    ) -> Response<Full<Bytes>> {
438        let error_response = ValidationErrorResponse {
439            error: "Validation failed".to_string(),
440            status: 400,
441            validation_errors: errors,
442            request_id: request_id.to_string(),
443        };
444
445        let body = serde_json::to_vec(&error_response)
446            .unwrap_or_else(|_| br#"{"error":"Validation failed","status":400}"#.to_vec());
447
448        Response::builder()
449            .status(StatusCode::BAD_REQUEST)
450            .header("Content-Type", "application/json")
451            .header("X-Request-Id", request_id)
452            .body(Full::new(Bytes::from(body)))
453            .unwrap_or_else(|_| {
454                Response::builder()
455                    .status(StatusCode::INTERNAL_SERVER_ERROR)
456                    .body(Full::new(Bytes::new()))
457                    .unwrap()
458            })
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use serde_json::json;
466
467    #[test]
468    fn test_schema_validation() {
469        let schema = json!({
470            "type": "object",
471            "properties": {
472                "name": {
473                    "type": "string",
474                    "minLength": 1
475                },
476                "age": {
477                    "type": "integer",
478                    "minimum": 0
479                }
480            },
481            "required": ["name"]
482        });
483
484        let config = ApiSchemaConfig {
485            schema_file: None,
486            request_schema: Some(schema),
487            response_schema: None,
488            validate_requests: true,
489            validate_responses: false,
490            strict_mode: false,
491        };
492
493        let validator = SchemaValidator::new(config).unwrap();
494
495        // Valid JSON
496        let valid_json = json!({
497            "name": "John",
498            "age": 30
499        });
500
501        let schema = validator.request_schema.as_ref().unwrap();
502        let result = validator.validate_against_schema(schema, &valid_json, "test-123");
503        assert!(result.is_ok());
504
505        // Invalid JSON (missing required field)
506        let invalid_json = json!({
507            "age": 30
508        });
509
510        let result = validator.validate_against_schema(schema, &invalid_json, "test-124");
511        assert!(result.is_err());
512
513        // Invalid JSON (wrong type)
514        let invalid_json = json!({
515            "name": 123,
516            "age": "thirty"
517        });
518
519        let result = validator.validate_against_schema(schema, &invalid_json, "test-125");
520        assert!(result.is_err());
521    }
522
523    #[tokio::test]
524    async fn test_request_validation() {
525        let schema = json!({
526            "type": "object",
527            "properties": {
528                "email": {
529                    "type": "string",
530                    "format": "email"
531                },
532                "password": {
533                    "type": "string",
534                    "minLength": 8
535                }
536            },
537            "required": ["email", "password"]
538        });
539
540        let config = ApiSchemaConfig {
541            schema_file: None,
542            request_schema: Some(schema),
543            response_schema: None,
544            validate_requests: true,
545            validate_responses: false,
546            strict_mode: false,
547        };
548
549        let validator = SchemaValidator::new(config).unwrap();
550
551        let request = Request::post("/login")
552            .header("Content-Type", "application/json")
553            .body(())
554            .unwrap();
555
556        // Valid request body
557        let valid_body = json!({
558            "email": "user@example.com",
559            "password": "securepassword123"
560        });
561        let body_bytes = serde_json::to_vec(&valid_body).unwrap();
562
563        let result = validator
564            .validate_request(&request, &body_bytes, "/login", "req-001")
565            .await;
566        assert!(result.is_ok());
567
568        // Invalid request body
569        let invalid_body = json!({
570            "email": "not-an-email",
571            "password": "short"
572        });
573        let body_bytes = serde_json::to_vec(&invalid_body).unwrap();
574
575        let result = validator
576            .validate_request(&request, &body_bytes, "/login", "req-002")
577            .await;
578        assert!(result.is_err());
579    }
580}