Skip to main content

shaperail_runtime/observability/
middleware.rs

1use std::collections::HashSet;
2use std::future::{ready, Future, Ready};
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Instant;
6
7use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
8use actix_web::web;
9use actix_web::Error;
10
11use super::metrics::MetricsState;
12
13/// Actix-web middleware that logs every request/response as a single structured JSON line.
14///
15/// Logged fields: method, path, status, duration_ms, user_id, request_id.
16/// Sensitive fields from the resource schema are never included in log output.
17#[derive(Clone)]
18pub struct RequestLogger {
19    sensitive_fields: Arc<HashSet<String>>,
20}
21
22impl RequestLogger {
23    pub fn new(sensitive_fields: HashSet<String>) -> Self {
24        Self {
25            sensitive_fields: Arc::new(sensitive_fields),
26        }
27    }
28}
29
30impl<S, B> Transform<S, ServiceRequest> for RequestLogger
31where
32    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
33    B: 'static,
34{
35    type Response = ServiceResponse<B>;
36    type Error = Error;
37    type Transform = RequestLoggerMiddleware<S>;
38    type InitError = ();
39    type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41    fn new_transform(&self, service: S) -> Self::Future {
42        ready(Ok(RequestLoggerMiddleware {
43            service,
44            _sensitive_fields: self.sensitive_fields.clone(),
45        }))
46    }
47}
48
49pub struct RequestLoggerMiddleware<S> {
50    service: S,
51    _sensitive_fields: Arc<HashSet<String>>,
52}
53
54impl<S, B> Service<ServiceRequest> for RequestLoggerMiddleware<S>
55where
56    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57    B: 'static,
58{
59    type Response = ServiceResponse<B>;
60    type Error = Error;
61    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
62
63    forward_ready!(service);
64
65    fn call(&self, mut req: ServiceRequest) -> Self::Future {
66        let start = Instant::now();
67        let method = req.method().to_string();
68        let path = req.path().to_string();
69        let request_id = uuid::Uuid::new_v4().to_string();
70        let metrics = req.app_data::<web::Data<MetricsState>>().cloned();
71
72        // Extract user_id from request extensions if auth middleware has run
73        let user_id = req
74            .headers()
75            .get("x-request-user-id")
76            .and_then(|v| v.to_str().ok())
77            .map(String::from);
78
79        // Insert request_id header for downstream use
80        req.headers_mut().insert(
81            actix_web::http::header::HeaderName::from_static("x-request-id"),
82            actix_web::http::header::HeaderValue::from_str(&request_id)
83                .unwrap_or_else(|_| actix_web::http::header::HeaderValue::from_static("unknown")),
84        );
85
86        let fut = self.service.call(req);
87
88        Box::pin(async move {
89            let res = fut.await?;
90
91            let status = res.status().as_u16();
92            let duration = start.elapsed();
93            let duration_ms = duration.as_millis() as u64;
94
95            if let Some(metrics) = metrics {
96                metrics.record_request(&method, &path, status, duration.as_secs_f64());
97                if status >= 400 {
98                    metrics.record_error(&format!("http_{status}"));
99                }
100            }
101
102            tracing::info!(
103                request_id = %request_id,
104                method = %method,
105                path = %path,
106                status = status,
107                duration_ms = duration_ms,
108                user_id = user_id.as_deref().unwrap_or("-"),
109                "request completed"
110            );
111
112            Ok(res)
113        })
114    }
115}
116
117/// Extracts the request_id from the request headers.
118pub fn get_request_id(req: &actix_web::HttpRequest) -> String {
119    req.headers()
120        .get("x-request-id")
121        .and_then(|v| v.to_str().ok())
122        .unwrap_or("-")
123        .to_string()
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn request_logger_new() {
132        let sensitive = HashSet::new();
133        let logger = RequestLogger::new(sensitive);
134        assert!(logger.sensitive_fields.is_empty());
135    }
136
137    #[test]
138    fn get_request_id_missing() {
139        let req = actix_web::test::TestRequest::default().to_http_request();
140        assert_eq!(get_request_id(&req), "-");
141    }
142
143    #[test]
144    fn get_request_id_present() {
145        let req = actix_web::test::TestRequest::default()
146            .insert_header(("x-request-id", "abc-123"))
147            .to_http_request();
148        assert_eq!(get_request_id(&req), "abc-123");
149    }
150}