tower_http/
normalize_path.rs1use http::{Request, Response, Uri};
38use std::{
39 borrow::Cow,
40 task::{Context, Poll},
41};
42use tower_layer::Layer;
43use tower_service::Service;
44
45#[derive(Debug, Copy, Clone)]
47enum NormalizeMode {
48 Trim,
50 Append,
52}
53
54#[derive(Debug, Copy, Clone)]
58pub struct NormalizePathLayer {
59 mode: NormalizeMode,
60}
61
62impl NormalizePathLayer {
63 pub fn trim_trailing_slash() -> Self {
68 NormalizePathLayer {
69 mode: NormalizeMode::Trim,
70 }
71 }
72
73 pub fn append_trailing_slash() -> Self {
78 NormalizePathLayer {
79 mode: NormalizeMode::Append,
80 }
81 }
82}
83
84impl<S> Layer<S> for NormalizePathLayer {
85 type Service = NormalizePath<S>;
86
87 fn layer(&self, inner: S) -> Self::Service {
88 NormalizePath {
89 mode: self.mode,
90 inner,
91 }
92 }
93}
94
95#[derive(Debug, Copy, Clone)]
99pub struct NormalizePath<S> {
100 mode: NormalizeMode,
101 inner: S,
102}
103
104impl<S> NormalizePath<S> {
105 pub fn trim_trailing_slash(inner: S) -> Self {
107 Self {
108 mode: NormalizeMode::Trim,
109 inner,
110 }
111 }
112
113 pub fn append_trailing_slash(inner: S) -> Self {
115 Self {
116 mode: NormalizeMode::Append,
117 inner,
118 }
119 }
120
121 define_inner_service_accessors!();
122}
123
124impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
125where
126 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
127{
128 type Response = S::Response;
129 type Error = S::Error;
130 type Future = S::Future;
131
132 #[inline]
133 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134 self.inner.poll_ready(cx)
135 }
136
137 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
138 match self.mode {
139 NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
140 NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
141 }
142 self.inner.call(req)
143 }
144}
145
146fn trim_trailing_slash(uri: &mut Uri) {
147 if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
148 return;
149 }
150
151 let new_path = format!("/{}", uri.path().trim_matches('/'));
152
153 let mut parts = uri.clone().into_parts();
154
155 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
156 let new_path_and_query = if let Some(query) = path_and_query.query() {
157 Cow::Owned(format!("{}?{}", new_path, query))
158 } else {
159 new_path.into()
160 }
161 .parse()
162 .unwrap();
163
164 Some(new_path_and_query)
165 } else {
166 None
167 };
168
169 parts.path_and_query = new_path_and_query;
170 if let Ok(new_uri) = Uri::from_parts(parts) {
171 *uri = new_uri;
172 }
173}
174
175fn append_trailing_slash(uri: &mut Uri) {
176 if uri.path().ends_with("/") && !uri.path().ends_with("//") {
177 return;
178 }
179
180 let trimmed = uri.path().trim_matches('/');
181 let new_path = if trimmed.is_empty() {
182 "/".to_string()
183 } else {
184 format!("/{trimmed}/")
185 };
186
187 let mut parts = uri.clone().into_parts();
188
189 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
190 let new_path_and_query = if let Some(query) = path_and_query.query() {
191 Cow::Owned(format!("{new_path}?{query}"))
192 } else {
193 new_path.into()
194 }
195 .parse()
196 .unwrap();
197
198 Some(new_path_and_query)
199 } else {
200 Some(new_path.parse().unwrap())
201 };
202
203 parts.path_and_query = new_path_and_query;
204 if let Ok(new_uri) = Uri::from_parts(parts) {
205 *uri = new_uri;
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use std::convert::Infallible;
213 use tower::{ServiceBuilder, ServiceExt};
214
215 #[tokio::test]
216 async fn trim_works() {
217 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
218 Ok(Response::new(request.uri().to_string()))
219 }
220
221 let mut svc = ServiceBuilder::new()
222 .layer(NormalizePathLayer::trim_trailing_slash())
223 .service_fn(handle);
224
225 let body = svc
226 .ready()
227 .await
228 .unwrap()
229 .call(Request::builder().uri("/foo/").body(()).unwrap())
230 .await
231 .unwrap()
232 .into_body();
233
234 assert_eq!(body, "/foo");
235 }
236
237 #[test]
238 fn is_noop_if_no_trailing_slash() {
239 let mut uri = "/foo".parse::<Uri>().unwrap();
240 trim_trailing_slash(&mut uri);
241 assert_eq!(uri, "/foo");
242 }
243
244 #[test]
245 fn maintains_query() {
246 let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
247 trim_trailing_slash(&mut uri);
248 assert_eq!(uri, "/foo?a=a");
249 }
250
251 #[test]
252 fn removes_multiple_trailing_slashes() {
253 let mut uri = "/foo////".parse::<Uri>().unwrap();
254 trim_trailing_slash(&mut uri);
255 assert_eq!(uri, "/foo");
256 }
257
258 #[test]
259 fn removes_multiple_trailing_slashes_even_with_query() {
260 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
261 trim_trailing_slash(&mut uri);
262 assert_eq!(uri, "/foo?a=a");
263 }
264
265 #[test]
266 fn is_noop_on_index() {
267 let mut uri = "/".parse::<Uri>().unwrap();
268 trim_trailing_slash(&mut uri);
269 assert_eq!(uri, "/");
270 }
271
272 #[test]
273 fn removes_multiple_trailing_slashes_on_index() {
274 let mut uri = "////".parse::<Uri>().unwrap();
275 trim_trailing_slash(&mut uri);
276 assert_eq!(uri, "/");
277 }
278
279 #[test]
280 fn removes_multiple_trailing_slashes_on_index_even_with_query() {
281 let mut uri = "////?a=a".parse::<Uri>().unwrap();
282 trim_trailing_slash(&mut uri);
283 assert_eq!(uri, "/?a=a");
284 }
285
286 #[test]
287 fn removes_multiple_preceding_slashes_even_with_query() {
288 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
289 trim_trailing_slash(&mut uri);
290 assert_eq!(uri, "/foo?a=a");
291 }
292
293 #[test]
294 fn removes_multiple_preceding_slashes() {
295 let mut uri = "///foo".parse::<Uri>().unwrap();
296 trim_trailing_slash(&mut uri);
297 assert_eq!(uri, "/foo");
298 }
299
300 #[tokio::test]
301 async fn append_works() {
302 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
303 Ok(Response::new(request.uri().to_string()))
304 }
305
306 let mut svc = ServiceBuilder::new()
307 .layer(NormalizePathLayer::append_trailing_slash())
308 .service_fn(handle);
309
310 let body = svc
311 .ready()
312 .await
313 .unwrap()
314 .call(Request::builder().uri("/foo").body(()).unwrap())
315 .await
316 .unwrap()
317 .into_body();
318
319 assert_eq!(body, "/foo/");
320 }
321
322 #[test]
323 fn is_noop_if_trailing_slash() {
324 let mut uri = "/foo/".parse::<Uri>().unwrap();
325 append_trailing_slash(&mut uri);
326 assert_eq!(uri, "/foo/");
327 }
328
329 #[test]
330 fn append_maintains_query() {
331 let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
332 append_trailing_slash(&mut uri);
333 assert_eq!(uri, "/foo/?a=a");
334 }
335
336 #[test]
337 fn append_only_keeps_one_slash() {
338 let mut uri = "/foo////".parse::<Uri>().unwrap();
339 append_trailing_slash(&mut uri);
340 assert_eq!(uri, "/foo/");
341 }
342
343 #[test]
344 fn append_only_keeps_one_slash_even_with_query() {
345 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
346 append_trailing_slash(&mut uri);
347 assert_eq!(uri, "/foo/?a=a");
348 }
349
350 #[test]
351 fn append_is_noop_on_index() {
352 let mut uri = "/".parse::<Uri>().unwrap();
353 append_trailing_slash(&mut uri);
354 assert_eq!(uri, "/");
355 }
356
357 #[test]
358 fn append_removes_multiple_trailing_slashes_on_index() {
359 let mut uri = "////".parse::<Uri>().unwrap();
360 append_trailing_slash(&mut uri);
361 assert_eq!(uri, "/");
362 }
363
364 #[test]
365 fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
366 let mut uri = "////?a=a".parse::<Uri>().unwrap();
367 append_trailing_slash(&mut uri);
368 assert_eq!(uri, "/?a=a");
369 }
370
371 #[test]
372 fn append_removes_multiple_preceding_slashes_even_with_query() {
373 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
374 append_trailing_slash(&mut uri);
375 assert_eq!(uri, "/foo/?a=a");
376 }
377
378 #[test]
379 fn append_removes_multiple_preceding_slashes() {
380 let mut uri = "///foo".parse::<Uri>().unwrap();
381 append_trailing_slash(&mut uri);
382 assert_eq!(uri, "/foo/");
383 }
384}