tower_async_http/set_header/
response.rs

1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use http::{Request, Response, header::{self, HeaderValue}};
13//! use tower_async::{Service, ServiceExt, ServiceBuilder};
14//! use tower_async_http::set_header::SetResponseHeaderLayer;
15//! use http_body_util::Full;
16//! use bytes::Bytes;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! # let render_html = tower_async::service_fn(|request: Request<Full<Bytes>>| async move {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = ServiceBuilder::new()
25//!     .layer(
26//!         // Layer that sets `Content-Type: text/html` on responses.
27//!         //
28//!         // `if_not_present` will only insert the header if it does not already
29//!         // have a value.
30//!         SetResponseHeaderLayer::if_not_present(
31//!             header::CONTENT_TYPE,
32//!             HeaderValue::from_static("text/html"),
33//!         )
34//!     )
35//!     .service(render_html);
36//!
37//! let request = Request::new(Full::default());
38//!
39//! let response = svc.call(request).await?;
40//!
41//! assert_eq!(response.headers()["content-type"], "text/html");
42//! #
43//! # Ok(())
44//! # }
45//! ```
46//!
47//! Setting a header based on a value determined dynamically from the response:
48//!
49//! ```
50//! use http::{Request, Response, header::{self, HeaderValue}};
51//! use tower_async::{Service, ServiceExt, ServiceBuilder, BoxError};
52//! use tower_async_http::set_header::SetResponseHeaderLayer;
53//! use http_body_util::Full;
54//! use bytes::Bytes;
55//! use http_body::Body as _; // for `Body::size_hint`
56//!
57//! # #[tokio::main]
58//! # async fn main() -> Result<(), BoxError> {
59//! # let render_html = tower_async::service_fn(|request: Request<Full<Bytes>>| async move {
60//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::from("1234567890")))
61//! # });
62//! #
63//! let mut svc = ServiceBuilder::new()
64//!     .layer(
65//!         // Layer that sets `Content-Length` if the body has a known size.
66//!         // Bodies with streaming responses wont have a known size.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetResponseHeaderLayer::overriding(
71//!             header::CONTENT_LENGTH,
72//!             |response: &Response<Full<Bytes>>| {
73//!                 if let Some(size) = response.body().size_hint().exact() {
74//!                     // If the response body has a known size, returning `Some` will
75//!                     // set the `Content-Length` header to that value.
76//!                     Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//!                 } else {
78//!                     // If the response body doesn't have a known size, return `None`
79//!                     // to skip setting the header on this response.
80//!                     None
81//!                 }
82//!             }
83//!         )
84//!     )
85//!     .service(render_html);
86//!
87//! let request = Request::new(Full::default());
88//!
89//! let response = svc.call(request).await?;
90//!
91//! assert_eq!(response.headers()["content-length"], "10");
92//! #
93//! # Ok(())
94//! # }
95//! ```
96
97use super::{InsertHeaderMode, MakeHeaderValue};
98use http::{header::HeaderName, Request, Response};
99use std::fmt;
100use tower_async_layer::Layer;
101use tower_async_service::Service;
102
103/// Layer that applies [`SetResponseHeader`] which adds a response header.
104///
105/// See [`SetResponseHeader`] for more details.
106pub struct SetResponseHeaderLayer<M> {
107    header_name: HeaderName,
108    make: M,
109    mode: InsertHeaderMode,
110}
111
112impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        f.debug_struct("SetResponseHeaderLayer")
115            .field("header_name", &self.header_name)
116            .field("mode", &self.mode)
117            .field("make", &std::any::type_name::<M>())
118            .finish()
119    }
120}
121
122impl<M> SetResponseHeaderLayer<M> {
123    /// Create a new [`SetResponseHeaderLayer`].
124    ///
125    /// If a previous value exists for the same header, it is removed and replaced with the new
126    /// header value.
127    pub fn overriding(header_name: HeaderName, make: M) -> Self {
128        Self::new(header_name, make, InsertHeaderMode::Override)
129    }
130
131    /// Create a new [`SetResponseHeaderLayer`].
132    ///
133    /// The new header is always added, preserving any existing values. If previous values exist,
134    /// the header will have multiple values.
135    pub fn appending(header_name: HeaderName, make: M) -> Self {
136        Self::new(header_name, make, InsertHeaderMode::Append)
137    }
138
139    /// Create a new [`SetResponseHeaderLayer`].
140    ///
141    /// If a previous value exists for the header, the new value is not inserted.
142    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
143        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
144    }
145
146    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
147        Self {
148            make,
149            header_name,
150            mode,
151        }
152    }
153}
154
155impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
156where
157    M: Clone,
158{
159    type Service = SetResponseHeader<S, M>;
160
161    fn layer(&self, inner: S) -> Self::Service {
162        SetResponseHeader {
163            inner,
164            header_name: self.header_name.clone(),
165            make: self.make.clone(),
166            mode: self.mode,
167        }
168    }
169}
170
171impl<M> Clone for SetResponseHeaderLayer<M>
172where
173    M: Clone,
174{
175    fn clone(&self) -> Self {
176        Self {
177            make: self.make.clone(),
178            header_name: self.header_name.clone(),
179            mode: self.mode,
180        }
181    }
182}
183
184/// Middleware that sets a header on the response.
185#[derive(Clone)]
186pub struct SetResponseHeader<S, M> {
187    inner: S,
188    header_name: HeaderName,
189    make: M,
190    mode: InsertHeaderMode,
191}
192
193impl<S, M> SetResponseHeader<S, M> {
194    /// Create a new [`SetResponseHeader`].
195    ///
196    /// If a previous value exists for the same header, it is removed and replaced with the new
197    /// header value.
198    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
199        Self::new(inner, header_name, make, InsertHeaderMode::Override)
200    }
201
202    /// Create a new [`SetResponseHeader`].
203    ///
204    /// The new header is always added, preserving any existing values. If previous values exist,
205    /// the header will have multiple values.
206    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
207        Self::new(inner, header_name, make, InsertHeaderMode::Append)
208    }
209
210    /// Create a new [`SetResponseHeader`].
211    ///
212    /// If a previous value exists for the header, the new value is not inserted.
213    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
214        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
215    }
216
217    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
218        Self {
219            inner,
220            header_name,
221            make,
222            mode,
223        }
224    }
225
226    define_inner_service_accessors!();
227}
228
229impl<S, M> fmt::Debug for SetResponseHeader<S, M>
230where
231    S: fmt::Debug,
232{
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        f.debug_struct("SetResponseHeader")
235            .field("inner", &self.inner)
236            .field("header_name", &self.header_name)
237            .field("mode", &self.mode)
238            .field("make", &std::any::type_name::<M>())
239            .finish()
240    }
241}
242
243impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
244where
245    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
246    M: MakeHeaderValue<Response<ResBody>>,
247{
248    type Response = S::Response;
249    type Error = S::Error;
250
251    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
252        let mut res = self.inner.call(req).await?;
253        self.mode.apply(&self.header_name, &mut res, &self.make);
254        Ok(res)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    use crate::test_helpers::Body;
263
264    use http::{header, HeaderValue};
265    use std::convert::Infallible;
266    use tower_async::{service_fn, ServiceExt};
267
268    #[tokio::test]
269    async fn test_override_mode() {
270        let svc = SetResponseHeader::overriding(
271            service_fn(|_req: Request<Body>| async {
272                let res = Response::builder()
273                    .header(header::CONTENT_TYPE, "good-content")
274                    .body(Body::empty())
275                    .unwrap();
276                Ok::<_, Infallible>(res)
277            }),
278            header::CONTENT_TYPE,
279            HeaderValue::from_static("text/html"),
280        );
281
282        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
283
284        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
285        assert_eq!(values.next().unwrap(), "text/html");
286        assert_eq!(values.next(), None);
287    }
288
289    #[tokio::test]
290    async fn test_append_mode() {
291        let svc = SetResponseHeader::appending(
292            service_fn(|_req: Request<Body>| async {
293                let res = Response::builder()
294                    .header(header::CONTENT_TYPE, "good-content")
295                    .body(Body::empty())
296                    .unwrap();
297                Ok::<_, Infallible>(res)
298            }),
299            header::CONTENT_TYPE,
300            HeaderValue::from_static("text/html"),
301        );
302
303        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
304
305        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
306        assert_eq!(values.next().unwrap(), "good-content");
307        assert_eq!(values.next().unwrap(), "text/html");
308        assert_eq!(values.next(), None);
309    }
310
311    #[tokio::test]
312    async fn test_skip_if_present_mode() {
313        let svc = SetResponseHeader::if_not_present(
314            service_fn(|_req: Request<Body>| async {
315                let res = Response::builder()
316                    .header(header::CONTENT_TYPE, "good-content")
317                    .body(Body::empty())
318                    .unwrap();
319                Ok::<_, Infallible>(res)
320            }),
321            header::CONTENT_TYPE,
322            HeaderValue::from_static("text/html"),
323        );
324
325        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
326
327        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
328        assert_eq!(values.next().unwrap(), "good-content");
329        assert_eq!(values.next(), None);
330    }
331
332    #[tokio::test]
333    async fn test_skip_if_present_mode_when_not_present() {
334        let svc = SetResponseHeader::if_not_present(
335            service_fn(|_req: Request<Body>| async {
336                let res = Response::builder().body(Body::empty()).unwrap();
337                Ok::<_, Infallible>(res)
338            }),
339            header::CONTENT_TYPE,
340            HeaderValue::from_static("text/html"),
341        );
342
343        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
344
345        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
346        assert_eq!(values.next().unwrap(), "text/html");
347        assert_eq!(values.next(), None);
348    }
349}