tirea_contract/runtime/
extensions.rs1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3
4pub struct Extensions {
11 map: HashMap<TypeId, Box<dyn Any + Send>>,
12}
13
14impl Extensions {
15 pub fn new() -> Self {
16 Self {
17 map: HashMap::new(),
18 }
19 }
20
21 pub fn get<T: 'static + Send>(&self) -> Option<&T> {
23 self.map
24 .get(&TypeId::of::<T>())
25 .and_then(|boxed| boxed.downcast_ref::<T>())
26 }
27
28 pub fn get_mut<T: 'static + Send>(&mut self) -> Option<&mut T> {
30 self.map
31 .get_mut(&TypeId::of::<T>())
32 .and_then(|boxed| boxed.downcast_mut::<T>())
33 }
34
35 pub fn get_or_default<T: 'static + Send + Default>(&mut self) -> &mut T {
37 self.map
38 .entry(TypeId::of::<T>())
39 .or_insert_with(|| Box::new(T::default()))
40 .downcast_mut::<T>()
41 .expect("type mismatch in Extensions (impossible)")
42 }
43
44 pub fn insert<T: 'static + Send>(&mut self, val: T) -> Option<T> {
46 self.map
47 .insert(TypeId::of::<T>(), Box::new(val))
48 .and_then(|prev| prev.downcast::<T>().ok())
49 .map(|boxed| *boxed)
50 }
51
52 pub fn clear(&mut self) {
54 self.map.clear();
55 }
56
57 pub fn contains<T: 'static + Send>(&self) -> bool {
59 self.map.contains_key(&TypeId::of::<T>())
60 }
61}
62
63impl Default for Extensions {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl std::fmt::Debug for Extensions {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Extensions")
72 .field("len", &self.map.len())
73 .finish()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[test]
82 fn insert_and_get() {
83 let mut ext = Extensions::new();
84 ext.insert(42u32);
85 assert_eq!(ext.get::<u32>(), Some(&42));
86 }
87
88 #[test]
89 fn get_mut() {
90 let mut ext = Extensions::new();
91 ext.insert(String::from("hello"));
92 ext.get_mut::<String>().unwrap().push_str(" world");
93 assert_eq!(ext.get::<String>().unwrap(), "hello world");
94 }
95
96 #[test]
97 fn get_or_default() {
98 let mut ext = Extensions::new();
99 let val: &mut Vec<i32> = ext.get_or_default();
100 val.push(1);
101 assert_eq!(ext.get::<Vec<i32>>().unwrap(), &vec![1]);
102 }
103
104 #[test]
105 fn insert_replaces() {
106 let mut ext = Extensions::new();
107 ext.insert(1u32);
108 let prev = ext.insert(2u32);
109 assert_eq!(prev, Some(1));
110 assert_eq!(ext.get::<u32>(), Some(&2));
111 }
112
113 #[test]
114 fn clear_removes_all() {
115 let mut ext = Extensions::new();
116 ext.insert(1u32);
117 ext.insert(String::from("x"));
118 ext.clear();
119 assert!(!ext.contains::<u32>());
120 assert!(!ext.contains::<String>());
121 }
122
123 #[test]
124 fn different_types_coexist() {
125 let mut ext = Extensions::new();
126 ext.insert(42u32);
127 ext.insert(String::from("hello"));
128 ext.insert(true);
129 assert_eq!(ext.get::<u32>(), Some(&42));
130 assert_eq!(ext.get::<String>().unwrap(), "hello");
131 assert_eq!(ext.get::<bool>(), Some(&true));
132 }
133
134 #[test]
135 fn missing_type_returns_none() {
136 let ext = Extensions::new();
137 assert!(ext.get::<u32>().is_none());
138 }
139}