1use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::Semaphore;
6use crate::{Request, Response, middleware::Middleware};
7
8#[cfg(feature = "monitoring")]
9use {
10 std::sync::atomic::{AtomicU64, Ordering},
11};
12
13#[derive(Debug, Clone)]
17pub struct ProductionConfig {
18 pub max_connections: usize,
20 pub request_timeout: Duration,
22 pub keep_alive_timeout: Duration,
24 pub max_body_size: usize,
26 pub enable_compression: bool,
28 pub worker_threads: Option<usize>,
30 pub enable_http2: bool,
32 pub rate_limit_rps: Option<u32>,
34 pub graceful_shutdown_timeout: Duration,
36}
37
38impl Default for ProductionConfig {
39 fn default() -> Self {
40 Self {
41 max_connections: 10_000,
42 request_timeout: Duration::from_secs(30),
43 keep_alive_timeout: Duration::from_secs(60),
44 max_body_size: 16 * 1024 * 1024, enable_compression: true,
46 worker_threads: None, enable_http2: true,
48 rate_limit_rps: Some(1000), graceful_shutdown_timeout: Duration::from_secs(30),
50 }
51 }
52}
53
54pub struct ConnectionPool<T> {
56 pool: Arc<T>,
57}
58
59impl<T> ConnectionPool<T>
60where
61 T: Send + Sync + 'static,
62{
63 pub fn new(pool: T) -> Self {
64 Self {
65 pool: Arc::new(pool),
66 }
67 }
68
69 pub fn get_pool(&self) -> Arc<T> {
70 self.pool.clone()
71 }
72}
73
74pub struct RateLimiter {
76 semaphore: Arc<Semaphore>,
77 #[allow(dead_code)]
78 requests_per_second: u32,
79}
80
81impl RateLimiter {
82 pub fn new(requests_per_second: u32) -> Self {
83 Self {
84 semaphore: Arc::new(Semaphore::new(requests_per_second as usize)),
85 requests_per_second,
86 }
87 }
88}
89
90impl Middleware for RateLimiter {
91 fn call(
92 &self,
93 req: Request,
94 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
95 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
96 let semaphore = self.semaphore.clone();
97 Box::pin(async move {
98 let _permit = match semaphore.try_acquire() {
100 Ok(permit) => permit,
101 Err(_) => {
102 return Response::with_status(http::StatusCode::TOO_MANY_REQUESTS)
103 .body("Rate limit exceeded");
104 }
105 };
106
107 next(req).await
109 })
110 }
111}
112
113pub struct RequestTimeout {
115 timeout: Duration,
116}
117
118impl RequestTimeout {
119 pub fn new(timeout: Duration) -> Self {
120 Self { timeout }
121 }
122}
123
124impl Middleware for RequestTimeout {
125 fn call(
126 &self,
127 req: Request,
128 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
129 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
130 let timeout = self.timeout;
131 Box::pin(async move {
132 match tokio::time::timeout(timeout, next(req)).await {
133 Ok(response) => response,
134 Err(_) => Response::with_status(http::StatusCode::REQUEST_TIMEOUT)
135 .body("Request timeout"),
136 }
137 })
138 }
139}
140
141pub struct RequestSizeLimit {
143 max_size: usize,
144}
145
146impl RequestSizeLimit {
147 pub fn new(max_size: usize) -> Self {
148 Self { max_size }
149 }
150}
151
152impl Middleware for RequestSizeLimit {
153 fn call(
154 &self,
155 req: Request,
156 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
157 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
158 let max_size = self.max_size;
159 Box::pin(async move {
160 if req.body().len() > max_size {
161 return Response::with_status(http::StatusCode::PAYLOAD_TOO_LARGE)
162 .body("Request body too large");
163 }
164 next(req).await
165 })
166 }
167}
168
169pub struct PerformanceMonitor;
171
172impl Middleware for PerformanceMonitor {
173 fn call(
174 &self,
175 req: Request,
176 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
177 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
178 Box::pin(async move {
179 let start = Instant::now();
180 let method = req.method().clone();
181 let path = req.path().to_string();
182
183 let response = next(req).await;
184
185 let duration = start.elapsed();
186 let status = response.status_code();
187
188 if duration > Duration::from_millis(1000) {
190 eprintln!(
191 "SLOW REQUEST: {} {} - {} ({:.2}ms)",
192 method,
193 path,
194 status,
195 duration.as_secs_f64() * 1000.0
196 );
197 }
198
199 response
200 })
201 }
202}
203
204pub fn health_check() -> impl Middleware {
206 |req: Request, next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>| {
207 Box::pin(async move {
208 if req.path() == "/health" {
209 return Response::ok()
210 .json(&{
211 #[cfg(feature = "monitoring")]
212 {
213 serde_json::json!({
214 "status": "healthy",
215 "timestamp": chrono::Utc::now().to_rfc3339(),
216 "uptime": "unknown"
217 })
218 }
219 #[cfg(not(feature = "monitoring"))]
220 {
221 serde_json::json!({
222 "status": "healthy",
223 "timestamp": "unknown",
224 "uptime": "unknown"
225 })
226 }
227 })
228 .unwrap_or_else(|_| Response::ok().body("healthy"));
229 }
230 next(req).await
231 })
232 }
233}
234
235pub struct MetricsCollector {
237 #[cfg(feature = "monitoring")]
238 request_counter: Arc<AtomicU64>,
239 #[cfg(feature = "monitoring")]
240 active_requests: Arc<AtomicU64>,
241 #[cfg(not(feature = "monitoring"))]
242 _phantom: std::marker::PhantomData<()>,
243}
244
245impl MetricsCollector {
246 pub fn new() -> Self {
247 Self {
248 #[cfg(feature = "monitoring")]
249 request_counter: Arc::new(AtomicU64::new(0)),
250 #[cfg(feature = "monitoring")]
251 active_requests: Arc::new(AtomicU64::new(0)),
252 #[cfg(not(feature = "monitoring"))]
253 _phantom: std::marker::PhantomData,
254 }
255 }
256
257 #[cfg(feature = "monitoring")]
258 pub fn get_request_count(&self) -> u64 {
259 self.request_counter.load(Ordering::Relaxed)
260 }
261
262 #[cfg(feature = "monitoring")]
263 pub fn get_active_requests(&self) -> u64 {
264 self.active_requests.load(Ordering::Relaxed)
265 }
266}
267
268impl Middleware for MetricsCollector {
269 fn call(
270 &self,
271 req: Request,
272 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
273 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
274 Box::pin(async move {
275 let start = Instant::now();
276 let method = req.method().clone();
277 let path = req.path().to_string();
278
279 let response = next(req).await;
280
281 let duration = start.elapsed();
282 let status = response.status_code();
283
284 println!(
287 "METRIC: method={} path={} status={} duration_ms={:.2}",
288 method,
289 path,
290 status.as_u16(),
291 duration.as_secs_f64() * 1000.0
292 );
293
294 response
295 })
296 }
297}
298
299#[cfg(disabled_for_now)]
300mod tests {
301 use super::*;
302 use std::pin::Pin;
303 use std::future::Future;
304 use crate::Response;
305
306 #[tokio::test]
307 async fn test_rate_limiter() {
308 let rate_limiter = RateLimiter::new(1);
309
310 let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
311 Box::pin(async { Response::ok().body("success") })
312 });
313
314 let req = crate::Request::from_hyper(
315 http::Request::builder()
316 .method("GET")
317 .uri("/")
318 .body(())
319 .unwrap()
320 .into_parts()
321 .0,
322 Vec::new(),
323 )
324 .await
325 .unwrap();
326
327 let response = rate_limiter.call(req, next.clone()).await;
329 assert_eq!(response.status_code(), http::StatusCode::OK);
330 }
331
332 #[tokio::test]
333 async fn test_request_timeout() {
334 let timeout_middleware = RequestTimeout::new(Duration::from_millis(100));
335
336 let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
337 Box::pin(async {
338 tokio::time::sleep(Duration::from_millis(200)).await;
339 Response::ok().body("too slow")
340 })
341 });
342
343 let req = crate::Request::from_hyper(
344 http::Request::builder()
345 .method("GET")
346 .uri("/")
347 .body(())
348 .unwrap()
349 .into_parts()
350 .0,
351 Vec::new(),
352 )
353 .await
354 .unwrap();
355
356 let response = timeout_middleware.call(req, next).await;
357 assert_eq!(response.status_code(), http::StatusCode::REQUEST_TIMEOUT);
358 }
359}