1use anyhow::{Context, Result};
7use bytes::Bytes;
8use http::{Request, Response, StatusCode};
9use http_body_util::{BodyExt, Full};
10use jsonschema::Validator;
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tracing::{debug, info, warn};
17
18use sentinel_config::ApiSchemaConfig;
19
20pub struct SchemaValidator {
22 config: Arc<ApiSchemaConfig>,
24 request_schema: Option<Arc<Validator>>,
26 response_schema: Option<Arc<Validator>>,
28 openapi_spec: Option<OpenApiSpec>,
30}
31
32#[derive(Debug, Clone, Deserialize)]
34struct OpenApiSpec {
35 openapi: String,
36 paths: HashMap<String, PathItem>,
37 components: Option<Components>,
38}
39
40#[derive(Debug, Clone, Deserialize)]
42struct PathItem {
43 #[serde(default)]
44 get: Option<Operation>,
45 #[serde(default)]
46 post: Option<Operation>,
47 #[serde(default)]
48 put: Option<Operation>,
49 #[serde(default)]
50 delete: Option<Operation>,
51 #[serde(default)]
52 patch: Option<Operation>,
53}
54
55#[derive(Debug, Clone, Deserialize)]
57struct Operation {
58 #[serde(rename = "operationId")]
59 operation_id: Option<String>,
60 #[serde(rename = "requestBody")]
61 request_body: Option<RequestBody>,
62 responses: HashMap<String, ApiResponse>,
63}
64
65#[derive(Debug, Clone, Deserialize)]
67struct RequestBody {
68 required: Option<bool>,
69 content: HashMap<String, MediaType>,
70}
71
72#[derive(Debug, Clone, Deserialize)]
74struct ApiResponse {
75 description: String,
76 content: Option<HashMap<String, MediaType>>,
77}
78
79#[derive(Debug, Clone, Deserialize)]
81struct MediaType {
82 schema: Option<Value>,
83}
84
85#[derive(Debug, Clone, Deserialize)]
87struct Components {
88 schemas: Option<HashMap<String, Value>>,
89}
90
91#[derive(Debug, Serialize)]
93pub struct ValidationErrorResponse {
94 pub error: String,
95 pub status: u16,
96 pub validation_errors: Vec<ValidationErrorDetail>,
97 pub request_id: String,
98}
99
100#[derive(Debug, Serialize)]
102pub struct ValidationErrorDetail {
103 pub field: String,
104 pub message: String,
105 pub value: Option<Value>,
106}
107
108impl SchemaValidator {
109 pub fn new(config: ApiSchemaConfig) -> Result<Self> {
111 let mut validator = Self {
112 config: Arc::new(config.clone()),
113 request_schema: None,
114 response_schema: None,
115 openapi_spec: None,
116 };
117
118 if let Some(ref schema_file) = config.schema_file {
120 validator.load_openapi_spec(schema_file)?;
121 }
122
123 if let Some(ref schema) = config.request_schema {
125 validator.request_schema = Some(Arc::new(Self::compile_schema(schema)?));
126 }
127
128 if let Some(ref schema) = config.response_schema {
130 validator.response_schema = Some(Arc::new(Self::compile_schema(schema)?));
131 }
132
133 Ok(validator)
134 }
135
136 fn load_openapi_spec(&mut self, path: &Path) -> Result<()> {
138 let content = std::fs::read_to_string(path)
139 .with_context(|| format!("Failed to read OpenAPI spec: {:?}", path))?;
140
141 let spec: OpenApiSpec = if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
142 serde_yaml::from_str(&content)?
143 } else {
144 serde_json::from_str(&content)?
145 };
146
147 info!("Loaded OpenAPI specification from {:?}", path);
148 self.openapi_spec = Some(spec);
149 Ok(())
150 }
151
152 fn compile_schema(schema: &Value) -> Result<Validator> {
154 jsonschema::draft7::new(schema)
155 .map_err(|e| anyhow::anyhow!("Failed to compile schema: {}", e))
156 }
157
158 pub async fn validate_request<B>(
160 &self,
161 request: &Request<B>,
162 body: &[u8],
163 path: &str,
164 request_id: &str,
165 ) -> Result<()> {
166 if !self.config.validate_requests {
167 return Ok(());
168 }
169
170 let json_body: Value = if body.is_empty() {
172 json!(null)
173 } else {
174 serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
175 };
176
177 let schema = if let Some(ref request_schema) = self.request_schema {
179 request_schema.clone()
180 } else if let Some(ref spec) = self.openapi_spec {
181 match self.get_request_schema_from_spec(spec, path, request.method().as_str()) {
183 Some(s) => Arc::new(Self::compile_schema(&s)?),
184 None => {
185 debug!("No schema found for {} {}", request.method(), path);
186 return Ok(());
187 }
188 }
189 } else {
190 return Ok(());
192 };
193
194 self.validate_against_schema(&schema, &json_body, request_id)?;
196
197 Ok(())
198 }
199
200 pub async fn validate_response(
202 &self,
203 status: StatusCode,
204 body: &[u8],
205 path: &str,
206 method: &str,
207 request_id: &str,
208 ) -> Result<()> {
209 if !self.config.validate_responses {
210 return Ok(());
211 }
212
213 let json_body: Value = if body.is_empty() {
215 json!(null)
216 } else {
217 serde_json::from_slice(body).map_err(|e| self.create_parsing_error(e, request_id))?
218 };
219
220 let schema = if let Some(ref response_schema) = self.response_schema {
222 response_schema.clone()
223 } else if let Some(ref spec) = self.openapi_spec {
224 match self.get_response_schema_from_spec(spec, path, method, status.as_u16()) {
226 Some(s) => Arc::new(Self::compile_schema(&s)?),
227 None => {
228 debug!(
229 "No schema found for {} {} response {}",
230 method, path, status
231 );
232 return Ok(());
233 }
234 }
235 } else {
236 return Ok(());
238 };
239
240 self.validate_against_schema(&schema, &json_body, request_id)?;
242
243 Ok(())
244 }
245
246 fn validate_against_schema(
248 &self,
249 schema: &Validator,
250 instance: &Value,
251 request_id: &str,
252 ) -> Result<()> {
253 let validation_errors: Vec<ValidationErrorDetail> = schema
254 .iter_errors(instance)
255 .map(|error| self.format_validation_error(&error, instance))
256 .collect();
257
258 if !validation_errors.is_empty() {
259 return Err(self.create_validation_error(validation_errors, request_id));
260 }
261
262 if self.config.strict_mode {
264 self.strict_mode_checks(schema, instance, request_id)?;
265 }
266
267 Ok(())
268 }
269
270 fn format_validation_error(
272 &self,
273 error: &jsonschema::ValidationError,
274 instance: &Value,
275 ) -> ValidationErrorDetail {
276 let field = error.instance_path().to_string();
277 let field = if field.is_empty() {
278 "$".to_string()
279 } else {
280 field
281 };
282
283 let value = error
284 .instance_path()
285 .iter()
286 .fold(Some(instance), |acc: Option<&Value>, segment| {
287 acc.and_then(|v| match segment {
288 jsonschema::paths::LocationSegment::Property(prop) => v.get(prop.as_ref()),
289 jsonschema::paths::LocationSegment::Index(idx) => v.get(idx),
290 })
291 })
292 .cloned();
293
294 ValidationErrorDetail {
295 field,
296 message: error.to_string(),
297 value,
298 }
299 }
300
301 fn strict_mode_checks(
303 &self,
304 _schema: &Validator,
305 instance: &Value,
306 _request_id: &str,
307 ) -> Result<()> {
308 if self.has_null_values(instance) {
310 warn!("Strict mode: Found null values in JSON");
311 }
312
313 if self.has_empty_strings(instance) {
315 warn!("Strict mode: Found empty strings in JSON");
316 }
317
318 Ok(())
319 }
320
321 fn has_null_values(&self, value: &Value) -> bool {
323 match value {
324 Value::Null => true,
325 Value::Array(arr) => arr.iter().any(|v| self.has_null_values(v)),
326 Value::Object(obj) => obj.values().any(|v| self.has_null_values(v)),
327 _ => false,
328 }
329 }
330
331 fn has_empty_strings(&self, value: &Value) -> bool {
333 match value {
334 Value::String(s) if s.is_empty() => true,
335 Value::Array(arr) => arr.iter().any(|v| self.has_empty_strings(v)),
336 Value::Object(obj) => obj.values().any(|v| self.has_empty_strings(v)),
337 _ => false,
338 }
339 }
340
341 fn get_request_schema_from_spec(
343 &self,
344 spec: &OpenApiSpec,
345 path: &str,
346 method: &str,
347 ) -> Option<Value> {
348 let path_item = spec.paths.get(path)?;
349 let operation = match method.to_lowercase().as_str() {
350 "get" => path_item.get.as_ref(),
351 "post" => path_item.post.as_ref(),
352 "put" => path_item.put.as_ref(),
353 "delete" => path_item.delete.as_ref(),
354 "patch" => path_item.patch.as_ref(),
355 _ => None,
356 }?;
357
358 let request_body = operation.request_body.as_ref()?;
359 let media_type = request_body.content.get("application/json")?;
360 media_type.schema.clone()
361 }
362
363 fn get_response_schema_from_spec(
365 &self,
366 spec: &OpenApiSpec,
367 path: &str,
368 method: &str,
369 status: u16,
370 ) -> Option<Value> {
371 let path_item = spec.paths.get(path)?;
372 let operation = match method.to_lowercase().as_str() {
373 "get" => path_item.get.as_ref(),
374 "post" => path_item.post.as_ref(),
375 "put" => path_item.put.as_ref(),
376 "delete" => path_item.delete.as_ref(),
377 "patch" => path_item.patch.as_ref(),
378 _ => None,
379 }?;
380
381 let response = operation
383 .responses
384 .get(&status.to_string())
385 .or_else(|| operation.responses.get("default"))?;
386
387 let content = response.content.as_ref()?;
388 let media_type = content.get("application/json")?;
389 media_type.schema.clone()
390 }
391
392 fn create_parsing_error(&self, error: serde_json::Error, request_id: &str) -> anyhow::Error {
394 let error_response = ValidationErrorResponse {
395 error: "Invalid JSON".to_string(),
396 status: 400,
397 validation_errors: vec![ValidationErrorDetail {
398 field: "$".to_string(),
399 message: error.to_string(),
400 value: None,
401 }],
402 request_id: request_id.to_string(),
403 };
404
405 anyhow::anyhow!(serde_json::to_string(&error_response)
406 .unwrap_or_else(|_| { format!("JSON parsing error: {}", error) }))
407 }
408
409 fn create_validation_error(
411 &self,
412 errors: Vec<ValidationErrorDetail>,
413 request_id: &str,
414 ) -> anyhow::Error {
415 let error_response = ValidationErrorResponse {
416 error: "Validation failed".to_string(),
417 status: 400,
418 validation_errors: errors,
419 request_id: request_id.to_string(),
420 };
421
422 anyhow::anyhow!(serde_json::to_string(&error_response)
423 .unwrap_or_else(|_| { "Validation failed".to_string() }))
424 }
425
426 pub fn generate_error_response(
428 &self,
429 errors: Vec<ValidationErrorDetail>,
430 request_id: &str,
431 ) -> Response<Full<Bytes>> {
432 let error_response = ValidationErrorResponse {
433 error: "Validation failed".to_string(),
434 status: 400,
435 validation_errors: errors,
436 request_id: request_id.to_string(),
437 };
438
439 let body = serde_json::to_vec(&error_response)
440 .unwrap_or_else(|_| br#"{"error":"Validation failed","status":400}"#.to_vec());
441
442 Response::builder()
443 .status(StatusCode::BAD_REQUEST)
444 .header("Content-Type", "application/json")
445 .header("X-Request-Id", request_id)
446 .body(Full::new(Bytes::from(body)))
447 .unwrap_or_else(|_| {
448 Response::builder()
449 .status(StatusCode::INTERNAL_SERVER_ERROR)
450 .body(Full::new(Bytes::new()))
451 .unwrap()
452 })
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use serde_json::json;
460
461 #[test]
462 fn test_schema_validation() {
463 let schema = json!({
464 "type": "object",
465 "properties": {
466 "name": {
467 "type": "string",
468 "minLength": 1
469 },
470 "age": {
471 "type": "integer",
472 "minimum": 0
473 }
474 },
475 "required": ["name"]
476 });
477
478 let config = ApiSchemaConfig {
479 schema_file: None,
480 schema_content: None,
481 request_schema: Some(schema),
482 response_schema: None,
483 validate_requests: true,
484 validate_responses: false,
485 strict_mode: false,
486 };
487
488 let validator = SchemaValidator::new(config).unwrap();
489
490 let valid_json = json!({
492 "name": "John",
493 "age": 30
494 });
495
496 let schema = validator.request_schema.as_ref().unwrap();
497 let result = validator.validate_against_schema(schema, &valid_json, "test-123");
498 assert!(result.is_ok());
499
500 let invalid_json = json!({
502 "age": 30
503 });
504
505 let result = validator.validate_against_schema(schema, &invalid_json, "test-124");
506 assert!(result.is_err());
507
508 let invalid_json = json!({
510 "name": 123,
511 "age": "thirty"
512 });
513
514 let result = validator.validate_against_schema(schema, &invalid_json, "test-125");
515 assert!(result.is_err());
516 }
517
518 #[tokio::test]
519 async fn test_request_validation() {
520 let schema = json!({
521 "type": "object",
522 "properties": {
523 "email": {
524 "type": "string",
525 "format": "email"
526 },
527 "password": {
528 "type": "string",
529 "minLength": 8
530 }
531 },
532 "required": ["email", "password"]
533 });
534
535 let config = ApiSchemaConfig {
536 schema_file: None,
537 schema_content: None,
538 request_schema: Some(schema),
539 response_schema: None,
540 validate_requests: true,
541 validate_responses: false,
542 strict_mode: false,
543 };
544
545 let validator = SchemaValidator::new(config).unwrap();
546
547 let request = Request::post("/login")
548 .header("Content-Type", "application/json")
549 .body(())
550 .unwrap();
551
552 let valid_body = json!({
554 "email": "user@example.com",
555 "password": "securepassword123"
556 });
557 let body_bytes = serde_json::to_vec(&valid_body).unwrap();
558
559 let result = validator
560 .validate_request(&request, &body_bytes, "/login", "req-001")
561 .await;
562 assert!(result.is_ok());
563
564 let invalid_body = json!({
566 "email": "not-an-email",
567 "password": "short"
568 });
569 let body_bytes = serde_json::to_vec(&invalid_body).unwrap();
570
571 let result = validator
572 .validate_request(&request, &body_bytes, "/login", "req-002")
573 .await;
574 assert!(result.is_err());
575 }
576}