tower_http/
normalize_path.rs

1//! Middleware that normalizes paths.
2//!
3//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
4//! will be changed to `/foo` before reaching the inner service.
5//!
6//! # Example
7//!
8//! ```
9//! use tower_http::normalize_path::NormalizePathLayer;
10//! use http::{Request, Response, StatusCode};
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::{iter::once, convert::Infallible};
14//! use tower::{ServiceBuilder, Service, ServiceExt};
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
19//!     // `req.uri().path()` will not have trailing slashes
20//!     # Ok(Response::new(Full::default()))
21//! }
22//!
23//! let mut service = ServiceBuilder::new()
24//!     // trim trailing slashes from paths
25//!     .layer(NormalizePathLayer::trim_trailing_slash())
26//!     .service_fn(handle);
27//!
28//! // call the service
29//! let request = Request::builder()
30//!     // `handle` will see `/foo`
31//!     .uri("/foo/")
32//!     .body(Full::default())?;
33//!
34//! service.ready().await?.call(request).await?;
35//! #
36//! # Ok(())
37//! # }
38//! ```
39
40use http::{Request, Response, Uri};
41use std::{
42    borrow::Cow,
43    task::{Context, Poll},
44};
45use tower_layer::Layer;
46use tower_service::Service;
47
48/// Layer that applies [`NormalizePath`] which normalizes paths.
49///
50/// See the [module docs](self) for more details.
51#[derive(Debug, Copy, Clone)]
52pub struct NormalizePathLayer {}
53
54impl NormalizePathLayer {
55    /// Create a new [`NormalizePathLayer`].
56    ///
57    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
58    /// will be changed to `/foo` before reaching the inner service.
59    pub fn trim_trailing_slash() -> Self {
60        NormalizePathLayer {}
61    }
62}
63
64impl<S> Layer<S> for NormalizePathLayer {
65    type Service = NormalizePath<S>;
66
67    fn layer(&self, inner: S) -> Self::Service {
68        NormalizePath::trim_trailing_slash(inner)
69    }
70}
71
72/// Middleware that normalizes paths.
73///
74/// See the [module docs](self) for more details.
75#[derive(Debug, Copy, Clone)]
76pub struct NormalizePath<S> {
77    inner: S,
78}
79
80impl<S> NormalizePath<S> {
81    /// Create a new [`NormalizePath`].
82    ///
83    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
84    /// will be changed to `/foo` before reaching the inner service.
85    pub fn trim_trailing_slash(inner: S) -> Self {
86        Self { inner }
87    }
88
89    define_inner_service_accessors!();
90}
91
92impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
93where
94    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
95{
96    type Response = S::Response;
97    type Error = S::Error;
98    type Future = S::Future;
99
100    #[inline]
101    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        self.inner.poll_ready(cx)
103    }
104
105    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106        normalize_trailing_slash(req.uri_mut());
107        self.inner.call(req)
108    }
109}
110
111fn normalize_trailing_slash(uri: &mut Uri) {
112    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
113        return;
114    }
115
116    let new_path = format!("/{}", uri.path().trim_matches('/'));
117
118    let mut parts = uri.clone().into_parts();
119
120    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
121        let new_path_and_query = if let Some(query) = path_and_query.query() {
122            Cow::Owned(format!("{}?{}", new_path, query))
123        } else {
124            new_path.into()
125        }
126        .parse()
127        .unwrap();
128
129        Some(new_path_and_query)
130    } else {
131        None
132    };
133
134    parts.path_and_query = new_path_and_query;
135    if let Ok(new_uri) = Uri::from_parts(parts) {
136        *uri = new_uri;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use std::convert::Infallible;
144    use tower::{ServiceBuilder, ServiceExt};
145
146    #[tokio::test]
147    async fn works() {
148        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
149            Ok(Response::new(request.uri().to_string()))
150        }
151
152        let mut svc = ServiceBuilder::new()
153            .layer(NormalizePathLayer::trim_trailing_slash())
154            .service_fn(handle);
155
156        let body = svc
157            .ready()
158            .await
159            .unwrap()
160            .call(Request::builder().uri("/foo/").body(()).unwrap())
161            .await
162            .unwrap()
163            .into_body();
164
165        assert_eq!(body, "/foo");
166    }
167
168    #[test]
169    fn is_noop_if_no_trailing_slash() {
170        let mut uri = "/foo".parse::<Uri>().unwrap();
171        normalize_trailing_slash(&mut uri);
172        assert_eq!(uri, "/foo");
173    }
174
175    #[test]
176    fn maintains_query() {
177        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
178        normalize_trailing_slash(&mut uri);
179        assert_eq!(uri, "/foo?a=a");
180    }
181
182    #[test]
183    fn removes_multiple_trailing_slashes() {
184        let mut uri = "/foo////".parse::<Uri>().unwrap();
185        normalize_trailing_slash(&mut uri);
186        assert_eq!(uri, "/foo");
187    }
188
189    #[test]
190    fn removes_multiple_trailing_slashes_even_with_query() {
191        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
192        normalize_trailing_slash(&mut uri);
193        assert_eq!(uri, "/foo?a=a");
194    }
195
196    #[test]
197    fn is_noop_on_index() {
198        let mut uri = "/".parse::<Uri>().unwrap();
199        normalize_trailing_slash(&mut uri);
200        assert_eq!(uri, "/");
201    }
202
203    #[test]
204    fn removes_multiple_trailing_slashes_on_index() {
205        let mut uri = "////".parse::<Uri>().unwrap();
206        normalize_trailing_slash(&mut uri);
207        assert_eq!(uri, "/");
208    }
209
210    #[test]
211    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
212        let mut uri = "////?a=a".parse::<Uri>().unwrap();
213        normalize_trailing_slash(&mut uri);
214        assert_eq!(uri, "/?a=a");
215    }
216
217    #[test]
218    fn removes_multiple_preceding_slashes_even_with_query() {
219        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
220        normalize_trailing_slash(&mut uri);
221        assert_eq!(uri, "/foo?a=a");
222    }
223
224    #[test]
225    fn removes_multiple_preceding_slashes() {
226        let mut uri = "///foo".parse::<Uri>().unwrap();
227        normalize_trailing_slash(&mut uri);
228        assert_eq!(uri, "/foo");
229    }
230}