tower_http/set_header/request/
multiple_headers.rs1use http::{Request, Response};
6use std::{
7 fmt,
8 task::{Context, Poll},
9};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
14
15pub struct SetMultipleRequestHeadersLayer<M> {
19 headers: Vec<HeaderInsertionConfig<M>>,
20}
21
22impl<M> Clone for SetMultipleRequestHeadersLayer<M> {
23 fn clone(&self) -> Self {
24 Self {
25 headers: self.headers.clone(),
26 }
27 }
28}
29
30impl<M> fmt::Debug for SetMultipleRequestHeadersLayer<M> {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("SetMultipleRequestHeadersLayer")
33 .field("headers", &self.headers)
34 .finish()
35 }
36}
37
38impl<M> SetMultipleRequestHeadersLayer<M> {
39 pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
43 let headers: Vec<HeaderInsertionConfig<M>> = metadata
44 .into_iter()
45 .map(|m| m.build_config(InsertHeaderMode::Override))
46 .collect();
47
48 Self::new(headers)
49 }
50
51 pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
56 let headers: Vec<HeaderInsertionConfig<M>> = metadata
57 .into_iter()
58 .map(|m| m.build_config(InsertHeaderMode::Append))
59 .collect();
60
61 Self::new(headers)
62 }
63
64 pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
68 let headers: Vec<HeaderInsertionConfig<M>> = metadata
69 .into_iter()
70 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
71 .collect();
72
73 Self::new(headers)
74 }
75
76 fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
78 Self { headers }
79 }
80}
81
82impl<S, M> Layer<S> for SetMultipleRequestHeadersLayer<M> {
83 type Service = SetMultipleRequestHeader<S, M>;
84
85 fn layer(&self, inner: S) -> Self::Service {
86 SetMultipleRequestHeader {
87 inner,
88 headers: self.headers.clone(),
89 }
90 }
91}
92
93pub struct SetMultipleRequestHeader<S, M> {
95 inner: S,
96 headers: Vec<HeaderInsertionConfig<M>>,
97}
98
99impl<S, M> Clone for SetMultipleRequestHeader<S, M>
100where
101 S: Clone,
102{
103 fn clone(&self) -> Self {
104 Self {
105 inner: self.inner.clone(),
106 headers: self.headers.clone(),
107 }
108 }
109}
110
111impl<S, M> SetMultipleRequestHeader<S, M> {
112 pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
117 let headers: Vec<HeaderInsertionConfig<M>> = metadata
118 .into_iter()
119 .map(|m| m.build_config(InsertHeaderMode::Override))
120 .collect();
121
122 Self::new(inner, headers)
123 }
124
125 pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
130 let headers: Vec<HeaderInsertionConfig<M>> = metadata
131 .into_iter()
132 .map(|m| m.build_config(InsertHeaderMode::Append))
133 .collect();
134
135 Self::new(inner, headers)
136 }
137
138 pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
142 let headers: Vec<HeaderInsertionConfig<M>> = metadata
143 .into_iter()
144 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
145 .collect();
146
147 Self::new(inner, headers)
148 }
149
150 fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
152 Self { inner, headers }
153 }
154
155 define_inner_service_accessors!();
156}
157
158impl<S, M> fmt::Debug for SetMultipleRequestHeader<S, M>
159where
160 S: fmt::Debug,
161{
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("SetMultipleRequestHeader")
164 .field("inner", &self.inner)
165 .field("headers", &self.headers)
166 .finish()
167 }
168}
169
170impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
171 for SetMultipleRequestHeader<S, Request<ReqBody>>
172where
173 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
174{
175 type Response = S::Response;
176 type Error = S::Error;
177 type Future = S::Future;
178
179 #[inline]
180 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181 self.inner.poll_ready(cx)
182 }
183
184 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
185 for header in &mut self.headers {
186 header
187 .mode
188 .apply(&header.header_name, &mut req, &mut header.make);
189 }
190
191 self.inner.call(req)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::test_helpers::Body;
199 use http::{header, HeaderValue, Request, Response};
200 use std::convert::Infallible;
201 use tower::{service_fn, ServiceExt};
202
203 #[tokio::test]
204 async fn test_override_mode() {
205 let svc = SetMultipleRequestHeader::overriding(
206 service_fn(|req: Request<Body>| async move {
207 assert_eq!(req.headers()["content-type"], "text/html");
208 Ok::<_, Infallible>(Response::new(Body::empty()))
209 }),
210 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
211 );
212
213 let mut req = Request::new(Body::empty());
214
215 req.headers_mut().insert(
217 header::CONTENT_TYPE,
218 HeaderValue::from_static("good-content"),
219 );
220
221 let _ = svc.oneshot(req).await.unwrap();
222 }
223
224 #[tokio::test]
225 async fn test_append_mode() {
226 let svc = SetMultipleRequestHeader::appending(
227 service_fn(|req: Request<Body>| async move {
228 let mut values = req.headers().get_all("content-type").iter();
229 assert_eq!(values.next().unwrap(), "good-content");
230 assert_eq!(values.next().unwrap(), "text/html");
231 assert_eq!(values.next(), None);
232
233 Ok::<_, Infallible>(Response::new(Body::empty()))
234 }),
235 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
236 );
237
238 let mut req = Request::new(Body::empty());
240 req.headers_mut().insert(
241 header::CONTENT_TYPE,
242 HeaderValue::from_static("good-content"),
243 );
244
245 _ = svc.oneshot(req).await.unwrap();
246 }
247
248 #[tokio::test]
249 async fn test_skip_if_present_mode() {
250 let svc = SetMultipleRequestHeader::if_not_present(
251 service_fn(|req: Request<Body>| async move {
252 let mut values = req.headers().get_all("content-type").iter();
253 assert_eq!(values.next().unwrap(), "good-content");
254 assert_eq!(values.next(), None);
255
256 Ok::<_, Infallible>(Response::new(Body::empty()))
257 }),
258 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
259 );
260
261 let mut req = Request::new(Body::empty());
263 req.headers_mut().insert(
264 header::CONTENT_TYPE,
265 HeaderValue::from_static("good-content"),
266 );
267
268 let _ = svc.oneshot(req).await.unwrap();
269 }
270
271 #[tokio::test]
272 async fn test_skip_if_present_mode_when_not_present() {
273 let svc = SetMultipleRequestHeader::if_not_present(
274 service_fn(|req: Request<Body>| async move {
275 let mut values = req.headers().get_all("content-type").iter();
276 assert_eq!(values.next().unwrap(), "text/html");
277 assert_eq!(values.next(), None);
278 Ok::<_, Infallible>(Response::new(Body::empty()))
279 }),
280 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
281 );
282
283 let req = Request::new(Body::empty());
285
286 _ = svc.oneshot(req).await.unwrap();
287 }
288
289 #[test]
290 fn test_debug_impls() {
291 let meta: HeaderMetadata<HeaderValue> =
292 (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
293 let rh = meta
294 .clone()
295 .build_config(crate::set_header::InsertHeaderMode::Override);
296 let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
297 let debug_str = format!("{:?}", layer);
298 assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
299 let debug_rh = format!("{:?}", rh);
300 assert!(debug_rh.contains("HeaderInsertionConfig"));
301
302 let svc = SetMultipleRequestHeader::overriding(
303 tower::service_fn(|_req: Request<Body>| async {
304 Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
305 }),
306 vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
307 as Vec<HeaderMetadata<HeaderValue>>,
308 );
309 let debug_svc = format!("{:?}", svc);
310 assert!(debug_svc.contains("SetMultipleRequestHeader"));
311 }
312
313 #[tokio::test]
314 async fn test_layer_construction_and_multiple_headers() {
315 let svc = tower::ServiceBuilder::new()
317 .layer(SetMultipleRequestHeadersLayer::overriding(vec![
318 (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
319 (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
320 ]))
321 .service(service_fn(|req: Request<Body>| async move {
322 assert_eq!(req.headers()["content-type"], "text/html");
323 assert_eq!(req.headers()["cache-control"], "no-cache");
324
325 Ok::<_, Infallible>(Response::new(Body::empty()))
326 }));
327
328 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
329 }
330
331 #[tokio::test]
332 async fn test_layer_with_empty_vec() {
333 let header_metadatas: Vec<HeaderMetadata<Request<Body>>> = vec![];
334 let svc = tower::ServiceBuilder::new()
335 .layer(SetMultipleRequestHeadersLayer::<Request<Body>>::overriding(
336 header_metadatas,
337 ))
338 .service(service_fn(|req: Request<Body>| async move {
339 assert_eq!(req.headers().len(), 0);
340 Ok::<_, Infallible>(Response::new(Body::empty()))
341 }));
342
343 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
344 }
345
346 #[tokio::test]
347 async fn test_layer_with_static_and_closure_headers_fixed() {
348 let static_meta: HeaderMetadata<Request<Body>> =
350 (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
351
352 let closure_meta: HeaderMetadata<Request<Body>> =
354 (header::X_FRAME_OPTIONS, |_req: &Request<Body>| {
355 Some(HeaderValue::from_static("DENY"))
356 })
357 .into();
358
359 let svc = tower::ServiceBuilder::new()
360 .layer(SetMultipleRequestHeadersLayer::overriding(vec![
361 static_meta,
362 closure_meta,
363 ]))
364 .service(service_fn(|req: Request<Body>| async move {
365 assert_eq!(req.headers()["content-type"], "text/html");
366 assert_eq!(req.headers()["x-frame-options"], "DENY");
367
368 Ok::<_, Infallible>(Response::new(Body::empty()))
369 }));
370
371 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
372 }
373
374 #[test]
375 fn test_debug_layer_and_service() {
376 let meta: HeaderMetadata<HeaderValue> =
377 (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
378 let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
379 let debug_str = format!("{:?}", layer);
380 assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
381 }
382
383 #[test]
384 fn test_service_clone() {
385 struct NonCloneBody;
386 let svc = tower::ServiceBuilder::new()
387 .layer(SetMultipleRequestHeadersLayer::<Request<NonCloneBody>>::overriding(vec![]))
388 .check_clone()
389 .service(service_fn(|_: Request<NonCloneBody>| async move {
390 Ok::<_, Infallible>(Response::new(NonCloneBody))
391 }));
392
393 fn check_service_and_clone<T: Service<Request<NonCloneBody>> + Clone>(_: T) {}
394 check_service_and_clone(svc);
395 }
396}