Skip to main content

rust_dicore/
provider.rs

1use crate::entry::{IServiceResolver, ServiceEntry, ServiceFactory};
2use crate::error::RdiError;
3use crate::lifetime::ServiceLifetime;
4use crate::store::ServiceStore;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9pub struct ServiceProvider {
10    store: ServiceStore,
11    /// String → TypeId lookup for IServiceLocator string-based resolution.
12    type_map: HashMap<&'static str, TypeId>,
13    /// Eager-executed singleton cache. Indexed by cache_key.
14    /// Non-singleton entries are not present.
15    /// Uses RwLock for interior mutability: during eager initialization,
16    /// a singleton factory may reference another singleton not yet populated.
17    /// The lazy fallback in get_any_by_entry handles this by executing the
18    /// factory on-demand and caching the result.
19    singleton_cache: RwLock<HashMap<usize, Arc<dyn Any + Send + Sync>>>,
20    /// String-keyed registry for cross-DLL (cdylib) service access.
21    /// Rust's `TypeId` differs across compilation units, so named
22    /// lookup is the only reliable mechanism for plugin services.
23    pub(crate) named: RwLock<HashMap<String, Arc<dyn Any + Send + Sync>>>,
24}
25
26impl ServiceProvider {
27    pub(crate) fn new(store: ServiceStore) -> Result<Self, RdiError> {
28        // Build type_name → TypeId lookup table for string-based resolution
29        let mut type_map = HashMap::new();
30        for (&tid, entries) in &store {
31            if let Some(e) = entries.first() {
32                type_map.entry(e.type_name).or_insert(tid);
33            }
34        }
35
36        // Two-phase singleton initialization to support cross-references.
37        // Phase 1: collect all singleton entries.
38        let singleton_entries: Vec<(usize, ServiceFactory)> = store
39            .values()
40            .flat_map(|entries| entries.iter())
41            .filter(|e| e.lifetime == ServiceLifetime::Singleton)
42            .map(|e| (e.cache_key, e.factory.clone()))
43            .collect();
44
45        // Phase 2: eagerly execute all singleton factories.
46        // If a singleton factory references another singleton not yet
47        // populated, the lazy fallback in get_any_by_entry handles it
48        // by executing the factory on-demand via interior mutability.
49        let sp = Self {
50            store,
51            type_map,
52            singleton_cache: RwLock::new(HashMap::new()),
53            named: RwLock::new(HashMap::new()),
54        };
55
56        for (ck, factory) in &singleton_entries {
57            let instance = (factory)(&sp as &dyn IServiceResolver);
58            sp.singleton_cache.write().unwrap().insert(*ck, instance);
59        }
60
61        Ok(sp)
62    }
63
64    /// Resolve a service by type. Works uniformly for concrete types and trait objects.
65    /// Panics if not registered.
66    pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
67        self.try_get::<T>()
68            .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
69    }
70
71    /// Resolve a service by type, returning `None` if not registered.
72    pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
73        self.try_get::<T>()
74    }
75
76    /// Resolve a keyed service by type and key. Panics if not found.
77    pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
78        self.try_get_keyed::<T>(key).unwrap_or_else(|| {
79            panic!(
80                "keyed service not registered: {}:{}",
81                std::any::type_name::<T>(),
82                key
83            )
84        })
85    }
86
87    /// Return all registered instances of the given type.
88    pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
89        let tid = TypeId::of::<T>();
90        match self.store.get(&tid) {
91            Some(entries) => entries
92                .iter()
93                .filter_map(|e| {
94                    let arc = self.get_any_by_entry(e)?;
95                    Self::extract(arc)
96                })
97                .collect(),
98            None => Vec::new(),
99        }
100    }
101
102    /// Create a new service scope.
103    ///
104    /// Analogous to `IServiceProvider.CreateScope()` in MEDI.
105    /// Scoped-lifetime services are cached within the returned scope.
106    pub fn scope(self: &Arc<Self>) -> crate::scope::Scope {
107        crate::scope::Scope::new(self.clone())
108    }
109
110    /// Alias for `scope()` with MEDI-inspired naming.
111    pub fn create_scope(self: &Arc<Self>) -> crate::scope::Scope {
112        self.scope()
113    }
114
115    /// Alias for `get()` with MEDI-inspired naming (`GetService<T>()`).
116    pub fn get_service<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
117        self.get_optional::<T>()
118    }
119
120    /// Resolve a service by type, panicking if not registered.
121    /// MEDI-inspired naming (`GetRequiredService<T>()`).
122    pub fn get_required_service<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
123        self.get::<T>()
124    }
125
126    /// Return all registered instances of the given type.
127    /// MEDI-inspired naming (`GetServices<T>()`).
128    pub fn get_services<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
129        self.get_all::<T>()
130    }
131
132    fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
133        let tid = TypeId::of::<T>();
134        let entry = self.store.get(&tid)?.iter().find(|e| e.key.is_none())?;
135        let arc = self.get_any_by_entry(entry)?;
136        Self::extract(arc)
137    }
138
139    fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
140        let tid = TypeId::of::<T>();
141        let entry = self
142            .store
143            .get(&tid)?
144            .iter()
145            .find(|e| e.key.as_deref() == Some(key))?;
146        let arc = self.get_any_by_entry(entry)?;
147        Self::extract(arc)
148    }
149
150    pub(crate) fn get_any_by_entry(
151        &self,
152        entry: &ServiceEntry,
153    ) -> Option<Arc<dyn Any + Send + Sync>> {
154        match entry.lifetime {
155            ServiceLifetime::Singleton => {
156                // Check eager cache (populated at build time).
157                // If not found (e.g. cross-reference during eager init), execute
158                // the factory on-demand as a lazy fallback.
159                {
160                    let cache = self.singleton_cache.read().unwrap();
161                    if let Some(instance) = cache.get(&entry.cache_key) {
162                        return Some(instance.clone());
163                    }
164                }
165                // Lazy fallback: execute factory on-demand and cache.
166                let instance = (entry.factory)(self);
167                self.singleton_cache
168                    .write()
169                    .unwrap()
170                    .insert(entry.cache_key, instance.clone());
171                Some(instance)
172            }
173            ServiceLifetime::Transient | ServiceLifetime::Scoped => Some((entry.factory)(self)),
174        }
175    }
176
177    /// Extract `Arc<T>` from `Arc<Arc<T>>` stored inside `Arc<dyn Any>`.
178    /// The factory double-wraps: inner `Arc<T>`, outer `Arc<dyn Any>`.
179    pub(crate) fn extract<T: ?Sized + Send + Sync + 'static>(
180        arc: Arc<dyn Any + Send + Sync>,
181    ) -> Option<Arc<T>> {
182        let double: Arc<Arc<T>> = arc.downcast::<Arc<T>>().ok()?;
183        Some(Arc::clone(&*double))
184    }
185
186    /// Get entries by TypeId (used internally and by Scope).
187    pub(crate) fn entries_by_tid(&self, tid: &TypeId) -> Option<&Vec<ServiceEntry>> {
188        self.store.get(tid)
189    }
190
191    /// Find entry by string type_name + variant (for string-based resolution).
192    pub(crate) fn entry_by_str(&self, type_name: &str, variant: &str) -> Option<&ServiceEntry> {
193        let tid = self.type_map.get(type_name)?;
194        self.store
195            .get(tid)?
196            .iter()
197            .find(|e| e.key.as_deref() == Some(variant))
198    }
199
200    /// Get entries by string type_name (for string-based resolution).
201    pub(crate) fn entries_by_str(&self, type_name: &str) -> Option<&Vec<ServiceEntry>> {
202        let tid = self.type_map.get(type_name)?;
203        self.store.get(tid)
204    }
205
206    /// Cross-DLL safe named service resolution (generic).
207    pub fn get_named<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> {
208        self.named
209            .read()
210            .unwrap()
211            .get(name)?
212            .clone()
213            .downcast::<T>()
214            .ok()
215    }
216
217    /// Non-generic named resolution; returns `Arc<dyn Any>` for trait-object dispatch.
218    pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
219        self.named.read().unwrap().get(name).cloned()
220    }
221
222    /// Register a named service for cross-DLL plugin access.
223    pub fn register_named<T: Send + Sync + 'static>(&self, name: &str, service: Arc<T>) {
224        self.named
225            .write()
226            .unwrap()
227            .insert(name.to_string(), service);
228    }
229
230    /// Remove a named service (for plugin unload).
231    pub fn remove_named(&self, name: &str) {
232        self.named.write().unwrap().remove(name);
233    }
234
235    /// Register a named service (for `impl_service_locator!` macro).
236    pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
237        self.named
238            .write()
239            .unwrap()
240            .insert(name.to_string(), service);
241    }
242
243    /// Remove a named service (for `impl_service_locator!` macro).
244    pub fn rdi_remove_named(&self, name: &str) {
245        self.named.write().unwrap().remove(name);
246    }
247}
248
249impl IServiceResolver for ServiceProvider {
250    fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
251        let tid = self.type_map.get(key)?;
252        let entry = self.store.get(tid)?.iter().find(|e| e.key.is_none())?;
253        self.get_any_by_entry(entry)
254    }
255    fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
256        let entry = self.entry_by_str(key, variant)?;
257        self.get_any_by_entry(entry)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use crate::collection::ServiceCollection;
264    use std::sync::Arc;
265
266    #[derive(Debug, PartialEq)]
267    struct Calc(i32);
268    #[test]
269    fn optional_missing() {
270        let p = ServiceCollection::new().build().unwrap();
271        assert!(p.get_optional::<i32>().is_none());
272    }
273    #[test]
274    fn all_basic() {
275        let p = ServiceCollection::new()
276            .keyed("a", |_| Arc::new(Calc(1)))
277            .keyed("b", |_| Arc::new(Calc(2)))
278            .build()
279            .unwrap();
280        assert_eq!(p.get_all::<Calc>().len(), 2);
281    }
282}