Skip to main content

rust_dicore/
wrapper.rs

1use crate::entry::IServiceResolver;
2use crate::provider::ServiceProvider;
3use std::any::Any;
4use std::sync::Arc;
5
6pub struct ServiceProviderWrapper {
7    child: Arc<ServiceProvider>,
8    root: Arc<ServiceProvider>,
9}
10
11impl ServiceProviderWrapper {
12    pub fn new(child: Arc<ServiceProvider>, root: Arc<ServiceProvider>) -> Arc<Self> {
13        Arc::new(Self { child, root })
14    }
15    pub fn child(&self) -> &Arc<ServiceProvider> {
16        &self.child
17    }
18    pub fn root(&self) -> &Arc<ServiceProvider> {
19        &self.root
20    }
21
22    pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
23        self.try_get::<T>()
24            .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
25    }
26    pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
27        self.try_get::<T>()
28    }
29    pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
30        self.try_get_keyed::<T>(key).unwrap_or_else(|| {
31            panic!(
32                "keyed service not found: {}:{}",
33                std::any::type_name::<T>(),
34                key
35            )
36        })
37    }
38    pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
39        let mut r = self.child.get_all::<T>();
40        r.extend(self.root.get_all::<T>());
41        r
42    }
43    pub fn get_named<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> {
44        self.child
45            .get_named::<T>(name)
46            .or_else(|| self.root.get_named::<T>(name))
47    }
48    pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
49        self.child
50            .get_named_any(name)
51            .or_else(|| self.root.get_named_any(name))
52    }
53
54    fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
55        self.child
56            .get_optional::<T>()
57            .or_else(|| self.root.get_optional::<T>())
58    }
59
60    fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
61        // Delegate to IServiceResolver which goes through the cache properly.
62        IServiceResolver::get_keyed_any(self.child.as_ref(), std::any::type_name::<T>(), key)
63            .or_else(|| {
64                IServiceResolver::get_keyed_any(self.root.as_ref(), std::any::type_name::<T>(), key)
65            })
66            .and_then(|arc| crate::provider::ServiceProvider::extract(arc))
67    }
68}
69
70impl IServiceResolver for ServiceProviderWrapper {
71    fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
72        IServiceResolver::get_any(self.child.as_ref(), key)
73            .or_else(|| IServiceResolver::get_any(self.root.as_ref(), key))
74    }
75    fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
76        IServiceResolver::get_keyed_any(self.child.as_ref(), key, variant)
77            .or_else(|| IServiceResolver::get_keyed_any(self.root.as_ref(), key, variant))
78    }
79}
80
81impl ServiceProviderWrapper {
82    /// Register a named service (for `impl_service_locator!` macro).
83    pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
84        self.child.rdi_register_named(name, service);
85    }
86
87    /// Remove a named service (for `impl_service_locator!` macro).
88    pub fn rdi_remove_named(&self, name: &str) {
89        self.child.rdi_remove_named(name);
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::collection::ServiceCollection;
97    #[derive(Debug, PartialEq)]
98    struct PO;
99    #[derive(Debug, PartialEq)]
100    struct RO {
101        v: i32,
102    }
103    #[derive(Debug, PartialEq)]
104    struct B {
105        s: String,
106    }
107    #[test]
108    fn child_prio() {
109        let r = Arc::new(
110            ServiceCollection::new()
111                .singleton(|_| Arc::new(B { s: "root".into() }))
112                .build()
113                .unwrap(),
114        );
115        let c = ServiceCollection::new()
116            .singleton(|_| Arc::new(B { s: "child".into() }))
117            .build()
118            .unwrap();
119        let w = ServiceProviderWrapper::new(Arc::new(c), r);
120        assert_eq!(w.get::<B>().s, "child");
121    }
122    #[test]
123    fn root_fallback() {
124        let r = Arc::new(
125            ServiceCollection::new()
126                .singleton(|_| Arc::new(RO { v: 42 }))
127                .build()
128                .unwrap(),
129        );
130        let c = ServiceCollection::new().build().unwrap();
131        let w = ServiceProviderWrapper::new(Arc::new(c), r);
132        assert_eq!(w.get::<RO>().v, 42);
133    }
134    #[test]
135    fn child_invisible() {
136        let r = Arc::new(ServiceCollection::new().build().unwrap());
137        let c = ServiceCollection::new()
138            .singleton(|_| Arc::new(PO))
139            .build()
140            .unwrap();
141        let _w = ServiceProviderWrapper::new(Arc::new(c), r.clone());
142        assert!(r.get_optional::<PO>().is_none());
143    }
144}