1use crate::core::api_key::ApiKey;
2use actix_web::{
3 body::EitherBody,
4 dev::{Service, ServiceRequest, ServiceResponse, Transform},
5 Error, HttpResponse,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9 collections::{HashMap, VecDeque},
10 future::{ready, Ready},
11 sync::{Arc, Mutex},
12 time::Instant,
13};
14
15pub struct LoggingMiddleware {
16 server_logger: Arc<crate::server::logging::ServerLogger>,
17}
18
19impl LoggingMiddleware {
20 pub fn new(server_logger: Arc<crate::server::logging::ServerLogger>) -> Self {
21 Self { server_logger }
22 }
23}
24
25impl<S, B> Transform<S, ServiceRequest> for LoggingMiddleware
26where
27 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
28 S::Future: 'static,
29 B: 'static,
30{
31 type Response = ServiceResponse<B>;
32 type Error = Error;
33 type InitError = ();
34 type Transform = LoggingMiddlewareService<S>;
35 type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
36
37 fn new_transform(&self, service: S) -> Self::Future {
38 ready(Ok(LoggingMiddlewareService {
39 service,
40 server_logger: self.server_logger.clone(),
41 }))
42 }
43}
44
45pub struct LoggingMiddlewareService<S> {
46 service: S,
47 server_logger: Arc<crate::server::logging::ServerLogger>,
48}
49
50impl<S, B> Service<ServiceRequest> for LoggingMiddlewareService<S>
51where
52 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
53 S::Future: 'static,
54 B: 'static,
55{
56 type Response = ServiceResponse<B>;
57 type Error = Error;
58 type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
59
60 actix_web::dev::forward_ready!(service);
61
62 fn call(&self, req: ServiceRequest) -> Self::Future {
63 let start_time = Instant::now();
64 let server_logger = self.server_logger.clone();
65
66 let ip = {
67 let connection_info = req.connection_info();
68 connection_info
69 .realip_remote_addr()
70 .or_else(|| connection_info.peer_addr())
71 .unwrap_or("unknown")
72 .split(':')
73 .next()
74 .unwrap_or("unknown")
75 .to_string()
76 };
77
78 let path = req.path().to_string();
79 let method = req.method().to_string();
80 let query_string = req.query_string().to_string();
81
82 let suspicious = is_suspicious_path(&path);
83
84 if suspicious {
85 let logger_clone = server_logger.clone();
86 let ip_clone = ip.clone();
87 let path_clone = path.clone();
88 tokio::spawn(async move {
89 let _ = logger_clone
90 .log_security_alert(
91 &ip_clone,
92 "Suspicious Request",
93 &format!("Suspicious path: {}", path_clone),
94 )
95 .await;
96 });
97 }
98
99 let headers: std::collections::HashMap<String, String> = req
100 .headers()
101 .iter()
102 .filter_map(|(name, value)| {
103 let header_name = name.as_str().to_lowercase();
104 if !["authorization", "cookie", "x-api-key"].contains(&header_name.as_str()) {
105 value
106 .to_str()
107 .ok()
108 .map(|v| (name.as_str().to_string(), v.to_string()))
109 } else {
110 Some((name.as_str().to_string(), "[FILTERED]".to_string()))
111 }
112 })
113 .collect();
114
115 let fut = self.service.call(req);
116
117 Box::pin(async move {
118 let res = fut.await?;
119 let response_time = start_time.elapsed().as_millis() as u64;
120 let status = res.status().as_u16();
121 let bytes_sent = res
122 .response()
123 .headers()
124 .get("content-length")
125 .and_then(|h| h.to_str().ok())
126 .and_then(|s| s.parse().ok())
127 .unwrap_or(0);
128
129 let analytics_path = path.clone();
131 let analytics_ip = ip.clone();
132 let analytics_ua = headers.get("user-agent").cloned().unwrap_or_default();
133
134 let entry = crate::server::logging::ServerLogEntry {
135 timestamp: chrono::Local::now()
136 .format("%Y-%m-%d %H:%M:%S%.3f")
137 .to_string(),
138 timestamp_unix: std::time::SystemTime::now()
139 .duration_since(std::time::UNIX_EPOCH)
140 .unwrap_or_default()
141 .as_secs(),
142 event_type: crate::server::logging::LogEventType::Request,
143 ip_address: ip,
144 user_agent: headers.get("user-agent").cloned(),
145 method,
146 path,
147 status_code: Some(status),
148 response_time_ms: Some(response_time),
149 bytes_sent: Some(bytes_sent),
150 referer: headers.get("referer").cloned(),
151 query_string: if query_string.is_empty() {
152 None
153 } else {
154 Some(query_string)
155 },
156 headers,
157 session_id: None,
158 };
159
160 if let Err(e) = server_logger.write_log_entry(entry).await {
161 log::error!("Failed to log request: {}", e);
162 }
163
164 crate::server::analytics::track_request("", &analytics_path, &analytics_ip, &analytics_ua);
165
166 Ok(res)
167 })
168 }
169}
170
171fn percent_decode(input: &str) -> String {
172 let mut result = String::with_capacity(input.len());
173 let bytes = input.as_bytes();
174 let mut i = 0;
175 while i < bytes.len() {
176 if bytes[i] == b'%' && i + 2 < bytes.len() {
177 if let Ok(byte) = u8::from_str_radix(&input[i + 1..i + 3], 16) {
178 result.push(byte as char);
179 i += 3;
180 continue;
181 }
182 }
183 result.push(bytes[i] as char);
184 i += 1;
185 }
186 result
187}
188
189fn is_suspicious_path(path: &str) -> bool {
190 let decoded = percent_decode(path);
191 let normalized = decoded.replace('\\', "/");
192 let lower = normalized.to_lowercase();
193
194 normalized.contains("..")
195 || lower.contains("<script")
196 || lower.contains("union select")
197 || lower.contains("drop table")
198 || path.len() > 1000
199}
200
201#[derive(Clone)]
206pub struct ApiKeyAuth {
207 api_key: ApiKey,
208}
209
210impl ApiKeyAuth {
211 pub fn new(api_key: ApiKey) -> Self {
212 Self { api_key }
213 }
214}
215
216impl<S, B> Transform<S, ServiceRequest> for ApiKeyAuth
217where
218 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
219 S::Future: 'static,
220 B: 'static,
221{
222 type Response = ServiceResponse<EitherBody<B>>;
223 type Error = Error;
224 type InitError = ();
225 type Transform = ApiKeyAuthService<S>;
226 type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
227
228 fn new_transform(&self, service: S) -> Self::Future {
229 ready(Ok(ApiKeyAuthService {
230 service,
231 api_key: self.api_key.clone(),
232 }))
233 }
234}
235
236pub struct ApiKeyAuthService<S> {
237 service: S,
238 api_key: ApiKey,
239}
240
241impl<S, B> Service<ServiceRequest> for ApiKeyAuthService<S>
242where
243 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
244 S::Future: 'static,
245 B: 'static,
246{
247 type Response = ServiceResponse<EitherBody<B>>;
248 type Error = Error;
249 type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
250
251 actix_web::dev::forward_ready!(service);
252
253 fn call(&self, req: ServiceRequest) -> Self::Future {
254 let path = req.path().to_string();
255
256 let is_public_asset = path == "/.rss/"
257 || path == "/.rss/_reset.css"
258 || path == "/.rss/style.css"
259 || path == "/.rss/favicon.svg"
260 || path.starts_with("/.rss/js/")
261 || path.starts_with("/.rss/fonts/")
262 || path == "/ws/hot-reload";
263
264 let needs_auth =
265 (path.starts_with("/api/") || path.starts_with("/.rss/") || path.starts_with("/ws/"))
266 && path != "/api/health"
267 && !path.starts_with("/api/acme/")
268 && !path.starts_with("/api/analytics")
269 && !path.starts_with("/.well-known/")
270 && !is_public_asset;
271
272 if !needs_auth || self.api_key.is_empty() {
274 let fut = self.service.call(req);
275 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
276 }
277
278 let header_key = req
280 .headers()
281 .get("x-api-key")
282 .and_then(|v| v.to_str().ok())
283 .map(|s| s.to_string());
284
285 let query_key = req
287 .query_string()
288 .split('&')
289 .find_map(|param| param.strip_prefix("api_key="))
290 .map(|s| s.to_string());
291
292 let provided_key = header_key.or(query_key);
293
294 let is_valid = provided_key
295 .as_deref()
296 .map(|k| self.api_key.verify(k))
297 .unwrap_or(false);
298 if is_valid {
299 let fut = self.service.call(req);
300 Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
301 } else {
302 let response = HttpResponse::Unauthorized()
303 .json(serde_json::json!({
304 "error": "Unauthorized",
305 "message": "Valid API key required. Provide via X-API-Key header or ?api_key= query parameter."
306 }));
307 Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
308 }
309 }
310}
311
312#[derive(Clone)]
317pub struct RateLimiter {
318 max_rps: u32,
319 enabled: bool,
320 clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
321}
322
323impl RateLimiter {
324 pub fn new(max_rps: u32, enabled: bool) -> Self {
325 Self {
326 max_rps,
327 enabled,
328 clients: Arc::new(Mutex::new(HashMap::new())),
329 }
330 }
331}
332
333impl<S, B> Transform<S, ServiceRequest> for RateLimiter
334where
335 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
336 S::Future: 'static,
337 B: 'static,
338{
339 type Response = ServiceResponse<EitherBody<B>>;
340 type Error = Error;
341 type InitError = ();
342 type Transform = RateLimiterService<S>;
343 type Future = Ready<std::result::Result<Self::Transform, Self::InitError>>;
344
345 fn new_transform(&self, service: S) -> Self::Future {
346 ready(Ok(RateLimiterService {
347 service,
348 max_rps: self.max_rps,
349 enabled: self.enabled,
350 clients: self.clients.clone(),
351 }))
352 }
353}
354
355pub struct RateLimiterService<S> {
356 service: S,
357 max_rps: u32,
358 enabled: bool,
359 clients: Arc<Mutex<HashMap<String, VecDeque<Instant>>>>,
360}
361
362impl<S, B> Service<ServiceRequest> for RateLimiterService<S>
363where
364 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
365 S::Future: 'static,
366 B: 'static,
367{
368 type Response = ServiceResponse<EitherBody<B>>;
369 type Error = Error;
370 type Future = LocalBoxFuture<'static, std::result::Result<Self::Response, Self::Error>>;
371
372 actix_web::dev::forward_ready!(service);
373
374 fn call(&self, req: ServiceRequest) -> Self::Future {
375 if !self.enabled || !req.path().starts_with("/api/") {
377 let fut = self.service.call(req);
378 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
379 }
380
381 let ip = {
382 let connection_info = req.connection_info();
383 connection_info
384 .realip_remote_addr()
385 .or_else(|| connection_info.peer_addr())
386 .unwrap_or("unknown")
387 .split(':')
388 .next()
389 .unwrap_or("unknown")
390 .to_string()
391 };
392
393 let now = Instant::now();
394 let one_second_ago = now - std::time::Duration::from_secs(1);
395
396 let is_limited = if let Ok(mut clients) = self.clients.lock() {
397 let timestamps = clients.entry(ip).or_insert_with(VecDeque::new);
398
399 while timestamps.front().is_some_and(|t| *t < one_second_ago) {
401 timestamps.pop_front();
402 }
403
404 if timestamps.len() >= self.max_rps as usize {
405 true
406 } else {
407 timestamps.push_back(now);
408 false
409 }
410 } else {
411 false };
413
414 if is_limited {
415 let response = HttpResponse::TooManyRequests()
416 .insert_header(("Retry-After", "1"))
417 .json(serde_json::json!({
418 "error": "Too Many Requests",
419 "message": "Rate limit exceeded. Try again later.",
420 "retry_after": 1
421 }));
422 Box::pin(async move { Ok(req.into_response(response).map_into_right_body()) })
423 } else {
424 let fut = self.service.call(req);
425 Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) })
426 }
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
437 fn test_percent_decode_plain() {
438 assert_eq!(percent_decode("/api/status"), "/api/status");
439 }
440
441 #[test]
442 fn test_percent_decode_encoded_slash() {
443 assert_eq!(percent_decode("%2F"), "/");
444 }
445
446 #[test]
447 fn test_percent_decode_dot_dot() {
448 assert_eq!(percent_decode("%2e%2e"), "..");
449 }
450
451 #[test]
452 fn test_percent_decode_mixed() {
453 assert_eq!(percent_decode("/foo%2Fbar%2E%2E%2Fbaz"), "/foo/bar../baz");
454 }
455
456 #[test]
457 fn test_percent_decode_incomplete_sequence() {
458 assert_eq!(percent_decode("abc%2"), "abc%2");
459 }
460
461 #[test]
462 fn test_percent_decode_invalid_hex() {
463 assert_eq!(percent_decode("%ZZ"), "%ZZ");
464 }
465
466 #[test]
467 fn test_percent_decode_empty() {
468 assert_eq!(percent_decode(""), "");
469 }
470
471 #[test]
472 fn test_percent_decode_script_tag() {
473 assert_eq!(percent_decode("%3Cscript%3E"), "<script>");
474 }
475
476 #[test]
479 fn test_suspicious_path_traversal() {
480 assert!(is_suspicious_path("/../etc/passwd"));
481 assert!(is_suspicious_path("/foo/../../etc/shadow"));
482 }
483
484 #[test]
485 fn test_suspicious_path_encoded_traversal() {
486 assert!(is_suspicious_path("/%2e%2e/etc/passwd"));
487 assert!(is_suspicious_path("/%2E%2E/secret"));
488 }
489
490 #[test]
491 fn test_suspicious_path_backslash_traversal() {
492 assert!(is_suspicious_path("/foo\\..\\etc\\passwd"));
493 }
494
495 #[test]
496 fn test_suspicious_path_script_injection() {
497 assert!(is_suspicious_path("/<script>alert(1)</script>"));
498 assert!(is_suspicious_path("/%3Cscript%3Ealert(1)"));
499 }
500
501 #[test]
502 fn test_suspicious_path_sql_injection() {
503 assert!(is_suspicious_path("/api?q=1 UNION SELECT * FROM users"));
504 assert!(is_suspicious_path("/api?q=DROP TABLE users"));
505 }
506
507 #[test]
508 fn test_suspicious_path_too_long() {
509 let long_path = "/".to_string() + &"a".repeat(1001);
510 assert!(is_suspicious_path(&long_path));
511 }
512
513 #[test]
514 fn test_safe_paths() {
515 assert!(!is_suspicious_path("/"));
516 assert!(!is_suspicious_path("/api/status"));
517 assert!(!is_suspicious_path("/index.html"));
518 assert!(!is_suspicious_path("/.rss/style.css"));
519 assert!(!is_suspicious_path("/api/logs?offset=100"));
520 assert!(!is_suspicious_path("/ws/hot-reload"));
521 }
522
523 #[test]
524 fn test_safe_path_with_dots_in_filename() {
525 assert!(!is_suspicious_path("/file.name.html"));
526 assert!(!is_suspicious_path("/.rss/favicon.svg"));
527 }
528}