singleton_registry/
registry_trait.rs

1//! Core trait defining registry behavior.
2//!
3//! This module provides the `RegistryApi` trait with default implementations for
4//! type-safe registration, retrieval, and tracing of singleton instances.
5//!
6//! The registry is type-based: each type (`TypeId`) can have exactly one instance stored.
7//! Registering a value of the same type will replace the previous instance.
8
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, LazyLock, Mutex};
12
13use crate::{RegistryError, RegistryEvent};
14
15/// Type alias for the trace callback storage.
16///
17/// Note: This type is also defined in the `define_registry!` macro.
18/// Keep both definitions in sync.
19type TraceCallback = LazyLock<Mutex<Option<Arc<dyn Fn(&RegistryEvent) + Send + Sync>>>>;
20
21/// Core trait defining registry behavior.
22///
23/// Provides default implementations for all registry operations, requiring only
24/// two accessor methods (`storage` and `trace`) to be implemented by the implementor.
25///
26/// The registry stores singleton instances indexed by their type (`TypeId`).
27/// Each type can have at most one instance stored at any given time.
28pub trait RegistryApi {
29    // -------------------------------------------------------------------------------------------------
30    // Tracing
31    // -------------------------------------------------------------------------------------------------
32
33    /// Access the trace callback static.
34    ///
35    /// This method must be implemented to provide access to the registry's trace callback.
36    fn trace() -> &'static TraceCallback;
37
38    /// Set a tracing callback for registry operations.
39    ///
40    /// The callback will be invoked for every registry operation (register, get, contains).
41    ///
42    /// # Lock Poisoning Recovery
43    ///
44    /// If the trace lock is poisoned (due to a panic while holding the lock),
45    /// this method automatically recovers by extracting the inner value.
46    /// This is safe because trace operations are non-critical and idempotent.
47    ///
48    /// # Safety Restrictions
49    ///
50    /// The callback must NOT call any registry methods on the same registry,
51    /// as this will cause a deadlock. The callback is invoked while holding
52    /// the trace lock.
53    fn set_trace_callback(&self, callback: impl Fn(&RegistryEvent) + Send + Sync + 'static) {
54        let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
55        *guard = Some(Arc::new(callback));
56    }
57
58    /// Clear the tracing callback.
59    ///
60    /// After calling this, no tracing events will be emitted.
61    /// Note: This does not affect registered values, only the tracing callback.
62    ///
63    /// # Lock Poisoning Recovery
64    ///
65    /// If the trace lock is poisoned, this method automatically recovers.
66    fn clear_trace_callback(&self) {
67        let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
68        *guard = None;
69    }
70
71    /// Convenience wrapper to emit a registry event using the current callback.
72    ///
73    /// If a trace callback is set, this method will invoke it with the provided event.
74    ///
75    /// # Lock Poisoning Recovery
76    ///
77    /// Lock poisoning is automatically recovered by extracting the inner value.
78    ///
79    /// # Panics
80    ///
81    /// If the callback itself panics, the panic will propagate to the caller.
82    /// The registry lock is not held during callback execution, so this won't
83    /// poison the registry storage.
84    fn emit_event(&self, event: &RegistryEvent) {
85        let guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
86        if let Some(callback) = guard.as_ref() {
87            callback(event);
88        }
89    }
90
91    // -------------------------------------------------------------------------------------------------
92    // Registry
93    // -------------------------------------------------------------------------------------------------
94
95    /// Access the storage static.
96    ///
97    /// This method must be implemented to provide access to the registry's storage.
98    fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>;
99
100    /// Register a value in the registry.
101    ///
102    /// Takes ownership of the value and wraps it in an `Arc` automatically.
103    /// If a value of the same type is already registered, it will be replaced.
104    ///
105    /// # Design Note
106    ///
107    /// This method does not return a `Result` because registration is designed
108    /// for the "write-once" pattern during application startup (and rarely at runtime for rewrite). Lock poisoning
109    /// is automatically recovered. If registration must succeed, ensure your
110    /// application initialization doesn't panic while holding registry locks.
111    fn register<T: Send + Sync + 'static>(&self, value: T) {
112        self.register_arc(Arc::new(value));
113    }
114
115    /// Register an Arc-wrapped value in the registry.
116    ///
117    /// More efficient than `register` when you already have an `Arc`,
118    /// as it avoids creating an additional reference count.
119    ///
120    /// # Lock Poisoning Recovery
121    ///
122    /// If the storage lock is poisoned, this method automatically recovers.
123    /// This is safe because the insert operation is idempotent.
124    fn register_arc<T: Send + Sync + 'static>(&self, value: Arc<T>) {
125        self.emit_event(&RegistryEvent::Register {
126            type_name: std::any::type_name::<T>(),
127        });
128
129        // Register the value
130        Self::storage()
131            .lock()
132            .unwrap_or_else(|p| p.into_inner())
133            .insert(TypeId::of::<T>(), value);
134    }
135
136    /// Retrieve a value from the registry.
137    ///
138    /// Returns `Ok(Arc<T>)` if the type is found.
139    ///
140    /// # Errors
141    ///
142    /// - Type `T` is not found in the registry
143    /// - Type mismatch (extremely rare)
144    /// - Registry lock is poisoned
145    fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, RegistryError> {
146        let map = Self::storage()
147            .lock()
148            .map_err(|_| RegistryError::RegistryLock)?;
149
150        let any_arc_opt = map.get(&TypeId::of::<T>()).cloned();
151
152        drop(map);
153
154        let result: Result<Arc<T>, RegistryError> = match any_arc_opt {
155            Some(any_arc) => any_arc
156                .downcast::<T>()
157                .map_err(|_| RegistryError::TypeMismatch {
158                    type_name: std::any::type_name::<T>(),
159                }),
160            None => Err(RegistryError::TypeNotFound {
161                type_name: std::any::type_name::<T>(),
162            }),
163        };
164
165        self.emit_event(&RegistryEvent::Get {
166            type_name: std::any::type_name::<T>(),
167            found: result.is_ok(),
168        });
169
170        result
171    }
172
173    /// Retrieve a cloned value from the registry.
174    ///
175    /// Returns an owned value by cloning the value stored in the registry.
176    /// The type `T` must implement `Clone`.
177    ///
178    /// # Errors
179    ///
180    /// - Type `T` is not found in the registry
181    /// - Type mismatch
182    fn get_cloned<T: Send + Sync + Clone + 'static>(&self) -> Result<T, RegistryError> {
183        let arc = self.get::<T>()?;
184        Ok((*arc).clone())
185    }
186
187    /// Check if a type is registered in the registry.
188    ///
189    /// Returns `Ok(true)` if the type is registered, `Ok(false)` if not found.
190    ///
191    /// # Errors
192    ///
193    /// - Registry lock is poisoned
194    fn contains<T: Send + Sync + 'static>(&self) -> Result<bool, RegistryError> {
195        let found = Self::storage()
196            .lock()
197            .map(|m| m.contains_key(&TypeId::of::<T>()))
198            .map_err(|_| RegistryError::RegistryLock)?;
199
200        self.emit_event(&RegistryEvent::Contains {
201            type_name: std::any::type_name::<T>(),
202            found,
203        });
204
205        Ok(found)
206    }
207
208    // EDUCATIONAL: Memory leak demonstration (commented out)
209    //
210    // This method demonstrates a common pitfall when working with Arc::into_raw().
211    // It leaks memory because the Arc reference count is never decremented.
212    // Every call to this method leaks one Arc reference permanently.
213    //
214    // #[doc(hidden)]
215    // fn get_ref<T: Send + Sync + Clone + 'static>(&self) -> Result<&'static T, RegistryError> {
216    //     let arc = self.get::<T>()?;
217    //     let ptr = Arc::into_raw(arc);  // ⚠️ MEMORY LEAK: Arc is never freed
218    //     Ok(unsafe { &*ptr })
219    // }
220
221    /// Clear all registered values from the registry.
222    ///
223    /// This method is primarily intended for testing. It removes all registered
224    /// values but does NOT affect:
225    /// - Already-retrieved `Arc<T>` references (they remain valid)
226    /// - The tracing callback (use `clear_trace_callback()` to clear that)
227    ///
228    /// # Lock Poisoning Recovery
229    ///
230    /// If the storage lock is poisoned, this method silently fails.
231    /// This is acceptable for a test-only method.
232    #[doc(hidden)]
233    fn clear(&self) {
234        self.emit_event(&RegistryEvent::Clear {});
235
236        if let Ok(mut registry) = Self::storage().lock() {
237            registry.clear();
238        }
239    }
240}
241
242// -------------------------------------------------------------------------------------------------
243// Tests
244// -------------------------------------------------------------------------------------------------
245
246#[cfg(test)]
247mod tests {
248    use crate::RegistryError;
249
250    use super::{RegistryApi, TraceCallback};
251
252    use serial_test::serial;
253    use std::any::{Any, TypeId};
254    use std::collections::HashMap;
255    use std::sync::{Arc, LazyLock, Mutex};
256
257    static STORAGE: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
258        LazyLock::new(|| Mutex::new(HashMap::new()));
259
260    static TRACE: TraceCallback = LazyLock::new(|| Mutex::new(None));
261
262    struct Api;
263
264    impl RegistryApi for Api {
265        fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> {
266            &STORAGE
267        }
268
269        fn trace() -> &'static TraceCallback {
270            &TRACE
271        }
272    }
273
274    const API: Api = Api;
275
276    #[test]
277    #[serial]
278    fn test_register_and_get_primitive() -> Result<(), RegistryError> {
279        // Clear any previous state
280        API.clear();
281
282        // Register a primitive type
283        API.register(42i32);
284
285        // Retrieve it 1
286        let num: Arc<i32> = API.get()?;
287        assert_eq!(*num, 42);
288
289        // Retrieve it 2
290        let num_2 = API.get::<i32>()?;
291        assert_eq!(*num_2, 42);
292
293        Ok(())
294    }
295
296    #[test]
297    #[serial]
298    fn test_register_and_get_string() {
299        // Clear the registry before the test
300        API.clear();
301
302        // Create and register a string
303        let s = "test".to_string();
304        API.register(s.clone());
305
306        // Retrieve it and verify
307        let retrieved: Arc<String> = API.get().unwrap();
308        assert_eq!(&*retrieved, &s);
309
310        // Clear the registry after the test
311        API.clear();
312    }
313
314    #[test]
315    #[serial]
316    fn test_get_nonexistent() {
317        API.clear();
318
319        let result: Result<Arc<String>, RegistryError> = API.get();
320        assert!(result.is_err());
321        assert_eq!(
322            result.unwrap_err(),
323            RegistryError::TypeNotFound {
324                type_name: "alloc::string::String"
325            }
326        );
327    }
328
329    #[test]
330    #[serial]
331    fn test_thread_safety() {
332        API.clear();
333
334        use std::sync::{mpsc, Arc, Barrier};
335        use std::thread;
336
337        let barrier = Arc::new(Barrier::new(2));
338        let (main_tx, thread_rx) = mpsc::channel();
339        let (thread_tx, main_rx) = mpsc::channel();
340
341        let barrier_clone = barrier.clone();
342        let handle = thread::spawn(move || {
343            API.register(100u32);
344            thread_tx.send(100u32).unwrap();
345
346            // Wait for the main thread to register its value
347            let main_value: String = thread_rx.recv().unwrap();
348
349            // Synchronize: ensure both threads have registered before retrieval
350            barrier_clone.wait();
351
352            let s: Arc<String> = API.get().unwrap();
353            assert_eq!(&*s, &main_value);
354        });
355
356        let thread_value = main_rx.recv().unwrap();
357        let num: Arc<u32> = API.get().unwrap();
358        assert_eq!(*num, thread_value);
359
360        // Register a string in main thread
361        let main_string = "main_thread_value".to_string();
362        API.register(main_string.clone());
363        main_tx.send(main_string.clone()).unwrap();
364
365        // Synchronize: ensure both threads have registered before retrieval
366        barrier.wait();
367
368        handle.join().unwrap();
369        API.clear();
370    }
371
372    #[test]
373    #[serial]
374    fn test_multiple_types() {
375        API.clear();
376
377        // Define wrapper types to ensure unique TypeIds
378        #[derive(Debug, PartialEq, Eq, Clone)]
379        struct Num(i32);
380        #[derive(Debug, PartialEq, Eq, Clone)]
381        struct Text(String);
382        #[derive(Debug, PartialEq, Eq, Clone)]
383        struct Numbers(Vec<i32>);
384
385        // Create the values
386        let num_val = Num(42);
387        let text_val = Text("hello".to_string());
388        let nums_val = Numbers(vec![1, 2, 3]);
389
390        // Register all types first
391        API.register(num_val.clone());
392        API.register(text_val.clone());
393        API.register(nums_val.clone());
394
395        // Then retrieve and verify each one
396        let num: Arc<Num> = API.get().unwrap();
397        assert_eq!(num.0, num_val.0);
398
399        let text: Arc<Text> = API.get().unwrap();
400        assert_eq!(text.0, text_val.0);
401
402        let nums: Arc<Numbers> = API.get().unwrap();
403        assert_eq!(&nums.0, &nums_val.0);
404
405        // Clear the registry after the test
406        API.clear();
407    }
408
409    #[test]
410    #[serial]
411    fn test_custom_type() {
412        API.clear();
413
414        #[derive(Debug, PartialEq, Eq, Clone)]
415        struct MyStruct {
416            field: String,
417        }
418
419        let my_value = MyStruct {
420            field: "test".into(),
421        };
422        API.register(my_value.clone());
423
424        let retrieved: Arc<MyStruct> = API.get().unwrap();
425        assert_eq!(&*retrieved, &my_value);
426    }
427
428    #[test]
429    #[serial]
430    fn test_tuple_type() -> Result<(), RegistryError> {
431        API.clear();
432
433        let tuple = (1, "test");
434        API.register(tuple);
435
436        let retrieved = API.get::<(i32, &str)>()?;
437        assert_eq!(&*retrieved, &tuple);
438
439        Ok(())
440    }
441
442    #[test]
443    #[serial]
444    fn test_overwrite_same_type() {
445        API.clear();
446
447        API.register(10i32);
448        API.register(20i32); // should replace
449
450        let num: Arc<i32> = API.get().unwrap();
451        assert_eq!(*num, 20);
452    }
453
454    #[test]
455    #[serial]
456    fn test_get_cloned() {
457        API.clear();
458        API.register("hello".to_string());
459        let value: String = API.get_cloned::<String>().unwrap();
460        assert_eq!(value, "hello");
461    }
462
463    // EDUCATIONAL: Memory leak test (commented out)
464    //
465    // This test demonstrates the memory leak in the get_ref() method above.
466    // Uncomment this along with get_ref() to see the leak in action.
467    //
468    // #[test]
469    // #[serial]
470    // fn test_get_ref() {
471    //     API.clear();
472    //     API.register("world".to_string());
473    //     let value: &'static String = API.get_ref::<String>().unwrap();
474    //     assert_eq!(value, "world");
475    //
476    //     // WARNING: The following line causes undefined behavior (UB).
477    //     // After calling `clear`, the original `String` has been dropped and its memory deallocated,
478    //     // but `value` is still a reference to the old memory location. Accessing or printing `value`
479    //     // after this point is use-after-free, which is always UB in Rust. This may cause a crash,
480    //     // memory corruption, or appear to "work" by accident, depending on the allocator and OS.
481    //     // This code is for demonstration purposes only—never use a leaked reference after the value is dropped!
482    //     // API.clear(); // value is dropped
483    //     // let _ = value.len();
484    //     // eprintln!("{}", value);
485    // }
486
487    #[test]
488    #[serial]
489    fn test_contains() {
490        API.clear();
491        assert!(!API.contains::<u32>().unwrap());
492        API.register(1u32);
493        assert!(API.contains::<u32>().unwrap());
494    }
495
496    #[test]
497    #[serial]
498    fn test_function_pointer_registration() {
499        API.clear();
500
501        // Test the function pointer example from README
502        let multiply_by_two: fn(i32) -> i32 = |x| x * 2;
503        API.register(multiply_by_two);
504
505        let doubler: Arc<fn(i32) -> i32> = API.get().unwrap();
506        let result = doubler(21);
507        assert_eq!(result, 42);
508    }
509
510    #[test]
511    #[serial]
512    fn test_trace_callback_register_event() {
513        API.clear();
514        use std::sync::{Arc as StdArc, Mutex as StdMutex};
515        let events = StdArc::new(StdMutex::new(Vec::new()));
516        let events_clone = events.clone();
517
518        API.set_trace_callback(move |e| {
519            events_clone.lock().unwrap().push(format!("{}", e));
520        });
521
522        API.register(5u8);
523
524        let captured = events.lock().unwrap();
525        assert_eq!(captured.len(), 1);
526        assert_eq!(captured[0], "register { type_name: u8 }");
527
528        API.clear_trace_callback();
529    }
530
531    #[test]
532    #[serial]
533    fn test_trace_callback_get_event() {
534        API.clear();
535        use std::sync::{Arc as StdArc, Mutex as StdMutex};
536        let events = StdArc::new(StdMutex::new(Vec::new()));
537        let events_clone = events.clone();
538
539        API.set_trace_callback(move |e| {
540            events_clone.lock().unwrap().push(format!("{}", e));
541        });
542
543        API.register(42i32);
544        let _ = API.get::<i32>();
545
546        let captured = events.lock().unwrap();
547        assert_eq!(captured.len(), 2);
548        assert_eq!(captured[0], "register { type_name: i32 }");
549        assert_eq!(captured[1], "get { type_name: i32, found: true }");
550
551        API.clear_trace_callback();
552    }
553
554    #[test]
555    #[serial]
556    fn test_trace_callback_contains_event() {
557        API.clear();
558        use std::sync::{Arc as StdArc, Mutex as StdMutex};
559        let events = StdArc::new(StdMutex::new(Vec::new()));
560        let events_clone = events.clone();
561
562        API.set_trace_callback(move |e| {
563            events_clone.lock().unwrap().push(format!("{}", e));
564        });
565
566        let _ = API.contains::<String>();
567        API.register("test".to_string());
568        let _ = API.contains::<String>();
569
570        let captured = events.lock().unwrap();
571        assert_eq!(captured.len(), 3);
572        assert_eq!(
573            captured[0],
574            "contains { type_name: alloc::string::String, found: false }"
575        );
576        assert_eq!(captured[1], "register { type_name: alloc::string::String }");
577        assert_eq!(
578            captured[2],
579            "contains { type_name: alloc::string::String, found: true }"
580        );
581
582        API.clear_trace_callback();
583    }
584
585    #[test]
586    #[serial]
587    fn test_trace_callback_clear_event() {
588        API.clear();
589        use std::sync::{Arc as StdArc, Mutex as StdMutex};
590        let events = StdArc::new(StdMutex::new(Vec::new()));
591        let events_clone = events.clone();
592
593        API.set_trace_callback(move |e| {
594            events_clone.lock().unwrap().push(format!("{}", e));
595        });
596
597        API.clear();
598
599        let captured = events.lock().unwrap();
600        assert_eq!(captured.len(), 1);
601        assert_eq!(captured[0], "Clearing the Registry");
602
603        API.clear_trace_callback();
604    }
605
606    #[test]
607    #[serial]
608    fn test_clear_trace_callback_stops_events() {
609        API.clear();
610        use std::sync::{Arc as StdArc, Mutex as StdMutex};
611        let events = StdArc::new(StdMutex::new(Vec::new()));
612        let events_clone = events.clone();
613
614        // Set callback and register a value
615        API.set_trace_callback(move |e| {
616            events_clone.lock().unwrap().push(format!("{}", e));
617        });
618
619        API.register(10u16);
620
621        // Verify event was captured
622        {
623            let captured = events.lock().unwrap();
624            assert_eq!(captured.len(), 1);
625            assert_eq!(captured[0], "register { type_name: u16 }");
626        }
627
628        // Clear the callback
629        API.clear_trace_callback();
630
631        // Perform more operations - these should NOT be traced
632        API.register(20u16);
633        let _ = API.get::<u16>();
634        let _ = API.contains::<u16>();
635
636        // Verify no new events were captured
637        let captured = events.lock().unwrap();
638        assert_eq!(captured.len(), 1); // Still only the first event
639    }
640
641    #[test]
642    #[serial]
643    fn test_register_arc_directly() {
644        API.clear();
645        let value = Arc::new(42i32);
646        let clone = value.clone();
647        API.register_arc(value);
648
649        let retrieved: Arc<i32> = API.get().unwrap();
650        assert_eq!(*retrieved, 42);
651        assert_eq!(Arc::strong_count(&clone), 3); // clone + registry + retrieved
652    }
653}