1use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15#[derive(Clone)]
37pub struct TracingLayer {
38 level: Level,
39 custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43 pub fn new() -> Self {
45 Self {
46 level: Level::INFO,
47 custom_fields: Vec::new(),
48 }
49 }
50
51 pub fn with_level(level: Level) -> Self {
53 Self {
54 level,
55 custom_fields: Vec::new(),
56 }
57 }
58
59 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71 self.custom_fields.push((key.into(), value.into()));
72 self
73 }
74}
75
76impl Default for TracingLayer {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl MiddlewareLayer for TracingLayer {
83 fn call(
84 &self,
85 req: Request,
86 next: BoxedNext,
87 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88 let level = self.level;
89 let method = req.method().to_string();
90 let path = req.uri().path().to_string();
91 let custom_fields = self.custom_fields.clone();
92
93 let request_id = req
95 .extensions()
96 .get::<RequestId>()
97 .map(|id| id.as_str().to_string())
98 .unwrap_or_else(|| "unknown".to_string());
99
100 Box::pin(async move {
101 let start = Instant::now();
102
103 let span = info_span!(
106 "http_request",
107 method = %method,
108 path = %path,
109 request_id = %request_id,
110 status = tracing::field::Empty,
111 duration_ms = tracing::field::Empty,
112 error = tracing::field::Empty,
113 );
114
115 for (key, value) in &custom_fields {
117 span.record(key.as_str(), value.as_str());
118 }
119
120 let response = async { next(req).await }.instrument(span.clone()).await;
122
123 let duration = start.elapsed();
124 let status = response.status();
125 let status_code = status.as_u16();
126
127 span.record("status", status_code);
129 span.record("duration_ms", duration.as_millis() as u64);
130
131 if status.is_client_error() || status.is_server_error() {
133 span.record("error", true);
134 }
135
136 let _enter = span.enter();
138 if status.is_success() {
139 match level {
140 Level::TRACE => tracing::trace!(
141 method = %method,
142 path = %path,
143 request_id = %request_id,
144 status = %status_code,
145 duration_ms = %duration.as_millis(),
146 "Request completed"
147 ),
148 Level::DEBUG => tracing::debug!(
149 method = %method,
150 path = %path,
151 request_id = %request_id,
152 status = %status_code,
153 duration_ms = %duration.as_millis(),
154 "Request completed"
155 ),
156 Level::INFO => tracing::info!(
157 method = %method,
158 path = %path,
159 request_id = %request_id,
160 status = %status_code,
161 duration_ms = %duration.as_millis(),
162 "Request completed"
163 ),
164 Level::WARN => tracing::warn!(
165 method = %method,
166 path = %path,
167 request_id = %request_id,
168 status = %status_code,
169 duration_ms = %duration.as_millis(),
170 "Request completed"
171 ),
172 Level::ERROR => tracing::error!(
173 method = %method,
174 path = %path,
175 request_id = %request_id,
176 status = %status_code,
177 duration_ms = %duration.as_millis(),
178 "Request completed"
179 ),
180 }
181 } else {
182 tracing::warn!(
183 method = %method,
184 path = %path,
185 request_id = %request_id,
186 status = %status_code,
187 duration_ms = %duration.as_millis(),
188 error = true,
189 "Request failed"
190 );
191 }
192
193 response
194 })
195 }
196
197 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
198 Box::new(self.clone())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::middleware::layer::{BoxedNext, LayerStack};
206 use crate::middleware::request_id::RequestIdLayer;
207 use crate::path_params::PathParams;
208 use bytes::Bytes;
209 use http::{Extensions, Method, StatusCode};
210 use proptest::prelude::*;
211 use proptest::test_runner::TestCaseError;
212 use std::collections::HashMap;
213 use std::sync::Arc;
214 use tracing_subscriber::layer::SubscriberExt;
215
216 fn create_test_request(method: Method, path: &str) -> crate::request::Request {
218 let uri: http::Uri = path.parse().unwrap();
219 let builder = http::Request::builder().method(method).uri(uri);
220
221 let req = builder.body(()).unwrap();
222 let (parts, _) = req.into_parts();
223
224 crate::request::Request::new(
225 parts,
226 crate::request::BodyVariant::Buffered(Bytes::new()),
227 Arc::new(Extensions::new()),
228 PathParams::new(),
229 )
230 }
231
232 #[test]
233 fn test_tracing_layer_creation() {
234 let layer = TracingLayer::new();
235 assert_eq!(layer.level, Level::INFO);
236 assert!(layer.custom_fields.is_empty());
237
238 let layer = TracingLayer::with_level(Level::DEBUG);
239 assert_eq!(layer.level, Level::DEBUG);
240 }
241
242 #[test]
243 fn test_tracing_layer_with_custom_fields() {
244 let layer = TracingLayer::new()
245 .with_field("service", "test-api")
246 .with_field("version", "1.0.0");
247
248 assert_eq!(layer.custom_fields.len(), 2);
249 assert_eq!(
250 layer.custom_fields[0],
251 ("service".to_string(), "test-api".to_string())
252 );
253 assert_eq!(
254 layer.custom_fields[1],
255 ("version".to_string(), "1.0.0".to_string())
256 );
257 }
258
259 #[test]
260 fn test_tracing_layer_clone() {
261 let layer = TracingLayer::new().with_field("key", "value");
262
263 let cloned = layer.clone();
264 assert_eq!(cloned.level, layer.level);
265 assert_eq!(cloned.custom_fields, layer.custom_fields);
266 }
267
268 #[derive(Clone)]
270 struct SpanFieldCapture {
271 captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
272 }
273
274 #[derive(Debug, Clone)]
275 struct CapturedSpan {
276 name: String,
277 fields: HashMap<String, String>,
278 }
279
280 impl SpanFieldCapture {
281 fn new() -> Self {
282 Self {
283 captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
284 }
285 }
286
287 fn get_spans(&self) -> Vec<CapturedSpan> {
288 self.captured_fields.lock().unwrap().clone()
289 }
290 }
291
292 impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
293 where
294 S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
295 {
296 fn on_new_span(
297 &self,
298 attrs: &tracing::span::Attributes<'_>,
299 _id: &tracing::span::Id,
300 _ctx: tracing_subscriber::layer::Context<'_, S>,
301 ) {
302 let mut fields = HashMap::new();
303 let mut visitor = FieldVisitor {
304 fields: &mut fields,
305 };
306 attrs.record(&mut visitor);
307
308 let span = CapturedSpan {
309 name: attrs.metadata().name().to_string(),
310 fields,
311 };
312
313 self.captured_fields.lock().unwrap().push(span);
314 }
315
316 fn on_record(
317 &self,
318 id: &tracing::span::Id,
319 values: &tracing::span::Record<'_>,
320 ctx: tracing_subscriber::layer::Context<'_, S>,
321 ) {
322 if let Some(_span) = ctx.span(id) {
323 let mut captured = self.captured_fields.lock().unwrap();
324 if let Some(last_span) = captured.last_mut() {
325 let mut visitor = FieldVisitor {
326 fields: &mut last_span.fields,
327 };
328 values.record(&mut visitor);
329 }
330 }
331 }
332 }
333
334 struct FieldVisitor<'a> {
335 fields: &'a mut HashMap<String, String>,
336 }
337
338 impl<'a> tracing::field::Visit for FieldVisitor<'a> {
339 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
340 self.fields
341 .insert(field.name().to_string(), format!("{:?}", value));
342 }
343
344 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
345 self.fields
346 .insert(field.name().to_string(), value.to_string());
347 }
348
349 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
350 self.fields
351 .insert(field.name().to_string(), value.to_string());
352 }
353
354 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
355 self.fields
356 .insert(field.name().to_string(), value.to_string());
357 }
358
359 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
360 self.fields
361 .insert(field.name().to_string(), value.to_string());
362 }
363 }
364
365 proptest! {
373 #![proptest_config(ProptestConfig::with_cases(100))]
374
375 #[test]
376 fn prop_tracing_span_completeness(
377 method_idx in 0usize..5usize,
378 path in "/[a-z]{1,10}(/[a-z]{1,10})?",
379 status_code in 200u16..600u16,
380 custom_key in "[a-z]{3,10}",
381 custom_value in "[a-z0-9]{3,20}",
382 ) {
383 let rt = tokio::runtime::Runtime::new().unwrap();
384 let result: Result<(), TestCaseError> = rt.block_on(async {
385 let capture = SpanFieldCapture::new();
387 let subscriber = tracing_subscriber::registry().with(capture.clone());
388
389 let _guard = tracing::subscriber::set_default(subscriber);
391
392 let mut stack = LayerStack::new();
394 stack.push(Box::new(RequestIdLayer::new()));
395 stack.push(Box::new(TracingLayer::new()
396 .with_field(&custom_key, &custom_value)));
397
398 let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
400 let method = methods[method_idx].clone();
401
402 let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
404 let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
405 let status = response_status;
406 Box::pin(async move {
407 http::Response::builder()
408 .status(status)
409 .body(http_body_util::Full::new(Bytes::from("test")))
410 .unwrap()
411 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
412 });
413
414 let request = create_test_request(method.clone(), &path);
416 let response = stack.execute(request, handler).await;
417
418 prop_assert_eq!(response.status(), response_status);
420
421 let spans = capture.get_spans();
423 let http_span = spans.iter().find(|s| s.name == "http_request");
424
425 prop_assert!(http_span.is_some(), "Should have created an http_request span");
426 let span = http_span.unwrap();
427
428 prop_assert!(
431 span.fields.contains_key("method"),
432 "Span should contain 'method' field. Fields: {:?}", span.fields
433 );
434 prop_assert_eq!(
435 span.fields.get("method").map(|s| s.trim_matches('"')),
436 Some(method.as_str()),
437 "Method should match request method"
438 );
439
440 prop_assert!(
442 span.fields.contains_key("path"),
443 "Span should contain 'path' field. Fields: {:?}", span.fields
444 );
445 prop_assert_eq!(
446 span.fields.get("path").map(|s| s.trim_matches('"')),
447 Some(path.as_str()),
448 "Path should match request path"
449 );
450
451 prop_assert!(
453 span.fields.contains_key("request_id"),
454 "Span should contain 'request_id' field. Fields: {:?}", span.fields
455 );
456 let request_id = span.fields.get("request_id").unwrap();
457 let request_id_trimmed = request_id.trim_matches('"');
459 prop_assert!(
460 request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
461 "Request ID should be UUID format or 'unknown', got: {}", request_id
462 );
463
464 prop_assert!(
466 span.fields.contains_key("status"),
467 "Span should contain 'status' field. Fields: {:?}", span.fields
468 );
469 let recorded_status: u16 = span.fields.get("status")
470 .and_then(|s| s.parse().ok())
471 .unwrap_or(0);
472 prop_assert_eq!(
473 recorded_status,
474 status_code,
475 "Status should match response status code"
476 );
477
478 prop_assert!(
480 span.fields.contains_key("duration_ms"),
481 "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
482 );
483 let duration: u64 = span.fields.get("duration_ms")
484 .and_then(|s| s.parse().ok())
485 .unwrap_or(u64::MAX);
486 prop_assert!(
487 duration < 10000, "Duration should be reasonable, got: {} ms", duration
489 );
490
491 if response_status.is_client_error() || response_status.is_server_error() {
493 prop_assert!(
494 span.fields.contains_key("error"),
495 "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
496 );
497 }
498
499 Ok(())
500 });
501 result?;
502 }
503 }
504
505 #[test]
506 fn test_tracing_layer_records_request_id() {
507 let rt = tokio::runtime::Runtime::new().unwrap();
508 rt.block_on(async {
509 let capture = SpanFieldCapture::new();
510 let subscriber = tracing_subscriber::registry().with(capture.clone());
511 let _guard = tracing::subscriber::set_default(subscriber);
512
513 let mut stack = LayerStack::new();
514 stack.push(Box::new(RequestIdLayer::new()));
515 stack.push(Box::new(TracingLayer::new()));
516
517 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
518 Box::pin(async {
519 http::Response::builder()
520 .status(StatusCode::OK)
521 .body(http_body_util::Full::new(Bytes::from("ok")))
522 .unwrap()
523 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
524 });
525
526 let request = create_test_request(Method::GET, "/test");
527 let _response = stack.execute(request, handler).await;
528
529 let spans = capture.get_spans();
530 let http_span = spans.iter().find(|s| s.name == "http_request");
531 assert!(http_span.is_some(), "Should have http_request span");
532
533 let span = http_span.unwrap();
534 assert!(
535 span.fields.contains_key("request_id"),
536 "Should have request_id field"
537 );
538 });
539 }
540
541 #[test]
542 fn test_tracing_layer_records_error_for_failures() {
543 let rt = tokio::runtime::Runtime::new().unwrap();
544 rt.block_on(async {
545 let capture = SpanFieldCapture::new();
546 let subscriber = tracing_subscriber::registry().with(capture.clone());
547 let _guard = tracing::subscriber::set_default(subscriber);
548
549 let mut stack = LayerStack::new();
550 stack.push(Box::new(TracingLayer::new()));
551
552 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
553 Box::pin(async {
554 http::Response::builder()
555 .status(StatusCode::INTERNAL_SERVER_ERROR)
556 .body(http_body_util::Full::new(Bytes::from("error")))
557 .unwrap()
558 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
559 });
560
561 let request = create_test_request(Method::GET, "/test");
562 let response = stack.execute(request, handler).await;
563
564 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
565
566 let spans = capture.get_spans();
567 let http_span = spans.iter().find(|s| s.name == "http_request");
568 assert!(http_span.is_some(), "Should have http_request span");
569
570 let span = http_span.unwrap();
571 assert!(
572 span.fields.contains_key("error"),
573 "Should have error field for 5xx response"
574 );
575 });
576 }
577}