tower_http/
normalize_path.rs1use http::{Request, Response, Uri};
41use std::{
42    borrow::Cow,
43    task::{Context, Poll},
44};
45use tower_layer::Layer;
46use tower_service::Service;
47
48#[derive(Debug, Copy, Clone)]
52pub struct NormalizePathLayer {}
53
54impl NormalizePathLayer {
55    pub fn trim_trailing_slash() -> Self {
60        NormalizePathLayer {}
61    }
62}
63
64impl<S> Layer<S> for NormalizePathLayer {
65    type Service = NormalizePath<S>;
66
67    fn layer(&self, inner: S) -> Self::Service {
68        NormalizePath::trim_trailing_slash(inner)
69    }
70}
71
72#[derive(Debug, Copy, Clone)]
76pub struct NormalizePath<S> {
77    inner: S,
78}
79
80impl<S> NormalizePath<S> {
81    pub fn trim_trailing_slash(inner: S) -> Self {
86        Self { inner }
87    }
88
89    define_inner_service_accessors!();
90}
91
92impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
93where
94    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
95{
96    type Response = S::Response;
97    type Error = S::Error;
98    type Future = S::Future;
99
100    #[inline]
101    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        self.inner.poll_ready(cx)
103    }
104
105    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106        normalize_trailing_slash(req.uri_mut());
107        self.inner.call(req)
108    }
109}
110
111fn normalize_trailing_slash(uri: &mut Uri) {
112    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
113        return;
114    }
115
116    let new_path = format!("/{}", uri.path().trim_matches('/'));
117
118    let mut parts = uri.clone().into_parts();
119
120    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
121        let new_path_and_query = if let Some(query) = path_and_query.query() {
122            Cow::Owned(format!("{}?{}", new_path, query))
123        } else {
124            new_path.into()
125        }
126        .parse()
127        .unwrap();
128
129        Some(new_path_and_query)
130    } else {
131        None
132    };
133
134    parts.path_and_query = new_path_and_query;
135    if let Ok(new_uri) = Uri::from_parts(parts) {
136        *uri = new_uri;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use std::convert::Infallible;
144    use tower::{ServiceBuilder, ServiceExt};
145
146    #[tokio::test]
147    async fn works() {
148        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
149            Ok(Response::new(request.uri().to_string()))
150        }
151
152        let mut svc = ServiceBuilder::new()
153            .layer(NormalizePathLayer::trim_trailing_slash())
154            .service_fn(handle);
155
156        let body = svc
157            .ready()
158            .await
159            .unwrap()
160            .call(Request::builder().uri("/foo/").body(()).unwrap())
161            .await
162            .unwrap()
163            .into_body();
164
165        assert_eq!(body, "/foo");
166    }
167
168    #[test]
169    fn is_noop_if_no_trailing_slash() {
170        let mut uri = "/foo".parse::<Uri>().unwrap();
171        normalize_trailing_slash(&mut uri);
172        assert_eq!(uri, "/foo");
173    }
174
175    #[test]
176    fn maintains_query() {
177        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
178        normalize_trailing_slash(&mut uri);
179        assert_eq!(uri, "/foo?a=a");
180    }
181
182    #[test]
183    fn removes_multiple_trailing_slashes() {
184        let mut uri = "/foo////".parse::<Uri>().unwrap();
185        normalize_trailing_slash(&mut uri);
186        assert_eq!(uri, "/foo");
187    }
188
189    #[test]
190    fn removes_multiple_trailing_slashes_even_with_query() {
191        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
192        normalize_trailing_slash(&mut uri);
193        assert_eq!(uri, "/foo?a=a");
194    }
195
196    #[test]
197    fn is_noop_on_index() {
198        let mut uri = "/".parse::<Uri>().unwrap();
199        normalize_trailing_slash(&mut uri);
200        assert_eq!(uri, "/");
201    }
202
203    #[test]
204    fn removes_multiple_trailing_slashes_on_index() {
205        let mut uri = "////".parse::<Uri>().unwrap();
206        normalize_trailing_slash(&mut uri);
207        assert_eq!(uri, "/");
208    }
209
210    #[test]
211    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
212        let mut uri = "////?a=a".parse::<Uri>().unwrap();
213        normalize_trailing_slash(&mut uri);
214        assert_eq!(uri, "/?a=a");
215    }
216
217    #[test]
218    fn removes_multiple_preceding_slashes_even_with_query() {
219        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
220        normalize_trailing_slash(&mut uri);
221        assert_eq!(uri, "/foo?a=a");
222    }
223
224    #[test]
225    fn removes_multiple_preceding_slashes() {
226        let mut uri = "///foo".parse::<Uri>().unwrap();
227        normalize_trailing_slash(&mut uri);
228        assert_eq!(uri, "/foo");
229    }
230}