tower_default_headers/
lib.rs1#![deny(clippy::all, unsafe_code)]
2#![warn(missing_docs)]
3
4use std::{
36 future::Future,
37 pin::Pin,
38 task::{Context, Poll},
39};
40
41use futures_util::ready;
42use http::{header::HeaderMap, Request, Response};
43use pin_project::pin_project;
44use tower_layer::Layer;
45use tower_service::Service;
46
47#[doc(hidden)]
48#[pin_project]
49pub struct ResponseFuture<F> {
50 #[pin]
51 default_headers: HeaderMap,
52 #[pin]
53 future: F,
54}
55impl<F, ResponseBody, E> Future for ResponseFuture<F>
56where
57 F: Future<Output = Result<Response<ResponseBody>, E>>,
58{
59 type Output = F::Output;
60
61 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
62 let this = self.project();
63 let mut res = ready!(this.future.poll(cx)?);
64 let headers = res.headers_mut();
65
66 for (name, value) in this.default_headers.iter() {
67 if !headers.contains_key(name) {
68 headers.insert(name, value.clone());
69 }
70 }
71
72 Poll::Ready(Ok(res))
73 }
74}
75
76#[doc(hidden)]
77#[derive(Clone)]
78pub struct DefaultHeaders<S> {
79 default_headers: HeaderMap,
80 inner: S,
81}
82impl<S> DefaultHeaders<S> {}
83impl<RequestBody, ResponseBody, S> Service<Request<RequestBody>> for DefaultHeaders<S>
84where
85 S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
86{
87 type Error = S::Error;
88 type Future = ResponseFuture<S::Future>;
89 type Response = S::Response;
90
91 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
92 self.inner.poll_ready(cx)
93 }
94
95 fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
96 ResponseFuture {
97 default_headers: self.default_headers.clone(),
99 future: self.inner.call(req),
100 }
101 }
102}
103
104#[derive(Clone)]
106pub struct DefaultHeadersLayer {
107 default_headers: HeaderMap,
108}
109impl DefaultHeadersLayer {
110 pub fn new(default_headers: HeaderMap) -> Self {
123 Self { default_headers }
124 }
125}
126impl<S> Layer<S> for DefaultHeadersLayer {
127 type Service = DefaultHeaders<S>;
128
129 fn layer(&self, inner: S) -> Self::Service {
130 Self::Service {
131 default_headers: self.default_headers.clone(),
133 inner,
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use axum::{
141 body::Body,
142 http::{
143 header::{HeaderValue, X_FRAME_OPTIONS},
144 Request, StatusCode,
145 },
146 routing::{get, Router},
147 };
148 use http_body_util::BodyExt;
149 use tower::ServiceExt;
150
151 use super::*;
152
153 #[tokio::test]
154 async fn test_headers_when_missing() {
155 let mut default_headers = HeaderMap::new();
156 default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
157
158 let app = Router::new()
159 .route("/", get(|| async { "hello, world!" }))
160 .layer(DefaultHeadersLayer::new(default_headers));
161
162 let mut response = app
163 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
164 .await
165 .unwrap();
166
167 assert_eq!(response.status(), StatusCode::OK);
168
169 let headers = response.headers();
170 assert_eq!(headers["x-frame-options"], "deny");
171
172 let mut body: Vec<u8> = Vec::new();
173 while let Some(next) = response.frame().await {
174 let frame = next.expect("should read body bytes");
175 if let Some(chunk) = frame.data_ref() {
176 body.extend(chunk);
177 }
178 }
179
180 assert_eq!(&body[..], b"hello, world!");
181 }
182
183 #[tokio::test]
184 async fn test_headers_when_already_set_by_handler() {
185 let mut default_headers = HeaderMap::new();
186 default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
187
188 let app = Router::new()
189 .route(
190 "/",
191 get(|| async {
192 let mut headers = HeaderMap::new();
193 headers.insert("x-frame-options", HeaderValue::from_static("sameorigin"));
194 (headers, "hello, world!")
195 }),
196 )
197 .layer(DefaultHeadersLayer::new(default_headers));
198
199 let mut response = app
200 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
201 .await
202 .unwrap();
203
204 assert_eq!(response.status(), StatusCode::OK);
205
206 let headers = response.headers();
207 assert_eq!(headers["x-frame-options"], "sameorigin");
208
209 let mut body: Vec<u8> = Vec::new();
210 while let Some(next) = response.frame().await {
211 let frame = next.expect("should read body bytes");
212 if let Some(chunk) = frame.data_ref() {
213 body.extend(chunk);
214 }
215 }
216
217 assert_eq!(&body[..], b"hello, world!");
218 }
219}