1use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::fmt;
10
11pub struct Context {
12 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
13}
14
15impl Context {
16 pub fn new() -> Self {
17 Self {
18 map: HashMap::new(),
19 }
20 }
21
22 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
23 self.map
24 .insert(TypeId::of::<T>(), Box::new(value))
25 .and_then(|prev| prev.downcast::<T>().ok().map(|b| *b))
26 }
27
28 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
29 self.map
30 .get(&TypeId::of::<T>())
31 .and_then(|boxed| boxed.downcast_ref::<T>())
32 }
33
34 pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
35 self.map
36 .get_mut(&TypeId::of::<T>())
37 .and_then(|boxed| boxed.downcast_mut::<T>())
38 }
39
40 pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
41 self.map
42 .remove(&TypeId::of::<T>())
43 .and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
44 }
45
46 pub fn contains<T: 'static>(&self) -> bool {
47 self.map.contains_key(&TypeId::of::<T>())
48 }
49
50 pub fn len(&self) -> usize {
51 self.map.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
55 self.map.is_empty()
56 }
57}
58
59impl Default for Context {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl fmt::Debug for Context {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 f.debug_struct("Context")
68 .field("entries", &self.map.len())
69 .finish()
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[derive(Debug, PartialEq)]
78 struct Tag(u32);
79
80 struct NotClone(String);
81
82 #[test]
83 fn insert_and_get_by_type() {
84 let mut ctx = Context::new();
85 assert!(ctx.insert(Tag(42)).is_none());
86 assert_eq!(ctx.get::<Tag>(), Some(&Tag(42)));
87 }
88
89 #[test]
90 fn different_types_coexist() {
91 let mut ctx = Context::new();
92 ctx.insert(Tag(1));
93 ctx.insert(String::from("hello"));
94 assert_eq!(ctx.get::<Tag>(), Some(&Tag(1)));
95 assert_eq!(ctx.get::<String>().map(String::as_str), Some("hello"));
96 }
97
98 #[test]
99 fn insert_replaces_same_type_and_returns_prev() {
100 let mut ctx = Context::new();
101 ctx.insert(Tag(1));
102 let prev = ctx.insert(Tag(2));
103 assert_eq!(prev, Some(Tag(1)));
104 assert_eq!(ctx.get::<Tag>(), Some(&Tag(2)));
105 }
106
107 #[test]
108 fn get_mut_permits_mutation_in_place() {
109 let mut ctx = Context::new();
110 ctx.insert(Tag(0));
111 ctx.get_mut::<Tag>().unwrap().0 = 7;
112 assert_eq!(ctx.get::<Tag>(), Some(&Tag(7)));
113 }
114
115 #[test]
116 fn remove_returns_owned_value_and_clears() {
117 let mut ctx = Context::new();
118 ctx.insert(Tag(9));
119 assert_eq!(ctx.remove::<Tag>(), Some(Tag(9)));
120 assert!(!ctx.contains::<Tag>());
121 assert!(ctx.get::<Tag>().is_none());
122 }
123
124 #[test]
125 fn absent_type_returns_none() {
126 let ctx = Context::new();
127 assert!(ctx.get::<Tag>().is_none());
128 assert!(!ctx.contains::<Tag>());
129 }
130
131 #[test]
132 fn non_clone_types_are_allowed() {
133 let mut ctx = Context::new();
134 ctx.insert(NotClone("present".into()));
135 assert_eq!(ctx.get::<NotClone>().unwrap().0, "present");
136 }
137}