torch_web/
production.rs

1//! Production-ready features for high-scale applications
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::Semaphore;
6use crate::{Request, Response, middleware::Middleware};
7
8#[cfg(feature = "monitoring")]
9use {
10    std::sync::atomic::{AtomicU64, Ordering},
11};
12
13// DashMap import removed as it's not currently used
14
15/// Production server configuration for high-scale applications
16#[derive(Debug, Clone)]
17pub struct ProductionConfig {
18    /// Maximum number of concurrent connections
19    pub max_connections: usize,
20    /// Request timeout in seconds
21    pub request_timeout: Duration,
22    /// Keep-alive timeout in seconds
23    pub keep_alive_timeout: Duration,
24    /// Maximum request body size in bytes
25    pub max_body_size: usize,
26    /// Enable request/response compression
27    pub enable_compression: bool,
28    /// Number of worker threads
29    pub worker_threads: Option<usize>,
30    /// Enable HTTP/2
31    pub enable_http2: bool,
32    /// Rate limiting: requests per second per IP
33    pub rate_limit_rps: Option<u32>,
34    /// Enable graceful shutdown
35    pub graceful_shutdown_timeout: Duration,
36}
37
38impl Default for ProductionConfig {
39    fn default() -> Self {
40        Self {
41            max_connections: 10_000,
42            request_timeout: Duration::from_secs(30),
43            keep_alive_timeout: Duration::from_secs(60),
44            max_body_size: 16 * 1024 * 1024, // 16MB
45            enable_compression: true,
46            worker_threads: None, // Use default (number of CPU cores)
47            enable_http2: true,
48            rate_limit_rps: Some(1000), // 1000 requests per second per IP
49            graceful_shutdown_timeout: Duration::from_secs(30),
50        }
51    }
52}
53
54/// Connection pool middleware for managing database connections
55pub struct ConnectionPool<T> {
56    pool: Arc<T>,
57}
58
59impl<T> ConnectionPool<T> 
60where 
61    T: Send + Sync + 'static,
62{
63    pub fn new(pool: T) -> Self {
64        Self {
65            pool: Arc::new(pool),
66        }
67    }
68    
69    pub fn get_pool(&self) -> Arc<T> {
70        self.pool.clone()
71    }
72}
73
74/// Rate limiting middleware
75pub struct RateLimiter {
76    semaphore: Arc<Semaphore>,
77    #[allow(dead_code)]
78    requests_per_second: u32,
79}
80
81impl RateLimiter {
82    pub fn new(requests_per_second: u32) -> Self {
83        Self {
84            semaphore: Arc::new(Semaphore::new(requests_per_second as usize)),
85            requests_per_second,
86        }
87    }
88}
89
90impl Middleware for RateLimiter {
91    fn call(
92        &self,
93        req: Request,
94        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
95    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
96        let semaphore = self.semaphore.clone();
97        Box::pin(async move {
98            // Try to acquire a permit
99            let _permit = match semaphore.try_acquire() {
100                Ok(permit) => permit,
101                Err(_) => {
102                    return Response::with_status(http::StatusCode::TOO_MANY_REQUESTS)
103                        .body("Rate limit exceeded");
104                }
105            };
106            
107            // Process the request
108            next(req).await
109        })
110    }
111}
112
113/// Request timeout middleware
114pub struct RequestTimeout {
115    timeout: Duration,
116}
117
118impl RequestTimeout {
119    pub fn new(timeout: Duration) -> Self {
120        Self { timeout }
121    }
122}
123
124impl Middleware for RequestTimeout {
125    fn call(
126        &self,
127        req: Request,
128        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
129    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
130        let timeout = self.timeout;
131        Box::pin(async move {
132            match tokio::time::timeout(timeout, next(req)).await {
133                Ok(response) => response,
134                Err(_) => Response::with_status(http::StatusCode::REQUEST_TIMEOUT)
135                    .body("Request timeout"),
136            }
137        })
138    }
139}
140
141/// Request size limiting middleware
142pub struct RequestSizeLimit {
143    max_size: usize,
144}
145
146impl RequestSizeLimit {
147    pub fn new(max_size: usize) -> Self {
148        Self { max_size }
149    }
150}
151
152impl Middleware for RequestSizeLimit {
153    fn call(
154        &self,
155        req: Request,
156        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
157    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
158        let max_size = self.max_size;
159        Box::pin(async move {
160            if req.body().len() > max_size {
161                return Response::with_status(http::StatusCode::PAYLOAD_TOO_LARGE)
162                    .body("Request body too large");
163            }
164            next(req).await
165        })
166    }
167}
168
169/// Performance monitoring middleware
170pub struct PerformanceMonitor;
171
172impl Middleware for PerformanceMonitor {
173    fn call(
174        &self,
175        req: Request,
176        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
177    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
178        Box::pin(async move {
179            let start = Instant::now();
180            let method = req.method().clone();
181            let path = req.path().to_string();
182            
183            let response = next(req).await;
184            
185            let duration = start.elapsed();
186            let status = response.status_code();
187            
188            // Log performance metrics (in production, you'd send this to a monitoring system)
189            if duration > Duration::from_millis(1000) {
190                eprintln!(
191                    "SLOW REQUEST: {} {} - {} ({:.2}ms)",
192                    method,
193                    path,
194                    status,
195                    duration.as_secs_f64() * 1000.0
196                );
197            }
198            
199            response
200        })
201    }
202}
203
204/// Health check middleware
205pub fn health_check() -> impl Middleware {
206    |req: Request, next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>| {
207        Box::pin(async move {
208            if req.path() == "/health" {
209                return Response::ok()
210                    .json(&{
211                        #[cfg(feature = "monitoring")]
212                        {
213                            serde_json::json!({
214                                "status": "healthy",
215                                "timestamp": chrono::Utc::now().to_rfc3339(),
216                                "uptime": "unknown"
217                            })
218                        }
219                        #[cfg(not(feature = "monitoring"))]
220                        {
221                            serde_json::json!({
222                                "status": "healthy",
223                                "timestamp": "unknown",
224                                "uptime": "unknown"
225                            })
226                        }
227                    })
228                    .unwrap_or_else(|_| Response::ok().body("healthy"));
229            }
230            next(req).await
231        })
232    }
233}
234
235/// Advanced metrics collection middleware with real monitoring integration
236pub struct MetricsCollector {
237    #[cfg(feature = "monitoring")]
238    request_counter: Arc<AtomicU64>,
239    #[cfg(feature = "monitoring")]
240    active_requests: Arc<AtomicU64>,
241    #[cfg(not(feature = "monitoring"))]
242    _phantom: std::marker::PhantomData<()>,
243}
244
245impl MetricsCollector {
246    pub fn new() -> Self {
247        Self {
248            #[cfg(feature = "monitoring")]
249            request_counter: Arc::new(AtomicU64::new(0)),
250            #[cfg(feature = "monitoring")]
251            active_requests: Arc::new(AtomicU64::new(0)),
252            #[cfg(not(feature = "monitoring"))]
253            _phantom: std::marker::PhantomData,
254        }
255    }
256
257    #[cfg(feature = "monitoring")]
258    pub fn get_request_count(&self) -> u64 {
259        self.request_counter.load(Ordering::Relaxed)
260    }
261
262    #[cfg(feature = "monitoring")]
263    pub fn get_active_requests(&self) -> u64 {
264        self.active_requests.load(Ordering::Relaxed)
265    }
266}
267
268impl Middleware for MetricsCollector {
269    fn call(
270        &self,
271        req: Request,
272        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
273    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
274        Box::pin(async move {
275            let start = Instant::now();
276            let method = req.method().clone();
277            let path = req.path().to_string();
278            
279            let response = next(req).await;
280            
281            let duration = start.elapsed();
282            let status = response.status_code();
283            
284            // In production, send metrics to your monitoring system
285            // For now, just log them
286            println!(
287                "METRIC: method={} path={} status={} duration_ms={:.2}",
288                method,
289                path,
290                status.as_u16(),
291                duration.as_secs_f64() * 1000.0
292            );
293            
294            response
295        })
296    }
297}
298
299#[cfg(disabled_for_now)]
300mod tests {
301    use super::*;
302    use std::pin::Pin;
303    use std::future::Future;
304    use crate::Response;
305
306    #[tokio::test]
307    async fn test_rate_limiter() {
308        let rate_limiter = RateLimiter::new(1);
309        
310        let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
311            Box::pin(async { Response::ok().body("success") })
312        });
313
314        let req = crate::Request::from_hyper(
315            http::Request::builder()
316                .method("GET")
317                .uri("/")
318                .body(())
319                .unwrap()
320                .into_parts()
321                .0,
322            Vec::new(),
323        )
324        .await
325        .unwrap();
326
327        // First request should succeed
328        let response = rate_limiter.call(req, next.clone()).await;
329        assert_eq!(response.status_code(), http::StatusCode::OK);
330    }
331
332    #[tokio::test]
333    async fn test_request_timeout() {
334        let timeout_middleware = RequestTimeout::new(Duration::from_millis(100));
335        
336        let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
337            Box::pin(async {
338                tokio::time::sleep(Duration::from_millis(200)).await;
339                Response::ok().body("too slow")
340            })
341        });
342
343        let req = crate::Request::from_hyper(
344            http::Request::builder()
345                .method("GET")
346                .uri("/")
347                .body(())
348                .unwrap()
349                .into_parts()
350                .0,
351            Vec::new(),
352        )
353        .await
354        .unwrap();
355
356        let response = timeout_middleware.call(req, next).await;
357        assert_eq!(response.status_code(), http::StatusCode::REQUEST_TIMEOUT);
358    }
359}