shaperail_runtime/observability/
middleware.rs1use std::collections::HashSet;
2use std::future::{ready, Future, Ready};
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Instant;
6
7use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
8use actix_web::web;
9use actix_web::Error;
10
11use super::metrics::MetricsState;
12
13#[derive(Clone)]
18pub struct RequestLogger {
19 sensitive_fields: Arc<HashSet<String>>,
20}
21
22impl RequestLogger {
23 pub fn new(sensitive_fields: HashSet<String>) -> Self {
24 Self {
25 sensitive_fields: Arc::new(sensitive_fields),
26 }
27 }
28}
29
30impl<S, B> Transform<S, ServiceRequest> for RequestLogger
31where
32 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
33 B: 'static,
34{
35 type Response = ServiceResponse<B>;
36 type Error = Error;
37 type Transform = RequestLoggerMiddleware<S>;
38 type InitError = ();
39 type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41 fn new_transform(&self, service: S) -> Self::Future {
42 ready(Ok(RequestLoggerMiddleware {
43 service,
44 _sensitive_fields: self.sensitive_fields.clone(),
45 }))
46 }
47}
48
49pub struct RequestLoggerMiddleware<S> {
50 service: S,
51 _sensitive_fields: Arc<HashSet<String>>,
52}
53
54impl<S, B> Service<ServiceRequest> for RequestLoggerMiddleware<S>
55where
56 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57 B: 'static,
58{
59 type Response = ServiceResponse<B>;
60 type Error = Error;
61 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
62
63 forward_ready!(service);
64
65 fn call(&self, mut req: ServiceRequest) -> Self::Future {
66 let start = Instant::now();
67 let method = req.method().to_string();
68 let path = req.path().to_string();
69 let request_id = uuid::Uuid::new_v4().to_string();
70 let metrics = req.app_data::<web::Data<MetricsState>>().cloned();
71
72 let user_id = req
74 .headers()
75 .get("x-request-user-id")
76 .and_then(|v| v.to_str().ok())
77 .map(String::from);
78
79 req.headers_mut().insert(
81 actix_web::http::header::HeaderName::from_static("x-request-id"),
82 actix_web::http::header::HeaderValue::from_str(&request_id)
83 .unwrap_or_else(|_| actix_web::http::header::HeaderValue::from_static("unknown")),
84 );
85
86 let fut = self.service.call(req);
87
88 Box::pin(async move {
89 let res = fut.await?;
90
91 let status = res.status().as_u16();
92 let duration = start.elapsed();
93 let duration_ms = duration.as_millis() as u64;
94
95 if let Some(metrics) = metrics {
96 metrics.record_request(&method, &path, status, duration.as_secs_f64());
97 if status >= 400 {
98 metrics.record_error(&format!("http_{status}"));
99 }
100 }
101
102 tracing::info!(
103 request_id = %request_id,
104 method = %method,
105 path = %path,
106 status = status,
107 duration_ms = duration_ms,
108 user_id = user_id.as_deref().unwrap_or("-"),
109 "request completed"
110 );
111
112 Ok(res)
113 })
114 }
115}
116
117pub fn get_request_id(req: &actix_web::HttpRequest) -> String {
119 req.headers()
120 .get("x-request-id")
121 .and_then(|v| v.to_str().ok())
122 .unwrap_or("-")
123 .to_string()
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn request_logger_new() {
132 let sensitive = HashSet::new();
133 let logger = RequestLogger::new(sensitive);
134 assert!(logger.sensitive_fields.is_empty());
135 }
136
137 #[test]
138 fn get_request_id_missing() {
139 let req = actix_web::test::TestRequest::default().to_http_request();
140 assert_eq!(get_request_id(&req), "-");
141 }
142
143 #[test]
144 fn get_request_id_present() {
145 let req = actix_web::test::TestRequest::default()
146 .insert_header(("x-request-id", "abc-123"))
147 .to_http_request();
148 assert_eq!(get_request_id(&req), "abc-123");
149 }
150}