torch_web/extractors/
state.rs

1//! Application state extraction
2//!
3//! Extract shared application state from the request context.
4
5use std::pin::Pin;
6use std::future::Future;
7use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::sync::Arc;
10use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
11
12/// Extract application state of a specific type
13///
14/// # Example
15///
16/// ```rust,no_run
17/// use torch_web::{App, extractors::State};
18/// use std::sync::Arc;
19/// use tokio::sync::Mutex;
20///
21/// #[derive(Clone)]
22/// struct AppState {
23///     counter: Arc<Mutex<u64>>,
24/// }
25///
26/// async fn increment(State(state): State<AppState>) {
27///     let mut counter = state.counter.lock().await;
28///     *counter += 1;
29/// }
30///
31/// #[tokio::main]
32/// async fn main() {
33///     let state = AppState {
34///         counter: Arc::new(Mutex::new(0)),
35///     };
36///
37///     let app = App::new()
38///         .with_state(state)
39///         .get("/increment", increment);
40/// }
41/// ```
42pub struct State<T>(pub T);
43
44impl<T> FromRequestParts for State<T>
45where
46    T: Clone + Send + Sync + 'static,
47{
48    type Error = ExtractionError;
49
50    fn from_request_parts(
51        req: &mut Request,
52    ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
53        let type_id = TypeId::of::<T>();
54
55        // Clone the state to avoid lifetime issues
56        let state_result = if let Some(state_any) = req.get_state(type_id) {
57            match state_any.downcast_ref::<T>() {
58                Some(state) => Ok(state.clone()),
59                None => Err(ExtractionError::MissingState(
60                    format!("State type mismatch for {}", std::any::type_name::<T>())
61                )),
62            }
63        } else {
64            Err(ExtractionError::MissingState(
65                format!("No state found for type {}", std::any::type_name::<T>())
66            ))
67        };
68
69        Box::pin(async move {
70            match state_result {
71                Ok(state) => Ok(State(state)),
72                Err(err) => Err(err),
73            }
74        })
75    }
76}
77
78/// Container for application state
79#[derive(Clone, Default)]
80pub struct StateMap {
81    states: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
82}
83
84impl StateMap {
85    /// Create a new empty state map
86    pub fn new() -> Self {
87        Self {
88            states: HashMap::new(),
89        }
90    }
91
92    /// Insert state of a specific type
93    pub fn insert<T>(&mut self, state: T)
94    where
95        T: Send + Sync + 'static,
96    {
97        let type_id = TypeId::of::<T>();
98        self.states.insert(type_id, Arc::new(state));
99    }
100
101    /// Get state of a specific type
102    pub fn get<T>(&self) -> Option<&T>
103    where
104        T: Send + Sync + 'static,
105    {
106        let type_id = TypeId::of::<T>();
107        self.states
108            .get(&type_id)
109            .and_then(|state| state.downcast_ref::<T>())
110    }
111
112    /// Get state by TypeId (used internally)
113    pub(crate) fn get_by_type_id(&self, type_id: TypeId) -> Option<&Arc<dyn Any + Send + Sync>> {
114        self.states.get(&type_id)
115    }
116
117    /// Check if state of a specific type exists
118    pub fn contains<T>(&self) -> bool
119    where
120        T: Send + Sync + 'static,
121    {
122        let type_id = TypeId::of::<T>();
123        self.states.contains_key(&type_id)
124    }
125
126    /// Remove state of a specific type
127    pub fn remove<T>(&mut self) -> Option<Arc<dyn Any + Send + Sync>>
128    where
129        T: Send + Sync + 'static,
130    {
131        let type_id = TypeId::of::<T>();
132        self.states.remove(&type_id)
133    }
134
135    /// Get the number of stored states
136    pub fn len(&self) -> usize {
137        self.states.len()
138    }
139
140    /// Check if the state map is empty
141    pub fn is_empty(&self) -> bool {
142        self.states.is_empty()
143    }
144}
145
146/// Extension trait for Request to handle state
147pub trait RequestStateExt {
148    /// Get state by TypeId
149    fn get_state(&self, type_id: TypeId) -> Option<&Arc<dyn Any + Send + Sync>>;
150    
151    /// Set the state map for this request
152    fn set_state_map(&mut self, state_map: StateMap);
153    
154    /// Get a reference to the state map
155    fn state_map(&self) -> Option<&StateMap>;
156}
157
158// We'll implement this trait for Request in the request.rs file
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[derive(Clone, Debug, PartialEq)]
165    struct TestState {
166        value: u32,
167    }
168
169    #[derive(Clone, Debug, PartialEq)]
170    struct AnotherState {
171        name: String,
172    }
173
174    #[test]
175    fn test_state_map_insert_and_get() {
176        let mut state_map = StateMap::new();
177        
178        let test_state = TestState { value: 42 };
179        state_map.insert(test_state.clone());
180        
181        let retrieved = state_map.get::<TestState>();
182        assert_eq!(retrieved, Some(&test_state));
183    }
184
185    #[test]
186    fn test_state_map_multiple_types() {
187        let mut state_map = StateMap::new();
188        
189        let test_state = TestState { value: 42 };
190        let another_state = AnotherState { name: "test".to_string() };
191        
192        state_map.insert(test_state.clone());
193        state_map.insert(another_state.clone());
194        
195        assert_eq!(state_map.get::<TestState>(), Some(&test_state));
196        assert_eq!(state_map.get::<AnotherState>(), Some(&another_state));
197    }
198
199    #[test]
200    fn test_state_map_missing_type() {
201        let state_map = StateMap::new();
202        let retrieved = state_map.get::<TestState>();
203        assert_eq!(retrieved, None);
204    }
205
206    #[test]
207    fn test_state_map_contains() {
208        let mut state_map = StateMap::new();
209        
210        assert!(!state_map.contains::<TestState>());
211        
212        state_map.insert(TestState { value: 42 });
213        
214        assert!(state_map.contains::<TestState>());
215        assert!(!state_map.contains::<AnotherState>());
216    }
217
218    #[test]
219    fn test_state_map_remove() {
220        let mut state_map = StateMap::new();
221        
222        let test_state = TestState { value: 42 };
223        state_map.insert(test_state.clone());
224        
225        assert!(state_map.contains::<TestState>());
226        
227        let removed = state_map.remove::<TestState>();
228        assert!(removed.is_some());
229        assert!(!state_map.contains::<TestState>());
230    }
231
232    #[test]
233    fn test_state_map_len_and_empty() {
234        let mut state_map = StateMap::new();
235        
236        assert_eq!(state_map.len(), 0);
237        assert!(state_map.is_empty());
238        
239        state_map.insert(TestState { value: 42 });
240        
241        assert_eq!(state_map.len(), 1);
242        assert!(!state_map.is_empty());
243        
244        state_map.insert(AnotherState { name: "test".to_string() });
245        
246        assert_eq!(state_map.len(), 2);
247    }
248}