singleton_registry/registry.rs
1#![allow(dead_code)]
2
3//! A thread-safe dependency injection registry for storing and retrieving global instances.
4//! Currently designed for write-once, read-many pattern.
5//!
6//! This module provides a type-safe way to register and retrieve instances of any type
7//! that implements `Send + Sync + 'static`.
8//!
9//! # Examples
10//!
11//! ```
12//! use singleton_registry::{register, get};
13//! use std::sync::Arc;
14//!
15//! // Register a value
16//! register("Hello, World!".to_string());
17//!
18//! // Retrieve the value
19//! let message: Arc<String> = get().unwrap();
20//! assert_eq!(&*message, "Hello, World!");
21//! ```
22
23use std::{
24 any::{Any, TypeId},
25 collections::HashMap,
26 fmt,
27 sync::{Arc, LazyLock, Mutex},
28};
29
30/// Global thread-safe registry storing type instances.
31///
32/// This is a `LazyLock` ensuring thread-safe lazy initialization of the underlying `Mutex<HashMap>`.
33/// The registry maps `TypeId` to `Arc<dyn Any + Send + Sync>` for type-erased storage.
34static GLOBAL_REGISTRY: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
35 LazyLock::new(|| Mutex::new(HashMap::new()));
36
37// -------------------------------------------------------------------------------------------------
38// Tracing callback support
39// -------------------------------------------------------------------------------------------------
40
41/// Events emitted by the dependency-injection registry.
42#[derive(Debug)]
43pub enum RegistryEvent {
44 /// A value was registered.
45 Register {
46 type_name: &'static str,
47 },
48 /// A value was requested with `di_get`.
49 Get {
50 type_name: &'static str,
51 found: bool,
52 },
53 /// A `di_contains` check was performed.
54 Contains {
55 type_name: &'static str,
56 found: bool,
57 },
58 Clear {},
59}
60
61impl fmt::Display for RegistryEvent {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 RegistryEvent::Register { type_name } => {
65 write!(f, "register {{ type_name: {type_name} }}")
66 }
67 RegistryEvent::Get { type_name, found } => {
68 write!(f, "get {{ type_name: {type_name}, found: {found} }}")
69 }
70 RegistryEvent::Contains { type_name, found } => {
71 write!(f, "contains {{ type_name: {type_name}, found: {found} }}")
72 }
73 RegistryEvent::Clear {} => write!(f, "Clearing the Registry"),
74 }
75 }
76}
77
78/// Type alias for the user-supplied tracing callback.
79///
80/// The callback receives a reference to a `RegistryEvent` every time the registry is
81/// interacted with. It must be thread-safe because the registry itself is globally shared.
82pub type TraceCallback = dyn Fn(&RegistryEvent) + Send + Sync + 'static;
83
84/// Holds an optional user-defined tracing callback.
85static TRACE_CALLBACK: LazyLock<Mutex<Option<Arc<TraceCallback>>>> =
86 LazyLock::new(|| Mutex::new(None));
87
88/// Sets a tracing callback that will be invoked on every registry interaction.
89///
90/// Pass `None` (or call `clear_trace_callback`) to disable tracing.
91///
92/// # Example
93/// ```rust
94/// use singleton_registry::{set_trace_callback, RegistryEvent};
95///
96/// set_trace_callback(|event| println!("[registry-trace] {:?}", event));
97/// ```
98pub fn set_trace_callback(callback: impl Fn(&RegistryEvent) + Send + Sync + 'static) {
99 let mut guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
100 *guard = Some(Arc::new(callback));
101}
102
103/// Clears the tracing callback (disables registry tracing).
104pub fn clear_trace_callback() {
105 let mut guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
106 *guard = None;
107}
108
109/// Convenience wrapper to emit a registry event using the current callback.
110fn emit_event(event: &RegistryEvent) {
111 // lock poisoning unlikely; if poisoned, keep emitting with recovered lock
112 let guard = TRACE_CALLBACK.lock().unwrap_or_else(|p| p.into_inner());
113 if let Some(callback) = guard.as_ref() {
114 callback(event);
115 }
116}
117
118// -------------------------------------------------------------------------------------------------
119// Registry
120// -------------------------------------------------------------------------------------------------
121
122/// Registers an `Arc<T>` in the global registry.
123///
124/// This is more efficient than `di_register` when you already have an `Arc`,
125/// as it avoids creating an additional reference count.
126///
127/// # Safety
128///
129/// If the registry's lock is poisoned (which can happen if a thread panicked while
130/// holding the lock), this function will recover the lock and continue execution.
131/// This is safe because the registry is used in a read-only manner after the
132/// initial registration phase in `main.rs`.
133///
134/// # Arguments
135///
136/// * `value` - The `Arc`-wrapped value to register. The inner type must implement
137/// `Send + Sync + 'static`.
138///
139/// # Examples
140///
141/// ```
142/// use std::sync::Arc;
143/// use singleton_registry::{register_arc, get};
144///
145/// let value = Arc::new("shared".to_string());
146/// register_arc(value.clone());
147///
148/// let retrieved: Arc<String> = get().expect("Failed to get value");
149/// assert_eq!(&*retrieved, "shared");
150/// ```
151pub fn register_arc<T: Send + Sync + 'static>(value: Arc<T>) {
152 emit_event(&RegistryEvent::Register {
153 type_name: std::any::type_name::<T>(),
154 });
155
156 GLOBAL_REGISTRY
157 .lock()
158 // The registry is used as read only, so we do not expect a poisoned lock.
159 // Poisoning only occurs if a thread panics while holding the lock.
160 .unwrap_or_else(|poisoned| poisoned.into_inner())
161 .insert(TypeId::of::<T>(), value);
162}
163
164/// Registers a value of type `T` in the global registry.
165///
166/// This is a convenience wrapper around `di_register_arc` that takes ownership
167/// of the value and wraps it in an `Arc` automatically.
168///
169/// # Safety
170///
171/// If the registry's lock is poisoned (which can happen if a thread panicked while
172/// holding the lock), this function will recover the lock and continue execution.
173/// This is safe because the registry is used in a read-only manner after the
174/// initial registration phase in `main.rs`.
175///
176/// # Arguments
177///
178/// * `value` - The value to register. Must implement `Send + Sync + 'static`.
179///
180/// # Examples
181///
182/// ```
183/// use singleton_registry::{register, get};
184/// use std::sync::Arc;
185///
186/// // Register a primitive
187/// register(42i32);
188///
189/// // Register a string
190/// register("Hello".to_string());
191///
192/// // Retrieve values
193/// let num: Arc<i32> = get().expect("Failed to get i32");
194/// let s: Arc<String> = get().expect("Failed to get String");
195///
196/// assert_eq!(*num, 42);
197/// assert_eq!(&*s, "Hello");
198/// ```
199pub fn register<T: Send + Sync + 'static>(value: T) {
200 register_arc::<T>(Arc::new(value));
201}
202
203/// Checks if a value of type `T` is registered in the global registry.
204///
205/// # Returns
206///
207/// - `Ok(true)` if the type is registered
208/// - `Ok(false)` if the type is not found
209/// - `Err(String)` if failed to acquire the registry lock
210///
211/// # Examples
212///
213/// ```
214/// use singleton_registry::{register, contains};
215///
216/// // Check for unregistered type
217/// assert!(!contains::<i32>().expect("Failed to check registry"));
218///
219/// // Register and check
220/// register(42i32);
221/// assert!(contains::<i32>().expect("Failed to check registry"));
222/// ```
223pub fn contains<T: Send + Sync + 'static>() -> Result<bool, String> {
224 let found = GLOBAL_REGISTRY
225 .lock()
226 .map(|m| m.contains_key(&TypeId::of::<T>()))
227 .map_err(|_| "Failed to acquire registry lock".to_string())?;
228
229 emit_event(&RegistryEvent::Contains {
230 type_name: std::any::type_name::<T>(),
231 found,
232 });
233
234 Ok(found)
235}
236
237/// Retrieves a value of type `T` from the global registry.
238///
239/// # Returns
240///
241/// - `Ok(Arc<T>)` if the type is found and the downcast is successful
242/// - `Err(String)` in the following cases:
243/// - Failed to acquire the registry lock
244/// - Type `T` is not found in the registry
245/// - Type mismatch (found a different type with the same TypeId)
246///
247/// # Examples
248///
249/// ```
250/// use singleton_registry::{register, get};
251/// use std::sync::Arc;
252///
253/// // Register and retrieve a value
254/// register(42i32);
255/// let num: Arc<i32> = get().expect("Failed to get i32");
256/// assert_eq!(*num, 42);
257///
258/// // Handle missing value
259/// let result: Result<Arc<String>, _> = get();
260/// assert!(result.is_err());
261/// ```
262pub fn get<T: Send + Sync + 'static>() -> Result<Arc<T>, String> {
263 let map = GLOBAL_REGISTRY
264 .lock()
265 .map_err(|_| "Failed to acquire registry lock")?;
266
267 let any_arc_opt = map.get(&TypeId::of::<T>()).cloned();
268
269 // Determine result and emit tracing event in a single place.
270 let result: Result<Arc<T>, String> = match any_arc_opt {
271 Some(any_arc) => any_arc.downcast::<T>().map_err(|_| {
272 format!(
273 "Type mismatch in registry for type: {}",
274 std::any::type_name::<T>()
275 )
276 }),
277 None => Err(format!(
278 "Type not found in registry: {}",
279 std::any::type_name::<T>()
280 )),
281 };
282
283 emit_event(&RegistryEvent::Get {
284 type_name: std::any::type_name::<T>(),
285 found: result.is_ok(),
286 });
287
288 result
289}
290
291/// Retrieves a clone of the value stored in the registry for the given type.
292///
293/// This function returns an owned value by cloning the value stored in the registry.
294/// The type `T` must implement `Clone`. This is useful if you need to own the value
295/// rather than share it via `Arc<T>`.
296///
297/// # Errors
298/// Returns an error if the value for the given type is not found in the registry.
299///
300/// # Examples
301/// ```
302/// use singleton_registry::{register, get_cloned};
303///
304/// register("hello".to_string());
305/// let value: String = get_cloned::<String>().expect("Value should be present");
306/// assert_eq!(value, "hello");
307/// ```
308pub fn get_cloned<T: Send + Sync + Clone + 'static>() -> Result<T, String> {
309 let arc = get::<T>()?;
310 Ok((*arc).clone())
311}
312
313/// Returns a `'static` reference to a value stored in the registry.
314///
315/// This function is here only for educational purpose and future reference. Better to avoid it.
316///
317/// # Safety
318/// This function intentionally leaks the `Arc<T>` to extend its lifetime to `'static`.
319/// Only use this for values that are truly immutable and meant to live for the entire
320/// lifetime of the application (true singletons). Never use for values that may be
321/// mutated or replaced, or if you plan to clear the registry at runtime.
322///
323/// If you need shared access, prefer using `Arc<T>` via `get`.
324#[doc(hidden)]
325pub fn get_ref<T: Send + Sync + Clone + 'static>() -> Result<&'static T, String> {
326 let arc = get::<T>()?;
327 let ptr = Arc::into_raw(arc);
328 Ok(unsafe { &*ptr })
329}
330
331#[doc(hidden)]
332pub fn clear() {
333 emit_event(&RegistryEvent::Clear {});
334
335 if let Ok(mut registry) = GLOBAL_REGISTRY.lock() {
336 registry.clear();
337 }
338}
339
340// -------------------------------------------------------------------------------------------------
341// Tests
342// -------------------------------------------------------------------------------------------------
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use serial_test::serial;
348 use std::sync::Arc;
349
350 #[test]
351 #[serial]
352 fn test_register_and_get_primitive() -> Result<(), String> {
353 // Clear any previous state
354 clear();
355
356 // Register a primitive type
357 register(42i32);
358
359 // Retrieve it 1
360 let num: Arc<i32> = get()?;
361 assert_eq!(*num, 42);
362
363 // Retrieve it 2
364 let num_2 = get::<i32>()?;
365 assert_eq!(*num_2, 42);
366
367 Ok(())
368 }
369
370 #[test]
371 #[serial]
372 fn test_register_and_get_string() {
373 // Clear the registry before the test
374 clear();
375
376 // Create and register a string
377 let s = "test".to_string();
378 register(s.clone());
379
380 // Retrieve it and verify
381 let retrieved: Arc<String> = get().expect("Failed to retrieve string");
382 assert_eq!(&*retrieved, &s);
383
384 // Clear the registry after the test
385 clear();
386 }
387
388 #[test]
389 #[serial]
390 fn test_get_nonexistent() {
391 clear();
392
393 let result: Result<Arc<String>, _> = get();
394 assert!(result.is_err());
395 assert_eq!(
396 result.unwrap_err(),
397 "Type not found in registry: alloc::string::String"
398 );
399 }
400
401 #[test]
402 #[serial]
403 fn test_thread_safety() {
404 clear();
405
406 use std::sync::{mpsc, Arc, Barrier};
407 use std::thread;
408
409 let barrier = Arc::new(Barrier::new(2));
410 let (main_tx, thread_rx) = mpsc::channel();
411 let (thread_tx, main_rx) = mpsc::channel();
412
413 let barrier_clone = barrier.clone();
414 let handle = thread::spawn(move || {
415 register(100u32);
416 thread_tx.send(100u32).unwrap();
417
418 // Wait for the main thread to register its value
419 let main_value: String = thread_rx.recv().unwrap();
420
421 // Synchronize: ensure both threads have registered before retrieval
422 barrier_clone.wait();
423
424 let s: Arc<String> = get().expect("Failed to get string in thread");
425 assert_eq!(&*s, &main_value);
426 });
427
428 let thread_value = main_rx.recv().unwrap();
429 let num: Arc<u32> = get().expect("Failed to get u32 in main thread");
430 assert_eq!(*num, thread_value);
431
432 // Register a string in main thread
433 let main_string = "main_thread_value".to_string();
434 register(main_string.clone());
435 main_tx.send(main_string.clone()).unwrap();
436
437 // Synchronize: ensure both threads have registered before retrieval
438 barrier.wait();
439
440 handle.join().unwrap();
441 clear();
442 }
443
444 #[test]
445 #[serial]
446 fn test_multiple_types() {
447 clear();
448
449 // Define wrapper types to ensure unique TypeIds
450 #[derive(Debug, PartialEq, Eq, Clone)]
451 struct Num(i32);
452 #[derive(Debug, PartialEq, Eq, Clone)]
453 struct Text(String);
454 #[derive(Debug, PartialEq, Eq, Clone)]
455 struct Numbers(Vec<i32>);
456
457 // Create the values
458 let num_val = Num(42);
459 let text_val = Text("hello".to_string());
460 let nums_val = Numbers(vec![1, 2, 3]);
461
462 // Register all types first
463 register(num_val.clone());
464 register(text_val.clone());
465 register(nums_val.clone());
466
467 // Then retrieve and verify each one
468 let num: Arc<Num> = get().expect("Num not found in registry");
469 assert_eq!(num.0, num_val.0);
470
471 let text: Arc<Text> = get().expect("Text not found in registry");
472 assert_eq!(text.0, text_val.0);
473
474 let nums: Arc<Numbers> = get().expect("Numbers not found in registry");
475 assert_eq!(&nums.0, &nums_val.0);
476
477 // Clear the registry after the test
478 clear();
479 }
480
481 #[test]
482 #[serial]
483 fn test_custom_type() {
484 clear();
485
486 #[derive(Debug, PartialEq, Eq, Clone)]
487 struct MyStruct {
488 field: String,
489 }
490
491 let my_value = MyStruct {
492 field: "test".into(),
493 };
494 register(my_value.clone());
495
496 let retrieved: Arc<MyStruct> = get().unwrap();
497 assert_eq!(&*retrieved, &my_value);
498 }
499
500 #[test]
501 #[serial]
502 fn test_tuple_type() -> Result<(), String> {
503 clear();
504
505 let tuple = (1, "test");
506 register(tuple.clone());
507
508 let retrieved = get::<(i32, &str)>()?;
509 assert_eq!(&*retrieved, &tuple);
510
511 Ok(())
512 }
513
514 #[test]
515 #[serial]
516 fn test_overwrite_same_type() {
517 clear();
518
519 register(10i32);
520 register(20i32); // should replace
521
522 let num: Arc<i32> = get().unwrap();
523 assert_eq!(*num, 20);
524 }
525
526 #[test]
527 #[serial]
528 fn test_di_get_cloned() {
529 clear();
530 register("hello".to_string());
531 let value: String = get_cloned::<String>().expect("Value should be present");
532 assert_eq!(value, "hello");
533 }
534
535 #[test]
536 #[serial]
537 fn test_di_get_ref() {
538 clear();
539 register("world".to_string());
540 let value: &'static String = get_ref::<String>().expect("Value should be present");
541 assert_eq!(value, "world");
542
543 // WARNING: The following line causes undefined behavior (UB).
544 // After calling `di_clear`, the original `String` has been dropped and its memory deallocated,
545 // but `value` is still a reference to the old memory location. Accessing or printing `value`
546 // after this point is use-after-free, which is always UB in Rust. This may cause a crash,
547 // memory corruption, or appear to "work" by accident, depending on the allocator and OS.
548 // This code is for demonstration purposes only—never use a leaked reference after the value is dropped!
549 // di_clear(); // value is dropped
550 // let _ = value.len();
551 // eprintln!("{}", value);
552 }
553
554 #[test]
555 #[serial]
556 fn test_di_contains() {
557 clear();
558 assert!(!contains::<u32>().unwrap());
559 register(1u32);
560 assert!(contains::<u32>().unwrap());
561 }
562
563 #[test]
564 #[serial]
565 fn test_function_pointer_registration() {
566 clear();
567
568 // Test the function pointer example from README
569 let multiply_by_two: fn(i32) -> i32 = |x| x * 2;
570 register(multiply_by_two);
571
572 let doubler: Arc<fn(i32) -> i32> = get().unwrap();
573 let result = doubler(21);
574 assert_eq!(result, 42);
575 }
576
577 #[test]
578 #[serial]
579 fn test_trace_callback_invoked() {
580 clear();
581 use std::sync::atomic::{AtomicUsize, Ordering};
582 static COUNT: AtomicUsize = AtomicUsize::new(0);
583 set_trace_callback(|_e| {
584 COUNT.fetch_add(1, Ordering::SeqCst);
585 });
586 register(5u8);
587 assert_eq!(COUNT.load(Ordering::SeqCst), 1); // adjust after re-enabling emit
588 clear_trace_callback();
589 }
590
591 // -------------------------------------------------------------
592 // Display implementation tests
593 // -------------------------------------------------------------
594
595 #[test]
596 fn test_display_register() {
597 let ev = RegistryEvent::Register { type_name: "i32" };
598 assert_eq!(ev.to_string(), "register { type_name: i32 }");
599 }
600
601 #[test]
602 fn test_display_get() {
603 let ev = RegistryEvent::Get {
604 type_name: "String",
605 found: true,
606 };
607 assert_eq!(ev.to_string(), "get { type_name: String, found: true }");
608 }
609
610 #[test]
611 fn test_display_contains() {
612 let ev = RegistryEvent::Contains {
613 type_name: "u8",
614 found: false,
615 };
616 assert_eq!(ev.to_string(), "contains { type_name: u8, found: false }");
617 }
618
619 #[test]
620 fn test_display_clear() {
621 let ev = RegistryEvent::Clear {};
622 assert_eq!(ev.to_string(), "Clearing the Registry");
623 }
624}