tower_async_http/request_id.rs
1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower_async::{Service, ServiceExt, ServiceBuilder};
10//! use tower_async_http::request_id::{
11//! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
12//! };
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower_async::service_fn(|request: Request<Full<Bytes>>| async move {
18//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//! counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//! fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
29//! let request_id = self.counter
30//! .fetch_add(1, Ordering::SeqCst)
31//! .to_string()
32//! .parse()
33//! .unwrap();
34//!
35//! Some(RequestId::new(request_id))
36//! }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//! // set `x-request-id` header on all requests
43//! .layer(SetRequestIdLayer::new(
44//! x_request_id.clone(),
45//! MyMakeRequestId::default(),
46//! ))
47//! // propagate `x-request-id` headers from request to response
48//! .layer(PropagateRequestIdLayer::new(x_request_id))
49//! .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_async_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use http_body_util::Full;
66//! # use bytes::Bytes;
67//! # use tower_async::{Service, ServiceExt, ServiceBuilder};
68//! # use tower_async_http::request_id::{
69//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
70//! # };
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! #
73//! # type Body = Full<Bytes>;
74//! #
75//! # #[tokio::main]
76//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
77//! # let handler = tower_async::service_fn(|request: Request<Body>| async move {
78//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
79//! # });
80//! # #[derive(Clone, Default)]
81//! # struct MyMakeRequestId {
82//! # counter: Arc<AtomicU64>,
83//! # }
84//! # impl MakeRequestId for MyMakeRequestId {
85//! # fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
86//! # let request_id = self.counter
87//! # .fetch_add(1, Ordering::SeqCst)
88//! # .to_string()
89//! # .parse()
90//! # .unwrap();
91//! # Some(RequestId::new(request_id))
92//! # }
93//! # }
94//!
95//! let mut svc = ServiceBuilder::new()
96//! .set_x_request_id(MyMakeRequestId::default())
97//! .propagate_x_request_id()
98//! .service(handler);
99//!
100//! let request = Request::new(Body::default());
101//! let response = svc.call(request).await?;
102//!
103//! assert_eq!(response.headers()["x-request-id"], "0");
104//! #
105//! # Ok(())
106//! # }
107//! ```
108//!
109//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
110//!
111//! # Doesn't override existing headers
112//!
113//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
114//! requests or responses. Among other things, this allows other middleware to conditionally set
115//! request ids and use the middleware in this module as a fallback.
116//!
117//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
118//! [`Uuid`]: https://crates.io/crates/uuid
119
120use http::{
121 header::{HeaderName, HeaderValue},
122 Request, Response,
123};
124use tower_async_layer::Layer;
125use tower_async_service::Service;
126use uuid::Uuid;
127
128pub(crate) const X_REQUEST_ID: &str = "x-request-id";
129
130/// Trait for producing [`RequestId`]s.
131///
132/// Used by [`SetRequestId`].
133pub trait MakeRequestId {
134 /// Try and produce a [`RequestId`] from the request.
135 fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
136}
137
138/// An identifier for a request.
139#[derive(Debug, Clone)]
140pub struct RequestId(HeaderValue);
141
142impl RequestId {
143 /// Create a new `RequestId` from a [`HeaderValue`].
144 pub fn new(header_value: HeaderValue) -> Self {
145 Self(header_value)
146 }
147
148 /// Gets a reference to the underlying [`HeaderValue`].
149 pub fn header_value(&self) -> &HeaderValue {
150 &self.0
151 }
152
153 /// Consumes `self`, returning the underlying [`HeaderValue`].
154 pub fn into_header_value(self) -> HeaderValue {
155 self.0
156 }
157}
158
159impl From<HeaderValue> for RequestId {
160 fn from(value: HeaderValue) -> Self {
161 Self::new(value)
162 }
163}
164
165/// Set request id headers and extensions on requests.
166///
167/// This layer applies the [`SetRequestId`] middleware.
168///
169/// See the [module docs](self) and [`SetRequestId`] for more details.
170#[derive(Debug, Clone)]
171pub struct SetRequestIdLayer<M> {
172 header_name: HeaderName,
173 make_request_id: M,
174}
175
176impl<M> SetRequestIdLayer<M> {
177 /// Create a new `SetRequestIdLayer`.
178 pub fn new(header_name: HeaderName, make_request_id: M) -> Self
179 where
180 M: MakeRequestId,
181 {
182 SetRequestIdLayer {
183 header_name,
184 make_request_id,
185 }
186 }
187
188 /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
189 pub fn x_request_id(make_request_id: M) -> Self
190 where
191 M: MakeRequestId,
192 {
193 SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
194 }
195}
196
197impl<S, M> Layer<S> for SetRequestIdLayer<M>
198where
199 M: Clone + MakeRequestId,
200{
201 type Service = SetRequestId<S, M>;
202
203 fn layer(&self, inner: S) -> Self::Service {
204 SetRequestId::new(
205 inner,
206 self.header_name.clone(),
207 self.make_request_id.clone(),
208 )
209 }
210}
211
212/// Set request id headers and extensions on requests.
213///
214/// See the [module docs](self) for an example.
215///
216/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
217/// header with the same name, then the header will be inserted.
218///
219/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
220/// services can access it.
221#[derive(Debug, Clone)]
222pub struct SetRequestId<S, M> {
223 inner: S,
224 header_name: HeaderName,
225 make_request_id: M,
226}
227
228impl<S, M> SetRequestId<S, M> {
229 /// Create a new `SetRequestId`.
230 pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
231 where
232 M: MakeRequestId,
233 {
234 Self {
235 inner,
236 header_name,
237 make_request_id,
238 }
239 }
240
241 /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
242 pub fn x_request_id(inner: S, make_request_id: M) -> Self
243 where
244 M: MakeRequestId,
245 {
246 Self::new(
247 inner,
248 HeaderName::from_static(X_REQUEST_ID),
249 make_request_id,
250 )
251 }
252
253 define_inner_service_accessors!();
254
255 /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
256 pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
257 where
258 M: MakeRequestId,
259 {
260 SetRequestIdLayer::new(header_name, make_request_id)
261 }
262}
263
264impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
265where
266 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
267 M: MakeRequestId,
268{
269 type Response = S::Response;
270 type Error = S::Error;
271
272 async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
273 if let Some(request_id) = req.headers().get(&self.header_name) {
274 if req.extensions().get::<RequestId>().is_none() {
275 let request_id = request_id.clone();
276 req.extensions_mut().insert(RequestId::new(request_id));
277 }
278 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
279 req.extensions_mut().insert(request_id.clone());
280 req.headers_mut()
281 .insert(self.header_name.clone(), request_id.0);
282 }
283
284 self.inner.call(req).await
285 }
286}
287
288/// Propagate request ids from requests to responses.
289///
290/// This layer applies the [`PropagateRequestId`] middleware.
291///
292/// See the [module docs](self) and [`PropagateRequestId`] for more details.
293#[derive(Debug, Clone)]
294pub struct PropagateRequestIdLayer {
295 header_name: HeaderName,
296}
297
298impl PropagateRequestIdLayer {
299 /// Create a new `PropagateRequestIdLayer`.
300 pub fn new(header_name: HeaderName) -> Self {
301 PropagateRequestIdLayer { header_name }
302 }
303
304 /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
305 pub fn x_request_id() -> Self {
306 Self::new(HeaderName::from_static(X_REQUEST_ID))
307 }
308}
309
310impl<S> Layer<S> for PropagateRequestIdLayer {
311 type Service = PropagateRequestId<S>;
312
313 fn layer(&self, inner: S) -> Self::Service {
314 PropagateRequestId::new(inner, self.header_name.clone())
315 }
316}
317
318/// Propagate request ids from requests to responses.
319///
320/// See the [module docs](self) for an example.
321///
322/// If the request contains a matching header that header will be applied to responses. If a
323/// [`RequestId`] extension is also present it will be propagated as well.
324#[derive(Debug, Clone)]
325pub struct PropagateRequestId<S> {
326 inner: S,
327 header_name: HeaderName,
328}
329
330impl<S> PropagateRequestId<S> {
331 /// Create a new `PropagateRequestId`.
332 pub fn new(inner: S, header_name: HeaderName) -> Self {
333 Self { inner, header_name }
334 }
335
336 /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
337 pub fn x_request_id(inner: S) -> Self {
338 Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
339 }
340
341 define_inner_service_accessors!();
342
343 /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
344 pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
345 PropagateRequestIdLayer::new(header_name)
346 }
347}
348
349impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
350where
351 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
352{
353 type Response = S::Response;
354 type Error = S::Error;
355
356 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
357 let request_id = req
358 .headers()
359 .get(&self.header_name)
360 .cloned()
361 .map(RequestId::new);
362
363 let mut response = self.inner.call(req).await?;
364
365 if let Some(current_id) = response.headers().get(&self.header_name) {
366 if response.extensions().get::<RequestId>().is_none() {
367 let current_id = current_id.clone();
368 response.extensions_mut().insert(RequestId::new(current_id));
369 }
370 } else if let Some(request_id) = request_id {
371 response
372 .headers_mut()
373 .insert(self.header_name.clone(), request_id.0.clone());
374 response.extensions_mut().insert(request_id);
375 }
376
377 Ok(response)
378 }
379}
380
381/// A [`MakeRequestId`] that generates `UUID`s.
382#[derive(Clone, Copy, Default)]
383pub struct MakeRequestUuid;
384
385impl MakeRequestId for MakeRequestUuid {
386 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
387 let request_id = Uuid::new_v4().to_string().parse().unwrap();
388 Some(RequestId::new(request_id))
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use crate::test_helpers::Body;
395 use crate::ServiceBuilderExt as _;
396 use http::Response;
397 use std::{
398 convert::Infallible,
399 sync::{
400 atomic::{AtomicU64, Ordering},
401 Arc,
402 },
403 };
404 use tower_async::{ServiceBuilder, ServiceExt};
405
406 #[allow(unused_imports)]
407 use super::*;
408
409 #[tokio::test]
410 async fn basic() {
411 let svc = ServiceBuilder::new()
412 .set_x_request_id(Counter::default())
413 .propagate_x_request_id()
414 .service_fn(handler);
415
416 // header on response
417 let req = Request::builder().body(Body::empty()).unwrap();
418 let res = svc.clone().oneshot(req).await.unwrap();
419 assert_eq!(res.headers()["x-request-id"], "0");
420
421 let req = Request::builder().body(Body::empty()).unwrap();
422 let res = svc.clone().oneshot(req).await.unwrap();
423 assert_eq!(res.headers()["x-request-id"], "1");
424
425 // doesn't override if header is already there
426 let req = Request::builder()
427 .header("x-request-id", "foo")
428 .body(Body::empty())
429 .unwrap();
430 let res = svc.clone().oneshot(req).await.unwrap();
431 assert_eq!(res.headers()["x-request-id"], "foo");
432
433 // extension propagated
434 let req = Request::builder().body(Body::empty()).unwrap();
435 let res = svc.clone().oneshot(req).await.unwrap();
436 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
437 }
438
439 // #[tokio::test]
440 // async fn other_middleware_setting_request_id() {
441 // let svc = ServiceBuilder::new()
442 // .override_request_header(
443 // HeaderName::from_static("x-request-id"),
444 // HeaderValue::from_str("foo").unwrap(),
445 // )
446 // .set_x_request_id(Counter::default())
447 // .map_request(|request: Request<_>| {
448 // // `set_x_request_id` should set the extension if its missing
449 // assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
450 // request
451 // })
452 // .propagate_x_request_id()
453 // .service_fn(handler);
454
455 // let req = Request::builder()
456 // .header(
457 // "x-request-id",
458 // "this-will-be-overridden-by-override_request_header-middleware",
459 // )
460 // .body(Body::empty())
461 // .unwrap();
462 // let res = svc.clone().oneshot(req).await.unwrap();
463 // assert_eq!(res.headers()["x-request-id"], "foo");
464 // assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
465 // }
466
467 #[tokio::test]
468 async fn other_middleware_setting_request_id_on_response() {
469 let svc = ServiceBuilder::new()
470 .set_x_request_id(Counter::default())
471 .propagate_x_request_id()
472 .override_response_header(
473 HeaderName::from_static("x-request-id"),
474 HeaderValue::from_str("foo").unwrap(),
475 )
476 .service_fn(handler);
477
478 let req = Request::builder()
479 .header("x-request-id", "foo")
480 .body(Body::empty())
481 .unwrap();
482 let res = svc.clone().oneshot(req).await.unwrap();
483 assert_eq!(res.headers()["x-request-id"], "foo");
484 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
485 }
486
487 #[derive(Clone, Default)]
488 struct Counter(Arc<AtomicU64>);
489
490 impl MakeRequestId for Counter {
491 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
492 let id =
493 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
494 Some(RequestId::new(id))
495 }
496 }
497
498 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
499 Ok(Response::new(Body::empty()))
500 }
501
502 #[tokio::test]
503 async fn uuid() {
504 let svc = ServiceBuilder::new()
505 .set_x_request_id(MakeRequestUuid)
506 .propagate_x_request_id()
507 .service_fn(handler);
508
509 // header on response
510 let req = Request::builder().body(Body::empty()).unwrap();
511 let mut res = svc.clone().oneshot(req).await.unwrap();
512 let id = res.headers_mut().remove("x-request-id").unwrap();
513 id.to_str().unwrap().parse::<Uuid>().unwrap();
514 }
515}