typed_container/
container.rs1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet},
4 sync::{Arc, RwLock},
5};
6
7use crate::{Error, ErrorKind};
8
9#[derive(Clone)]
11pub struct Container<'a>(Arc<RwLock<ContainerImpl<'a>>>);
12
13impl<'a> Container<'a> {
14 pub fn new() -> Self {
15 Self(Arc::new(RwLock::new(ContainerImpl::default())))
16 }
17
18 fn register_service_internal(&self, boxed_service: BoxedService) -> Result<(), ErrorKind> {
19 let mut impl_obj = self.0.write()?;
20 if impl_obj.services.contains_key(&boxed_service.type_id) {
21 return Err(ErrorKind::Duplicated);
22 }
23 impl_obj
24 .services
25 .insert(boxed_service.type_id.clone(), boxed_service);
26
27 Ok(())
28 }
29
30 fn register_constructor_internal(
31 &self,
32 boxed_constructor: BoxedConstructor<'a>,
33 ) -> Result<(), ErrorKind> {
34 let mut write = self.0.write()?;
35 write
36 .constructors
37 .insert(boxed_constructor.type_id, Arc::new(boxed_constructor));
38
39 Ok(())
40 }
41
42 fn construct_internal<T: Clone + 'static>(&self) -> Result<T, ErrorKind> {
43 let type_id = TypeId::of::<T>();
44
45 let constructor = {
47 let impl_ref = self.0.read()?;
48
49 if impl_ref.pending_construction.contains(&type_id) {
50 return Err(ErrorKind::CircularReference);
51 }
52
53 impl_ref
54 .constructors
55 .get(&type_id)
56 .ok_or(ErrorKind::NotFound)?
57 .clone()
58 };
59
60 self.0.write()?.pending_construction.insert(type_id.clone());
62
63 let construction = constructor.construct::<T>(self.clone());
65
66 self.0.write()?.pending_construction.remove(&type_id);
68
69 match construction {
70 None => Err(ErrorKind::FailDowncast),
71 Some(v) => Ok(v),
72 }
73 }
74
75 pub fn register_service<T: Clone + 'static>(&self, value: T) {
79 self.register_service_internal(BoxedService::from(value))
80 .unwrap()
81 }
82
83 pub fn try_register_service<T: Clone + 'static>(&self, value: T) -> Result<(), Error<T>> {
85 Ok(self.register_service_internal(BoxedService::from(value))?)
86 }
87
88 pub fn register_constructor<T: Clone + 'static>(&self, value: impl Fn(Container) -> T + 'a) {
92 self.register_constructor_internal(BoxedConstructor::from(value))
93 .unwrap()
94 }
95
96 pub fn try_register_constructor<T: Clone + 'static>(
98 &self,
99 value: impl Fn(Container) -> T + 'a,
100 ) -> Result<(), Error<T>> {
101 Ok(self.register_constructor_internal(BoxedConstructor::from(value))?)
102 }
103
104 pub fn get<T: Clone + 'static>(&self) -> T {
108 self.try_get().unwrap()
109 }
110
111 pub fn try_get<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
113 let type_id = TypeId::of::<T>();
114 {
115 let impl_obj = self.0.read()?;
116
117 if let Some(s) = impl_obj.services.get(&type_id) {
118 return match s.get_cloned() {
119 Some(v) => Ok(v),
120 None => Err(ErrorKind::FailDowncast.into()),
121 };
122 }
123
124 if !impl_obj.constructors.contains_key(&type_id) {
125 return Err(ErrorKind::NotFound.into());
126 }
127 }
128
129 let new_value = self.try_construct::<T>()?;
130
131 self.register_service_internal(BoxedService::from(new_value.clone()))?;
132 Ok(new_value)
133 }
134
135 pub fn construct<T: Clone + 'static>(&self) -> T {
139 self.construct_internal().unwrap()
140 }
141
142 pub fn try_construct<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
144 Ok(self.construct_internal()?)
145 }
146
147 pub fn remove_service<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
149 Ok(self
150 .0
151 .write()?
152 .services
153 .remove(&TypeId::of::<T>())
154 .ok_or(ErrorKind::NotFound)?
155 .get_cloned::<T>()
156 .ok_or(ErrorKind::FailDowncast)?)
157 }
158
159 pub fn remove_constructor<T: Clone + 'static>(&self) -> Result<(), Error<T>> {
161 self.0.write()?.constructors.remove(&TypeId::of::<T>());
162 Ok(())
163 }
164
165 pub fn into_static(self) -> Container<'static> {
167 let mut services = HashMap::new();
168 std::mem::swap(&mut self.0.write().unwrap().services, &mut services);
169
170 Container::<'static>(Arc::new(RwLock::new(ContainerImpl {
171 services,
172 ..Default::default()
173 })))
174 }
175}
176
177#[derive(Default)]
178struct ContainerImpl<'a> {
179 pub constructors: HashMap<TypeId, Arc<BoxedConstructor<'a>>>,
180 pub services: HashMap<TypeId, BoxedService>,
181 pub pending_construction: HashSet<TypeId>,
182}
183
184struct BoxedService {
185 pub type_id: TypeId,
186 pub value: Box<dyn Any>,
187}
188
189impl BoxedService {
190 fn get_cloned<T: Clone + 'static>(&self) -> Option<T> {
191 self.value.downcast_ref::<T>().cloned()
192 }
193}
194
195impl<T: Clone + 'static> From<T> for BoxedService {
196 fn from(value: T) -> Self {
197 Self {
198 type_id: TypeId::of::<T>(),
199 value: Box::new(value),
200 }
201 }
202}
203
204struct BoxedConstructor<'a> {
205 pub type_id: TypeId,
206 pub value: Box<dyn Fn(Container) -> Box<dyn Any> + 'a>,
207}
208
209impl<'a> BoxedConstructor<'a> {
210 fn construct<T: Clone + 'static>(&self, container: Container) -> Option<T> {
211 let value = (self.value)(container).downcast::<T>();
212 match value {
213 Err(_) => None,
214 Ok(v) => Some(*v),
215 }
216 }
217}
218
219impl<'a, T: Clone + 'static, F: Fn(Container) -> T + 'a> From<F> for BoxedConstructor<'a> {
220 fn from(value: F) -> Self {
221 Self {
222 type_id: TypeId::of::<T>(),
223 value: Box::new(move |c| Box::new(value(c))),
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use std::sync::Arc;
231
232 use crate::Container;
233
234 #[test]
235 fn basic_register() {
236 let c = Container::new();
237 c.register_service("A".to_string());
238 c.register_service(123 as u64);
239
240 assert_eq!(c.get::<String>(), "A");
241 assert_eq!(c.get::<u64>(), 123);
242 }
243
244 #[test]
245 fn basic_constructor() {
246 let c = Container::new();
247 c.register_constructor(|_| "A".to_string());
248 c.register_constructor(|_| 123 as u64);
249
250 assert_eq!(c.get::<String>(), "A");
251 assert_eq!(c.get::<u64>(), 123);
252 }
253
254 #[allow(dead_code)]
255 struct A {
256 b: Arc<B>,
257 d: Arc<D>,
258 }
259
260 #[allow(dead_code)]
261 struct B {
262 c: Arc<C>,
263 }
264
265 struct C;
266 #[derive(Clone)]
267 struct D;
268
269 #[test]
270 fn complex() {
271 let c = Container::new();
272 c.register_constructor(|container| {
273 Arc::new(A {
274 b: container.get(),
275 d: container.get(),
276 })
277 });
278 c.register_constructor(|container| Arc::new(B { c: container.get() }));
279 c.register_constructor(|_| Arc::new(C));
280 c.register_constructor(|_| Arc::new(D));
281
282 _ = c.get::<Arc<A>>();
283 }
284
285 #[test]
286 fn constructor_with_lifetime() {
287 let outside_string = "A".to_string();
288 let outside_d = D;
289
290 let c = Container::new();
291 c.register_constructor(|_| outside_string.clone());
292 c.register_constructor(|_| outside_d.clone());
293
294 assert_eq!(c.get::<String>(), "A");
295 }
296
297 #[derive(Debug)]
298 struct RefA {
299 pub _b: Arc<RefB>,
300 }
301
302 #[derive(Debug)]
303 struct RefB {
304 pub _a: Arc<RefA>,
305 }
306
307 #[test]
308 #[should_panic]
309 fn circular_reference() {
310 let c = Container::new();
311 c.register_constructor(|c| Arc::new(RefA { _b: c.get() }));
312 c.register_constructor(|c| Arc::new(RefB { _a: c.get() }));
313
314 _ = c.get::<Arc<RefA>>();
315 }
316}