tower_async_http/
normalize_path.rs1use http::{Request, Response, Uri};
41use std::borrow::Cow;
42use tower_async_layer::Layer;
43use tower_async_service::Service;
44
45#[derive(Debug, Copy, Clone)]
49pub struct NormalizePathLayer {}
50
51impl NormalizePathLayer {
52 pub fn trim_trailing_slash() -> Self {
57 NormalizePathLayer {}
58 }
59}
60
61impl<S> Layer<S> for NormalizePathLayer {
62 type Service = NormalizePath<S>;
63
64 fn layer(&self, inner: S) -> Self::Service {
65 NormalizePath::trim_trailing_slash(inner)
66 }
67}
68
69#[derive(Debug, Copy, Clone)]
73pub struct NormalizePath<S> {
74 inner: S,
75}
76
77impl<S> NormalizePath<S> {
78 pub fn trim_trailing_slash(inner: S) -> Self {
83 Self { inner }
84 }
85
86 define_inner_service_accessors!();
87}
88
89impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
90where
91 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
92{
93 type Response = S::Response;
94 type Error = S::Error;
95
96 async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
97 normalize_trailing_slash(req.uri_mut());
98 self.inner.call(req).await
99 }
100}
101
102fn normalize_trailing_slash(uri: &mut Uri) {
103 if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
104 return;
105 }
106
107 let new_path = format!("/{}", uri.path().trim_matches('/'));
108
109 let mut parts = uri.clone().into_parts();
110
111 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
112 let new_path = if new_path.is_empty() {
113 "/"
114 } else {
115 new_path.as_str()
116 };
117
118 let new_path_and_query = if let Some(query) = path_and_query.query() {
119 Cow::Owned(format!("{}?{}", new_path, query))
120 } else {
121 new_path.into()
122 }
123 .parse()
124 .unwrap();
125
126 Some(new_path_and_query)
127 } else {
128 None
129 };
130
131 parts.path_and_query = new_path_and_query;
132 if let Ok(new_uri) = Uri::from_parts(parts) {
133 *uri = new_uri;
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use std::convert::Infallible;
141 use tower_async::ServiceBuilder;
142
143 #[tokio::test]
144 async fn works() {
145 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
146 Ok(Response::new(request.uri().to_string()))
147 }
148
149 let svc = ServiceBuilder::new()
150 .layer(NormalizePathLayer::trim_trailing_slash())
151 .service_fn(handle);
152
153 let body = svc
154 .call(Request::builder().uri("/foo/").body(()).unwrap())
155 .await
156 .unwrap()
157 .into_body();
158
159 assert_eq!(body, "/foo");
160 }
161
162 #[test]
163 fn is_noop_if_no_trailing_slash() {
164 let mut uri = "/foo".parse::<Uri>().unwrap();
165 normalize_trailing_slash(&mut uri);
166 assert_eq!(uri, "/foo");
167 }
168
169 #[test]
170 fn maintains_query() {
171 let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
172 normalize_trailing_slash(&mut uri);
173 assert_eq!(uri, "/foo?a=a");
174 }
175
176 #[test]
177 fn removes_multiple_trailing_slashes() {
178 let mut uri = "/foo////".parse::<Uri>().unwrap();
179 normalize_trailing_slash(&mut uri);
180 assert_eq!(uri, "/foo");
181 }
182
183 #[test]
184 fn removes_multiple_trailing_slashes_even_with_query() {
185 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
186 normalize_trailing_slash(&mut uri);
187 assert_eq!(uri, "/foo?a=a");
188 }
189
190 #[test]
191 fn is_noop_on_index() {
192 let mut uri = "/".parse::<Uri>().unwrap();
193 normalize_trailing_slash(&mut uri);
194 assert_eq!(uri, "/");
195 }
196
197 #[test]
198 fn removes_multiple_trailing_slashes_on_index() {
199 let mut uri = "////".parse::<Uri>().unwrap();
200 normalize_trailing_slash(&mut uri);
201 assert_eq!(uri, "/");
202 }
203
204 #[test]
205 fn removes_multiple_trailing_slashes_on_index_even_with_query() {
206 let mut uri = "////?a=a".parse::<Uri>().unwrap();
207 normalize_trailing_slash(&mut uri);
208 assert_eq!(uri, "/?a=a");
209 }
210
211 #[test]
212 fn removes_multiple_preceding_slashes_even_with_query() {
213 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
214 normalize_trailing_slash(&mut uri);
215 assert_eq!(uri, "/foo?a=a");
216 }
217
218 #[test]
219 fn removes_multiple_preceding_slashes() {
220 let mut uri = "///foo".parse::<Uri>().unwrap();
221 normalize_trailing_slash(&mut uri);
222 assert_eq!(uri, "/foo");
223 }
224}