rustapi_core/middleware/
tracing_layer.rs

1//! Enhanced Tracing middleware
2//!
3//! Logs request method, path, request_id, status code, and duration for each request.
4//! Supports custom fields that are included in all request spans.
5
6use super::layer::{BoxedNext, MiddlewareLayer};
7use super::request_id::RequestId;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Instant;
13use tracing::{info_span, Instrument, Level};
14
15/// Middleware layer that creates tracing spans for requests
16///
17/// This layer creates a span for each request containing:
18/// - HTTP method
19/// - Request path
20/// - Request ID (if RequestIdLayer is applied)
21/// - Response status code
22/// - Request duration
23/// - Any custom fields configured via `with_field()`
24///
25/// # Example
26///
27/// ```rust,ignore
28/// use rustapi_core::middleware::TracingLayer;
29///
30/// RustApi::new()
31///     .layer(TracingLayer::new()
32///         .with_field("service", "my-api")
33///         .with_field("version", "1.0.0"))
34///     .route("/", get(handler))
35/// ```
36#[derive(Clone)]
37pub struct TracingLayer {
38    level: Level,
39    custom_fields: Vec<(String, String)>,
40}
41
42impl TracingLayer {
43    /// Create a new TracingLayer with default INFO level
44    pub fn new() -> Self {
45        Self {
46            level: Level::INFO,
47            custom_fields: Vec::new(),
48        }
49    }
50
51    /// Create a TracingLayer with a specific log level
52    pub fn with_level(level: Level) -> Self {
53        Self {
54            level,
55            custom_fields: Vec::new(),
56        }
57    }
58
59    /// Add a custom field to all request spans
60    ///
61    /// Custom fields are included in every span created by this layer.
62    ///
63    /// # Example
64    ///
65    /// ```rust,ignore
66    /// TracingLayer::new()
67    ///     .with_field("service", "my-api")
68    ///     .with_field("environment", "production")
69    /// ```
70    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
71        self.custom_fields.push((key.into(), value.into()));
72        self
73    }
74}
75
76impl Default for TracingLayer {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl MiddlewareLayer for TracingLayer {
83    fn call(
84        &self,
85        req: Request,
86        next: BoxedNext,
87    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
88        let level = self.level;
89        let method = req.method().to_string();
90        let path = req.uri().path().to_string();
91        let custom_fields = self.custom_fields.clone();
92
93        // Extract request_id if available
94        let request_id = req
95            .extensions()
96            .get::<RequestId>()
97            .map(|id| id.as_str().to_string())
98            .unwrap_or_else(|| "unknown".to_string());
99
100        Box::pin(async move {
101            let start = Instant::now();
102
103            // Create span with all fields
104            // We use info_span! as the base and record custom fields dynamically
105            let span = info_span!(
106                "http_request",
107                method = %method,
108                path = %path,
109                request_id = %request_id,
110                status = tracing::field::Empty,
111                duration_ms = tracing::field::Empty,
112                error = tracing::field::Empty,
113            );
114
115            // Record custom fields in the span
116            for (key, value) in &custom_fields {
117                span.record(key.as_str(), value.as_str());
118            }
119
120            // Execute the request within the span
121            let response = async {
122                next(req).await
123            }
124            .instrument(span.clone())
125            .await;
126
127            let duration = start.elapsed();
128            let status = response.status();
129            let status_code = status.as_u16();
130
131            // Record response fields
132            span.record("status", status_code);
133            span.record("duration_ms", duration.as_millis() as u64);
134
135            // Record error if status indicates failure
136            if status.is_client_error() || status.is_server_error() {
137                span.record("error", true);
138            }
139
140            // Log based on status code and configured level
141            let _enter = span.enter();
142            if status.is_success() {
143                match level {
144                    Level::TRACE => tracing::trace!(
145                        method = %method,
146                        path = %path,
147                        request_id = %request_id,
148                        status = %status_code,
149                        duration_ms = %duration.as_millis(),
150                        "Request completed"
151                    ),
152                    Level::DEBUG => tracing::debug!(
153                        method = %method,
154                        path = %path,
155                        request_id = %request_id,
156                        status = %status_code,
157                        duration_ms = %duration.as_millis(),
158                        "Request completed"
159                    ),
160                    Level::INFO => tracing::info!(
161                        method = %method,
162                        path = %path,
163                        request_id = %request_id,
164                        status = %status_code,
165                        duration_ms = %duration.as_millis(),
166                        "Request completed"
167                    ),
168                    Level::WARN => tracing::warn!(
169                        method = %method,
170                        path = %path,
171                        request_id = %request_id,
172                        status = %status_code,
173                        duration_ms = %duration.as_millis(),
174                        "Request completed"
175                    ),
176                    Level::ERROR => tracing::error!(
177                        method = %method,
178                        path = %path,
179                        request_id = %request_id,
180                        status = %status_code,
181                        duration_ms = %duration.as_millis(),
182                        "Request completed"
183                    ),
184                }
185            } else {
186                tracing::warn!(
187                    method = %method,
188                    path = %path,
189                    request_id = %request_id,
190                    status = %status_code,
191                    duration_ms = %duration.as_millis(),
192                    error = true,
193                    "Request failed"
194                );
195            }
196
197            response
198        })
199    }
200
201    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
202        Box::new(self.clone())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::middleware::layer::{BoxedNext, LayerStack};
210    use crate::middleware::request_id::RequestIdLayer;
211    use bytes::Bytes;
212    use http::{Extensions, Method, StatusCode};
213    use proptest::prelude::*;
214    use proptest::test_runner::TestCaseError;
215    use std::collections::HashMap;
216    use std::sync::Arc;
217    use tracing_subscriber::layer::SubscriberExt;
218
219    /// Create a test request with the given method and path
220    fn create_test_request(method: Method, path: &str) -> crate::request::Request {
221        let uri: http::Uri = path.parse().unwrap();
222        let builder = http::Request::builder().method(method).uri(uri);
223
224        let req = builder.body(()).unwrap();
225        let (parts, _) = req.into_parts();
226
227        crate::request::Request::new(
228            parts,
229            Bytes::new(),
230            Arc::new(Extensions::new()),
231            HashMap::new(),
232        )
233    }
234
235    #[test]
236    fn test_tracing_layer_creation() {
237        let layer = TracingLayer::new();
238        assert_eq!(layer.level, Level::INFO);
239        assert!(layer.custom_fields.is_empty());
240
241        let layer = TracingLayer::with_level(Level::DEBUG);
242        assert_eq!(layer.level, Level::DEBUG);
243    }
244
245    #[test]
246    fn test_tracing_layer_with_custom_fields() {
247        let layer = TracingLayer::new()
248            .with_field("service", "test-api")
249            .with_field("version", "1.0.0");
250
251        assert_eq!(layer.custom_fields.len(), 2);
252        assert_eq!(layer.custom_fields[0], ("service".to_string(), "test-api".to_string()));
253        assert_eq!(layer.custom_fields[1], ("version".to_string(), "1.0.0".to_string()));
254    }
255
256    #[test]
257    fn test_tracing_layer_clone() {
258        let layer = TracingLayer::new()
259            .with_field("key", "value");
260        
261        let cloned = layer.clone();
262        assert_eq!(cloned.level, layer.level);
263        assert_eq!(cloned.custom_fields, layer.custom_fields);
264    }
265
266    /// A test subscriber that captures span fields for verification
267    #[derive(Clone)]
268    struct SpanFieldCapture {
269        captured_fields: Arc<std::sync::Mutex<Vec<CapturedSpan>>>,
270    }
271
272    #[derive(Debug, Clone)]
273    struct CapturedSpan {
274        name: String,
275        fields: HashMap<String, String>,
276    }
277
278    impl SpanFieldCapture {
279        fn new() -> Self {
280            Self {
281                captured_fields: Arc::new(std::sync::Mutex::new(Vec::new())),
282            }
283        }
284
285        fn get_spans(&self) -> Vec<CapturedSpan> {
286            self.captured_fields.lock().unwrap().clone()
287        }
288    }
289
290    impl<S> tracing_subscriber::Layer<S> for SpanFieldCapture
291    where
292        S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
293    {
294        fn on_new_span(
295            &self,
296            attrs: &tracing::span::Attributes<'_>,
297            _id: &tracing::span::Id,
298            _ctx: tracing_subscriber::layer::Context<'_, S>,
299        ) {
300            let mut fields = HashMap::new();
301            let mut visitor = FieldVisitor { fields: &mut fields };
302            attrs.record(&mut visitor);
303
304            let span = CapturedSpan {
305                name: attrs.metadata().name().to_string(),
306                fields,
307            };
308
309            self.captured_fields.lock().unwrap().push(span);
310        }
311
312        fn on_record(
313            &self,
314            id: &tracing::span::Id,
315            values: &tracing::span::Record<'_>,
316            ctx: tracing_subscriber::layer::Context<'_, S>,
317        ) {
318            if let Some(_span) = ctx.span(id) {
319                let mut captured = self.captured_fields.lock().unwrap();
320                if let Some(last_span) = captured.last_mut() {
321                    let mut visitor = FieldVisitor { fields: &mut last_span.fields };
322                    values.record(&mut visitor);
323                }
324            }
325        }
326    }
327
328    struct FieldVisitor<'a> {
329        fields: &'a mut HashMap<String, String>,
330    }
331
332    impl<'a> tracing::field::Visit for FieldVisitor<'a> {
333        fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
334            self.fields.insert(field.name().to_string(), format!("{:?}", value));
335        }
336
337        fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
338            self.fields.insert(field.name().to_string(), value.to_string());
339        }
340
341        fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
342            self.fields.insert(field.name().to_string(), value.to_string());
343        }
344
345        fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
346            self.fields.insert(field.name().to_string(), value.to_string());
347        }
348
349        fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
350            self.fields.insert(field.name().to_string(), value.to_string());
351        }
352    }
353
354    // **Feature: phase4-ergonomics-v1, Property 8: Tracing Span Completeness**
355    //
356    // For any HTTP request processed by the system with tracing enabled, the resulting
357    // span should contain: request method, request path, request ID, response status code,
358    // and response duration.
359    //
360    // **Validates: Requirements 4.1, 4.2, 4.3, 4.4**
361    proptest! {
362        #![proptest_config(ProptestConfig::with_cases(100))]
363
364        #[test]
365        fn prop_tracing_span_completeness(
366            method_idx in 0usize..5usize,
367            path in "/[a-z]{1,10}(/[a-z]{1,10})?",
368            status_code in 200u16..600u16,
369            custom_key in "[a-z]{3,10}",
370            custom_value in "[a-z0-9]{3,20}",
371        ) {
372            let rt = tokio::runtime::Runtime::new().unwrap();
373            let result: Result<(), TestCaseError> = rt.block_on(async {
374                // Set up span capture
375                let capture = SpanFieldCapture::new();
376                let subscriber = tracing_subscriber::registry().with(capture.clone());
377                
378                // Use a guard to set the subscriber for this test
379                let _guard = tracing::subscriber::set_default(subscriber);
380
381                // Create middleware stack with RequestIdLayer and TracingLayer
382                let mut stack = LayerStack::new();
383                stack.push(Box::new(RequestIdLayer::new()));
384                stack.push(Box::new(TracingLayer::new()
385                    .with_field(&custom_key, &custom_value)));
386
387                // Map index to HTTP method
388                let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
389                let method = methods[method_idx].clone();
390
391                // Create handler that returns the specified status
392                let response_status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK);
393                let handler: BoxedNext = Arc::new(move |_req: crate::request::Request| {
394                    let status = response_status;
395                    Box::pin(async move {
396                        http::Response::builder()
397                            .status(status)
398                            .body(http_body_util::Full::new(Bytes::from("test")))
399                            .unwrap()
400                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
401                });
402
403                // Execute request
404                let request = create_test_request(method.clone(), &path);
405                let response = stack.execute(request, handler).await;
406
407                // Verify response status matches
408                prop_assert_eq!(response.status(), response_status);
409
410                // Find the http_request span
411                let spans = capture.get_spans();
412                let http_span = spans.iter().find(|s| s.name == "http_request");
413                
414                prop_assert!(http_span.is_some(), "Should have created an http_request span");
415                let span = http_span.unwrap();
416
417                // Verify required fields are present
418                // Method
419                prop_assert!(
420                    span.fields.contains_key("method"),
421                    "Span should contain 'method' field. Fields: {:?}", span.fields
422                );
423                prop_assert_eq!(
424                    span.fields.get("method").map(|s| s.trim_matches('"')),
425                    Some(method.as_str()),
426                    "Method should match request method"
427                );
428
429                // Path
430                prop_assert!(
431                    span.fields.contains_key("path"),
432                    "Span should contain 'path' field. Fields: {:?}", span.fields
433                );
434                prop_assert_eq!(
435                    span.fields.get("path").map(|s| s.trim_matches('"')),
436                    Some(path.as_str()),
437                    "Path should match request path"
438                );
439
440                // Request ID
441                prop_assert!(
442                    span.fields.contains_key("request_id"),
443                    "Span should contain 'request_id' field. Fields: {:?}", span.fields
444                );
445                let request_id = span.fields.get("request_id").unwrap();
446                // Request ID should be a UUID format (36 chars with hyphens) or "unknown"
447                let request_id_trimmed = request_id.trim_matches('"');
448                prop_assert!(
449                    request_id_trimmed == "unknown" || request_id_trimmed.len() == 36,
450                    "Request ID should be UUID format or 'unknown', got: {}", request_id
451                );
452
453                // Status code (recorded after response)
454                prop_assert!(
455                    span.fields.contains_key("status"),
456                    "Span should contain 'status' field. Fields: {:?}", span.fields
457                );
458                let recorded_status: u16 = span.fields.get("status")
459                    .and_then(|s| s.parse().ok())
460                    .unwrap_or(0);
461                prop_assert_eq!(
462                    recorded_status,
463                    status_code,
464                    "Status should match response status code"
465                );
466
467                // Duration (recorded after response)
468                prop_assert!(
469                    span.fields.contains_key("duration_ms"),
470                    "Span should contain 'duration_ms' field. Fields: {:?}", span.fields
471                );
472                let duration: u64 = span.fields.get("duration_ms")
473                    .and_then(|s| s.parse().ok())
474                    .unwrap_or(u64::MAX);
475                prop_assert!(
476                    duration < 10000, // Should complete in less than 10 seconds
477                    "Duration should be reasonable, got: {} ms", duration
478                );
479
480                // Error field should be present for error responses
481                if response_status.is_client_error() || response_status.is_server_error() {
482                    prop_assert!(
483                        span.fields.contains_key("error"),
484                        "Span should contain 'error' field for error responses. Fields: {:?}", span.fields
485                    );
486                }
487
488                Ok(())
489            });
490            result?;
491        }
492    }
493
494    #[test]
495    fn test_tracing_layer_records_request_id() {
496        let rt = tokio::runtime::Runtime::new().unwrap();
497        rt.block_on(async {
498            let capture = SpanFieldCapture::new();
499            let subscriber = tracing_subscriber::registry().with(capture.clone());
500            let _guard = tracing::subscriber::set_default(subscriber);
501
502            let mut stack = LayerStack::new();
503            stack.push(Box::new(RequestIdLayer::new()));
504            stack.push(Box::new(TracingLayer::new()));
505
506            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
507                Box::pin(async {
508                    http::Response::builder()
509                        .status(StatusCode::OK)
510                        .body(http_body_util::Full::new(Bytes::from("ok")))
511                        .unwrap()
512                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
513            });
514
515            let request = create_test_request(Method::GET, "/test");
516            let _response = stack.execute(request, handler).await;
517
518            let spans = capture.get_spans();
519            let http_span = spans.iter().find(|s| s.name == "http_request");
520            assert!(http_span.is_some(), "Should have http_request span");
521            
522            let span = http_span.unwrap();
523            assert!(span.fields.contains_key("request_id"), "Should have request_id field");
524        });
525    }
526
527    #[test]
528    fn test_tracing_layer_records_error_for_failures() {
529        let rt = tokio::runtime::Runtime::new().unwrap();
530        rt.block_on(async {
531            let capture = SpanFieldCapture::new();
532            let subscriber = tracing_subscriber::registry().with(capture.clone());
533            let _guard = tracing::subscriber::set_default(subscriber);
534
535            let mut stack = LayerStack::new();
536            stack.push(Box::new(TracingLayer::new()));
537
538            let handler: BoxedNext = Arc::new(|_req: crate::request::Request| {
539                Box::pin(async {
540                    http::Response::builder()
541                        .status(StatusCode::INTERNAL_SERVER_ERROR)
542                        .body(http_body_util::Full::new(Bytes::from("error")))
543                        .unwrap()
544                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
545            });
546
547            let request = create_test_request(Method::GET, "/test");
548            let response = stack.execute(request, handler).await;
549
550            assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
551
552            let spans = capture.get_spans();
553            let http_span = spans.iter().find(|s| s.name == "http_request");
554            assert!(http_span.is_some(), "Should have http_request span");
555            
556            let span = http_span.unwrap();
557            assert!(span.fields.contains_key("error"), "Should have error field for 5xx response");
558        });
559    }
560}