rush_sync_server/server/
middleware.rs

1use actix_web::{
2    dev::{Service, ServiceRequest, ServiceResponse, Transform},
3    Error,
4};
5use futures_util::future::LocalBoxFuture;
6use std::{
7    future::{ready, Ready},
8    sync::Arc,
9    time::Instant,
10};
11
12pub struct LoggingMiddleware {
13    server_logger: Arc<crate::server::logging::ServerLogger>,
14}
15
16impl LoggingMiddleware {
17    pub fn new(server_logger: Arc<crate::server::logging::ServerLogger>) -> Self {
18        Self { server_logger }
19    }
20}
21
22impl<S, B> Transform<S, ServiceRequest> for LoggingMiddleware
23where
24    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
25    S::Future: 'static,
26    B: 'static,
27{
28    type Response = ServiceResponse<B>;
29    type Error = Error;
30    type InitError = ();
31    type Transform = LoggingMiddlewareService<S>;
32    type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
33
34    fn new_transform(&self, service: S) -> Self::Future {
35        ready(Ok(LoggingMiddlewareService {
36            service,
37            server_logger: self.server_logger.clone(),
38        }))
39    }
40}
41
42pub struct LoggingMiddlewareService<S> {
43    service: S,
44    server_logger: Arc<crate::server::logging::ServerLogger>,
45}
46
47impl<S, B> Service<ServiceRequest> for LoggingMiddlewareService<S>
48where
49    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
50    S::Future: 'static,
51    B: 'static,
52{
53    type Response = ServiceResponse<B>;
54    type Error = Error;
55    type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
56
57    actix_web::dev::forward_ready!(service);
58
59    fn call(&self, req: ServiceRequest) -> Self::Future {
60        let start_time = Instant::now();
61        let server_logger = self.server_logger.clone();
62
63        let ip = {
64            let connection_info = req.connection_info();
65            connection_info
66                .realip_remote_addr()
67                .or_else(|| connection_info.peer_addr())
68                .unwrap_or("unknown")
69                .split(':')
70                .next()
71                .unwrap_or("unknown")
72                .to_string()
73        };
74
75        let path = req.path().to_string();
76        let method = req.method().to_string();
77        let query_string = req.query_string().to_string();
78
79        let suspicious = path.contains("..")
80            || path.contains("<script")
81            || path.contains("sql")
82            || path.len() > 1000;
83
84        if suspicious {
85            let logger_clone = server_logger.clone();
86            let ip_clone = ip.clone();
87            let path_clone = path.clone();
88            tokio::spawn(async move {
89                let _ = logger_clone
90                    .log_security_alert(
91                        &ip_clone,
92                        "Suspicious Request",
93                        &format!("Suspicious path: {}", path_clone),
94                    )
95                    .await;
96            });
97        }
98
99        let headers: std::collections::HashMap<String, String> = req
100            .headers()
101            .iter()
102            .filter_map(|(name, value)| {
103                let header_name = name.as_str().to_lowercase();
104                if !["authorization", "cookie", "x-api-key"].contains(&header_name.as_str()) {
105                    value
106                        .to_str()
107                        .ok()
108                        .map(|v| (name.as_str().to_string(), v.to_string()))
109                } else {
110                    Some((name.as_str().to_string(), "[FILTERED]".to_string()))
111                }
112            })
113            .collect();
114
115        let fut = self.service.call(req);
116
117        Box::pin(async move {
118            let res = fut.await?;
119            let response_time = start_time.elapsed().as_millis() as u64;
120            let status = res.status().as_u16();
121            let bytes_sent = res
122                .response()
123                .headers()
124                .get("content-length")
125                .and_then(|h| h.to_str().ok())
126                .and_then(|s| s.parse().ok())
127                .unwrap_or(0);
128
129            let entry = crate::server::logging::ServerLogEntry {
130                timestamp: chrono::Local::now()
131                    .format("%Y-%m-%d %H:%M:%S%.3f")
132                    .to_string(),
133                timestamp_unix: std::time::SystemTime::now()
134                    .duration_since(std::time::UNIX_EPOCH)
135                    .unwrap_or_default()
136                    .as_secs(),
137                event_type: crate::server::logging::LogEventType::Request,
138                ip_address: ip,
139                user_agent: headers.get("user-agent").cloned(),
140                method,
141                path,
142                status_code: Some(status),
143                response_time_ms: Some(response_time),
144                bytes_sent: Some(bytes_sent),
145                referer: headers.get("referer").cloned(),
146                query_string: if query_string.is_empty() {
147                    None
148                } else {
149                    Some(query_string)
150                },
151                headers,
152                session_id: None,
153            };
154
155            if let Err(e) = server_logger.write_log_entry(entry).await {
156                log::error!("Failed to log request: {}", e);
157            }
158
159            Ok(res)
160        })
161    }
162}