singleton_registry/
registry.rs

1#![allow(dead_code)]
2
3//! A thread-safe dependency injection registry for storing and retrieving global instances.
4//! Currently designed for write-once, read-many pattern.
5//!
6//! This module provides a type-safe way to register and retrieve instances of any type
7//! that implements `Send + Sync + 'static`.
8//!
9//! # Examples
10//!
11//! ```
12//! use singleton_registry::{register, get};
13//! use std::sync::Arc;
14//!
15//! // Register a value
16//! register("Hello, World!".to_string());
17//!
18//! // Retrieve the value
19//! let message: Arc<String> = get().unwrap();
20//! assert_eq!(&*message, "Hello, World!");
21//! ```
22
23use std::{
24    any::{Any, TypeId},
25    collections::HashMap,
26    fmt,
27    sync::{Arc, LazyLock, Mutex},
28};
29
30/// Global thread-safe registry storing type instances.
31///
32/// This is a `LazyLock` ensuring thread-safe lazy initialization of the underlying `Mutex<HashMap>`.
33/// The registry maps `TypeId` to `Arc<dyn Any + Send + Sync>` for type-erased storage.
34static GLOBAL_REGISTRY: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
35    LazyLock::new(|| Mutex::new(HashMap::new()));
36
37// -------------------------------------------------------------------------------------------------
38// Tracing callback support
39// -------------------------------------------------------------------------------------------------
40
41/// Events emitted by the dependency-injection registry.
42#[derive(Debug)]
43pub enum RegistryEvent {
44    /// A value was registered.
45    Register {
46        type_name: &'static str,
47    },
48    /// A value was requested with `di_get`.
49    Get {
50        type_name: &'static str,
51        found: bool,
52    },
53    /// A `di_contains` check was performed.
54    Contains {
55        type_name: &'static str,
56        found: bool,
57    },
58    Clear {},
59}
60
61impl fmt::Display for RegistryEvent {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            RegistryEvent::Register { type_name } => {
65                write!(f, "register {{ type_name: {type_name} }}")
66            }
67            RegistryEvent::Get { type_name, found } => {
68                write!(f, "get {{ type_name: {type_name}, found: {found} }}")
69            }
70            RegistryEvent::Contains { type_name, found } => {
71                write!(f, "contains {{ type_name: {type_name}, found: {found} }}")
72            }
73            RegistryEvent::Clear {} => write!(f, "Clearing the Registry"),
74        }
75    }
76}
77
78/// Type alias for the user-supplied tracing callback.
79///
80/// The callback receives a reference to a `RegistryEvent` every time the registry is
81/// interacted with. It must be thread-safe because the registry itself is globally shared.
82pub type TraceCallback = dyn Fn(&RegistryEvent) + Send + Sync + 'static;
83
84/// Holds an optional user-defined tracing callback.
85static TRACE_CALLBACK: LazyLock<Mutex<Option<Arc<TraceCallback>>>> =
86    LazyLock::new(|| Mutex::new(None));
87
88/// Sets a tracing callback that will be invoked on every registry interaction.
89///
90/// Pass `None` (or call `clear_trace_callback`) to disable tracing.
91///
92/// # Example
93/// ```rust
94/// use singleton_registry::{set_trace_callback, RegistryEvent};
95///
96/// set_trace_callback(|event| println!("[registry-trace] {:?}", event));
97/// ```
98pub fn set_trace_callback(callback: impl Fn(&RegistryEvent) + Send + Sync + 'static) {
99    let mut guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
100    *guard = Some(Arc::new(callback));
101}
102
103/// Clears the tracing callback (disables registry tracing).
104pub fn clear_trace_callback() {
105    let mut guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
106    *guard = None;
107}
108
109/// Convenience wrapper to emit a registry event using the current callback.
110fn emit_event(event: &RegistryEvent) {
111    // lock poisoning unlikely; if poisoned, keep emitting with recovered lock
112    let guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
113    if let Some(callback) = guard.as_ref() {
114        callback(event);
115    }
116}
117
118// -------------------------------------------------------------------------------------------------
119// Registry
120// -------------------------------------------------------------------------------------------------
121
122/// Registers an `Arc<T>` in the global registry.
123///
124/// This is more efficient than `di_register` when you already have an `Arc`,
125/// as it avoids creating an additional reference count.
126///
127/// # Safety
128///
129/// If the registry's lock is poisoned (which can happen if a thread panicked while
130/// holding the lock), this function will recover the lock and continue execution.
131/// This is safe because the registry is used in a read-only manner after the
132/// initial registration phase in `main.rs`.
133///
134/// # Arguments
135///
136/// * `value` - The `Arc`-wrapped value to register. The inner type must implement
137///   `Send + Sync + 'static`.
138///
139/// # Examples
140///
141/// ```
142/// use std::sync::Arc;
143/// use singleton_registry::{register_arc, get};
144///
145/// let value = Arc::new("shared".to_string());
146/// register_arc(value.clone());
147///
148/// let retrieved: Arc<String> = get().expect("Failed to get value");
149/// assert_eq!(&*retrieved, "shared");
150/// ```
151pub fn register_arc<T: Send + Sync + 'static>(value: Arc<T>) {
152    emit_event(&RegistryEvent::Register {
153        type_name: std::any::type_name::<T>(),
154    });
155
156    GLOBAL_REGISTRY
157        .lock()
158        // The registry is used as read only, so we do not expect a poisoned lock.
159        // Poisoning only occurs if a thread panics while holding the lock.
160        .unwrap_or_else(|poisoned| poisoned.into_inner())
161        .insert(TypeId::of::<T>(), value);
162}
163
164/// Registers a value of type `T` in the global registry.
165///
166/// This is a convenience wrapper around `di_register_arc` that takes ownership
167/// of the value and wraps it in an `Arc` automatically.
168///
169/// # Safety
170///
171/// If the registry's lock is poisoned (which can happen if a thread panicked while
172/// holding the lock), this function will recover the lock and continue execution.
173/// This is safe because the registry is used in a read-only manner after the
174/// initial registration phase in `main.rs`.
175///
176/// # Arguments
177///
178/// * `value` - The value to register. Must implement `Send + Sync + 'static`.
179///
180/// # Examples
181///
182/// ```
183/// use singleton_registry::{register, get};
184/// use std::sync::Arc;
185///
186/// // Register a primitive
187/// register(42i32);
188///
189/// // Register a string
190/// register("Hello".to_string());
191///
192/// // Retrieve values
193/// let num: Arc<i32> = get().expect("Failed to get i32");
194/// let s: Arc<String> = get().expect("Failed to get String");
195///
196/// assert_eq!(*num, 42);
197/// assert_eq!(&*s, "Hello");
198/// ```
199pub fn register<T: Send + Sync + 'static>(value: T) {
200    register_arc::<T>(Arc::new(value));
201}
202
203/// Checks if a value of type `T` is registered in the global registry.
204///
205/// # Returns
206///
207/// - `Ok(true)` if the type is registered
208/// - `Ok(false)` if the type is not found
209/// - `Err(String)` if failed to acquire the registry lock
210///
211/// # Examples
212///
213/// ```
214/// use singleton_registry::{register, contains};
215///
216/// // Check for unregistered type
217/// assert!(!contains::<i32>().expect("Failed to check registry"));
218///
219/// // Register and check
220/// register(42i32);
221/// assert!(contains::<i32>().expect("Failed to check registry"));
222/// ```
223pub fn contains<T: Send + Sync + 'static>() -> Result<bool, String> {
224    let found = GLOBAL_REGISTRY
225        .lock()
226        .map(|m| m.contains_key(&TypeId::of::<T>()))
227        .map_err(|_| "Failed to acquire registry lock".to_string())?;
228
229    emit_event(&RegistryEvent::Contains {
230        type_name: std::any::type_name::<T>(),
231        found,
232    });
233
234    Ok(found)
235}
236
237/// Retrieves a value of type `T` from the global registry.
238///
239/// # Returns
240///
241/// - `Ok(Arc<T>)` if the type is found and the downcast is successful
242/// - `Err(String)` in the following cases:
243///   - Failed to acquire the registry lock
244///   - Type `T` is not found in the registry
245///   - Type mismatch (found a different type with the same TypeId)
246///
247/// # Examples
248///
249/// ```
250/// use singleton_registry::{register, get};
251/// use std::sync::Arc;
252///
253/// // Register and retrieve a value
254/// register(42i32);
255/// let num: Arc<i32> = get().expect("Failed to get i32");
256/// assert_eq!(*num, 42);
257///
258/// // Handle missing value
259/// let result: Result<Arc<String>, _> = get();
260/// assert!(result.is_err());
261/// ```
262pub fn get<T: Send + Sync + 'static>() -> Result<Arc<T>, String> {
263    let map = GLOBAL_REGISTRY
264        .lock()
265        .map_err(|_| "Failed to acquire registry lock")?;
266
267    let any_arc_opt = map.get(&TypeId::of::<T>()).cloned();
268
269    // Determine result and emit tracing event in a single place.
270    let result: Result<Arc<T>, String> = match any_arc_opt {
271        Some(any_arc) => any_arc.downcast::<T>().map_err(|_| {
272            format!(
273                "Type mismatch in registry for type: {}",
274                std::any::type_name::<T>()
275            )
276        }),
277        None => Err(format!(
278            "Type not found in registry: {}",
279            std::any::type_name::<T>()
280        )),
281    };
282
283    emit_event(&RegistryEvent::Get {
284        type_name: std::any::type_name::<T>(),
285        found: result.is_ok(),
286    });
287
288    result
289}
290
291/// Retrieves a clone of the value stored in the registry for the given type.
292///
293/// This function returns an owned value by cloning the value stored in the registry.
294/// The type `T` must implement `Clone`. This is useful if you need to own the value
295/// rather than share it via `Arc<T>`.
296///
297/// # Errors
298/// Returns an error if the value for the given type is not found in the registry.
299///
300/// # Examples
301/// ```
302/// use singleton_registry::{register, get_cloned};
303///
304/// register("hello".to_string());
305/// let value: String = get_cloned::<String>().expect("Value should be present");
306/// assert_eq!(value, "hello");
307/// ```
308pub fn get_cloned<T: Send + Sync + Clone + 'static>() -> Result<T, String> {
309    let arc = get::<T>()?;
310    Ok((*arc).clone())
311}
312
313/// Returns a `'static` reference to a value stored in the registry.
314///
315/// This function is here only for educational purpose and future reference. Better to avoid it.
316///
317/// # Safety
318/// This function intentionally leaks the `Arc<T>` to extend its lifetime to `'static`.
319/// Only use this for values that are truly immutable and meant to live for the entire
320/// lifetime of the application (true singletons). Never use for values that may be
321/// mutated or replaced, or if you plan to clear the registry at runtime.
322///
323/// If you need shared access, prefer using `Arc<T>` via `get`.
324#[doc(hidden)]
325pub fn get_ref<T: Send + Sync + Clone + 'static>() -> Result<&'static T, String> {
326    let arc = get::<T>()?;
327    let ptr = Arc::into_raw(arc);
328    Ok(unsafe { &*ptr })
329}
330
331#[doc(hidden)]
332pub fn clear() {
333    emit_event(&RegistryEvent::Clear {});
334
335    if let Ok(mut registry) = GLOBAL_REGISTRY.lock() {
336        registry.clear();
337    }
338}
339
340// -------------------------------------------------------------------------------------------------
341// Tests
342// -------------------------------------------------------------------------------------------------
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use serial_test::serial;
348    use std::sync::Arc;
349
350    #[test]
351    #[serial]
352    fn test_register_and_get_primitive() -> Result<(), String> {
353        // Clear any previous state
354        clear();
355
356        // Register a primitive type
357        register(42i32);
358
359        // Retrieve it 1
360        let num: Arc<i32> = get()?;
361        assert_eq!(*num, 42);
362
363        // Retrieve it 2
364        let num_2 = get::<i32>()?;
365        assert_eq!(*num_2, 42);
366
367        Ok(())
368    }
369
370    #[test]
371    #[serial]
372    fn test_register_and_get_string() {
373        // Clear the registry before the test
374        clear();
375
376        // Create and register a string
377        let s = "test".to_string();
378        register(s.clone());
379
380        // Retrieve it and verify
381        let retrieved: Arc<String> = get().expect("Failed to retrieve string");
382        assert_eq!(&*retrieved, &s);
383
384        // Clear the registry after the test
385        clear();
386    }
387
388    #[test]
389    #[serial]
390    fn test_get_nonexistent() {
391        clear();
392
393        let result: Result<Arc<String>, _> = get();
394        assert!(result.is_err());
395        assert_eq!(
396            result.unwrap_err(),
397            "Type not found in registry: alloc::string::String"
398        );
399    }
400
401    #[test]
402    #[serial]
403    fn test_thread_safety() {
404        clear();
405
406        use std::sync::{mpsc, Arc, Barrier};
407        use std::thread;
408
409        let barrier = Arc::new(Barrier::new(2));
410        let (main_tx, thread_rx) = mpsc::channel();
411        let (thread_tx, main_rx) = mpsc::channel();
412
413        let barrier_clone = barrier.clone();
414        let handle = thread::spawn(move || {
415            register(100u32);
416            thread_tx.send(100u32).unwrap();
417
418            // Wait for the main thread to register its value
419            let main_value: String = thread_rx.recv().unwrap();
420
421            // Synchronize: ensure both threads have registered before retrieval
422            barrier_clone.wait();
423
424            let s: Arc<String> = get().expect("Failed to get string in thread");
425            assert_eq!(&*s, &main_value);
426        });
427
428        let thread_value = main_rx.recv().unwrap();
429        let num: Arc<u32> = get().expect("Failed to get u32 in main thread");
430        assert_eq!(*num, thread_value);
431
432        // Register a string in main thread
433        let main_string = "main_thread_value".to_string();
434        register(main_string.clone());
435        main_tx.send(main_string.clone()).unwrap();
436
437        // Synchronize: ensure both threads have registered before retrieval
438        barrier.wait();
439
440        handle.join().unwrap();
441        clear();
442    }
443
444    #[test]
445    #[serial]
446    fn test_multiple_types() {
447        clear();
448
449        // Define wrapper types to ensure unique TypeIds
450        #[derive(Debug, PartialEq, Eq, Clone)]
451        struct Num(i32);
452        #[derive(Debug, PartialEq, Eq, Clone)]
453        struct Text(String);
454        #[derive(Debug, PartialEq, Eq, Clone)]
455        struct Numbers(Vec<i32>);
456
457        // Create the values
458        let num_val = Num(42);
459        let text_val = Text("hello".to_string());
460        let nums_val = Numbers(vec![1, 2, 3]);
461
462        // Register all types first
463        register(num_val.clone());
464        register(text_val.clone());
465        register(nums_val.clone());
466
467        // Then retrieve and verify each one
468        let num: Arc<Num> = get().expect("Num not found in registry");
469        assert_eq!(num.0, num_val.0);
470
471        let text: Arc<Text> = get().expect("Text not found in registry");
472        assert_eq!(text.0, text_val.0);
473
474        let nums: Arc<Numbers> = get().expect("Numbers not found in registry");
475        assert_eq!(&nums.0, &nums_val.0);
476
477        // Clear the registry after the test
478        clear();
479    }
480
481    #[test]
482    #[serial]
483    fn test_custom_type() {
484        clear();
485
486        #[derive(Debug, PartialEq, Eq, Clone)]
487        struct MyStruct {
488            field: String,
489        }
490
491        let my_value = MyStruct {
492            field: "test".into(),
493        };
494        register(my_value.clone());
495
496        let retrieved: Arc<MyStruct> = get().unwrap();
497        assert_eq!(&*retrieved, &my_value);
498    }
499
500    #[test]
501    #[serial]
502    fn test_tuple_type() -> Result<(), String> {
503        clear();
504
505        let tuple = (1, "test");
506        register(tuple.clone());
507
508        let retrieved = get::<(i32, &str)>()?;
509        assert_eq!(&*retrieved, &tuple);
510
511        Ok(())
512    }
513
514    #[test]
515    #[serial]
516    fn test_overwrite_same_type() {
517        clear();
518
519        register(10i32);
520        register(20i32); // should replace
521
522        let num: Arc<i32> = get().unwrap();
523        assert_eq!(*num, 20);
524    }
525
526    #[test]
527    #[serial]
528    fn test_di_get_cloned() {
529        clear();
530        register("hello".to_string());
531        let value: String = get_cloned::<String>().expect("Value should be present");
532        assert_eq!(value, "hello");
533    }
534
535    #[test]
536    #[serial]
537    fn test_di_get_ref() {
538        clear();
539        register("world".to_string());
540        let value: &'static String = get_ref::<String>().expect("Value should be present");
541        assert_eq!(value, "world");
542
543        // WARNING: The following line causes undefined behavior (UB).
544        // After calling `di_clear`, the original `String` has been dropped and its memory deallocated,
545        // but `value` is still a reference to the old memory location. Accessing or printing `value`
546        // after this point is use-after-free, which is always UB in Rust. This may cause a crash,
547        // memory corruption, or appear to "work" by accident, depending on the allocator and OS.
548        // This code is for demonstration purposes only—never use a leaked reference after the value is dropped!
549        // di_clear(); // value is dropped
550        // let _ = value.len();
551        // eprintln!("{}", value);
552    }
553
554    #[test]
555    #[serial]
556    fn test_di_contains() {
557        clear();
558        assert!(!contains::<u32>().unwrap());
559        register(1u32);
560        assert!(contains::<u32>().unwrap());
561    }
562
563    #[test]
564    #[serial]
565    fn test_function_pointer_registration() {
566        clear();
567
568        // Test the function pointer example from README
569        let multiply_by_two: fn(i32) -> i32 = |x| x * 2;
570        register(multiply_by_two);
571
572        let doubler: Arc<fn(i32) -> i32> = get().unwrap();
573        let result = doubler(21);
574        assert_eq!(result, 42);
575    }
576
577    #[test]
578    #[serial]
579    fn test_trace_callback_invoked() {
580        clear();
581        use std::sync::atomic::{AtomicUsize, Ordering};
582        static COUNT: AtomicUsize = AtomicUsize::new(0);
583        set_trace_callback(|_e| {
584            COUNT.fetch_add(1, Ordering::SeqCst);
585        });
586        register(5u8);
587        assert_eq!(COUNT.load(Ordering::SeqCst), 1); // adjust after re-enabling emit
588        clear_trace_callback();
589    }
590
591    // -------------------------------------------------------------
592    // Display implementation tests
593    // -------------------------------------------------------------
594
595    #[test]
596    fn test_display_register() {
597        let ev = RegistryEvent::Register { type_name: "i32" };
598        assert_eq!(ev.to_string(), "register { type_name: i32 }");
599    }
600
601    #[test]
602    fn test_display_get() {
603        let ev = RegistryEvent::Get {
604            type_name: "String",
605            found: true,
606        };
607        assert_eq!(ev.to_string(), "get { type_name: String, found: true }");
608    }
609
610    #[test]
611    fn test_display_contains() {
612        let ev = RegistryEvent::Contains {
613            type_name: "u8",
614            found: false,
615        };
616        assert_eq!(ev.to_string(), "contains { type_name: u8, found: false }");
617    }
618
619    #[test]
620    fn test_display_clear() {
621        let ev = RegistryEvent::Clear {};
622        assert_eq!(ev.to_string(), "Clearing the Registry");
623    }
624}