Skip to main content

rust_dicore/
collection.rs

1use crate::descriptor::ServiceDescriptor;
2use crate::entry::ServiceFactory;
3use crate::lifetime::ServiceLifetime;
4use crate::registration::ServiceRegistration;
5use std::any::{Any, TypeId};
6use std::sync::Arc;
7
8pub struct ServiceCollection {
9    descriptors: Vec<ServiceDescriptor>,
10}
11
12impl ServiceCollection {
13    pub fn new() -> Self {
14        Self {
15            descriptors: Vec::new(),
16        }
17    }
18
19    pub fn singleton<T: ?Sized + Send + Sync + 'static>(
20        mut self,
21        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
22    ) -> Self {
23        self.push(ServiceLifetime::Singleton, None, f);
24        self
25    }
26
27    pub fn transient<T: ?Sized + Send + Sync + 'static>(
28        mut self,
29        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
30    ) -> Self {
31        self.push(ServiceLifetime::Transient, None, f);
32        self
33    }
34
35    pub fn scoped<T: ?Sized + Send + Sync + 'static>(
36        mut self,
37        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
38    ) -> Self {
39        self.push(ServiceLifetime::Scoped, None, f);
40        self
41    }
42
43    pub fn keyed<T: ?Sized + Send + Sync + 'static>(
44        mut self,
45        k: impl Into<String>,
46        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
47    ) -> Self {
48        self.push(ServiceLifetime::Singleton, Some(k.into()), f);
49        self
50    }
51
52    pub fn keyed_transient<T: ?Sized + Send + Sync + 'static>(
53        mut self,
54        k: impl Into<String>,
55        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
56    ) -> Self {
57        self.push(ServiceLifetime::Transient, Some(k.into()), f);
58        self
59    }
60
61    pub fn keyed_scoped<T: ?Sized + Send + Sync + 'static>(
62        mut self,
63        k: impl Into<String>,
64        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
65    ) -> Self {
66        self.push(ServiceLifetime::Scoped, Some(k.into()), f);
67        self
68    }
69
70    pub fn try_add<T: ?Sized + Send + Sync + 'static>(
71        mut self,
72        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
73    ) -> Self {
74        let tid = TypeId::of::<T>();
75        if self
76            .descriptors
77            .iter()
78            .any(|d| d.type_id == tid && d.key.is_none())
79        {
80            return self;
81        }
82        self.push(ServiceLifetime::Singleton, None, f);
83        self
84    }
85
86    pub fn add<T: ?Sized + Send + Sync + 'static>(
87        mut self,
88        lt: ServiceLifetime,
89        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
90    ) -> Self {
91        self.push(lt, None, f);
92        self
93    }
94
95    pub fn instance<T: Send + Sync + 'static>(mut self, v: Arc<T>) -> Self {
96        let ff: ServiceFactory = Arc::new(move |_| Arc::new(v.clone()));
97        self.descriptors.push(ServiceDescriptor {
98            type_id: TypeId::of::<T>(),
99            type_name: std::any::type_name::<T>(),
100            key: None,
101            factory: ff,
102            lifetime: ServiceLifetime::Singleton,
103        });
104        self
105    }
106
107    pub fn singleton_value<T: Send + Sync + 'static>(self, v: T) -> Self {
108        self.instance(Arc::new(v))
109    }
110
111    pub fn build(self) -> Result<crate::provider::ServiceProvider, crate::error::RdiError> {
112        let mut s = crate::store::ServiceStore::new();
113        for (n, d) in self.descriptors.into_iter().enumerate() {
114            let e = crate::entry::ServiceEntry {
115                cache_key: n,
116                key: d.key,
117                type_name: d.type_name,
118                factory: d.factory,
119                lifetime: d.lifetime,
120            };
121            s.entry(d.type_id).or_default().push(e);
122        }
123        crate::provider::ServiceProvider::new(s)
124    }
125
126    /// Build a `ServiceCollection` from all `#[rust_dicore::inject]` annotations
127    /// in the current binary (across all crates).
128    #[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))]
129    pub fn from_injected() -> Self {
130        let mut descriptors = Vec::new();
131        for reg in inventory::iter::<ServiceRegistration> {
132            let factory: ServiceFactory = Arc::new(move |r| (reg.factory)(r));
133            descriptors.push(ServiceDescriptor {
134                type_id: reg.type_id,
135                type_name: (reg.type_name_fn)(),
136                key: None,
137                factory,
138                lifetime: reg.lifetime,
139            });
140        }
141        Self { descriptors }
142    }
143
144    /// Register a new singleton service.
145    fn push<T: ?Sized + Send + Sync + 'static>(
146        &mut self,
147        lt: ServiceLifetime,
148        key: Option<String>,
149        f: impl Fn(&dyn crate::entry::IServiceResolver) -> Arc<T> + Send + Sync + 'static,
150    ) {
151        let sf: ServiceFactory = Arc::new(move |r| {
152            let val: Arc<T> = (f)(r);
153            Arc::new(val) as Arc<dyn Any + Send + Sync>
154        });
155        self.descriptors.push(ServiceDescriptor {
156            type_id: TypeId::of::<T>(),
157            type_name: std::any::type_name::<T>(),
158            key,
159            factory: sf,
160            lifetime: lt,
161        });
162    }
163}
164impl Default for ServiceCollection {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    #[derive(Debug, PartialEq)]
174    struct G {
175        n: String,
176    }
177    #[derive(Debug, PartialEq)]
178    struct C {
179        v: i32,
180    }
181    #[test]
182    fn empty() {
183        let p = ServiceCollection::new().build().unwrap();
184        assert!(p.get_optional::<G>().is_none());
185    }
186    #[test]
187    fn singleton() {
188        let p = ServiceCollection::new()
189            .singleton(|_| Arc::new(G { n: "Hi".into() }))
190            .build()
191            .unwrap();
192        assert_eq!(p.get::<G>().n, "Hi");
193    }
194    #[test]
195    fn singleton_caches() {
196        use std::sync::atomic::{AtomicUsize, Ordering};
197        static CNT: AtomicUsize = AtomicUsize::new(0);
198        let p = ServiceCollection::new()
199            .singleton(|_| {
200                CNT.fetch_add(1, Ordering::SeqCst);
201                Arc::new(C { v: 42 })
202            })
203            .build()
204            .unwrap();
205        let _ = p.get::<C>();
206        let _ = p.get::<C>();
207        assert_eq!(CNT.load(Ordering::SeqCst), 1);
208    }
209    #[test]
210    fn transient_not_cached() {
211        use std::sync::atomic::{AtomicUsize, Ordering};
212        static CNT: AtomicUsize = AtomicUsize::new(0);
213        let p = ServiceCollection::new()
214            .transient(|_| {
215                CNT.fetch_add(1, Ordering::SeqCst);
216                Arc::new(C { v: 1 })
217            })
218            .build()
219            .unwrap();
220        let _ = p.get::<C>();
221        let _ = p.get::<C>();
222        assert_eq!(CNT.load(Ordering::SeqCst), 2);
223    }
224    #[test]
225    fn instance() {
226        let g = Arc::new(G { n: "Inst".into() });
227        let p = ServiceCollection::new()
228            .instance(g.clone())
229            .build()
230            .unwrap();
231        assert!(Arc::ptr_eq(&g, &p.get::<G>()));
232    }
233    #[test]
234    fn keyed_svc() {
235        let p = ServiceCollection::new()
236            .keyed("a", |_| Arc::new(G { n: "A".into() }))
237            .keyed("b", |_| Arc::new(G { n: "B".into() }))
238            .build()
239            .unwrap();
240        assert_eq!(p.get_keyed::<G>("a").n, "A");
241        assert_eq!(p.get_keyed::<G>("b").n, "B");
242    }
243    #[test]
244    fn get_all() {
245        let p = ServiceCollection::new()
246            .keyed("x", |_| Arc::new(C { v: 1 }))
247            .keyed("y", |_| Arc::new(C { v: 2 }))
248            .build()
249            .unwrap();
250        assert_eq!(p.get_all::<C>().len(), 2);
251    }
252    #[test]
253    fn try_add_skips() {
254        let p = ServiceCollection::new()
255            .singleton(|_| Arc::new(G { n: "First".into() }))
256            .try_add(|_| Arc::new(G { n: "Second".into() }))
257            .build()
258            .unwrap();
259        assert_eq!(p.get::<G>().n, "First");
260    }
261}