tower_async_http/set_header/
response.rs1use super::{InsertHeaderMode, MakeHeaderValue};
98use http::{header::HeaderName, Request, Response};
99use std::fmt;
100use tower_async_layer::Layer;
101use tower_async_service::Service;
102
103pub struct SetResponseHeaderLayer<M> {
107 header_name: HeaderName,
108 make: M,
109 mode: InsertHeaderMode,
110}
111
112impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 f.debug_struct("SetResponseHeaderLayer")
115 .field("header_name", &self.header_name)
116 .field("mode", &self.mode)
117 .field("make", &std::any::type_name::<M>())
118 .finish()
119 }
120}
121
122impl<M> SetResponseHeaderLayer<M> {
123 pub fn overriding(header_name: HeaderName, make: M) -> Self {
128 Self::new(header_name, make, InsertHeaderMode::Override)
129 }
130
131 pub fn appending(header_name: HeaderName, make: M) -> Self {
136 Self::new(header_name, make, InsertHeaderMode::Append)
137 }
138
139 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
143 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
144 }
145
146 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
147 Self {
148 make,
149 header_name,
150 mode,
151 }
152 }
153}
154
155impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
156where
157 M: Clone,
158{
159 type Service = SetResponseHeader<S, M>;
160
161 fn layer(&self, inner: S) -> Self::Service {
162 SetResponseHeader {
163 inner,
164 header_name: self.header_name.clone(),
165 make: self.make.clone(),
166 mode: self.mode,
167 }
168 }
169}
170
171impl<M> Clone for SetResponseHeaderLayer<M>
172where
173 M: Clone,
174{
175 fn clone(&self) -> Self {
176 Self {
177 make: self.make.clone(),
178 header_name: self.header_name.clone(),
179 mode: self.mode,
180 }
181 }
182}
183
184#[derive(Clone)]
186pub struct SetResponseHeader<S, M> {
187 inner: S,
188 header_name: HeaderName,
189 make: M,
190 mode: InsertHeaderMode,
191}
192
193impl<S, M> SetResponseHeader<S, M> {
194 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
199 Self::new(inner, header_name, make, InsertHeaderMode::Override)
200 }
201
202 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
207 Self::new(inner, header_name, make, InsertHeaderMode::Append)
208 }
209
210 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
214 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
215 }
216
217 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
218 Self {
219 inner,
220 header_name,
221 make,
222 mode,
223 }
224 }
225
226 define_inner_service_accessors!();
227}
228
229impl<S, M> fmt::Debug for SetResponseHeader<S, M>
230where
231 S: fmt::Debug,
232{
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 f.debug_struct("SetResponseHeader")
235 .field("inner", &self.inner)
236 .field("header_name", &self.header_name)
237 .field("mode", &self.mode)
238 .field("make", &std::any::type_name::<M>())
239 .finish()
240 }
241}
242
243impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
244where
245 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
246 M: MakeHeaderValue<Response<ResBody>>,
247{
248 type Response = S::Response;
249 type Error = S::Error;
250
251 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
252 let mut res = self.inner.call(req).await?;
253 self.mode.apply(&self.header_name, &mut res, &self.make);
254 Ok(res)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 use crate::test_helpers::Body;
263
264 use http::{header, HeaderValue};
265 use std::convert::Infallible;
266 use tower_async::{service_fn, ServiceExt};
267
268 #[tokio::test]
269 async fn test_override_mode() {
270 let svc = SetResponseHeader::overriding(
271 service_fn(|_req: Request<Body>| async {
272 let res = Response::builder()
273 .header(header::CONTENT_TYPE, "good-content")
274 .body(Body::empty())
275 .unwrap();
276 Ok::<_, Infallible>(res)
277 }),
278 header::CONTENT_TYPE,
279 HeaderValue::from_static("text/html"),
280 );
281
282 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
283
284 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
285 assert_eq!(values.next().unwrap(), "text/html");
286 assert_eq!(values.next(), None);
287 }
288
289 #[tokio::test]
290 async fn test_append_mode() {
291 let svc = SetResponseHeader::appending(
292 service_fn(|_req: Request<Body>| async {
293 let res = Response::builder()
294 .header(header::CONTENT_TYPE, "good-content")
295 .body(Body::empty())
296 .unwrap();
297 Ok::<_, Infallible>(res)
298 }),
299 header::CONTENT_TYPE,
300 HeaderValue::from_static("text/html"),
301 );
302
303 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
304
305 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
306 assert_eq!(values.next().unwrap(), "good-content");
307 assert_eq!(values.next().unwrap(), "text/html");
308 assert_eq!(values.next(), None);
309 }
310
311 #[tokio::test]
312 async fn test_skip_if_present_mode() {
313 let svc = SetResponseHeader::if_not_present(
314 service_fn(|_req: Request<Body>| async {
315 let res = Response::builder()
316 .header(header::CONTENT_TYPE, "good-content")
317 .body(Body::empty())
318 .unwrap();
319 Ok::<_, Infallible>(res)
320 }),
321 header::CONTENT_TYPE,
322 HeaderValue::from_static("text/html"),
323 );
324
325 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
326
327 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
328 assert_eq!(values.next().unwrap(), "good-content");
329 assert_eq!(values.next(), None);
330 }
331
332 #[tokio::test]
333 async fn test_skip_if_present_mode_when_not_present() {
334 let svc = SetResponseHeader::if_not_present(
335 service_fn(|_req: Request<Body>| async {
336 let res = Response::builder().body(Body::empty()).unwrap();
337 Ok::<_, Infallible>(res)
338 }),
339 header::CONTENT_TYPE,
340 HeaderValue::from_static("text/html"),
341 );
342
343 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
344
345 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
346 assert_eq!(values.next().unwrap(), "text/html");
347 assert_eq!(values.next(), None);
348 }
349}