1use crate::entry::IServiceResolver;
2use crate::lifetime::ServiceLifetime;
3use crate::provider::ServiceProvider;
4use std::any::{Any, TypeId};
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7
8pub struct Scope {
9 parent: Arc<ServiceProvider>,
10 scoped_cache: RwLock<HashMap<usize, Arc<dyn Any + Send + Sync>>>,
11}
12
13impl Scope {
14 pub(crate) fn new(parent: Arc<ServiceProvider>) -> Self {
15 Self {
16 parent,
17 scoped_cache: RwLock::new(HashMap::new()),
18 }
19 }
20
21 pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
22 self.try_get::<T>()
23 .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
24 }
25
26 pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
27 self.try_get::<T>()
28 }
29
30 pub fn get_service<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
32 self.get_optional::<T>()
33 }
34
35 pub fn get_required_service<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
37 self.get::<T>()
38 }
39
40 pub fn get_services<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
43 self.get_all::<T>()
44 }
45
46 pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
47 self.try_get_keyed::<T>(key).unwrap_or_else(|| {
48 panic!(
49 "keyed service not registered: {}:{}",
50 std::any::type_name::<T>(),
51 key
52 )
53 })
54 }
55
56 pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
57 let tid = TypeId::of::<T>();
58 if let Some(entries) = self.parent.entries_by_tid(&tid) {
59 entries
60 .iter()
61 .filter_map(|e| {
62 let arc = self.get_any_by_entry(e)?;
63 ServiceProvider::extract(arc)
64 })
65 .collect()
66 } else {
67 Vec::new()
68 }
69 }
70
71 pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
72 self.parent.get_named_any(name)
73 }
74
75 fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
76 let tid = TypeId::of::<T>();
77 let entry = self
78 .parent
79 .entries_by_tid(&tid)?
80 .iter()
81 .find(|e| e.key.is_none())?;
82 let arc = self.get_any_by_entry(entry)?;
83 ServiceProvider::extract(arc)
84 }
85
86 fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
87 let tid = TypeId::of::<T>();
88 let entries = self.parent.entries_by_tid(&tid)?;
89 let entry = entries.iter().find(|e| e.key.as_deref() == Some(key))?;
90 let arc = self.get_any_by_entry(entry)?;
91 ServiceProvider::extract(arc)
92 }
93
94 fn get_any_by_entry(
95 &self,
96 entry: &crate::entry::ServiceEntry,
97 ) -> Option<Arc<dyn Any + Send + Sync>> {
98 match entry.lifetime {
99 ServiceLifetime::Singleton => {
100 self.parent.get_any_by_entry(entry)
102 }
103 ServiceLifetime::Transient => Some((entry.factory)(self.parent.as_ref())),
104 ServiceLifetime::Scoped => {
105 {
106 let cache = self.scoped_cache.read().unwrap();
107 if let Some(instance) = cache.get(&entry.cache_key) {
108 return Some(instance.clone());
109 }
110 }
111 let instance = (entry.factory)(self);
112 {
113 self.scoped_cache
114 .write()
115 .unwrap()
116 .insert(entry.cache_key, instance.clone());
117 }
118 Some(instance)
119 }
120 }
121 }
122}
123
124impl IServiceResolver for Scope {
125 fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
126 if let Some(entries) = self.parent.entries_by_str(key) {
127 for entry in entries {
128 if entry.key.is_none() {
129 if let Some(r) = self.get_any_by_entry(entry) {
130 return Some(r);
131 }
132 }
133 }
134 }
135 None
136 }
137 fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
138 let entry = self.parent.entry_by_str(key, variant)?;
139 self.get_any_by_entry(entry)
140 }
141}
142
143impl Scope {
144 pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
146 self.parent.rdi_register_named(name, service);
147 }
148
149 pub fn rdi_remove_named(&self, name: &str) {
151 self.parent.rdi_remove_named(name);
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::collection::ServiceCollection;
159 use std::sync::atomic::{AtomicU64, Ordering};
160 #[derive(Debug, PartialEq)]
161 struct Sd(u64);
162 #[test]
163 fn scoped_cached_per_scope() {
164 static NXT: AtomicU64 = AtomicU64::new(0);
165 let p = Arc::new(
166 ServiceCollection::new()
167 .scoped(|_| Arc::new(Sd(NXT.fetch_add(1, Ordering::SeqCst))))
168 .build()
169 .unwrap(),
170 );
171 let s1 = p.scope();
172 let a = s1.get::<Sd>();
173 let b = s1.get::<Sd>();
174 assert_eq!(a.0, b.0);
175 let s2 = p.scope();
176 let c = s2.get::<Sd>();
177 assert_ne!(a.0, c.0);
178 }
179}