wsforge_core/
state.rs

1//! Shared application state management.
2//!
3//! This module provides a type-safe, thread-safe container for storing and retrieving
4//! application state that needs to be shared across all WebSocket connections. State
5//! is commonly used for database connections, configuration, caches, and other shared
6//! resources.
7//!
8//! # Overview
9//!
10//! The [`AppState`] type uses a type-map pattern, allowing you to store multiple
11//! different types of state in a single container. Each type is stored separately
12//! and can be retrieved by its type, ensuring type safety at compile time.
13//!
14//! # Design
15//!
16//! - **Type-safe**: Each state type is stored and retrieved by its exact type
17//! - **Thread-safe**: Uses `Arc` and `DashMap` for lock-free concurrent access
18//! - **Zero-cost abstraction**: No runtime overhead when state is not used
19//! - **Flexible**: Any type that is `Send + Sync + 'static` can be stored
20//!
21//! # Common Use Cases
22//!
23//! | Use Case | Example Type | Description |
24//! |----------|--------------|-------------|
25//! | Database | `Arc<DatabasePool>` | Shared database connection pool |
26//! | Configuration | `Arc<Config>` | Application settings and configuration |
27//! | Cache | `Arc<Cache>` | In-memory cache for frequently accessed data |
28//! | Metrics | `Arc<Metrics>` | Performance metrics and monitoring |
29//! | Connection Manager | `Arc<ConnectionManager>` | Manage all active WebSocket connections |
30//!
31//! # Examples
32//!
33//! ## Single State Type
34//!
35//! ```
36//! use wsforge::prelude::*;
37//! use std::sync::Arc;
38//!
39//! struct Database {
40//!     connection_string: String,
41//! }
42//!
43//! async fn query_handler(State(db): State<Arc<Database>>) -> Result<String> {
44//!     Ok(format!("Connected to: {}", db.connection_string))
45//! }
46//!
47//! # fn example() {
48//! let db = Arc::new(Database {
49//!     connection_string: "postgres://localhost/mydb".to_string(),
50//! });
51//!
52//! let router = Router::new()
53//!     .with_state(db)
54//!     .default_handler(handler(query_handler));
55//! # }
56//! ```
57//!
58//! ## Multiple State Types
59//!
60//! ```
61//! use wsforge::prelude::*;
62//! use std::sync::Arc;
63//!
64//! struct Database {
65//!     url: String,
66//! }
67//!
68//! struct Config {
69//!     max_connections: usize,
70//!     timeout_seconds: u64,
71//! }
72//!
73//! struct Cache {
74//!     data: std::collections::HashMap<String, String>,
75//! }
76//!
77//! async fn handler(
78//!     State(db): State<Arc<Database>>,
79//!     State(config): State<Arc<Config>>,
80//!     State(cache): State<Arc<Cache>>,
81//! ) -> Result<String> {
82//!     Ok(format!(
83//!         "DB: {}, Max: {}, Cache size: {}",
84//!         db.url,
85//!         config.max_connections,
86//!         cache.data.len()
87//!     ))
88//! }
89//!
90//! # fn example() {
91//! let router = Router::new()
92//!     .with_state(Arc::new(Database { url: "...".to_string() }))
93//!     .with_state(Arc::new(Config { max_connections: 100, timeout_seconds: 30 }))
94//!     .with_state(Arc::new(Cache { data: Default::default() }))
95//!     .default_handler(handler(handler));
96//! # }
97//! ```
98//!
99//! ## Mutable State with RwLock
100//!
101//! ```
102//! use wsforge::prelude::*;
103//! use std::sync::{Arc, RwLock};
104//!
105//! struct Counter {
106//!     value: RwLock<u64>,
107//! }
108//!
109//! impl Counter {
110//!     fn increment(&self) {
111//!         let mut value = self.value.write().unwrap();
112//!         *value += 1;
113//!     }
114//!
115//!     fn get(&self) -> u64 {
116//!         *self.value.read().unwrap()
117//!     }
118//! }
119//!
120//! async fn count_handler(State(counter): State<Arc<Counter>>) -> Result<String> {
121//!     counter.increment();
122//!     Ok(format!("Count: {}", counter.get()))
123//! }
124//!
125//! # fn example() {
126//! let counter = Arc::new(Counter {
127//!     value: RwLock::new(0),
128//! });
129//!
130//! let router = Router::new()
131//!     .with_state(counter)
132//!     .default_handler(handler(count_handler));
133//! # }
134//! ```
135
136use dashmap::DashMap;
137use std::any::{Any, TypeId};
138use std::sync::Arc;
139
140/// A type-safe container for shared application state.
141///
142/// `AppState` allows you to store multiple different types of state in a single
143/// container. Each type is identified by its `TypeId`, ensuring type safety when
144/// retrieving state.
145///
146/// # Thread Safety
147///
148/// `AppState` is fully thread-safe and can be cloned cheaply (uses `Arc` internally).
149/// Multiple handlers can access the same state concurrently without additional
150/// synchronization.
151///
152/// # Memory Management
153///
154/// State is stored using `Arc`, so cloning `AppState` or extracting state with
155/// the `State` extractor only increments a reference count. The actual state
156/// data is shared across all references.
157///
158/// # Type Requirements
159///
160/// Types stored in `AppState` must be:
161/// - `Send`: Can be sent between threads
162/// - `Sync`: Can be referenced from multiple threads
163/// - `'static`: Has a static lifetime
164///
165/// # Examples
166///
167/// ## Creating and Using State
168///
169/// ```
170/// use wsforge::prelude::*;
171/// use std::sync::Arc;
172///
173/// # fn example() {
174/// // Create empty state
175/// let state = AppState::new();
176///
177/// // Add some data
178/// state.insert(Arc::new("Hello".to_string()));
179/// state.insert(Arc::new(42_u32));
180///
181/// // Retrieve data
182/// let text: Option<Arc<String>> = state.get();
183/// assert_eq!(*text.unwrap(), "Hello");
184///
185/// let number: Option<Arc<u32>> = state.get();
186/// assert_eq!(*number.unwrap(), 42);
187/// # }
188/// ```
189///
190/// ## With Router
191///
192/// ```
193/// use wsforge::prelude::*;
194/// use std::sync::Arc;
195///
196/// struct AppConfig {
197///     name: String,
198///     version: String,
199/// }
200///
201/// async fn info_handler(State(config): State<Arc<AppConfig>>) -> Result<String> {
202///     Ok(format!("{} v{}", config.name, config.version))
203/// }
204///
205/// # fn example() {
206/// let config = Arc::new(AppConfig {
207///     name: "MyApp".to_string(),
208///     version: "1.0.0".to_string(),
209/// });
210///
211/// let router = Router::new()
212///     .with_state(config)
213///     .default_handler(handler(info_handler));
214/// # }
215/// ```
216///
217/// ## Complex State Management
218///
219/// ```
220/// use wsforge::prelude::*;
221/// use std::sync::Arc;
222/// use std::collections::HashMap;
223///
224/// struct UserStore {
225///     users: tokio::sync::RwLock<HashMap<u64, String>>,
226/// }
227///
228/// impl UserStore {
229///     fn new() -> Self {
230///         Self {
231///             users: tokio::sync::RwLock::new(HashMap::new()),
232///         }
233///     }
234///
235///     async fn add_user(&self, id: u64, name: String) {
236///         self.users.write().await.insert(id, name);
237///     }
238///
239///     async fn get_user(&self, id: u64) -> Option<String> {
240///         self.users.read().await.get(&id).cloned()
241///     }
242/// }
243///
244/// async fn user_handler(
245///     State(store): State<Arc<UserStore>>,
246/// ) -> Result<String> {
247///     store.add_user(1, "Alice".to_string()).await;
248///     let user = store.get_user(1).await;
249///     Ok(format!("User: {:?}", user))
250/// }
251/// # }
252/// ```
253#[derive(Clone)]
254pub struct AppState {
255    /// Internal storage mapping TypeId to Arc-wrapped values
256    data: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
257}
258
259impl AppState {
260    /// Creates a new empty `AppState`.
261    ///
262    /// The state starts with no data. Use [`insert`](Self::insert) to add state.
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// use wsforge::prelude::*;
268    ///
269    /// let state = AppState::new();
270    /// ```
271    pub fn new() -> Self {
272        Self {
273            data: Arc::new(DashMap::new()),
274        }
275    }
276
277    /// Inserts a value into the state.
278    ///
279    /// If a value of the same type already exists, it will be replaced.
280    /// The value is automatically wrapped in an `Arc`.
281    ///
282    /// # Type Requirements
283    ///
284    /// The type `T` must implement:
285    /// - `Send`: Can be transferred across thread boundaries
286    /// - `Sync`: Can be safely shared between threads
287    /// - `'static`: Has a static lifetime (no borrowed data)
288    ///
289    /// # Arguments
290    ///
291    /// * `value` - The value to store in state
292    ///
293    /// # Examples
294    ///
295    /// ## Basic Usage
296    ///
297    /// ```
298    /// use wsforge::prelude::*;
299    /// use std::sync::Arc;
300    ///
301    /// # fn example() {
302    /// let state = AppState::new();
303    ///
304    /// // Insert different types
305    /// state.insert(Arc::new(String::from("Hello")));
306    /// state.insert(Arc::new(42_u32));
307    /// state.insert(Arc::new(true));
308    /// # }
309    /// ```
310    ///
311    /// ## Replacing Values
312    ///
313    /// ```
314    /// use wsforge::prelude::*;
315    /// use std::sync::Arc;
316    ///
317    /// # fn example() {
318    /// let state = AppState::new();
319    ///
320    /// state.insert(Arc::new(10_u32));
321    /// assert_eq!(*state.get::<u32>().unwrap(), 10);
322    ///
323    /// // Replace with new value
324    /// state.insert(Arc::new(20_u32));
325    /// assert_eq!(*state.get::<u32>().unwrap(), 20);
326    /// # }
327    /// ```
328    ///
329    /// ## Custom Types
330    ///
331    /// ```
332    /// use wsforge::prelude::*;
333    /// use std::sync::Arc;
334    ///
335    /// struct Database {
336    ///     url: String,
337    /// }
338    ///
339    /// # fn example() {
340    /// let state = AppState::new();
341    ///
342    /// let db = Arc::new(Database {
343    ///     url: "postgres://localhost/mydb".to_string(),
344    /// });
345    ///
346    /// state.insert(db);
347    /// # }
348    /// ```
349    pub fn insert<T: Send + Sync + 'static>(&self, value: Arc<T>) {
350        self.data.insert(TypeId::of::<T>(), value);
351    }
352
353    /// Retrieves a value from the state by its type.
354    ///
355    /// Returns `None` if no value of type `T` has been stored.
356    /// Returns `Some(Arc<T>)` if a value exists.
357    ///
358    /// # Type Safety
359    ///
360    /// The returned value is guaranteed to be of type `T` because values
361    /// are stored and retrieved using `TypeId`.
362    ///
363    /// # Performance
364    ///
365    /// This operation is O(1) and lock-free, making it very efficient even
366    /// with concurrent access from multiple threads.
367    ///
368    /// # Examples
369    ///
370    /// ## Basic Retrieval
371    ///
372    /// ```
373    /// use wsforge::prelude::*;
374    /// use std::sync::Arc;
375    ///
376    /// # fn example() {
377    /// let state = AppState::new();
378    /// state.insert(Arc::new(String::from("Hello")));
379    ///
380    /// let text: Option<Arc<String>> = state.get();
381    /// assert_eq!(*text.unwrap(), "Hello");
382    ///
383    /// // Trying to get a type that doesn't exist
384    /// let number: Option<Arc<u32>> = state.get();
385    /// assert!(number.is_none());
386    /// # }
387    /// ```
388    ///
389    /// ## Pattern Matching
390    ///
391    /// ```
392    /// use wsforge::prelude::*;
393    /// use std::sync::Arc;
394    ///
395    /// # fn example() {
396    /// let state = AppState::new();
397    /// state.insert(Arc::new(42_u32));
398    ///
399    /// match state.get::<u32>() {
400    ///     Some(value) => println!("Found: {}", value),
401    ///     None => println!("Not found"),
402    /// }
403    /// # }
404    /// ```
405    ///
406    /// ## Multiple Types
407    ///
408    /// ```
409    /// use wsforge::prelude::*;
410    /// use std::sync::Arc;
411    ///
412    /// struct Config { port: u16 }
413    /// struct Database { url: String }
414    ///
415    /// # fn example() {
416    /// let state = AppState::new();
417    /// state.insert(Arc::new(Config { port: 8080 }));
418    /// state.insert(Arc::new(Database { url: "...".to_string() }));
419    ///
420    /// // Each type is stored separately
421    /// let config: Arc<Config> = state.get().unwrap();
422    /// let db: Arc<Database> = state.get().unwrap();
423    ///
424    /// println!("Port: {}, DB: {}", config.port, db.url);
425    /// # }
426    /// ```
427    ///
428    /// ## With Error Handling
429    ///
430    /// ```
431    /// use wsforge::prelude::*;
432    /// use std::sync::Arc;
433    ///
434    /// struct Database;
435    ///
436    /// # fn example() -> Result<()> {
437    /// let state = AppState::new();
438    ///
439    /// let db = state
440    ///     .get::<Database>()
441    ///     .ok_or_else(|| Error::custom("Database not configured"))?;
442    ///
443    /// // Use db...
444    /// # Ok(())
445    /// # }
446    /// ```
447    pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
448        self.data
449            .get(&TypeId::of::<T>())
450            .and_then(|arc| arc.value().clone().downcast::<T>().ok())
451    }
452
453    /// Checks if a value of type `T` exists in the state.
454    ///
455    /// This is equivalent to `state.get::<T>().is_some()` but more explicit.
456    ///
457    /// # Examples
458    ///
459    /// ```
460    /// use wsforge::prelude::*;
461    /// use std::sync::Arc;
462    ///
463    /// # fn example() {
464    /// let state = AppState::new();
465    /// state.insert(Arc::new(42_u32));
466    ///
467    /// assert!(state.contains::<u32>());
468    /// assert!(!state.contains::<String>());
469    /// # }
470    /// ```
471    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
472        self.data.contains_key(&TypeId::of::<T>())
473    }
474
475    /// Removes a value of type `T` from the state.
476    ///
477    /// Returns the removed value if it existed, or `None` otherwise.
478    ///
479    /// # Examples
480    ///
481    /// ```
482    /// use wsforge::prelude::*;
483    /// use std::sync::Arc;
484    ///
485    /// # fn example() {
486    /// let state = AppState::new();
487    /// state.insert(Arc::new(42_u32));
488    ///
489    /// assert!(state.contains::<u32>());
490    ///
491    /// let value = state.remove::<u32>();
492    /// assert_eq!(*value.unwrap(), 42);
493    ///
494    /// assert!(!state.contains::<u32>());
495    /// # }
496    /// ```
497    pub fn remove<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
498        self.data
499            .remove(&TypeId::of::<T>())
500            .and_then(|(_, arc)| arc.downcast::<T>().ok())
501    }
502
503    /// Returns the number of different types stored in the state.
504    ///
505    /// # Examples
506    ///
507    /// ```
508    /// use wsforge::prelude::*;
509    /// use std::sync::Arc;
510    ///
511    /// # fn example() {
512    /// let state = AppState::new();
513    /// assert_eq!(state.len(), 0);
514    ///
515    /// state.insert(Arc::new(String::from("Hello")));
516    /// assert_eq!(state.len(), 1);
517    ///
518    /// state.insert(Arc::new(42_u32));
519    /// assert_eq!(state.len(), 2);
520    ///
521    /// // Replacing same type doesn't increase count
522    /// state.insert(Arc::new(100_u32));
523    /// assert_eq!(state.len(), 2);
524    /// # }
525    /// ```
526    pub fn len(&self) -> usize {
527        self.data.len()
528    }
529
530    /// Checks if the state is empty (contains no data).
531    ///
532    /// # Examples
533    ///
534    /// ```
535    /// use wsforge::prelude::*;
536    /// use std::sync::Arc;
537    ///
538    /// # fn example() {
539    /// let state = AppState::new();
540    /// assert!(state.is_empty());
541    ///
542    /// state.insert(Arc::new(42_u32));
543    /// assert!(!state.is_empty());
544    /// # }
545    /// ```
546    pub fn is_empty(&self) -> bool {
547        self.data.is_empty()
548    }
549
550    /// Clears all state data.
551    ///
552    /// Removes all stored values, leaving the state empty.
553    ///
554    /// # Examples
555    ///
556    /// ```
557    /// use wsforge::prelude::*;
558    /// use std::sync::Arc;
559    ///
560    /// # fn example() {
561    /// let state = AppState::new();
562    /// state.insert(Arc::new(String::from("Hello")));
563    /// state.insert(Arc::new(42_u32));
564    ///
565    /// assert_eq!(state.len(), 2);
566    ///
567    /// state.clear();
568    ///
569    /// assert_eq!(state.len(), 0);
570    /// assert!(state.is_empty());
571    /// # }
572    /// ```
573    pub fn clear(&self) {
574        self.data.clear();
575    }
576}
577
578impl Default for AppState {
579    fn default() -> Self {
580        Self::new()
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn test_insert_and_get() {
590        let state = AppState::new();
591        state.insert(Arc::new(String::from("test")));
592
593        let value: Option<Arc<String>> = state.get();
594        assert_eq!(*value.unwrap(), "test");
595    }
596
597    #[test]
598    fn test_multiple_types() {
599        let state = AppState::new();
600        state.insert(Arc::new(42_u32));
601        state.insert(Arc::new(String::from("hello")));
602        state.insert(Arc::new(true));
603
604        assert_eq!(*state.get::<u32>().unwrap(), 42);
605        assert_eq!(*state.get::<String>().unwrap(), "hello");
606        assert_eq!(*state.get::<bool>().unwrap(), true);
607    }
608
609    #[test]
610    fn test_get_nonexistent() {
611        let state = AppState::new();
612        let value: Option<Arc<String>> = state.get();
613        assert!(value.is_none());
614    }
615
616    #[test]
617    fn test_contains() {
618        let state = AppState::new();
619        assert!(!state.contains::<u32>());
620
621        state.insert(Arc::new(42_u32));
622        assert!(state.contains::<u32>());
623    }
624
625    #[test]
626    fn test_remove() {
627        let state = AppState::new();
628        state.insert(Arc::new(42_u32));
629
630        assert!(state.contains::<u32>());
631
632        let removed = state.remove::<u32>();
633        assert_eq!(*removed.unwrap(), 42);
634        assert!(!state.contains::<u32>());
635    }
636
637    #[test]
638    fn test_len_and_empty() {
639        let state = AppState::new();
640        assert!(state.is_empty());
641        assert_eq!(state.len(), 0);
642
643        state.insert(Arc::new(42_u32));
644        assert!(!state.is_empty());
645        assert_eq!(state.len(), 1);
646
647        state.insert(Arc::new(String::from("test")));
648        assert_eq!(state.len(), 2);
649    }
650
651    #[test]
652    fn test_clear() {
653        let state = AppState::new();
654        state.insert(Arc::new(42_u32));
655        state.insert(Arc::new(String::from("test")));
656
657        assert_eq!(state.len(), 2);
658
659        state.clear();
660
661        assert_eq!(state.len(), 0);
662        assert!(state.is_empty());
663    }
664
665    #[test]
666    fn test_replace_value() {
667        let state = AppState::new();
668        state.insert(Arc::new(10_u32));
669        assert_eq!(*state.get::<u32>().unwrap(), 10);
670
671        state.insert(Arc::new(20_u32));
672        assert_eq!(*state.get::<u32>().unwrap(), 20);
673    }
674
675    #[test]
676    fn test_clone() {
677        let state1 = AppState::new();
678        state1.insert(Arc::new(42_u32));
679
680        let state2 = state1.clone();
681        assert_eq!(*state2.get::<u32>().unwrap(), 42);
682
683        // Both share the same data
684        state2.insert(Arc::new(100_u32));
685        assert_eq!(*state1.get::<u32>().unwrap(), 100);
686    }
687}