typed_container/
container.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, RwLock},
5};
6
7use crate::{Error, ErrorKind};
8
9#[derive(Clone)]
10pub struct Container<'a>(Arc<RwLock<ContainerImpl<'a>>>);
11
12impl<'a> Container<'a> {
13 pub fn new() -> Self {
14 Self(Arc::new(RwLock::new(ContainerImpl::default())))
15 }
16
17 fn register_service_internal(&self, boxed_service: BoxedService) -> Result<(), ErrorKind> {
18 {
19 let read = self.0.read().map_err(|_| ErrorKind::LockPoisoned)?;
20
21 if read.services.contains_key(&boxed_service.type_id) {
22 return Err(ErrorKind::Duplicated);
23 }
24 }
25
26 {
27 let mut write = self.0.write().map_err(|_| ErrorKind::LockPoisoned)?;
28 write
29 .services
30 .insert(boxed_service.type_id.clone(), boxed_service);
31 }
32
33 Ok(())
34 }
35
36 fn register_constructor_internal(
37 &self,
38 boxed_constructor: BoxedConstructor<'a>,
39 ) -> Result<(), ErrorKind> {
40 {
41 let read = self.0.read().map_err(|_| ErrorKind::LockPoisoned)?;
42
43 if read.constructors.contains_key(&boxed_constructor.type_id) {
44 return Err(ErrorKind::Duplicated);
45 }
46 }
47
48 {
49 let mut write = self.0.write().map_err(|_| ErrorKind::LockPoisoned)?;
50 write
51 .constructors
52 .insert(boxed_constructor.type_id, Arc::new(boxed_constructor));
53 }
54
55 Ok(())
56 }
57
58 pub fn construct<T: Clone + 'static>(&self) -> T {
59 self.try_construct().unwrap()
60 }
61
62 pub fn try_construct<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
63 let type_id = TypeId::of::<T>();
64 let construct = {
65 let read = self.0.read()?;
66
67 read.constructors
68 .get(&type_id)
69 .ok_or(ErrorKind::NotFound)?
70 .clone()
71 };
72
73 match construct.construct::<T>(self.clone()) {
74 None => Err(ErrorKind::FailDowncast.into()),
75 Some(v) => Ok(v),
76 }
77 }
78
79 pub fn register_service<T: Clone + 'static>(&self, value: T) {
80 self.register_service_internal(BoxedService::from(value))
81 .unwrap()
82 }
83
84 pub fn try_register_service<T: Clone + 'static>(&self, value: T) -> Result<(), Error<T>> {
85 Ok(self.register_service_internal(BoxedService::from(value))?)
86 }
87
88 pub fn register_constructor<T: Clone + 'static>(&self, value: impl Fn(Container) -> T + 'a) {
89 self.register_constructor_internal(BoxedConstructor::from(value))
90 .unwrap()
91 }
92
93 pub fn try_register_constructor<T: Clone + 'static>(
94 &self,
95 value: impl Fn(Container) -> T + 'a,
96 ) -> Result<(), Error<T>> {
97 Ok(self.register_constructor_internal(BoxedConstructor::from(value))?)
98 }
99
100 pub fn get<T: Clone + 'static>(&self) -> T {
101 self.try_get().unwrap()
102 }
103
104 pub fn try_get<T: Clone + 'static>(&self) -> Result<T, Error<T>> {
105 let type_id = TypeId::of::<T>();
106 {
107 let read = self.0.read()?;
108 if let Some(s) = read.services.get(&type_id) {
109 return match s.get_cloned() {
110 Some(v) => Ok(v),
111 None => Err(ErrorKind::FailDowncast.into()),
112 };
113 }
114
115 if !read.constructors.contains_key(&type_id) {
116 return Err(ErrorKind::NotFound.into());
117 }
118 }
119
120 let new_value = self.try_construct::<T>()?;
121
122 self.register_service_internal(BoxedService::from(new_value.clone()))?;
123 Ok(new_value)
124 }
125}
126
127#[derive(Default)]
128struct ContainerImpl<'a> {
129 pub constructors: HashMap<TypeId, Arc<BoxedConstructor<'a>>>,
130 pub services: HashMap<TypeId, BoxedService>,
131}
132
133struct BoxedService {
134 pub type_id: TypeId,
135 pub value: Box<dyn Any>,
136}
137
138impl BoxedService {
139 fn get_cloned<T: Clone + 'static>(&self) -> Option<T> {
140 self.value.downcast_ref::<T>().cloned()
141 }
142}
143
144impl<T: Clone + 'static> From<T> for BoxedService {
145 fn from(value: T) -> Self {
146 Self {
147 type_id: TypeId::of::<T>(),
148 value: Box::new(value),
149 }
150 }
151}
152
153struct BoxedConstructor<'a> {
154 pub type_id: TypeId,
155 pub value: Box<dyn Fn(Container) -> Box<dyn Any> + 'a>,
156}
157
158impl<'a> BoxedConstructor<'a> {
159 fn construct<T: Clone + 'static>(&self, container: Container) -> Option<T> {
160 let value = (self.value)(container).downcast::<T>();
161 match value {
162 Err(_) => None,
163 Ok(v) => Some(*v),
164 }
165 }
166}
167
168impl<'a, T: Clone + 'static, F: Fn(Container) -> T + 'a> From<F> for BoxedConstructor<'a> {
169 fn from(value: F) -> Self {
170 Self {
171 type_id: TypeId::of::<T>(),
172 value: Box::new(move |c| Box::new(value(c))),
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use std::sync::Arc;
180
181 use crate::Container;
182
183 #[test]
184 fn basic_register() {
185 let c = Container::new();
186 c.register_service("A".to_string());
187 c.register_service(123 as u64);
188
189 assert_eq!(c.get::<String>(), "A");
190 assert_eq!(c.get::<u64>(), 123);
191 }
192
193 #[test]
194 fn basic_constructor() {
195 let c = Container::new();
196 c.register_constructor(|_| "A".to_string());
197 c.register_constructor(|_| 123 as u64);
198
199 assert_eq!(c.get::<String>(), "A");
200 assert_eq!(c.get::<u64>(), 123);
201 }
202
203 #[allow(dead_code)]
204 struct A {
205 b: Arc<B>,
206 d: Arc<D>,
207 }
208
209 #[allow(dead_code)]
210 struct B {
211 c: Arc<C>,
212 }
213
214 struct C;
215 #[derive(Clone)]
216 struct D;
217
218 #[test]
219 fn complex() {
220 let c = Container::new();
221 c.register_constructor(|container| {
222 Arc::new(A {
223 b: container.get(),
224 d: container.get(),
225 })
226 });
227 c.register_constructor(|container| Arc::new(B { c: container.get() }));
228 c.register_constructor(|_| Arc::new(C));
229 c.register_constructor(|_| Arc::new(D));
230
231 _ = c.get::<Arc<A>>();
232 }
233
234 #[test]
235 fn constructor_with_lifetime() {
236 let outside_string = "A".to_string();
237 let outside_d = D;
238
239 let c = Container::new();
240 c.register_constructor(|_| outside_string.clone());
241 c.register_constructor(|_| outside_d.clone());
242
243 assert_eq!(c.get::<String>(), "A");
244 }
245}