1use crate::handler::{
2 extract_body_bytes, extract_response_body_bytes, put_body_back, put_response_body_back,
3 BoxBody, Buffered, Dropped, RequestHandler,
4};
5use base64::Engine;
6use bytes::Bytes;
7use hyper::{Request, Response};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io::Write;
11use std::path::Path;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::{mpsc, Arc, Mutex};
14use tracing::info;
15
16#[derive(Clone, Debug)]
18pub struct LogId(pub u64);
19
20#[derive(Clone, Debug)]
22pub struct UpstreamTarget {
23 pub scheme: String,
24 pub host: String,
25 pub port: u16,
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub struct LogEntry {
30 pub id: u64,
31 pub timestamp_req: String,
32 pub timestamp_res: String,
33 pub request: LoggedRequest,
34 pub response: LoggedResponse,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
38pub struct LoggedRequest {
39 pub method: String,
40 pub uri: String,
41 pub version: String,
42 #[serde(default)]
43 pub target_scheme: String,
44 #[serde(default)]
45 pub target_host: String,
46 #[serde(default)]
47 pub target_port: u16,
48 pub headers: Vec<(String, String)>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub body: Option<String>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub body_base64: Option<String>,
53 pub body_truncated: bool,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct LoggedResponse {
58 pub status: u16,
59 pub version: String,
60 pub headers: Vec<(String, String)>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub body: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub body_base64: Option<String>,
65 pub body_truncated: bool,
66}
67
68struct PendingLogEntry {
69 created_at: std::time::Instant,
70 timestamp_req: String,
71 request: LoggedRequest,
72}
73
74fn format_timestamp() -> String {
76 let d = std::time::SystemTime::now()
77 .duration_since(std::time::UNIX_EPOCH)
78 .unwrap_or_default();
79 let secs = d.as_secs();
80 let millis = d.subsec_millis();
81
82 let days = secs / 86400;
84 let time_secs = secs % 86400;
85 let hours = time_secs / 3600;
86 let minutes = (time_secs % 3600) / 60;
87 let seconds = time_secs % 60;
88
89 let (year, month, day) = days_to_ymd(days);
91
92 format!(
93 "{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}.{millis:03}Z"
94 )
95}
96
97fn days_to_ymd(mut days: u64) -> (u64, u64, u64) {
98 let mut year = 1970;
99 loop {
100 let days_in_year = if is_leap(year) { 366 } else { 365 };
101 if days < days_in_year {
102 break;
103 }
104 days -= days_in_year;
105 year += 1;
106 }
107 let month_days = if is_leap(year) {
108 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
109 } else {
110 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
111 };
112 let mut month = 1;
113 for &md in &month_days {
114 if days < md {
115 break;
116 }
117 days -= md;
118 month += 1;
119 }
120 (year, month, days + 1)
121}
122
123fn is_leap(y: u64) -> bool {
124 (y.is_multiple_of(4) && !y.is_multiple_of(100)) || y.is_multiple_of(400)
125}
126
127fn encode_body(
132 bytes: &Bytes,
133 is_buffered: bool,
134 may_have_body: bool,
135) -> (Option<String>, Option<String>, bool) {
136 if !is_buffered || bytes.is_empty() {
137 let truncated = !is_buffered && may_have_body;
139 return (None, None, truncated);
140 }
141 match std::str::from_utf8(bytes) {
142 Ok(text) => (Some(text.to_string()), None, false),
143 Err(_) => {
144 let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
145 (None, Some(b64), false)
146 }
147 }
148}
149
150const SAFE_LOG_HEADERS: &[&str] = &[
153 "accept", "accept-encoding", "accept-language", "cache-control",
154 "connection", "content-encoding", "content-language", "content-length",
155 "content-type", "date", "etag", "expires", "host", "if-match",
156 "if-modified-since", "if-none-match", "if-unmodified-since",
157 "last-modified", "location", "pragma", "range", "server",
158 "transfer-encoding", "user-agent", "vary", "via",
159 "access-control-allow-origin", "access-control-allow-methods",
160 "access-control-allow-headers", "access-control-max-age",
161 "x-content-type-options", "x-frame-options", "x-request-id",
162 "strict-transport-security", "content-security-policy",
163];
164
165fn capture_headers(headers: &hyper::HeaderMap) -> Vec<(String, String)> {
166 headers
167 .iter()
168 .map(|(name, value)| {
169 let val = if SAFE_LOG_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h)) {
170 value.to_str().unwrap_or("<binary>").to_string()
171 } else {
172 "<redacted>".to_string()
173 };
174 (name.to_string(), val)
175 })
176 .collect()
177}
178
179fn redact_query_values(uri: &hyper::Uri) -> String {
182 let path = uri.path();
183 match uri.query() {
184 None => path.to_string(),
185 Some(query) => {
186 let redacted: Vec<String> = query
187 .split('&')
188 .map(|pair| {
189 if let Some((key, _)) = pair.split_once('=') {
190 format!("{key}=<redacted>")
191 } else {
192 pair.to_string()
193 }
194 })
195 .collect();
196 format!("{path}?{}", redacted.join("&"))
197 }
198 }
199}
200
201struct LogWriter {
203 rx: mpsc::Receiver<LogEntry>,
204 file: std::io::BufWriter<std::fs::File>,
205}
206
207impl LogWriter {
208 fn run(mut self) {
209 while let Ok(entry) = self.rx.recv() {
210 match serde_json::to_string(&entry) {
211 Ok(json) => {
212 if let Err(e) = writeln!(self.file, "{json}") {
213 eprintln!("[rustgate] Traffic log write error: {e}");
214 }
215 if let Err(e) = self.file.flush() {
216 eprintln!("[rustgate] Traffic log flush error: {e}");
217 }
218 }
219 Err(e) => {
220 eprintln!("[rustgate] Traffic log serialize error: {e}");
221 }
222 }
223 }
224 }
225}
226
227pub struct TrafficLogHandler {
230 inner: Arc<dyn RequestHandler>,
231 tx: mpsc::SyncSender<LogEntry>,
232 next_id: AtomicU64,
233 pending: Mutex<HashMap<u64, PendingLogEntry>>,
234}
235
236impl TrafficLogHandler {
237 pub fn new(
238 inner: Arc<dyn RequestHandler>,
239 path: &Path,
240 ) -> std::io::Result<Self> {
241 #[cfg(unix)]
243 if let Ok(meta) = std::fs::symlink_metadata(path) {
244 if meta.file_type().is_symlink() {
245 return Err(std::io::Error::new(
246 std::io::ErrorKind::InvalidInput,
247 format!("Refusing to write log to symlink: {}", path.display()),
248 ));
249 }
250 }
251
252 #[cfg(unix)]
254 let file = {
255 use std::os::unix::fs::OpenOptionsExt;
256 use std::os::unix::fs::PermissionsExt;
257 let f = std::fs::OpenOptions::new()
258 .create(true)
259 .append(true)
260 .mode(0o600)
261 .open(path)?;
262 f.set_permissions(std::fs::Permissions::from_mode(0o600))?;
264 f
265 };
266 #[cfg(not(unix))]
267 let file = std::fs::OpenOptions::new()
268 .create(true)
269 .append(true)
270 .open(path)?;
271 let writer = std::io::BufWriter::new(file);
272 let (tx, rx) = mpsc::sync_channel(256);
273
274 std::thread::spawn(move || {
275 LogWriter { rx, file: writer }.run();
276 });
277
278 info!("Traffic logging to {}", path.display());
279
280 Ok(Self {
281 inner,
282 tx,
283 next_id: AtomicU64::new(1),
284 pending: Mutex::new(HashMap::new()),
285 })
286 }
287}
288
289impl RequestHandler for TrafficLogHandler {
290 fn handle_request(&self, req: &mut Request<BoxBody>) {
291 self.inner.handle_request(req);
293
294 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
295 let is_buffered = req.extensions().get::<Buffered>().is_some();
296 let is_dropped = req.extensions().get::<Dropped>().is_some();
297
298 let body_bytes = if is_buffered && !is_dropped {
300 let b = extract_body_bytes(req);
301 put_body_back(req, b.clone());
302 b
303 } else {
304 Bytes::new()
305 };
306
307 let may_have_body = req.headers().contains_key(hyper::header::CONTENT_LENGTH)
308 || req.headers().contains_key(hyper::header::TRANSFER_ENCODING);
309 let (body, body_base64, body_truncated) = encode_body(&body_bytes, is_buffered, may_have_body);
310
311 let upstream = req.extensions().get::<UpstreamTarget>().cloned();
312 let logged_req = LoggedRequest {
313 method: req.method().to_string(),
314 uri: redact_query_values(req.uri()),
315 version: format!("{:?}", req.version()),
316 target_scheme: upstream.as_ref().map(|t| t.scheme.clone()).unwrap_or_default(),
317 target_host: upstream.as_ref().map(|t| t.host.clone()).unwrap_or_default(),
318 target_port: upstream.as_ref().map(|t| t.port).unwrap_or(0),
319 headers: capture_headers(req.headers()),
320 body,
321 body_base64,
322 body_truncated,
323 };
324
325 if is_dropped {
327 let entry = LogEntry {
328 id,
329 timestamp_req: format_timestamp(),
330 timestamp_res: format_timestamp(),
331 request: logged_req,
332 response: LoggedResponse {
333 status: 0,
334 version: String::new(),
335 headers: Vec::new(),
336 body: None,
337 body_base64: None,
338 body_truncated: true,
339 },
340 };
341 if self.tx.try_send(entry).is_err() {
342 tracing::warn!("Traffic log queue full, entry dropped");
343 }
344 return;
345 }
346
347 req.extensions_mut().insert(LogId(id));
349 if let Ok(mut pending) = self.pending.lock() {
350 let now = std::time::Instant::now();
352 let expired: Vec<u64> = pending
353 .iter()
354 .filter(|(_, v)| now.duration_since(v.created_at).as_secs() > 300)
355 .map(|(k, _)| *k)
356 .collect();
357 for eid in &expired {
358 if let Some(stale) = pending.remove(eid) {
359 tracing::warn!("Expired unpaired log entry {eid} (>300s)");
360 let entry = LogEntry {
362 id: *eid,
363 timestamp_req: stale.timestamp_req,
364 timestamp_res: format_timestamp(),
365 request: stale.request,
366 response: LoggedResponse {
367 status: 0,
368 version: String::new(),
369 headers: Vec::new(),
370 body: None,
371 body_base64: None,
372 body_truncated: true,
373 },
374 };
375 if self.tx.try_send(entry).is_err() {
376 tracing::warn!("Traffic log queue full, expired entry dropped");
377 }
378 }
379 }
380 pending.insert(id, PendingLogEntry {
381 created_at: now,
382 timestamp_req: format_timestamp(),
383 request: logged_req,
384 });
385 }
386 }
387
388 fn handle_response(&self, res: &mut Response<BoxBody>) {
389 let log_id = res.extensions().get::<LogId>().cloned();
390
391 self.inner.handle_response(res);
393
394 let is_buffered = res.extensions().get::<Buffered>().is_some();
396 let is_dropped = res.extensions().get::<Dropped>().is_some();
397
398 let body_bytes = if is_buffered && !is_dropped {
399 let b = extract_response_body_bytes(res);
400 put_response_body_back(res, b.clone());
401 b
402 } else {
403 Bytes::new()
404 };
405
406 let may_have_body = res.headers().contains_key(hyper::header::CONTENT_LENGTH)
407 || res.headers().contains_key(hyper::header::TRANSFER_ENCODING);
408 let (body, body_base64, body_truncated) = encode_body(&body_bytes, is_buffered, may_have_body);
409
410 let logged_res = LoggedResponse {
411 status: if is_dropped { 0 } else { res.status().as_u16() },
412 version: format!("{:?}", res.version()),
413 headers: if is_dropped { Vec::new() } else { capture_headers(res.headers()) },
414 body,
415 body_base64,
416 body_truncated: body_truncated || is_dropped,
417 };
418
419 if let Some(LogId(id)) = log_id {
421 let pending_entry = self.pending.lock().ok().and_then(|mut p| p.remove(&id));
422 if let Some(pending) = pending_entry {
423 let entry = LogEntry {
424 id,
425 timestamp_req: pending.timestamp_req,
426 timestamp_res: format_timestamp(),
427 request: pending.request,
428 response: logged_res,
429 };
430 if self.tx.try_send(entry).is_err() {
431 tracing::warn!("Traffic log queue full, entry dropped");
432 }
433 }
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_log_entry_serde_roundtrip() {
444 let entry = LogEntry {
445 id: 1,
446 timestamp_req: "2026-04-11T12:00:00.000Z".into(),
447 timestamp_res: "2026-04-11T12:00:00.123Z".into(),
448 request: LoggedRequest {
449 method: "GET".into(),
450 uri: "/api".into(),
451 version: "HTTP/1.1".into(),
452 target_scheme: "https".into(),
453 target_host: "example.com".into(),
454 target_port: 443,
455 headers: vec![("host".into(), "example.com".into())],
456 body: None,
457 body_base64: None,
458 body_truncated: false,
459 },
460 response: LoggedResponse {
461 status: 200,
462 version: "HTTP/1.1".into(),
463 headers: vec![("content-type".into(), "application/json".into())],
464 body: Some("{\"ok\":true}".into()),
465 body_base64: None,
466 body_truncated: false,
467 },
468 };
469 let json = serde_json::to_string(&entry).unwrap();
470 let parsed: LogEntry = serde_json::from_str(&json).unwrap();
471 assert_eq!(parsed.id, 1);
472 assert_eq!(parsed.request.method, "GET");
473 assert_eq!(parsed.response.status, 200);
474 }
475
476 #[test]
477 fn test_encode_body_utf8() {
478 let bytes = Bytes::from("hello world");
479 let (body, b64, trunc) = encode_body(&bytes, true, true);
480 assert_eq!(body.unwrap(), "hello world");
481 assert!(b64.is_none());
482 assert!(!trunc);
483 }
484
485 #[test]
486 fn test_encode_body_binary() {
487 let bytes = Bytes::from(vec![0xFF, 0xFE, 0x00, 0x01]);
488 let (body, b64, trunc) = encode_body(&bytes, true, true);
489 assert!(body.is_none());
490 assert!(b64.is_some());
491 assert!(!trunc);
492 }
493
494 #[test]
495 fn test_encode_body_not_buffered_with_cl() {
496 let bytes = Bytes::new();
498 let (body, b64, trunc) = encode_body(&bytes, false, true);
499 assert!(body.is_none());
500 assert!(b64.is_none());
501 assert!(trunc);
502 }
503
504 #[test]
505 fn test_encode_body_not_buffered_no_cl() {
506 let bytes = Bytes::new();
508 let (body, b64, trunc) = encode_body(&bytes, false, false);
509 assert!(body.is_none());
510 assert!(b64.is_none());
511 assert!(!trunc);
512 }
513
514 #[test]
515 fn test_format_timestamp() {
516 let ts = format_timestamp();
517 assert!(ts.ends_with('Z'));
518 assert!(ts.contains('T'));
519 }
520}