Skip to main content

tower_http/set_header/response/
multiple_headers.rs

1//! Set multiple headers on the response.
2//!
3//! See the root [`crate::set_header::response`] module for full documentation and usage examples.
4//!
5use http::{Request, Response};
6use pin_project_lite::pin_project;
7use std::{
8    fmt,
9    future::Future,
10    pin::Pin,
11    task::{ready, Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
17
18/// Layer that applies [`SetMultipleResponseHeader`] which adds multiple response headers.
19///
20/// See [`SetMultipleResponseHeader`] for more details.
21#[derive(Clone)]
22pub struct SetMultipleResponseHeadersLayer<M> {
23    headers: Vec<HeaderInsertionConfig<M>>,
24}
25
26impl<M> fmt::Debug for SetMultipleResponseHeadersLayer<M> {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("SetMultipleResponseHeadersLayer")
29            .field("headers", &self.headers)
30            .finish()
31    }
32}
33
34impl<M> SetMultipleResponseHeadersLayer<M> {
35    /// Create a new [`SetMultipleResponseHeadersLayer`] that overrides any existing values for the same header.
36    ///
37    /// If any previous value exists for the same header, it is removed and replaced with the new matching header value.
38    pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
39        let headers: Vec<HeaderInsertionConfig<M>> = metadata
40            .into_iter()
41            .map(|m| m.build_config(InsertHeaderMode::Override))
42            .collect();
43
44        Self::new(headers)
45    }
46
47    /// Create a new [`SetMultipleResponseHeadersLayer`] that appends header values.
48    ///
49    /// The new header is always added, preserving any existing values. If previous values exist, the header will have multiple values.
50    pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
51        let headers: Vec<HeaderInsertionConfig<M>> = metadata
52            .into_iter()
53            .map(|m| m.build_config(InsertHeaderMode::Append))
54            .collect();
55
56        Self::new(headers)
57    }
58
59    /// Create a new [`SetMultipleResponseHeadersLayer`] that only inserts if the header is not already present.
60    ///
61    /// If a previous value exists for the header, the new value is not inserted.
62    pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
63        let headers: Vec<HeaderInsertionConfig<M>> = metadata
64            .into_iter()
65            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
66            .collect();
67
68        Self::new(headers)
69    }
70
71    /// Internal constructor for a new [`SetMultipleResponseHeadersLayer`] from a list of headers.
72    fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
73        Self { headers }
74    }
75}
76
77impl<S, M> Layer<S> for SetMultipleResponseHeadersLayer<M> {
78    type Service = SetMultipleResponseHeader<S, M>;
79
80    fn layer(&self, inner: S) -> Self::Service {
81        SetMultipleResponseHeader {
82            inner,
83            headers: self.headers.clone(),
84        }
85    }
86}
87
88/// Middleware that sets multiple headers on the response.
89
90#[derive(Clone)]
91pub struct SetMultipleResponseHeader<S, M> {
92    inner: S,
93    headers: Vec<HeaderInsertionConfig<M>>,
94}
95
96impl<S, M> SetMultipleResponseHeader<S, M> {
97    /// Create a new [`SetMultipleResponseHeader`] that overrides any existing values for the same header.
98    ///
99    /// If a previous value exists for the same header, it is removed and replaced with the new header value.
100    pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
101        let headers: Vec<HeaderInsertionConfig<M>> = metadata
102            .into_iter()
103            .map(|m| m.build_config(InsertHeaderMode::Override))
104            .collect();
105
106        Self::new(inner, headers)
107    }
108
109    /// Create a new [`SetMultipleResponseHeader`] that appends header values.
110    ///
111    /// The new header is always added, preserving any existing values. If previous values exist, the header will have multiple values.
112    pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
113        let headers: Vec<HeaderInsertionConfig<M>> = metadata
114            .into_iter()
115            .map(|m| m.build_config(InsertHeaderMode::Append))
116            .collect();
117
118        Self::new(inner, headers)
119    }
120
121    /// Create a new [`SetMultipleResponseHeader`] that only inserts if the header is not already present.
122    ///
123    /// If a previous value exists for the header, the new value is not inserted.
124    pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
125        let headers: Vec<HeaderInsertionConfig<M>> = metadata
126            .into_iter()
127            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
128            .collect();
129
130        Self::new(inner, headers)
131    }
132
133    /// Internal constructor for a new [`SetMultipleResponseHeader`] from an inner service and a list of headers.
134    fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
135        Self { inner, headers }
136    }
137
138    define_inner_service_accessors!();
139}
140
141impl<S, M> fmt::Debug for SetMultipleResponseHeader<S, M>
142where
143    S: fmt::Debug,
144{
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.debug_struct("SetMultipleResponseHeader")
147            .field("inner", &self.inner)
148            .field("headers", &self.headers)
149            .finish()
150    }
151}
152
153impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
154    for SetMultipleResponseHeader<S, Response<ResBody>>
155where
156    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
157{
158    type Response = S::Response;
159    type Error = S::Error;
160    type Future = ResponseFuture<S::Future, Response<ResBody>>;
161
162    #[inline]
163    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164        self.inner.poll_ready(cx)
165    }
166
167    /// Call the inner service and apply all configured headers to the response.
168    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
169        ResponseFuture {
170            future: self.inner.call(req),
171            headers: self.headers.clone(),
172        }
173    }
174}
175
176pin_project! {
177    /// Response future for [`SetMultipleResponseHeader`].
178    #[derive(Debug)]
179    pub struct ResponseFuture<F, M> {
180        #[pin]
181        future: F,
182        headers: Vec<HeaderInsertionConfig<M>>,
183    }
184}
185
186impl<F, ResBody, E> Future for ResponseFuture<F, Response<ResBody>>
187where
188    F: Future<Output = Result<Response<ResBody>, E>>,
189{
190    type Output = F::Output;
191
192    /// Polls the inner future and applies all configured headers to the response before returning it.
193    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
194        let this = self.project();
195        let mut res = ready!(this.future.poll(cx)?);
196
197        for header in this.headers {
198            header
199                .mode
200                .apply(&header.header_name, &mut res, &mut header.make);
201        }
202
203        Poll::Ready(Ok(res))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::{
211        set_header::{BoxedMakeHeaderValue, MakeHeaderValue as _},
212        test_helpers::Body,
213    };
214    use http::{header, HeaderName, HeaderValue};
215    use std::convert::Infallible;
216    use tower::{service_fn, ServiceExt};
217
218    #[tokio::test]
219    async fn test_override_mode() {
220        let svc = SetMultipleResponseHeader::overriding(
221            service_fn(|_req: Request<Body>| async {
222                let res = Response::builder()
223                    .header(header::CONTENT_TYPE, "good-content")
224                    .body(Body::empty())
225                    .unwrap();
226                Ok::<_, Infallible>(res)
227            }),
228            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
229        );
230
231        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
232
233        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
234        assert_eq!(values.next().unwrap(), "text/html");
235        assert_eq!(values.next(), None);
236    }
237
238    #[tokio::test]
239    async fn test_append_mode() {
240        let svc = SetMultipleResponseHeader::appending(
241            service_fn(|_req: Request<Body>| async {
242                let res = Response::builder()
243                    .header(header::CONTENT_TYPE, "good-content")
244                    .body(Body::empty())
245                    .unwrap();
246                Ok::<_, Infallible>(res)
247            }),
248            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
249        );
250
251        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
252
253        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
254        assert_eq!(values.next().unwrap(), "good-content");
255        assert_eq!(values.next().unwrap(), "text/html");
256        assert_eq!(values.next(), None);
257    }
258
259    #[tokio::test]
260    async fn test_skip_if_present_mode() {
261        let svc = SetMultipleResponseHeader::if_not_present(
262            service_fn(|_req: Request<Body>| async {
263                let res = Response::builder()
264                    .header(header::CONTENT_TYPE, "good-content")
265                    .body(Body::empty())
266                    .unwrap();
267                Ok::<_, Infallible>(res)
268            }),
269            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
270        );
271
272        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
273
274        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
275        assert_eq!(values.next().unwrap(), "good-content");
276        assert_eq!(values.next(), None);
277    }
278
279    #[tokio::test]
280    async fn test_skip_if_present_mode_when_not_present() {
281        let svc = SetMultipleResponseHeader::if_not_present(
282            service_fn(|_req: Request<Body>| async {
283                let res = Response::builder().body(Body::empty()).unwrap();
284                Ok::<_, Infallible>(res)
285            }),
286            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
287        );
288
289        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
290
291        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
292        assert_eq!(values.next().unwrap(), "text/html");
293        assert_eq!(values.next(), None);
294    }
295
296    #[test]
297    fn test_tuple_metadata_impl() {
298        let tuple: (HeaderName, HeaderValue) =
299            (header::CONTENT_TYPE, HeaderValue::from_static("foo"));
300        let meta: HeaderMetadata<HeaderValue> = tuple.into();
301        assert_eq!(meta.header_name, header::CONTENT_TYPE);
302        // Check that the header value is correct by making a header value from meta.make
303        let mut make = meta.make.clone();
304        assert_eq!(
305            make.make_header_value(&HeaderValue::from_static("foo")),
306            Some(HeaderValue::from_static("foo"))
307        );
308    }
309
310    #[test]
311    fn test_convert_to_header_config_struct_and_tuple() {
312        let meta: HeaderMetadata<HeaderValue> = HeaderMetadata::<HeaderValue> {
313            header_name: header::CONTENT_TYPE,
314            make: BoxedMakeHeaderValue::new(HeaderValue::from_static("bar")),
315        };
316        let rh = meta.build_config(crate::set_header::InsertHeaderMode::Override);
317        assert_eq!(rh.header_name, header::CONTENT_TYPE);
318        let mut make = rh.make.clone();
319        assert_eq!(
320            make.make_header_value(&HeaderValue::from_static("bar")),
321            Some(HeaderValue::from_static("bar"))
322        );
323
324        let tuple: (HeaderName, HeaderValue) =
325            (header::CONTENT_TYPE, HeaderValue::from_static("baz"));
326        let meta: HeaderMetadata<HeaderValue> = tuple.into();
327        let rh2 = meta.build_config(crate::set_header::InsertHeaderMode::Override);
328        assert_eq!(rh2.header_name, header::CONTENT_TYPE);
329        let mut make2 = rh2.make.clone();
330        assert_eq!(
331            make2.make_header_value(&HeaderValue::from_static("baz")),
332            Some(HeaderValue::from_static("baz"))
333        );
334    }
335
336    #[test]
337    fn test_debug_impls() {
338        let meta: HeaderMetadata<HeaderValue> =
339            (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
340        let rh = meta
341            .clone()
342            .build_config(crate::set_header::InsertHeaderMode::Override);
343        let layer = SetMultipleResponseHeadersLayer::overriding(vec![meta]);
344        let debug_str = format!("{:?}", layer);
345        assert!(debug_str.contains("SetMultipleResponseHeadersLayer"));
346        let debug_rh = format!("{:?}", rh);
347        assert!(debug_rh.contains("HeaderInsertionConfig"));
348
349        let svc = SetMultipleResponseHeader::overriding(
350            tower::service_fn(|_req: Request<Body>| async {
351                Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
352            }),
353            vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
354                as Vec<HeaderMetadata<HeaderValue>>,
355        );
356        let debug_svc = format!("{:?}", svc);
357        assert!(debug_svc.contains("SetMultipleResponseHeader"));
358    }
359
360    #[tokio::test]
361    async fn test_layer_construction_and_multiple_headers() {
362        // Multiple different headers in the same vec
363        let svc = tower::ServiceBuilder::new()
364            .layer(SetMultipleResponseHeadersLayer::overriding(vec![
365                (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
366                (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
367            ]))
368            .service(service_fn(|_req: Request<Body>| async {
369                Ok::<_, Infallible>(Response::new(Body::empty()))
370            }));
371
372        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
373        assert_eq!(res.headers()["content-type"], "text/html");
374        assert_eq!(res.headers()["cache-control"], "no-cache");
375    }
376
377    #[tokio::test]
378    async fn test_layer_with_empty_vec() {
379        let svc = tower::ServiceBuilder::new()
380            .layer(SetMultipleResponseHeadersLayer::<Response<Body>>::overriding(vec![]))
381            .service(service_fn(|_req: Request<Body>| async {
382                Ok::<_, Infallible>(Response::new(Body::empty()))
383            }));
384
385        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
386        // No headers should be set
387        assert_eq!(res.headers().len(), 0);
388    }
389
390    #[tokio::test]
391    async fn test_layer_with_static_and_closure_headers_fixed() {
392        // Wrap the static value
393        let static_meta = (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
394
395        // Wrap the closure
396        let closure_meta = (header::X_FRAME_OPTIONS, |_res: &Response<Body>| {
397            Some(HeaderValue::from_static("DENY"))
398        })
399            .into();
400
401        let svc = tower::ServiceBuilder::new()
402            .layer(SetMultipleResponseHeadersLayer::overriding(vec![
403                static_meta,
404                closure_meta,
405            ]))
406            .service(service_fn(|_req: Request<Body>| async {
407                Ok::<_, Infallible>(Response::new(Body::empty()))
408            }));
409
410        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
411        assert_eq!(res.headers()["content-type"], "text/html");
412        assert_eq!(res.headers()["x-frame-options"], "DENY");
413    }
414
415    #[test]
416    fn test_debug_layer_and_service() {
417        let meta: HeaderMetadata<HeaderValue> =
418            (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
419        let layer = SetMultipleResponseHeadersLayer::overriding(vec![meta]);
420        let debug_str = format!("{:?}", layer);
421        assert!(debug_str.contains("SetMultipleResponseHeadersLayer"));
422    }
423}