spikard_core/
schema_registry.rs1use crate::validation::SchemaValidator;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14pub struct SchemaRegistry {
19 schemas: RwLock<HashMap<String, Arc<SchemaValidator>>>,
21}
22
23impl SchemaRegistry {
24 pub fn new() -> Self {
26 Self {
27 schemas: RwLock::new(HashMap::new()),
28 }
29 }
30
31 pub fn get_or_compile(&self, schema: &Value) -> Result<Arc<SchemaValidator>, String> {
43 let key = serde_json::to_string(schema).map_err(|e| format!("Failed to serialize schema: {}", e))?;
44
45 {
46 let schemas = self.schemas.read().unwrap();
47 if let Some(validator) = schemas.get(&key) {
48 return Ok(Arc::clone(validator));
49 }
50 }
51
52 let validator = Arc::new(SchemaValidator::new(schema.clone())?);
53
54 {
55 let mut schemas = self.schemas.write().unwrap();
56 if let Some(existing) = schemas.get(&key) {
57 return Ok(Arc::clone(existing));
58 }
59 schemas.insert(key, Arc::clone(&validator));
60 }
61
62 Ok(validator)
63 }
64
65 pub fn all_schemas(&self) -> Vec<Arc<SchemaValidator>> {
70 let schemas = self.schemas.read().unwrap();
71 schemas.values().cloned().collect()
72 }
73
74 pub fn schema_count(&self) -> usize {
78 let schemas = self.schemas.read().unwrap();
79 schemas.len()
80 }
81}
82
83impl Default for SchemaRegistry {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use serde_json::json;
93
94 #[test]
95 fn test_schema_deduplication() {
96 let registry = SchemaRegistry::new();
97
98 let schema1 = json!({
99 "type": "object",
100 "properties": {
101 "name": {"type": "string"}
102 }
103 });
104
105 let schema2 = json!({
106 "type": "object",
107 "properties": {
108 "name": {"type": "string"}
109 }
110 });
111
112 let validator1 = registry.get_or_compile(&schema1).unwrap();
113 let validator2 = registry.get_or_compile(&schema2).unwrap();
114
115 assert!(Arc::ptr_eq(&validator1, &validator2));
116
117 assert_eq!(registry.schema_count(), 1);
118 }
119
120 #[test]
121 fn test_different_schemas() {
122 let registry = SchemaRegistry::new();
123
124 let schema1 = json!({
125 "type": "string"
126 });
127
128 let schema2 = json!({
129 "type": "integer"
130 });
131
132 let validator1 = registry.get_or_compile(&schema1).unwrap();
133 let validator2 = registry.get_or_compile(&schema2).unwrap();
134
135 assert!(!Arc::ptr_eq(&validator1, &validator2));
136
137 assert_eq!(registry.schema_count(), 2);
138 }
139
140 #[test]
141 fn test_all_schemas() {
142 let registry = SchemaRegistry::new();
143
144 let schema1 = json!({"type": "string"});
145 let schema2 = json!({"type": "integer"});
146
147 registry.get_or_compile(&schema1).unwrap();
148 registry.get_or_compile(&schema2).unwrap();
149
150 let all = registry.all_schemas();
151 assert_eq!(all.len(), 2);
152 }
153
154 #[test]
155 fn test_concurrent_access() {
156 use std::sync::Arc as StdArc;
157 use std::thread;
158
159 let registry = StdArc::new(SchemaRegistry::new());
160 let schema = json!({
161 "type": "object",
162 "properties": {
163 "id": {"type": "integer"}
164 }
165 });
166
167 let handles: Vec<_> = (0..10)
168 .map(|_| {
169 let registry = StdArc::clone(®istry);
170 let schema = schema.clone();
171 thread::spawn(move || registry.get_or_compile(&schema).unwrap())
172 })
173 .collect();
174
175 let validators: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
176
177 for i in 1..validators.len() {
178 assert!(Arc::ptr_eq(&validators[0], &validators[i]));
179 }
180
181 assert_eq!(registry.schema_count(), 1);
182 }
183}