web_server_abstraction/
state.rs1use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Debug)]
9pub struct AppState {
10 data: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
11}
12
13impl AppState {
14 pub fn new() -> Self {
16 Self {
17 data: Arc::new(RwLock::new(HashMap::new())),
18 }
19 }
20
21 pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
23 let mut data = self.data.write().unwrap();
24 data.insert(TypeId::of::<T>(), Box::new(value));
25 }
26
27 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
29 let data = self.data.read().unwrap();
30 let value = data.get(&TypeId::of::<T>())?;
31 let value_ref = value.downcast_ref::<T>()?;
32 Some(value_ref.clone())
34 }
35
36 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
38 let data = self.data.read().unwrap();
39 data.contains_key(&TypeId::of::<T>())
40 }
41
42 pub fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
44 let mut data = self.data.write().unwrap();
45 let value = data.remove(&TypeId::of::<T>())?;
46 value.downcast().ok().map(|boxed| *boxed)
47 }
48}
49
50impl Default for AppState {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct SharedState<T> {
59 inner: Arc<RwLock<T>>,
60}
61
62impl<T> SharedState<T> {
63 pub fn new(value: T) -> Self {
65 Self {
66 inner: Arc::new(RwLock::new(value)),
67 }
68 }
69
70 pub fn read(&self) -> std::sync::RwLockReadGuard<'_, T> {
72 self.inner.read().unwrap()
73 }
74
75 pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, T> {
77 self.inner.write().unwrap()
78 }
79
80 pub fn inner(&self) -> Arc<RwLock<T>> {
82 self.inner.clone()
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct Config {
89 pub host: String,
90 pub port: u16,
91 pub environment: Environment,
92 pub database_url: Option<String>,
93 pub redis_url: Option<String>,
94 pub secret_key: Option<String>,
95 pub cors_origins: Vec<String>,
96 pub max_request_size: usize,
97 pub request_timeout: std::time::Duration,
98}
99
100#[derive(Debug, Clone, PartialEq)]
101pub enum Environment {
102 Development,
103 Staging,
104 Production,
105}
106
107impl Default for Config {
108 fn default() -> Self {
109 Self {
110 host: "127.0.0.1".to_string(),
111 port: 8080,
112 environment: Environment::Development,
113 database_url: None,
114 redis_url: None,
115 secret_key: None,
116 cors_origins: vec!["http://localhost:3000".to_string()],
117 max_request_size: 1024 * 1024 * 10, request_timeout: std::time::Duration::from_secs(30),
119 }
120 }
121}
122
123impl Config {
124 pub fn from_env() -> Self {
126 let mut config = Self::default();
127
128 if let Ok(host) = std::env::var("HOST") {
129 config.host = host;
130 }
131
132 if let Ok(port_str) = std::env::var("PORT")
133 && let Ok(port) = port_str.parse()
134 {
135 config.port = port;
136 }
137
138 if let Ok(env_str) = std::env::var("ENVIRONMENT") {
139 config.environment = match env_str.to_lowercase().as_str() {
140 "production" | "prod" => Environment::Production,
141 "staging" | "stage" => Environment::Staging,
142 _ => Environment::Development,
143 };
144 }
145
146 config.database_url = std::env::var("DATABASE_URL").ok();
147 config.redis_url = std::env::var("REDIS_URL").ok();
148 config.secret_key = std::env::var("SECRET_KEY").ok();
149
150 if let Ok(origins) = std::env::var("CORS_ORIGINS") {
151 config.cors_origins = origins.split(',').map(|s| s.trim().to_string()).collect();
152 }
153
154 if let Ok(timeout_str) = std::env::var("REQUEST_TIMEOUT")
155 && let Ok(timeout_secs) = timeout_str.parse::<u64>()
156 {
157 config.request_timeout = std::time::Duration::from_secs(timeout_secs);
158 }
159
160 config
161 }
162
163 pub fn is_production(&self) -> bool {
165 self.environment == Environment::Production
166 }
167
168 pub fn is_development(&self) -> bool {
170 self.environment == Environment::Development
171 }
172
173 pub fn bind_address(&self) -> String {
175 format!("{}:{}", self.host, self.port)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_shared_state() {
185 let state = SharedState::new(42);
186
187 assert_eq!(*state.read(), 42);
189
190 *state.write() = 100;
192 assert_eq!(*state.read(), 100);
193 }
194
195 #[test]
196 fn test_config_default() {
197 let config = Config::default();
198 assert_eq!(config.host, "127.0.0.1");
199 assert_eq!(config.port, 8080);
200 assert_eq!(config.environment, Environment::Development);
201 assert!(config.is_development());
202 assert!(!config.is_production());
203 }
204
205 #[test]
206 fn test_config_bind_address() {
207 let config = Config {
208 host: "0.0.0.0".to_string(),
209 port: 3000,
210 ..Default::default()
211 };
212 assert_eq!(config.bind_address(), "0.0.0.0:3000");
213 }
214}