tower_http/
normalize_path.rs1use http::{Request, Response, Uri};
41use std::{
42 borrow::Cow,
43 task::{Context, Poll},
44};
45use tower_layer::Layer;
46use tower_service::Service;
47
48#[derive(Debug, Copy, Clone)]
52pub struct NormalizePathLayer {}
53
54impl NormalizePathLayer {
55 pub fn trim_trailing_slash() -> Self {
60 NormalizePathLayer {}
61 }
62}
63
64impl<S> Layer<S> for NormalizePathLayer {
65 type Service = NormalizePath<S>;
66
67 fn layer(&self, inner: S) -> Self::Service {
68 NormalizePath::trim_trailing_slash(inner)
69 }
70}
71
72#[derive(Debug, Copy, Clone)]
76pub struct NormalizePath<S> {
77 inner: S,
78}
79
80impl<S> NormalizePath<S> {
81 pub fn trim_trailing_slash(inner: S) -> Self {
86 Self { inner }
87 }
88
89 define_inner_service_accessors!();
90}
91
92impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
93where
94 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
95{
96 type Response = S::Response;
97 type Error = S::Error;
98 type Future = S::Future;
99
100 #[inline]
101 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102 self.inner.poll_ready(cx)
103 }
104
105 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106 normalize_trailing_slash(req.uri_mut());
107 self.inner.call(req)
108 }
109}
110
111fn normalize_trailing_slash(uri: &mut Uri) {
112 if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
113 return;
114 }
115
116 let new_path = format!("/{}", uri.path().trim_matches('/'));
117
118 let mut parts = uri.clone().into_parts();
119
120 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
121 let new_path_and_query = if let Some(query) = path_and_query.query() {
122 Cow::Owned(format!("{}?{}", new_path, query))
123 } else {
124 new_path.into()
125 }
126 .parse()
127 .unwrap();
128
129 Some(new_path_and_query)
130 } else {
131 None
132 };
133
134 parts.path_and_query = new_path_and_query;
135 if let Ok(new_uri) = Uri::from_parts(parts) {
136 *uri = new_uri;
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use std::convert::Infallible;
144 use tower::{ServiceBuilder, ServiceExt};
145
146 #[tokio::test]
147 async fn works() {
148 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
149 Ok(Response::new(request.uri().to_string()))
150 }
151
152 let mut svc = ServiceBuilder::new()
153 .layer(NormalizePathLayer::trim_trailing_slash())
154 .service_fn(handle);
155
156 let body = svc
157 .ready()
158 .await
159 .unwrap()
160 .call(Request::builder().uri("/foo/").body(()).unwrap())
161 .await
162 .unwrap()
163 .into_body();
164
165 assert_eq!(body, "/foo");
166 }
167
168 #[test]
169 fn is_noop_if_no_trailing_slash() {
170 let mut uri = "/foo".parse::<Uri>().unwrap();
171 normalize_trailing_slash(&mut uri);
172 assert_eq!(uri, "/foo");
173 }
174
175 #[test]
176 fn maintains_query() {
177 let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
178 normalize_trailing_slash(&mut uri);
179 assert_eq!(uri, "/foo?a=a");
180 }
181
182 #[test]
183 fn removes_multiple_trailing_slashes() {
184 let mut uri = "/foo////".parse::<Uri>().unwrap();
185 normalize_trailing_slash(&mut uri);
186 assert_eq!(uri, "/foo");
187 }
188
189 #[test]
190 fn removes_multiple_trailing_slashes_even_with_query() {
191 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
192 normalize_trailing_slash(&mut uri);
193 assert_eq!(uri, "/foo?a=a");
194 }
195
196 #[test]
197 fn is_noop_on_index() {
198 let mut uri = "/".parse::<Uri>().unwrap();
199 normalize_trailing_slash(&mut uri);
200 assert_eq!(uri, "/");
201 }
202
203 #[test]
204 fn removes_multiple_trailing_slashes_on_index() {
205 let mut uri = "////".parse::<Uri>().unwrap();
206 normalize_trailing_slash(&mut uri);
207 assert_eq!(uri, "/");
208 }
209
210 #[test]
211 fn removes_multiple_trailing_slashes_on_index_even_with_query() {
212 let mut uri = "////?a=a".parse::<Uri>().unwrap();
213 normalize_trailing_slash(&mut uri);
214 assert_eq!(uri, "/?a=a");
215 }
216
217 #[test]
218 fn removes_multiple_preceding_slashes_even_with_query() {
219 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
220 normalize_trailing_slash(&mut uri);
221 assert_eq!(uri, "/foo?a=a");
222 }
223
224 #[test]
225 fn removes_multiple_preceding_slashes() {
226 let mut uri = "///foo".parse::<Uri>().unwrap();
227 normalize_trailing_slash(&mut uri);
228 assert_eq!(uri, "/foo");
229 }
230}