Skip to main content

tower_http/set_header/request/
multiple_headers.rs

1//! Set multiple headers on the request.
2//!
3//! See the root [`crate::set_header::request`] module for full documentation and usage examples.
4//!
5use http::{Request, Response};
6use std::{
7    fmt,
8    task::{Context, Poll},
9};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
14
15/// Layer that applies [`SetMultipleRequestHeader`] which adds multiple request headers.
16///
17/// See [`SetMultipleRequestHeader`] for more details.
18#[derive(Clone)]
19pub struct SetMultipleRequestHeadersLayer<M> {
20    headers: Vec<HeaderInsertionConfig<M>>,
21}
22
23impl<M> fmt::Debug for SetMultipleRequestHeadersLayer<M> {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("SetMultipleRequestHeadersLayer")
26            .field("headers", &self.headers)
27            .finish()
28    }
29}
30
31impl<M> SetMultipleRequestHeadersLayer<M> {
32    /// Create a new [`SetMultipleRequestHeadersLayer`].
33    ///
34    /// If any previous value exists for the same header, it is removed and replaced with the new matching header value.
35    pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
36        let headers: Vec<HeaderInsertionConfig<M>> = metadata
37            .into_iter()
38            .map(|m| m.build_config(InsertHeaderMode::Override))
39            .collect();
40
41        Self::new(headers)
42    }
43
44    /// Create a new [`SetMultipleRequestHeadersLayer`].
45    ///
46    /// The new header is always added, preserving any existing values. If previous values exist,
47    /// the header will have multiple values.
48    pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
49        let headers: Vec<HeaderInsertionConfig<M>> = metadata
50            .into_iter()
51            .map(|m| m.build_config(InsertHeaderMode::Append))
52            .collect();
53
54        Self::new(headers)
55    }
56
57    /// Create a new [`SetMultipleRequestHeadersLayer`].
58    ///
59    /// If a previous value exists for the header, the new value is not inserted.
60    pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
61        let headers: Vec<HeaderInsertionConfig<M>> = metadata
62            .into_iter()
63            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
64            .collect();
65
66        Self::new(headers)
67    }
68
69    /// Internal constructor for a new [`SetMultipleRequestHeadersLayer`] from a list of headers.
70    fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
71        Self { headers }
72    }
73}
74
75impl<S, M> Layer<S> for SetMultipleRequestHeadersLayer<M> {
76    type Service = SetMultipleRequestHeader<S, M>;
77
78    fn layer(&self, inner: S) -> Self::Service {
79        SetMultipleRequestHeader {
80            inner,
81            headers: self.headers.clone(),
82        }
83    }
84}
85
86/// Middleware that sets multiple headers on the request.
87#[derive(Clone)]
88pub struct SetMultipleRequestHeader<S, M> {
89    inner: S,
90    headers: Vec<HeaderInsertionConfig<M>>,
91}
92
93impl<S, M> SetMultipleRequestHeader<S, M> {
94    /// Create a new [`SetMultipleRequestHeader`].
95    ///
96    /// If a previous value exists for the same header, it is removed and replaced with the new
97    /// header value.
98    pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
99        let headers: Vec<HeaderInsertionConfig<M>> = metadata
100            .into_iter()
101            .map(|m| m.build_config(InsertHeaderMode::Override))
102            .collect();
103
104        Self::new(inner, headers)
105    }
106
107    /// Create a new [`SetMultipleRequestHeader`].
108    ///
109    /// The new header is always added, preserving any existing values. If previous values exist,
110    /// the header will have multiple values.
111    pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
112        let headers: Vec<HeaderInsertionConfig<M>> = metadata
113            .into_iter()
114            .map(|m| m.build_config(InsertHeaderMode::Append))
115            .collect();
116
117        Self::new(inner, headers)
118    }
119
120    /// Create a new [`SetMultipleRequestHeader`].
121    ///
122    /// If a previous value exists for the header, the new value is not inserted.
123    pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
124        let headers: Vec<HeaderInsertionConfig<M>> = metadata
125            .into_iter()
126            .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
127            .collect();
128
129        Self::new(inner, headers)
130    }
131
132    /// Internal constructor for a new [`SetMultipleRequestHeader`] from an inner service and a list of headers.
133    fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
134        Self { inner, headers }
135    }
136
137    define_inner_service_accessors!();
138}
139
140impl<S, M> fmt::Debug for SetMultipleRequestHeader<S, M>
141where
142    S: fmt::Debug,
143{
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        f.debug_struct("SetMultipleRequestHeader")
146            .field("inner", &self.inner)
147            .field("headers", &self.headers)
148            .finish()
149    }
150}
151
152impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
153    for SetMultipleRequestHeader<S, Request<ReqBody>>
154where
155    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
156{
157    type Response = S::Response;
158    type Error = S::Error;
159    type Future = S::Future;
160
161    #[inline]
162    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163        self.inner.poll_ready(cx)
164    }
165
166    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
167        for header in &mut self.headers {
168            header
169                .mode
170                .apply(&header.header_name, &mut req, &mut header.make);
171        }
172
173        self.inner.call(req)
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::test_helpers::Body;
181    use http::{header, HeaderValue, Request, Response};
182    use std::convert::Infallible;
183    use tower::{service_fn, ServiceExt};
184
185    #[tokio::test]
186    async fn test_override_mode() {
187        let svc = SetMultipleRequestHeader::overriding(
188            service_fn(|req: Request<Body>| async move {
189                assert_eq!(req.headers()["content-type"], "text/html");
190                Ok::<_, Infallible>(Response::new(Body::empty()))
191            }),
192            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
193        );
194
195        let mut req = Request::new(Body::empty());
196
197        // Add an initial CONTENT_TYPE header to the request
198        req.headers_mut().insert(
199            header::CONTENT_TYPE,
200            HeaderValue::from_static("good-content"),
201        );
202
203        let _ = svc.oneshot(req).await.unwrap();
204    }
205
206    #[tokio::test]
207    async fn test_append_mode() {
208        let svc = SetMultipleRequestHeader::appending(
209            service_fn(|req: Request<Body>| async move {
210                let mut values = req.headers().get_all("content-type").iter();
211                assert_eq!(values.next().unwrap(), "good-content");
212                assert_eq!(values.next().unwrap(), "text/html");
213                assert_eq!(values.next(), None);
214
215                Ok::<_, Infallible>(Response::new(Body::empty()))
216            }),
217            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
218        );
219
220        // Add an initial CONTENT_TYPE header to the request
221        let mut req = Request::new(Body::empty());
222        req.headers_mut().insert(
223            header::CONTENT_TYPE,
224            HeaderValue::from_static("good-content"),
225        );
226
227        _ = svc.oneshot(req).await.unwrap();
228    }
229
230    #[tokio::test]
231    async fn test_skip_if_present_mode() {
232        let svc = SetMultipleRequestHeader::if_not_present(
233            service_fn(|req: Request<Body>| async move {
234                let mut values = req.headers().get_all("content-type").iter();
235                assert_eq!(values.next().unwrap(), "good-content");
236                assert_eq!(values.next(), None);
237
238                Ok::<_, Infallible>(Response::new(Body::empty()))
239            }),
240            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
241        );
242
243        // Add an initial CONTENT_TYPE header to the request
244        let mut req = Request::new(Body::empty());
245        req.headers_mut().insert(
246            header::CONTENT_TYPE,
247            HeaderValue::from_static("good-content"),
248        );
249
250        let _ = svc.oneshot(req).await.unwrap();
251    }
252
253    #[tokio::test]
254    async fn test_skip_if_present_mode_when_not_present() {
255        let svc = SetMultipleRequestHeader::if_not_present(
256            service_fn(|req: Request<Body>| async move {
257                let mut values = req.headers().get_all("content-type").iter();
258                assert_eq!(values.next().unwrap(), "text/html");
259                assert_eq!(values.next(), None);
260                Ok::<_, Infallible>(Response::new(Body::empty()))
261            }),
262            vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
263        );
264
265        // No header to the request
266        let req = Request::new(Body::empty());
267
268        _ = svc.oneshot(req).await.unwrap();
269    }
270
271    #[test]
272    fn test_debug_impls() {
273        let meta: HeaderMetadata<HeaderValue> =
274            (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
275        let rh = meta
276            .clone()
277            .build_config(crate::set_header::InsertHeaderMode::Override);
278        let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
279        let debug_str = format!("{:?}", layer);
280        assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
281        let debug_rh = format!("{:?}", rh);
282        assert!(debug_rh.contains("HeaderInsertionConfig"));
283
284        let svc = SetMultipleRequestHeader::overriding(
285            tower::service_fn(|_req: Request<Body>| async {
286                Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
287            }),
288            vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
289                as Vec<HeaderMetadata<HeaderValue>>,
290        );
291        let debug_svc = format!("{:?}", svc);
292        assert!(debug_svc.contains("SetMultipleRequestHeader"));
293    }
294
295    #[tokio::test]
296    async fn test_layer_construction_and_multiple_headers() {
297        // Multiple different headers in the same vec
298        let svc = tower::ServiceBuilder::new()
299            .layer(SetMultipleRequestHeadersLayer::overriding(vec![
300                (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
301                (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
302            ]))
303            .service(service_fn(|req: Request<Body>| async move {
304                assert_eq!(req.headers()["content-type"], "text/html");
305                assert_eq!(req.headers()["cache-control"], "no-cache");
306
307                Ok::<_, Infallible>(Response::new(Body::empty()))
308            }));
309
310        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
311    }
312
313    #[tokio::test]
314    async fn test_layer_with_empty_vec() {
315        let header_metadatas: Vec<HeaderMetadata<Request<Body>>> = vec![];
316        let svc = tower::ServiceBuilder::new()
317            .layer(SetMultipleRequestHeadersLayer::<Request<Body>>::overriding(
318                header_metadatas,
319            ))
320            .service(service_fn(|req: Request<Body>| async move {
321                assert_eq!(req.headers().len(), 0);
322                Ok::<_, Infallible>(Response::new(Body::empty()))
323            }));
324
325        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
326    }
327
328    #[tokio::test]
329    async fn test_layer_with_static_and_closure_headers_fixed() {
330        // Wrap the static value
331        let static_meta: HeaderMetadata<Request<Body>> =
332            (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
333
334        // Wrap the closure
335        let closure_meta: HeaderMetadata<Request<Body>> =
336            (header::X_FRAME_OPTIONS, |_req: &Request<Body>| {
337                Some(HeaderValue::from_static("DENY"))
338            })
339                .into();
340
341        let svc = tower::ServiceBuilder::new()
342            .layer(SetMultipleRequestHeadersLayer::overriding(vec![
343                static_meta,
344                closure_meta,
345            ]))
346            .service(service_fn(|req: Request<Body>| async move {
347                assert_eq!(req.headers()["content-type"], "text/html");
348                assert_eq!(req.headers()["x-frame-options"], "DENY");
349
350                Ok::<_, Infallible>(Response::new(Body::empty()))
351            }));
352
353        _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
354    }
355
356    #[test]
357    fn test_debug_layer_and_service() {
358        let meta: HeaderMetadata<HeaderValue> =
359            (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
360        let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
361        let debug_str = format!("{:?}", layer);
362        assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
363    }
364}