Skip to main content

rustio_core/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt;
4
5pub struct Context {
6    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
7}
8
9impl Context {
10    pub fn new() -> Self {
11        Self {
12            map: HashMap::new(),
13        }
14    }
15
16    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
17        self.map
18            .insert(TypeId::of::<T>(), Box::new(value))
19            .and_then(|prev| prev.downcast::<T>().ok().map(|b| *b))
20    }
21
22    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
23        self.map
24            .get(&TypeId::of::<T>())
25            .and_then(|boxed| boxed.downcast_ref::<T>())
26    }
27
28    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
29        self.map
30            .get_mut(&TypeId::of::<T>())
31            .and_then(|boxed| boxed.downcast_mut::<T>())
32    }
33
34    pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
35        self.map
36            .remove(&TypeId::of::<T>())
37            .and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
38    }
39
40    pub fn contains<T: 'static>(&self) -> bool {
41        self.map.contains_key(&TypeId::of::<T>())
42    }
43
44    pub fn len(&self) -> usize {
45        self.map.len()
46    }
47
48    pub fn is_empty(&self) -> bool {
49        self.map.is_empty()
50    }
51}
52
53impl Default for Context {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl fmt::Debug for Context {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("Context")
62            .field("entries", &self.map.len())
63            .finish()
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[derive(Debug, PartialEq)]
72    struct Tag(u32);
73
74    struct NotClone(String);
75
76    #[test]
77    fn insert_and_get_by_type() {
78        let mut ctx = Context::new();
79        assert!(ctx.insert(Tag(42)).is_none());
80        assert_eq!(ctx.get::<Tag>(), Some(&Tag(42)));
81    }
82
83    #[test]
84    fn different_types_coexist() {
85        let mut ctx = Context::new();
86        ctx.insert(Tag(1));
87        ctx.insert(String::from("hello"));
88        assert_eq!(ctx.get::<Tag>(), Some(&Tag(1)));
89        assert_eq!(ctx.get::<String>().map(String::as_str), Some("hello"));
90    }
91
92    #[test]
93    fn insert_replaces_same_type_and_returns_prev() {
94        let mut ctx = Context::new();
95        ctx.insert(Tag(1));
96        let prev = ctx.insert(Tag(2));
97        assert_eq!(prev, Some(Tag(1)));
98        assert_eq!(ctx.get::<Tag>(), Some(&Tag(2)));
99    }
100
101    #[test]
102    fn get_mut_permits_mutation_in_place() {
103        let mut ctx = Context::new();
104        ctx.insert(Tag(0));
105        ctx.get_mut::<Tag>().unwrap().0 = 7;
106        assert_eq!(ctx.get::<Tag>(), Some(&Tag(7)));
107    }
108
109    #[test]
110    fn remove_returns_owned_value_and_clears() {
111        let mut ctx = Context::new();
112        ctx.insert(Tag(9));
113        assert_eq!(ctx.remove::<Tag>(), Some(Tag(9)));
114        assert!(!ctx.contains::<Tag>());
115        assert!(ctx.get::<Tag>().is_none());
116    }
117
118    #[test]
119    fn absent_type_returns_none() {
120        let ctx = Context::new();
121        assert!(ctx.get::<Tag>().is_none());
122        assert!(!ctx.contains::<Tag>());
123    }
124
125    #[test]
126    fn non_clone_types_are_allowed() {
127        let mut ctx = Context::new();
128        ctx.insert(NotClone("present".into()));
129        assert_eq!(ctx.get::<NotClone>().unwrap().0, "present");
130    }
131}