1use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, LazyLock, Mutex};
12
13use crate::{RegistryError, RegistryEvent};
14
15type TraceCallback = LazyLock<Mutex<Option<Arc<dyn Fn(&RegistryEvent) + Send + Sync>>>>;
20
21pub trait RegistryApi {
29 fn trace() -> &'static TraceCallback;
37
38 fn set_trace_callback(&self, callback: impl Fn(&RegistryEvent) + Send + Sync + 'static) {
54 let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
55 *guard = Some(Arc::new(callback));
56 }
57
58 fn clear_trace_callback(&self) {
67 let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
68 *guard = None;
69 }
70
71 fn emit_event(&self, event: &RegistryEvent) {
85 let guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
86 if let Some(callback) = guard.as_ref() {
87 callback(event);
88 }
89 }
90
91 fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>;
99
100 fn register<T: Send + Sync + 'static>(&self, value: T) {
112 self.register_arc(Arc::new(value));
113 }
114
115 fn register_arc<T: Send + Sync + 'static>(&self, value: Arc<T>) {
125 self.emit_event(&RegistryEvent::Register {
126 type_name: std::any::type_name::<T>(),
127 });
128
129 Self::storage()
131 .lock()
132 .unwrap_or_else(|p| p.into_inner())
133 .insert(TypeId::of::<T>(), value);
134 }
135
136 fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, RegistryError> {
146 let map = Self::storage()
147 .lock()
148 .map_err(|_| RegistryError::RegistryLock)?;
149
150 let any_arc_opt = map.get(&TypeId::of::<T>()).cloned();
151
152 drop(map);
153
154 let result: Result<Arc<T>, RegistryError> = match any_arc_opt {
155 Some(any_arc) => any_arc
156 .downcast::<T>()
157 .map_err(|_| RegistryError::TypeMismatch {
158 type_name: std::any::type_name::<T>(),
159 }),
160 None => Err(RegistryError::TypeNotFound {
161 type_name: std::any::type_name::<T>(),
162 }),
163 };
164
165 self.emit_event(&RegistryEvent::Get {
166 type_name: std::any::type_name::<T>(),
167 found: result.is_ok(),
168 });
169
170 result
171 }
172
173 fn get_cloned<T: Send + Sync + Clone + 'static>(&self) -> Result<T, RegistryError> {
183 let arc = self.get::<T>()?;
184 Ok((*arc).clone())
185 }
186
187 fn contains<T: Send + Sync + 'static>(&self) -> Result<bool, RegistryError> {
195 let found = Self::storage()
196 .lock()
197 .map(|m| m.contains_key(&TypeId::of::<T>()))
198 .map_err(|_| RegistryError::RegistryLock)?;
199
200 self.emit_event(&RegistryEvent::Contains {
201 type_name: std::any::type_name::<T>(),
202 found,
203 });
204
205 Ok(found)
206 }
207
208 #[doc(hidden)]
233 fn clear(&self) {
234 self.emit_event(&RegistryEvent::Clear {});
235
236 if let Ok(mut registry) = Self::storage().lock() {
237 registry.clear();
238 }
239 }
240}
241
242#[cfg(test)]
247mod tests {
248 use crate::RegistryError;
249
250 use super::{RegistryApi, TraceCallback};
251
252 use serial_test::serial;
253 use std::any::{Any, TypeId};
254 use std::collections::HashMap;
255 use std::sync::{Arc, LazyLock, Mutex};
256
257 static STORAGE: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
258 LazyLock::new(|| Mutex::new(HashMap::new()));
259
260 static TRACE: TraceCallback = LazyLock::new(|| Mutex::new(None));
261
262 struct Api;
263
264 impl RegistryApi for Api {
265 fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> {
266 &STORAGE
267 }
268
269 fn trace() -> &'static TraceCallback {
270 &TRACE
271 }
272 }
273
274 const API: Api = Api;
275
276 #[test]
277 #[serial]
278 fn test_register_and_get_primitive() -> Result<(), RegistryError> {
279 API.clear();
281
282 API.register(42i32);
284
285 let num: Arc<i32> = API.get()?;
287 assert_eq!(*num, 42);
288
289 let num_2 = API.get::<i32>()?;
291 assert_eq!(*num_2, 42);
292
293 Ok(())
294 }
295
296 #[test]
297 #[serial]
298 fn test_register_and_get_string() {
299 API.clear();
301
302 let s = "test".to_string();
304 API.register(s.clone());
305
306 let retrieved: Arc<String> = API.get().unwrap();
308 assert_eq!(&*retrieved, &s);
309
310 API.clear();
312 }
313
314 #[test]
315 #[serial]
316 fn test_get_nonexistent() {
317 API.clear();
318
319 let result: Result<Arc<String>, RegistryError> = API.get();
320 assert!(result.is_err());
321 assert_eq!(
322 result.unwrap_err(),
323 RegistryError::TypeNotFound {
324 type_name: "alloc::string::String"
325 }
326 );
327 }
328
329 #[test]
330 #[serial]
331 fn test_thread_safety() {
332 API.clear();
333
334 use std::sync::{mpsc, Arc, Barrier};
335 use std::thread;
336
337 let barrier = Arc::new(Barrier::new(2));
338 let (main_tx, thread_rx) = mpsc::channel();
339 let (thread_tx, main_rx) = mpsc::channel();
340
341 let barrier_clone = barrier.clone();
342 let handle = thread::spawn(move || {
343 API.register(100u32);
344 thread_tx.send(100u32).unwrap();
345
346 let main_value: String = thread_rx.recv().unwrap();
348
349 barrier_clone.wait();
351
352 let s: Arc<String> = API.get().unwrap();
353 assert_eq!(&*s, &main_value);
354 });
355
356 let thread_value = main_rx.recv().unwrap();
357 let num: Arc<u32> = API.get().unwrap();
358 assert_eq!(*num, thread_value);
359
360 let main_string = "main_thread_value".to_string();
362 API.register(main_string.clone());
363 main_tx.send(main_string.clone()).unwrap();
364
365 barrier.wait();
367
368 handle.join().unwrap();
369 API.clear();
370 }
371
372 #[test]
373 #[serial]
374 fn test_multiple_types() {
375 API.clear();
376
377 #[derive(Debug, PartialEq, Eq, Clone)]
379 struct Num(i32);
380 #[derive(Debug, PartialEq, Eq, Clone)]
381 struct Text(String);
382 #[derive(Debug, PartialEq, Eq, Clone)]
383 struct Numbers(Vec<i32>);
384
385 let num_val = Num(42);
387 let text_val = Text("hello".to_string());
388 let nums_val = Numbers(vec![1, 2, 3]);
389
390 API.register(num_val.clone());
392 API.register(text_val.clone());
393 API.register(nums_val.clone());
394
395 let num: Arc<Num> = API.get().unwrap();
397 assert_eq!(num.0, num_val.0);
398
399 let text: Arc<Text> = API.get().unwrap();
400 assert_eq!(text.0, text_val.0);
401
402 let nums: Arc<Numbers> = API.get().unwrap();
403 assert_eq!(&nums.0, &nums_val.0);
404
405 API.clear();
407 }
408
409 #[test]
410 #[serial]
411 fn test_custom_type() {
412 API.clear();
413
414 #[derive(Debug, PartialEq, Eq, Clone)]
415 struct MyStruct {
416 field: String,
417 }
418
419 let my_value = MyStruct {
420 field: "test".into(),
421 };
422 API.register(my_value.clone());
423
424 let retrieved: Arc<MyStruct> = API.get().unwrap();
425 assert_eq!(&*retrieved, &my_value);
426 }
427
428 #[test]
429 #[serial]
430 fn test_tuple_type() -> Result<(), RegistryError> {
431 API.clear();
432
433 let tuple = (1, "test");
434 API.register(tuple);
435
436 let retrieved = API.get::<(i32, &str)>()?;
437 assert_eq!(&*retrieved, &tuple);
438
439 Ok(())
440 }
441
442 #[test]
443 #[serial]
444 fn test_overwrite_same_type() {
445 API.clear();
446
447 API.register(10i32);
448 API.register(20i32); let num: Arc<i32> = API.get().unwrap();
451 assert_eq!(*num, 20);
452 }
453
454 #[test]
455 #[serial]
456 fn test_get_cloned() {
457 API.clear();
458 API.register("hello".to_string());
459 let value: String = API.get_cloned::<String>().unwrap();
460 assert_eq!(value, "hello");
461 }
462
463 #[test]
488 #[serial]
489 fn test_contains() {
490 API.clear();
491 assert!(!API.contains::<u32>().unwrap());
492 API.register(1u32);
493 assert!(API.contains::<u32>().unwrap());
494 }
495
496 #[test]
497 #[serial]
498 fn test_function_pointer_registration() {
499 API.clear();
500
501 let multiply_by_two: fn(i32) -> i32 = |x| x * 2;
503 API.register(multiply_by_two);
504
505 let doubler: Arc<fn(i32) -> i32> = API.get().unwrap();
506 let result = doubler(21);
507 assert_eq!(result, 42);
508 }
509
510 #[test]
511 #[serial]
512 fn test_trace_callback_register_event() {
513 API.clear();
514 use std::sync::{Arc as StdArc, Mutex as StdMutex};
515 let events = StdArc::new(StdMutex::new(Vec::new()));
516 let events_clone = events.clone();
517
518 API.set_trace_callback(move |e| {
519 events_clone.lock().unwrap().push(format!("{}", e));
520 });
521
522 API.register(5u8);
523
524 let captured = events.lock().unwrap();
525 assert_eq!(captured.len(), 1);
526 assert_eq!(captured[0], "register { type_name: u8 }");
527
528 API.clear_trace_callback();
529 }
530
531 #[test]
532 #[serial]
533 fn test_trace_callback_get_event() {
534 API.clear();
535 use std::sync::{Arc as StdArc, Mutex as StdMutex};
536 let events = StdArc::new(StdMutex::new(Vec::new()));
537 let events_clone = events.clone();
538
539 API.set_trace_callback(move |e| {
540 events_clone.lock().unwrap().push(format!("{}", e));
541 });
542
543 API.register(42i32);
544 let _ = API.get::<i32>();
545
546 let captured = events.lock().unwrap();
547 assert_eq!(captured.len(), 2);
548 assert_eq!(captured[0], "register { type_name: i32 }");
549 assert_eq!(captured[1], "get { type_name: i32, found: true }");
550
551 API.clear_trace_callback();
552 }
553
554 #[test]
555 #[serial]
556 fn test_trace_callback_contains_event() {
557 API.clear();
558 use std::sync::{Arc as StdArc, Mutex as StdMutex};
559 let events = StdArc::new(StdMutex::new(Vec::new()));
560 let events_clone = events.clone();
561
562 API.set_trace_callback(move |e| {
563 events_clone.lock().unwrap().push(format!("{}", e));
564 });
565
566 let _ = API.contains::<String>();
567 API.register("test".to_string());
568 let _ = API.contains::<String>();
569
570 let captured = events.lock().unwrap();
571 assert_eq!(captured.len(), 3);
572 assert_eq!(
573 captured[0],
574 "contains { type_name: alloc::string::String, found: false }"
575 );
576 assert_eq!(captured[1], "register { type_name: alloc::string::String }");
577 assert_eq!(
578 captured[2],
579 "contains { type_name: alloc::string::String, found: true }"
580 );
581
582 API.clear_trace_callback();
583 }
584
585 #[test]
586 #[serial]
587 fn test_trace_callback_clear_event() {
588 API.clear();
589 use std::sync::{Arc as StdArc, Mutex as StdMutex};
590 let events = StdArc::new(StdMutex::new(Vec::new()));
591 let events_clone = events.clone();
592
593 API.set_trace_callback(move |e| {
594 events_clone.lock().unwrap().push(format!("{}", e));
595 });
596
597 API.clear();
598
599 let captured = events.lock().unwrap();
600 assert_eq!(captured.len(), 1);
601 assert_eq!(captured[0], "Clearing the Registry");
602
603 API.clear_trace_callback();
604 }
605
606 #[test]
607 #[serial]
608 fn test_clear_trace_callback_stops_events() {
609 API.clear();
610 use std::sync::{Arc as StdArc, Mutex as StdMutex};
611 let events = StdArc::new(StdMutex::new(Vec::new()));
612 let events_clone = events.clone();
613
614 API.set_trace_callback(move |e| {
616 events_clone.lock().unwrap().push(format!("{}", e));
617 });
618
619 API.register(10u16);
620
621 {
623 let captured = events.lock().unwrap();
624 assert_eq!(captured.len(), 1);
625 assert_eq!(captured[0], "register { type_name: u16 }");
626 }
627
628 API.clear_trace_callback();
630
631 API.register(20u16);
633 let _ = API.get::<u16>();
634 let _ = API.contains::<u16>();
635
636 let captured = events.lock().unwrap();
638 assert_eq!(captured.len(), 1); }
640
641 #[test]
642 #[serial]
643 fn test_register_arc_directly() {
644 API.clear();
645 let value = Arc::new(42i32);
646 let clone = value.clone();
647 API.register_arc(value);
648
649 let retrieved: Arc<i32> = API.get().unwrap();
650 assert_eq!(*retrieved, 42);
651 assert_eq!(Arc::strong_count(&clone), 3); }
653}