tower_http/set_header/response/
multiple_headers.rs1use http::{Request, Response};
6use pin_project_lite::pin_project;
7use std::{
8 fmt,
9 future::Future,
10 pin::Pin,
11 task::{ready, Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::set_header::{HeaderInsertionConfig, HeaderMetadata, InsertHeaderMode};
17
18#[derive(Clone)]
22pub struct SetMultipleResponseHeadersLayer<M> {
23 headers: Vec<HeaderInsertionConfig<M>>,
24}
25
26impl<M> fmt::Debug for SetMultipleResponseHeadersLayer<M> {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("SetMultipleResponseHeadersLayer")
29 .field("headers", &self.headers)
30 .finish()
31 }
32}
33
34impl<M> SetMultipleResponseHeadersLayer<M> {
35 pub fn overriding(metadata: Vec<HeaderMetadata<M>>) -> Self {
39 let headers: Vec<HeaderInsertionConfig<M>> = metadata
40 .into_iter()
41 .map(|m| m.build_config(InsertHeaderMode::Override))
42 .collect();
43
44 Self::new(headers)
45 }
46
47 pub fn appending(metadata: Vec<HeaderMetadata<M>>) -> Self {
51 let headers: Vec<HeaderInsertionConfig<M>> = metadata
52 .into_iter()
53 .map(|m| m.build_config(InsertHeaderMode::Append))
54 .collect();
55
56 Self::new(headers)
57 }
58
59 pub fn if_not_present(metadata: Vec<HeaderMetadata<M>>) -> Self {
63 let headers: Vec<HeaderInsertionConfig<M>> = metadata
64 .into_iter()
65 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
66 .collect();
67
68 Self::new(headers)
69 }
70
71 fn new(headers: Vec<HeaderInsertionConfig<M>>) -> Self {
73 Self { headers }
74 }
75}
76
77impl<S, M> Layer<S> for SetMultipleResponseHeadersLayer<M> {
78 type Service = SetMultipleResponseHeader<S, M>;
79
80 fn layer(&self, inner: S) -> Self::Service {
81 SetMultipleResponseHeader {
82 inner,
83 headers: self.headers.clone(),
84 }
85 }
86}
87
88#[derive(Clone)]
91pub struct SetMultipleResponseHeader<S, M> {
92 inner: S,
93 headers: Vec<HeaderInsertionConfig<M>>,
94}
95
96impl<S, M> SetMultipleResponseHeader<S, M> {
97 pub fn overriding(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
101 let headers: Vec<HeaderInsertionConfig<M>> = metadata
102 .into_iter()
103 .map(|m| m.build_config(InsertHeaderMode::Override))
104 .collect();
105
106 Self::new(inner, headers)
107 }
108
109 pub fn appending(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
113 let headers: Vec<HeaderInsertionConfig<M>> = metadata
114 .into_iter()
115 .map(|m| m.build_config(InsertHeaderMode::Append))
116 .collect();
117
118 Self::new(inner, headers)
119 }
120
121 pub fn if_not_present(inner: S, metadata: Vec<HeaderMetadata<M>>) -> Self {
125 let headers: Vec<HeaderInsertionConfig<M>> = metadata
126 .into_iter()
127 .map(|m| m.build_config(InsertHeaderMode::IfNotPresent))
128 .collect();
129
130 Self::new(inner, headers)
131 }
132
133 fn new(inner: S, headers: Vec<HeaderInsertionConfig<M>>) -> Self {
135 Self { inner, headers }
136 }
137
138 define_inner_service_accessors!();
139}
140
141impl<S, M> fmt::Debug for SetMultipleResponseHeader<S, M>
142where
143 S: fmt::Debug,
144{
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 f.debug_struct("SetMultipleResponseHeader")
147 .field("inner", &self.inner)
148 .field("headers", &self.headers)
149 .finish()
150 }
151}
152
153impl<ReqBody, ResBody, S> Service<Request<ReqBody>>
154 for SetMultipleResponseHeader<S, Response<ResBody>>
155where
156 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
157{
158 type Response = S::Response;
159 type Error = S::Error;
160 type Future = ResponseFuture<S::Future, Response<ResBody>>;
161
162 #[inline]
163 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164 self.inner.poll_ready(cx)
165 }
166
167 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
169 ResponseFuture {
170 future: self.inner.call(req),
171 headers: self.headers.clone(),
172 }
173 }
174}
175
176pin_project! {
177 #[derive(Debug)]
179 pub struct ResponseFuture<F, M> {
180 #[pin]
181 future: F,
182 headers: Vec<HeaderInsertionConfig<M>>,
183 }
184}
185
186impl<F, ResBody, E> Future for ResponseFuture<F, Response<ResBody>>
187where
188 F: Future<Output = Result<Response<ResBody>, E>>,
189{
190 type Output = F::Output;
191
192 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
194 let this = self.project();
195 let mut res = ready!(this.future.poll(cx)?);
196
197 for header in this.headers {
198 header
199 .mode
200 .apply(&header.header_name, &mut res, &mut header.make);
201 }
202
203 Poll::Ready(Ok(res))
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::{
211 set_header::{BoxedMakeHeaderValue, MakeHeaderValue as _},
212 test_helpers::Body,
213 };
214 use http::{header, HeaderName, HeaderValue};
215 use std::convert::Infallible;
216 use tower::{service_fn, ServiceExt};
217
218 #[tokio::test]
219 async fn test_override_mode() {
220 let svc = SetMultipleResponseHeader::overriding(
221 service_fn(|_req: Request<Body>| async {
222 let res = Response::builder()
223 .header(header::CONTENT_TYPE, "good-content")
224 .body(Body::empty())
225 .unwrap();
226 Ok::<_, Infallible>(res)
227 }),
228 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
229 );
230
231 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
232
233 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
234 assert_eq!(values.next().unwrap(), "text/html");
235 assert_eq!(values.next(), None);
236 }
237
238 #[tokio::test]
239 async fn test_append_mode() {
240 let svc = SetMultipleResponseHeader::appending(
241 service_fn(|_req: Request<Body>| async {
242 let res = Response::builder()
243 .header(header::CONTENT_TYPE, "good-content")
244 .body(Body::empty())
245 .unwrap();
246 Ok::<_, Infallible>(res)
247 }),
248 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
249 );
250
251 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
252
253 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
254 assert_eq!(values.next().unwrap(), "good-content");
255 assert_eq!(values.next().unwrap(), "text/html");
256 assert_eq!(values.next(), None);
257 }
258
259 #[tokio::test]
260 async fn test_skip_if_present_mode() {
261 let svc = SetMultipleResponseHeader::if_not_present(
262 service_fn(|_req: Request<Body>| async {
263 let res = Response::builder()
264 .header(header::CONTENT_TYPE, "good-content")
265 .body(Body::empty())
266 .unwrap();
267 Ok::<_, Infallible>(res)
268 }),
269 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
270 );
271
272 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
273
274 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
275 assert_eq!(values.next().unwrap(), "good-content");
276 assert_eq!(values.next(), None);
277 }
278
279 #[tokio::test]
280 async fn test_skip_if_present_mode_when_not_present() {
281 let svc = SetMultipleResponseHeader::if_not_present(
282 service_fn(|_req: Request<Body>| async {
283 let res = Response::builder().body(Body::empty()).unwrap();
284 Ok::<_, Infallible>(res)
285 }),
286 vec![(header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into()],
287 );
288
289 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
290
291 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
292 assert_eq!(values.next().unwrap(), "text/html");
293 assert_eq!(values.next(), None);
294 }
295
296 #[test]
297 fn test_tuple_metadata_impl() {
298 let tuple: (HeaderName, HeaderValue) =
299 (header::CONTENT_TYPE, HeaderValue::from_static("foo"));
300 let meta: HeaderMetadata<HeaderValue> = tuple.into();
301 assert_eq!(meta.header_name, header::CONTENT_TYPE);
302 let mut make = meta.make.clone();
304 assert_eq!(
305 make.make_header_value(&HeaderValue::from_static("foo")),
306 Some(HeaderValue::from_static("foo"))
307 );
308 }
309
310 #[test]
311 fn test_convert_to_header_config_struct_and_tuple() {
312 let meta: HeaderMetadata<HeaderValue> = HeaderMetadata::<HeaderValue> {
313 header_name: header::CONTENT_TYPE,
314 make: BoxedMakeHeaderValue::new(HeaderValue::from_static("bar")),
315 };
316 let rh = meta.build_config(crate::set_header::InsertHeaderMode::Override);
317 assert_eq!(rh.header_name, header::CONTENT_TYPE);
318 let mut make = rh.make.clone();
319 assert_eq!(
320 make.make_header_value(&HeaderValue::from_static("bar")),
321 Some(HeaderValue::from_static("bar"))
322 );
323
324 let tuple: (HeaderName, HeaderValue) =
325 (header::CONTENT_TYPE, HeaderValue::from_static("baz"));
326 let meta: HeaderMetadata<HeaderValue> = tuple.into();
327 let rh2 = meta.build_config(crate::set_header::InsertHeaderMode::Override);
328 assert_eq!(rh2.header_name, header::CONTENT_TYPE);
329 let mut make2 = rh2.make.clone();
330 assert_eq!(
331 make2.make_header_value(&HeaderValue::from_static("baz")),
332 Some(HeaderValue::from_static("baz"))
333 );
334 }
335
336 #[test]
337 fn test_debug_impls() {
338 let meta: HeaderMetadata<HeaderValue> =
339 (header::CONTENT_TYPE, HeaderValue::from_static("bar")).into();
340 let rh = meta
341 .clone()
342 .build_config(crate::set_header::InsertHeaderMode::Override);
343 let layer = SetMultipleResponseHeadersLayer::overriding(vec![meta]);
344 let debug_str = format!("{:?}", layer);
345 assert!(debug_str.contains("SetMultipleResponseHeadersLayer"));
346 let debug_rh = format!("{:?}", rh);
347 assert!(debug_rh.contains("HeaderInsertionConfig"));
348
349 let svc = SetMultipleResponseHeader::overriding(
350 tower::service_fn(|_req: Request<Body>| async {
351 Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
352 }),
353 vec![(header::CONTENT_TYPE, HeaderValue::from_static("foo")).into()]
354 as Vec<HeaderMetadata<HeaderValue>>,
355 );
356 let debug_svc = format!("{:?}", svc);
357 assert!(debug_svc.contains("SetMultipleResponseHeader"));
358 }
359
360 #[tokio::test]
361 async fn test_layer_construction_and_multiple_headers() {
362 let svc = tower::ServiceBuilder::new()
364 .layer(SetMultipleResponseHeadersLayer::overriding(vec![
365 (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into(),
366 (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")).into(),
367 ]))
368 .service(service_fn(|_req: Request<Body>| async {
369 Ok::<_, Infallible>(Response::new(Body::empty()))
370 }));
371
372 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
373 assert_eq!(res.headers()["content-type"], "text/html");
374 assert_eq!(res.headers()["cache-control"], "no-cache");
375 }
376
377 #[tokio::test]
378 async fn test_layer_with_empty_vec() {
379 let svc = tower::ServiceBuilder::new()
380 .layer(SetMultipleResponseHeadersLayer::<Response<Body>>::overriding(vec![]))
381 .service(service_fn(|_req: Request<Body>| async {
382 Ok::<_, Infallible>(Response::new(Body::empty()))
383 }));
384
385 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
386 assert_eq!(res.headers().len(), 0);
388 }
389
390 #[tokio::test]
391 async fn test_layer_with_static_and_closure_headers_fixed() {
392 let static_meta = (header::CONTENT_TYPE, HeaderValue::from_static("text/html")).into();
394
395 let closure_meta = (header::X_FRAME_OPTIONS, |_res: &Response<Body>| {
397 Some(HeaderValue::from_static("DENY"))
398 })
399 .into();
400
401 let svc = tower::ServiceBuilder::new()
402 .layer(SetMultipleResponseHeadersLayer::overriding(vec![
403 static_meta,
404 closure_meta,
405 ]))
406 .service(service_fn(|_req: Request<Body>| async {
407 Ok::<_, Infallible>(Response::new(Body::empty()))
408 }));
409
410 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
411 assert_eq!(res.headers()["content-type"], "text/html");
412 assert_eq!(res.headers()["x-frame-options"], "DENY");
413 }
414
415 #[test]
416 fn test_debug_layer_and_service() {
417 let meta: HeaderMetadata<HeaderValue> =
418 (header::CONTENT_TYPE, HeaderValue::from_static("foo")).into();
419 let layer = SetMultipleResponseHeadersLayer::overriding(vec![meta]);
420 let debug_str = format!("{:?}", layer);
421 assert!(debug_str.contains("SetMultipleResponseHeadersLayer"));
422 }
423}