runtime_context/
context.rs1use super::{Data, ShareableTid, TypeMap};
2use std::any::TypeId;
3
4pub struct Context<'ty, 'r> {
10 data: TypeMap<Data<'ty, 'r>>,
11}
12
13impl Default for Context<'_, '_> {
14 #[inline]
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl<'ty, 'r> Context<'ty, 'r> {
21 #[inline]
23 pub fn new() -> Self {
24 Self {
25 data: TypeMap::default(),
26 }
27 }
28
29 #[inline]
33 pub fn insert_unchecked(&mut self, key: TypeId, data: Data<'ty, 'r>) {
34 self.data.insert(key, data);
35 }
36
37 #[inline]
39 pub fn insert_ref<T: ShareableTid<'ty>>(&mut self, value: &'r T) {
40 self.data.insert(T::id(), Data::Borrowed(value));
41 }
42
43 #[inline]
45 pub fn insert_mut<T: ShareableTid<'ty>>(&mut self, value: &'r mut T) {
46 self.data.insert(T::id(), Data::Mut(value));
47 }
48
49 #[inline]
51 pub fn insert<T: ShareableTid<'ty>>(&mut self, value: T) {
52 self.data.insert(T::id(), Data::Owned(Box::new(value)));
53 }
54
55 #[inline]
57 pub fn get<'b, T: ShareableTid<'ty>>(&'b self) -> Option<&'b T> {
58 self.data.get(&T::id()).and_then(|v| v.downcast_ref())
59 }
60
61 #[inline]
63 pub fn get_mut<'b, T: ShareableTid<'ty>>(&'b mut self) -> Option<&'b mut T> {
64 self.data
65 .get_mut(&T::id())
66 .and_then(|v| v.downcast_mut())
67 }
68
69 #[inline]
71 pub fn get_data<'b>(&'b self, id: &TypeId) -> Option<&'b Data<'ty, 'r>> {
72 self.data.get(id)
73 }
74
75 #[inline]
77 pub fn get_data_mut<'b>(&'b mut self, id: &TypeId) -> Option<&'b mut Data<'ty, 'r>> {
78 self.data.get_mut(id)
79 }
80
81 #[inline]
83 pub fn get_disjoint_mut<'b, const N: usize>(
84 &'b mut self,
85 keys: [&TypeId; N],
86 ) -> [Option<&'b mut Data<'ty, 'r>>; N] {
87 self.data.get_disjoint_mut(keys)
88 }
89
90 #[inline]
92 pub fn take<T: ShareableTid<'ty>>(&mut self) -> Option<T> {
93 let id = T::id();
94 match self.data.remove(&id) {
95 Some(data) => data.try_take_owned::<T>().ok(),
96 None => None,
97 }
98 }
99
100 #[inline]
102 pub fn remove<T: ShareableTid<'ty>>(&mut self) -> Option<Data<'ty, 'r>> {
103 self.data.remove(&T::id())
104 }
105
106 #[inline]
108 pub fn contains<T: ShareableTid<'ty>>(&self) -> bool {
109 self.data.contains_key(&T::id())
110 }
111
112 #[inline]
114 pub fn clear(&mut self) {
115 self.data.clear();
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use better_any::{Tid, tid};
122
123 use super::*;
124
125 #[derive(Debug, Clone, PartialEq, Eq)]
126 struct Dummy<'a>(&'a str);
127 tid!(Dummy<'_>);
128
129 #[test]
130 fn test_context_owned() {
131 let dummy = Dummy("Hello, World!");
132 let mut context = Context::new();
133
134 context.insert(dummy);
135 assert!(matches!(context.get::<Dummy>(), Some(_)));
136 assert!(matches!(context.get_mut::<Dummy>(), Some(_)));
137 assert_eq!(context.contains::<Dummy>(), true);
138 }
139
140 #[test]
141 fn test_context_ref() {
142 let dummy = Dummy("Hello, World!");
143 let mut context = Context::new();
144
145 context.insert_ref(&dummy);
146 assert_eq!(context.get::<Dummy>(), Some(&dummy));
147 assert_eq!(context.get_mut::<Dummy>(), None);
148 assert_eq!(context.contains::<Dummy>(), true);
149 }
150
151 #[test]
152 fn test_context_mut() {
153 let mut dummy = Dummy("Hello, World!");
154 let mut context = Context::new();
155
156 context.insert_mut(&mut dummy);
157 assert!(matches!(context.get::<Dummy>(), Some(_)));
158 assert!(matches!(context.get_mut::<Dummy>(), Some(_)));
159 assert_eq!(context.contains::<Dummy>(), true);
160 }
161
162 #[test]
163 fn test_context_no_immutable_err() {
164 let mut dummy = Dummy("Hello, World!");
165 {
166 let mut context = Context::new();
167 context.insert_mut(&mut dummy);
168 }
169
170 assert_eq!(dummy.0, "Hello, World!");
171 }
172
173 #[test]
174 fn test_downcast_to_trait() {
175 trait Foo {
176 fn foo(&self) -> &str;
177 }
178
179 impl Foo for Dummy<'_> {
180 fn foo(&self) -> &str {
181 self.0
182 }
183 }
184
185 struct FooWrapper<'a, T: Foo + 'static>(&'a mut T);
186 tid! { impl<'a, T: 'static> TidAble<'a> for FooWrapper<'a, T> where T: Foo }
187
188 let mut dummy = Dummy("Hello, World!");
189 let mut context = Context::new();
190 context.insert(FooWrapper(&mut dummy));
191
192 fn inner_ref_fn<T: Foo + 'static>(context: &Context) {
193 let data = context
194 .get_data(&FooWrapper::<T>::id())
195 .expect("Data not found");
196 data.downcast_ref::<FooWrapper<T>>()
197 .expect("Downcast failed")
198 .0
199 .foo();
200 }
201
202 inner_ref_fn::<Dummy>(&context);
203
204 fn inner_mut_fn<T: Foo + 'static>(context: &mut Context) {
205 let data = context
206 .get_data_mut(&FooWrapper::<T>::id())
207 .expect("Data not found");
208 data.downcast_mut::<FooWrapper<T>>()
209 .expect("Downcast failed")
210 .0
211 .foo();
212 }
213
214 inner_mut_fn::<Dummy>(&mut context);
215 }
216
217 #[test]
218 fn test_take_and_remove() {
219 #[derive(Debug, Clone, PartialEq, Eq)]
220 struct TakeMe(u64);
221 tid!(TakeMe);
222
223 let mut context = Context::new();
224 context.insert(TakeMe(7));
225
226 let owned = context.take::<TakeMe>().unwrap();
227 assert_eq!(owned, TakeMe(7));
228 assert_eq!(context.contains::<TakeMe>(), false);
229
230 context.insert(TakeMe(9));
231 let data = context.remove::<TakeMe>().unwrap();
232 assert!(matches!(data.try_take_owned::<TakeMe>(), Ok(TakeMe(9))));
233 }
234
235 #[test]
236 fn test_get_disjoint_mut() {
237 #[derive(Debug, Clone, PartialEq, Eq)]
238 struct A(u8);
239 #[derive(Debug, Clone, PartialEq, Eq)]
240 struct B(u8);
241 tid!(A);
242 tid!(B);
243
244 let mut context = Context::new();
245 context.insert(A(1));
246 context.insert(B(2));
247
248 let [a, b] = context.get_disjoint_mut([&A::id(), &B::id()]);
249 let a = a.unwrap().downcast_mut::<A>().unwrap();
250 let b = b.unwrap().downcast_mut::<B>().unwrap();
251
252 a.0 += 1;
253 b.0 += 2;
254
255 assert_eq!(context.get::<A>().unwrap().0, 2);
256 assert_eq!(context.get::<B>().unwrap().0, 4);
257 }
258
259 #[test]
260 fn test_clear_and_get_data() {
261 #[derive(Debug, Clone, PartialEq, Eq)]
262 struct C(i32);
263 tid!(C);
264
265 let mut context = Context::new();
266 context.insert(C(10));
267
268 let data = context.get_data(&C::id()).unwrap();
269 assert_eq!(data.downcast_ref::<C>().unwrap().0, 10);
270
271 context.clear();
272 assert_eq!(context.get::<C>(), None);
273 }
274}