vine_core/context/
context.rs

1use std::any::{type_name, TypeId};
2use std::collections::HashMap;
3use std::fmt::{Display, Formatter};
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use log::{debug, trace, warn};
8
9use crate::core::{DynBean, Error};
10use crate::core::bean_def::BeanDef;
11use crate::core::ty::Type;
12
13type InitContextFn = Arc<dyn Fn(&Context) -> Result<(), Error> + Send + Sync>;
14
15#[derive(Clone)]
16pub struct Context {
17    inner: Arc<InnerContext>,
18}
19
20struct InnerContext {
21    name: String,
22    beans: Arc<DashMap<String, DynBean>>,
23    bean_defs: Arc<DashMap<String, Arc<BeanDef>>>,
24    contexts: Arc<DashMap<String, Arc<Context>>>,
25    init_fns: DashMap<String, InitContextFn>
26}
27
28impl Context {
29    pub fn new(name: &str) -> Context {
30        Context {
31            inner: Arc::new(InnerContext {
32                name: name.to_string(),
33                beans: Default::default(),
34                bean_defs: Default::default(),
35                contexts: Default::default(),
36                init_fns: Default::default(),
37            })
38        }
39    }
40
41    pub fn name(&self) -> &str {
42        &self.inner.name
43    }
44
45    pub fn add_context(&self, context: Context) {
46        // TODO: missed feature (context allow overrides) - if override use warn log. Also think about context property to allow overrides
47        self.inner.contexts.insert(context.name().to_string(), Arc::new(context));
48    }
49
50    pub fn add_init_fn(&self, name: &str, init_fn: Arc<dyn Fn(&Context) -> Result<(), Error> + Send + Sync>) -> Result<(), Error>{
51        // TODO: missed feature (context allow init overrides) - if override use warn log. Also think about context property to allow overrides
52        self.inner.init_fns.insert(name.to_string(), init_fn);
53        Ok(())
54    }
55
56    pub fn init_contexts(&self) -> Result<(), Error> {
57        let init_fns = self.inner.get_init_context_fns();
58
59        // TODO: missed feature (disable init function by name)
60        for (init_fn_name, init_fn) in init_fns.iter() {
61            trace!("execute init fn with name {:?}", init_fn_name);
62            init_fn(self)?;
63        }
64
65        Ok(())
66    }
67
68    pub fn register(&self, bean_def: impl Into<BeanDef>) -> Result<(), Error> {
69        let bean_def = bean_def.into();
70        if let Some(_) = self.inner.get_bean_def(bean_def.name()) {
71            warn!("failed to register duplicated BeanDef(name={}, type={}) in {}", bean_def.name(), bean_def.ty().name(), self);
72            return Err(Error::from(format!("failed to register duplicated BeanDef(name={}, type={}) in {}", bean_def.name(), bean_def.ty().name(), self)));
73        };
74
75        trace!("registering {} within {}", &bean_def, self);
76        self.inner.bean_defs.insert(bean_def.name().to_string(), Arc::new(bean_def));
77        Ok(())
78    }
79
80    pub fn get_bean<T: ?Sized + 'static>(&self, name: &str) -> Result<Arc<T>, Error> {
81        if let Some(dyn_bean) = self.inner.get_bean(name) {
82            return Type::downcast::<T>(dyn_bean);
83        }
84
85        let Some(bean_def) = self.inner.get_bean_def(name) else {
86            warn!("cannot resolve Bean(name={}, type={}) in {}", name, type_name::<T>(), self);
87            return Err(Error::from(format!("cannot resolve Bean(name={}, type={}) in {}", name, type_name::<T>(), self)));
88        };
89
90        let (name, dyn_bean) = bean_def.get(self)?;
91        if let Some(_) = self.inner.beans.insert(name.clone(), dyn_bean.clone()) {
92            warn!("unexpected duplicated bean has been created Bean(name={}, type={}) in {}", &name, bean_def.ty().name(), self);
93            return Err(Error::from(format!("unexpected duplicated bean has been created Bean(name={}, type={}) in {}", name, bean_def.ty().name(), self)));
94        };
95
96        debug!("Bean(name={}, type={}) has been added to {}", &name, bean_def.ty().name(), self);
97        return Type::downcast::<T>(dyn_bean);
98    }
99
100    pub fn get_primary_bean<T: ?Sized + 'static>(&self) -> Result<Arc<T>, Error> {
101        let type_id = TypeId::of::<T>();
102        let mut candidates = self.inner.get_bean_defs_by_type(self, &type_id);
103        match candidates.len() {
104            0 => {
105                Err(Error::from(""))
106            },
107            1 => {
108                let bean_def = candidates.pop().unwrap();
109                self.get_bean::<T>(bean_def.name())
110            },
111            _ => {
112                todo!("missed feature(primary beans) - add primary to BeanDef and use it to resolve primary bean")
113            },
114        }
115    }
116
117    pub fn get_beans<T: ?Sized + 'static>(&self) -> Result<Vec<Arc<T>>, Error> {
118        let type_id = TypeId::of::<T>();
119
120        self.inner.get_bean_defs_by_type(self, &type_id)
121            .iter()
122            .map(|def| self.get_bean::<T>(def.name()))
123            .collect()
124    }
125}
126
127impl Display for Context {
128    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129        write!(f, "Context(name={})", &self.inner.name)
130    }
131}
132
133impl InnerContext {
134    fn get_bean(&self, name: &str) -> Option<DynBean> {
135        self.beans.get(name).map(|dyn_ref| dyn_ref.value().clone())
136    }
137    fn get_bean_def(&self, name: &str) -> Option<Arc<BeanDef>> {
138        if let Some(bean_def) = self.bean_defs.get(name) {
139            trace!("found {} in Context(name={})", bean_def.value(), &self.name);
140            return Some(bean_def.value().clone());
141        }
142
143        for ctx in self.contexts.iter() {
144            if let Some(bean_def) = ctx.inner.get_bean_def(name) {
145                return Some(bean_def);
146            }
147        }
148
149        trace!("cannot find BeanDef(name={}) in Context(name={})", name, &self.name);
150        return None;
151    }
152
153    // TODO: missed feature (conditional beans) - use context to check conditional BeanDefs
154    fn get_bean_defs_within_context(&self, ctx: &Context) -> Vec<Arc<BeanDef>> {
155        let mut bean_defs = Vec::new();
156        let mut ctx_defs = self.bean_defs.iter()
157            // filter conditional beans here
158            .map(|def| def.value().clone())
159            .collect();
160
161        bean_defs.append(&mut ctx_defs);
162        for child_ctx in self.contexts.iter() {
163            let mut ctx_defs = child_ctx.inner.get_bean_defs_within_context(ctx);
164            bean_defs.append(&mut ctx_defs);
165        }
166
167        return bean_defs;
168    }
169
170    fn get_bean_defs_by_type(&self, ctx: &Context, type_id: &TypeId) -> Vec<Arc<BeanDef>> {
171        self.get_bean_defs_within_context(ctx).into_iter()
172            .filter(|def| def.ty().assignable(type_id))
173            .collect()
174    }
175
176    fn get_init_context_fns(&self) -> HashMap<String, InitContextFn> {
177        let mut fns = HashMap::new();
178        for ctx in self.contexts.iter() {
179            let mut ctx_fns: Vec<_> = ctx.inner.get_init_context_fns().into_iter().collect();
180            while let Some((key, value)) = ctx_fns.pop() {
181                fns.insert(key, value);
182            }
183        }
184
185        let mut ctx_fns: Vec<_> = self.init_fns.iter()
186            .map(|item_ref| (item_ref.key().clone(), item_ref.value().clone()))
187            .collect();
188
189        while let Some((key, value)) = ctx_fns.pop() {
190            fns.insert(key, value);
191        }
192
193        return fns;
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use std::sync::Arc;
200
201    use crate::context::context::Context;
202    use crate::core::bean_def::BeanDef;
203    use crate::core::Error;
204    use crate::core::ty::Type;
205
206    struct TestBean { name: &'static str, }
207    trait TestTrait { fn name(&self) -> &'static str; }
208    impl TestTrait for TestBean {
209        fn name(&self) -> &'static str { self.name }
210    }
211
212    #[test]
213    fn register_types_bean_def_then_should_return_and_cast_to_struct_and_dyn_trait() -> Result<(), Error>{
214        let ctx = Context::new("test-context");
215
216        let ty = Type::of::<TestBean>();
217        ty.add_downcast::<TestBean>(|b| Ok(Arc::downcast::<TestBean>(b)?));
218        ty.add_downcast::<dyn TestTrait + Sync + Send>(|b| Ok(Arc::downcast::<TestBean>(b)?));
219
220        let bean_def = BeanDef::builder()
221            .ty(ty)
222            .name("testBean")
223            .get(Arc::new(|_ctx| Ok(Arc::new(TestBean { name: "instance_of_testBean" }))))
224            .build();
225
226        ctx.register(bean_def)?;
227
228        let bean = ctx.get_bean::<TestBean>("testBean")?;
229        assert_eq!(bean.as_ref().name, "instance_of_testBean");
230
231        let bean = ctx.get_bean::<dyn TestTrait + Sync + Send>("testBean")?;
232        assert_eq!(bean.name(), "instance_of_testBean");
233
234        Ok(())
235    }
236
237    struct TestBeanWithDep {
238        dyn_dep: Arc<dyn TestTrait + Sync + Send>
239    }
240
241
242    #[test]
243    fn should_create_and_get_bean_with_dyn_dep() -> Result<(), Error>{
244        let ctx = Context::new("test-context");
245
246        let ty = Type::of::<TestBean>();
247        ty.add_downcast::<TestBean>(|b| Ok(Arc::downcast::<TestBean>(b)?));
248        ty.add_downcast::<dyn TestTrait + Sync + Send>(|b| Ok(Arc::downcast::<TestBean>(b)?));
249
250        let bean_def = BeanDef::builder()
251            .ty(ty)
252            .name("testBean")
253            .get(Arc::new(|_ctx| Ok(Arc::new(TestBean { name: "instance_of_testBean" }))))
254            .build();
255        ctx.register(bean_def)?;
256
257        let ty = Type::of::<TestBeanWithDep>();
258        ty.add_downcast::<TestBeanWithDep>(|b| Ok(Arc::downcast::<TestBeanWithDep>(b)?));
259
260        let bean_def = BeanDef::builder()
261            .ty(ty)
262            .name("testBeanWithDep")
263            .get(Arc::new(|ctx| {
264                let bean = Arc::new(TestBeanWithDep {
265                    dyn_dep: ctx.get_bean("testBean")?,
266                });
267                Ok(bean)
268            }))
269            .build();
270        ctx.register(bean_def)?;
271
272        let bean_dep = ctx.get_bean::<TestBean>("testBean")?;
273        assert_eq!(bean_dep.name, "instance_of_testBean");
274
275        let bean = ctx.get_bean::<TestBeanWithDep>("testBeanWithDep")?;
276        assert_eq!(bean.dyn_dep.name(), "instance_of_testBean");
277        Ok(())
278    }
279}