tower_async_http/
normalize_path.rs

1//! Middleware that normalizes paths.
2//!
3//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
4//! will be changed to `/foo` before reaching the inner service.
5//!
6//! # Example
7//!
8//! ```
9//! use tower_async_http::normalize_path::NormalizePathLayer;
10//! use http::{Request, Response, StatusCode};
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::{iter::once, convert::Infallible};
14//! use tower_async::{ServiceBuilder, Service, ServiceExt};
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
19//!     // `req.uri().path()` will not have trailing slashes
20//!     # Ok(Response::new(Full::default()))
21//! }
22//!
23//! let mut service = ServiceBuilder::new()
24//!     // trim trailing slashes from paths
25//!     .layer(NormalizePathLayer::trim_trailing_slash())
26//!     .service_fn(handle);
27//!
28//! // call the service
29//! let request = Request::builder()
30//!     // `handle` will see `/foo`
31//!     .uri("/foo/")
32//!     .body(Full::<Bytes>::default())?;
33//!
34//! service.call(request).await?;
35//! #
36//! # Ok(())
37//! # }
38//! ```
39
40use http::{Request, Response, Uri};
41use std::borrow::Cow;
42use tower_async_layer::Layer;
43use tower_async_service::Service;
44
45/// Layer that applies [`NormalizePath`] which normalizes paths.
46///
47/// See the [module docs](self) for more details.
48#[derive(Debug, Copy, Clone)]
49pub struct NormalizePathLayer {}
50
51impl NormalizePathLayer {
52    /// Create a new [`NormalizePathLayer`].
53    ///
54    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
55    /// will be changed to `/foo` before reaching the inner service.
56    pub fn trim_trailing_slash() -> Self {
57        NormalizePathLayer {}
58    }
59}
60
61impl<S> Layer<S> for NormalizePathLayer {
62    type Service = NormalizePath<S>;
63
64    fn layer(&self, inner: S) -> Self::Service {
65        NormalizePath::trim_trailing_slash(inner)
66    }
67}
68
69/// Middleware that normalizes paths.
70///
71/// See the [module docs](self) for more details.
72#[derive(Debug, Copy, Clone)]
73pub struct NormalizePath<S> {
74    inner: S,
75}
76
77impl<S> NormalizePath<S> {
78    /// Create a new [`NormalizePath`].
79    ///
80    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
81    /// will be changed to `/foo` before reaching the inner service.
82    pub fn trim_trailing_slash(inner: S) -> Self {
83        Self { inner }
84    }
85
86    define_inner_service_accessors!();
87}
88
89impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
90where
91    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
92{
93    type Response = S::Response;
94    type Error = S::Error;
95
96    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
97        normalize_trailing_slash(req.uri_mut());
98        self.inner.call(req).await
99    }
100}
101
102fn normalize_trailing_slash(uri: &mut Uri) {
103    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
104        return;
105    }
106
107    let new_path = format!("/{}", uri.path().trim_matches('/'));
108
109    let mut parts = uri.clone().into_parts();
110
111    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
112        let new_path = if new_path.is_empty() {
113            "/"
114        } else {
115            new_path.as_str()
116        };
117
118        let new_path_and_query = if let Some(query) = path_and_query.query() {
119            Cow::Owned(format!("{}?{}", new_path, query))
120        } else {
121            new_path.into()
122        }
123        .parse()
124        .unwrap();
125
126        Some(new_path_and_query)
127    } else {
128        None
129    };
130
131    parts.path_and_query = new_path_and_query;
132    if let Ok(new_uri) = Uri::from_parts(parts) {
133        *uri = new_uri;
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use std::convert::Infallible;
141    use tower_async::ServiceBuilder;
142
143    #[tokio::test]
144    async fn works() {
145        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
146            Ok(Response::new(request.uri().to_string()))
147        }
148
149        let svc = ServiceBuilder::new()
150            .layer(NormalizePathLayer::trim_trailing_slash())
151            .service_fn(handle);
152
153        let body = svc
154            .call(Request::builder().uri("/foo/").body(()).unwrap())
155            .await
156            .unwrap()
157            .into_body();
158
159        assert_eq!(body, "/foo");
160    }
161
162    #[test]
163    fn is_noop_if_no_trailing_slash() {
164        let mut uri = "/foo".parse::<Uri>().unwrap();
165        normalize_trailing_slash(&mut uri);
166        assert_eq!(uri, "/foo");
167    }
168
169    #[test]
170    fn maintains_query() {
171        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
172        normalize_trailing_slash(&mut uri);
173        assert_eq!(uri, "/foo?a=a");
174    }
175
176    #[test]
177    fn removes_multiple_trailing_slashes() {
178        let mut uri = "/foo////".parse::<Uri>().unwrap();
179        normalize_trailing_slash(&mut uri);
180        assert_eq!(uri, "/foo");
181    }
182
183    #[test]
184    fn removes_multiple_trailing_slashes_even_with_query() {
185        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
186        normalize_trailing_slash(&mut uri);
187        assert_eq!(uri, "/foo?a=a");
188    }
189
190    #[test]
191    fn is_noop_on_index() {
192        let mut uri = "/".parse::<Uri>().unwrap();
193        normalize_trailing_slash(&mut uri);
194        assert_eq!(uri, "/");
195    }
196
197    #[test]
198    fn removes_multiple_trailing_slashes_on_index() {
199        let mut uri = "////".parse::<Uri>().unwrap();
200        normalize_trailing_slash(&mut uri);
201        assert_eq!(uri, "/");
202    }
203
204    #[test]
205    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
206        let mut uri = "////?a=a".parse::<Uri>().unwrap();
207        normalize_trailing_slash(&mut uri);
208        assert_eq!(uri, "/?a=a");
209    }
210
211    #[test]
212    fn removes_multiple_preceding_slashes_even_with_query() {
213        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
214        normalize_trailing_slash(&mut uri);
215        assert_eq!(uri, "/foo?a=a");
216    }
217
218    #[test]
219    fn removes_multiple_preceding_slashes() {
220        let mut uri = "///foo".parse::<Uri>().unwrap();
221        normalize_trailing_slash(&mut uri);
222        assert_eq!(uri, "/foo");
223    }
224}