spikard_core/
schema_registry.rs

1//! Schema registry for deduplication and `OpenAPI` generation
2//!
3//! This module provides a global registry that compiles JSON schemas once at application
4//! startup and reuses them across all routes. This enables:
5//! - Schema deduplication (same schema used by multiple routes)
6//! - `OpenAPI` spec generation (access to all schemas)
7//! - Memory efficiency (one compiled validator per unique schema)
8
9use crate::validation::SchemaValidator;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14/// Global schema registry for compiled validators
15///
16/// Thread-safe registry that ensures each unique schema is compiled exactly once.
17/// Uses `RwLock` for concurrent read access with occasional writes during startup.
18pub struct SchemaRegistry {
19    /// Map from schema JSON string to compiled validator
20    schemas: RwLock<HashMap<String, Arc<SchemaValidator>>>,
21}
22
23impl SchemaRegistry {
24    /// Create a new empty schema registry
25    #[must_use]
26    pub fn new() -> Self {
27        Self {
28            schemas: RwLock::new(HashMap::new()),
29        }
30    }
31
32    /// Get or compile a schema, returning `Arc` to the compiled validator
33    ///
34    /// This method is thread-safe and uses a double-check pattern:
35    /// 1. Fast path: Read lock to check if schema exists
36    /// 2. Slow path: Write lock to compile and store new schema
37    ///
38    /// # Arguments
39    /// * `schema` - The JSON schema to compile
40    ///
41    /// # Returns
42    /// `Arc`-wrapped compiled validator that can be cheaply cloned
43    ///
44    /// # Errors
45    /// Returns an error if schema serialization or compilation fails.
46    ///
47    /// # Panics
48    /// Panics if the read or write lock is poisoned.
49    pub fn get_or_compile(&self, schema: &Value) -> Result<Arc<SchemaValidator>, String> {
50        let key = serde_json::to_string(schema).map_err(|e| format!("Failed to serialize schema: {e}"))?;
51
52        {
53            let schemas = self.schemas.read().unwrap();
54            if let Some(validator) = schemas.get(&key) {
55                return Ok(Arc::clone(validator));
56            }
57        }
58
59        let validator = Arc::new(SchemaValidator::new(schema.clone())?);
60
61        {
62            let mut schemas = self.schemas.write().unwrap();
63            if let Some(existing) = schemas.get(&key) {
64                return Ok(Arc::clone(existing));
65            }
66            schemas.insert(key, Arc::clone(&validator));
67        }
68
69        Ok(validator)
70    }
71
72    /// Get all registered schemas (for `OpenAPI` generation)
73    ///
74    /// Returns a snapshot of all compiled validators.
75    /// Useful for generating `OpenAPI` specifications from runtime schema information.
76    ///
77    /// # Panics
78    /// Panics if the read lock is poisoned.
79    #[must_use]
80    pub fn all_schemas(&self) -> Vec<Arc<SchemaValidator>> {
81        let schemas = self.schemas.read().unwrap();
82        schemas.values().cloned().collect()
83    }
84
85    /// Get the number of unique schemas registered
86    ///
87    /// Useful for diagnostics and understanding schema deduplication effectiveness.
88    ///
89    /// # Panics
90    /// Panics if the read lock is poisoned.
91    #[must_use]
92    pub fn schema_count(&self) -> usize {
93        let schemas = self.schemas.read().unwrap();
94        schemas.len()
95    }
96}
97
98impl Default for SchemaRegistry {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use serde_json::json;
108
109    #[test]
110    fn test_schema_deduplication() {
111        let registry = SchemaRegistry::new();
112
113        let schema1 = json!({
114            "type": "object",
115            "properties": {
116                "name": {"type": "string"}
117            }
118        });
119
120        let schema2 = json!({
121            "type": "object",
122            "properties": {
123                "name": {"type": "string"}
124            }
125        });
126
127        let validator1 = registry.get_or_compile(&schema1).unwrap();
128        let validator2 = registry.get_or_compile(&schema2).unwrap();
129
130        assert!(Arc::ptr_eq(&validator1, &validator2));
131
132        assert_eq!(registry.schema_count(), 1);
133    }
134
135    #[test]
136    fn test_different_schemas() {
137        let registry = SchemaRegistry::new();
138
139        let schema1 = json!({
140            "type": "string"
141        });
142
143        let schema2 = json!({
144            "type": "integer"
145        });
146
147        let validator1 = registry.get_or_compile(&schema1).unwrap();
148        let validator2 = registry.get_or_compile(&schema2).unwrap();
149
150        assert!(!Arc::ptr_eq(&validator1, &validator2));
151
152        assert_eq!(registry.schema_count(), 2);
153    }
154
155    #[test]
156    fn test_all_schemas() {
157        let registry = SchemaRegistry::new();
158
159        let schema1 = json!({"type": "string"});
160        let schema2 = json!({"type": "integer"});
161
162        registry.get_or_compile(&schema1).unwrap();
163        registry.get_or_compile(&schema2).unwrap();
164
165        let all = registry.all_schemas();
166        assert_eq!(all.len(), 2);
167    }
168
169    #[test]
170    fn test_concurrent_access() {
171        use std::sync::Arc as StdArc;
172        use std::thread;
173
174        let registry = StdArc::new(SchemaRegistry::new());
175        let schema = json!({
176            "type": "object",
177            "properties": {
178                "id": {"type": "integer"}
179            }
180        });
181
182        let validators: Vec<_> = (0..10)
183            .map(|_| {
184                let registry = StdArc::clone(&registry);
185                let schema = schema.clone();
186                thread::spawn(move || registry.get_or_compile(&schema).unwrap())
187            })
188            .map(|h| h.join().unwrap())
189            .collect();
190
191        for i in 1..validators.len() {
192            assert!(Arc::ptr_eq(&validators[0], &validators[i]));
193        }
194
195        assert_eq!(registry.schema_count(), 1);
196    }
197}