tower_http/set_header/response/
single_header.rs1use http::{header::HeaderName, Request, Response};
6use pin_project_lite::pin_project;
7use std::{
8 fmt,
9 future::Future,
10 pin::Pin,
11 task::{ready, Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::set_header::{InsertHeaderMode, MakeHeaderValue};
17
18pub struct SetResponseHeaderLayer<M> {
22 header_name: HeaderName,
23 make: M,
24 mode: InsertHeaderMode,
25}
26
27impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("SetResponseHeaderLayer")
30 .field("header_name", &self.header_name)
31 .field("mode", &self.mode)
32 .field("make", &std::any::type_name::<M>())
33 .finish()
34 }
35}
36
37impl<M> SetResponseHeaderLayer<M> {
38 pub fn overriding(header_name: HeaderName, make: M) -> Self {
43 Self::new(header_name, make, InsertHeaderMode::Override)
44 }
45
46 pub fn appending(header_name: HeaderName, make: M) -> Self {
51 Self::new(header_name, make, InsertHeaderMode::Append)
52 }
53
54 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
58 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
59 }
60
61 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
62 Self {
63 make,
64 header_name,
65 mode,
66 }
67 }
68}
69
70impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
71where
72 M: Clone,
73{
74 type Service = SetResponseHeader<S, M>;
75
76 fn layer(&self, inner: S) -> Self::Service {
77 SetResponseHeader {
78 inner,
79 header_name: self.header_name.clone(),
80 make: self.make.clone(),
81 mode: self.mode,
82 }
83 }
84}
85
86impl<M> Clone for SetResponseHeaderLayer<M>
87where
88 M: Clone,
89{
90 fn clone(&self) -> Self {
91 Self {
92 make: self.make.clone(),
93 header_name: self.header_name.clone(),
94 mode: self.mode,
95 }
96 }
97}
98
99#[derive(Clone)]
101pub struct SetResponseHeader<S, M> {
102 inner: S,
103 header_name: HeaderName,
104 make: M,
105 mode: InsertHeaderMode,
106}
107
108impl<S, M> SetResponseHeader<S, M> {
109 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
114 Self::new(inner, header_name, make, InsertHeaderMode::Override)
115 }
116
117 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
122 Self::new(inner, header_name, make, InsertHeaderMode::Append)
123 }
124
125 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
129 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
130 }
131
132 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
133 Self {
134 inner,
135 header_name,
136 make,
137 mode,
138 }
139 }
140
141 define_inner_service_accessors!();
142}
143
144impl<S, M> fmt::Debug for SetResponseHeader<S, M>
145where
146 S: fmt::Debug,
147{
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 f.debug_struct("SetResponseHeader")
150 .field("inner", &self.inner)
151 .field("header_name", &self.header_name)
152 .field("mode", &self.mode)
153 .field("make", &std::any::type_name::<M>())
154 .finish()
155 }
156}
157
158impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
159where
160 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
161 M: MakeHeaderValue<Response<ResBody>> + Clone,
162{
163 type Response = S::Response;
164 type Error = S::Error;
165 type Future = ResponseFuture<S::Future, M>;
166
167 #[inline]
168 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 self.inner.poll_ready(cx)
170 }
171
172 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
173 ResponseFuture {
174 future: self.inner.call(req),
175 header_name: self.header_name.clone(),
176 make: self.make.clone(),
177 mode: self.mode,
178 }
179 }
180}
181
182pin_project! {
183 #[derive(Debug)]
185 pub struct ResponseFuture<F, M> {
186 #[pin]
187 future: F,
188 header_name: HeaderName,
189 make: M,
190 mode: InsertHeaderMode,
191 }
192}
193
194impl<F, ResBody, E, M> Future for ResponseFuture<F, M>
195where
196 F: Future<Output = Result<Response<ResBody>, E>>,
197 M: MakeHeaderValue<Response<ResBody>>,
198{
199 type Output = F::Output;
200
201 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202 let this = self.project();
203 let mut res = ready!(this.future.poll(cx)?);
204
205 this.mode.apply(this.header_name, &mut res, &mut *this.make);
206
207 Poll::Ready(Ok(res))
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::test_helpers::Body;
215 use http::{header, HeaderValue};
216 use std::convert::Infallible;
217 use tower::{service_fn, ServiceExt};
218
219 #[tokio::test]
220 async fn test_override_mode() {
221 let svc = SetResponseHeader::overriding(
222 service_fn(|_req: Request<Body>| async {
223 let res = Response::builder()
224 .header(header::CONTENT_TYPE, "good-content")
225 .body(Body::empty())
226 .unwrap();
227 Ok::<_, Infallible>(res)
228 }),
229 header::CONTENT_TYPE,
230 HeaderValue::from_static("text/html"),
231 );
232
233 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
234
235 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
236 assert_eq!(values.next().unwrap(), "text/html");
237 assert_eq!(values.next(), None);
238 }
239
240 #[tokio::test]
241 async fn test_append_mode() {
242 let svc = SetResponseHeader::appending(
243 service_fn(|_req: Request<Body>| async {
244 let res = Response::builder()
245 .header(header::CONTENT_TYPE, "good-content")
246 .body(Body::empty())
247 .unwrap();
248 Ok::<_, Infallible>(res)
249 }),
250 header::CONTENT_TYPE,
251 HeaderValue::from_static("text/html"),
252 );
253
254 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
255
256 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
257 assert_eq!(values.next().unwrap(), "good-content");
258 assert_eq!(values.next().unwrap(), "text/html");
259 assert_eq!(values.next(), None);
260 }
261
262 #[tokio::test]
263 async fn test_skip_if_present_mode() {
264 let svc = SetResponseHeader::if_not_present(
265 service_fn(|_req: Request<Body>| async {
266 let res = Response::builder()
267 .header(header::CONTENT_TYPE, "good-content")
268 .body(Body::empty())
269 .unwrap();
270 Ok::<_, Infallible>(res)
271 }),
272 header::CONTENT_TYPE,
273 HeaderValue::from_static("text/html"),
274 );
275
276 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
277
278 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
279 assert_eq!(values.next().unwrap(), "good-content");
280 assert_eq!(values.next(), None);
281 }
282
283 #[tokio::test]
284 async fn test_skip_if_present_mode_when_not_present() {
285 let svc = SetResponseHeader::if_not_present(
286 service_fn(|_req: Request<Body>| async {
287 let res = Response::builder().body(Body::empty()).unwrap();
288 Ok::<_, Infallible>(res)
289 }),
290 header::CONTENT_TYPE,
291 HeaderValue::from_static("text/html"),
292 );
293
294 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
295
296 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
297 assert_eq!(values.next().unwrap(), "text/html");
298 assert_eq!(values.next(), None);
299 }
300}