Skip to main content

rush_sync_server/server/
middleware.rs

1use crate::core::api_key::ApiKey;
2use actix_web::{
3    body::EitherBody,
4    dev::{Service, ServiceRequest, ServiceResponse, Transform},
5    Error, HttpResponse,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9    collections::{HashMap, VecDeque},
10    future::{ready, Ready},
11    sync::{Arc, Mutex},
12    time::Instant,
13};
14
15pub struct LoggingMiddleware {
16    server_logger: Arc<crate::server::logging::ServerLogger>,
17}
18
19impl LoggingMiddleware {
20    pub fn new(server_logger: Arc<crate::server::logging::ServerLogger>) -> Self {
21        Self { server_logger }
22    }
23}
24
25impl<S, B> Transform<S, ServiceRequest> for LoggingMiddleware
26where
27    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
28    S::Future: 'static,
29    B: 'static,
30{
31    type Response = ServiceResponse<B>;
32    type Error = Error;
33    type InitError = ();
34    type Transform = LoggingMiddlewareService<S>;
35    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
36
37    fn new_transform(&self, service: S) -> Self::Future {
38        ready(Ok(LoggingMiddlewareService {
39            service,
40            server_logger: self.server_logger.clone(),
41        }))
42    }
43}
44
45pub struct LoggingMiddlewareService<S> {
46    service: S,
47    server_logger: Arc<crate::server::logging::ServerLogger>,
48}
49
50impl<S, B> Service<ServiceRequest> for LoggingMiddlewareService<S>
51where
52    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
53    S::Future: 'static,
54    B: 'static,
55{
56    type Response = ServiceResponse<B>;
57    type Error = Error;
58    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
59
60    actix_web::dev::forward_ready!(service);
61
62    fn call(&self, req: ServiceRequest) -> Self::Future {
63        let start_time = Instant::now();
64        let server_logger = self.server_logger.clone();
65
66        let ip = {
67            let connection_info = req.connection_info();
68            connection_info
69                .realip_remote_addr()
70                .or_else(|| connection_info.peer_addr())
71                .unwrap_or("unknown")
72                .split(':')
73                .next()
74                .unwrap_or("unknown")
75                .to_string()
76        };
77
78        let path = req.path().to_string();
79        let method = req.method().to_string();
80        let query_string = req.query_string().to_string();
81
82        let suspicious = is_suspicious_path(&path);
83
84        if suspicious {
85            let logger_clone = server_logger.clone();
86            let ip_clone = ip.clone();
87            let path_clone = path.clone();
88            tokio::spawn(async move {
89                let _ = logger_clone
90                    .log_security_alert(
91                        &ip_clone,
92                        "Suspicious Request",
93                        &format!("Suspicious path: {}", path_clone),
94                    )
95                    .await;
96            });
97        }
98
99        let headers: std::collections::HashMap<String, String> = req
100            .headers()
101            .iter()
102            .filter_map(|(name, value)| {
103                let header_name = name.as_str().to_lowercase();
104                if !["authorization", "cookie", "x-api-key"].contains(&header_name.as_str()) {
105                    value
106                        .to_str()
107                        .ok()
108                        .map(|v| (name.as_str().to_string(), v.to_string()))
109                } else {
110                    Some((name.as_str().to_string(), "[FILTERED]".to_string()))
111                }
112            })
113            .collect();
114
115        let fut = self.service.call(req);
116
117        Box::pin(async move {
118            let res = fut.await?;
119            let response_time = start_time.elapsed().as_millis() as u64;
120            let status = res.status().as_u16();
121            let bytes_sent = res
122                .response()
123                .headers()
124                .get("content-length")
125                .and_then(|h| h.to_str().ok())
126                .and_then(|s| s.parse().ok())
127                .unwrap_or(0);
128
129            // Clone for analytics tracking (values move into the log entry below)
130            let analytics_path = path.clone();
131            let analytics_ip = ip.clone();
132            let analytics_ua = headers.get("user-agent").cloned().unwrap_or_default();
133
134            let entry = crate::server::logging::ServerLogEntry {
135                timestamp: chrono::Local::now()
136                    .format("%Y-%m-%d %H:%M:%S%.3f")
137                    .to_string(),
138                timestamp_unix: std::time::SystemTime::now()
139                    .duration_since(std::time::UNIX_EPOCH)
140                    .unwrap_or_default()
141                    .as_secs(),
142                event_type: crate::server::logging::LogEventType::Request,
143                ip_address: ip,
144                user_agent: headers.get("user-agent").cloned(),
145                method,
146                path,
147                status_code: Some(status),
148                response_time_ms: Some(response_time),
149                bytes_sent: Some(bytes_sent),
150                referer: headers.get("referer").cloned(),
151                query_string: if query_string.is_empty() {
152                    None
153                } else {
154                    Some(query_string)
155                },
156                headers,
157                session_id: None,
158            };
159
160            if let Err(e) = server_logger.write_log_entry(entry).await {
161                log::error!("Failed to log request: {}", e);
162            }
163
164            crate::server::analytics::track_request("", &analytics_path, &analytics_ip, &analytics_ua);
165
166            Ok(res)
167        })
168    }
169}
170
171fn percent_decode(input: &str) -> String {
172    let mut result = String::with_capacity(input.len());
173    let bytes = input.as_bytes();
174    let mut i = 0;
175    while i < bytes.len() {
176        if bytes[i] == b'%' && i + 2 < bytes.len() {
177            if let Ok(byte) = u8::from_str_radix(&input[i + 1..i + 3], 16) {
178                result.push(byte as char);
179                i += 3;
180                continue;
181            }
182        }
183        result.push(bytes[i] as char);
184        i += 1;
185    }
186    result
187}
188
189fn is_suspicious_path(path: &str) -> bool {
190    let decoded = percent_decode(path);
191    let normalized = decoded.replace('\\', "/");
192    let lower = normalized.to_lowercase();
193
194    normalized.contains("..")
195        || lower.contains("<script")
196        || lower.contains("union select")
197        || lower.contains("drop table")
198        || path.len() > 1000
199}
200
201// =============================================================================
202// API Key Authentication Middleware
203// =============================================================================
204
205#[derive(Clone)]
206pub struct ApiKeyAuth {
207    api_key: ApiKey,
208}
209
210impl ApiKeyAuth {
211    pub fn new(api_key: ApiKey) -> Self {
212        Self { api_key }
213    }
214}
215
216impl<S, B> Transform<S, ServiceRequest> for ApiKeyAuth
217where
218    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
219    S::Future: 'static,
220    B: 'static,
221{
222    type Response = ServiceResponse<EitherBody<B>>;
223    type Error = Error;
224    type InitError = ();
225    type Transform = ApiKeyAuthService<S>;
226    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
227
228    fn new_transform(&self, service: S) -> Self::Future {
229        ready(Ok(ApiKeyAuthService {
230            service,
231            api_key: self.api_key.clone(),
232        }))
233    }
234}
235
236pub struct ApiKeyAuthService<S> {
237    service: S,
238    api_key: ApiKey,
239}
240
241impl<S, B> Service<ServiceRequest> for ApiKeyAuthService<S>
242where
243    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
244    S::Future: 'static,
245    B: 'static,
246{
247    type Response = ServiceResponse<EitherBody<B>>;
248    type Error = Error;
249    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
250
251    actix_web::dev::forward_ready!(service);
252
253    fn call(&self, req: ServiceRequest) -> Self::Future {
254        let path = req.path().to_string();
255
256        let is_public_asset = path == "/.rss/"
257            || path == "/.rss/_reset.css"
258            || path == "/.rss/style.css"
259            || path == "/.rss/favicon.svg"
260            || path.starts_with("/.rss/js/")
261            || path.starts_with("/.rss/fonts/")
262            || path == "/ws/hot-reload";
263
264        let needs_auth =
265            (path.starts_with("/api/") || path.starts_with("/.rss/") || path.starts_with("/ws/"))
266                && path != "/api/health"
267                && !path.starts_with("/api/acme/")
268                && !path.starts_with("/api/analytics")
269                && !path.starts_with("/.well-known/")
270                && !is_public_asset;
271
272        // Skip auth if not needed or no key configured
273        if !needs_auth || self.api_key.is_empty() {
274            let fut = self.service.call(req);
275            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
276        }
277
278        // Check X-API-Key header
279        let header_key = req
280            .headers()
281            .get("x-api-key")
282            .and_then(|v| v.to_str().ok())
283            .map(|s| s.to_string());
284
285        // Check ?api_key= query parameter
286        let query_key = req
287            .query_string()
288            .split('&')
289            .find_map(|param| param.strip_prefix("api_key="))
290            .map(|s| s.to_string());
291
292        let provided_key = header_key.or(query_key);
293
294        let is_valid = provided_key
295            .as_deref()
296            .map(|k| self.api_key.verify(k))
297            .unwrap_or(false);
298        if is_valid {
299            let fut = self.service.call(req);
300            Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
301        } else {
302            let response = HttpResponse::Unauthorized()
303                .json(serde_json::json!({
304                    "error": "Unauthorized",
305                    "message": "Valid API key required. Provide via X-API-Key header or ?api_key= query parameter."
306                }));
307            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
308        }
309    }
310}
311
312// =============================================================================
313// Rate Limiter Middleware
314// =============================================================================
315
316#[derive(Clone)]
317pub struct RateLimiter {
318    max_rps: u32,
319    enabled: bool,
320    clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
321}
322
323impl RateLimiter {
324    pub fn new(max_rps: u32, enabled: bool) -> Self {
325        Self {
326            max_rps,
327            enabled,
328            clients: Arc::new(Mutex::new(HashMap::new())),
329        }
330    }
331}
332
333impl<S, B> Transform<S, ServiceRequest> for RateLimiter
334where
335    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
336    S::Future: 'static,
337    B: 'static,
338{
339    type Response = ServiceResponse<EitherBody<B>>;
340    type Error = Error;
341    type InitError = ();
342    type Transform = RateLimiterService<S>;
343    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
344
345    fn new_transform(&self, service: S) -> Self::Future {
346        ready(Ok(RateLimiterService {
347            service,
348            max_rps: self.max_rps,
349            enabled: self.enabled,
350            clients: self.clients.clone(),
351        }))
352    }
353}
354
355pub struct RateLimiterService<S> {
356    service: S,
357    max_rps: u32,
358    enabled: bool,
359    clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
360}
361
362impl<S, B> Service<ServiceRequest> for RateLimiterService<S>
363where
364    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
365    S::Future: 'static,
366    B: 'static,
367{
368    type Response = ServiceResponse<EitherBody<B>>;
369    type Error = Error;
370    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
371
372    actix_web::dev::forward_ready!(service);
373
374    fn call(&self, req: ServiceRequest) -> Self::Future {
375        // Only rate-limit /api/* paths
376        if !self.enabled || !req.path().starts_with("/api/") {
377            let fut = self.service.call(req);
378            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
379        }
380
381        let ip = {
382            let connection_info = req.connection_info();
383            connection_info
384                .realip_remote_addr()
385                .or_else(|| connection_info.peer_addr())
386                .unwrap_or("unknown")
387                .split(':')
388                .next()
389                .unwrap_or("unknown")
390                .to_string()
391        };
392
393        let now = Instant::now();
394        let one_second_ago = now - std::time::Duration::from_secs(1);
395
396        let is_limited = if let Ok(mut clients) = self.clients.lock() {
397            let timestamps = clients.entry(ip).or_insert_with(VecDeque::new);
398
399            // Remove entries older than 1 second
400            while timestamps.front().is_some_and(|t| *t < one_second_ago) {
401                timestamps.pop_front();
402            }
403
404            if timestamps.len() >= self.max_rps as usize {
405                true
406            } else {
407                timestamps.push_back(now);
408                false
409            }
410        } else {
411            false // If lock fails, allow the request
412        };
413
414        if is_limited {
415            let response = HttpResponse::TooManyRequests()
416                .insert_header(("Retry-After", "1"))
417                .json(serde_json::json!({
418                    "error": "Too Many Requests",
419                    "message": "Rate limit exceeded. Try again later.",
420                    "retry_after": 1
421                }));
422            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
423        } else {
424            let fut = self.service.call(req);
425            Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
426        }
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    // --- percent_decode tests ---
435
436    #[test]
437    fn test_percent_decode_plain() {
438        assert_eq!(percent_decode("/api/status"), "/api/status");
439    }
440
441    #[test]
442    fn test_percent_decode_encoded_slash() {
443        assert_eq!(percent_decode("%2F"), "/");
444    }
445
446    #[test]
447    fn test_percent_decode_dot_dot() {
448        assert_eq!(percent_decode("%2e%2e"), "..");
449    }
450
451    #[test]
452    fn test_percent_decode_mixed() {
453        assert_eq!(percent_decode("/foo%2Fbar%2E%2E%2Fbaz"), "/foo/bar../baz");
454    }
455
456    #[test]
457    fn test_percent_decode_incomplete_sequence() {
458        assert_eq!(percent_decode("abc%2"), "abc%2");
459    }
460
461    #[test]
462    fn test_percent_decode_invalid_hex() {
463        assert_eq!(percent_decode("%ZZ"), "%ZZ");
464    }
465
466    #[test]
467    fn test_percent_decode_empty() {
468        assert_eq!(percent_decode(""), "");
469    }
470
471    #[test]
472    fn test_percent_decode_script_tag() {
473        assert_eq!(percent_decode("%3Cscript%3E"), "<script>");
474    }
475
476    // --- is_suspicious_path tests ---
477
478    #[test]
479    fn test_suspicious_path_traversal() {
480        assert!(is_suspicious_path("/../etc/passwd"));
481        assert!(is_suspicious_path("/foo/../../etc/shadow"));
482    }
483
484    #[test]
485    fn test_suspicious_path_encoded_traversal() {
486        assert!(is_suspicious_path("/%2e%2e/etc/passwd"));
487        assert!(is_suspicious_path("/%2E%2E/secret"));
488    }
489
490    #[test]
491    fn test_suspicious_path_backslash_traversal() {
492        assert!(is_suspicious_path("/foo\\..\\etc\\passwd"));
493    }
494
495    #[test]
496    fn test_suspicious_path_script_injection() {
497        assert!(is_suspicious_path("/<script>alert(1)</script>"));
498        assert!(is_suspicious_path("/%3Cscript%3Ealert(1)"));
499    }
500
501    #[test]
502    fn test_suspicious_path_sql_injection() {
503        assert!(is_suspicious_path("/api?q=1 UNION SELECT * FROM users"));
504        assert!(is_suspicious_path("/api?q=DROP TABLE users"));
505    }
506
507    #[test]
508    fn test_suspicious_path_too_long() {
509        let long_path = "/".to_string() + &"a".repeat(1001);
510        assert!(is_suspicious_path(&long_path));
511    }
512
513    #[test]
514    fn test_safe_paths() {
515        assert!(!is_suspicious_path("/"));
516        assert!(!is_suspicious_path("/api/status"));
517        assert!(!is_suspicious_path("/index.html"));
518        assert!(!is_suspicious_path("/.rss/style.css"));
519        assert!(!is_suspicious_path("/api/logs?offset=100"));
520        assert!(!is_suspicious_path("/ws/hot-reload"));
521    }
522
523    #[test]
524    fn test_safe_path_with_dots_in_filename() {
525        assert!(!is_suspicious_path("/file.name.html"));
526        assert!(!is_suspicious_path("/.rss/favicon.svg"));
527    }
528}