1use crate::descriptor::ServiceDescriptor;
2use crate::entry::ServiceFactory;
3use crate::lifetime::ServiceLifetime;
4use crate::registration::ServiceRegistration;
5use std::any::{Any, TypeId};
6use std::sync::Arc;
7
8pub struct ServiceCollection {
9 descriptors: Vec<ServiceDescriptor>,
10}
11
12impl ServiceCollection {
13 pub fn new() -> Self {
14 Self {
15 descriptors: Vec::new(),
16 }
17 }
18
19 pub fn singleton<T: ?Sized + Send + Sync + 'static>(
20 mut self,
21 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
22 ) -> Self {
23 self.push(ServiceLifetime::Singleton, None, f);
24 self
25 }
26
27 pub fn transient<T: ?Sized + Send + Sync + 'static>(
28 mut self,
29 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
30 ) -> Self {
31 self.push(ServiceLifetime::Transient, None, f);
32 self
33 }
34
35 pub fn scoped<T: ?Sized + Send + Sync + 'static>(
36 mut self,
37 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
38 ) -> Self {
39 self.push(ServiceLifetime::Scoped, None, f);
40 self
41 }
42
43 pub fn keyed<T: ?Sized + Send + Sync + 'static>(
44 mut self,
45 k: impl Into<String>,
46 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
47 ) -> Self {
48 self.push(ServiceLifetime::Singleton, Some(k.into()), f);
49 self
50 }
51
52 pub fn keyed_transient<T: ?Sized + Send + Sync + 'static>(
53 mut self,
54 k: impl Into<String>,
55 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
56 ) -> Self {
57 self.push(ServiceLifetime::Transient, Some(k.into()), f);
58 self
59 }
60
61 pub fn keyed_scoped<T: ?Sized + Send + Sync + 'static>(
62 mut self,
63 k: impl Into<String>,
64 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
65 ) -> Self {
66 self.push(ServiceLifetime::Scoped, Some(k.into()), f);
67 self
68 }
69
70 pub fn try_add<T: ?Sized + Send + Sync + 'static>(
71 mut self,
72 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
73 ) -> Self {
74 let tid = TypeId::of::<T>();
75 if self
76 .descriptors
77 .iter()
78 .any(|d| d.type_id == tid && d.key.is_none())
79 {
80 return self;
81 }
82 self.push(ServiceLifetime::Singleton, None, f);
83 self
84 }
85
86 pub fn add<T: ?Sized + Send + Sync + 'static>(
87 mut self,
88 lt: ServiceLifetime,
89 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
90 ) -> Self {
91 self.push(lt, None, f);
92 self
93 }
94
95 pub fn instance<T: Send + Sync + 'static>(mut self, v: Arc<T>) -> Self {
96 let ff: ServiceFactory = Arc::new(move |_| Arc::new(v.clone()));
97 self.descriptors.push(ServiceDescriptor {
98 type_id: TypeId::of::<T>(),
99 type_name: std::any::type_name::<T>(),
100 key: None,
101 factory: ff,
102 lifetime: ServiceLifetime::Singleton,
103 });
104 self
105 }
106
107 pub fn singleton_value<T: Send + Sync + 'static>(self, v: T) -> Self {
108 self.instance(Arc::new(v))
109 }
110
111 pub fn build(self) -> Result<crate::provider::ServiceProvider, crate::error::RdiError> {
112 let mut s = crate::store::ServiceStore::new();
113 for (n, d) in self.descriptors.into_iter().enumerate() {
114 let e = crate::entry::ServiceEntry {
115 cache_key: n,
116 key: d.key,
117 type_name: d.type_name,
118 factory: d.factory,
119 lifetime: d.lifetime,
120 };
121 s.entry(d.type_id).or_default().push(e);
122 }
123 crate::provider::ServiceProvider::new(s)
124 }
125
126 #[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))]
129 pub fn from_injected() -> Self {
130 let mut descriptors = Vec::new();
131 for reg in inventory::iter::<ServiceRegistration> {
132 let factory: ServiceFactory = Arc::new(move |r| (reg.factory)(r));
133 descriptors.push(ServiceDescriptor {
134 type_id: reg.type_id,
135 type_name: (reg.type_name_fn)(),
136 key: None,
137 factory,
138 lifetime: reg.lifetime,
139 });
140 }
141 Self { descriptors }
142 }
143
144 fn push<T: ?Sized + Send + Sync + 'static>(
146 &mut self,
147 lt: ServiceLifetime,
148 key: Option<String>,
149 f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
150 ) {
151 let sf: ServiceFactory = Arc::new(move |r| {
152 let val: Arc<T> = (f)(r);
153 Arc::new(val) as Arc<dyn Any + Send + Sync>
154 });
155 self.descriptors.push(ServiceDescriptor {
156 type_id: TypeId::of::<T>(),
157 type_name: std::any::type_name::<T>(),
158 key,
159 factory: sf,
160 lifetime: lt,
161 });
162 }
163}
164impl Default for ServiceCollection {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 #[derive(Debug, PartialEq)]
174 struct G {
175 n: String,
176 }
177 #[derive(Debug, PartialEq)]
178 struct C {
179 v: i32,
180 }
181 #[test]
182 fn empty() {
183 let p = ServiceCollection::new().build().unwrap();
184 assert!(p.get_optional::<G>().is_none());
185 }
186 #[test]
187 fn singleton() {
188 let p = ServiceCollection::new()
189 .singleton(|_| Arc::new(G { n: "Hi".into() }))
190 .build()
191 .unwrap();
192 assert_eq!(p.get::<G>().n, "Hi");
193 }
194 #[test]
195 fn singleton_caches() {
196 use std::sync::atomic::{AtomicUsize, Ordering};
197 static CNT: AtomicUsize = AtomicUsize::new(0);
198 let p = ServiceCollection::new()
199 .singleton(|_| {
200 CNT.fetch_add(1, Ordering::SeqCst);
201 Arc::new(C { v: 42 })
202 })
203 .build()
204 .unwrap();
205 let _ = p.get::<C>();
206 let _ = p.get::<C>();
207 assert_eq!(CNT.load(Ordering::SeqCst), 1);
208 }
209 #[test]
210 fn transient_not_cached() {
211 use std::sync::atomic::{AtomicUsize, Ordering};
212 static CNT: AtomicUsize = AtomicUsize::new(0);
213 let p = ServiceCollection::new()
214 .transient(|_| {
215 CNT.fetch_add(1, Ordering::SeqCst);
216 Arc::new(C { v: 1 })
217 })
218 .build()
219 .unwrap();
220 let _ = p.get::<C>();
221 let _ = p.get::<C>();
222 assert_eq!(CNT.load(Ordering::SeqCst), 2);
223 }
224 #[test]
225 fn instance() {
226 let g = Arc::new(G { n: "Inst".into() });
227 let p = ServiceCollection::new()
228 .instance(g.clone())
229 .build()
230 .unwrap();
231 assert!(Arc::ptr_eq(&g, &p.get::<G>()));
232 }
233 #[test]
234 fn keyed_svc() {
235 let p = ServiceCollection::new()
236 .keyed("a", |_| Arc::new(G { n: "A".into() }))
237 .keyed("b", |_| Arc::new(G { n: "B".into() }))
238 .build()
239 .unwrap();
240 assert_eq!(p.get_keyed::<G>("a").n, "A");
241 assert_eq!(p.get_keyed::<G>("b").n, "B");
242 }
243 #[test]
244 fn get_all() {
245 let p = ServiceCollection::new()
246 .keyed("x", |_| Arc::new(C { v: 1 }))
247 .keyed("y", |_| Arc::new(C { v: 2 }))
248 .build()
249 .unwrap();
250 assert_eq!(p.get_all::<C>().len(), 2);
251 }
252 #[test]
253 fn try_add_skips() {
254 let p = ServiceCollection::new()
255 .singleton(|_| Arc::new(G { n: "First".into() }))
256 .try_add(|_| Arc::new(G { n: "Second".into() }))
257 .build()
258 .unwrap();
259 assert_eq!(p.get::<G>().n, "First");
260 }
261}