typed_container/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    sync::{Arc, RwLock},
5};
6
7use crate::{Error, ErrorKind};
8
9/// A container that you can register your services or constructors and handle their dependencies
10#[derive(Clone)]
11pub struct Container<'a>(Arc<RwLock<ContainerImpl<'a>>>);
12
13impl<'a> Container<'a> {
14    pub fn new() -> Self {
15        Self(Arc::new(RwLock::new(ContainerImpl::default())))
16    }
17
18    fn register_service_internal(&self, boxed_service: BoxedService) -> Result<(), ErrorKind> {
19        let mut impl_obj = self.0.write()?;
20        if impl_obj.services.contains_key(&boxed_service.type_id) {
21            return Err(ErrorKind::Duplicated);
22        }
23        impl_obj
24            .services
25            .insert(boxed_service.type_id.clone(), boxed_service);
26
27        Ok(())
28    }
29
30    fn register_constructor_internal(
31        &self,
32        boxed_constructor: BoxedConstructor<'a>,
33    ) -> Result<(), ErrorKind> {
34        let mut write = self.0.write()?;
35        write
36            .constructors
37            .insert(boxed_constructor.type_id, Arc::new(boxed_constructor));
38
39        Ok(())
40    }
41
42    fn construct_internal<T: Clone + 'static>(&self) -> Result<T, ErrorKind> {
43        let type_id = TypeId::of::<T>();
44
45        // check pending and get the constructor
46        let constructor = {
47            let impl_ref = self.0.read()?;
48
49            if impl_ref.pending_construction.contains(&type_id) {
50                return Err(ErrorKind::CircularReference);
51            }
52
53            impl_ref
54                .constructors
55                .get(&type_id)
56                .ok_or(ErrorKind::NotFound)?
57                .clone()
58        };
59
60        // add current type into pending
61        self.0.write()?.pending_construction.insert(type_id.clone());
62
63        // construct the object
64        let construction = constructor.construct::<T>(self.clone());
65
66        // remove pending
67        self.0.write()?.pending_construction.remove(&type_id);
68
69        match construction {
70            None => Err(ErrorKind::FailDowncast),
71            Some(v) => Ok(v),
72        }
73    }
74
75    /// Register a new service.
76    ///
77    /// Panic if error occurred.
78    pub fn register_service<T: Clone + 'static>(&self, value: T) {
79        self.register_service_internal(BoxedService::from(value))
80            .unwrap()
81    }
82
83    /// Register a new service.
84    pub fn try_register_service<T: Clone + 'static>(&self, value: T) -> Result<(), Error<T>> {
85        Ok(self.register_service_internal(BoxedService::from(value))?)
86    }
87
88    /// Register a new constructor or replace the old constructor.
89    ///
90    /// Panic if error occurred.
91    pub fn register_constructor<T: Clone + 'static>(&self, value: impl Fn(Container) -> T + 'a) {
92        self.register_constructor_internal(BoxedConstructor::from(value))
93            .unwrap()
94    }
95
96    /// Register a new constructor or replace the old constructor.
97    pub fn try_register_constructor<T: Clone + 'static>(
98        &self,
99        value: impl Fn(Container) -> T + 'a,
100    ) -> Result<(), Error<T>> {
101        Ok(self.register_constructor_internal(BoxedConstructor::from(value))?)
102    }
103
104    /// Get a service from the container, if the service doesn't exist but a available constructor exists, it will try to construct it.
105    ///
106    /// Panic if error occurred.
107    pub fn get<T: Clone + 'static>(&self) -> T {
108        self.try_get().unwrap()
109    }
110
111    /// Get a service from the container, if the service doesn't exist but a available constructor exists, it will try to construct it.
112    pub fn try_get<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
113        let type_id = TypeId::of::<T>();
114        {
115            let impl_obj = self.0.read()?;
116
117            if let Some(s) = impl_obj.services.get(&type_id) {
118                return match s.get_cloned() {
119                    Some(v) => Ok(v),
120                    None => Err(ErrorKind::FailDowncast.into()),
121                };
122            }
123
124            if !impl_obj.constructors.contains_key(&type_id) {
125                return Err(ErrorKind::NotFound.into());
126            }
127        }
128
129        let new_value = self.try_construct::<T>()?;
130
131        self.register_service_internal(BoxedService::from(new_value.clone()))?;
132        Ok(new_value)
133    }
134
135    /// Just construct a new object, whether it exists or not, it will be constructed and won't be inserted into container
136    ///
137    /// Panic if error occurred.
138    pub fn construct<T: Clone + 'static>(&self) -> T {
139        self.construct_internal().unwrap()
140    }
141
142    /// Just construct a new object, whether it exists or not, it will be constructed and won't be inserted into container
143    pub fn try_construct<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
144        Ok(self.construct_internal()?)
145    }
146
147    /// Remove a service
148    pub fn remove_service<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
149        Ok(self
150            .0
151            .write()?
152            .services
153            .remove(&TypeId::of::<T>())
154            .ok_or(ErrorKind::NotFound)?
155            .get_cloned::<T>()
156            .ok_or(ErrorKind::FailDowncast)?)
157    }
158
159    /// Remove a constructor
160    pub fn remove_constructor<T: Clone + 'static>(&self) -> Result<(), Error<T>> {
161        self.0.write()?.constructors.remove(&TypeId::of::<T>());
162        Ok(())
163    }
164
165    /// Remove all constructors and convert the lifetime to static
166    pub fn into_static(self) -> Container<'static> {
167        let mut services = HashMap::new();
168        std::mem::swap(&mut self.0.write().unwrap().services, &mut services);
169
170        Container::<'static>(Arc::new(RwLock::new(ContainerImpl {
171            services,
172            ..Default::default()
173        })))
174    }
175}
176
177#[derive(Default)]
178struct ContainerImpl<'a> {
179    pub constructors: HashMap<TypeId, Arc<BoxedConstructor<'a>>>,
180    pub services: HashMap<TypeId, BoxedService>,
181    pub pending_construction: HashSet<TypeId>,
182}
183
184struct BoxedService {
185    pub type_id: TypeId,
186    pub value: Box<dyn Any>,
187}
188
189impl BoxedService {
190    fn get_cloned<T: Clone + 'static>(&self) -> Option<T> {
191        self.value.downcast_ref::<T>().cloned()
192    }
193}
194
195impl<T: Clone + 'static> From<T> for BoxedService {
196    fn from(value: T) -> Self {
197        Self {
198            type_id: TypeId::of::<T>(),
199            value: Box::new(value),
200        }
201    }
202}
203
204struct BoxedConstructor<'a> {
205    pub type_id: TypeId,
206    pub value: Box<dyn Fn(Container) -> Box<dyn Any> + 'a>,
207}
208
209impl<'a> BoxedConstructor<'a> {
210    fn construct<T: Clone + 'static>(&self, container: Container) -> Option<T> {
211        let value = (self.value)(container).downcast::<T>();
212        match value {
213            Err(_) => None,
214            Ok(v) => Some(*v),
215        }
216    }
217}
218
219impl<'a, T: Clone + 'static, F: Fn(Container) -> T + 'a> From<F> for BoxedConstructor<'a> {
220    fn from(value: F) -> Self {
221        Self {
222            type_id: TypeId::of::<T>(),
223            value: Box::new(move |c| Box::new(value(c))),
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use std::sync::Arc;
231
232    use crate::Container;
233
234    #[test]
235    fn basic_register() {
236        let c = Container::new();
237        c.register_service("A".to_string());
238        c.register_service(123 as u64);
239
240        assert_eq!(c.get::<String>(), "A");
241        assert_eq!(c.get::<u64>(), 123);
242    }
243
244    #[test]
245    fn basic_constructor() {
246        let c = Container::new();
247        c.register_constructor(|_| "A".to_string());
248        c.register_constructor(|_| 123 as u64);
249
250        assert_eq!(c.get::<String>(), "A");
251        assert_eq!(c.get::<u64>(), 123);
252    }
253
254    #[allow(dead_code)]
255    struct A {
256        b: Arc<B>,
257        d: Arc<D>,
258    }
259
260    #[allow(dead_code)]
261    struct B {
262        c: Arc<C>,
263    }
264
265    struct C;
266    #[derive(Clone)]
267    struct D;
268
269    #[test]
270    fn complex() {
271        let c = Container::new();
272        c.register_constructor(|container| {
273            Arc::new(A {
274                b: container.get(),
275                d: container.get(),
276            })
277        });
278        c.register_constructor(|container| Arc::new(B { c: container.get() }));
279        c.register_constructor(|_| Arc::new(C));
280        c.register_constructor(|_| Arc::new(D));
281
282        _ = c.get::<Arc<A>>();
283    }
284
285    #[test]
286    fn constructor_with_lifetime() {
287        let outside_string = "A".to_string();
288        let outside_d = D;
289
290        let c = Container::new();
291        c.register_constructor(|_| outside_string.clone());
292        c.register_constructor(|_| outside_d.clone());
293
294        assert_eq!(c.get::<String>(), "A");
295    }
296
297    #[derive(Debug)]
298    struct RefA {
299        pub _b: Arc<RefB>,
300    }
301
302    #[derive(Debug)]
303    struct RefB {
304        pub _a: Arc<RefA>,
305    }
306
307    #[test]
308    #[should_panic]
309    fn circular_reference() {
310        let c = Container::new();
311        c.register_constructor(|c| Arc::new(RefA { _b: c.get() }));
312        c.register_constructor(|c| Arc::new(RefB { _a: c.get() }));
313
314        _ = c.get::<Arc<RefA>>();
315    }
316}