Skip to main content

tower_http/set_header/request/
multiple_headers.rs

1//! Set multiple headers on the request.
2//!
3//! See the root [`crate::set_header::request`] module for full documentation and usage examples.
4//!
5use http::{Request, Response};
6use std::{
7    fmt,
8    task::{Context, Poll},
9};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
14
15/// Layer that applies [`SetMultipleRequestHeader`] which adds multiple request headers.
16///
17/// See [`SetMultipleRequestHeader`] for more details.
18pub struct SetMultipleRequestHeadersLayer<M> {
19    headers: Vec<HeaderInsertionConfig<M>>,
20}
21
22impl<M> Clone for SetMultipleRequestHeadersLayer<M> {
23    fn clone(&self) -> Self {
24        Self {
25            headers: self.headers.clone(),
26        }
27    }
28}
29
30impl<M> fmt::Debug for SetMultipleRequestHeadersLayer<M> {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("SetMultipleRequestHeadersLayer")
33            .field("headers", &self.headers)
34            .finish()
35    }
36}
37
38impl<M> SetMultipleRequestHeadersLayer<M> {
39    /// Create a new [`SetMultipleRequestHeadersLayer`].
40    ///
41    /// If any previous value exists for the same header, it is removed and replaced with the new matching header value.
42    pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
43        let headers: Vec<HeaderInsertionConfig<M>> = metadata
44            .into_iter()
45            .map(|m| m.build_config(InsertHeaderMode::Override))
46            .collect();
47
48        Self::new(headers)
49    }
50
51    /// Create a new [`SetMultipleRequestHeadersLayer`].
52    ///
53    /// The new header is always added, preserving any existing values. If previous values exist,
54    /// the header will have multiple values.
55    pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
56        let headers: Vec<HeaderInsertionConfig<M>> = metadata
57            .into_iter()
58            .map(|m| m.build_config(InsertHeaderMode::Append))
59            .collect();
60
61        Self::new(headers)
62    }
63
64    /// Create a new [`SetMultipleRequestHeadersLayer`].
65    ///
66    /// If a previous value exists for the header, the new value is not inserted.
67    pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
68        let headers: Vec<HeaderInsertionConfig<M>> = metadata
69            .into_iter()
70            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
71            .collect();
72
73        Self::new(headers)
74    }
75
76    /// Internal constructor for a new [`SetMultipleRequestHeadersLayer`] from a list of headers.
77    fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
78        Self { headers }
79    }
80}
81
82impl<S, M> Layer<S> for SetMultipleRequestHeadersLayer<M> {
83    type Service = SetMultipleRequestHeader<S, M>;
84
85    fn layer(&self, inner: S) -> Self::Service {
86        SetMultipleRequestHeader {
87            inner,
88            headers: self.headers.clone(),
89        }
90    }
91}
92
93/// Middleware that sets multiple headers on the request.
94pub struct SetMultipleRequestHeader<S, M> {
95    inner: S,
96    headers: Vec<HeaderInsertionConfig<M>>,
97}
98
99impl<S, M> Clone for SetMultipleRequestHeader<S, M>
100where
101    S: Clone,
102{
103    fn clone(&self) -> Self {
104        Self {
105            inner: self.inner.clone(),
106            headers: self.headers.clone(),
107        }
108    }
109}
110
111impl<S, M> SetMultipleRequestHeader<S, M> {
112    /// Create a new [`SetMultipleRequestHeader`].
113    ///
114    /// If a previous value exists for the same header, it is removed and replaced with the new
115    /// header value.
116    pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
117        let headers: Vec<HeaderInsertionConfig<M>> = metadata
118            .into_iter()
119            .map(|m| m.build_config(InsertHeaderMode::Override))
120            .collect();
121
122        Self::new(inner, headers)
123    }
124
125    /// Create a new [`SetMultipleRequestHeader`].
126    ///
127    /// The new header is always added, preserving any existing values. If previous values exist,
128    /// the header will have multiple values.
129    pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
130        let headers: Vec<HeaderInsertionConfig<M>> = metadata
131            .into_iter()
132            .map(|m| m.build_config(InsertHeaderMode::Append))
133            .collect();
134
135        Self::new(inner, headers)
136    }
137
138    /// Create a new [`SetMultipleRequestHeader`].
139    ///
140    /// If a previous value exists for the header, the new value is not inserted.
141    pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
142        let headers: Vec<HeaderInsertionConfig<M>> = metadata
143            .into_iter()
144            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
145            .collect();
146
147        Self::new(inner, headers)
148    }
149
150    /// Internal constructor for a new [`SetMultipleRequestHeader`] from an inner service and a list of headers.
151    fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
152        Self { inner, headers }
153    }
154
155    define_inner_service_accessors!();
156}
157
158impl<S, M> fmt::Debug for SetMultipleRequestHeader<S, M>
159where
160    S: fmt::Debug,
161{
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("SetMultipleRequestHeader")
164            .field("inner", &self.inner)
165            .field("headers", &self.headers)
166            .finish()
167    }
168}
169
170impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
171    for SetMultipleRequestHeader<S, Request<ReqBody>>
172where
173    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
174{
175    type Response = S::Response;
176    type Error = S::Error;
177    type Future = S::Future;
178
179    #[inline]
180    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181        self.inner.poll_ready(cx)
182    }
183
184    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
185        for header in &mut self.headers {
186            header
187                .mode
188                .apply(&header.header_name, &mut req, &mut header.make);
189        }
190
191        self.inner.call(req)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::test_helpers::Body;
199    use http::{header, HeaderValue, Request, Response};
200    use std::convert::Infallible;
201    use tower::{service_fn, ServiceExt};
202
203    #[tokio::test]
204    async fn test_override_mode() {
205        let svc = SetMultipleRequestHeader::overriding(
206            service_fn(|req: Request<Body>| async move {
207                assert_eq!(req.headers()["content-type"], "text/html");
208                Ok::<_, Infallible>(Response::new(Body::empty()))
209            }),
210            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
211        );
212
213        let mut req = Request::new(Body::empty());
214
215        // Add an initial CONTENT_TYPE header to the request
216        req.headers_mut().insert(
217            header::CONTENT_TYPE,
218            HeaderValue::from_static("good-content"),
219        );
220
221        let _ = svc.oneshot(req).await.unwrap();
222    }
223
224    #[tokio::test]
225    async fn test_append_mode() {
226        let svc = SetMultipleRequestHeader::appending(
227            service_fn(|req: Request<Body>| async move {
228                let mut values = req.headers().get_all("content-type").iter();
229                assert_eq!(values.next().unwrap(), "good-content");
230                assert_eq!(values.next().unwrap(), "text/html");
231                assert_eq!(values.next(), None);
232
233                Ok::<_, Infallible>(Response::new(Body::empty()))
234            }),
235            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
236        );
237
238        // Add an initial CONTENT_TYPE header to the request
239        let mut req = Request::new(Body::empty());
240        req.headers_mut().insert(
241            header::CONTENT_TYPE,
242            HeaderValue::from_static("good-content"),
243        );
244
245        _ = svc.oneshot(req).await.unwrap();
246    }
247
248    #[tokio::test]
249    async fn test_skip_if_present_mode() {
250        let svc = SetMultipleRequestHeader::if_not_present(
251            service_fn(|req: Request<Body>| async move {
252                let mut values = req.headers().get_all("content-type").iter();
253                assert_eq!(values.next().unwrap(), "good-content");
254                assert_eq!(values.next(), None);
255
256                Ok::<_, Infallible>(Response::new(Body::empty()))
257            }),
258            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
259        );
260
261        // Add an initial CONTENT_TYPE header to the request
262        let mut req = Request::new(Body::empty());
263        req.headers_mut().insert(
264            header::CONTENT_TYPE,
265            HeaderValue::from_static("good-content"),
266        );
267
268        let _ = svc.oneshot(req).await.unwrap();
269    }
270
271    #[tokio::test]
272    async fn test_skip_if_present_mode_when_not_present() {
273        let svc = SetMultipleRequestHeader::if_not_present(
274            service_fn(|req: Request<Body>| async move {
275                let mut values = req.headers().get_all("content-type").iter();
276                assert_eq!(values.next().unwrap(), "text/html");
277                assert_eq!(values.next(), None);
278                Ok::<_, Infallible>(Response::new(Body::empty()))
279            }),
280            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
281        );
282
283        // No header to the request
284        let req = Request::new(Body::empty());
285
286        _ = svc.oneshot(req).await.unwrap();
287    }
288
289    #[test]
290    fn test_debug_impls() {
291        let meta: HeaderMetadata<HeaderValue> =
292            (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
293        let rh = meta
294            .clone()
295            .build_config(crate::set_header::InsertHeaderMode::Override);
296        let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
297        let debug_str = format!("{:?}", layer);
298        assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
299        let debug_rh = format!("{:?}", rh);
300        assert!(debug_rh.contains("HeaderInsertionConfig"));
301
302        let svc = SetMultipleRequestHeader::overriding(
303            tower::service_fn(|_req: Request<Body>| async {
304                Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
305            }),
306            vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
307                as Vec<HeaderMetadata<HeaderValue>>,
308        );
309        let debug_svc = format!("{:?}", svc);
310        assert!(debug_svc.contains("SetMultipleRequestHeader"));
311    }
312
313    #[tokio::test]
314    async fn test_layer_construction_and_multiple_headers() {
315        // Multiple different headers in the same vec
316        let svc = tower::ServiceBuilder::new()
317            .layer(SetMultipleRequestHeadersLayer::overriding(vec![
318                (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
319                (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
320            ]))
321            .service(service_fn(|req: Request<Body>| async move {
322                assert_eq!(req.headers()["content-type"], "text/html");
323                assert_eq!(req.headers()["cache-control"], "no-cache");
324
325                Ok::<_, Infallible>(Response::new(Body::empty()))
326            }));
327
328        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
329    }
330
331    #[tokio::test]
332    async fn test_layer_with_empty_vec() {
333        let header_metadatas: Vec<HeaderMetadata<Request<Body>>> = vec![];
334        let svc = tower::ServiceBuilder::new()
335            .layer(SetMultipleRequestHeadersLayer::<Request<Body>>::overriding(
336                header_metadatas,
337            ))
338            .service(service_fn(|req: Request<Body>| async move {
339                assert_eq!(req.headers().len(), 0);
340                Ok::<_, Infallible>(Response::new(Body::empty()))
341            }));
342
343        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
344    }
345
346    #[tokio::test]
347    async fn test_layer_with_static_and_closure_headers_fixed() {
348        // Wrap the static value
349        let static_meta: HeaderMetadata<Request<Body>> =
350            (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
351
352        // Wrap the closure
353        let closure_meta: HeaderMetadata<Request<Body>> =
354            (header::X_FRAME_OPTIONS, |_req: &Request<Body>| {
355                Some(HeaderValue::from_static("DENY"))
356            })
357                .into();
358
359        let svc = tower::ServiceBuilder::new()
360            .layer(SetMultipleRequestHeadersLayer::overriding(vec![
361                static_meta,
362                closure_meta,
363            ]))
364            .service(service_fn(|req: Request<Body>| async move {
365                assert_eq!(req.headers()["content-type"], "text/html");
366                assert_eq!(req.headers()["x-frame-options"], "DENY");
367
368                Ok::<_, Infallible>(Response::new(Body::empty()))
369            }));
370
371        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
372    }
373
374    #[test]
375    fn test_debug_layer_and_service() {
376        let meta: HeaderMetadata<HeaderValue> =
377            (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
378        let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
379        let debug_str = format!("{:?}", layer);
380        assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
381    }
382
383    #[test]
384    fn test_service_clone() {
385        struct NonCloneBody;
386        let svc = tower::ServiceBuilder::new()
387            .layer(SetMultipleRequestHeadersLayer::<Request<NonCloneBody>>::overriding(vec![]))
388            .check_clone()
389            .service(service_fn(|_: Request<NonCloneBody>| async move {
390                Ok::<_, Infallible>(Response::new(NonCloneBody))
391            }));
392
393        fn check_service_and_clone<T: Service<Request<NonCloneBody>> + Clone>(_: T) {}
394        check_service_and_clone(svc);
395    }
396}