tower_http/
normalize_path.rs

1//! Middleware that normalizes paths.
2//!
3//! # Example
4//!
5//! ```
6//! use tower_http::normalize_path::NormalizePathLayer;
7//! use http::{Request, Response, StatusCode};
8//! use http_body_util::Full;
9//! use bytes::Bytes;
10//! use std::{iter::once, convert::Infallible};
11//! use tower::{ServiceBuilder, Service, ServiceExt};
12//!
13//! # #[tokio::main]
14//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
16//!     // `req.uri().path()` will not have trailing slashes
17//!     # Ok(Response::new(Full::default()))
18//! }
19//!
20//! let mut service = ServiceBuilder::new()
21//!     // trim trailing slashes from paths
22//!     .layer(NormalizePathLayer::trim_trailing_slash())
23//!     .service_fn(handle);
24//!
25//! // call the service
26//! let request = Request::builder()
27//!     // `handle` will see `/foo`
28//!     .uri("/foo/")
29//!     .body(Full::default())?;
30//!
31//! service.ready().await?.call(request).await?;
32//! #
33//! # Ok(())
34//! # }
35//! ```
36
37use http::{Request, Response, Uri};
38use std::{
39    borrow::Cow,
40    task::{Context, Poll},
41};
42use tower_layer::Layer;
43use tower_service::Service;
44
45/// Different modes of normalizing paths
46#[derive(Debug, Copy, Clone)]
47enum NormalizeMode {
48    /// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo
49    Trim,
50    /// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/
51    Append,
52}
53
54/// Layer that applies [`NormalizePath`] which normalizes paths.
55///
56/// See the [module docs](self) for more details.
57#[derive(Debug, Copy, Clone)]
58pub struct NormalizePathLayer {
59    mode: NormalizeMode,
60}
61
62impl NormalizePathLayer {
63    /// Create a new [`NormalizePathLayer`].
64    ///
65    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
66    /// will be changed to `/foo` before reaching the inner service.
67    pub fn trim_trailing_slash() -> Self {
68        NormalizePathLayer {
69            mode: NormalizeMode::Trim,
70        }
71    }
72
73    /// Create a new [`NormalizePathLayer`].
74    ///
75    /// Request paths without trailing slash will be appended with a trailing slash. For example, a request with `/foo`
76    /// will be changed to `/foo/` before reaching the inner service.
77    pub fn append_trailing_slash() -> Self {
78        NormalizePathLayer {
79            mode: NormalizeMode::Append,
80        }
81    }
82}
83
84impl<S> Layer<S> for NormalizePathLayer {
85    type Service = NormalizePath<S>;
86
87    fn layer(&self, inner: S) -> Self::Service {
88        NormalizePath {
89            mode: self.mode,
90            inner,
91        }
92    }
93}
94
95/// Middleware that normalizes paths.
96///
97/// See the [module docs](self) for more details.
98#[derive(Debug, Copy, Clone)]
99pub struct NormalizePath<S> {
100    mode: NormalizeMode,
101    inner: S,
102}
103
104impl<S> NormalizePath<S> {
105    /// Construct a new [`NormalizePath`] with trim mode.
106    pub fn trim_trailing_slash(inner: S) -> Self {
107        Self {
108            mode: NormalizeMode::Trim,
109            inner,
110        }
111    }
112
113    /// Construct a new [`NormalizePath`] with append mode.
114    pub fn append_trailing_slash(inner: S) -> Self {
115        Self {
116            mode: NormalizeMode::Append,
117            inner,
118        }
119    }
120
121    define_inner_service_accessors!();
122}
123
124impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
125where
126    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
127{
128    type Response = S::Response;
129    type Error = S::Error;
130    type Future = S::Future;
131
132    #[inline]
133    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134        self.inner.poll_ready(cx)
135    }
136
137    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
138        match self.mode {
139            NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
140            NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
141        }
142        self.inner.call(req)
143    }
144}
145
146fn trim_trailing_slash(uri: &mut Uri) {
147    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
148        return;
149    }
150
151    let new_path = format!("/{}", uri.path().trim_matches('/'));
152
153    let mut parts = uri.clone().into_parts();
154
155    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
156        let new_path_and_query = if let Some(query) = path_and_query.query() {
157            Cow::Owned(format!("{}?{}", new_path, query))
158        } else {
159            new_path.into()
160        }
161        .parse()
162        .unwrap();
163
164        Some(new_path_and_query)
165    } else {
166        None
167    };
168
169    parts.path_and_query = new_path_and_query;
170    if let Ok(new_uri) = Uri::from_parts(parts) {
171        *uri = new_uri;
172    }
173}
174
175fn append_trailing_slash(uri: &mut Uri) {
176    if uri.path().ends_with("/") && !uri.path().ends_with("//") {
177        return;
178    }
179
180    let trimmed = uri.path().trim_matches('/');
181    let new_path = if trimmed.is_empty() {
182        "/".to_string()
183    } else {
184        format!("/{trimmed}/")
185    };
186
187    let mut parts = uri.clone().into_parts();
188
189    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
190        let new_path_and_query = if let Some(query) = path_and_query.query() {
191            Cow::Owned(format!("{new_path}?{query}"))
192        } else {
193            new_path.into()
194        }
195        .parse()
196        .unwrap();
197
198        Some(new_path_and_query)
199    } else {
200        Some(new_path.parse().unwrap())
201    };
202
203    parts.path_and_query = new_path_and_query;
204    if let Ok(new_uri) = Uri::from_parts(parts) {
205        *uri = new_uri;
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use std::convert::Infallible;
213    use tower::{ServiceBuilder, ServiceExt};
214
215    #[tokio::test]
216    async fn trim_works() {
217        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
218            Ok(Response::new(request.uri().to_string()))
219        }
220
221        let mut svc = ServiceBuilder::new()
222            .layer(NormalizePathLayer::trim_trailing_slash())
223            .service_fn(handle);
224
225        let body = svc
226            .ready()
227            .await
228            .unwrap()
229            .call(Request::builder().uri("/foo/").body(()).unwrap())
230            .await
231            .unwrap()
232            .into_body();
233
234        assert_eq!(body, "/foo");
235    }
236
237    #[test]
238    fn is_noop_if_no_trailing_slash() {
239        let mut uri = "/foo".parse::<Uri>().unwrap();
240        trim_trailing_slash(&mut uri);
241        assert_eq!(uri, "/foo");
242    }
243
244    #[test]
245    fn maintains_query() {
246        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
247        trim_trailing_slash(&mut uri);
248        assert_eq!(uri, "/foo?a=a");
249    }
250
251    #[test]
252    fn removes_multiple_trailing_slashes() {
253        let mut uri = "/foo////".parse::<Uri>().unwrap();
254        trim_trailing_slash(&mut uri);
255        assert_eq!(uri, "/foo");
256    }
257
258    #[test]
259    fn removes_multiple_trailing_slashes_even_with_query() {
260        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
261        trim_trailing_slash(&mut uri);
262        assert_eq!(uri, "/foo?a=a");
263    }
264
265    #[test]
266    fn is_noop_on_index() {
267        let mut uri = "/".parse::<Uri>().unwrap();
268        trim_trailing_slash(&mut uri);
269        assert_eq!(uri, "/");
270    }
271
272    #[test]
273    fn removes_multiple_trailing_slashes_on_index() {
274        let mut uri = "////".parse::<Uri>().unwrap();
275        trim_trailing_slash(&mut uri);
276        assert_eq!(uri, "/");
277    }
278
279    #[test]
280    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
281        let mut uri = "////?a=a".parse::<Uri>().unwrap();
282        trim_trailing_slash(&mut uri);
283        assert_eq!(uri, "/?a=a");
284    }
285
286    #[test]
287    fn removes_multiple_preceding_slashes_even_with_query() {
288        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
289        trim_trailing_slash(&mut uri);
290        assert_eq!(uri, "/foo?a=a");
291    }
292
293    #[test]
294    fn removes_multiple_preceding_slashes() {
295        let mut uri = "///foo".parse::<Uri>().unwrap();
296        trim_trailing_slash(&mut uri);
297        assert_eq!(uri, "/foo");
298    }
299
300    #[tokio::test]
301    async fn append_works() {
302        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
303            Ok(Response::new(request.uri().to_string()))
304        }
305
306        let mut svc = ServiceBuilder::new()
307            .layer(NormalizePathLayer::append_trailing_slash())
308            .service_fn(handle);
309
310        let body = svc
311            .ready()
312            .await
313            .unwrap()
314            .call(Request::builder().uri("/foo").body(()).unwrap())
315            .await
316            .unwrap()
317            .into_body();
318
319        assert_eq!(body, "/foo/");
320    }
321
322    #[test]
323    fn is_noop_if_trailing_slash() {
324        let mut uri = "/foo/".parse::<Uri>().unwrap();
325        append_trailing_slash(&mut uri);
326        assert_eq!(uri, "/foo/");
327    }
328
329    #[test]
330    fn append_maintains_query() {
331        let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
332        append_trailing_slash(&mut uri);
333        assert_eq!(uri, "/foo/?a=a");
334    }
335
336    #[test]
337    fn append_only_keeps_one_slash() {
338        let mut uri = "/foo////".parse::<Uri>().unwrap();
339        append_trailing_slash(&mut uri);
340        assert_eq!(uri, "/foo/");
341    }
342
343    #[test]
344    fn append_only_keeps_one_slash_even_with_query() {
345        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
346        append_trailing_slash(&mut uri);
347        assert_eq!(uri, "/foo/?a=a");
348    }
349
350    #[test]
351    fn append_is_noop_on_index() {
352        let mut uri = "/".parse::<Uri>().unwrap();
353        append_trailing_slash(&mut uri);
354        assert_eq!(uri, "/");
355    }
356
357    #[test]
358    fn append_removes_multiple_trailing_slashes_on_index() {
359        let mut uri = "////".parse::<Uri>().unwrap();
360        append_trailing_slash(&mut uri);
361        assert_eq!(uri, "/");
362    }
363
364    #[test]
365    fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
366        let mut uri = "////?a=a".parse::<Uri>().unwrap();
367        append_trailing_slash(&mut uri);
368        assert_eq!(uri, "/?a=a");
369    }
370
371    #[test]
372    fn append_removes_multiple_preceding_slashes_even_with_query() {
373        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
374        append_trailing_slash(&mut uri);
375        assert_eq!(uri, "/foo/?a=a");
376    }
377
378    #[test]
379    fn append_removes_multiple_preceding_slashes() {
380        let mut uri = "///foo".parse::<Uri>().unwrap();
381        append_trailing_slash(&mut uri);
382        assert_eq!(uri, "/foo/");
383    }
384}