1use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15#[derive(Clone)]
37pub struct TracingLayer {
38 level: Level,
39 custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43 pub fn new() -> Self {
45 Self {
46 level: Level::INFO,
47 custom_fields: Vec::new(),
48 }
49 }
50
51 pub fn with_level(level: Level) -> Self {
53 Self {
54 level,
55 custom_fields: Vec::new(),
56 }
57 }
58
59 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71 self.custom_fields.push((key.into(), value.into()));
72 self
73 }
74}
75
76impl Default for TracingLayer {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl MiddlewareLayer for TracingLayer {
83 fn call(
84 &self,
85 req: Request,
86 next: BoxedNext,
87 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88 let level = self.level;
89 let method = req.method().to_string();
90 let path = req.uri().path().to_string();
91 let custom_fields = self.custom_fields.clone();
92
93 let request_id = req
95 .extensions()
96 .get::<RequestId>()
97 .map(|id| id.as_str().to_string())
98 .unwrap_or_else(|| "unknown".to_string());
99
100 Box::pin(async move {
101 let start = Instant::now();
102
103 let span = info_span!(
106 "http_request",
107 method = %method,
108 path = %path,
109 request_id = %request_id,
110 status = tracing::field::Empty,
111 duration_ms = tracing::field::Empty,
112 error = tracing::field::Empty,
113 );
114
115 for (key, value) in &custom_fields {
117 span.record(key.as_str(), value.as_str());
118 }
119
120 let response = async {
122 next(req).await
123 }
124 .instrument(span.clone())
125 .await;
126
127 let duration = start.elapsed();
128 let status = response.status();
129 let status_code = status.as_u16();
130
131 span.record("status", status_code);
133 span.record("duration_ms", duration.as_millis() as u64);
134
135 if status.is_client_error() || status.is_server_error() {
137 span.record("error", true);
138 }
139
140 let _enter = span.enter();
142 if status.is_success() {
143 match level {
144 Level::TRACE => tracing::trace!(
145 method = %method,
146 path = %path,
147 request_id = %request_id,
148 status = %status_code,
149 duration_ms = %duration.as_millis(),
150 "Request completed"
151 ),
152 Level::DEBUG => tracing::debug!(
153 method = %method,
154 path = %path,
155 request_id = %request_id,
156 status = %status_code,
157 duration_ms = %duration.as_millis(),
158 "Request completed"
159 ),
160 Level::INFO => tracing::info!(
161 method = %method,
162 path = %path,
163 request_id = %request_id,
164 status = %status_code,
165 duration_ms = %duration.as_millis(),
166 "Request completed"
167 ),
168 Level::WARN => tracing::warn!(
169 method = %method,
170 path = %path,
171 request_id = %request_id,
172 status = %status_code,
173 duration_ms = %duration.as_millis(),
174 "Request completed"
175 ),
176 Level::ERROR => tracing::error!(
177 method = %method,
178 path = %path,
179 request_id = %request_id,
180 status = %status_code,
181 duration_ms = %duration.as_millis(),
182 "Request completed"
183 ),
184 }
185 } else {
186 tracing::warn!(
187 method = %method,
188 path = %path,
189 request_id = %request_id,
190 status = %status_code,
191 duration_ms = %duration.as_millis(),
192 error = true,
193 "Request failed"
194 );
195 }
196
197 response
198 })
199 }
200
201 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
202 Box::new(self.clone())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::middleware::layer::{BoxedNext, LayerStack};
210 use crate::middleware::request_id::RequestIdLayer;
211 use bytes::Bytes;
212 use http::{Extensions, Method, StatusCode};
213 use proptest::prelude::*;
214 use proptest::test_runner::TestCaseError;
215 use std::collections::HashMap;
216 use std::sync::Arc;
217 use tracing_subscriber::layer::SubscriberExt;
218
219 fn create_test_request(method: Method, path: &str) -> crate::request::Request {
221 let uri: http::Uri = path.parse().unwrap();
222 let builder = http::Request::builder().method(method).uri(uri);
223
224 let req = builder.body(()).unwrap();
225 let (parts, _) = req.into_parts();
226
227 crate::request::Request::new(
228 parts,
229 Bytes::new(),
230 Arc::new(Extensions::new()),
231 HashMap::new(),
232 )
233 }
234
235 #[test]
236 fn test_tracing_layer_creation() {
237 let layer = TracingLayer::new();
238 assert_eq!(layer.level, Level::INFO);
239 assert!(layer.custom_fields.is_empty());
240
241 let layer = TracingLayer::with_level(Level::DEBUG);
242 assert_eq!(layer.level, Level::DEBUG);
243 }
244
245 #[test]
246 fn test_tracing_layer_with_custom_fields() {
247 let layer = TracingLayer::new()
248 .with_field("service", "test-api")
249 .with_field("version", "1.0.0");
250
251 assert_eq!(layer.custom_fields.len(), 2);
252 assert_eq!(layer.custom_fields[0], ("service".to_string(), "test-api".to_string()));
253 assert_eq!(layer.custom_fields[1], ("version".to_string(), "1.0.0".to_string()));
254 }
255
256 #[test]
257 fn test_tracing_layer_clone() {
258 let layer = TracingLayer::new()
259 .with_field("key", "value");
260
261 let cloned = layer.clone();
262 assert_eq!(cloned.level, layer.level);
263 assert_eq!(cloned.custom_fields, layer.custom_fields);
264 }
265
266 #[derive(Clone)]
268 struct SpanFieldCapture {
269 captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
270 }
271
272 #[derive(Debug, Clone)]
273 struct CapturedSpan {
274 name: String,
275 fields: HashMap<String, String>,
276 }
277
278 impl SpanFieldCapture {
279 fn new() -> Self {
280 Self {
281 captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
282 }
283 }
284
285 fn get_spans(&self) -> Vec<CapturedSpan> {
286 self.captured_fields.lock().unwrap().clone()
287 }
288 }
289
290 impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
291 where
292 S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
293 {
294 fn on_new_span(
295 &self,
296 attrs: &tracing::span::Attributes<'_>,
297 _id: &tracing::span::Id,
298 _ctx: tracing_subscriber::layer::Context<'_, S>,
299 ) {
300 let mut fields = HashMap::new();
301 let mut visitor = FieldVisitor { fields: &mut fields };
302 attrs.record(&mut visitor);
303
304 let span = CapturedSpan {
305 name: attrs.metadata().name().to_string(),
306 fields,
307 };
308
309 self.captured_fields.lock().unwrap().push(span);
310 }
311
312 fn on_record(
313 &self,
314 id: &tracing::span::Id,
315 values: &tracing::span::Record<'_>,
316 ctx: tracing_subscriber::layer::Context<'_, S>,
317 ) {
318 if let Some(_span) = ctx.span(id) {
319 let mut captured = self.captured_fields.lock().unwrap();
320 if let Some(last_span) = captured.last_mut() {
321 let mut visitor = FieldVisitor { fields: &mut last_span.fields };
322 values.record(&mut visitor);
323 }
324 }
325 }
326 }
327
328 struct FieldVisitor<'a> {
329 fields: &'a mut HashMap<String, String>,
330 }
331
332 impl<'a> tracing::field::Visit for FieldVisitor<'a> {
333 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
334 self.fields.insert(field.name().to_string(), format!("{:?}", value));
335 }
336
337 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
338 self.fields.insert(field.name().to_string(), value.to_string());
339 }
340
341 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
342 self.fields.insert(field.name().to_string(), value.to_string());
343 }
344
345 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
346 self.fields.insert(field.name().to_string(), value.to_string());
347 }
348
349 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
350 self.fields.insert(field.name().to_string(), value.to_string());
351 }
352 }
353
354 proptest! {
362 #![proptest_config(ProptestConfig::with_cases(100))]
363
364 #[test]
365 fn prop_tracing_span_completeness(
366 method_idx in 0usize..5usize,
367 path in "/[a-z]{1,10}(/[a-z]{1,10})?",
368 status_code in 200u16..600u16,
369 custom_key in "[a-z]{3,10}",
370 custom_value in "[a-z0-9]{3,20}",
371 ) {
372 let rt = tokio::runtime::Runtime::new().unwrap();
373 let result: Result<(), TestCaseError> = rt.block_on(async {
374 let capture = SpanFieldCapture::new();
376 let subscriber = tracing_subscriber::registry().with(capture.clone());
377
378 let _guard = tracing::subscriber::set_default(subscriber);
380
381 let mut stack = LayerStack::new();
383 stack.push(Box::new(RequestIdLayer::new()));
384 stack.push(Box::new(TracingLayer::new()
385 .with_field(&custom_key, &custom_value)));
386
387 let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
389 let method = methods[method_idx].clone();
390
391 let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
393 let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
394 let status = response_status;
395 Box::pin(async move {
396 http::Response::builder()
397 .status(status)
398 .body(http_body_util::Full::new(Bytes::from("test")))
399 .unwrap()
400 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
401 });
402
403 let request = create_test_request(method.clone(), &path);
405 let response = stack.execute(request, handler).await;
406
407 prop_assert_eq!(response.status(), response_status);
409
410 let spans = capture.get_spans();
412 let http_span = spans.iter().find(|s| s.name == "http_request");
413
414 prop_assert!(http_span.is_some(), "Should have created an http_request span");
415 let span = http_span.unwrap();
416
417 prop_assert!(
420 span.fields.contains_key("method"),
421 "Span should contain 'method' field. Fields: {:?}", span.fields
422 );
423 prop_assert_eq!(
424 span.fields.get("method").map(|s| s.trim_matches('"')),
425 Some(method.as_str()),
426 "Method should match request method"
427 );
428
429 prop_assert!(
431 span.fields.contains_key("path"),
432 "Span should contain 'path' field. Fields: {:?}", span.fields
433 );
434 prop_assert_eq!(
435 span.fields.get("path").map(|s| s.trim_matches('"')),
436 Some(path.as_str()),
437 "Path should match request path"
438 );
439
440 prop_assert!(
442 span.fields.contains_key("request_id"),
443 "Span should contain 'request_id' field. Fields: {:?}", span.fields
444 );
445 let request_id = span.fields.get("request_id").unwrap();
446 let request_id_trimmed = request_id.trim_matches('"');
448 prop_assert!(
449 request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
450 "Request ID should be UUID format or 'unknown', got: {}", request_id
451 );
452
453 prop_assert!(
455 span.fields.contains_key("status"),
456 "Span should contain 'status' field. Fields: {:?}", span.fields
457 );
458 let recorded_status: u16 = span.fields.get("status")
459 .and_then(|s| s.parse().ok())
460 .unwrap_or(0);
461 prop_assert_eq!(
462 recorded_status,
463 status_code,
464 "Status should match response status code"
465 );
466
467 prop_assert!(
469 span.fields.contains_key("duration_ms"),
470 "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
471 );
472 let duration: u64 = span.fields.get("duration_ms")
473 .and_then(|s| s.parse().ok())
474 .unwrap_or(u64::MAX);
475 prop_assert!(
476 duration < 10000, "Duration should be reasonable, got: {} ms", duration
478 );
479
480 if response_status.is_client_error() || response_status.is_server_error() {
482 prop_assert!(
483 span.fields.contains_key("error"),
484 "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
485 );
486 }
487
488 Ok(())
489 });
490 result?;
491 }
492 }
493
494 #[test]
495 fn test_tracing_layer_records_request_id() {
496 let rt = tokio::runtime::Runtime::new().unwrap();
497 rt.block_on(async {
498 let capture = SpanFieldCapture::new();
499 let subscriber = tracing_subscriber::registry().with(capture.clone());
500 let _guard = tracing::subscriber::set_default(subscriber);
501
502 let mut stack = LayerStack::new();
503 stack.push(Box::new(RequestIdLayer::new()));
504 stack.push(Box::new(TracingLayer::new()));
505
506 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
507 Box::pin(async {
508 http::Response::builder()
509 .status(StatusCode::OK)
510 .body(http_body_util::Full::new(Bytes::from("ok")))
511 .unwrap()
512 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
513 });
514
515 let request = create_test_request(Method::GET, "/test");
516 let _response = stack.execute(request, handler).await;
517
518 let spans = capture.get_spans();
519 let http_span = spans.iter().find(|s| s.name == "http_request");
520 assert!(http_span.is_some(), "Should have http_request span");
521
522 let span = http_span.unwrap();
523 assert!(span.fields.contains_key("request_id"), "Should have request_id field");
524 });
525 }
526
527 #[test]
528 fn test_tracing_layer_records_error_for_failures() {
529 let rt = tokio::runtime::Runtime::new().unwrap();
530 rt.block_on(async {
531 let capture = SpanFieldCapture::new();
532 let subscriber = tracing_subscriber::registry().with(capture.clone());
533 let _guard = tracing::subscriber::set_default(subscriber);
534
535 let mut stack = LayerStack::new();
536 stack.push(Box::new(TracingLayer::new()));
537
538 let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
539 Box::pin(async {
540 http::Response::builder()
541 .status(StatusCode::INTERNAL_SERVER_ERROR)
542 .body(http_body_util::Full::new(Bytes::from("error")))
543 .unwrap()
544 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
545 });
546
547 let request = create_test_request(Method::GET, "/test");
548 let response = stack.execute(request, handler).await;
549
550 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
551
552 let spans = capture.get_spans();
553 let http_span = spans.iter().find(|s| s.name == "http_request");
554 assert!(http_span.is_some(), "Should have http_request span");
555
556 let span = http_span.unwrap();
557 assert!(span.fields.contains_key("error"), "Should have error field for 5xx response");
558 });
559 }
560}