Skip to main content

rustgate/
logging.rs

1use crate::handler::{
2    extract_body_bytes, extract_response_body_bytes, put_body_back, put_response_body_back,
3    BoxBody, Buffered, Dropped, RequestHandler,
4};
5use base64::Engine;
6use bytes::Bytes;
7use hyper::{Request, Response};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io::Write;
11use std::path::Path;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::{mpsc, Arc, Mutex};
14use tracing::info;
15
16/// Unique ID for pairing request and response in the log.
17#[derive(Clone, Debug)]
18pub struct LogId(pub u64);
19
20/// Upstream target info stored in request extensions for logging.
21#[derive(Clone, Debug)]
22pub struct UpstreamTarget {
23    pub scheme: String,
24    pub host: String,
25    pub port: u16,
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub struct LogEntry {
30    pub id: u64,
31    pub timestamp_req: String,
32    pub timestamp_res: String,
33    pub request: LoggedRequest,
34    pub response: LoggedResponse,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
38pub struct LoggedRequest {
39    pub method: String,
40    pub uri: String,
41    pub version: String,
42    #[serde(default)]
43    pub target_scheme: String,
44    #[serde(default)]
45    pub target_host: String,
46    #[serde(default)]
47    pub target_port: u16,
48    pub headers: Vec<(String, String)>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub body: Option<String>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub body_base64: Option<String>,
53    pub body_truncated: bool,
54}
55
56#[derive(Serialize, Deserialize, Debug)]
57pub struct LoggedResponse {
58    pub status: u16,
59    pub version: String,
60    pub headers: Vec<(String, String)>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub body: Option<String>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub body_base64: Option<String>,
65    pub body_truncated: bool,
66}
67
68struct PendingLogEntry {
69    created_at: std::time::Instant,
70    timestamp_req: String,
71    request: LoggedRequest,
72}
73
74/// Format SystemTime as ISO 8601 UTC string.
75fn format_timestamp() -> String {
76    let d = std::time::SystemTime::now()
77        .duration_since(std::time::UNIX_EPOCH)
78        .unwrap_or_default();
79    let secs = d.as_secs();
80    let millis = d.subsec_millis();
81
82    // Simple UTC datetime formatting
83    let days = secs / 86400;
84    let time_secs = secs % 86400;
85    let hours = time_secs / 3600;
86    let minutes = (time_secs % 3600) / 60;
87    let seconds = time_secs % 60;
88
89    // Days since epoch to Y-M-D (simplified leap year calculation)
90    let (year, month, day) = days_to_ymd(days);
91
92    format!(
93        "{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}.{millis:03}Z"
94    )
95}
96
97fn days_to_ymd(mut days: u64) -> (u64, u64, u64) {
98    let mut year = 1970;
99    loop {
100        let days_in_year = if is_leap(year) { 366 } else { 365 };
101        if days < days_in_year {
102            break;
103        }
104        days -= days_in_year;
105        year += 1;
106    }
107    let month_days = if is_leap(year) {
108        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
109    } else {
110        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
111    };
112    let mut month = 1;
113    for &md in &month_days {
114        if days < md {
115            break;
116        }
117        days -= md;
118        month += 1;
119    }
120    (year, month, days + 1)
121}
122
123fn is_leap(y: u64) -> bool {
124    (y.is_multiple_of(4) && !y.is_multiple_of(100)) || y.is_multiple_of(400)
125}
126
127/// Encode body bytes for logging.
128/// Returns (body_text, body_base64, body_truncated).
129/// `may_have_body` is true if the message had Content-Length or Transfer-Encoding,
130/// indicating a body was expected but may not have been captured.
131fn encode_body(
132    bytes: &Bytes,
133    is_buffered: bool,
134    may_have_body: bool,
135) -> (Option<String>, Option<String>, bool) {
136    if !is_buffered || bytes.is_empty() {
137        // truncated if body was expected but we couldn't buffer it
138        let truncated = !is_buffered && may_have_body;
139        return (None, None, truncated);
140    }
141    match std::str::from_utf8(bytes) {
142        Ok(text) => (Some(text.to_string()), None, false),
143        Err(_) => {
144            let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
145            (None, Some(b64), false)
146        }
147    }
148}
149
150/// Safe headers that are logged verbatim. All others are redacted to prevent
151/// credential persistence (Authorization, Cookie, vendor API keys, etc.).
152const SAFE_LOG_HEADERS: &[&str] = &[
153    "accept", "accept-encoding", "accept-language", "cache-control",
154    "connection", "content-encoding", "content-language", "content-length",
155    "content-type", "date", "etag", "expires", "host", "if-match",
156    "if-modified-since", "if-none-match", "if-unmodified-since",
157    "last-modified", "location", "pragma", "range", "server",
158    "transfer-encoding", "user-agent", "vary", "via",
159    "access-control-allow-origin", "access-control-allow-methods",
160    "access-control-allow-headers", "access-control-max-age",
161    "x-content-type-options", "x-frame-options", "x-request-id",
162    "strict-transport-security", "content-security-policy",
163];
164
165fn capture_headers(headers: &hyper::HeaderMap) -> Vec<(String, String)> {
166    headers
167        .iter()
168        .map(|(name, value)| {
169            let val = if SAFE_LOG_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h)) {
170                value.to_str().unwrap_or("<binary>").to_string()
171            } else {
172                "<redacted>".to_string()
173            };
174            (name.to_string(), val)
175        })
176        .collect()
177}
178
179/// Redact query parameter values in a URI to prevent credential persistence.
180/// `/path?key=secret&token=xxx` → `/path?key=<redacted>&token=<redacted>`
181fn redact_query_values(uri: &hyper::Uri) -> String {
182    let path = uri.path();
183    match uri.query() {
184        None => path.to_string(),
185        Some(query) => {
186            let redacted: Vec<String> = query
187                .split('&')
188                .map(|pair| {
189                    if let Some((key, _)) = pair.split_once('=') {
190                        format!("{key}=<redacted>")
191                    } else {
192                        pair.to_string()
193                    }
194                })
195                .collect();
196            format!("{path}?{}", redacted.join("&"))
197        }
198    }
199}
200
201/// Background writer thread that receives LogEntry values and writes JSON Lines.
202struct LogWriter {
203    rx: mpsc::Receiver<LogEntry>,
204    file: std::io::BufWriter<std::fs::File>,
205}
206
207impl LogWriter {
208    fn run(mut self) {
209        while let Ok(entry) = self.rx.recv() {
210            match serde_json::to_string(&entry) {
211                Ok(json) => {
212                    if let Err(e) = writeln!(self.file, "{json}") {
213                        eprintln!("[rustgate] Traffic log write error: {e}");
214                    }
215                    if let Err(e) = self.file.flush() {
216                        eprintln!("[rustgate] Traffic log flush error: {e}");
217                    }
218                }
219                Err(e) => {
220                    eprintln!("[rustgate] Traffic log serialize error: {e}");
221                }
222            }
223        }
224    }
225}
226
227/// Decorator handler that logs traffic to a JSON Lines file.
228/// Wraps any inner RequestHandler.
229pub struct TrafficLogHandler {
230    inner: Arc<dyn RequestHandler>,
231    tx: mpsc::SyncSender<LogEntry>,
232    next_id: AtomicU64,
233    pending: Mutex<HashMap<u64, PendingLogEntry>>,
234}
235
236impl TrafficLogHandler {
237    pub fn new(
238        inner: Arc<dyn RequestHandler>,
239        path: &Path,
240    ) -> std::io::Result<Self> {
241        // Reject symlinks to prevent writing to unintended locations
242        #[cfg(unix)]
243        if let Ok(meta) = std::fs::symlink_metadata(path) {
244            if meta.file_type().is_symlink() {
245                return Err(std::io::Error::new(
246                    std::io::ErrorKind::InvalidInput,
247                    format!("Refusing to write log to symlink: {}", path.display()),
248                ));
249            }
250        }
251
252        // Create with restricted permissions (owner-only on Unix)
253        #[cfg(unix)]
254        let file = {
255            use std::os::unix::fs::OpenOptionsExt;
256            use std::os::unix::fs::PermissionsExt;
257            let f = std::fs::OpenOptions::new()
258                .create(true)
259                .append(true)
260                .mode(0o600)
261                .open(path)?;
262            // Force 0o600 even if the file already existed with broader permissions
263            f.set_permissions(std::fs::Permissions::from_mode(0o600))?;
264            f
265        };
266        #[cfg(not(unix))]
267        let file = std::fs::OpenOptions::new()
268            .create(true)
269            .append(true)
270            .open(path)?;
271        let writer = std::io::BufWriter::new(file);
272        let (tx, rx) = mpsc::sync_channel(256);
273
274        std::thread::spawn(move || {
275            LogWriter { rx, file: writer }.run();
276        });
277
278        info!("Traffic logging to {}", path.display());
279
280        Ok(Self {
281            inner,
282            tx,
283            next_id: AtomicU64::new(1),
284            pending: Mutex::new(HashMap::new()),
285        })
286    }
287}
288
289impl RequestHandler for TrafficLogHandler {
290    fn handle_request(&self, req: &mut Request<BoxBody>) {
291        // Let inner handler process first (e.g., InterceptHandler may modify/drop)
292        self.inner.handle_request(req);
293
294        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
295        let is_buffered = req.extensions().get::<Buffered>().is_some();
296        let is_dropped = req.extensions().get::<Dropped>().is_some();
297
298        // Capture request data (after inner handler's modifications)
299        let body_bytes = if is_buffered && !is_dropped {
300            let b = extract_body_bytes(req);
301            put_body_back(req, b.clone());
302            b
303        } else {
304            Bytes::new()
305        };
306
307        let may_have_body = req.headers().contains_key(hyper::header::CONTENT_LENGTH)
308            || req.headers().contains_key(hyper::header::TRANSFER_ENCODING);
309        let (body, body_base64, body_truncated) = encode_body(&body_bytes, is_buffered, may_have_body);
310
311        let upstream = req.extensions().get::<UpstreamTarget>().cloned();
312        let logged_req = LoggedRequest {
313            method: req.method().to_string(),
314            uri: redact_query_values(req.uri()),
315            version: format!("{:?}", req.version()),
316            target_scheme: upstream.as_ref().map(|t| t.scheme.clone()).unwrap_or_default(),
317            target_host: upstream.as_ref().map(|t| t.host.clone()).unwrap_or_default(),
318            target_port: upstream.as_ref().map(|t| t.port).unwrap_or(0),
319            headers: capture_headers(req.headers()),
320            body,
321            body_base64,
322            body_truncated,
323        };
324
325        // If dropped, emit log entry immediately with synthetic response
326        if is_dropped {
327            let entry = LogEntry {
328                id,
329                timestamp_req: format_timestamp(),
330                timestamp_res: format_timestamp(),
331                request: logged_req,
332                response: LoggedResponse {
333                    status: 0,
334                    version: String::new(),
335                    headers: Vec::new(),
336                    body: None,
337                    body_base64: None,
338                    body_truncated: true,
339                },
340            };
341            if self.tx.try_send(entry).is_err() {
342                tracing::warn!("Traffic log queue full, entry dropped");
343            }
344            return;
345        }
346
347        // Store pending for pairing with response
348        req.extensions_mut().insert(LogId(id));
349        if let Ok(mut pending) = self.pending.lock() {
350            // Expire stale entries (>60s) instead of evicting live ones
351            let now = std::time::Instant::now();
352            let expired: Vec<u64> = pending
353                .iter()
354                .filter(|(_, v)| now.duration_since(v.created_at).as_secs() > 300)
355                .map(|(k, _)| *k)
356                .collect();
357            for eid in &expired {
358                if let Some(stale) = pending.remove(eid) {
359                    tracing::warn!("Expired unpaired log entry {eid} (>300s)");
360                    // Emit synthetic timeout entry
361                    let entry = LogEntry {
362                        id: *eid,
363                        timestamp_req: stale.timestamp_req,
364                        timestamp_res: format_timestamp(),
365                        request: stale.request,
366                        response: LoggedResponse {
367                            status: 0,
368                            version: String::new(),
369                            headers: Vec::new(),
370                            body: None,
371                            body_base64: None,
372                            body_truncated: true,
373                        },
374                    };
375                    if self.tx.try_send(entry).is_err() {
376                        tracing::warn!("Traffic log queue full, expired entry dropped");
377                    }
378                }
379            }
380            pending.insert(id, PendingLogEntry {
381                created_at: now,
382                timestamp_req: format_timestamp(),
383                request: logged_req,
384            });
385        }
386    }
387
388    fn handle_response(&self, res: &mut Response<BoxBody>) {
389        let log_id = res.extensions().get::<LogId>().cloned();
390
391        // Let inner handler process response FIRST (e.g., interceptor may edit/drop)
392        self.inner.handle_response(res);
393
394        // Now capture the final post-interception state for logging
395        let is_buffered = res.extensions().get::<Buffered>().is_some();
396        let is_dropped = res.extensions().get::<Dropped>().is_some();
397
398        let body_bytes = if is_buffered && !is_dropped {
399            let b = extract_response_body_bytes(res);
400            put_response_body_back(res, b.clone());
401            b
402        } else {
403            Bytes::new()
404        };
405
406        let may_have_body = res.headers().contains_key(hyper::header::CONTENT_LENGTH)
407            || res.headers().contains_key(hyper::header::TRANSFER_ENCODING);
408        let (body, body_base64, body_truncated) = encode_body(&body_bytes, is_buffered, may_have_body);
409
410        let logged_res = LoggedResponse {
411            status: if is_dropped { 0 } else { res.status().as_u16() },
412            version: format!("{:?}", res.version()),
413            headers: if is_dropped { Vec::new() } else { capture_headers(res.headers()) },
414            body,
415            body_base64,
416            body_truncated: body_truncated || is_dropped,
417        };
418
419        // Pair with pending request
420        if let Some(LogId(id)) = log_id {
421            let pending_entry = self.pending.lock().ok().and_then(|mut p| p.remove(&id));
422            if let Some(pending) = pending_entry {
423                let entry = LogEntry {
424                    id,
425                    timestamp_req: pending.timestamp_req,
426                    timestamp_res: format_timestamp(),
427                    request: pending.request,
428                    response: logged_res,
429                };
430                if self.tx.try_send(entry).is_err() {
431                tracing::warn!("Traffic log queue full, entry dropped");
432            }
433            }
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_log_entry_serde_roundtrip() {
444        let entry = LogEntry {
445            id: 1,
446            timestamp_req: "2026-04-11T12:00:00.000Z".into(),
447            timestamp_res: "2026-04-11T12:00:00.123Z".into(),
448            request: LoggedRequest {
449                method: "GET".into(),
450                uri: "/api".into(),
451                version: "HTTP/1.1".into(),
452                target_scheme: "https".into(),
453                target_host: "example.com".into(),
454                target_port: 443,
455                headers: vec![("host".into(), "example.com".into())],
456                body: None,
457                body_base64: None,
458                body_truncated: false,
459            },
460            response: LoggedResponse {
461                status: 200,
462                version: "HTTP/1.1".into(),
463                headers: vec![("content-type".into(), "application/json".into())],
464                body: Some("{\"ok\":true}".into()),
465                body_base64: None,
466                body_truncated: false,
467            },
468        };
469        let json = serde_json::to_string(&entry).unwrap();
470        let parsed: LogEntry = serde_json::from_str(&json).unwrap();
471        assert_eq!(parsed.id, 1);
472        assert_eq!(parsed.request.method, "GET");
473        assert_eq!(parsed.response.status, 200);
474    }
475
476    #[test]
477    fn test_encode_body_utf8() {
478        let bytes = Bytes::from("hello world");
479        let (body, b64, trunc) = encode_body(&bytes, true, true);
480        assert_eq!(body.unwrap(), "hello world");
481        assert!(b64.is_none());
482        assert!(!trunc);
483    }
484
485    #[test]
486    fn test_encode_body_binary() {
487        let bytes = Bytes::from(vec![0xFF, 0xFE, 0x00, 0x01]);
488        let (body, b64, trunc) = encode_body(&bytes, true, true);
489        assert!(body.is_none());
490        assert!(b64.is_some());
491        assert!(!trunc);
492    }
493
494    #[test]
495    fn test_encode_body_not_buffered_with_cl() {
496        // Has Content-Length but wasn't buffered → truncated
497        let bytes = Bytes::new();
498        let (body, b64, trunc) = encode_body(&bytes, false, true);
499        assert!(body.is_none());
500        assert!(b64.is_none());
501        assert!(trunc);
502    }
503
504    #[test]
505    fn test_encode_body_not_buffered_no_cl() {
506        // No Content-Length, not buffered → NOT truncated (bodyless request)
507        let bytes = Bytes::new();
508        let (body, b64, trunc) = encode_body(&bytes, false, false);
509        assert!(body.is_none());
510        assert!(b64.is_none());
511        assert!(!trunc);
512    }
513
514    #[test]
515    fn test_format_timestamp() {
516        let ts = format_timestamp();
517        assert!(ts.ends_with('Z'));
518        assert!(ts.contains('T'));
519    }
520}