1use crate::entry::{IServiceResolver, ServiceEntry, ServiceFactory};
2use crate::error::RdiError;
3use crate::lifetime::ServiceLifetime;
4use crate::store::ServiceStore;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9pub struct ServiceProvider {
10 store: ServiceStore,
11 type_map: HashMap<&'static str, TypeId>,
13 singleton_cache: RwLock<HashMap<usize, Arc<dyn Any + Send + Sync>>>,
20 pub(crate) named: RwLock<HashMap<String, Arc<dyn Any + Send + Sync>>>,
24}
25
26impl ServiceProvider {
27 pub(crate) fn new(store: ServiceStore) -> Result<Self, RdiError> {
28 let mut type_map = HashMap::new();
30 for (&tid, entries) in &store {
31 if let Some(e) = entries.first() {
32 type_map.entry(e.type_name).or_insert(tid);
33 }
34 }
35
36 let singleton_entries: Vec<(usize, ServiceFactory)> = store
39 .values()
40 .flat_map(|entries| entries.iter())
41 .filter(|e| e.lifetime == ServiceLifetime::Singleton)
42 .map(|e| (e.cache_key, e.factory.clone()))
43 .collect();
44
45 let sp = Self {
50 store,
51 type_map,
52 singleton_cache: RwLock::new(HashMap::new()),
53 named: RwLock::new(HashMap::new()),
54 };
55
56 for (ck, factory) in &singleton_entries {
57 let instance = (factory)(&sp as &dyn IServiceResolver);
58 sp.singleton_cache.write().unwrap().insert(*ck, instance);
59 }
60
61 Ok(sp)
62 }
63
64 pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
67 self.try_get::<T>()
68 .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
69 }
70
71 pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
73 self.try_get::<T>()
74 }
75
76 pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
78 self.try_get_keyed::<T>(key).unwrap_or_else(|| {
79 panic!(
80 "keyed service not registered: {}:{}",
81 std::any::type_name::<T>(),
82 key
83 )
84 })
85 }
86
87 pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
89 let tid = TypeId::of::<T>();
90 match self.store.get(&tid) {
91 Some(entries) => entries
92 .iter()
93 .filter_map(|e| {
94 let arc = self.get_any_by_entry(e)?;
95 Self::extract(arc)
96 })
97 .collect(),
98 None => Vec::new(),
99 }
100 }
101
102 pub fn scope(self: &Arc<Self>) -> crate::scope::Scope {
107 crate::scope::Scope::new(self.clone())
108 }
109
110 pub fn create_scope(self: &Arc<Self>) -> crate::scope::Scope {
112 self.scope()
113 }
114
115 pub fn get_service<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
117 self.get_optional::<T>()
118 }
119
120 pub fn get_required_service<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
123 self.get::<T>()
124 }
125
126 pub fn get_services<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
129 self.get_all::<T>()
130 }
131
132 fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
133 let tid = TypeId::of::<T>();
134 let entry = self.store.get(&tid)?.iter().find(|e| e.key.is_none())?;
135 let arc = self.get_any_by_entry(entry)?;
136 Self::extract(arc)
137 }
138
139 fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
140 let tid = TypeId::of::<T>();
141 let entry = self
142 .store
143 .get(&tid)?
144 .iter()
145 .find(|e| e.key.as_deref() == Some(key))?;
146 let arc = self.get_any_by_entry(entry)?;
147 Self::extract(arc)
148 }
149
150 pub(crate) fn get_any_by_entry(
151 &self,
152 entry: &ServiceEntry,
153 ) -> Option<Arc<dyn Any + Send + Sync>> {
154 match entry.lifetime {
155 ServiceLifetime::Singleton => {
156 {
160 let cache = self.singleton_cache.read().unwrap();
161 if let Some(instance) = cache.get(&entry.cache_key) {
162 return Some(instance.clone());
163 }
164 }
165 let instance = (entry.factory)(self);
167 self.singleton_cache
168 .write()
169 .unwrap()
170 .insert(entry.cache_key, instance.clone());
171 Some(instance)
172 }
173 ServiceLifetime::Transient | ServiceLifetime::Scoped => Some((entry.factory)(self)),
174 }
175 }
176
177 pub(crate) fn extract<T: ?Sized + Send + Sync + 'static>(
180 arc: Arc<dyn Any + Send + Sync>,
181 ) -> Option<Arc<T>> {
182 let double: Arc<Arc<T>> = arc.downcast::<Arc<T>>().ok()?;
183 Some(Arc::clone(&*double))
184 }
185
186 pub(crate) fn entries_by_tid(&self, tid: &TypeId) -> Option<&Vec<ServiceEntry>> {
188 self.store.get(tid)
189 }
190
191 pub(crate) fn entry_by_str(&self, type_name: &str, variant: &str) -> Option<&ServiceEntry> {
193 let tid = self.type_map.get(type_name)?;
194 self.store
195 .get(tid)?
196 .iter()
197 .find(|e| e.key.as_deref() == Some(variant))
198 }
199
200 pub(crate) fn entries_by_str(&self, type_name: &str) -> Option<&Vec<ServiceEntry>> {
202 let tid = self.type_map.get(type_name)?;
203 self.store.get(tid)
204 }
205
206 pub fn get_named<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> {
208 self.named
209 .read()
210 .unwrap()
211 .get(name)?
212 .clone()
213 .downcast::<T>()
214 .ok()
215 }
216
217 pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
219 self.named.read().unwrap().get(name).cloned()
220 }
221
222 pub fn register_named<T: Send + Sync + 'static>(&self, name: &str, service: Arc<T>) {
224 self.named
225 .write()
226 .unwrap()
227 .insert(name.to_string(), service);
228 }
229
230 pub fn remove_named(&self, name: &str) {
232 self.named.write().unwrap().remove(name);
233 }
234
235 pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
237 self.named
238 .write()
239 .unwrap()
240 .insert(name.to_string(), service);
241 }
242
243 pub fn rdi_remove_named(&self, name: &str) {
245 self.named.write().unwrap().remove(name);
246 }
247}
248
249impl IServiceResolver for ServiceProvider {
250 fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
251 let tid = self.type_map.get(key)?;
252 let entry = self.store.get(tid)?.iter().find(|e| e.key.is_none())?;
253 self.get_any_by_entry(entry)
254 }
255 fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
256 let entry = self.entry_by_str(key, variant)?;
257 self.get_any_by_entry(entry)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use crate::collection::ServiceCollection;
264 use std::sync::Arc;
265
266 #[derive(Debug, PartialEq)]
267 struct Calc(i32);
268 #[test]
269 fn optional_missing() {
270 let p = ServiceCollection::new().build().unwrap();
271 assert!(p.get_optional::<i32>().is_none());
272 }
273 #[test]
274 fn all_basic() {
275 let p = ServiceCollection::new()
276 .keyed("a", |_| Arc::new(Calc(1)))
277 .keyed("b", |_| Arc::new(Calc(2)))
278 .build()
279 .unwrap();
280 assert_eq!(p.get_all::<Calc>().len(), 2);
281 }
282}