tower_async_http/
propagate_header.rs

1//! Propagate a header from the request to the response.
2//!
3//! # Example
4//!
5//! ```rust
6//! use http::{Request, Response, header::HeaderName};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use std::convert::Infallible;
10//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn};
11//! use tower_async_http::propagate_header::PropagateHeaderLayer;
12//!
13//! # #[tokio::main]
14//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
16//!     // ...
17//!     # Ok(Response::new(Full::default()))
18//! }
19//!
20//! let mut svc = ServiceBuilder::new()
21//!     // This will copy `x-request-id` headers from requests onto responses.
22//!     .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id")))
23//!     .service_fn(handle);
24//!
25//! // Call the service.
26//! let request = Request::builder()
27//!     .header("x-request-id", "1337")
28//!     .body(Full::default())?;
29//!
30//! let response = svc.call(request).await?;
31//!
32//! assert_eq!(response.headers()["x-request-id"], "1337");
33//! #
34//! # Ok(())
35//! # }
36//! ```
37
38use http::{header::HeaderName, Request, Response};
39use tower_async_layer::Layer;
40use tower_async_service::Service;
41
42/// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses.
43///
44/// If the header is present on the request it'll be applied to the response as well. This could
45/// for example be used to propagate headers such as `X-Request-Id`.
46///
47/// See the [module docs](crate::propagate_header) for more details.
48#[derive(Clone, Debug)]
49pub struct PropagateHeaderLayer {
50    header: HeaderName,
51}
52
53impl PropagateHeaderLayer {
54    /// Create a new [`PropagateHeaderLayer`].
55    pub fn new(header: HeaderName) -> Self {
56        Self { header }
57    }
58}
59
60impl<S> Layer<S> for PropagateHeaderLayer {
61    type Service = PropagateHeader<S>;
62
63    fn layer(&self, inner: S) -> Self::Service {
64        PropagateHeader {
65            inner,
66            header: self.header.clone(),
67        }
68    }
69}
70
71/// Middleware that propagates headers from requests to responses.
72///
73/// If the header is present on the request it'll be applied to the response as well. This could
74/// for example be used to propagate headers such as `X-Request-Id`.
75///
76/// See the [module docs](crate::propagate_header) for more details.
77#[derive(Clone, Debug)]
78pub struct PropagateHeader<S> {
79    inner: S,
80    header: HeaderName,
81}
82
83impl<S> PropagateHeader<S> {
84    /// Create a new [`PropagateHeader`] that propagates the given header.
85    pub fn new(inner: S, header: HeaderName) -> Self {
86        Self { inner, header }
87    }
88
89    define_inner_service_accessors!();
90
91    /// Returns a new [`Layer`] that wraps services with a `PropagateHeader` middleware.
92    ///
93    /// [`Layer`]: tower_async_layer::Layer
94    pub fn layer(header: HeaderName) -> PropagateHeaderLayer {
95        PropagateHeaderLayer::new(header)
96    }
97}
98
99impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for PropagateHeader<S>
100where
101    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
102{
103    type Response = S::Response;
104    type Error = S::Error;
105
106    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
107        let value = req.headers().get(&self.header).cloned();
108
109        let mut res = self.inner.call(req).await?;
110
111        if let Some(value) = value {
112            res.headers_mut().insert(self.header.clone(), value);
113        }
114
115        Ok(res)
116    }
117}