Skip to main content

rust_dicore/
scope.rs

1use crate::entry::IServiceResolver;
2use crate::lifetime::ServiceLifetime;
3use crate::provider::ServiceProvider;
4use std::any::{Any, TypeId};
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7
8pub struct Scope {
9    parent: Arc<ServiceProvider>,
10    scoped_cache: RwLock<HashMap<usize, Arc<dyn Any + Send + Sync>>>,
11}
12
13impl Scope {
14    pub(crate) fn new(parent: Arc<ServiceProvider>) -> Self {
15        Self {
16            parent,
17            scoped_cache: RwLock::new(HashMap::new()),
18        }
19    }
20
21    pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
22        self.try_get::<T>()
23            .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
24    }
25
26    pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
27        self.try_get::<T>()
28    }
29
30    /// Alias for `get_optional()` with MEDI-inspired naming.
31    pub fn get_service<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
32        self.get_optional::<T>()
33    }
34
35    /// Alias for `get()` with MEDI-inspired naming.
36    pub fn get_required_service<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
37        self.get::<T>()
38    }
39
40    /// Return all registered instances of the given type.
41    /// MEDI-inspired naming.
42    pub fn get_services<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
43        self.get_all::<T>()
44    }
45
46    pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
47        self.try_get_keyed::<T>(key).unwrap_or_else(|| {
48            panic!(
49                "keyed service not registered: {}:{}",
50                std::any::type_name::<T>(),
51                key
52            )
53        })
54    }
55
56    pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
57        let tid = TypeId::of::<T>();
58        if let Some(entries) = self.parent.entries_by_tid(&tid) {
59            entries
60                .iter()
61                .filter_map(|e| {
62                    let arc = self.get_any_by_entry(e)?;
63                    ServiceProvider::extract(arc)
64                })
65                .collect()
66        } else {
67            Vec::new()
68        }
69    }
70
71    pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
72        self.parent.get_named_any(name)
73    }
74
75    fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
76        let tid = TypeId::of::<T>();
77        let entry = self
78            .parent
79            .entries_by_tid(&tid)?
80            .iter()
81            .find(|e| e.key.is_none())?;
82        let arc = self.get_any_by_entry(entry)?;
83        ServiceProvider::extract(arc)
84    }
85
86    fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
87        let tid = TypeId::of::<T>();
88        let entries = self.parent.entries_by_tid(&tid)?;
89        let entry = entries.iter().find(|e| e.key.as_deref() == Some(key))?;
90        let arc = self.get_any_by_entry(entry)?;
91        ServiceProvider::extract(arc)
92    }
93
94    fn get_any_by_entry(
95        &self,
96        entry: &crate::entry::ServiceEntry,
97    ) -> Option<Arc<dyn Any + Send + Sync>> {
98        match entry.lifetime {
99            ServiceLifetime::Singleton => {
100                // Singleton cache lives in parent's eager cache
101                self.parent.get_any_by_entry(entry)
102            }
103            ServiceLifetime::Transient => Some((entry.factory)(self.parent.as_ref())),
104            ServiceLifetime::Scoped => {
105                {
106                    let cache = self.scoped_cache.read().unwrap();
107                    if let Some(instance) = cache.get(&entry.cache_key) {
108                        return Some(instance.clone());
109                    }
110                }
111                let instance = (entry.factory)(self);
112                {
113                    self.scoped_cache
114                        .write()
115                        .unwrap()
116                        .insert(entry.cache_key, instance.clone());
117                }
118                Some(instance)
119            }
120        }
121    }
122}
123
124impl IServiceResolver for Scope {
125    fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
126        if let Some(entries) = self.parent.entries_by_str(key) {
127            for entry in entries {
128                if entry.key.is_none() {
129                    if let Some(r) = self.get_any_by_entry(entry) {
130                        return Some(r);
131                    }
132                }
133            }
134        }
135        None
136    }
137    fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
138        let entry = self.parent.entry_by_str(key, variant)?;
139        self.get_any_by_entry(entry)
140    }
141}
142
143impl Scope {
144    /// Register a named service (for `impl_service_locator!` macro).
145    pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
146        self.parent.rdi_register_named(name, service);
147    }
148
149    /// Remove a named service (for `impl_service_locator!` macro).
150    pub fn rdi_remove_named(&self, name: &str) {
151        self.parent.rdi_remove_named(name);
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::collection::ServiceCollection;
159    use std::sync::atomic::{AtomicU64, Ordering};
160    #[derive(Debug, PartialEq)]
161    struct Sd(u64);
162    #[test]
163    fn scoped_cached_per_scope() {
164        static NXT: AtomicU64 = AtomicU64::new(0);
165        let p = Arc::new(
166            ServiceCollection::new()
167                .scoped(|_| Arc::new(Sd(NXT.fetch_add(1, Ordering::SeqCst))))
168                .build()
169                .unwrap(),
170        );
171        let s1 = p.scope();
172        let a = s1.get::<Sd>();
173        let b = s1.get::<Sd>();
174        assert_eq!(a.0, b.0);
175        let s2 = p.scope();
176        let c = s2.get::<Sd>();
177        assert_ne!(a.0, c.0);
178    }
179}