1use std::any::{type_name, TypeId};
2use std::collections::HashMap;
3use std::fmt::{Display, Formatter};
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use log::{debug, trace, warn};
8
9use crate::core::{DynBean, Error};
10use crate::core::bean_def::BeanDef;
11use crate::core::ty::Type;
12
13type InitContextFn = Arc<dyn Fn(&Context) -> Result<(), Error> + Send + Sync>;
14
15#[derive(Clone)]
16pub struct Context {
17 inner: Arc<InnerContext>,
18}
19
20struct InnerContext {
21 name: String,
22 beans: Arc<DashMap<String, DynBean>>,
23 bean_defs: Arc<DashMap<String, Arc<BeanDef>>>,
24 contexts: Arc<DashMap<String, Arc<Context>>>,
25 init_fns: DashMap<String, InitContextFn>
26}
27
28impl Context {
29 pub fn new(name: &str) -> Context {
30 Context {
31 inner: Arc::new(InnerContext {
32 name: name.to_string(),
33 beans: Default::default(),
34 bean_defs: Default::default(),
35 contexts: Default::default(),
36 init_fns: Default::default(),
37 })
38 }
39 }
40
41 pub fn name(&self) -> &str {
42 &self.inner.name
43 }
44
45 pub fn add_context(&self, context: Context) {
46 self.inner.contexts.insert(context.name().to_string(), Arc::new(context));
48 }
49
50 pub fn add_init_fn(&self, name: &str, init_fn: Arc<dyn Fn(&Context) -> Result<(), Error> + Send + Sync>) -> Result<(), Error>{
51 self.inner.init_fns.insert(name.to_string(), init_fn);
53 Ok(())
54 }
55
56 pub fn init_contexts(&self) -> Result<(), Error> {
57 let init_fns = self.inner.get_init_context_fns();
58
59 for (init_fn_name, init_fn) in init_fns.iter() {
61 trace!("execute init fn with name {:?}", init_fn_name);
62 init_fn(self)?;
63 }
64
65 Ok(())
66 }
67
68 pub fn register(&self, bean_def: impl Into<BeanDef>) -> Result<(), Error> {
69 let bean_def = bean_def.into();
70 if let Some(_) = self.inner.get_bean_def(bean_def.name()) {
71 warn!("failed to register duplicated BeanDef(name={}, type={}) in {}", bean_def.name(), bean_def.ty().name(), self);
72 return Err(Error::from(format!("failed to register duplicated BeanDef(name={}, type={}) in {}", bean_def.name(), bean_def.ty().name(), self)));
73 };
74
75 trace!("registering {} within {}", &bean_def, self);
76 self.inner.bean_defs.insert(bean_def.name().to_string(), Arc::new(bean_def));
77 Ok(())
78 }
79
80 pub fn get_bean<T: ?Sized + 'static>(&self, name: &str) -> Result<Arc<T>, Error> {
81 if let Some(dyn_bean) = self.inner.get_bean(name) {
82 return Type::downcast::<T>(dyn_bean);
83 }
84
85 let Some(bean_def) = self.inner.get_bean_def(name) else {
86 warn!("cannot resolve Bean(name={}, type={}) in {}", name, type_name::<T>(), self);
87 return Err(Error::from(format!("cannot resolve Bean(name={}, type={}) in {}", name, type_name::<T>(), self)));
88 };
89
90 let (name, dyn_bean) = bean_def.get(self)?;
91 if let Some(_) = self.inner.beans.insert(name.clone(), dyn_bean.clone()) {
92 warn!("unexpected duplicated bean has been created Bean(name={}, type={}) in {}", &name, bean_def.ty().name(), self);
93 return Err(Error::from(format!("unexpected duplicated bean has been created Bean(name={}, type={}) in {}", name, bean_def.ty().name(), self)));
94 };
95
96 debug!("Bean(name={}, type={}) has been added to {}", &name, bean_def.ty().name(), self);
97 return Type::downcast::<T>(dyn_bean);
98 }
99
100 pub fn get_primary_bean<T: ?Sized + 'static>(&self) -> Result<Arc<T>, Error> {
101 let type_id = TypeId::of::<T>();
102 let mut candidates = self.inner.get_bean_defs_by_type(self, &type_id);
103 match candidates.len() {
104 0 => {
105 Err(Error::from(""))
106 },
107 1 => {
108 let bean_def = candidates.pop().unwrap();
109 self.get_bean::<T>(bean_def.name())
110 },
111 _ => {
112 todo!("missed feature(primary beans) - add primary to BeanDef and use it to resolve primary bean")
113 },
114 }
115 }
116
117 pub fn get_beans<T: ?Sized + 'static>(&self) -> Result<Vec<Arc<T>>, Error> {
118 let type_id = TypeId::of::<T>();
119
120 self.inner.get_bean_defs_by_type(self, &type_id)
121 .iter()
122 .map(|def| self.get_bean::<T>(def.name()))
123 .collect()
124 }
125}
126
127impl Display for Context {
128 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129 write!(f, "Context(name={})", &self.inner.name)
130 }
131}
132
133impl InnerContext {
134 fn get_bean(&self, name: &str) -> Option<DynBean> {
135 self.beans.get(name).map(|dyn_ref| dyn_ref.value().clone())
136 }
137 fn get_bean_def(&self, name: &str) -> Option<Arc<BeanDef>> {
138 if let Some(bean_def) = self.bean_defs.get(name) {
139 trace!("found {} in Context(name={})", bean_def.value(), &self.name);
140 return Some(bean_def.value().clone());
141 }
142
143 for ctx in self.contexts.iter() {
144 if let Some(bean_def) = ctx.inner.get_bean_def(name) {
145 return Some(bean_def);
146 }
147 }
148
149 trace!("cannot find BeanDef(name={}) in Context(name={})", name, &self.name);
150 return None;
151 }
152
153 fn get_bean_defs_within_context(&self, ctx: &Context) -> Vec<Arc<BeanDef>> {
155 let mut bean_defs = Vec::new();
156 let mut ctx_defs = self.bean_defs.iter()
157 .map(|def| def.value().clone())
159 .collect();
160
161 bean_defs.append(&mut ctx_defs);
162 for child_ctx in self.contexts.iter() {
163 let mut ctx_defs = child_ctx.inner.get_bean_defs_within_context(ctx);
164 bean_defs.append(&mut ctx_defs);
165 }
166
167 return bean_defs;
168 }
169
170 fn get_bean_defs_by_type(&self, ctx: &Context, type_id: &TypeId) -> Vec<Arc<BeanDef>> {
171 self.get_bean_defs_within_context(ctx).into_iter()
172 .filter(|def| def.ty().assignable(type_id))
173 .collect()
174 }
175
176 fn get_init_context_fns(&self) -> HashMap<String, InitContextFn> {
177 let mut fns = HashMap::new();
178 for ctx in self.contexts.iter() {
179 let mut ctx_fns: Vec<_> = ctx.inner.get_init_context_fns().into_iter().collect();
180 while let Some((key, value)) = ctx_fns.pop() {
181 fns.insert(key, value);
182 }
183 }
184
185 let mut ctx_fns: Vec<_> = self.init_fns.iter()
186 .map(|item_ref| (item_ref.key().clone(), item_ref.value().clone()))
187 .collect();
188
189 while let Some((key, value)) = ctx_fns.pop() {
190 fns.insert(key, value);
191 }
192
193 return fns;
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use std::sync::Arc;
200
201 use crate::context::context::Context;
202 use crate::core::bean_def::BeanDef;
203 use crate::core::Error;
204 use crate::core::ty::Type;
205
206 struct TestBean { name: &'static str, }
207 trait TestTrait { fn name(&self) -> &'static str; }
208 impl TestTrait for TestBean {
209 fn name(&self) -> &'static str { self.name }
210 }
211
212 #[test]
213 fn register_types_bean_def_then_should_return_and_cast_to_struct_and_dyn_trait() -> Result<(), Error>{
214 let ctx = Context::new("test-context");
215
216 let ty = Type::of::<TestBean>();
217 ty.add_downcast::<TestBean>(|b| Ok(Arc::downcast::<TestBean>(b)?));
218 ty.add_downcast::<dyn TestTrait + Sync + Send>(|b| Ok(Arc::downcast::<TestBean>(b)?));
219
220 let bean_def = BeanDef::builder()
221 .ty(ty)
222 .name("testBean")
223 .get(Arc::new(|_ctx| Ok(Arc::new(TestBean { name: "instance_of_testBean" }))))
224 .build();
225
226 ctx.register(bean_def)?;
227
228 let bean = ctx.get_bean::<TestBean>("testBean")?;
229 assert_eq!(bean.as_ref().name, "instance_of_testBean");
230
231 let bean = ctx.get_bean::<dyn TestTrait + Sync + Send>("testBean")?;
232 assert_eq!(bean.name(), "instance_of_testBean");
233
234 Ok(())
235 }
236
237 struct TestBeanWithDep {
238 dyn_dep: Arc<dyn TestTrait + Sync + Send>
239 }
240
241
242 #[test]
243 fn should_create_and_get_bean_with_dyn_dep() -> Result<(), Error>{
244 let ctx = Context::new("test-context");
245
246 let ty = Type::of::<TestBean>();
247 ty.add_downcast::<TestBean>(|b| Ok(Arc::downcast::<TestBean>(b)?));
248 ty.add_downcast::<dyn TestTrait + Sync + Send>(|b| Ok(Arc::downcast::<TestBean>(b)?));
249
250 let bean_def = BeanDef::builder()
251 .ty(ty)
252 .name("testBean")
253 .get(Arc::new(|_ctx| Ok(Arc::new(TestBean { name: "instance_of_testBean" }))))
254 .build();
255 ctx.register(bean_def)?;
256
257 let ty = Type::of::<TestBeanWithDep>();
258 ty.add_downcast::<TestBeanWithDep>(|b| Ok(Arc::downcast::<TestBeanWithDep>(b)?));
259
260 let bean_def = BeanDef::builder()
261 .ty(ty)
262 .name("testBeanWithDep")
263 .get(Arc::new(|ctx| {
264 let bean = Arc::new(TestBeanWithDep {
265 dyn_dep: ctx.get_bean("testBean")?,
266 });
267 Ok(bean)
268 }))
269 .build();
270 ctx.register(bean_def)?;
271
272 let bean_dep = ctx.get_bean::<TestBean>("testBean")?;
273 assert_eq!(bean_dep.name, "instance_of_testBean");
274
275 let bean = ctx.get_bean::<TestBeanWithDep>("testBeanWithDep")?;
276 assert_eq!(bean.dyn_dep.name(), "instance_of_testBean");
277 Ok(())
278 }
279}