spikard_core/
schema_registry.rs1use crate::validation::SchemaValidator;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14pub struct SchemaRegistry {
19 schemas: RwLock<HashMap<String, Arc<SchemaValidator>>>,
21}
22
23impl SchemaRegistry {
24 #[must_use]
26 pub fn new() -> Self {
27 Self {
28 schemas: RwLock::new(HashMap::new()),
29 }
30 }
31
32 pub fn get_or_compile(&self, schema: &Value) -> Result<Arc<SchemaValidator>, String> {
50 let key = serde_json::to_string(schema).map_err(|e| format!("Failed to serialize schema: {e}"))?;
51
52 {
53 let schemas = self.schemas.read().unwrap();
54 if let Some(validator) = schemas.get(&key) {
55 return Ok(Arc::clone(validator));
56 }
57 }
58
59 let validator = Arc::new(SchemaValidator::new(schema.clone())?);
60
61 {
62 let mut schemas = self.schemas.write().unwrap();
63 if let Some(existing) = schemas.get(&key) {
64 return Ok(Arc::clone(existing));
65 }
66 schemas.insert(key, Arc::clone(&validator));
67 }
68
69 Ok(validator)
70 }
71
72 #[must_use]
80 pub fn all_schemas(&self) -> Vec<Arc<SchemaValidator>> {
81 let schemas = self.schemas.read().unwrap();
82 schemas.values().cloned().collect()
83 }
84
85 #[must_use]
92 pub fn schema_count(&self) -> usize {
93 let schemas = self.schemas.read().unwrap();
94 schemas.len()
95 }
96}
97
98impl Default for SchemaRegistry {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use serde_json::json;
108
109 #[test]
110 fn test_schema_deduplication() {
111 let registry = SchemaRegistry::new();
112
113 let schema1 = json!({
114 "type": "object",
115 "properties": {
116 "name": {"type": "string"}
117 }
118 });
119
120 let schema2 = json!({
121 "type": "object",
122 "properties": {
123 "name": {"type": "string"}
124 }
125 });
126
127 let validator1 = registry.get_or_compile(&schema1).unwrap();
128 let validator2 = registry.get_or_compile(&schema2).unwrap();
129
130 assert!(Arc::ptr_eq(&validator1, &validator2));
131
132 assert_eq!(registry.schema_count(), 1);
133 }
134
135 #[test]
136 fn test_different_schemas() {
137 let registry = SchemaRegistry::new();
138
139 let schema1 = json!({
140 "type": "string"
141 });
142
143 let schema2 = json!({
144 "type": "integer"
145 });
146
147 let validator1 = registry.get_or_compile(&schema1).unwrap();
148 let validator2 = registry.get_or_compile(&schema2).unwrap();
149
150 assert!(!Arc::ptr_eq(&validator1, &validator2));
151
152 assert_eq!(registry.schema_count(), 2);
153 }
154
155 #[test]
156 fn test_all_schemas() {
157 let registry = SchemaRegistry::new();
158
159 let schema1 = json!({"type": "string"});
160 let schema2 = json!({"type": "integer"});
161
162 registry.get_or_compile(&schema1).unwrap();
163 registry.get_or_compile(&schema2).unwrap();
164
165 let all = registry.all_schemas();
166 assert_eq!(all.len(), 2);
167 }
168
169 #[test]
170 fn test_concurrent_access() {
171 use std::sync::Arc as StdArc;
172 use std::thread;
173
174 let registry = StdArc::new(SchemaRegistry::new());
175 let schema = json!({
176 "type": "object",
177 "properties": {
178 "id": {"type": "integer"}
179 }
180 });
181
182 let handles: Vec<_> = (0..10)
183 .map(|_| {
184 let registry = StdArc::clone(®istry);
185 let schema = schema.clone();
186 thread::spawn(move || registry.get_or_compile(&schema).unwrap())
187 })
188 .collect();
189
190 let validators: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
191
192 for i in 1..validators.len() {
193 assert!(Arc::ptr_eq(&validators[0], &validators[i]));
194 }
195
196 assert_eq!(registry.schema_count(), 1);
197 }
198}