torch_web/extractors/
state.rs1use std::pin::Pin;
6use std::future::Future;
7use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::sync::Arc;
10use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
11
12pub struct State<T>(pub T);
43
44impl<T> FromRequestParts for State<T>
45where
46 T: Clone + Send + Sync + 'static,
47{
48 type Error = ExtractionError;
49
50 fn from_request_parts(
51 req: &mut Request,
52 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
53 let type_id = TypeId::of::<T>();
54
55 let state_result = if let Some(state_any) = req.get_state(type_id) {
57 match state_any.downcast_ref::<T>() {
58 Some(state) => Ok(state.clone()),
59 None => Err(ExtractionError::MissingState(
60 format!("State type mismatch for {}", std::any::type_name::<T>())
61 )),
62 }
63 } else {
64 Err(ExtractionError::MissingState(
65 format!("No state found for type {}", std::any::type_name::<T>())
66 ))
67 };
68
69 Box::pin(async move {
70 match state_result {
71 Ok(state) => Ok(State(state)),
72 Err(err) => Err(err),
73 }
74 })
75 }
76}
77
78#[derive(Clone, Default)]
80pub struct StateMap {
81 states: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
82}
83
84impl StateMap {
85 pub fn new() -> Self {
87 Self {
88 states: HashMap::new(),
89 }
90 }
91
92 pub fn insert<T>(&mut self, state: T)
94 where
95 T: Send + Sync + 'static,
96 {
97 let type_id = TypeId::of::<T>();
98 self.states.insert(type_id, Arc::new(state));
99 }
100
101 pub fn get<T>(&self) -> Option<&T>
103 where
104 T: Send + Sync + 'static,
105 {
106 let type_id = TypeId::of::<T>();
107 self.states
108 .get(&type_id)
109 .and_then(|state| state.downcast_ref::<T>())
110 }
111
112 pub(crate) fn get_by_type_id(&self, type_id: TypeId) -> Option<&Arc<dyn Any + Send + Sync>> {
114 self.states.get(&type_id)
115 }
116
117 pub fn contains<T>(&self) -> bool
119 where
120 T: Send + Sync + 'static,
121 {
122 let type_id = TypeId::of::<T>();
123 self.states.contains_key(&type_id)
124 }
125
126 pub fn remove<T>(&mut self) -> Option<Arc<dyn Any + Send + Sync>>
128 where
129 T: Send + Sync + 'static,
130 {
131 let type_id = TypeId::of::<T>();
132 self.states.remove(&type_id)
133 }
134
135 pub fn len(&self) -> usize {
137 self.states.len()
138 }
139
140 pub fn is_empty(&self) -> bool {
142 self.states.is_empty()
143 }
144}
145
146pub trait RequestStateExt {
148 fn get_state(&self, type_id: TypeId) -> Option<&Arc<dyn Any + Send + Sync>>;
150
151 fn set_state_map(&mut self, state_map: StateMap);
153
154 fn state_map(&self) -> Option<&StateMap>;
156}
157
158#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[derive(Clone, Debug, PartialEq)]
165 struct TestState {
166 value: u32,
167 }
168
169 #[derive(Clone, Debug, PartialEq)]
170 struct AnotherState {
171 name: String,
172 }
173
174 #[test]
175 fn test_state_map_insert_and_get() {
176 let mut state_map = StateMap::new();
177
178 let test_state = TestState { value: 42 };
179 state_map.insert(test_state.clone());
180
181 let retrieved = state_map.get::<TestState>();
182 assert_eq!(retrieved, Some(&test_state));
183 }
184
185 #[test]
186 fn test_state_map_multiple_types() {
187 let mut state_map = StateMap::new();
188
189 let test_state = TestState { value: 42 };
190 let another_state = AnotherState { name: "test".to_string() };
191
192 state_map.insert(test_state.clone());
193 state_map.insert(another_state.clone());
194
195 assert_eq!(state_map.get::<TestState>(), Some(&test_state));
196 assert_eq!(state_map.get::<AnotherState>(), Some(&another_state));
197 }
198
199 #[test]
200 fn test_state_map_missing_type() {
201 let state_map = StateMap::new();
202 let retrieved = state_map.get::<TestState>();
203 assert_eq!(retrieved, None);
204 }
205
206 #[test]
207 fn test_state_map_contains() {
208 let mut state_map = StateMap::new();
209
210 assert!(!state_map.contains::<TestState>());
211
212 state_map.insert(TestState { value: 42 });
213
214 assert!(state_map.contains::<TestState>());
215 assert!(!state_map.contains::<AnotherState>());
216 }
217
218 #[test]
219 fn test_state_map_remove() {
220 let mut state_map = StateMap::new();
221
222 let test_state = TestState { value: 42 };
223 state_map.insert(test_state.clone());
224
225 assert!(state_map.contains::<TestState>());
226
227 let removed = state_map.remove::<TestState>();
228 assert!(removed.is_some());
229 assert!(!state_map.contains::<TestState>());
230 }
231
232 #[test]
233 fn test_state_map_len_and_empty() {
234 let mut state_map = StateMap::new();
235
236 assert_eq!(state_map.len(), 0);
237 assert!(state_map.is_empty());
238
239 state_map.insert(TestState { value: 42 });
240
241 assert_eq!(state_map.len(), 1);
242 assert!(!state_map.is_empty());
243
244 state_map.insert(AnotherState { name: "test".to_string() });
245
246 assert_eq!(state_map.len(), 2);
247 }
248}