web_server_abstraction/
state.rs

1//! Application state management for sharing data across requests.
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7/// Application state container that can store any type
8#[derive(Clone, Debug)]
9pub struct AppState {
10    data: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
11}
12
13impl AppState {
14    /// Create a new empty state container
15    pub fn new() -> Self {
16        Self {
17            data: Arc::new(RwLock::new(HashMap::new())),
18        }
19    }
20
21    /// Insert a value into the state
22    pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
23        let mut data = self.data.write().unwrap();
24        data.insert(TypeId::of::<T>(), Box::new(value));
25    }
26
27    /// Get a value from the state
28    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
29        let data = self.data.read().unwrap();
30        let value = data.get(&TypeId::of::<T>())?;
31        let value_ref = value.downcast_ref::<T>()?;
32        // Clone the value to return it safely
33        Some(value_ref.clone())
34    }
35
36    /// Check if a type exists in the state
37    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
38        let data = self.data.read().unwrap();
39        data.contains_key(&TypeId::of::<T>())
40    }
41
42    /// Remove a value from the state
43    pub fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
44        let mut data = self.data.write().unwrap();
45        let value = data.remove(&TypeId::of::<T>())?;
46        value.downcast().ok().map(|boxed| *boxed)
47    }
48}
49
50impl Default for AppState {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56/// A better approach for shared state using Arc<T> directly
57#[derive(Clone, Debug)]
58pub struct SharedState<T> {
59    inner: Arc<RwLock<T>>,
60}
61
62impl<T> SharedState<T> {
63    /// Create new shared state
64    pub fn new(value: T) -> Self {
65        Self {
66            inner: Arc::new(RwLock::new(value)),
67        }
68    }
69
70    /// Get a read lock on the value
71    pub fn read(&self) -> std::sync::RwLockReadGuard<'_, T> {
72        self.inner.read().unwrap()
73    }
74
75    /// Get a write lock on the value
76    pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, T> {
77        self.inner.write().unwrap()
78    }
79
80    /// Get the inner Arc for sharing
81    pub fn inner(&self) -> Arc<RwLock<T>> {
82        self.inner.clone()
83    }
84}
85
86/// Configuration for the application
87#[derive(Debug, Clone)]
88pub struct Config {
89    pub host: String,
90    pub port: u16,
91    pub environment: Environment,
92    pub database_url: Option<String>,
93    pub redis_url: Option<String>,
94    pub secret_key: Option<String>,
95    pub cors_origins: Vec<String>,
96    pub max_request_size: usize,
97    pub request_timeout: std::time::Duration,
98}
99
100#[derive(Debug, Clone, PartialEq)]
101pub enum Environment {
102    Development,
103    Staging,
104    Production,
105}
106
107impl Default for Config {
108    fn default() -> Self {
109        Self {
110            host: "127.0.0.1".to_string(),
111            port: 8080,
112            environment: Environment::Development,
113            database_url: None,
114            redis_url: None,
115            secret_key: None,
116            cors_origins: vec!["http://localhost:3000".to_string()],
117            max_request_size: 1024 * 1024 * 10, // 10MB
118            request_timeout: std::time::Duration::from_secs(30),
119        }
120    }
121}
122
123impl Config {
124    /// Load configuration from environment variables
125    pub fn from_env() -> Self {
126        let mut config = Self::default();
127
128        if let Ok(host) = std::env::var("HOST") {
129            config.host = host;
130        }
131
132        if let Ok(port_str) = std::env::var("PORT")
133            && let Ok(port) = port_str.parse()
134        {
135            config.port = port;
136        }
137
138        if let Ok(env_str) = std::env::var("ENVIRONMENT") {
139            config.environment = match env_str.to_lowercase().as_str() {
140                "production" | "prod" => Environment::Production,
141                "staging" | "stage" => Environment::Staging,
142                _ => Environment::Development,
143            };
144        }
145
146        config.database_url = std::env::var("DATABASE_URL").ok();
147        config.redis_url = std::env::var("REDIS_URL").ok();
148        config.secret_key = std::env::var("SECRET_KEY").ok();
149
150        if let Ok(origins) = std::env::var("CORS_ORIGINS") {
151            config.cors_origins = origins.split(',').map(|s| s.trim().to_string()).collect();
152        }
153
154        if let Ok(timeout_str) = std::env::var("REQUEST_TIMEOUT")
155            && let Ok(timeout_secs) = timeout_str.parse::<u64>()
156        {
157            config.request_timeout = std::time::Duration::from_secs(timeout_secs);
158        }
159
160        config
161    }
162
163    /// Check if running in production
164    pub fn is_production(&self) -> bool {
165        self.environment == Environment::Production
166    }
167
168    /// Check if running in development
169    pub fn is_development(&self) -> bool {
170        self.environment == Environment::Development
171    }
172
173    /// Get the server bind address
174    pub fn bind_address(&self) -> String {
175        format!("{}:{}", self.host, self.port)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_shared_state() {
185        let state = SharedState::new(42);
186
187        // Test read
188        assert_eq!(*state.read(), 42);
189
190        // Test write
191        *state.write() = 100;
192        assert_eq!(*state.read(), 100);
193    }
194
195    #[test]
196    fn test_config_default() {
197        let config = Config::default();
198        assert_eq!(config.host, "127.0.0.1");
199        assert_eq!(config.port, 8080);
200        assert_eq!(config.environment, Environment::Development);
201        assert!(config.is_development());
202        assert!(!config.is_production());
203    }
204
205    #[test]
206    fn test_config_bind_address() {
207        let config = Config {
208            host: "0.0.0.0".to_string(),
209            port: 3000,
210            ..Default::default()
211        };
212        assert_eq!(config.bind_address(), "0.0.0.0:3000");
213    }
214}