tower_default_headers/
lib.rs

1#![deny(clippy::all, unsafe_code)]
2#![warn(missing_docs)]
3
4//! When building an HTTP service,
5//! you may find that many/all of your endpoints are required to return the same set of HTTP
6//! headers,
7//! so may find this crate is a convenient way to centralise these common headers into a
8//! middleware.
9//!
10//! This middleware will apply these default headers to any outgoing response that does not already
11//! have headers with the same name(s).
12//!
13//! Example
14//! ```
15//! use axum::{
16//!     body::Body,
17//!     http::header::{HeaderMap, HeaderValue, X_FRAME_OPTIONS},
18//!     routing::{get, Router},
19//! };
20//! use tower_default_headers::DefaultHeadersLayer;
21//!
22//! # async fn create_and_bind_server() {
23//! let mut default_headers = HeaderMap::new();
24//! default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
25//!
26//! let app = Router::new()
27//!     .route("/", get(|| async { "hello, world!" }))
28//!     .layer(DefaultHeadersLayer::new(default_headers));
29//!
30//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
31//! axum::serve(listener, app).await.unwrap();
32//! # }
33//! ```
34
35use std::{
36    future::Future,
37    pin::Pin,
38    task::{Context, Poll},
39};
40
41use futures_util::ready;
42use http::{header::HeaderMap, Request, Response};
43use pin_project::pin_project;
44use tower_layer::Layer;
45use tower_service::Service;
46
47#[doc(hidden)]
48#[pin_project]
49pub struct ResponseFuture<F> {
50    #[pin]
51    default_headers: HeaderMap,
52    #[pin]
53    future: F,
54}
55impl<F, ResponseBody, E> Future for ResponseFuture<F>
56where
57    F: Future<Output = Result<Response<ResponseBody>, E>>,
58{
59    type Output = F::Output;
60
61    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
62        let this = self.project();
63        let mut res = ready!(this.future.poll(cx)?);
64        let headers = res.headers_mut();
65
66        for (name, value) in this.default_headers.iter() {
67            if !headers.contains_key(name) {
68                headers.insert(name, value.clone());
69            }
70        }
71
72        Poll::Ready(Ok(res))
73    }
74}
75
76#[doc(hidden)]
77#[derive(Clone)]
78pub struct DefaultHeaders<S> {
79    default_headers: HeaderMap,
80    inner: S,
81}
82impl<S> DefaultHeaders<S> {}
83impl<RequestBody, ResponseBody, S> Service<Request<RequestBody>> for DefaultHeaders<S>
84where
85    S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
86{
87    type Error = S::Error;
88    type Future = ResponseFuture<S::Future>;
89    type Response = S::Response;
90
91    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
92        self.inner.poll_ready(cx)
93    }
94
95    fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
96        ResponseFuture {
97            // TODO: juggle lifetimes and pass this in as a borrow
98            default_headers: self.default_headers.clone(),
99            future: self.inner.call(req),
100        }
101    }
102}
103
104/// middleware to set default HTTP response headers
105#[derive(Clone)]
106pub struct DefaultHeadersLayer {
107    default_headers: HeaderMap,
108}
109impl DefaultHeadersLayer {
110    /// Example
111    /// ```
112    /// use http::header::{HeaderMap, HeaderValue, X_FRAME_OPTIONS};
113    /// use tower_default_headers::DefaultHeadersLayer;
114    ///
115    /// # fn main() {
116    /// let mut default_headers = HeaderMap::new();
117    /// default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
118    ///
119    /// let layer = DefaultHeadersLayer::new(default_headers);
120    /// # }
121    /// ```
122    pub fn new(default_headers: HeaderMap) -> Self {
123        Self { default_headers }
124    }
125}
126impl<S> Layer<S> for DefaultHeadersLayer {
127    type Service = DefaultHeaders<S>;
128
129    fn layer(&self, inner: S) -> Self::Service {
130        Self::Service {
131            // TODO: juggle lifetimes and pass this in as a borrow
132            default_headers: self.default_headers.clone(),
133            inner,
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use axum::{
141        body::Body,
142        http::{
143            header::{HeaderValue, X_FRAME_OPTIONS},
144            Request, StatusCode,
145        },
146        routing::{get, Router},
147    };
148    use http_body_util::BodyExt;
149    use tower::ServiceExt;
150
151    use super::*;
152
153    #[tokio::test]
154    async fn test_headers_when_missing() {
155        let mut default_headers = HeaderMap::new();
156        default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
157
158        let app = Router::new()
159            .route("/", get(|| async { "hello, world!" }))
160            .layer(DefaultHeadersLayer::new(default_headers));
161
162        let mut response = app
163            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
164            .await
165            .unwrap();
166
167        assert_eq!(response.status(), StatusCode::OK);
168
169        let headers = response.headers();
170        assert_eq!(headers["x-frame-options"], "deny");
171
172        let mut body: Vec<u8> = Vec::new();
173        while let Some(next) = response.frame().await {
174            let frame = next.expect("should read body bytes");
175            if let Some(chunk) = frame.data_ref() {
176                body.extend(chunk);
177            }
178        }
179
180        assert_eq!(&body[..], b"hello, world!");
181    }
182
183    #[tokio::test]
184    async fn test_headers_when_already_set_by_handler() {
185        let mut default_headers = HeaderMap::new();
186        default_headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
187
188        let app = Router::new()
189            .route(
190                "/",
191                get(|| async {
192                    let mut headers = HeaderMap::new();
193                    headers.insert("x-frame-options", HeaderValue::from_static("sameorigin"));
194                    (headers, "hello, world!")
195                }),
196            )
197            .layer(DefaultHeadersLayer::new(default_headers));
198
199        let mut response = app
200            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
201            .await
202            .unwrap();
203
204        assert_eq!(response.status(), StatusCode::OK);
205
206        let headers = response.headers();
207        assert_eq!(headers["x-frame-options"], "sameorigin");
208
209        let mut body: Vec<u8> = Vec::new();
210        while let Some(next) = response.frame().await {
211            let frame = next.expect("should read body bytes");
212            if let Some(chunk) = frame.data_ref() {
213                body.extend(chunk);
214            }
215        }
216
217        assert_eq!(&body[..], b"hello, world!");
218    }
219}