tower_async_http/
request_id.rs

1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower_async::{Service, ServiceExt, ServiceBuilder};
10//! use tower_async_http::request_id::{
11//!     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
12//! };
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower_async::service_fn(|request: Request<Full<Bytes>>| async move {
18//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//!     counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//!     fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
29//!         let request_id = self.counter
30//!             .fetch_add(1, Ordering::SeqCst)
31//!             .to_string()
32//!             .parse()
33//!             .unwrap();
34//!
35//!         Some(RequestId::new(request_id))
36//!     }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//!     // set `x-request-id` header on all requests
43//!     .layer(SetRequestIdLayer::new(
44//!         x_request_id.clone(),
45//!         MyMakeRequestId::default(),
46//!     ))
47//!     // propagate `x-request-id` headers from request to response
48//!     .layer(PropagateRequestIdLayer::new(x_request_id))
49//!     .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_async_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use http_body_util::Full;
66//! # use bytes::Bytes;
67//! # use tower_async::{Service, ServiceExt, ServiceBuilder};
68//! # use tower_async_http::request_id::{
69//! #     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
70//! # };
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! #
73//! # type Body = Full<Bytes>;
74//! #
75//! # #[tokio::main]
76//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
77//! # let handler = tower_async::service_fn(|request: Request<Body>| async move {
78//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
79//! # });
80//! # #[derive(Clone, Default)]
81//! # struct MyMakeRequestId {
82//! #     counter: Arc<AtomicU64>,
83//! # }
84//! # impl MakeRequestId for MyMakeRequestId {
85//! #     fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
86//! #         let request_id = self.counter
87//! #             .fetch_add(1, Ordering::SeqCst)
88//! #             .to_string()
89//! #             .parse()
90//! #             .unwrap();
91//! #         Some(RequestId::new(request_id))
92//! #     }
93//! # }
94//!
95//! let mut svc = ServiceBuilder::new()
96//!     .set_x_request_id(MyMakeRequestId::default())
97//!     .propagate_x_request_id()
98//!     .service(handler);
99//!
100//! let request = Request::new(Body::default());
101//! let response = svc.call(request).await?;
102//!
103//! assert_eq!(response.headers()["x-request-id"], "0");
104//! #
105//! # Ok(())
106//! # }
107//! ```
108//!
109//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
110//!
111//! # Doesn't override existing headers
112//!
113//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
114//! requests or responses. Among other things, this allows other middleware to conditionally set
115//! request ids and use the middleware in this module as a fallback.
116//!
117//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
118//! [`Uuid`]: https://crates.io/crates/uuid
119
120use http::{
121    header::{HeaderName, HeaderValue},
122    Request, Response,
123};
124use tower_async_layer::Layer;
125use tower_async_service::Service;
126use uuid::Uuid;
127
128pub(crate) const X_REQUEST_ID: &str = "x-request-id";
129
130/// Trait for producing [`RequestId`]s.
131///
132/// Used by [`SetRequestId`].
133pub trait MakeRequestId {
134    /// Try and produce a [`RequestId`] from the request.
135    fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
136}
137
138/// An identifier for a request.
139#[derive(Debug, Clone)]
140pub struct RequestId(HeaderValue);
141
142impl RequestId {
143    /// Create a new `RequestId` from a [`HeaderValue`].
144    pub fn new(header_value: HeaderValue) -> Self {
145        Self(header_value)
146    }
147
148    /// Gets a reference to the underlying [`HeaderValue`].
149    pub fn header_value(&self) -> &HeaderValue {
150        &self.0
151    }
152
153    /// Consumes `self`, returning the underlying [`HeaderValue`].
154    pub fn into_header_value(self) -> HeaderValue {
155        self.0
156    }
157}
158
159impl From<HeaderValue> for RequestId {
160    fn from(value: HeaderValue) -> Self {
161        Self::new(value)
162    }
163}
164
165/// Set request id headers and extensions on requests.
166///
167/// This layer applies the [`SetRequestId`] middleware.
168///
169/// See the [module docs](self) and [`SetRequestId`] for more details.
170#[derive(Debug, Clone)]
171pub struct SetRequestIdLayer<M> {
172    header_name: HeaderName,
173    make_request_id: M,
174}
175
176impl<M> SetRequestIdLayer<M> {
177    /// Create a new `SetRequestIdLayer`.
178    pub fn new(header_name: HeaderName, make_request_id: M) -> Self
179    where
180        M: MakeRequestId,
181    {
182        SetRequestIdLayer {
183            header_name,
184            make_request_id,
185        }
186    }
187
188    /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
189    pub fn x_request_id(make_request_id: M) -> Self
190    where
191        M: MakeRequestId,
192    {
193        SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
194    }
195}
196
197impl<S, M> Layer<S> for SetRequestIdLayer<M>
198where
199    M: Clone + MakeRequestId,
200{
201    type Service = SetRequestId<S, M>;
202
203    fn layer(&self, inner: S) -> Self::Service {
204        SetRequestId::new(
205            inner,
206            self.header_name.clone(),
207            self.make_request_id.clone(),
208        )
209    }
210}
211
212/// Set request id headers and extensions on requests.
213///
214/// See the [module docs](self) for an example.
215///
216/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
217/// header with the same name, then the header will be inserted.
218///
219/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
220/// services can access it.
221#[derive(Debug, Clone)]
222pub struct SetRequestId<S, M> {
223    inner: S,
224    header_name: HeaderName,
225    make_request_id: M,
226}
227
228impl<S, M> SetRequestId<S, M> {
229    /// Create a new `SetRequestId`.
230    pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
231    where
232        M: MakeRequestId,
233    {
234        Self {
235            inner,
236            header_name,
237            make_request_id,
238        }
239    }
240
241    /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
242    pub fn x_request_id(inner: S, make_request_id: M) -> Self
243    where
244        M: MakeRequestId,
245    {
246        Self::new(
247            inner,
248            HeaderName::from_static(X_REQUEST_ID),
249            make_request_id,
250        )
251    }
252
253    define_inner_service_accessors!();
254
255    /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
256    pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
257    where
258        M: MakeRequestId,
259    {
260        SetRequestIdLayer::new(header_name, make_request_id)
261    }
262}
263
264impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
265where
266    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
267    M: MakeRequestId,
268{
269    type Response = S::Response;
270    type Error = S::Error;
271
272    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
273        if let Some(request_id) = req.headers().get(&self.header_name) {
274            if req.extensions().get::<RequestId>().is_none() {
275                let request_id = request_id.clone();
276                req.extensions_mut().insert(RequestId::new(request_id));
277            }
278        } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
279            req.extensions_mut().insert(request_id.clone());
280            req.headers_mut()
281                .insert(self.header_name.clone(), request_id.0);
282        }
283
284        self.inner.call(req).await
285    }
286}
287
288/// Propagate request ids from requests to responses.
289///
290/// This layer applies the [`PropagateRequestId`] middleware.
291///
292/// See the [module docs](self) and [`PropagateRequestId`] for more details.
293#[derive(Debug, Clone)]
294pub struct PropagateRequestIdLayer {
295    header_name: HeaderName,
296}
297
298impl PropagateRequestIdLayer {
299    /// Create a new `PropagateRequestIdLayer`.
300    pub fn new(header_name: HeaderName) -> Self {
301        PropagateRequestIdLayer { header_name }
302    }
303
304    /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
305    pub fn x_request_id() -> Self {
306        Self::new(HeaderName::from_static(X_REQUEST_ID))
307    }
308}
309
310impl<S> Layer<S> for PropagateRequestIdLayer {
311    type Service = PropagateRequestId<S>;
312
313    fn layer(&self, inner: S) -> Self::Service {
314        PropagateRequestId::new(inner, self.header_name.clone())
315    }
316}
317
318/// Propagate request ids from requests to responses.
319///
320/// See the [module docs](self) for an example.
321///
322/// If the request contains a matching header that header will be applied to responses. If a
323/// [`RequestId`] extension is also present it will be propagated as well.
324#[derive(Debug, Clone)]
325pub struct PropagateRequestId<S> {
326    inner: S,
327    header_name: HeaderName,
328}
329
330impl<S> PropagateRequestId<S> {
331    /// Create a new `PropagateRequestId`.
332    pub fn new(inner: S, header_name: HeaderName) -> Self {
333        Self { inner, header_name }
334    }
335
336    /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
337    pub fn x_request_id(inner: S) -> Self {
338        Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
339    }
340
341    define_inner_service_accessors!();
342
343    /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
344    pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
345        PropagateRequestIdLayer::new(header_name)
346    }
347}
348
349impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
350where
351    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
352{
353    type Response = S::Response;
354    type Error = S::Error;
355
356    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
357        let request_id = req
358            .headers()
359            .get(&self.header_name)
360            .cloned()
361            .map(RequestId::new);
362
363        let mut response = self.inner.call(req).await?;
364
365        if let Some(current_id) = response.headers().get(&self.header_name) {
366            if response.extensions().get::<RequestId>().is_none() {
367                let current_id = current_id.clone();
368                response.extensions_mut().insert(RequestId::new(current_id));
369            }
370        } else if let Some(request_id) = request_id {
371            response
372                .headers_mut()
373                .insert(self.header_name.clone(), request_id.0.clone());
374            response.extensions_mut().insert(request_id);
375        }
376
377        Ok(response)
378    }
379}
380
381/// A [`MakeRequestId`] that generates `UUID`s.
382#[derive(Clone, Copy, Default)]
383pub struct MakeRequestUuid;
384
385impl MakeRequestId for MakeRequestUuid {
386    fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
387        let request_id = Uuid::new_v4().to_string().parse().unwrap();
388        Some(RequestId::new(request_id))
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use crate::test_helpers::Body;
395    use crate::ServiceBuilderExt as _;
396    use http::Response;
397    use std::{
398        convert::Infallible,
399        sync::{
400            atomic::{AtomicU64, Ordering},
401            Arc,
402        },
403    };
404    use tower_async::{ServiceBuilder, ServiceExt};
405
406    #[allow(unused_imports)]
407    use super::*;
408
409    #[tokio::test]
410    async fn basic() {
411        let svc = ServiceBuilder::new()
412            .set_x_request_id(Counter::default())
413            .propagate_x_request_id()
414            .service_fn(handler);
415
416        // header on response
417        let req = Request::builder().body(Body::empty()).unwrap();
418        let res = svc.clone().oneshot(req).await.unwrap();
419        assert_eq!(res.headers()["x-request-id"], "0");
420
421        let req = Request::builder().body(Body::empty()).unwrap();
422        let res = svc.clone().oneshot(req).await.unwrap();
423        assert_eq!(res.headers()["x-request-id"], "1");
424
425        // doesn't override if header is already there
426        let req = Request::builder()
427            .header("x-request-id", "foo")
428            .body(Body::empty())
429            .unwrap();
430        let res = svc.clone().oneshot(req).await.unwrap();
431        assert_eq!(res.headers()["x-request-id"], "foo");
432
433        // extension propagated
434        let req = Request::builder().body(Body::empty()).unwrap();
435        let res = svc.clone().oneshot(req).await.unwrap();
436        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
437    }
438
439    // #[tokio::test]
440    // async fn other_middleware_setting_request_id() {
441    //     let svc = ServiceBuilder::new()
442    //         .override_request_header(
443    //             HeaderName::from_static("x-request-id"),
444    //             HeaderValue::from_str("foo").unwrap(),
445    //         )
446    //         .set_x_request_id(Counter::default())
447    //         .map_request(|request: Request<_>| {
448    //             // `set_x_request_id` should set the extension if its missing
449    //             assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
450    //             request
451    //         })
452    //         .propagate_x_request_id()
453    //         .service_fn(handler);
454
455    //     let req = Request::builder()
456    //         .header(
457    //             "x-request-id",
458    //             "this-will-be-overridden-by-override_request_header-middleware",
459    //         )
460    //         .body(Body::empty())
461    //         .unwrap();
462    //     let res = svc.clone().oneshot(req).await.unwrap();
463    //     assert_eq!(res.headers()["x-request-id"], "foo");
464    //     assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
465    // }
466
467    #[tokio::test]
468    async fn other_middleware_setting_request_id_on_response() {
469        let svc = ServiceBuilder::new()
470            .set_x_request_id(Counter::default())
471            .propagate_x_request_id()
472            .override_response_header(
473                HeaderName::from_static("x-request-id"),
474                HeaderValue::from_str("foo").unwrap(),
475            )
476            .service_fn(handler);
477
478        let req = Request::builder()
479            .header("x-request-id", "foo")
480            .body(Body::empty())
481            .unwrap();
482        let res = svc.clone().oneshot(req).await.unwrap();
483        assert_eq!(res.headers()["x-request-id"], "foo");
484        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
485    }
486
487    #[derive(Clone, Default)]
488    struct Counter(Arc<AtomicU64>);
489
490    impl MakeRequestId for Counter {
491        fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
492            let id =
493                HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
494            Some(RequestId::new(id))
495        }
496    }
497
498    async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
499        Ok(Response::new(Body::empty()))
500    }
501
502    #[tokio::test]
503    async fn uuid() {
504        let svc = ServiceBuilder::new()
505            .set_x_request_id(MakeRequestUuid)
506            .propagate_x_request_id()
507            .service_fn(handler);
508
509        // header on response
510        let req = Request::builder().body(Body::empty()).unwrap();
511        let mut res = svc.clone().oneshot(req).await.unwrap();
512        let id = res.headers_mut().remove("x-request-id").unwrap();
513        id.to_str().unwrap().parse::<Uuid>().unwrap();
514    }
515}