1use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15#[derive(Clone)]
37pub struct TracingLayer {
38 level: Level,
39 custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43 pub fn new() -> Self {
45 Self {
46 level: Level::INFO,
47 custom_fields: Vec::new(),
48 }
49 }
50
51 pub fn with_level(level: Level) -> Self {
53 Self {
54 level,
55 custom_fields: Vec::new(),
56 }
57 }
58
59 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71 self.custom_fields.push((key.into(), value.into()));
72 self
73 }
74}
75
76impl Default for TracingLayer {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl MiddlewareLayer for TracingLayer {
83 fn call(
84 &self,
85 req: Request,
86 next: BoxedNext,
87 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88 let level = self.level;
89 let method = req.method().to_string();
90 let path = req.uri().path().to_string();
91 let custom_fields = self.custom_fields.clone();
92
93 let request_id = req
95 .extensions()
96 .get::<RequestId>()
97 .map(|id| id.as_str().to_string())
98 .unwrap_or_else(|| "unknown".to_string());
99
100 Box::pin(async move {
101 let start = Instant::now();
102
103 let span = info_span!(
106 "http_request",
107 method = %method,
108 path = %path,
109 request_id = %request_id,
110 status = tracing::field::Empty,
111 duration_ms = tracing::field::Empty,
112 error = tracing::field::Empty,
113 );
114
115 for (key, value) in &custom_fields {
117 span.record(key.as_str(), value.as_str());
118 }
119
120 let response = async { next(req).await }.instrument(span.clone()).await;
122
123 let duration = start.elapsed();
124 let status = response.status();
125 let status_code = status.as_u16();
126
127 span.record("status", status_code);
129 span.record("duration_ms", duration.as_millis() as u64);
130
131 if status.is_client_error() || status.is_server_error() {
133 span.record("error", true);
134 }
135
136 let _enter = span.enter();
138 if status.is_success() {
139 match level {
140 Level::TRACE => tracing::trace!(
141 method = %method,
142 path = %path,
143 request_id = %request_id,
144 status = %status_code,
145 duration_ms = %duration.as_millis(),
146 "Request completed"
147 ),
148 Level::DEBUG => tracing::debug!(
149 method = %method,
150 path = %path,
151 request_id = %request_id,
152 status = %status_code,
153 duration_ms = %duration.as_millis(),
154 "Request completed"
155 ),
156 Level::INFO => tracing::info!(
157 method = %method,
158 path = %path,
159 request_id = %request_id,
160 status = %status_code,
161 duration_ms = %duration.as_millis(),
162 "Request completed"
163 ),
164 Level::WARN => tracing::warn!(
165 method = %method,
166 path = %path,
167 request_id = %request_id,
168 status = %status_code,
169 duration_ms = %duration.as_millis(),
170 "Request completed"
171 ),
172 Level::ERROR => tracing::error!(
173 method = %method,
174 path = %path,
175 request_id = %request_id,
176 status = %status_code,
177 duration_ms = %duration.as_millis(),
178 "Request completed"
179 ),
180 }
181 } else {
182 tracing::warn!(
183 method = %method,
184 path = %path,
185 request_id = %request_id,
186 status = %status_code,
187 duration_ms = %duration.as_millis(),
188 error = true,
189 "Request failed"
190 );
191 }
192
193 response
194 })
195 }
196
197 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
198 Box::new(self.clone())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::middleware::layer::{BoxedNext, LayerStack};
206 use crate::middleware::request_id::RequestIdLayer;
207 use bytes::Bytes;
208 use http::{Extensions, Method, StatusCode};
209 use proptest::prelude::*;
210 use proptest::test_runner::TestCaseError;
211 use std::collections::HashMap;
212 use std::sync::Arc;
213 use tracing_subscriber::layer::SubscriberExt;
214
215 fn create_test_request(method: Method, path: &str) -> crate::request::Request {
217 let uri: http::Uri = path.parse().unwrap();
218 let builder = http::Request::builder().method(method).uri(uri);
219
220 let req = builder.body(()).unwrap();
221 let (parts, _) = req.into_parts();
222
223 crate::request::Request::new(
224 parts,
225 Bytes::new(),
226 Arc::new(Extensions::new()),
227 HashMap::new(),
228 )
229 }
230
231 #[test]
232 fn test_tracing_layer_creation() {
233 let layer = TracingLayer::new();
234 assert_eq!(layer.level, Level::INFO);
235 assert!(layer.custom_fields.is_empty());
236
237 let layer = TracingLayer::with_level(Level::DEBUG);
238 assert_eq!(layer.level, Level::DEBUG);
239 }
240
241 #[test]
242 fn test_tracing_layer_with_custom_fields() {
243 let layer = TracingLayer::new()
244 .with_field("service", "test-api")
245 .with_field("version", "1.0.0");
246
247 assert_eq!(layer.custom_fields.len(), 2);
248 assert_eq!(
249 layer.custom_fields[0],
250 ("service".to_string(), "test-api".to_string())
251 );
252 assert_eq!(
253 layer.custom_fields[1],
254 ("version".to_string(), "1.0.0".to_string())
255 );
256 }
257
258 #[test]
259 fn test_tracing_layer_clone() {
260 let layer = TracingLayer::new().with_field("key", "value");
261
262 let cloned = layer.clone();
263 assert_eq!(cloned.level, layer.level);
264 assert_eq!(cloned.custom_fields, layer.custom_fields);
265 }
266
267 #[derive(Clone)]
269 struct SpanFieldCapture {
270 captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
271 }
272
273 #[derive(Debug, Clone)]
274 struct CapturedSpan {
275 name: String,
276 fields: HashMap<String, String>,
277 }
278
279 impl SpanFieldCapture {
280 fn new() -> Self {
281 Self {
282 captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
283 }
284 }
285
286 fn get_spans(&self) -> Vec<CapturedSpan> {
287 self.captured_fields.lock().unwrap().clone()
288 }
289 }
290
291 impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
292 where
293 S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
294 {
295 fn on_new_span(
296 &self,
297 attrs: &tracing::span::Attributes<'_>,
298 _id: &tracing::span::Id,
299 _ctx: tracing_subscriber::layer::Context<'_, S>,
300 ) {
301 let mut fields = HashMap::new();
302 let mut visitor = FieldVisitor {
303 fields: &mut fields,
304 };
305 attrs.record(&mut visitor);
306
307 let span = CapturedSpan {
308 name: attrs.metadata().name().to_string(),
309 fields,
310 };
311
312 self.captured_fields.lock().unwrap().push(span);
313 }
314
315 fn on_record(
316 &self,
317 id: &tracing::span::Id,
318 values: &tracing::span::Record<'_>,
319 ctx: tracing_subscriber::layer::Context<'_, S>,
320 ) {
321 if let Some(_span) = ctx.span(id) {
322 let mut captured = self.captured_fields.lock().unwrap();
323 if let Some(last_span) = captured.last_mut() {
324 let mut visitor = FieldVisitor {
325 fields: &mut last_span.fields,
326 };
327 values.record(&mut visitor);
328 }
329 }
330 }
331 }
332
333 struct FieldVisitor<'a> {
334 fields: &'a mut HashMap<String, String>,
335 }
336
337 impl<'a> tracing::field::Visit for FieldVisitor<'a> {
338 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
339 self.fields
340 .insert(field.name().to_string(), format!("{:?}", value));
341 }
342
343 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
344 self.fields
345 .insert(field.name().to_string(), value.to_string());
346 }
347
348 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
349 self.fields
350 .insert(field.name().to_string(), value.to_string());
351 }
352
353 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
354 self.fields
355 .insert(field.name().to_string(), value.to_string());
356 }
357
358 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
359 self.fields
360 .insert(field.name().to_string(), value.to_string());
361 }
362 }
363
364 proptest! {
372 #![proptest_config(ProptestConfig::with_cases(100))]
373
374 #[test]
375 fn prop_tracing_span_completeness(
376 method_idx in 0usize..5usize,
377 path in "/[a-z]{1,10}(/[a-z]{1,10})?",
378 status_code in 200u16..600u16,
379 custom_key in "[a-z]{3,10}",
380 custom_value in "[a-z0-9]{3,20}",
381 ) {
382 let rt = tokio::runtime::Runtime::new().unwrap();
383 let result: Result<(), TestCaseError> = rt.block_on(async {
384 let capture = SpanFieldCapture::new();
386 let subscriber = tracing_subscriber::registry().with(capture.clone());
387
388 let _guard = tracing::subscriber::set_default(subscriber);
390
391 let mut stack = LayerStack::new();
393 stack.push(Box::new(RequestIdLayer::new()));
394 stack.push(Box::new(TracingLayer::new()
395 .with_field(&custom_key, &custom_value)));
396
397 let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
399 let method = methods[method_idx].clone();
400
401 let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
403 let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
404 let status = response_status;
405 Box::pin(async move {
406 http::Response::builder()
407 .status(status)
408 .body(http_body_util::Full::new(Bytes::from("test")))
409 .unwrap()
410 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
411 });
412
413 let request = create_test_request(method.clone(), &path);
415 let response = stack.execute(request, handler).await;
416
417 prop_assert_eq!(response.status(), response_status);
419
420 let spans = capture.get_spans();
422 let http_span = spans.iter().find(|s| s.name == "http_request");
423
424 prop_assert!(http_span.is_some(), "Should have created an http_request span");
425 let span = http_span.unwrap();
426
427 prop_assert!(
430 span.fields.contains_key("method"),
431 "Span should contain 'method' field. Fields: {:?}", span.fields
432 );
433 prop_assert_eq!(
434 span.fields.get("method").map(|s| s.trim_matches('"')),
435 Some(method.as_str()),
436 "Method should match request method"
437 );
438
439 prop_assert!(
441 span.fields.contains_key("path"),
442 "Span should contain 'path' field. Fields: {:?}", span.fields
443 );
444 prop_assert_eq!(
445 span.fields.get("path").map(|s| s.trim_matches('"')),
446 Some(path.as_str()),
447 "Path should match request path"
448 );
449
450 prop_assert!(
452 span.fields.contains_key("request_id"),
453 "Span should contain 'request_id' field. Fields: {:?}", span.fields
454 );
455 let request_id = span.fields.get("request_id").unwrap();
456 let request_id_trimmed = request_id.trim_matches('"');
458 prop_assert!(
459 request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
460 "Request ID should be UUID format or 'unknown', got: {}", request_id
461 );
462
463 prop_assert!(
465 span.fields.contains_key("status"),
466 "Span should contain 'status' field. Fields: {:?}", span.fields
467 );
468 let recorded_status: u16 = span.fields.get("status")
469 .and_then(|s| s.parse().ok())
470 .unwrap_or(0);
471 prop_assert_eq!(
472 recorded_status,
473 status_code,
474 "Status should match response status code"
475 );
476
477 prop_assert!(
479 span.fields.contains_key("duration_ms"),
480 "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
481 );
482 let duration: u64 = span.fields.get("duration_ms")
483 .and_then(|s| s.parse().ok())
484 .unwrap_or(u64::MAX);
485 prop_assert!(
486 duration < 10000, "Duration should be reasonable, got: {} ms", duration
488 );
489
490 if response_status.is_client_error() || response_status.is_server_error() {
492 prop_assert!(
493 span.fields.contains_key("error"),
494 "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
495 );
496 }
497
498 Ok(())
499 });
500 result?;
501 }
502 }
503
504 #[test]
505 fn test_tracing_layer_records_request_id() {
506 let rt = tokio::runtime::Runtime::new().unwrap();
507 rt.block_on(async {
508 let capture = SpanFieldCapture::new();
509 let subscriber = tracing_subscriber::registry().with(capture.clone());
510 let _guard = tracing::subscriber::set_default(subscriber);
511
512 let mut stack = LayerStack::new();
513 stack.push(Box::new(RequestIdLayer::new()));
514 stack.push(Box::new(TracingLayer::new()));
515
516 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
517 Box::pin(async {
518 http::Response::builder()
519 .status(StatusCode::OK)
520 .body(http_body_util::Full::new(Bytes::from("ok")))
521 .unwrap()
522 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
523 });
524
525 let request = create_test_request(Method::GET, "/test");
526 let _response = stack.execute(request, handler).await;
527
528 let spans = capture.get_spans();
529 let http_span = spans.iter().find(|s| s.name == "http_request");
530 assert!(http_span.is_some(), "Should have http_request span");
531
532 let span = http_span.unwrap();
533 assert!(
534 span.fields.contains_key("request_id"),
535 "Should have request_id field"
536 );
537 });
538 }
539
540 #[test]
541 fn test_tracing_layer_records_error_for_failures() {
542 let rt = tokio::runtime::Runtime::new().unwrap();
543 rt.block_on(async {
544 let capture = SpanFieldCapture::new();
545 let subscriber = tracing_subscriber::registry().with(capture.clone());
546 let _guard = tracing::subscriber::set_default(subscriber);
547
548 let mut stack = LayerStack::new();
549 stack.push(Box::new(TracingLayer::new()));
550
551 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
552 Box::pin(async {
553 http::Response::builder()
554 .status(StatusCode::INTERNAL_SERVER_ERROR)
555 .body(http_body_util::Full::new(Bytes::from("error")))
556 .unwrap()
557 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
558 });
559
560 let request = create_test_request(Method::GET, "/test");
561 let response = stack.execute(request, handler).await;
562
563 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
564
565 let spans = capture.get_spans();
566 let http_span = spans.iter().find(|s| s.name == "http_request");
567 assert!(http_span.is_some(), "Should have http_request span");
568
569 let span = http_span.unwrap();
570 assert!(
571 span.fields.contains_key("error"),
572 "Should have error field for 5xx response"
573 );
574 });
575 }
576}