Skip to main content

rush_sync_server/server/
middleware.rs

1use crate::core::api_key::ApiKey;
2use actix_web::{
3    body::EitherBody,
4    dev::{Service, ServiceRequest, ServiceResponse, Transform},
5    Error, HttpResponse,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9    collections::{HashMap, VecDeque},
10    future::{ready, Ready},
11    sync::{Arc, Mutex},
12    time::Instant,
13};
14
15pub struct LoggingMiddleware {
16    server_logger: Arc<crate::server::logging::ServerLogger>,
17}
18
19impl LoggingMiddleware {
20    pub fn new(server_logger: Arc<crate::server::logging::ServerLogger>) -> Self {
21        Self { server_logger }
22    }
23}
24
25impl<S, B> Transform<S, ServiceRequest> for LoggingMiddleware
26where
27    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
28    S::Future: 'static,
29    B: 'static,
30{
31    type Response = ServiceResponse<B>;
32    type Error = Error;
33    type InitError = ();
34    type Transform = LoggingMiddlewareService<S>;
35    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
36
37    fn new_transform(&self, service: S) -> Self::Future {
38        ready(Ok(LoggingMiddlewareService {
39            service,
40            server_logger: self.server_logger.clone(),
41        }))
42    }
43}
44
45pub struct LoggingMiddlewareService<S> {
46    service: S,
47    server_logger: Arc<crate::server::logging::ServerLogger>,
48}
49
50impl<S, B> Service<ServiceRequest> for LoggingMiddlewareService<S>
51where
52    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
53    S::Future: 'static,
54    B: 'static,
55{
56    type Response = ServiceResponse<B>;
57    type Error = Error;
58    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
59
60    actix_web::dev::forward_ready!(service);
61
62    fn call(&self, req: ServiceRequest) -> Self::Future {
63        let start_time = Instant::now();
64        let server_logger = self.server_logger.clone();
65
66        let ip = {
67            let connection_info = req.connection_info();
68            connection_info
69                .realip_remote_addr()
70                .or_else(|| connection_info.peer_addr())
71                .unwrap_or("unknown")
72                .split(':')
73                .next()
74                .unwrap_or("unknown")
75                .to_string()
76        };
77
78        let path = req.path().to_string();
79        let method = req.method().to_string();
80        let query_string = req.query_string().to_string();
81
82        let suspicious = is_suspicious_path(&path);
83
84        if suspicious {
85            let logger_clone = server_logger.clone();
86            let ip_clone = ip.clone();
87            let path_clone = path.clone();
88            tokio::spawn(async move {
89                let _ = logger_clone
90                    .log_security_alert(
91                        &ip_clone,
92                        "Suspicious Request",
93                        &format!("Suspicious path: {}", path_clone),
94                    )
95                    .await;
96            });
97        }
98
99        let headers: std::collections::HashMap<String, String> = req
100            .headers()
101            .iter()
102            .filter_map(|(name, value)| {
103                let header_name = name.as_str().to_lowercase();
104                if !["authorization", "cookie", "x-api-key"].contains(&header_name.as_str()) {
105                    value
106                        .to_str()
107                        .ok()
108                        .map(|v| (name.as_str().to_string(), v.to_string()))
109                } else {
110                    Some((name.as_str().to_string(), "[FILTERED]".to_string()))
111                }
112            })
113            .collect();
114
115        let fut = self.service.call(req);
116
117        Box::pin(async move {
118            let res = fut.await?;
119            let response_time = start_time.elapsed().as_millis() as u64;
120            let status = res.status().as_u16();
121            let bytes_sent = res
122                .response()
123                .headers()
124                .get("content-length")
125                .and_then(|h| h.to_str().ok())
126                .and_then(|s| s.parse().ok())
127                .unwrap_or(0);
128
129            // Analytics: only track if NOT proxied (proxy handler tracks with real client IP)
130            let is_proxied = headers.contains_key("x-forwarded-for") || headers.contains_key("x-real-ip");
131            let analytics_path = path.clone();
132            let analytics_ip = ip.clone();
133            let analytics_ua = headers.get("user-agent").cloned().unwrap_or_default();
134
135            let entry = crate::server::logging::ServerLogEntry {
136                timestamp: chrono::Local::now()
137                    .format("%Y-%m-%d %H:%M:%S%.3f")
138                    .to_string(),
139                timestamp_unix: std::time::SystemTime::now()
140                    .duration_since(std::time::UNIX_EPOCH)
141                    .unwrap_or_default()
142                    .as_secs(),
143                event_type: crate::server::logging::LogEventType::Request,
144                ip_address: ip,
145                user_agent: headers.get("user-agent").cloned(),
146                method,
147                path,
148                status_code: Some(status),
149                response_time_ms: Some(response_time),
150                bytes_sent: Some(bytes_sent),
151                referer: headers.get("referer").cloned(),
152                query_string: if query_string.is_empty() {
153                    None
154                } else {
155                    Some(query_string)
156                },
157                headers,
158                session_id: None,
159            };
160
161            if let Err(e) = server_logger.write_log_entry(entry).await {
162                log::error!("Failed to log request: {}", e);
163            }
164
165            if !is_proxied {
166                crate::server::analytics::track_request("", &analytics_path, &analytics_ip, &analytics_ua);
167            }
168
169            Ok(res)
170        })
171    }
172}
173
174fn percent_decode(input: &str) -> String {
175    let mut result = String::with_capacity(input.len());
176    let bytes = input.as_bytes();
177    let mut i = 0;
178    while i < bytes.len() {
179        if bytes[i] == b'%' && i + 2 < bytes.len() {
180            if let Ok(byte) = u8::from_str_radix(&input[i + 1..i + 3], 16) {
181                result.push(byte as char);
182                i += 3;
183                continue;
184            }
185        }
186        result.push(bytes[i] as char);
187        i += 1;
188    }
189    result
190}
191
192fn is_suspicious_path(path: &str) -> bool {
193    let decoded = percent_decode(path);
194    let normalized = decoded.replace('\\', "/");
195    let lower = normalized.to_lowercase();
196
197    normalized.contains("..")
198        || lower.contains("<script")
199        || lower.contains("union select")
200        || lower.contains("drop table")
201        || path.len() > 1000
202}
203
204// =============================================================================
205// API Key Authentication Middleware
206// =============================================================================
207
208#[derive(Clone)]
209pub struct ApiKeyAuth {
210    api_key: ApiKey,
211}
212
213impl ApiKeyAuth {
214    pub fn new(api_key: ApiKey) -> Self {
215        Self { api_key }
216    }
217}
218
219impl<S, B> Transform<S, ServiceRequest> for ApiKeyAuth
220where
221    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
222    S::Future: 'static,
223    B: 'static,
224{
225    type Response = ServiceResponse<EitherBody<B>>;
226    type Error = Error;
227    type InitError = ();
228    type Transform = ApiKeyAuthService<S>;
229    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
230
231    fn new_transform(&self, service: S) -> Self::Future {
232        ready(Ok(ApiKeyAuthService {
233            service,
234            api_key: self.api_key.clone(),
235        }))
236    }
237}
238
239pub struct ApiKeyAuthService<S> {
240    service: S,
241    api_key: ApiKey,
242}
243
244impl<S, B> Service<ServiceRequest> for ApiKeyAuthService<S>
245where
246    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
247    S::Future: 'static,
248    B: 'static,
249{
250    type Response = ServiceResponse<EitherBody<B>>;
251    type Error = Error;
252    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
253
254    actix_web::dev::forward_ready!(service);
255
256    fn call(&self, req: ServiceRequest) -> Self::Future {
257        let path = req.path().to_string();
258
259        let is_public_asset = path == "/.rss/"
260            || path == "/.rss/_reset.css"
261            || path == "/.rss/style.css"
262            || path == "/.rss/favicon.svg"
263            || path.starts_with("/.rss/js/")
264            || path.starts_with("/.rss/fonts/")
265            || path == "/ws/hot-reload";
266
267        let needs_auth =
268            (path.starts_with("/api/") || path.starts_with("/.rss/") || path.starts_with("/ws/"))
269                && path != "/api/health"
270                && path != "/api/logs/raw"
271                && !path.starts_with("/api/acme/")
272                && !path.starts_with("/api/analytics")
273                && !path.starts_with("/.well-known/")
274                && !is_public_asset;
275
276        // Skip auth if not needed or no key configured
277        if !needs_auth || self.api_key.is_empty() {
278            let fut = self.service.call(req);
279            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
280        }
281
282        // Check X-API-Key header
283        let header_key = req
284            .headers()
285            .get("x-api-key")
286            .and_then(|v| v.to_str().ok())
287            .map(|s| s.to_string());
288
289        // Check ?api_key= query parameter
290        let query_key = req
291            .query_string()
292            .split('&')
293            .find_map(|param| param.strip_prefix("api_key="))
294            .map(|s| s.to_string());
295
296        let provided_key = header_key.or(query_key);
297
298        let is_valid = provided_key
299            .as_deref()
300            .map(|k| self.api_key.verify(k))
301            .unwrap_or(false);
302        if is_valid {
303            let fut = self.service.call(req);
304            Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
305        } else {
306            let response = HttpResponse::Unauthorized()
307                .json(serde_json::json!({
308                    "error": "Unauthorized",
309                    "message": "Valid API key required. Provide via X-API-Key header or ?api_key= query parameter."
310                }));
311            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
312        }
313    }
314}
315
316// =============================================================================
317// Rate Limiter Middleware
318// =============================================================================
319
320#[derive(Clone)]
321pub struct RateLimiter {
322    max_rps: u32,
323    enabled: bool,
324    clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
325}
326
327impl RateLimiter {
328    pub fn new(max_rps: u32, enabled: bool) -> Self {
329        Self {
330            max_rps,
331            enabled,
332            clients: Arc::new(Mutex::new(HashMap::new())),
333        }
334    }
335}
336
337impl<S, B> Transform<S, ServiceRequest> for RateLimiter
338where
339    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
340    S::Future: 'static,
341    B: 'static,
342{
343    type Response = ServiceResponse<EitherBody<B>>;
344    type Error = Error;
345    type InitError = ();
346    type Transform = RateLimiterService<S>;
347    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
348
349    fn new_transform(&self, service: S) -> Self::Future {
350        ready(Ok(RateLimiterService {
351            service,
352            max_rps: self.max_rps,
353            enabled: self.enabled,
354            clients: self.clients.clone(),
355        }))
356    }
357}
358
359pub struct RateLimiterService<S> {
360    service: S,
361    max_rps: u32,
362    enabled: bool,
363    clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
364}
365
366impl<S, B> Service<ServiceRequest> for RateLimiterService<S>
367where
368    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
369    S::Future: 'static,
370    B: 'static,
371{
372    type Response = ServiceResponse<EitherBody<B>>;
373    type Error = Error;
374    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
375
376    actix_web::dev::forward_ready!(service);
377
378    fn call(&self, req: ServiceRequest) -> Self::Future {
379        // Only rate-limit /api/* paths
380        if !self.enabled || !req.path().starts_with("/api/") {
381            let fut = self.service.call(req);
382            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
383        }
384
385        let ip = {
386            let connection_info = req.connection_info();
387            connection_info
388                .realip_remote_addr()
389                .or_else(|| connection_info.peer_addr())
390                .unwrap_or("unknown")
391                .split(':')
392                .next()
393                .unwrap_or("unknown")
394                .to_string()
395        };
396
397        let now = Instant::now();
398        let one_second_ago = now - std::time::Duration::from_secs(1);
399
400        let is_limited = if let Ok(mut clients) = self.clients.lock() {
401            let timestamps = clients.entry(ip).or_insert_with(VecDeque::new);
402
403            // Remove entries older than 1 second
404            while timestamps.front().is_some_and(|t| *t < one_second_ago) {
405                timestamps.pop_front();
406            }
407
408            if timestamps.len() >= self.max_rps as usize {
409                true
410            } else {
411                timestamps.push_back(now);
412                false
413            }
414        } else {
415            false // If lock fails, allow the request
416        };
417
418        if is_limited {
419            let response = HttpResponse::TooManyRequests()
420                .insert_header(("Retry-After", "1"))
421                .json(serde_json::json!({
422                    "error": "Too Many Requests",
423                    "message": "Rate limit exceeded. Try again later.",
424                    "retry_after": 1
425                }));
426            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
427        } else {
428            let fut = self.service.call(req);
429            Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
430        }
431    }
432}
433
434// =============================================================================
435// PIN Protection Middleware
436// =============================================================================
437
438#[derive(Clone)]
439pub struct PinProtection {
440    server_name: String,
441    server_port: u16,
442}
443
444impl PinProtection {
445    pub fn new(server_name: &str, server_port: u16) -> Self {
446        Self {
447            server_name: server_name.to_string(),
448            server_port,
449        }
450    }
451}
452
453impl<S, B> Transform<S, ServiceRequest> for PinProtection
454where
455    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
456    S::Future: 'static,
457    B: 'static,
458{
459    type Response = ServiceResponse<EitherBody<B>>;
460    type Error = Error;
461    type InitError = ();
462    type Transform = PinProtectionService<S>;
463    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
464
465    fn new_transform(&self, service: S) -> Self::Future {
466        ready(Ok(PinProtectionService {
467            service,
468            server_name: self.server_name.clone(),
469            server_port: self.server_port,
470        }))
471    }
472}
473
474pub struct PinProtectionService<S> {
475    service: S,
476    server_name: String,
477    server_port: u16,
478}
479
480impl<S, B> Service<ServiceRequest> for PinProtectionService<S>
481where
482    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
483    S::Future: 'static,
484    B: 'static,
485{
486    type Response = ServiceResponse<EitherBody<B>>;
487    type Error = Error;
488    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
489
490    actix_web::dev::forward_ready!(service);
491
492    fn call(&self, req: ServiceRequest) -> Self::Future {
493        let path = req.path().to_string();
494
495        // Always allow these paths (needed for PIN page + unlock + user content)
496        // /.rss/ exact = dashboard handler shows PIN page itself
497        let is_exempt = path == "/api/pin/verify"
498            || path == "/api/pin/logout"
499            || path == "/api/health"
500            || path == "/.rss/"
501            || path == "/.rss/style.css"
502            || path == "/.rss/_reset.css"
503            || path == "/.rss/favicon.svg"
504            || path.starts_with("/.rss/fonts/")
505            || path.starts_with("/.well-known/")
506            || path == "/rss.js"
507            || path == "/ws/hot-reload";
508
509        // Only protect dashboard and API paths
510        let is_protected = path.starts_with("/api/")
511            || path.starts_with("/.rss/");
512
513        if is_exempt || !is_protected {
514            let fut = self.service.call(req);
515            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
516        }
517
518        // Load settings and check PIN
519        let server_dir = crate::server::settings::ServerSettings::get_server_dir(
520            &self.server_name,
521            self.server_port,
522        );
523        let settings = match server_dir {
524            Some(ref dir) => crate::server::settings::ServerSettings::load(dir),
525            None => crate::server::settings::ServerSettings::default(),
526        };
527
528        // If PIN not enabled, pass through
529        if !settings.pin_enabled || settings.pin_code.is_empty() {
530            let fut = self.service.call(req);
531            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
532        }
533
534        // Check for valid PIN cookie
535        let expected_token = format!("rss-pin-{}-{}", self.server_name, self.server_port);
536        let has_valid_cookie = req
537            .cookie("rss_pin")
538            .map(|c| c.value() == expected_token)
539            .unwrap_or(false);
540
541        if has_valid_cookie {
542            let fut = self.service.call(req);
543            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
544        }
545
546        // Blocked — return 401 for API, redirect for dashboard
547        if path.starts_with("/api/") {
548            let response = HttpResponse::Unauthorized()
549                .json(serde_json::json!({
550                    "error": "PIN required",
551                    "message": "Dashboard is PIN protected. Unlock at /.rss/"
552                }));
553            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
554        } else {
555            // For /.rss/ paths, redirect to dashboard (which shows PIN page)
556            let response = HttpResponse::Found()
557                .insert_header(("Location", "/.rss/"))
558                .finish();
559            Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
560        }
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    // --- percent_decode tests ---
569
570    #[test]
571    fn test_percent_decode_plain() {
572        assert_eq!(percent_decode("/api/status"), "/api/status");
573    }
574
575    #[test]
576    fn test_percent_decode_encoded_slash() {
577        assert_eq!(percent_decode("%2F"), "/");
578    }
579
580    #[test]
581    fn test_percent_decode_dot_dot() {
582        assert_eq!(percent_decode("%2e%2e"), "..");
583    }
584
585    #[test]
586    fn test_percent_decode_mixed() {
587        assert_eq!(percent_decode("/foo%2Fbar%2E%2E%2Fbaz"), "/foo/bar../baz");
588    }
589
590    #[test]
591    fn test_percent_decode_incomplete_sequence() {
592        assert_eq!(percent_decode("abc%2"), "abc%2");
593    }
594
595    #[test]
596    fn test_percent_decode_invalid_hex() {
597        assert_eq!(percent_decode("%ZZ"), "%ZZ");
598    }
599
600    #[test]
601    fn test_percent_decode_empty() {
602        assert_eq!(percent_decode(""), "");
603    }
604
605    #[test]
606    fn test_percent_decode_script_tag() {
607        assert_eq!(percent_decode("%3Cscript%3E"), "<script>");
608    }
609
610    // --- is_suspicious_path tests ---
611
612    #[test]
613    fn test_suspicious_path_traversal() {
614        assert!(is_suspicious_path("/../etc/passwd"));
615        assert!(is_suspicious_path("/foo/../../etc/shadow"));
616    }
617
618    #[test]
619    fn test_suspicious_path_encoded_traversal() {
620        assert!(is_suspicious_path("/%2e%2e/etc/passwd"));
621        assert!(is_suspicious_path("/%2E%2E/secret"));
622    }
623
624    #[test]
625    fn test_suspicious_path_backslash_traversal() {
626        assert!(is_suspicious_path("/foo\\..\\etc\\passwd"));
627    }
628
629    #[test]
630    fn test_suspicious_path_script_injection() {
631        assert!(is_suspicious_path("/<script>alert(1)</script>"));
632        assert!(is_suspicious_path("/%3Cscript%3Ealert(1)"));
633    }
634
635    #[test]
636    fn test_suspicious_path_sql_injection() {
637        assert!(is_suspicious_path("/api?q=1 UNION SELECT * FROM users"));
638        assert!(is_suspicious_path("/api?q=DROP TABLE users"));
639    }
640
641    #[test]
642    fn test_suspicious_path_too_long() {
643        let long_path = "/".to_string() + &"a".repeat(1001);
644        assert!(is_suspicious_path(&long_path));
645    }
646
647    #[test]
648    fn test_safe_paths() {
649        assert!(!is_suspicious_path("/"));
650        assert!(!is_suspicious_path("/api/status"));
651        assert!(!is_suspicious_path("/index.html"));
652        assert!(!is_suspicious_path("/.rss/style.css"));
653        assert!(!is_suspicious_path("/api/logs?offset=100"));
654        assert!(!is_suspicious_path("/ws/hot-reload"));
655    }
656
657    #[test]
658    fn test_safe_path_with_dots_in_filename() {
659        assert!(!is_suspicious_path("/file.name.html"));
660        assert!(!is_suspicious_path("/.rss/favicon.svg"));
661    }
662}