tower_http/
sensitive_headers.rs1use http::{header::HeaderName, Request, Response};
88use pin_project_lite::pin_project;
89use std::{
90 future::Future,
91 pin::Pin,
92 sync::Arc,
93 task::{ready, Context, Poll},
94};
95use tower_layer::Layer;
96use tower_service::Service;
97
98#[derive(Clone, Debug)]
106pub struct SetSensitiveHeadersLayer {
107 headers: Arc<[HeaderName]>,
108}
109
110impl SetSensitiveHeadersLayer {
111 pub fn new<I>(headers: I) -> Self
113 where
114 I: IntoIterator<Item = HeaderName>,
115 {
116 let headers = headers.into_iter().collect::<Vec<_>>();
117 Self::from_shared(headers.into())
118 }
119
120 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
122 Self { headers }
123 }
124}
125
126impl<S> Layer<S> for SetSensitiveHeadersLayer {
127 type Service = SetSensitiveHeaders<S>;
128
129 fn layer(&self, inner: S) -> Self::Service {
130 SetSensitiveRequestHeaders::from_shared(
131 SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
132 self.headers.clone(),
133 )
134 }
135}
136
137pub type SetSensitiveHeaders<S> = SetSensitiveRequestHeaders<SetSensitiveResponseHeaders<S>>;
143
144#[derive(Clone, Debug)]
152pub struct SetSensitiveRequestHeadersLayer {
153 headers: Arc<[HeaderName]>,
154}
155
156impl SetSensitiveRequestHeadersLayer {
157 pub fn new<I>(headers: I) -> Self
159 where
160 I: IntoIterator<Item = HeaderName>,
161 {
162 let headers = headers.into_iter().collect::<Vec<_>>();
163 Self::from_shared(headers.into())
164 }
165
166 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
168 Self { headers }
169 }
170}
171
172impl<S> Layer<S> for SetSensitiveRequestHeadersLayer {
173 type Service = SetSensitiveRequestHeaders<S>;
174
175 fn layer(&self, inner: S) -> Self::Service {
176 SetSensitiveRequestHeaders {
177 inner,
178 headers: self.headers.clone(),
179 }
180 }
181}
182
183#[derive(Clone, Debug)]
189pub struct SetSensitiveRequestHeaders<S> {
190 inner: S,
191 headers: Arc<[HeaderName]>,
192}
193
194impl<S> SetSensitiveRequestHeaders<S> {
195 pub fn new<I>(inner: S, headers: I) -> Self
197 where
198 I: IntoIterator<Item = HeaderName>,
199 {
200 let headers = headers.into_iter().collect::<Vec<_>>();
201 Self::from_shared(inner, headers.into())
202 }
203
204 pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
206 Self { inner, headers }
207 }
208
209 define_inner_service_accessors!();
210
211 pub fn layer<I>(headers: I) -> SetSensitiveRequestHeadersLayer
215 where
216 I: IntoIterator<Item = HeaderName>,
217 {
218 SetSensitiveRequestHeadersLayer::new(headers)
219 }
220}
221
222impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveRequestHeaders<S>
223where
224 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
225{
226 type Response = S::Response;
227 type Error = S::Error;
228 type Future = S::Future;
229
230 #[inline]
231 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
232 self.inner.poll_ready(cx)
233 }
234
235 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
236 let headers = req.headers_mut();
237 for header in &*self.headers {
238 if let http::header::Entry::Occupied(mut entry) = headers.entry(header) {
239 for value in entry.iter_mut() {
240 value.set_sensitive(true);
241 }
242 }
243 }
244
245 self.inner.call(req)
246 }
247}
248
249#[derive(Clone, Debug)]
257pub struct SetSensitiveResponseHeadersLayer {
258 headers: Arc<[HeaderName]>,
259}
260
261impl SetSensitiveResponseHeadersLayer {
262 pub fn new<I>(headers: I) -> Self
264 where
265 I: IntoIterator<Item = HeaderName>,
266 {
267 let headers = headers.into_iter().collect::<Vec<_>>();
268 Self::from_shared(headers.into())
269 }
270
271 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
273 Self { headers }
274 }
275}
276
277impl<S> Layer<S> for SetSensitiveResponseHeadersLayer {
278 type Service = SetSensitiveResponseHeaders<S>;
279
280 fn layer(&self, inner: S) -> Self::Service {
281 SetSensitiveResponseHeaders {
282 inner,
283 headers: self.headers.clone(),
284 }
285 }
286}
287
288#[derive(Clone, Debug)]
294pub struct SetSensitiveResponseHeaders<S> {
295 inner: S,
296 headers: Arc<[HeaderName]>,
297}
298
299impl<S> SetSensitiveResponseHeaders<S> {
300 pub fn new<I>(inner: S, headers: I) -> Self
302 where
303 I: IntoIterator<Item = HeaderName>,
304 {
305 let headers = headers.into_iter().collect::<Vec<_>>();
306 Self::from_shared(inner, headers.into())
307 }
308
309 pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
311 Self { inner, headers }
312 }
313
314 define_inner_service_accessors!();
315
316 pub fn layer<I>(headers: I) -> SetSensitiveResponseHeadersLayer
320 where
321 I: IntoIterator<Item = HeaderName>,
322 {
323 SetSensitiveResponseHeadersLayer::new(headers)
324 }
325}
326
327impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for SetSensitiveResponseHeaders<S>
328where
329 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
330{
331 type Response = S::Response;
332 type Error = S::Error;
333 type Future = SetSensitiveResponseHeadersResponseFuture<S::Future>;
334
335 #[inline]
336 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
337 self.inner.poll_ready(cx)
338 }
339
340 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
341 SetSensitiveResponseHeadersResponseFuture {
342 future: self.inner.call(req),
343 headers: self.headers.clone(),
344 }
345 }
346}
347
348pin_project! {
349 #[derive(Debug)]
351 pub struct SetSensitiveResponseHeadersResponseFuture<F> {
352 #[pin]
353 future: F,
354 headers: Arc<[HeaderName]>,
355 }
356}
357
358impl<F, ResBody, E> Future for SetSensitiveResponseHeadersResponseFuture<F>
359where
360 F: Future<Output = Result<Response<ResBody>, E>>,
361{
362 type Output = F::Output;
363
364 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
365 let this = self.project();
366 let mut res = ready!(this.future.poll(cx)?);
367
368 let headers = res.headers_mut();
369 for header in &**this.headers {
370 if let http::header::Entry::Occupied(mut entry) = headers.entry(header) {
371 for value in entry.iter_mut() {
372 value.set_sensitive(true);
373 }
374 }
375 }
376
377 Poll::Ready(Ok(res))
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 #[allow(unused_imports)]
384 use super::*;
385 use http::header;
386 use tower::{ServiceBuilder, ServiceExt};
387
388 #[tokio::test]
389 async fn multiple_value_header() {
390 async fn response_set_cookie(req: http::Request<()>) -> Result<http::Response<()>, ()> {
391 let mut iter = req.headers().get_all(header::COOKIE).iter().peekable();
392
393 assert!(iter.peek().is_some());
394
395 for value in iter {
396 assert!(value.is_sensitive())
397 }
398
399 let mut resp = http::Response::new(());
400 resp.headers_mut().append(
401 header::CONTENT_TYPE,
402 http::HeaderValue::from_static("text/html"),
403 );
404 resp.headers_mut().append(
405 header::SET_COOKIE,
406 http::HeaderValue::from_static("cookie-1"),
407 );
408 resp.headers_mut().append(
409 header::SET_COOKIE,
410 http::HeaderValue::from_static("cookie-2"),
411 );
412 resp.headers_mut().append(
413 header::SET_COOKIE,
414 http::HeaderValue::from_static("cookie-3"),
415 );
416 Ok(resp)
417 }
418
419 let mut service = ServiceBuilder::new()
420 .layer(SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE]))
421 .layer(SetSensitiveResponseHeadersLayer::new(vec![
422 header::SET_COOKIE,
423 ]))
424 .service_fn(response_set_cookie);
425
426 let mut req = http::Request::new(());
427 req.headers_mut()
428 .append(header::COOKIE, http::HeaderValue::from_static("cookie+1"));
429 req.headers_mut()
430 .append(header::COOKIE, http::HeaderValue::from_static("cookie+2"));
431
432 let resp = service.ready().await.unwrap().call(req).await.unwrap();
433
434 assert!(!resp
435 .headers()
436 .get(header::CONTENT_TYPE)
437 .unwrap()
438 .is_sensitive());
439
440 let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable();
441
442 assert!(iter.peek().is_some());
443
444 for value in iter {
445 assert!(value.is_sensitive())
446 }
447 }
448}