tower_http/set_header/request/
multiple_headers.rs1use http::{Request, Response};
6use std::{
7 fmt,
8 task::{Context, Poll},
9};
10use tower_layer::Layer;
11use tower_service::Service;
12
13use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
14
15#[derive(Clone)]
19pub struct SetMultipleRequestHeadersLayer<M> {
20 headers: Vec<HeaderInsertionConfig<M>>,
21}
22
23impl<M> fmt::Debug for SetMultipleRequestHeadersLayer<M> {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 f.debug_struct("SetMultipleRequestHeadersLayer")
26 .field("headers", &self.headers)
27 .finish()
28 }
29}
30
31impl<M> SetMultipleRequestHeadersLayer<M> {
32 pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
36 let headers: Vec<HeaderInsertionConfig<M>> = metadata
37 .into_iter()
38 .map(|m| m.build_config(InsertHeaderMode::Override))
39 .collect();
40
41 Self::new(headers)
42 }
43
44 pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
49 let headers: Vec<HeaderInsertionConfig<M>> = metadata
50 .into_iter()
51 .map(|m| m.build_config(InsertHeaderMode::Append))
52 .collect();
53
54 Self::new(headers)
55 }
56
57 pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
61 let headers: Vec<HeaderInsertionConfig<M>> = metadata
62 .into_iter()
63 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
64 .collect();
65
66 Self::new(headers)
67 }
68
69 fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
71 Self { headers }
72 }
73}
74
75impl<S, M> Layer<S> for SetMultipleRequestHeadersLayer<M> {
76 type Service = SetMultipleRequestHeader<S, M>;
77
78 fn layer(&self, inner: S) -> Self::Service {
79 SetMultipleRequestHeader {
80 inner,
81 headers: self.headers.clone(),
82 }
83 }
84}
85
86#[derive(Clone)]
88pub struct SetMultipleRequestHeader<S, M> {
89 inner: S,
90 headers: Vec<HeaderInsertionConfig<M>>,
91}
92
93impl<S, M> SetMultipleRequestHeader<S, M> {
94 pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
99 let headers: Vec<HeaderInsertionConfig<M>> = metadata
100 .into_iter()
101 .map(|m| m.build_config(InsertHeaderMode::Override))
102 .collect();
103
104 Self::new(inner, headers)
105 }
106
107 pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
112 let headers: Vec<HeaderInsertionConfig<M>> = metadata
113 .into_iter()
114 .map(|m| m.build_config(InsertHeaderMode::Append))
115 .collect();
116
117 Self::new(inner, headers)
118 }
119
120 pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
124 let headers: Vec<HeaderInsertionConfig<M>> = metadata
125 .into_iter()
126 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
127 .collect();
128
129 Self::new(inner, headers)
130 }
131
132 fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
134 Self { inner, headers }
135 }
136
137 define_inner_service_accessors!();
138}
139
140impl<S, M> fmt::Debug for SetMultipleRequestHeader<S, M>
141where
142 S: fmt::Debug,
143{
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 f.debug_struct("SetMultipleRequestHeader")
146 .field("inner", &self.inner)
147 .field("headers", &self.headers)
148 .finish()
149 }
150}
151
152impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
153 for SetMultipleRequestHeader<S, Request<ReqBody>>
154where
155 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
156{
157 type Response = S::Response;
158 type Error = S::Error;
159 type Future = S::Future;
160
161 #[inline]
162 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163 self.inner.poll_ready(cx)
164 }
165
166 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
167 for header in &mut self.headers {
168 header
169 .mode
170 .apply(&header.header_name, &mut req, &mut header.make);
171 }
172
173 self.inner.call(req)
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::test_helpers::Body;
181 use http::{header, HeaderValue, Request, Response};
182 use std::convert::Infallible;
183 use tower::{service_fn, ServiceExt};
184
185 #[tokio::test]
186 async fn test_override_mode() {
187 let svc = SetMultipleRequestHeader::overriding(
188 service_fn(|req: Request<Body>| async move {
189 assert_eq!(req.headers()["content-type"], "text/html");
190 Ok::<_, Infallible>(Response::new(Body::empty()))
191 }),
192 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
193 );
194
195 let mut req = Request::new(Body::empty());
196
197 req.headers_mut().insert(
199 header::CONTENT_TYPE,
200 HeaderValue::from_static("good-content"),
201 );
202
203 let _ = svc.oneshot(req).await.unwrap();
204 }
205
206 #[tokio::test]
207 async fn test_append_mode() {
208 let svc = SetMultipleRequestHeader::appending(
209 service_fn(|req: Request<Body>| async move {
210 let mut values = req.headers().get_all("content-type").iter();
211 assert_eq!(values.next().unwrap(), "good-content");
212 assert_eq!(values.next().unwrap(), "text/html");
213 assert_eq!(values.next(), None);
214
215 Ok::<_, Infallible>(Response::new(Body::empty()))
216 }),
217 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
218 );
219
220 let mut req = Request::new(Body::empty());
222 req.headers_mut().insert(
223 header::CONTENT_TYPE,
224 HeaderValue::from_static("good-content"),
225 );
226
227 _ = svc.oneshot(req).await.unwrap();
228 }
229
230 #[tokio::test]
231 async fn test_skip_if_present_mode() {
232 let svc = SetMultipleRequestHeader::if_not_present(
233 service_fn(|req: Request<Body>| async move {
234 let mut values = req.headers().get_all("content-type").iter();
235 assert_eq!(values.next().unwrap(), "good-content");
236 assert_eq!(values.next(), None);
237
238 Ok::<_, Infallible>(Response::new(Body::empty()))
239 }),
240 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
241 );
242
243 let mut req = Request::new(Body::empty());
245 req.headers_mut().insert(
246 header::CONTENT_TYPE,
247 HeaderValue::from_static("good-content"),
248 );
249
250 let _ = svc.oneshot(req).await.unwrap();
251 }
252
253 #[tokio::test]
254 async fn test_skip_if_present_mode_when_not_present() {
255 let svc = SetMultipleRequestHeader::if_not_present(
256 service_fn(|req: Request<Body>| async move {
257 let mut values = req.headers().get_all("content-type").iter();
258 assert_eq!(values.next().unwrap(), "text/html");
259 assert_eq!(values.next(), None);
260 Ok::<_, Infallible>(Response::new(Body::empty()))
261 }),
262 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
263 );
264
265 let req = Request::new(Body::empty());
267
268 _ = svc.oneshot(req).await.unwrap();
269 }
270
271 #[test]
272 fn test_debug_impls() {
273 let meta: HeaderMetadata<HeaderValue> =
274 (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
275 let rh = meta
276 .clone()
277 .build_config(crate::set_header::InsertHeaderMode::Override);
278 let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
279 let debug_str = format!("{:?}", layer);
280 assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
281 let debug_rh = format!("{:?}", rh);
282 assert!(debug_rh.contains("HeaderInsertionConfig"));
283
284 let svc = SetMultipleRequestHeader::overriding(
285 tower::service_fn(|_req: Request<Body>| async {
286 Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
287 }),
288 vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
289 as Vec<HeaderMetadata<HeaderValue>>,
290 );
291 let debug_svc = format!("{:?}", svc);
292 assert!(debug_svc.contains("SetMultipleRequestHeader"));
293 }
294
295 #[tokio::test]
296 async fn test_layer_construction_and_multiple_headers() {
297 let svc = tower::ServiceBuilder::new()
299 .layer(SetMultipleRequestHeadersLayer::overriding(vec![
300 (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
301 (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
302 ]))
303 .service(service_fn(|req: Request<Body>| async move {
304 assert_eq!(req.headers()["content-type"], "text/html");
305 assert_eq!(req.headers()["cache-control"], "no-cache");
306
307 Ok::<_, Infallible>(Response::new(Body::empty()))
308 }));
309
310 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
311 }
312
313 #[tokio::test]
314 async fn test_layer_with_empty_vec() {
315 let header_metadatas: Vec<HeaderMetadata<Request<Body>>> = vec![];
316 let svc = tower::ServiceBuilder::new()
317 .layer(SetMultipleRequestHeadersLayer::<Request<Body>>::overriding(
318 header_metadatas,
319 ))
320 .service(service_fn(|req: Request<Body>| async move {
321 assert_eq!(req.headers().len(), 0);
322 Ok::<_, Infallible>(Response::new(Body::empty()))
323 }));
324
325 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
326 }
327
328 #[tokio::test]
329 async fn test_layer_with_static_and_closure_headers_fixed() {
330 let static_meta: HeaderMetadata<Request<Body>> =
332 (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
333
334 let closure_meta: HeaderMetadata<Request<Body>> =
336 (header::X_FRAME_OPTIONS, |_req: &Request<Body>| {
337 Some(HeaderValue::from_static("DENY"))
338 })
339 .into();
340
341 let svc = tower::ServiceBuilder::new()
342 .layer(SetMultipleRequestHeadersLayer::overriding(vec![
343 static_meta,
344 closure_meta,
345 ]))
346 .service(service_fn(|req: Request<Body>| async move {
347 assert_eq!(req.headers()["content-type"], "text/html");
348 assert_eq!(req.headers()["x-frame-options"], "DENY");
349
350 Ok::<_, Infallible>(Response::new(Body::empty()))
351 }));
352
353 _ = svc.oneshot(Request::new(Body::empty())).await.unwrap();
354 }
355
356 #[test]
357 fn test_debug_layer_and_service() {
358 let meta: HeaderMetadata<HeaderValue> =
359 (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
360 let layer = SetMultipleRequestHeadersLayer::overriding(vec![meta]);
361 let debug_str = format!("{:?}", layer);
362 assert!(debug_str.contains("SetMultipleRequestHeadersLayer"));
363 }
364}