wachter_storage/
type_map.rs

1use std::{any::TypeId, collections::HashMap, error::Error, fmt::Display};
2
3// Why do we need this 'static here?
4// Because if this is to represent any type, then it might also need to represent references.
5// So, we could conceivable want to create a `TypeMap<&Foo, Bar>`.
6// In those cases, the borrow checker can't assert that references to Foo used as keys will live
7// for long enough, so we use 'static.
8// Incidentally, this also forbids reference types that are not 'static, so essentially all
9// accidental reference types. You have to really try to make a reference type a key or value.
10// pub trait Type: 'static + Sync + Send {}
11pub trait Type: 'static {}
12// This implements `Type` for every type.
13impl<T: 'static> Type for T {}
14
15#[derive(Debug, PartialEq)]
16pub enum TypeQueryError {
17    // Key, ExpectedVal
18    Missing(TypeId, TypeId),
19    // Key, ExpectedVal, ActualVal
20    WrongType(TypeId, TypeId, TypeId),
21}
22impl Error for TypeQueryError {}
23
24impl Display for TypeQueryError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            TypeQueryError::Missing(key, expected) => f.write_str(&format!(
28                "TypeQueryError::Empty(key={:?}, expected={:?})",
29                &key, &expected
30            )),
31            TypeQueryError::WrongType(key, expected, actual) => f.write_str(&format!(
32                "TypeQueryError::Empty(key={:?}, expected={:?}, actual={:?})",
33                &key, &expected, &actual
34            )),
35        }
36    }
37}
38
39pub type TypeQueryResult<T> = Result<T, TypeQueryError>;
40
41// Like anymap, but stores a vector of somethings instead of somethings.
42pub type RawMap = HashMap<TypeId, Box<dyn std::any::Any>>;
43
44#[derive(Debug)]
45pub struct TypeMap {
46    raw: RawMap,
47}
48
49// TODO Add removals
50// TODO Add unchecked queries for performance.
51impl TypeMap {
52    pub fn new() -> Self {
53        Self {
54            raw: HashMap::new(),
55        }
56    }
57    pub fn contains_key<Key: Type>(&self) -> bool {
58        self.raw.contains_key(&TypeId::of::<Key>())
59    }
60
61    // TODO Handle overrides gracefully.
62    pub fn insert_raw<Value: Type>(&mut self, key_type: TypeId, val: Value) {
63        self.raw.insert(key_type, Box::new(val));
64    }
65
66    pub fn insert<Key: Type, Value: Type>(&mut self, val: Value) {
67        self.insert_raw(TypeId::of::<Key>(), val);
68    }
69
70    #[inline]
71    pub fn get_raw<'value, Val: Type>(&'value self, key: TypeId) -> TypeQueryResult<&Val> {
72        self.raw
73            .get(&key)
74            .ok_or(TypeQueryError::Missing(key, TypeId::of::<Val>()))
75            .and_then(|elem| {
76                elem.downcast_ref::<Val>().ok_or(TypeQueryError::WrongType(
77                    key,
78                    TypeId::of::<Val>(),
79                    elem.type_id(),
80                ))
81            })
82    }
83
84    #[inline]
85    pub fn get<'a, Key: Type, Val: Type>(&'a self) -> TypeQueryResult<&Val> {
86        let key = TypeId::of::<Key>();
87        self.get_raw::<Val>(key)
88    }
89
90    #[inline]
91    pub fn get_mut_raw<'value, Val: Type>(
92        &'value mut self,
93        key: TypeId,
94    ) -> TypeQueryResult<&mut Val> {
95        self.raw
96            .get_mut(&key)
97            .ok_or(TypeQueryError::Missing(key, TypeId::of::<Val>()))
98            .and_then(|elem| {
99                let etid = elem.type_id();
100                elem.downcast_mut::<Val>().ok_or(TypeQueryError::WrongType(
101                    key,
102                    TypeId::of::<Val>(),
103                    etid,
104                ))
105            })
106    }
107
108    #[inline]
109    pub fn get_mut<'a, Key: Type, Val: Type>(&'a mut self) -> TypeQueryResult<&mut Val> {
110        let key = TypeId::of::<Key>();
111        self.get_mut_raw::<Val>(key)
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use std::cell::RefCell;
118
119    use super::*;
120
121    #[test]
122    fn test_map_of_u8() {
123        let mut map = TypeMap::new();
124        let data: u8 = 3;
125        map.insert::<u8, u8>(data);
126        assert_eq!(map.get_raw::<u8>(TypeId::of::<u8>()), Ok(&data));
127        assert_eq!(map.get::<u8, u8>(), Ok(&data));
128    }
129
130    #[derive(PartialEq, Eq, Debug, Clone)]
131    struct TestPoint {
132        x: u8,
133        y: u8,
134    }
135
136    #[test]
137    fn test_map_of_struct() {
138        let mut map = TypeMap::new();
139        let data: TestPoint = TestPoint { x: 8, y: 8 };
140        map.insert::<TestPoint, TestPoint>(data.clone());
141    }
142
143    #[test]
144    fn test_map_of_vec_u8() {
145        let mut map = TypeMap::new();
146        let data: Vec<u8> = vec![2, 3];
147        map.insert::<u8, Vec<u8>>(data.clone());
148        // assert_eq!(map.get::<u8, Vec<u8>>(), Ok(Rc::new(data)));
149        assert_eq!(map.get::<u8, Vec<u8>>(), Ok(&data));
150    }
151
152    #[test]
153    fn test_mixed_map() {
154        let mut map = TypeMap::new();
155        let data_u8 = 8;
156        map.insert::<u8, u8>(data_u8);
157        map.insert::<u16, u8>(data_u8);
158        assert_eq!(map.get::<u16, u8>(), Ok(&data_u8));
159    }
160
161    #[test]
162    fn test_mutation() {
163        let mut map = TypeMap::new();
164        let data_u8 = 8;
165        map.insert::<u8, RefCell<u8>>(RefCell::new(data_u8));
166        let reference = map.get::<u8, RefCell<u8>>().expect("Error getting value");
167        *reference.borrow_mut() = 9;
168        assert_eq!(map.get::<u8, RefCell<u8>>(), Ok(&RefCell::new(9)));
169    }
170
171    #[test]
172    fn test_missing_value() {
173        let mut map = TypeMap::new();
174        let data: Vec<u8> = vec![2, 3];
175        map.insert::<u8, Vec<u8>>(data.clone());
176        let typeof_u8 = TypeId::of::<u8>();
177        let typeof_u16 = TypeId::of::<u16>();
178        assert_eq!(
179            map.get::<u16, u8>(),
180            Err(TypeQueryError::Missing(typeof_u16, typeof_u8))
181        );
182    }
183
184    #[test]
185    fn test_wrong_type() {
186        let mut map = TypeMap::new();
187        let data: Vec<u8> = vec![2, 3];
188        map.insert::<u8, Vec<u8>>(data.clone());
189        let typeof_u8 = TypeId::of::<u8>();
190        let typeof_vec_u8 = TypeId::of::<Vec<u8>>();
191        assert_eq!(
192            map.get::<u8, u8>(),
193            Err(TypeQueryError::WrongType(
194                typeof_u8,
195                typeof_u8,
196                typeof_vec_u8
197            ))
198        );
199    }
200}