rustapi_core/middleware/
request_id.rs

1//! Request ID middleware
2//!
3//! Generates a unique UUID for each request and makes it available via the `RequestId` extractor.
4
5use super::layer::{BoxedNext, MiddlewareLayer};
6use crate::error::{ApiError, Result};
7use crate::extract::FromRequestParts;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12
13/// A unique identifier for a request
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct RequestId(pub String);
16
17impl RequestId {
18    /// Create a new RequestId with a generated UUID
19    pub fn new() -> Self {
20        Self(generate_uuid())
21    }
22
23    /// Create a RequestId from an existing string
24    pub fn from_string(id: String) -> Self {
25        Self(id)
26    }
27
28    /// Get the request ID as a string slice
29    pub fn as_str(&self) -> &str {
30        &self.0
31    }
32}
33
34impl Default for RequestId {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl std::fmt::Display for RequestId {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "{}", self.0)
43    }
44}
45
46/// Extractor for RequestId from request extensions
47///
48/// This extractor retrieves the request ID that was generated by `RequestIdLayer`.
49/// Returns an error if the RequestIdLayer middleware was not applied.
50///
51/// # Example
52///
53/// ```rust,ignore
54/// use rustapi_core::middleware::RequestId;
55///
56/// async fn handler(request_id: RequestId) -> impl IntoResponse {
57///     format!("Request ID: {}", request_id)
58/// }
59/// ```
60impl FromRequestParts for RequestId {
61    fn from_request_parts(req: &Request) -> Result<Self> {
62        req.extensions().get::<RequestId>().cloned().ok_or_else(|| {
63            ApiError::internal(
64                "RequestId not found. Did you forget to add RequestIdLayer middleware?",
65            )
66        })
67    }
68}
69
70/// Middleware layer that generates a unique request ID for each request
71#[derive(Clone, Default)]
72pub struct RequestIdLayer;
73
74impl RequestIdLayer {
75    /// Create a new RequestIdLayer
76    pub fn new() -> Self {
77        Self
78    }
79}
80
81impl MiddlewareLayer for RequestIdLayer {
82    fn call(
83        &self,
84        mut req: Request,
85        next: BoxedNext,
86    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
87        Box::pin(async move {
88            // Generate a unique request ID
89            let request_id = RequestId::new();
90
91            // Store in request extensions
92            req.extensions_mut().insert(request_id.clone());
93
94            // Call the next handler
95            let mut response = next(req).await;
96
97            // Add request ID to response headers
98            if let Ok(header_value) = request_id.0.parse() {
99                response.headers_mut().insert("x-request-id", header_value);
100            }
101
102            response
103        })
104    }
105
106    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
107        Box::new(self.clone())
108    }
109}
110
111/// Generate a UUID v4 string
112///
113/// This is a simple implementation that doesn't require external dependencies.
114fn generate_uuid() -> String {
115    use std::time::{SystemTime, UNIX_EPOCH};
116
117    // Get current time for entropy
118    let now = SystemTime::now()
119        .duration_since(UNIX_EPOCH)
120        .unwrap_or_default();
121
122    // Use time and a counter for uniqueness
123    let time_part = now.as_nanos();
124
125    // Generate random-ish bytes using time and thread ID
126    let thread_id = std::thread::current().id();
127    let thread_hash = format!("{:?}", thread_id);
128
129    // Create a simple hash combining time and thread info
130    let mut bytes = [0u8; 16];
131
132    // Fill with time-based entropy
133    let time_bytes = time_part.to_le_bytes();
134    for (i, &b) in time_bytes.iter().enumerate().take(16) {
135        bytes[i] = b;
136    }
137
138    // Mix in thread hash
139    for (i, b) in thread_hash.bytes().enumerate() {
140        bytes[i % 16] ^= b;
141    }
142
143    // Add some additional entropy from a simple counter
144    static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
145    let count = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
146    let count_bytes = count.to_le_bytes();
147    for (i, &b) in count_bytes.iter().enumerate() {
148        bytes[(i + 8) % 16] ^= b;
149    }
150
151    // Set version (4) and variant bits for UUID v4 format
152    bytes[6] = (bytes[6] & 0x0f) | 0x40; // Version 4
153    bytes[8] = (bytes[8] & 0x3f) | 0x80; // Variant 1
154
155    // Format as UUID string
156    format!(
157        "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
158        bytes[0], bytes[1], bytes[2], bytes[3],
159        bytes[4], bytes[5],
160        bytes[6], bytes[7],
161        bytes[8], bytes[9],
162        bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
163    )
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::middleware::layer::{BoxedNext, LayerStack};
170    use crate::path_params::PathParams;
171    use bytes::Bytes;
172    use http::{Extensions, Method, StatusCode};
173    use proptest::prelude::*;
174    use proptest::test_runner::TestCaseError;
175    use std::collections::HashSet;
176    use std::sync::Arc;
177
178    /// Create a test request with the given method and path
179    fn create_test_request(method: Method, path: &str) -> Request {
180        let uri: http::Uri = path.parse().unwrap();
181        let builder = http::Request::builder().method(method).uri(uri);
182
183        let req = builder.body(()).unwrap();
184        let (parts, _) = req.into_parts();
185
186        Request::new(
187            parts,
188            crate::request::BodyVariant::Buffered(Bytes::new()),
189            Arc::new(Extensions::new()),
190            PathParams::new(),
191        )
192    }
193
194    #[test]
195    fn test_request_id_generation() {
196        let id1 = RequestId::new();
197        let id2 = RequestId::new();
198
199        // IDs should be different
200        assert_ne!(id1.0, id2.0);
201
202        // IDs should be valid UUID format (36 chars with hyphens)
203        assert_eq!(id1.0.len(), 36);
204        assert_eq!(id2.0.len(), 36);
205    }
206
207    #[test]
208    fn test_request_id_display() {
209        let id = RequestId::from_string("test-id-123".to_string());
210        assert_eq!(format!("{}", id), "test-id-123");
211    }
212
213    // **Feature: phase3-batteries-included, Property 3: Request ID uniqueness**
214    //
215    // For any set of N concurrent requests processed with RequestIdLayer enabled,
216    // the System SHALL generate N distinct UUID values, each accessible via the
217    // `RequestId` extractor.
218    //
219    // **Validates: Requirements 1.3**
220    proptest! {
221        #![proptest_config(ProptestConfig::with_cases(100))]
222
223        #[test]
224        fn prop_request_id_uniqueness(
225            num_requests in 1usize..100usize,
226        ) {
227            let rt = tokio::runtime::Runtime::new().unwrap();
228            let result: Result<(), TestCaseError> = rt.block_on(async {
229                let mut stack = LayerStack::new();
230                stack.push(Box::new(RequestIdLayer::new()));
231
232                // Collect all generated request IDs
233                let collected_ids = Arc::new(std::sync::Mutex::new(Vec::new()));
234
235                // Process multiple requests through the middleware
236                for _ in 0..num_requests {
237                    let ids = collected_ids.clone();
238
239                    // Create a handler that extracts and stores the request ID
240                    let handler: BoxedNext = Arc::new(move |req: Request| {
241                        let ids = ids.clone();
242                        Box::pin(async move {
243                            // Extract the request ID from extensions
244                            if let Some(request_id) = req.extensions().get::<RequestId>() {
245                                ids.lock().unwrap().push(request_id.0.clone());
246                            }
247
248                            http::Response::builder()
249                                .status(StatusCode::OK)
250                                .body(http_body_util::Full::new(Bytes::from("ok")))
251                                .unwrap()
252                        }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
253                    });
254
255                    let request = create_test_request(Method::GET, "/test");
256                    let _response = stack.execute(request, handler).await;
257                }
258
259                // Verify all IDs are unique
260                let ids = collected_ids.lock().unwrap();
261                prop_assert_eq!(ids.len(), num_requests, "Should have collected {} IDs", num_requests);
262
263                let unique_ids: HashSet<_> = ids.iter().collect();
264                prop_assert_eq!(
265                    unique_ids.len(),
266                    num_requests,
267                    "All {} request IDs should be unique, but found {} unique IDs",
268                    num_requests,
269                    unique_ids.len()
270                );
271
272                // Verify all IDs are valid UUID format (36 chars with hyphens)
273                for id in ids.iter() {
274                    prop_assert_eq!(id.len(), 36, "Request ID should be 36 characters (UUID format)");
275                    // Check UUID format: 8-4-4-4-12
276                    let parts: Vec<&str> = id.split('-').collect();
277                    prop_assert_eq!(parts.len(), 5, "UUID should have 5 parts separated by hyphens");
278                    prop_assert_eq!(parts[0].len(), 8);
279                    prop_assert_eq!(parts[1].len(), 4);
280                    prop_assert_eq!(parts[2].len(), 4);
281                    prop_assert_eq!(parts[3].len(), 4);
282                    prop_assert_eq!(parts[4].len(), 12);
283                }
284
285                Ok(())
286            });
287            result?;
288        }
289    }
290
291    #[test]
292    fn test_request_id_extractor() {
293        let rt = tokio::runtime::Runtime::new().unwrap();
294        rt.block_on(async {
295            let mut stack = LayerStack::new();
296            stack.push(Box::new(RequestIdLayer::new()));
297
298            let extracted_id = Arc::new(std::sync::Mutex::new(None));
299            let extracted_id_clone = extracted_id.clone();
300
301            let handler: BoxedNext = Arc::new(move |req: Request| {
302                let extracted_id = extracted_id_clone.clone();
303                Box::pin(async move {
304                    // Use the FromRequestParts implementation
305                    if let Ok(request_id) = RequestId::from_request_parts(&req) {
306                        *extracted_id.lock().unwrap() = Some(request_id.0.clone());
307                    }
308
309                    http::Response::builder()
310                        .status(StatusCode::OK)
311                        .body(http_body_util::Full::new(Bytes::from("ok")))
312                        .unwrap()
313                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
314            });
315
316            let request = create_test_request(Method::GET, "/test");
317            let _response = stack.execute(request, handler).await;
318
319            // Verify the request ID was extracted
320            let id = extracted_id.lock().unwrap();
321            assert!(id.is_some(), "Request ID should have been extracted");
322            assert_eq!(
323                id.as_ref().unwrap().len(),
324                36,
325                "Request ID should be UUID format"
326            );
327        });
328    }
329
330    #[test]
331    fn test_request_id_extractor_without_middleware() {
332        // Test that extractor returns error when middleware is not applied
333        let request = create_test_request(Method::GET, "/test");
334        let result = RequestId::from_request_parts(&request);
335        assert!(
336            result.is_err(),
337            "Should return error when RequestIdLayer is not applied"
338        );
339    }
340}