1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4};
5
6#[macro_export]
7macro_rules! provide {
8 ($jab_state:expr, dyn $trait:tt, $value:expr) => {
9 let _temp: std::boxed::Box<dyn $trait + Send + Sync> = Box::new($value);
10
11 $jab_state.put(_temp);
12 };
13 ($jab_state:expr, $trait:ty, $value:expr) => {
14 let _temp: std::boxed::Box<$trait> = Box::new($value);
15
16 $jab_state.put(_temp);
17 };
18 ($jab_state:expr, $value:expr) => {
19 $jab_state.put(Box::new($value));
20 };
21}
22
23#[macro_export]
24macro_rules! fetch {
25 ($jab_state:expr, dyn $trait:tt) => {
26 $jab_state.get::<Box<dyn $trait + Send + Sync>>()
27 };
28 ($jab_state:expr, $trait:ty) => {
29 $jab_state.get::<Box<$trait>>()
30 };
31}
32
33trait JabStateWithDI {
34 fn get_mut<'a>(&'a mut self) -> &'a mut JabDI;
35 fn get<'a>(&'a self) -> &'a JabDI;
36}
37
38#[derive(Debug, Default)]
39pub struct JabDI {
40 dep_map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
41}
42
43impl JabDI {
44 pub fn put<T: 'static + Send + Sync>(&mut self, val: T) {
45 self.dep_map
46 .insert(TypeId::of::<T>(), Box::new(Box::new(val)));
47 }
48
49 pub fn get<T: 'static + ?Sized>(&self) -> &T {
50 if let Some(v) = self.try_get() {
51 v
52 } else {
53 panic!("Could not find requested type");
54 }
55 }
56
57 pub fn try_get<T: 'static + ?Sized>(&self) -> Option<&T> {
58 if let Some(dep) = self.dep_map.get(&TypeId::of::<T>()) {
59 if let Some(val) = dep.downcast_ref::<Box<T>>() {
60 return Some(val);
61 }
62 }
63
64 None
65 }
66
67 pub fn get_mut<T: 'static>(&mut self) -> &mut T {
68 if let Some(v) = self.try_get_mut() {
69 v
70 } else {
71 panic!("Could not find requested type");
72 }
73 }
74
75 pub fn try_get_mut<T: 'static>(&mut self) -> Option<&mut T> {
76 if let Some(dep) = self.dep_map.get_mut(&TypeId::of::<T>()) {
77 if let Some(val) = Box::new(dep).downcast_mut::<T>() {
78 return Some(val);
79 }
80 }
81
82 None
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use crate::JabDI;
89
90 #[derive(Debug, PartialEq)]
91 struct A(i32);
92
93 #[derive(Debug, PartialEq)]
94 struct B(i32);
95
96 trait C {
97 fn valc(&self) -> i32;
98 }
99 trait D {
100 fn vald(&self) -> i32;
101 }
102
103 impl C for A {
104 fn valc(&self) -> i32 {
105 self.0
106 }
107 }
108
109 impl D for B {
110 fn vald(&self) -> i32 {
111 self.0
112 }
113 }
114
115 #[test]
116 fn test_get_struct() {
117 let mut jab = JabDI::default();
118
119 let a = A(0);
120 let b = B(1);
121
122 provide!(jab, a);
123 provide!(jab, b);
124
125 assert_eq!(
126 0,
127 fetch!(jab, A).0,
128 "it should correctly find struct A for struct A"
129 );
130
131 assert_eq!(
132 1,
133 fetch!(jab, B).0,
134 "it should correctly find struct B for struct B"
135 );
136 }
137
138 #[test]
139 fn test_get_trait() {
140 let mut jab = JabDI::default();
141
142 let a = A(0);
143 let b = B(1);
144
145 provide!(jab, dyn C, a);
146 provide!(jab, dyn D, b);
147
148 assert_eq!(
149 0,
150 fetch!(jab, dyn C).valc(),
151 "it should correctly find struct A for trait C"
152 );
153
154 assert_eq!(
155 1,
156 fetch!(jab, dyn D).vald(),
157 "it should correctly find struct B for trait D"
158 );
159 }
160}