rustapi_core/middleware/
tracing_layer.rs

1//! Enhanced Tracing middleware
2//!
3//! Logs request method, path, request_id, status code, and duration for each request.
4//! Supports custom fields that are included in all request spans.
5
6use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15/// Middleware layer that creates tracing spans for requests
16///
17/// This layer creates a span for each request containing:
18/// - HTTP method
19/// - Request path
20/// - Request ID (if RequestIdLayer is applied)
21/// - Response status code
22/// - Request duration
23/// - Any custom fields configured via `with_field()`
24///
25/// # Example
26///
27/// ```rust,ignore
28/// use rustapi_core::middleware::TracingLayer;
29///
30/// RustApi::new()
31///     .layer(TracingLayer::new()
32///         .with_field("service", "my-api")
33///         .with_field("version", "1.0.0"))
34///     .route("/", get(handler))
35/// ```
36#[derive(Clone)]
37pub struct TracingLayer {
38    level: Level,
39    custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43    /// Create a new TracingLayer with default INFO level
44    pub fn new() -> Self {
45        Self {
46            level: Level::INFO,
47            custom_fields: Vec::new(),
48        }
49    }
50
51    /// Create a TracingLayer with a specific log level
52    pub fn with_level(level: Level) -> Self {
53        Self {
54            level,
55            custom_fields: Vec::new(),
56        }
57    }
58
59    /// Add a custom field to all request spans
60    ///
61    /// Custom fields are included in every span created by this layer.
62    ///
63    /// # Example
64    ///
65    /// ```rust,ignore
66    /// TracingLayer::new()
67    ///     .with_field("service", "my-api")
68    ///     .with_field("environment", "production")
69    /// ```
70    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71        self.custom_fields.push((key.into(), value.into()));
72        self
73    }
74}
75
76impl Default for TracingLayer {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl MiddlewareLayer for TracingLayer {
83    fn call(
84        &self,
85        req: Request,
86        next: BoxedNext,
87    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88        let level = self.level;
89        let method = req.method().to_string();
90        let path = req.uri().path().to_string();
91        let custom_fields = self.custom_fields.clone();
92
93        // Extract request_id if available
94        let request_id = req
95            .extensions()
96            .get::<RequestId>()
97            .map(|id| id.as_str().to_string())
98            .unwrap_or_else(|| "unknown".to_string());
99
100        Box::pin(async move {
101            let start = Instant::now();
102
103            // Create span with all fields
104            // We use info_span! as the base and record custom fields dynamically
105            let span = info_span!(
106                "http_request",
107                method = %method,
108                path = %path,
109                request_id = %request_id,
110                status = tracing::field::Empty,
111                duration_ms = tracing::field::Empty,
112                error = tracing::field::Empty,
113            );
114
115            // Record custom fields in the span
116            for (key, value) in &custom_fields {
117                span.record(key.as_str(), value.as_str());
118            }
119
120            // Execute the request within the span
121            let response = async { next(req).await }.instrument(span.clone()).await;
122
123            let duration = start.elapsed();
124            let status = response.status();
125            let status_code = status.as_u16();
126
127            // Record response fields
128            span.record("status", status_code);
129            span.record("duration_ms", duration.as_millis() as u64);
130
131            // Record error if status indicates failure
132            if status.is_client_error() || status.is_server_error() {
133                span.record("error", true);
134            }
135
136            // Log based on status code and configured level
137            let _enter = span.enter();
138            if status.is_success() {
139                match level {
140                    Level::TRACE => tracing::trace!(
141                        method = %method,
142                        path = %path,
143                        request_id = %request_id,
144                        status = %status_code,
145                        duration_ms = %duration.as_millis(),
146                        "Request completed"
147                    ),
148                    Level::DEBUG => tracing::debug!(
149                        method = %method,
150                        path = %path,
151                        request_id = %request_id,
152                        status = %status_code,
153                        duration_ms = %duration.as_millis(),
154                        "Request completed"
155                    ),
156                    Level::INFO => tracing::info!(
157                        method = %method,
158                        path = %path,
159                        request_id = %request_id,
160                        status = %status_code,
161                        duration_ms = %duration.as_millis(),
162                        "Request completed"
163                    ),
164                    Level::WARN => tracing::warn!(
165                        method = %method,
166                        path = %path,
167                        request_id = %request_id,
168                        status = %status_code,
169                        duration_ms = %duration.as_millis(),
170                        "Request completed"
171                    ),
172                    Level::ERROR => tracing::error!(
173                        method = %method,
174                        path = %path,
175                        request_id = %request_id,
176                        status = %status_code,
177                        duration_ms = %duration.as_millis(),
178                        "Request completed"
179                    ),
180                }
181            } else {
182                tracing::warn!(
183                    method = %method,
184                    path = %path,
185                    request_id = %request_id,
186                    status = %status_code,
187                    duration_ms = %duration.as_millis(),
188                    error = true,
189                    "Request failed"
190                );
191            }
192
193            response
194        })
195    }
196
197    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
198        Box::new(self.clone())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::middleware::layer::{BoxedNext, LayerStack};
206    use crate::middleware::request_id::RequestIdLayer;
207    use bytes::Bytes;
208    use http::{Extensions, Method, StatusCode};
209    use proptest::prelude::*;
210    use proptest::test_runner::TestCaseError;
211    use std::collections::HashMap;
212    use std::sync::Arc;
213    use tracing_subscriber::layer::SubscriberExt;
214
215    /// Create a test request with the given method and path
216    fn create_test_request(method: Method, path: &str) -> crate::request::Request {
217        let uri: http::Uri = path.parse().unwrap();
218        let builder = http::Request::builder().method(method).uri(uri);
219
220        let req = builder.body(()).unwrap();
221        let (parts, _) = req.into_parts();
222
223        crate::request::Request::new(
224            parts,
225            Bytes::new(),
226            Arc::new(Extensions::new()),
227            HashMap::new(),
228        )
229    }
230
231    #[test]
232    fn test_tracing_layer_creation() {
233        let layer = TracingLayer::new();
234        assert_eq!(layer.level, Level::INFO);
235        assert!(layer.custom_fields.is_empty());
236
237        let layer = TracingLayer::with_level(Level::DEBUG);
238        assert_eq!(layer.level, Level::DEBUG);
239    }
240
241    #[test]
242    fn test_tracing_layer_with_custom_fields() {
243        let layer = TracingLayer::new()
244            .with_field("service", "test-api")
245            .with_field("version", "1.0.0");
246
247        assert_eq!(layer.custom_fields.len(), 2);
248        assert_eq!(
249            layer.custom_fields[0],
250            ("service".to_string(), "test-api".to_string())
251        );
252        assert_eq!(
253            layer.custom_fields[1],
254            ("version".to_string(), "1.0.0".to_string())
255        );
256    }
257
258    #[test]
259    fn test_tracing_layer_clone() {
260        let layer = TracingLayer::new().with_field("key", "value");
261
262        let cloned = layer.clone();
263        assert_eq!(cloned.level, layer.level);
264        assert_eq!(cloned.custom_fields, layer.custom_fields);
265    }
266
267    /// A test subscriber that captures span fields for verification
268    #[derive(Clone)]
269    struct SpanFieldCapture {
270        captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
271    }
272
273    #[derive(Debug, Clone)]
274    struct CapturedSpan {
275        name: String,
276        fields: HashMap<String, String>,
277    }
278
279    impl SpanFieldCapture {
280        fn new() -> Self {
281            Self {
282                captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
283            }
284        }
285
286        fn get_spans(&self) -> Vec<CapturedSpan> {
287            self.captured_fields.lock().unwrap().clone()
288        }
289    }
290
291    impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
292    where
293        S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
294    {
295        fn on_new_span(
296            &self,
297            attrs: &tracing::span::Attributes<'_>,
298            _id: &tracing::span::Id,
299            _ctx: tracing_subscriber::layer::Context<'_, S>,
300        ) {
301            let mut fields = HashMap::new();
302            let mut visitor = FieldVisitor {
303                fields: &mut fields,
304            };
305            attrs.record(&mut visitor);
306
307            let span = CapturedSpan {
308                name: attrs.metadata().name().to_string(),
309                fields,
310            };
311
312            self.captured_fields.lock().unwrap().push(span);
313        }
314
315        fn on_record(
316            &self,
317            id: &tracing::span::Id,
318            values: &tracing::span::Record<'_>,
319            ctx: tracing_subscriber::layer::Context<'_, S>,
320        ) {
321            if let Some(_span) = ctx.span(id) {
322                let mut captured = self.captured_fields.lock().unwrap();
323                if let Some(last_span) = captured.last_mut() {
324                    let mut visitor = FieldVisitor {
325                        fields: &mut last_span.fields,
326                    };
327                    values.record(&mut visitor);
328                }
329            }
330        }
331    }
332
333    struct FieldVisitor<'a> {
334        fields: &'a mut HashMap<String, String>,
335    }
336
337    impl<'a> tracing::field::Visit for FieldVisitor<'a> {
338        fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
339            self.fields
340                .insert(field.name().to_string(), format!("{:?}", value));
341        }
342
343        fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
344            self.fields
345                .insert(field.name().to_string(), value.to_string());
346        }
347
348        fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
349            self.fields
350                .insert(field.name().to_string(), value.to_string());
351        }
352
353        fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
354            self.fields
355                .insert(field.name().to_string(), value.to_string());
356        }
357
358        fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
359            self.fields
360                .insert(field.name().to_string(), value.to_string());
361        }
362    }
363
364    // **Feature: phase4-ergonomics-v1, Property 8: Tracing Span Completeness**
365    //
366    // For any HTTP request processed by the system with tracing enabled, the resulting
367    // span should contain: request method, request path, request ID, response status code,
368    // and response duration.
369    //
370    // **Validates: Requirements 4.1, 4.2, 4.3, 4.4**
371    proptest! {
372        #![proptest_config(ProptestConfig::with_cases(100))]
373
374        #[test]
375        fn prop_tracing_span_completeness(
376            method_idx in 0usize..5usize,
377            path in "/[a-z]{1,10}(/[a-z]{1,10})?",
378            status_code in 200u16..600u16,
379            custom_key in "[a-z]{3,10}",
380            custom_value in "[a-z0-9]{3,20}",
381        ) {
382            let rt = tokio::runtime::Runtime::new().unwrap();
383            let result: Result<(), TestCaseError> = rt.block_on(async {
384                // Set up span capture
385                let capture = SpanFieldCapture::new();
386                let subscriber = tracing_subscriber::registry().with(capture.clone());
387
388                // Use a guard to set the subscriber for this test
389                let _guard = tracing::subscriber::set_default(subscriber);
390
391                // Create middleware stack with RequestIdLayer and TracingLayer
392                let mut stack = LayerStack::new();
393                stack.push(Box::new(RequestIdLayer::new()));
394                stack.push(Box::new(TracingLayer::new()
395                    .with_field(&custom_key, &custom_value)));
396
397                // Map index to HTTP method
398                let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
399                let method = methods[method_idx].clone();
400
401                // Create handler that returns the specified status
402                let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
403                let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
404                    let status = response_status;
405                    Box::pin(async move {
406                        http::Response::builder()
407                            .status(status)
408                            .body(http_body_util::Full::new(Bytes::from("test")))
409                            .unwrap()
410                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
411                });
412
413                // Execute request
414                let request = create_test_request(method.clone(), &path);
415                let response = stack.execute(request, handler).await;
416
417                // Verify response status matches
418                prop_assert_eq!(response.status(), response_status);
419
420                // Find the http_request span
421                let spans = capture.get_spans();
422                let http_span = spans.iter().find(|s| s.name == "http_request");
423
424                prop_assert!(http_span.is_some(), "Should have created an http_request span");
425                let span = http_span.unwrap();
426
427                // Verify required fields are present
428                // Method
429                prop_assert!(
430                    span.fields.contains_key("method"),
431                    "Span should contain 'method' field. Fields: {:?}", span.fields
432                );
433                prop_assert_eq!(
434                    span.fields.get("method").map(|s| s.trim_matches('"')),
435                    Some(method.as_str()),
436                    "Method should match request method"
437                );
438
439                // Path
440                prop_assert!(
441                    span.fields.contains_key("path"),
442                    "Span should contain 'path' field. Fields: {:?}", span.fields
443                );
444                prop_assert_eq!(
445                    span.fields.get("path").map(|s| s.trim_matches('"')),
446                    Some(path.as_str()),
447                    "Path should match request path"
448                );
449
450                // Request ID
451                prop_assert!(
452                    span.fields.contains_key("request_id"),
453                    "Span should contain 'request_id' field. Fields: {:?}", span.fields
454                );
455                let request_id = span.fields.get("request_id").unwrap();
456                // Request ID should be a UUID format (36 chars with hyphens) or "unknown"
457                let request_id_trimmed = request_id.trim_matches('"');
458                prop_assert!(
459                    request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
460                    "Request ID should be UUID format or 'unknown', got: {}", request_id
461                );
462
463                // Status code (recorded after response)
464                prop_assert!(
465                    span.fields.contains_key("status"),
466                    "Span should contain 'status' field. Fields: {:?}", span.fields
467                );
468                let recorded_status: u16 = span.fields.get("status")
469                    .and_then(|s| s.parse().ok())
470                    .unwrap_or(0);
471                prop_assert_eq!(
472                    recorded_status,
473                    status_code,
474                    "Status should match response status code"
475                );
476
477                // Duration (recorded after response)
478                prop_assert!(
479                    span.fields.contains_key("duration_ms"),
480                    "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
481                );
482                let duration: u64 = span.fields.get("duration_ms")
483                    .and_then(|s| s.parse().ok())
484                    .unwrap_or(u64::MAX);
485                prop_assert!(
486                    duration < 10000, // Should complete in less than 10 seconds
487                    "Duration should be reasonable, got: {} ms", duration
488                );
489
490                // Error field should be present for error responses
491                if response_status.is_client_error() || response_status.is_server_error() {
492                    prop_assert!(
493                        span.fields.contains_key("error"),
494                        "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
495                    );
496                }
497
498                Ok(())
499            });
500            result?;
501        }
502    }
503
504    #[test]
505    fn test_tracing_layer_records_request_id() {
506        let rt = tokio::runtime::Runtime::new().unwrap();
507        rt.block_on(async {
508            let capture = SpanFieldCapture::new();
509            let subscriber = tracing_subscriber::registry().with(capture.clone());
510            let _guard = tracing::subscriber::set_default(subscriber);
511
512            let mut stack = LayerStack::new();
513            stack.push(Box::new(RequestIdLayer::new()));
514            stack.push(Box::new(TracingLayer::new()));
515
516            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
517                Box::pin(async {
518                    http::Response::builder()
519                        .status(StatusCode::OK)
520                        .body(http_body_util::Full::new(Bytes::from("ok")))
521                        .unwrap()
522                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
523            });
524
525            let request = create_test_request(Method::GET, "/test");
526            let _response = stack.execute(request, handler).await;
527
528            let spans = capture.get_spans();
529            let http_span = spans.iter().find(|s| s.name == "http_request");
530            assert!(http_span.is_some(), "Should have http_request span");
531
532            let span = http_span.unwrap();
533            assert!(
534                span.fields.contains_key("request_id"),
535                "Should have request_id field"
536            );
537        });
538    }
539
540    #[test]
541    fn test_tracing_layer_records_error_for_failures() {
542        let rt = tokio::runtime::Runtime::new().unwrap();
543        rt.block_on(async {
544            let capture = SpanFieldCapture::new();
545            let subscriber = tracing_subscriber::registry().with(capture.clone());
546            let _guard = tracing::subscriber::set_default(subscriber);
547
548            let mut stack = LayerStack::new();
549            stack.push(Box::new(TracingLayer::new()));
550
551            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
552                Box::pin(async {
553                    http::Response::builder()
554                        .status(StatusCode::INTERNAL_SERVER_ERROR)
555                        .body(http_body_util::Full::new(Bytes::from("error")))
556                        .unwrap()
557                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
558            });
559
560            let request = create_test_request(Method::GET, "/test");
561            let response = stack.execute(request, handler).await;
562
563            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
564
565            let spans = capture.get_spans();
566            let http_span = spans.iter().find(|s| s.name == "http_request");
567            assert!(http_span.is_some(), "Should have http_request span");
568
569            let span = http_span.unwrap();
570            assert!(
571                span.fields.contains_key("error"),
572                "Should have error field for 5xx response"
573            );
574        });
575    }
576}