1use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub trait DefaultProvider: Send + Sync {
16 fn default_bytes(&self, field: &str, type_name: &str) -> Option<Vec<u8>>;
27
28 fn default_string(&self, field: &str, type_name: &str) -> Option<String>;
32
33 fn has_default(&self, field: &str) -> bool;
35
36 fn fields_with_defaults(&self) -> Vec<String>;
38}
39
40#[derive(Default)]
42pub struct MapDefaultProvider {
43 defaults: HashMap<String, (String, Vec<u8>, String)>,
45}
46
47impl MapDefaultProvider {
48 pub fn new() -> Self {
50 Self {
51 defaults: HashMap::new(),
52 }
53 }
54
55 pub fn with_string(mut self, field: &str, value: &str) -> Self {
57 let bytes = value.as_bytes().to_vec();
58 self.defaults.insert(
59 field.to_string(),
60 ("String".to_string(), bytes, value.to_string()),
61 );
62 self
63 }
64
65 pub fn with_i32(mut self, field: &str, value: i32) -> Self {
67 let bytes = value.to_le_bytes().to_vec();
68 self.defaults.insert(
69 field.to_string(),
70 ("i32".to_string(), bytes, value.to_string()),
71 );
72 self
73 }
74
75 pub fn with_i64(mut self, field: &str, value: i64) -> Self {
77 let bytes = value.to_le_bytes().to_vec();
78 self.defaults.insert(
79 field.to_string(),
80 ("i64".to_string(), bytes, value.to_string()),
81 );
82 self
83 }
84
85 pub fn with_u32(mut self, field: &str, value: u32) -> Self {
87 let bytes = value.to_le_bytes().to_vec();
88 self.defaults.insert(
89 field.to_string(),
90 ("u32".to_string(), bytes, value.to_string()),
91 );
92 self
93 }
94
95 pub fn with_u64(mut self, field: &str, value: u64) -> Self {
97 let bytes = value.to_le_bytes().to_vec();
98 self.defaults.insert(
99 field.to_string(),
100 ("u64".to_string(), bytes, value.to_string()),
101 );
102 self
103 }
104
105 pub fn with_f32(mut self, field: &str, value: f32) -> Self {
107 let bytes = value.to_le_bytes().to_vec();
108 self.defaults.insert(
109 field.to_string(),
110 ("f32".to_string(), bytes, value.to_string()),
111 );
112 self
113 }
114
115 pub fn with_f64(mut self, field: &str, value: f64) -> Self {
117 let bytes = value.to_le_bytes().to_vec();
118 self.defaults.insert(
119 field.to_string(),
120 ("f64".to_string(), bytes, value.to_string()),
121 );
122 self
123 }
124
125 pub fn with_bool(mut self, field: &str, value: bool) -> Self {
127 let bytes = vec![if value { 1u8 } else { 0u8 }];
128 self.defaults.insert(
129 field.to_string(),
130 ("bool".to_string(), bytes, value.to_string()),
131 );
132 self
133 }
134
135 pub fn with_bytes(mut self, field: &str, type_name: &str, bytes: Vec<u8>, repr: &str) -> Self {
137 self.defaults.insert(
138 field.to_string(),
139 (type_name.to_string(), bytes, repr.to_string()),
140 );
141 self
142 }
143}
144
145impl DefaultProvider for MapDefaultProvider {
146 fn default_bytes(&self, field: &str, _type_name: &str) -> Option<Vec<u8>> {
147 self.defaults.get(field).map(|(_, bytes, _)| bytes.clone())
148 }
149
150 fn default_string(&self, field: &str, _type_name: &str) -> Option<String> {
151 self.defaults.get(field).map(|(_, _, repr)| repr.clone())
152 }
153
154 fn has_default(&self, field: &str) -> bool {
155 self.defaults.contains_key(field)
156 }
157
158 fn fields_with_defaults(&self) -> Vec<String> {
159 self.defaults.keys().cloned().collect()
160 }
161}
162
163pub struct DefaultRegistry {
168 providers: RwLock<HashMap<String, Arc<dyn DefaultProvider>>>,
170}
171
172impl DefaultRegistry {
173 pub fn new() -> Self {
175 Self {
176 providers: RwLock::new(HashMap::new()),
177 }
178 }
179
180 pub fn register(&self, schema: &str, provider: impl DefaultProvider + 'static) {
182 let mut providers = self.providers.write();
183 providers.insert(schema.to_string(), Arc::new(provider));
184 }
185
186 pub fn register_shared(&self, schema: &str, provider: Arc<dyn DefaultProvider>) {
188 let mut providers = self.providers.write();
189 providers.insert(schema.to_string(), provider);
190 }
191
192 pub fn get_provider(&self, schema: &str) -> Option<Arc<dyn DefaultProvider>> {
194 let providers = self.providers.read();
195 providers.get(schema).cloned()
196 }
197
198 pub fn get_default(&self, schema: &str, field: &str, type_name: &str) -> Option<Vec<u8>> {
200 self.get_provider(schema)
201 .and_then(|p| p.default_bytes(field, type_name))
202 }
203
204 pub fn get_default_string(&self, schema: &str, field: &str, type_name: &str) -> Option<String> {
206 self.get_provider(schema)
207 .and_then(|p| p.default_string(field, type_name))
208 }
209
210 pub fn has_default(&self, schema: &str, field: &str) -> bool {
212 self.get_provider(schema)
213 .map(|p| p.has_default(field))
214 .unwrap_or(false)
215 }
216
217 pub fn schemas(&self) -> Vec<String> {
219 let providers = self.providers.read();
220 providers.keys().cloned().collect()
221 }
222
223 pub fn unregister(&self, schema: &str) -> bool {
225 let mut providers = self.providers.write();
226 providers.remove(schema).is_some()
227 }
228}
229
230impl Default for DefaultRegistry {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236#[derive(Debug, Clone)]
241pub struct FieldDefault {
242 pub field: String,
244 pub type_name: String,
246 pub expression: String,
248 pub bytes: Option<Vec<u8>>,
250}
251
252impl FieldDefault {
253 pub fn new(
255 field: impl Into<String>,
256 type_name: impl Into<String>,
257 expression: impl Into<String>,
258 ) -> Self {
259 Self {
260 field: field.into(),
261 type_name: type_name.into(),
262 expression: expression.into(),
263 bytes: None,
264 }
265 }
266
267 pub fn with_bytes(mut self, bytes: Vec<u8>) -> Self {
269 self.bytes = Some(bytes);
270 self
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn map_provider_string_default() {
280 let provider = MapDefaultProvider::new().with_string("currency", "USD");
281
282 assert!(provider.has_default("currency"));
283 assert!(!provider.has_default("other"));
284
285 let bytes = provider.default_bytes("currency", "String").unwrap();
286 assert_eq!(String::from_utf8(bytes).unwrap(), "USD");
287
288 let repr = provider.default_string("currency", "String").unwrap();
289 assert_eq!(repr, "USD");
290 }
291
292 #[test]
293 fn map_provider_numeric_defaults() {
294 let provider = MapDefaultProvider::new()
295 .with_i32("count", 42)
296 .with_f64("rate", 0.05)
297 .with_bool("active", true);
298
299 let count_bytes = provider.default_bytes("count", "i32").unwrap();
300 assert_eq!(i32::from_le_bytes(count_bytes.try_into().unwrap()), 42);
301
302 let rate_bytes = provider.default_bytes("rate", "f64").unwrap();
303 assert_eq!(f64::from_le_bytes(rate_bytes.try_into().unwrap()), 0.05);
304
305 let active_bytes = provider.default_bytes("active", "bool").unwrap();
306 assert_eq!(active_bytes[0], 1);
307 }
308
309 #[test]
310 fn registry_register_and_get() {
311 let registry = DefaultRegistry::new();
312
313 let provider = MapDefaultProvider::new()
314 .with_string("currency", "USD")
315 .with_f64("amount", 0.0);
316
317 registry.register("OrderInput@v2", provider);
318
319 assert!(registry.has_default("OrderInput@v2", "currency"));
320 assert!(!registry.has_default("OrderInput@v2", "other"));
321 assert!(!registry.has_default("OrderInput@v1", "currency"));
322
323 let default = registry.get_default("OrderInput@v2", "currency", "String");
324 assert!(default.is_some());
325
326 let repr = registry.get_default_string("OrderInput@v2", "currency", "String");
327 assert_eq!(repr, Some("USD".to_string()));
328 }
329
330 #[test]
331 fn registry_schemas() {
332 let registry = DefaultRegistry::new();
333
334 registry.register("Schema1", MapDefaultProvider::new());
335 registry.register("Schema2", MapDefaultProvider::new());
336
337 let schemas = registry.schemas();
338 assert_eq!(schemas.len(), 2);
339 assert!(schemas.contains(&"Schema1".to_string()));
340 assert!(schemas.contains(&"Schema2".to_string()));
341 }
342
343 #[test]
344 fn registry_unregister() {
345 let registry = DefaultRegistry::new();
346
347 registry.register("Schema1", MapDefaultProvider::new());
348 assert!(registry.unregister("Schema1"));
349 assert!(!registry.unregister("Schema1"));
350 assert!(registry.schemas().is_empty());
351 }
352
353 #[test]
354 fn fields_with_defaults() {
355 let provider = MapDefaultProvider::new()
356 .with_string("a", "value")
357 .with_i32("b", 0)
358 .with_bool("c", false);
359
360 let fields = provider.fields_with_defaults();
361 assert_eq!(fields.len(), 3);
362 assert!(fields.contains(&"a".to_string()));
363 assert!(fields.contains(&"b".to_string()));
364 assert!(fields.contains(&"c".to_string()));
365 }
366}