tower_async_http/
sensitive_headers.rs

1//! Middlewares that mark headers as [sensitive].
2//!
3//! [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
4//!
5//! # Example
6//!
7//! ```
8//! use tower_async_http::sensitive_headers::SetSensitiveHeadersLayer;
9//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn};
10//! use http::{Request, Response, header::AUTHORIZATION};
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::{iter::once, convert::Infallible};
14//!
15//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
16//!     // ...
17//!     # Ok(Response::new(Full::default()))
18//! }
19//!
20//! # #[tokio::main]
21//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! let mut service = ServiceBuilder::new()
23//!     // Mark the `Authorization` header as sensitive so it doesn't show in logs
24//!     //
25//!     // `SetSensitiveHeadersLayer` will mark the header as sensitive on both the
26//!     // request and response.
27//!     //
28//!     // The middleware is constructed from an iterator of headers to easily mark
29//!     // multiple headers at once.
30//!     .layer(SetSensitiveHeadersLayer::new(once(AUTHORIZATION)))
31//!     .service(service_fn(handle));
32//!
33//! // Call the service.
34//! let response = service
35//!     .call(Request::new(Full::default()))
36//!     .await?;
37//! # Ok(())
38//! # }
39//! ```
40
41use http::{header::HeaderName, Request, Response};
42use std::sync::Arc;
43use tower_async_layer::Layer;
44use tower_async_service::Service;
45
46/// Mark headers as [sensitive] on both requests and responses.
47///
48/// Produces [`SetSensitiveHeaders`] services.
49///
50/// See the [module docs](crate::sensitive_headers) for more details.
51///
52/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
53#[derive(Clone, Debug)]
54pub struct SetSensitiveHeadersLayer {
55    headers: Arc<[HeaderName]>,
56}
57
58impl SetSensitiveHeadersLayer {
59    /// Create a new [`SetSensitiveHeadersLayer`].
60    pub fn new<I>(headers: I) -> Self
61    where
62        I: IntoIterator<Item = HeaderName>,
63    {
64        let headers = headers.into_iter().collect::<Vec<_>>();
65        Self::from_shared(headers.into())
66    }
67
68    /// Create a new [`SetSensitiveHeadersLayer`] from a shared slice of headers.
69    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
70        Self { headers }
71    }
72}
73
74impl<S> Layer<S> for SetSensitiveHeadersLayer {
75    type Service = SetSensitiveHeaders<S>;
76
77    fn layer(&self, inner: S) -> Self::Service {
78        SetSensitiveRequestHeaders::from_shared(
79            SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
80            self.headers.clone(),
81        )
82    }
83}
84
85/// Mark headers as [sensitive] on both requests and responses.
86///
87/// See the [module docs](crate::sensitive_headers) for more details.
88///
89/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
90pub type SetSensitiveHeaders<S> = SetSensitiveRequestHeaders<SetSensitiveResponseHeaders<S>>;
91
92/// Mark request headers as [sensitive].
93///
94/// Produces [`SetSensitiveRequestHeaders`] services.
95///
96/// See the [module docs](crate::sensitive_headers) for more details.
97///
98/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
99#[derive(Clone, Debug)]
100pub struct SetSensitiveRequestHeadersLayer {
101    headers: Arc<[HeaderName]>,
102}
103
104impl SetSensitiveRequestHeadersLayer {
105    /// Create a new [`SetSensitiveRequestHeadersLayer`].
106    pub fn new<I>(headers: I) -> Self
107    where
108        I: IntoIterator<Item = HeaderName>,
109    {
110        let headers = headers.into_iter().collect::<Vec<_>>();
111        Self::from_shared(headers.into())
112    }
113
114    /// Create a new [`SetSensitiveRequestHeadersLayer`] from a shared slice of headers.
115    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
116        Self { headers }
117    }
118}
119
120impl<S> Layer<S> for SetSensitiveRequestHeadersLayer {
121    type Service = SetSensitiveRequestHeaders<S>;
122
123    fn layer(&self, inner: S) -> Self::Service {
124        SetSensitiveRequestHeaders {
125            inner,
126            headers: self.headers.clone(),
127        }
128    }
129}
130
131/// Mark request headers as [sensitive].
132///
133/// See the [module docs](crate::sensitive_headers) for more details.
134///
135/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
136#[derive(Clone, Debug)]
137pub struct SetSensitiveRequestHeaders<S> {
138    inner: S,
139    headers: Arc<[HeaderName]>,
140}
141
142impl<S> SetSensitiveRequestHeaders<S> {
143    /// Create a new [`SetSensitiveRequestHeaders`].
144    pub fn new<I>(inner: S, headers: I) -> Self
145    where
146        I: IntoIterator<Item = HeaderName>,
147    {
148        let headers = headers.into_iter().collect::<Vec<_>>();
149        Self::from_shared(inner, headers.into())
150    }
151
152    /// Create a new [`SetSensitiveRequestHeaders`] from a shared slice of headers.
153    pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
154        Self { inner, headers }
155    }
156
157    define_inner_service_accessors!();
158
159    /// Returns a new [`Layer`] that wraps services with a `SetSensitiveRequestHeaders` middleware.
160    ///
161    /// [`Layer`]: tower_async_layer::Layer
162    pub fn layer<I>(headers: I) -> SetSensitiveRequestHeadersLayer
163    where
164        I: IntoIterator<Item = HeaderName>,
165    {
166        SetSensitiveRequestHeadersLayer::new(headers)
167    }
168}
169
170impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveRequestHeaders<S>
171where
172    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
173{
174    type Response = S::Response;
175    type Error = S::Error;
176
177    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
178        let headers = req.headers_mut();
179        for header in &*self.headers {
180            if let http::header::Entry::Occupied(mut entry) = headers.entry(header) {
181                for value in entry.iter_mut() {
182                    value.set_sensitive(true);
183                }
184            }
185        }
186
187        self.inner.call(req).await
188    }
189}
190
191/// Mark response headers as [sensitive].
192///
193/// Produces [`SetSensitiveResponseHeaders`] services.
194///
195/// See the [module docs](crate::sensitive_headers) for more details.
196///
197/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
198#[derive(Clone, Debug)]
199pub struct SetSensitiveResponseHeadersLayer {
200    headers: Arc<[HeaderName]>,
201}
202
203impl SetSensitiveResponseHeadersLayer {
204    /// Create a new [`SetSensitiveResponseHeadersLayer`].
205    pub fn new<I>(headers: I) -> Self
206    where
207        I: IntoIterator<Item = HeaderName>,
208    {
209        let headers = headers.into_iter().collect::<Vec<_>>();
210        Self::from_shared(headers.into())
211    }
212
213    /// Create a new [`SetSensitiveResponseHeadersLayer`] from a shared slice of headers.
214    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
215        Self { headers }
216    }
217}
218
219impl<S> Layer<S> for SetSensitiveResponseHeadersLayer {
220    type Service = SetSensitiveResponseHeaders<S>;
221
222    fn layer(&self, inner: S) -> Self::Service {
223        SetSensitiveResponseHeaders {
224            inner,
225            headers: self.headers.clone(),
226        }
227    }
228}
229
230/// Mark response headers as [sensitive].
231///
232/// See the [module docs](crate::sensitive_headers) for more details.
233///
234/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
235#[derive(Clone, Debug)]
236pub struct SetSensitiveResponseHeaders<S> {
237    inner: S,
238    headers: Arc<[HeaderName]>,
239}
240
241impl<S> SetSensitiveResponseHeaders<S> {
242    /// Create a new [`SetSensitiveResponseHeaders`].
243    pub fn new<I>(inner: S, headers: I) -> Self
244    where
245        I: IntoIterator<Item = HeaderName>,
246    {
247        let headers = headers.into_iter().collect::<Vec<_>>();
248        Self::from_shared(inner, headers.into())
249    }
250
251    /// Create a new [`SetSensitiveResponseHeaders`] from a shared slice of headers.
252    pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
253        Self { inner, headers }
254    }
255
256    define_inner_service_accessors!();
257
258    /// Returns a new [`Layer`] that wraps services with a `SetSensitiveResponseHeaders` middleware.
259    ///
260    /// [`Layer`]: tower_async_layer::Layer
261    pub fn layer<I>(headers: I) -> SetSensitiveResponseHeadersLayer
262    where
263        I: IntoIterator<Item = HeaderName>,
264    {
265        SetSensitiveResponseHeadersLayer::new(headers)
266    }
267}
268
269impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveResponseHeaders<S>
270where
271    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
272{
273    type Response = S::Response;
274    type Error = S::Error;
275
276    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
277        let mut res = self.inner.call(req).await?;
278
279        let headers = res.headers_mut();
280        for header in self.headers.iter() {
281            if let http::header::Entry::Occupied(mut entry) = headers.entry(header) {
282                for value in entry.iter_mut() {
283                    value.set_sensitive(true);
284                }
285            }
286        }
287
288        Ok(res)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    #[allow(unused_imports)]
295    use super::*;
296    use http::header;
297    use tower_async::ServiceBuilder;
298
299    #[tokio::test]
300    async fn multiple_value_header() {
301        async fn response_set_cookie(req: http::Request<()>) -> Result<http::Response<()>, ()> {
302            let mut iter = req.headers().get_all(header::COOKIE).iter().peekable();
303
304            assert!(iter.peek().is_some());
305
306            for value in iter {
307                assert!(value.is_sensitive())
308            }
309
310            let mut resp = http::Response::new(());
311            resp.headers_mut().append(
312                header::CONTENT_TYPE,
313                http::HeaderValue::from_static("text/html"),
314            );
315            resp.headers_mut().append(
316                header::SET_COOKIE,
317                http::HeaderValue::from_static("cookie-1"),
318            );
319            resp.headers_mut().append(
320                header::SET_COOKIE,
321                http::HeaderValue::from_static("cookie-2"),
322            );
323            resp.headers_mut().append(
324                header::SET_COOKIE,
325                http::HeaderValue::from_static("cookie-3"),
326            );
327            Ok(resp)
328        }
329
330        let service = ServiceBuilder::new()
331            .layer(SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE]))
332            .layer(SetSensitiveResponseHeadersLayer::new(vec![
333                header::SET_COOKIE,
334            ]))
335            .service_fn(response_set_cookie);
336
337        let mut req = http::Request::new(());
338        req.headers_mut()
339            .append(header::COOKIE, http::HeaderValue::from_static("cookie+1"));
340        req.headers_mut()
341            .append(header::COOKIE, http::HeaderValue::from_static("cookie+2"));
342
343        let resp = service.call(req).await.unwrap();
344
345        assert!(!resp
346            .headers()
347            .get(header::CONTENT_TYPE)
348            .unwrap()
349            .is_sensitive());
350
351        let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable();
352
353        assert!(iter.peek().is_some());
354
355        for value in iter {
356            assert!(value.is_sensitive())
357        }
358    }
359}