rustapi_core/middleware/
request_id.rs1use super::layer::{BoxedNext, MiddlewareLayer};
6use crate::error::{ApiError, Result};
7use crate::extract::FromRequestParts;
8use crate::request::Request;
9use crate::response::Response;
10use std::future::Future;
11use std::pin::Pin;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct RequestId(pub String);
16
17impl RequestId {
18 pub fn new() -> Self {
20 Self(generate_uuid())
21 }
22
23 pub fn from_string(id: String) -> Self {
25 Self(id)
26 }
27
28 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32}
33
34impl Default for RequestId {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl std::fmt::Display for RequestId {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "{}", self.0)
43 }
44}
45
46impl FromRequestParts for RequestId {
61 fn from_request_parts(req: &Request) -> Result<Self> {
62 req.extensions().get::<RequestId>().cloned().ok_or_else(|| {
63 ApiError::internal(
64 "RequestId not found. Did you forget to add RequestIdLayer middleware?",
65 )
66 })
67 }
68}
69
70#[derive(Clone, Default)]
72pub struct RequestIdLayer;
73
74impl RequestIdLayer {
75 pub fn new() -> Self {
77 Self
78 }
79}
80
81impl MiddlewareLayer for RequestIdLayer {
82 fn call(
83 &self,
84 mut req: Request,
85 next: BoxedNext,
86 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
87 Box::pin(async move {
88 let request_id = RequestId::new();
90
91 req.extensions_mut().insert(request_id.clone());
93
94 let mut response = next(req).await;
96
97 if let Ok(header_value) = request_id.0.parse() {
99 response.headers_mut().insert("x-request-id", header_value);
100 }
101
102 response
103 })
104 }
105
106 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
107 Box::new(self.clone())
108 }
109}
110
111fn generate_uuid() -> String {
115 use std::time::{SystemTime, UNIX_EPOCH};
116
117 let now = SystemTime::now()
119 .duration_since(UNIX_EPOCH)
120 .unwrap_or_default();
121
122 let time_part = now.as_nanos();
124
125 let thread_id = std::thread::current().id();
127 let thread_hash = format!("{:?}", thread_id);
128
129 let mut bytes = [0u8; 16];
131
132 let time_bytes = time_part.to_le_bytes();
134 for (i, &b) in time_bytes.iter().enumerate().take(16) {
135 bytes[i] = b;
136 }
137
138 for (i, b) in thread_hash.bytes().enumerate() {
140 bytes[i % 16] ^= b;
141 }
142
143 static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
145 let count = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
146 let count_bytes = count.to_le_bytes();
147 for (i, &b) in count_bytes.iter().enumerate() {
148 bytes[(i + 8) % 16] ^= b;
149 }
150
151 bytes[6] = (bytes[6] & 0x0f) | 0x40; bytes[8] = (bytes[8] & 0x3f) | 0x80; format!(
157 "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
158 bytes[0], bytes[1], bytes[2], bytes[3],
159 bytes[4], bytes[5],
160 bytes[6], bytes[7],
161 bytes[8], bytes[9],
162 bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
163 )
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::middleware::layer::{BoxedNext, LayerStack};
170 use crate::path_params::PathParams;
171 use bytes::Bytes;
172 use http::{Extensions, Method, StatusCode};
173 use proptest::prelude::*;
174 use proptest::test_runner::TestCaseError;
175 use std::collections::HashSet;
176 use std::sync::Arc;
177
178 fn create_test_request(method: Method, path: &str) -> Request {
180 let uri: http::Uri = path.parse().unwrap();
181 let builder = http::Request::builder().method(method).uri(uri);
182
183 let req = builder.body(()).unwrap();
184 let (parts, _) = req.into_parts();
185
186 Request::new(
187 parts,
188 crate::request::BodyVariant::Buffered(Bytes::new()),
189 Arc::new(Extensions::new()),
190 PathParams::new(),
191 )
192 }
193
194 #[test]
195 fn test_request_id_generation() {
196 let id1 = RequestId::new();
197 let id2 = RequestId::new();
198
199 assert_ne!(id1.0, id2.0);
201
202 assert_eq!(id1.0.len(), 36);
204 assert_eq!(id2.0.len(), 36);
205 }
206
207 #[test]
208 fn test_request_id_display() {
209 let id = RequestId::from_string("test-id-123".to_string());
210 assert_eq!(format!("{}", id), "test-id-123");
211 }
212
213 proptest! {
221 #![proptest_config(ProptestConfig::with_cases(100))]
222
223 #[test]
224 fn prop_request_id_uniqueness(
225 num_requests in 1usize..100usize,
226 ) {
227 let rt = tokio::runtime::Runtime::new().unwrap();
228 let result: Result<(), TestCaseError> = rt.block_on(async {
229 let mut stack = LayerStack::new();
230 stack.push(Box::new(RequestIdLayer::new()));
231
232 let collected_ids = Arc::new(std::sync::Mutex::new(Vec::new()));
234
235 for _ in 0..num_requests {
237 let ids = collected_ids.clone();
238
239 let handler: BoxedNext = Arc::new(move |req: Request| {
241 let ids = ids.clone();
242 Box::pin(async move {
243 if let Some(request_id) = req.extensions().get::<RequestId>() {
245 ids.lock().unwrap().push(request_id.0.clone());
246 }
247
248 http::Response::builder()
249 .status(StatusCode::OK)
250 .body(http_body_util::Full::new(Bytes::from("ok")))
251 .unwrap()
252 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
253 });
254
255 let request = create_test_request(Method::GET, "/test");
256 let _response = stack.execute(request, handler).await;
257 }
258
259 let ids = collected_ids.lock().unwrap();
261 prop_assert_eq!(ids.len(), num_requests, "Should have collected {} IDs", num_requests);
262
263 let unique_ids: HashSet<_> = ids.iter().collect();
264 prop_assert_eq!(
265 unique_ids.len(),
266 num_requests,
267 "All {} request IDs should be unique, but found {} unique IDs",
268 num_requests,
269 unique_ids.len()
270 );
271
272 for id in ids.iter() {
274 prop_assert_eq!(id.len(), 36, "Request ID should be 36 characters (UUID format)");
275 let parts: Vec<&str> = id.split('-').collect();
277 prop_assert_eq!(parts.len(), 5, "UUID should have 5 parts separated by hyphens");
278 prop_assert_eq!(parts[0].len(), 8);
279 prop_assert_eq!(parts[1].len(), 4);
280 prop_assert_eq!(parts[2].len(), 4);
281 prop_assert_eq!(parts[3].len(), 4);
282 prop_assert_eq!(parts[4].len(), 12);
283 }
284
285 Ok(())
286 });
287 result?;
288 }
289 }
290
291 #[test]
292 fn test_request_id_extractor() {
293 let rt = tokio::runtime::Runtime::new().unwrap();
294 rt.block_on(async {
295 let mut stack = LayerStack::new();
296 stack.push(Box::new(RequestIdLayer::new()));
297
298 let extracted_id = Arc::new(std::sync::Mutex::new(None));
299 let extracted_id_clone = extracted_id.clone();
300
301 let handler: BoxedNext = Arc::new(move |req: Request| {
302 let extracted_id = extracted_id_clone.clone();
303 Box::pin(async move {
304 if let Ok(request_id) = RequestId::from_request_parts(&req) {
306 *extracted_id.lock().unwrap() = Some(request_id.0.clone());
307 }
308
309 http::Response::builder()
310 .status(StatusCode::OK)
311 .body(http_body_util::Full::new(Bytes::from("ok")))
312 .unwrap()
313 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
314 });
315
316 let request = create_test_request(Method::GET, "/test");
317 let _response = stack.execute(request, handler).await;
318
319 let id = extracted_id.lock().unwrap();
321 assert!(id.is_some(), "Request ID should have been extracted");
322 assert_eq!(
323 id.as_ref().unwrap().len(),
324 36,
325 "Request ID should be UUID format"
326 );
327 });
328 }
329
330 #[test]
331 fn test_request_id_extractor_without_middleware() {
332 let request = create_test_request(Method::GET, "/test");
334 let result = RequestId::from_request_parts(&request);
335 assert!(
336 result.is_err(),
337 "Should return error when RequestIdLayer is not applied"
338 );
339 }
340}