rustapi_core/middleware/
tracing_layer.rs

1//! Enhanced Tracing middleware
2//!
3//! Logs request method, path, request_id, status code, and duration for each request.
4//! Supports custom fields that are included in all request spans.
5
6use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15/// Middleware layer that creates tracing spans for requests
16///
17/// This layer creates a span for each request containing:
18/// - HTTP method
19/// - Request path
20/// - Request ID (if RequestIdLayer is applied)
21/// - Response status code
22/// - Request duration
23/// - Any custom fields configured via `with_field()`
24///
25/// # Example
26///
27/// ```rust,ignore
28/// use rustapi_core::middleware::TracingLayer;
29///
30/// RustApi::new()
31///     .layer(TracingLayer::new()
32///         .with_field("service", "my-api")
33///         .with_field("version", "1.0.0"))
34///     .route("/", get(handler))
35/// ```
36#[derive(Clone)]
37pub struct TracingLayer {
38    level: Level,
39    custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43    /// Create a new TracingLayer with default INFO level
44    pub fn new() -> Self {
45        Self {
46            level: Level::INFO,
47            custom_fields: Vec::new(),
48        }
49    }
50
51    /// Create a TracingLayer with a specific log level
52    pub fn with_level(level: Level) -> Self {
53        Self {
54            level,
55            custom_fields: Vec::new(),
56        }
57    }
58
59    /// Add a custom field to all request spans
60    ///
61    /// Custom fields are included in every span created by this layer.
62    ///
63    /// # Example
64    ///
65    /// ```rust,ignore
66    /// TracingLayer::new()
67    ///     .with_field("service", "my-api")
68    ///     .with_field("environment", "production")
69    /// ```
70    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71        self.custom_fields.push((key.into(), value.into()));
72        self
73    }
74}
75
76impl Default for TracingLayer {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl MiddlewareLayer for TracingLayer {
83    fn call(
84        &self,
85        req: Request,
86        next: BoxedNext,
87    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88        let level = self.level;
89        let method = req.method().to_string();
90        let path = req.uri().path().to_string();
91        let custom_fields = self.custom_fields.clone();
92
93        // Extract request_id if available
94        let request_id = req
95            .extensions()
96            .get::<RequestId>()
97            .map(|id| id.as_str().to_string())
98            .unwrap_or_else(|| "unknown".to_string());
99
100        Box::pin(async move {
101            let start = Instant::now();
102
103            // Create span with all fields
104            // We use info_span! as the base and record custom fields dynamically
105            let span = info_span!(
106                "http_request",
107                method = %method,
108                path = %path,
109                request_id = %request_id,
110                status = tracing::field::Empty,
111                duration_ms = tracing::field::Empty,
112                error = tracing::field::Empty,
113            );
114
115            // Record custom fields in the span
116            for (key, value) in &custom_fields {
117                span.record(key.as_str(), value.as_str());
118            }
119
120            // Execute the request within the span
121            let response = async { next(req).await }.instrument(span.clone()).await;
122
123            let duration = start.elapsed();
124            let status = response.status();
125            let status_code = status.as_u16();
126
127            // Record response fields
128            span.record("status", status_code);
129            span.record("duration_ms", duration.as_millis() as u64);
130
131            // Record error if status indicates failure
132            if status.is_client_error() || status.is_server_error() {
133                span.record("error", true);
134            }
135
136            // Log based on status code and configured level
137            let _enter = span.enter();
138            if status.is_success() {
139                match level {
140                    Level::TRACE => tracing::trace!(
141                        method = %method,
142                        path = %path,
143                        request_id = %request_id,
144                        status = %status_code,
145                        duration_ms = %duration.as_millis(),
146                        "Request completed"
147                    ),
148                    Level::DEBUG => tracing::debug!(
149                        method = %method,
150                        path = %path,
151                        request_id = %request_id,
152                        status = %status_code,
153                        duration_ms = %duration.as_millis(),
154                        "Request completed"
155                    ),
156                    Level::INFO => tracing::info!(
157                        method = %method,
158                        path = %path,
159                        request_id = %request_id,
160                        status = %status_code,
161                        duration_ms = %duration.as_millis(),
162                        "Request completed"
163                    ),
164                    Level::WARN => tracing::warn!(
165                        method = %method,
166                        path = %path,
167                        request_id = %request_id,
168                        status = %status_code,
169                        duration_ms = %duration.as_millis(),
170                        "Request completed"
171                    ),
172                    Level::ERROR => tracing::error!(
173                        method = %method,
174                        path = %path,
175                        request_id = %request_id,
176                        status = %status_code,
177                        duration_ms = %duration.as_millis(),
178                        "Request completed"
179                    ),
180                }
181            } else {
182                tracing::warn!(
183                    method = %method,
184                    path = %path,
185                    request_id = %request_id,
186                    status = %status_code,
187                    duration_ms = %duration.as_millis(),
188                    error = true,
189                    "Request failed"
190                );
191            }
192
193            response
194        })
195    }
196
197    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
198        Box::new(self.clone())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::middleware::layer::{BoxedNext, LayerStack};
206    use crate::middleware::request_id::RequestIdLayer;
207    use crate::path_params::PathParams;
208    use bytes::Bytes;
209    use http::{Extensions, Method, StatusCode};
210    use proptest::prelude::*;
211    use proptest::test_runner::TestCaseError;
212    use std::collections::HashMap;
213    use std::sync::Arc;
214    use tracing_subscriber::layer::SubscriberExt;
215
216    /// Create a test request with the given method and path
217    fn create_test_request(method: Method, path: &str) -> crate::request::Request {
218        let uri: http::Uri = path.parse().unwrap();
219        let builder = http::Request::builder().method(method).uri(uri);
220
221        let req = builder.body(()).unwrap();
222        let (parts, _) = req.into_parts();
223
224        crate::request::Request::new(
225            parts,
226            crate::request::BodyVariant::Buffered(Bytes::new()),
227            Arc::new(Extensions::new()),
228            PathParams::new(),
229        )
230    }
231
232    #[test]
233    fn test_tracing_layer_creation() {
234        let layer = TracingLayer::new();
235        assert_eq!(layer.level, Level::INFO);
236        assert!(layer.custom_fields.is_empty());
237
238        let layer = TracingLayer::with_level(Level::DEBUG);
239        assert_eq!(layer.level, Level::DEBUG);
240    }
241
242    #[test]
243    fn test_tracing_layer_with_custom_fields() {
244        let layer = TracingLayer::new()
245            .with_field("service", "test-api")
246            .with_field("version", "1.0.0");
247
248        assert_eq!(layer.custom_fields.len(), 2);
249        assert_eq!(
250            layer.custom_fields[0],
251            ("service".to_string(), "test-api".to_string())
252        );
253        assert_eq!(
254            layer.custom_fields[1],
255            ("version".to_string(), "1.0.0".to_string())
256        );
257    }
258
259    #[test]
260    fn test_tracing_layer_clone() {
261        let layer = TracingLayer::new().with_field("key", "value");
262
263        let cloned = layer.clone();
264        assert_eq!(cloned.level, layer.level);
265        assert_eq!(cloned.custom_fields, layer.custom_fields);
266    }
267
268    /// A test subscriber that captures span fields for verification
269    #[derive(Clone)]
270    struct SpanFieldCapture {
271        captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
272    }
273
274    #[derive(Debug, Clone)]
275    struct CapturedSpan {
276        name: String,
277        fields: HashMap<String, String>,
278    }
279
280    impl SpanFieldCapture {
281        fn new() -> Self {
282            Self {
283                captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
284            }
285        }
286
287        fn get_spans(&self) -> Vec<CapturedSpan> {
288            self.captured_fields.lock().unwrap().clone()
289        }
290    }
291
292    impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
293    where
294        S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
295    {
296        fn on_new_span(
297            &self,
298            attrs: &tracing::span::Attributes<'_>,
299            _id: &tracing::span::Id,
300            _ctx: tracing_subscriber::layer::Context<'_, S>,
301        ) {
302            let mut fields = HashMap::new();
303            let mut visitor = FieldVisitor {
304                fields: &mut fields,
305            };
306            attrs.record(&mut visitor);
307
308            let span = CapturedSpan {
309                name: attrs.metadata().name().to_string(),
310                fields,
311            };
312
313            self.captured_fields.lock().unwrap().push(span);
314        }
315
316        fn on_record(
317            &self,
318            id: &tracing::span::Id,
319            values: &tracing::span::Record<'_>,
320            ctx: tracing_subscriber::layer::Context<'_, S>,
321        ) {
322            if let Some(_span) = ctx.span(id) {
323                let mut captured = self.captured_fields.lock().unwrap();
324                if let Some(last_span) = captured.last_mut() {
325                    let mut visitor = FieldVisitor {
326                        fields: &mut last_span.fields,
327                    };
328                    values.record(&mut visitor);
329                }
330            }
331        }
332    }
333
334    struct FieldVisitor<'a> {
335        fields: &'a mut HashMap<String, String>,
336    }
337
338    impl<'a> tracing::field::Visit for FieldVisitor<'a> {
339        fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
340            self.fields
341                .insert(field.name().to_string(), format!("{:?}", value));
342        }
343
344        fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
345            self.fields
346                .insert(field.name().to_string(), value.to_string());
347        }
348
349        fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
350            self.fields
351                .insert(field.name().to_string(), value.to_string());
352        }
353
354        fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
355            self.fields
356                .insert(field.name().to_string(), value.to_string());
357        }
358
359        fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
360            self.fields
361                .insert(field.name().to_string(), value.to_string());
362        }
363    }
364
365    // **Feature: phase4-ergonomics-v1, Property 8: Tracing Span Completeness**
366    //
367    // For any HTTP request processed by the system with tracing enabled, the resulting
368    // span should contain: request method, request path, request ID, response status code,
369    // and response duration.
370    //
371    // **Validates: Requirements 4.1, 4.2, 4.3, 4.4**
372    proptest! {
373        #![proptest_config(ProptestConfig::with_cases(100))]
374
375        #[test]
376        fn prop_tracing_span_completeness(
377            method_idx in 0usize..5usize,
378            path in "/[a-z]{1,10}(/[a-z]{1,10})?",
379            status_code in 200u16..600u16,
380            custom_key in "[a-z]{3,10}",
381            custom_value in "[a-z0-9]{3,20}",
382        ) {
383            let rt = tokio::runtime::Runtime::new().unwrap();
384            let result: Result<(), TestCaseError> = rt.block_on(async {
385                // Set up span capture
386                let capture = SpanFieldCapture::new();
387                let subscriber = tracing_subscriber::registry().with(capture.clone());
388
389                // Use a guard to set the subscriber for this test
390                let _guard = tracing::subscriber::set_default(subscriber);
391
392                // Create middleware stack with RequestIdLayer and TracingLayer
393                let mut stack = LayerStack::new();
394                stack.push(Box::new(RequestIdLayer::new()));
395                stack.push(Box::new(TracingLayer::new()
396                    .with_field(&custom_key, &custom_value)));
397
398                // Map index to HTTP method
399                let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
400                let method = methods[method_idx].clone();
401
402                // Create handler that returns the specified status
403                let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
404                let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
405                    let status = response_status;
406                    Box::pin(async move {
407                        http::Response::builder()
408                            .status(status)
409                            .body(http_body_util::Full::new(Bytes::from("test")))
410                            .unwrap()
411                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
412                });
413
414                // Execute request
415                let request = create_test_request(method.clone(), &path);
416                let response = stack.execute(request, handler).await;
417
418                // Verify response status matches
419                prop_assert_eq!(response.status(), response_status);
420
421                // Find the http_request span
422                let spans = capture.get_spans();
423                let http_span = spans.iter().find(|s| s.name == "http_request");
424
425                prop_assert!(http_span.is_some(), "Should have created an http_request span");
426                let span = http_span.unwrap();
427
428                // Verify required fields are present
429                // Method
430                prop_assert!(
431                    span.fields.contains_key("method"),
432                    "Span should contain 'method' field. Fields: {:?}", span.fields
433                );
434                prop_assert_eq!(
435                    span.fields.get("method").map(|s| s.trim_matches('"')),
436                    Some(method.as_str()),
437                    "Method should match request method"
438                );
439
440                // Path
441                prop_assert!(
442                    span.fields.contains_key("path"),
443                    "Span should contain 'path' field. Fields: {:?}", span.fields
444                );
445                prop_assert_eq!(
446                    span.fields.get("path").map(|s| s.trim_matches('"')),
447                    Some(path.as_str()),
448                    "Path should match request path"
449                );
450
451                // Request ID
452                prop_assert!(
453                    span.fields.contains_key("request_id"),
454                    "Span should contain 'request_id' field. Fields: {:?}", span.fields
455                );
456                let request_id = span.fields.get("request_id").unwrap();
457                // Request ID should be a UUID format (36 chars with hyphens) or "unknown"
458                let request_id_trimmed = request_id.trim_matches('"');
459                prop_assert!(
460                    request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
461                    "Request ID should be UUID format or 'unknown', got: {}", request_id
462                );
463
464                // Status code (recorded after response)
465                prop_assert!(
466                    span.fields.contains_key("status"),
467                    "Span should contain 'status' field. Fields: {:?}", span.fields
468                );
469                let recorded_status: u16 = span.fields.get("status")
470                    .and_then(|s| s.parse().ok())
471                    .unwrap_or(0);
472                prop_assert_eq!(
473                    recorded_status,
474                    status_code,
475                    "Status should match response status code"
476                );
477
478                // Duration (recorded after response)
479                prop_assert!(
480                    span.fields.contains_key("duration_ms"),
481                    "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
482                );
483                let duration: u64 = span.fields.get("duration_ms")
484                    .and_then(|s| s.parse().ok())
485                    .unwrap_or(u64::MAX);
486                prop_assert!(
487                    duration < 10000, // Should complete in less than 10 seconds
488                    "Duration should be reasonable, got: {} ms", duration
489                );
490
491                // Error field should be present for error responses
492                if response_status.is_client_error() || response_status.is_server_error() {
493                    prop_assert!(
494                        span.fields.contains_key("error"),
495                        "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
496                    );
497                }
498
499                Ok(())
500            });
501            result?;
502        }
503    }
504
505    #[test]
506    fn test_tracing_layer_records_request_id() {
507        let rt = tokio::runtime::Runtime::new().unwrap();
508        rt.block_on(async {
509            let capture = SpanFieldCapture::new();
510            let subscriber = tracing_subscriber::registry().with(capture.clone());
511            let _guard = tracing::subscriber::set_default(subscriber);
512
513            let mut stack = LayerStack::new();
514            stack.push(Box::new(RequestIdLayer::new()));
515            stack.push(Box::new(TracingLayer::new()));
516
517            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
518                Box::pin(async {
519                    http::Response::builder()
520                        .status(StatusCode::OK)
521                        .body(http_body_util::Full::new(Bytes::from("ok")))
522                        .unwrap()
523                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
524            });
525
526            let request = create_test_request(Method::GET, "/test");
527            let _response = stack.execute(request, handler).await;
528
529            let spans = capture.get_spans();
530            let http_span = spans.iter().find(|s| s.name == "http_request");
531            assert!(http_span.is_some(), "Should have http_request span");
532
533            let span = http_span.unwrap();
534            assert!(
535                span.fields.contains_key("request_id"),
536                "Should have request_id field"
537            );
538        });
539    }
540
541    #[test]
542    fn test_tracing_layer_records_error_for_failures() {
543        let rt = tokio::runtime::Runtime::new().unwrap();
544        rt.block_on(async {
545            let capture = SpanFieldCapture::new();
546            let subscriber = tracing_subscriber::registry().with(capture.clone());
547            let _guard = tracing::subscriber::set_default(subscriber);
548
549            let mut stack = LayerStack::new();
550            stack.push(Box::new(TracingLayer::new()));
551
552            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
553                Box::pin(async {
554                    http::Response::builder()
555                        .status(StatusCode::INTERNAL_SERVER_ERROR)
556                        .body(http_body_util::Full::new(Bytes::from("error")))
557                        .unwrap()
558                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
559            });
560
561            let request = create_test_request(Method::GET, "/test");
562            let response = stack.execute(request, handler).await;
563
564            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
565
566            let spans = capture.get_spans();
567            let http_span = spans.iter().find(|s| s.name == "http_request");
568            assert!(http_span.is_some(), "Should have http_request span");
569
570            let span = http_span.unwrap();
571            assert!(
572                span.fields.contains_key("error"),
573                "Should have error field for 5xx response"
574            );
575        });
576    }
577}