Skip to main content

tower_http/set_header/response/
single_header.rs

1//! Set a single header on the response.
2//!
3//! See the root [`crate::set_header::response`] module for full documentation and usage examples.
4//!
5use http::{header::HeaderName, Request, Response};
6use pin_project_lite::pin_project;
7use std::{
8    fmt,
9    future::Future,
10    pin::Pin,
11    task::{ready, Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::set_header::{InsertHeaderMode, MakeHeaderValue};
17
18/// Layer that applies [`SetResponseHeader`] which adds a response header.
19///
20/// See [`SetResponseHeader`] for more details.
21pub struct SetResponseHeaderLayer<M> {
22    header_name: HeaderName,
23    make: M,
24    mode: InsertHeaderMode,
25}
26
27impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("SetResponseHeaderLayer")
30            .field("header_name", &self.header_name)
31            .field("mode", &self.mode)
32            .field("make", &std::any::type_name::<M>())
33            .finish()
34    }
35}
36
37impl<M> SetResponseHeaderLayer<M> {
38    /// Create a new [`SetResponseHeaderLayer`].
39    ///
40    /// If a previous value exists for the same header, it is removed and replaced with the new
41    /// header value.
42    pub fn overriding(header_name: HeaderName, make: M) -> Self {
43        Self::new(header_name, make, InsertHeaderMode::Override)
44    }
45
46    /// Create a new [`SetResponseHeaderLayer`].
47    ///
48    /// The new header is always added, preserving any existing values. If previous values exist,
49    /// the header will have multiple values.
50    pub fn appending(header_name: HeaderName, make: M) -> Self {
51        Self::new(header_name, make, InsertHeaderMode::Append)
52    }
53
54    /// Create a new [`SetResponseHeaderLayer`].
55    ///
56    /// If a previous value exists for the header, the new value is not inserted.
57    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
58        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
59    }
60
61    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
62        Self {
63            make,
64            header_name,
65            mode,
66        }
67    }
68}
69
70impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
71where
72    M: Clone,
73{
74    type Service = SetResponseHeader<S, M>;
75
76    fn layer(&self, inner: S) -> Self::Service {
77        SetResponseHeader {
78            inner,
79            header_name: self.header_name.clone(),
80            make: self.make.clone(),
81            mode: self.mode,
82        }
83    }
84}
85
86impl<M> Clone for SetResponseHeaderLayer<M>
87where
88    M: Clone,
89{
90    fn clone(&self) -> Self {
91        Self {
92            make: self.make.clone(),
93            header_name: self.header_name.clone(),
94            mode: self.mode,
95        }
96    }
97}
98
99/// Middleware that sets a header on the response.
100#[derive(Clone)]
101pub struct SetResponseHeader<S, M> {
102    inner: S,
103    header_name: HeaderName,
104    make: M,
105    mode: InsertHeaderMode,
106}
107
108impl<S, M> SetResponseHeader<S, M> {
109    /// Create a new [`SetResponseHeader`].
110    ///
111    /// If a previous value exists for the same header, it is removed and replaced with the new
112    /// header value.
113    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
114        Self::new(inner, header_name, make, InsertHeaderMode::Override)
115    }
116
117    /// Create a new [`SetResponseHeader`].
118    ///
119    /// The new header is always added, preserving any existing values. If previous values exist,
120    /// the header will have multiple values.
121    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
122        Self::new(inner, header_name, make, InsertHeaderMode::Append)
123    }
124
125    /// Create a new [`SetResponseHeader`].
126    ///
127    /// If a previous value exists for the header, the new value is not inserted.
128    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
129        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
130    }
131
132    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
133        Self {
134            inner,
135            header_name,
136            make,
137            mode,
138        }
139    }
140
141    define_inner_service_accessors!();
142}
143
144impl<S, M> fmt::Debug for SetResponseHeader<S, M>
145where
146    S: fmt::Debug,
147{
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        f.debug_struct("SetResponseHeader")
150            .field("inner", &self.inner)
151            .field("header_name", &self.header_name)
152            .field("mode", &self.mode)
153            .field("make", &std::any::type_name::<M>())
154            .finish()
155    }
156}
157
158impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
159where
160    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
161    M: MakeHeaderValue<Response<ResBody>> + Clone,
162{
163    type Response = S::Response;
164    type Error = S::Error;
165    type Future = ResponseFuture<S::Future, M>;
166
167    #[inline]
168    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        self.inner.poll_ready(cx)
170    }
171
172    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
173        ResponseFuture {
174            future: self.inner.call(req),
175            header_name: self.header_name.clone(),
176            make: self.make.clone(),
177            mode: self.mode,
178        }
179    }
180}
181
182pin_project! {
183    /// Response future for [`SetResponseHeader`].
184    #[derive(Debug)]
185    pub struct ResponseFuture<F, M> {
186        #[pin]
187        future: F,
188        header_name: HeaderName,
189        make: M,
190        mode: InsertHeaderMode,
191    }
192}
193
194impl<F, ResBody, E, M> Future for ResponseFuture<F, M>
195where
196    F: Future<Output = Result<Response<ResBody>, E>>,
197    M: MakeHeaderValue<Response<ResBody>>,
198{
199    type Output = F::Output;
200
201    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202        let this = self.project();
203        let mut res = ready!(this.future.poll(cx)?);
204
205        this.mode.apply(this.header_name, &mut res, &mut *this.make);
206
207        Poll::Ready(Ok(res))
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::test_helpers::Body;
215    use http::{header, HeaderValue};
216    use std::convert::Infallible;
217    use tower::{service_fn, ServiceExt};
218
219    #[tokio::test]
220    async fn test_override_mode() {
221        let svc = SetResponseHeader::overriding(
222            service_fn(|_req: Request<Body>| async {
223                let res = Response::builder()
224                    .header(header::CONTENT_TYPE, "good-content")
225                    .body(Body::empty())
226                    .unwrap();
227                Ok::<_, Infallible>(res)
228            }),
229            header::CONTENT_TYPE,
230            HeaderValue::from_static("text/html"),
231        );
232
233        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
234
235        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
236        assert_eq!(values.next().unwrap(), "text/html");
237        assert_eq!(values.next(), None);
238    }
239
240    #[tokio::test]
241    async fn test_append_mode() {
242        let svc = SetResponseHeader::appending(
243            service_fn(|_req: Request<Body>| async {
244                let res = Response::builder()
245                    .header(header::CONTENT_TYPE, "good-content")
246                    .body(Body::empty())
247                    .unwrap();
248                Ok::<_, Infallible>(res)
249            }),
250            header::CONTENT_TYPE,
251            HeaderValue::from_static("text/html"),
252        );
253
254        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
255
256        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
257        assert_eq!(values.next().unwrap(), "good-content");
258        assert_eq!(values.next().unwrap(), "text/html");
259        assert_eq!(values.next(), None);
260    }
261
262    #[tokio::test]
263    async fn test_skip_if_present_mode() {
264        let svc = SetResponseHeader::if_not_present(
265            service_fn(|_req: Request<Body>| async {
266                let res = Response::builder()
267                    .header(header::CONTENT_TYPE, "good-content")
268                    .body(Body::empty())
269                    .unwrap();
270                Ok::<_, Infallible>(res)
271            }),
272            header::CONTENT_TYPE,
273            HeaderValue::from_static("text/html"),
274        );
275
276        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
277
278        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
279        assert_eq!(values.next().unwrap(), "good-content");
280        assert_eq!(values.next(), None);
281    }
282
283    #[tokio::test]
284    async fn test_skip_if_present_mode_when_not_present() {
285        let svc = SetResponseHeader::if_not_present(
286            service_fn(|_req: Request<Body>| async {
287                let res = Response::builder().body(Body::empty()).unwrap();
288                Ok::<_, Infallible>(res)
289            }),
290            header::CONTENT_TYPE,
291            HeaderValue::from_static("text/html"),
292        );
293
294        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
295
296        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
297        assert_eq!(values.next().unwrap(), "text/html");
298        assert_eq!(values.next(), None);
299    }
300}