typed_container/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, RwLock},
5};
6
7use crate::{Error, ErrorKind};
8
9#[derive(Clone)]
10pub struct Container<'a>(Arc<RwLock<ContainerImpl<'a>>>);
11
12impl<'a> Container<'a> {
13    pub fn new() -> Self {
14        Self(Arc::new(RwLock::new(ContainerImpl::default())))
15    }
16
17    fn register_service_internal(&self, boxed_service: BoxedService) -> Result<(), ErrorKind> {
18        {
19            let read = self.0.read().map_err(|_| ErrorKind::LockPoisoned)?;
20
21            if read.services.contains_key(&boxed_service.type_id) {
22                return Err(ErrorKind::Duplicated);
23            }
24        }
25
26        {
27            let mut write = self.0.write().map_err(|_| ErrorKind::LockPoisoned)?;
28            write
29                .services
30                .insert(boxed_service.type_id.clone(), boxed_service);
31        }
32
33        Ok(())
34    }
35
36    fn register_constructor_internal(
37        &self,
38        boxed_constructor: BoxedConstructor<'a>,
39    ) -> Result<(), ErrorKind> {
40        {
41            let read = self.0.read().map_err(|_| ErrorKind::LockPoisoned)?;
42
43            if read.constructors.contains_key(&boxed_constructor.type_id) {
44                return Err(ErrorKind::Duplicated);
45            }
46        }
47
48        {
49            let mut write = self.0.write().map_err(|_| ErrorKind::LockPoisoned)?;
50            write
51                .constructors
52                .insert(boxed_constructor.type_id, Arc::new(boxed_constructor));
53        }
54
55        Ok(())
56    }
57
58    pub fn construct<T: Clone + 'static>(&self) -> T {
59        self.try_construct().unwrap()
60    }
61
62    pub fn try_construct<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
63        let type_id = TypeId::of::<T>();
64        let construct = {
65            let read = self.0.read()?;
66
67            read.constructors
68                .get(&type_id)
69                .ok_or(ErrorKind::NotFound)?
70                .clone()
71        };
72
73        match construct.construct::<T>(self.clone()) {
74            None => Err(ErrorKind::FailDowncast.into()),
75            Some(v) => Ok(v),
76        }
77    }
78
79    pub fn register_service<T: Clone + 'static>(&self, value: T) {
80        self.register_service_internal(BoxedService::from(value))
81            .unwrap()
82    }
83
84    pub fn try_register_service<T: Clone + 'static>(&self, value: T) -> Result<(), Error<T>> {
85        Ok(self.register_service_internal(BoxedService::from(value))?)
86    }
87
88    pub fn register_constructor<T: Clone + 'static>(&self, value: impl Fn(Container) -> T + 'a) {
89        self.register_constructor_internal(BoxedConstructor::from(value))
90            .unwrap()
91    }
92
93    pub fn try_register_constructor<T: Clone + 'static>(
94        &self,
95        value: impl Fn(Container) -> T + 'a,
96    ) -> Result<(), Error<T>> {
97        Ok(self.register_constructor_internal(BoxedConstructor::from(value))?)
98    }
99
100    pub fn get<T: Clone + 'static>(&self) -> T {
101        self.try_get().unwrap()
102    }
103
104    pub fn try_get<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
105        let type_id = TypeId::of::<T>();
106        {
107            let read = self.0.read()?;
108            if let Some(s) = read.services.get(&type_id) {
109                return match s.get_cloned() {
110                    Some(v) => Ok(v),
111                    None => Err(ErrorKind::FailDowncast.into()),
112                };
113            }
114
115            if !read.constructors.contains_key(&type_id) {
116                return Err(ErrorKind::NotFound.into());
117            }
118        }
119
120        let new_value = self.try_construct::<T>()?;
121
122        self.register_service_internal(BoxedService::from(new_value.clone()))?;
123        Ok(new_value)
124    }
125}
126
127#[derive(Default)]
128struct ContainerImpl<'a> {
129    pub constructors: HashMap<TypeId, Arc<BoxedConstructor<'a>>>,
130    pub services: HashMap<TypeId, BoxedService>,
131}
132
133struct BoxedService {
134    pub type_id: TypeId,
135    pub value: Box<dyn Any>,
136}
137
138impl BoxedService {
139    fn get_cloned<T: Clone + 'static>(&self) -> Option<T> {
140        self.value.downcast_ref::<T>().cloned()
141    }
142}
143
144impl<T: Clone + 'static> From<T> for BoxedService {
145    fn from(value: T) -> Self {
146        Self {
147            type_id: TypeId::of::<T>(),
148            value: Box::new(value),
149        }
150    }
151}
152
153struct BoxedConstructor<'a> {
154    pub type_id: TypeId,
155    pub value: Box<dyn Fn(Container) -> Box<dyn Any> + 'a>,
156}
157
158impl<'a> BoxedConstructor<'a> {
159    fn construct<T: Clone + 'static>(&self, container: Container) -> Option<T> {
160        let value = (self.value)(container).downcast::<T>();
161        match value {
162            Err(_) => None,
163            Ok(v) => Some(*v),
164        }
165    }
166}
167
168impl<'a, T: Clone + 'static, F: Fn(Container) -> T + 'a> From<F> for BoxedConstructor<'a> {
169    fn from(value: F) -> Self {
170        Self {
171            type_id: TypeId::of::<T>(),
172            value: Box::new(move |c| Box::new(value(c))),
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use std::sync::Arc;
180
181    use crate::Container;
182
183    #[test]
184    fn basic_register() {
185        let c = Container::new();
186        c.register_service("A".to_string());
187        c.register_service(123 as u64);
188
189        assert_eq!(c.get::<String>(), "A");
190        assert_eq!(c.get::<u64>(), 123);
191    }
192
193    #[test]
194    fn basic_constructor() {
195        let c = Container::new();
196        c.register_constructor(|_| "A".to_string());
197        c.register_constructor(|_| 123 as u64);
198
199        assert_eq!(c.get::<String>(), "A");
200        assert_eq!(c.get::<u64>(), 123);
201    }
202
203    #[allow(dead_code)]
204    struct A {
205        b: Arc<B>,
206        d: Arc<D>,
207    }
208
209    #[allow(dead_code)]
210    struct B {
211        c: Arc<C>,
212    }
213
214    struct C;
215    #[derive(Clone)]
216    struct D;
217
218    #[test]
219    fn complex() {
220        let c = Container::new();
221        c.register_constructor(|container| {
222            Arc::new(A {
223                b: container.get(),
224                d: container.get(),
225            })
226        });
227        c.register_constructor(|container| Arc::new(B { c: container.get() }));
228        c.register_constructor(|_| Arc::new(C));
229        c.register_constructor(|_| Arc::new(D));
230
231        _ = c.get::<Arc<A>>();
232    }
233
234    #[test]
235    fn constructor_with_lifetime() {
236        let outside_string = "A".to_string();
237        let outside_d = D;
238
239        let c = Container::new();
240        c.register_constructor(|_| outside_string.clone());
241        c.register_constructor(|_| outside_d.clone());
242
243        assert_eq!(c.get::<String>(), "A");
244    }
245}