Skip to main content

smith_config/
http.rs

1//! HTTP server configuration
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7/// HTTP server configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct HttpConfig {
10    /// Server bind address
11    pub bind_address: String,
12
13    /// Server port
14    pub port: u16,
15
16    /// Smith service connection URL
17    pub smith_service_url: String,
18
19    /// JWT secret for authentication
20    pub jwt_secret: String,
21
22    /// Enable CORS
23    pub cors_enabled: bool,
24
25    /// Rate limiting configuration
26    pub rate_limit: RateLimitConfig,
27
28    /// WebSocket configuration
29    pub websocket: WebSocketConfig,
30
31    /// Security configuration
32    pub security: SecurityConfig,
33
34    /// Performance configuration
35    pub performance: PerformanceConfig,
36}
37
38/// Rate limiting configuration
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct RateLimitConfig {
41    /// Requests per minute per client
42    pub requests_per_minute: u32,
43
44    /// Burst allowance
45    pub burst_size: u32,
46
47    /// WebSocket messages per minute per connection
48    pub websocket_messages_per_minute: u32,
49
50    /// Maximum concurrent WebSocket connections per IP
51    pub max_connections_per_ip: u32,
52}
53
54/// WebSocket configuration
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct WebSocketConfig {
57    /// Maximum message size in bytes
58    pub max_message_size: usize,
59
60    /// Ping interval in seconds
61    #[serde(with = "duration_serde")]
62    pub ping_interval: Duration,
63
64    /// Connection timeout in seconds
65    #[serde(with = "duration_serde")]
66    pub connection_timeout: Duration,
67
68    /// Maximum concurrent connections
69    pub max_connections: usize,
70
71    /// Event buffer size per connection
72    pub event_buffer_size: usize,
73
74    /// Heartbeat interval in seconds
75    #[serde(with = "duration_serde")]
76    pub heartbeat_interval: Duration,
77}
78
79/// Security configuration
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct SecurityConfig {
82    /// JWT token expiration in seconds
83    #[serde(with = "duration_serde")]
84    pub jwt_expiration: Duration,
85
86    /// Require authentication for WebSocket connections
87    pub require_auth_websocket: bool,
88
89    /// Require authentication for API endpoints
90    pub require_auth_api: bool,
91
92    /// HTTPS only (for production)
93    pub https_only: bool,
94
95    /// Trusted proxy headers for rate limiting
96    pub trusted_proxies: Vec<String>,
97}
98
99/// Performance configuration
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct PerformanceConfig {
102    /// Event batching size for WebSocket
103    pub event_batch_size: usize,
104
105    /// Event batching timeout in milliseconds
106    #[serde(with = "duration_serde")]
107    pub event_batch_timeout: Duration,
108
109    /// Connection pool size for Smith service
110    pub smith_connection_pool_size: usize,
111
112    /// Request timeout in seconds
113    #[serde(with = "duration_serde")]
114    pub request_timeout: Duration,
115
116    /// Enable Gzip compression
117    pub enable_compression: bool,
118
119    /// Maximum request size in bytes
120    pub max_request_size: usize,
121}
122
123impl Default for HttpConfig {
124    fn default() -> Self {
125        Self {
126            bind_address: "127.0.0.1".to_string(),
127            port: 3000,
128            smith_service_url: "tcp://127.0.0.1:7878".to_string(),
129            jwt_secret: "dev-secret-change-in-production-secure-key".to_string(),
130            cors_enabled: false,
131            rate_limit: RateLimitConfig::default(),
132            websocket: WebSocketConfig::default(),
133            security: SecurityConfig::default(),
134            performance: PerformanceConfig::default(),
135        }
136    }
137}
138
139impl Default for RateLimitConfig {
140    fn default() -> Self {
141        Self {
142            requests_per_minute: 1000,
143            burst_size: 100,
144            websocket_messages_per_minute: 6000, // 100/second
145            max_connections_per_ip: 10,
146        }
147    }
148}
149
150impl Default for WebSocketConfig {
151    fn default() -> Self {
152        Self {
153            max_message_size: 64 * 1024, // 64KB
154            ping_interval: Duration::from_secs(30),
155            connection_timeout: Duration::from_secs(300), // 5 minutes
156            max_connections: 1000,
157            event_buffer_size: 1000,
158            heartbeat_interval: Duration::from_secs(10),
159        }
160    }
161}
162
163impl Default for SecurityConfig {
164    fn default() -> Self {
165        Self {
166            jwt_expiration: Duration::from_secs(24 * 60 * 60), // 24 hours
167            require_auth_websocket: false,                     // Development default
168            require_auth_api: false,                           // Development default
169            https_only: false,                                 // Development default
170            trusted_proxies: vec![],
171        }
172    }
173}
174
175impl Default for PerformanceConfig {
176    fn default() -> Self {
177        Self {
178            event_batch_size: 50,
179            event_batch_timeout: Duration::from_millis(10), // Sub-100ms requirement
180            smith_connection_pool_size: 10,
181            request_timeout: Duration::from_secs(30),
182            enable_compression: true,
183            max_request_size: 16 * 1024 * 1024, // 16MB
184        }
185    }
186}
187
188impl HttpConfig {
189    pub fn validate(&self) -> Result<()> {
190        // Validate bind address
191        if self.bind_address.is_empty() {
192            return Err(anyhow::anyhow!("Bind address cannot be empty"));
193        }
194
195        // Validate port range (note: u16 cannot exceed 65535, so just check minimum)
196        if self.port < 1024 {
197            return Err(anyhow::anyhow!(
198                "Port must be between 1024 and 65535, got: {}",
199                self.port
200            ));
201        }
202
203        // Validate Smith service URL
204        if self.smith_service_url.is_empty() {
205            return Err(anyhow::anyhow!("Smith service URL cannot be empty"));
206        }
207
208        // Warn about default JWT secret
209        if self.jwt_secret.contains("dev-secret-change-in-production") {
210            tracing::warn!("⚠️  Using default JWT secret - change this in production!");
211        }
212
213        if self.jwt_secret.len() < 32 {
214            return Err(anyhow::anyhow!("JWT secret must be at least 32 characters"));
215        }
216
217        // Validate sub-configurations
218        self.rate_limit.validate()?;
219        self.websocket.validate()?;
220        self.security.validate()?;
221        self.performance.validate()?;
222
223        Ok(())
224    }
225
226    pub fn development() -> Self {
227        Self {
228            bind_address: "127.0.0.1".to_string(),
229            port: 3000,
230            cors_enabled: true, // Allow CORS for development
231            security: SecurityConfig {
232                require_auth_websocket: false,
233                require_auth_api: false,
234                https_only: false,
235                ..Default::default()
236            },
237            performance: PerformanceConfig {
238                event_batch_timeout: Duration::from_millis(50), // Relaxed for development
239                ..Default::default()
240            },
241            ..Default::default()
242        }
243    }
244
245    pub fn production() -> Self {
246        Self {
247            bind_address: "0.0.0.0".to_string(),
248            port: 3000,
249            cors_enabled: false, // Strict CORS in production
250            security: SecurityConfig {
251                require_auth_websocket: true,
252                require_auth_api: true,
253                https_only: true,
254                jwt_expiration: Duration::from_secs(8 * 60 * 60), // 8 hours
255                trusted_proxies: vec!["127.0.0.1".to_string(), "::1".to_string()],
256            },
257            rate_limit: RateLimitConfig {
258                requests_per_minute: 2000,
259                burst_size: 200,
260                websocket_messages_per_minute: 12000, // 200/second
261                max_connections_per_ip: 20,
262            },
263            performance: PerformanceConfig {
264                event_batch_size: 100,                         // Larger batches
265                event_batch_timeout: Duration::from_millis(5), // Aggressive batching
266                smith_connection_pool_size: 20,
267                request_timeout: Duration::from_secs(15), // Shorter timeout
268                enable_compression: true,
269                max_request_size: 8 * 1024 * 1024, // 8MB (smaller for production)
270            },
271            ..Default::default()
272        }
273    }
274
275    pub fn testing() -> Self {
276        Self {
277            bind_address: "127.0.0.1".to_string(),
278            port: 0, // Let OS assign port for tests
279            cors_enabled: true,
280            rate_limit: RateLimitConfig {
281                requests_per_minute: 100, // Lower limits for tests
282                burst_size: 50,
283                websocket_messages_per_minute: 600,
284                max_connections_per_ip: 5,
285            },
286            websocket: WebSocketConfig {
287                max_connections: 10, // Few connections for tests
288                event_buffer_size: 100,
289                connection_timeout: Duration::from_secs(10), // Shorter timeout
290                ..Default::default()
291            },
292            performance: PerformanceConfig {
293                request_timeout: Duration::from_secs(5), // Shorter for tests
294                max_request_size: 1024 * 1024,           // 1MB for tests
295                ..Default::default()
296            },
297            ..Default::default()
298        }
299    }
300}
301
302impl RateLimitConfig {
303    pub fn validate(&self) -> Result<()> {
304        if self.requests_per_minute == 0 {
305            return Err(anyhow::anyhow!(
306                "Rate limit requests_per_minute must be > 0"
307            ));
308        }
309
310        if self.burst_size == 0 {
311            return Err(anyhow::anyhow!("Rate limit burst_size must be > 0"));
312        }
313
314        if self.websocket_messages_per_minute == 0 {
315            return Err(anyhow::anyhow!("WebSocket rate limit must be > 0"));
316        }
317
318        if self.max_connections_per_ip == 0 {
319            return Err(anyhow::anyhow!("Max connections per IP must be > 0"));
320        }
321
322        Ok(())
323    }
324}
325
326impl WebSocketConfig {
327    pub fn validate(&self) -> Result<()> {
328        if self.max_message_size < 1024 {
329            return Err(anyhow::anyhow!("WebSocket max message size must be >= 1KB"));
330        }
331
332        if self.max_message_size > 100 * 1024 * 1024 {
333            return Err(anyhow::anyhow!(
334                "WebSocket max message size must be <= 100MB"
335            ));
336        }
337
338        if self.ping_interval.as_secs() == 0 {
339            return Err(anyhow::anyhow!("WebSocket ping interval must be > 0"));
340        }
341
342        if self.connection_timeout.as_secs() == 0 {
343            return Err(anyhow::anyhow!("WebSocket connection timeout must be > 0"));
344        }
345
346        if self.max_connections == 0 {
347            return Err(anyhow::anyhow!("WebSocket max_connections must be > 0"));
348        }
349
350        if self.event_buffer_size == 0 {
351            return Err(anyhow::anyhow!("WebSocket event buffer size must be > 0"));
352        }
353
354        if self.heartbeat_interval.as_secs() == 0 {
355            return Err(anyhow::anyhow!("WebSocket heartbeat interval must be > 0"));
356        }
357
358        Ok(())
359    }
360}
361
362impl SecurityConfig {
363    pub fn validate(&self) -> Result<()> {
364        if self.jwt_expiration.as_secs() == 0 {
365            return Err(anyhow::anyhow!("JWT expiration must be > 0"));
366        }
367
368        if self.jwt_expiration.as_secs() > 7 * 24 * 60 * 60 {
369            tracing::warn!("JWT expiration > 7 days may be a security risk");
370        }
371
372        // Validate trusted proxy IPs
373        for proxy in &self.trusted_proxies {
374            if proxy.parse::<std::net::IpAddr>().is_err() && proxy != "localhost" {
375                return Err(anyhow::anyhow!("Invalid trusted proxy address: {}", proxy));
376            }
377        }
378
379        Ok(())
380    }
381}
382
383impl PerformanceConfig {
384    pub fn validate(&self) -> Result<()> {
385        if self.event_batch_size == 0 {
386            return Err(anyhow::anyhow!("Event batch size must be > 0"));
387        }
388
389        if self.event_batch_timeout.as_millis() == 0 {
390            return Err(anyhow::anyhow!("Event batch timeout must be > 0"));
391        }
392
393        if self.event_batch_timeout.as_millis() > 100 {
394            tracing::warn!(
395                "⚠️  Event batch timeout > 100ms may not meet sub-100ms latency requirement"
396            );
397        }
398
399        if self.smith_connection_pool_size == 0 {
400            return Err(anyhow::anyhow!("Smith connection pool size must be > 0"));
401        }
402
403        if self.request_timeout.as_secs() == 0 {
404            return Err(anyhow::anyhow!("Request timeout must be > 0"));
405        }
406
407        if self.max_request_size < 1024 {
408            return Err(anyhow::anyhow!("Max request size must be >= 1KB"));
409        }
410
411        Ok(())
412    }
413}
414
415// Helper module for Duration serialization
416mod duration_serde {
417    use serde::{Deserialize, Deserializer, Serializer};
418    use std::time::Duration;
419
420    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
421    where
422        S: Serializer,
423    {
424        serializer.serialize_u64(duration.as_millis() as u64)
425    }
426
427    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
428    where
429        D: Deserializer<'de>,
430    {
431        let millis = u64::deserialize(deserializer)?;
432        Ok(Duration::from_millis(millis))
433    }
434}