tower_async_http/set_header/
request.rs

1//! Set a header on the request.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the request
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use http::{Request, Response, header::{self, HeaderValue}};
13//! use tower_async::{Service, ServiceExt, ServiceBuilder};
14//! use tower_async_http::set_header::SetRequestHeaderLayer;
15//! use http_body_util::Full;
16//! use bytes::Bytes;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! # let http_client = tower_async::service_fn(|_: Request<Full<Bytes>>| async move {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default()))
22//! # });
23//! #
24//! let mut svc = ServiceBuilder::new()
25//!     .layer(
26//!         // Layer that sets `User-Agent: my very cool app` on requests.
27//!         //
28//!         // `if_not_present` will only insert the header if it does not already
29//!         // have a value.
30//!         SetRequestHeaderLayer::if_not_present(
31//!             header::USER_AGENT,
32//!             HeaderValue::from_static("my very cool app"),
33//!         )
34//!     )
35//!     .service(http_client);
36//!
37//! let request = Request::new(Full::default());
38//!
39//! let response = svc.call(request).await?;
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the request:
46//!
47//! ```
48//! use http::{Request, Response, header::{self, HeaderValue}};
49//! use tower_async::{Service, ServiceExt, ServiceBuilder};
50//! use tower_async_http::set_header::SetRequestHeaderLayer;
51//! use http_body_util::Full;
52//! use bytes::Bytes;
53//!
54//! # #[tokio::main]
55//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
56//! # let http_client = tower_async::service_fn(|_: Request<Full<Bytes>>| async move {
57//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default()))
58//! # });
59//! fn date_header_value() -> HeaderValue {
60//!     // ...
61//!     # HeaderValue::from_static("now")
62//! }
63//!
64//! let mut svc = ServiceBuilder::new()
65//!     .layer(
66//!         // Layer that sets `Date` to the current date and time.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetRequestHeaderLayer::overriding(
71//!             header::DATE,
72//!             |request: &Request<Full<Bytes>>| {
73//!                 Some(date_header_value())
74//!             }
75//!         )
76//!     )
77//!     .service(http_client);
78//!
79//! let request = Request::new(Full::default());
80//!
81//! let response = svc.call(request).await?;
82//! #
83//! # Ok(())
84//! # }
85//! ```
86
87use super::{InsertHeaderMode, MakeHeaderValue};
88use http::{header::HeaderName, Request, Response};
89use std::fmt;
90use tower_async_layer::Layer;
91use tower_async_service::Service;
92
93/// Layer that applies [`SetRequestHeader`] which adds a request header.
94///
95/// See [`SetRequestHeader`] for more details.
96pub struct SetRequestHeaderLayer<M> {
97    header_name: HeaderName,
98    make: M,
99    mode: InsertHeaderMode,
100}
101
102impl<M> fmt::Debug for SetRequestHeaderLayer<M> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.debug_struct("SetRequestHeaderLayer")
105            .field("header_name", &self.header_name)
106            .field("mode", &self.mode)
107            .field("make", &std::any::type_name::<M>())
108            .finish()
109    }
110}
111
112impl<M> SetRequestHeaderLayer<M> {
113    /// Create a new [`SetRequestHeaderLayer`].
114    ///
115    /// If a previous value exists for the same header, it is removed and replaced with the new
116    /// header value.
117    pub fn overriding(header_name: HeaderName, make: M) -> Self {
118        Self::new(header_name, make, InsertHeaderMode::Override)
119    }
120
121    /// Create a new [`SetRequestHeaderLayer`].
122    ///
123    /// The new header is always added, preserving any existing values. If previous values exist,
124    /// the header will have multiple values.
125    pub fn appending(header_name: HeaderName, make: M) -> Self {
126        Self::new(header_name, make, InsertHeaderMode::Append)
127    }
128
129    /// Create a new [`SetRequestHeaderLayer`].
130    ///
131    /// If a previous value exists for the header, the new value is not inserted.
132    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
133        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
134    }
135
136    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
137        Self {
138            make,
139            header_name,
140            mode,
141        }
142    }
143}
144
145impl<S, M> Layer<S> for SetRequestHeaderLayer<M>
146where
147    M: Clone,
148{
149    type Service = SetRequestHeader<S, M>;
150
151    fn layer(&self, inner: S) -> Self::Service {
152        SetRequestHeader {
153            inner,
154            header_name: self.header_name.clone(),
155            make: self.make.clone(),
156            mode: self.mode,
157        }
158    }
159}
160
161impl<M> Clone for SetRequestHeaderLayer<M>
162where
163    M: Clone,
164{
165    fn clone(&self) -> Self {
166        Self {
167            make: self.make.clone(),
168            header_name: self.header_name.clone(),
169            mode: self.mode,
170        }
171    }
172}
173
174/// Middleware that sets a header on the request.
175#[derive(Clone)]
176pub struct SetRequestHeader<S, M> {
177    inner: S,
178    header_name: HeaderName,
179    make: M,
180    mode: InsertHeaderMode,
181}
182
183impl<S, M> SetRequestHeader<S, M> {
184    /// Create a new [`SetRequestHeader`].
185    ///
186    /// If a previous value exists for the same header, it is removed and replaced with the new
187    /// header value.
188    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
189        Self::new(inner, header_name, make, InsertHeaderMode::Override)
190    }
191
192    /// Create a new [`SetRequestHeader`].
193    ///
194    /// The new header is always added, preserving any existing values. If previous values exist,
195    /// the header will have multiple values.
196    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
197        Self::new(inner, header_name, make, InsertHeaderMode::Append)
198    }
199
200    /// Create a new [`SetRequestHeader`].
201    ///
202    /// If a previous value exists for the header, the new value is not inserted.
203    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
204        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
205    }
206
207    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
208        Self {
209            inner,
210            header_name,
211            make,
212            mode,
213        }
214    }
215
216    define_inner_service_accessors!();
217}
218
219impl<S, M> fmt::Debug for SetRequestHeader<S, M>
220where
221    S: fmt::Debug,
222{
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        f.debug_struct("SetRequestHeader")
225            .field("inner", &self.inner)
226            .field("header_name", &self.header_name)
227            .field("mode", &self.mode)
228            .field("make", &std::any::type_name::<M>())
229            .finish()
230    }
231}
232
233impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetRequestHeader<S, M>
234where
235    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
236    M: MakeHeaderValue<Request<ReqBody>>,
237{
238    type Response = S::Response;
239    type Error = S::Error;
240
241    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
242        self.mode.apply(&self.header_name, &mut req, &self.make);
243        self.inner.call(req).await
244    }
245}