wachter_storage/
type_map.rs1use std::{any::TypeId, collections::HashMap, error::Error, fmt::Display};
2
3pub trait Type: 'static {}
12impl<T: 'static> Type for T {}
14
15#[derive(Debug, PartialEq)]
16pub enum TypeQueryError {
17 Missing(TypeId, TypeId),
19 WrongType(TypeId, TypeId, TypeId),
21}
22impl Error for TypeQueryError {}
23
24impl Display for TypeQueryError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 TypeQueryError::Missing(key, expected) => f.write_str(&format!(
28 "TypeQueryError::Empty(key={:?}, expected={:?})",
29 &key, &expected
30 )),
31 TypeQueryError::WrongType(key, expected, actual) => f.write_str(&format!(
32 "TypeQueryError::Empty(key={:?}, expected={:?}, actual={:?})",
33 &key, &expected, &actual
34 )),
35 }
36 }
37}
38
39pub type TypeQueryResult<T> = Result<T, TypeQueryError>;
40
41pub type RawMap = HashMap<TypeId, Box<dyn std::any::Any>>;
43
44#[derive(Debug)]
45pub struct TypeMap {
46 raw: RawMap,
47}
48
49impl TypeMap {
52 pub fn new() -> Self {
53 Self {
54 raw: HashMap::new(),
55 }
56 }
57 pub fn contains_key<Key: Type>(&self) -> bool {
58 self.raw.contains_key(&TypeId::of::<Key>())
59 }
60
61 pub fn insert_raw<Value: Type>(&mut self, key_type: TypeId, val: Value) {
63 self.raw.insert(key_type, Box::new(val));
64 }
65
66 pub fn insert<Key: Type, Value: Type>(&mut self, val: Value) {
67 self.insert_raw(TypeId::of::<Key>(), val);
68 }
69
70 #[inline]
71 pub fn get_raw<'value, Val: Type>(&'value self, key: TypeId) -> TypeQueryResult<&Val> {
72 self.raw
73 .get(&key)
74 .ok_or(TypeQueryError::Missing(key, TypeId::of::<Val>()))
75 .and_then(|elem| {
76 elem.downcast_ref::<Val>().ok_or(TypeQueryError::WrongType(
77 key,
78 TypeId::of::<Val>(),
79 elem.type_id(),
80 ))
81 })
82 }
83
84 #[inline]
85 pub fn get<'a, Key: Type, Val: Type>(&'a self) -> TypeQueryResult<&Val> {
86 let key = TypeId::of::<Key>();
87 self.get_raw::<Val>(key)
88 }
89
90 #[inline]
91 pub fn get_mut_raw<'value, Val: Type>(
92 &'value mut self,
93 key: TypeId,
94 ) -> TypeQueryResult<&mut Val> {
95 self.raw
96 .get_mut(&key)
97 .ok_or(TypeQueryError::Missing(key, TypeId::of::<Val>()))
98 .and_then(|elem| {
99 let etid = elem.type_id();
100 elem.downcast_mut::<Val>().ok_or(TypeQueryError::WrongType(
101 key,
102 TypeId::of::<Val>(),
103 etid,
104 ))
105 })
106 }
107
108 #[inline]
109 pub fn get_mut<'a, Key: Type, Val: Type>(&'a mut self) -> TypeQueryResult<&mut Val> {
110 let key = TypeId::of::<Key>();
111 self.get_mut_raw::<Val>(key)
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use std::cell::RefCell;
118
119 use super::*;
120
121 #[test]
122 fn test_map_of_u8() {
123 let mut map = TypeMap::new();
124 let data: u8 = 3;
125 map.insert::<u8, u8>(data);
126 assert_eq!(map.get_raw::<u8>(TypeId::of::<u8>()), Ok(&data));
127 assert_eq!(map.get::<u8, u8>(), Ok(&data));
128 }
129
130 #[derive(PartialEq, Eq, Debug, Clone)]
131 struct TestPoint {
132 x: u8,
133 y: u8,
134 }
135
136 #[test]
137 fn test_map_of_struct() {
138 let mut map = TypeMap::new();
139 let data: TestPoint = TestPoint { x: 8, y: 8 };
140 map.insert::<TestPoint, TestPoint>(data.clone());
141 }
142
143 #[test]
144 fn test_map_of_vec_u8() {
145 let mut map = TypeMap::new();
146 let data: Vec<u8> = vec![2, 3];
147 map.insert::<u8, Vec<u8>>(data.clone());
148 assert_eq!(map.get::<u8, Vec<u8>>(), Ok(&data));
150 }
151
152 #[test]
153 fn test_mixed_map() {
154 let mut map = TypeMap::new();
155 let data_u8 = 8;
156 map.insert::<u8, u8>(data_u8);
157 map.insert::<u16, u8>(data_u8);
158 assert_eq!(map.get::<u16, u8>(), Ok(&data_u8));
159 }
160
161 #[test]
162 fn test_mutation() {
163 let mut map = TypeMap::new();
164 let data_u8 = 8;
165 map.insert::<u8, RefCell<u8>>(RefCell::new(data_u8));
166 let reference = map.get::<u8, RefCell<u8>>().expect("Error getting value");
167 *reference.borrow_mut() = 9;
168 assert_eq!(map.get::<u8, RefCell<u8>>(), Ok(&RefCell::new(9)));
169 }
170
171 #[test]
172 fn test_missing_value() {
173 let mut map = TypeMap::new();
174 let data: Vec<u8> = vec![2, 3];
175 map.insert::<u8, Vec<u8>>(data.clone());
176 let typeof_u8 = TypeId::of::<u8>();
177 let typeof_u16 = TypeId::of::<u16>();
178 assert_eq!(
179 map.get::<u16, u8>(),
180 Err(TypeQueryError::Missing(typeof_u16, typeof_u8))
181 );
182 }
183
184 #[test]
185 fn test_wrong_type() {
186 let mut map = TypeMap::new();
187 let data: Vec<u8> = vec![2, 3];
188 map.insert::<u8, Vec<u8>>(data.clone());
189 let typeof_u8 = TypeId::of::<u8>();
190 let typeof_vec_u8 = TypeId::of::<Vec<u8>>();
191 assert_eq!(
192 map.get::<u8, u8>(),
193 Err(TypeQueryError::WrongType(
194 typeof_u8,
195 typeof_u8,
196 typeof_vec_u8
197 ))
198 );
199 }
200}