tower_sanitize_path/
lib.rs1#![doc = include_str!("../README.md")]
2
3use std::{
4 borrow::Cow,
5 path::{Component, PathBuf},
6 str::FromStr,
7 task::{Context, Poll},
8};
9
10use http::{Request, Response, Uri};
11use tower_layer::Layer;
12use tower_service::Service;
13use url_escape::decode;
14
15pub struct SanitizePathLayer;
19
20impl<S> Layer<S> for SanitizePathLayer {
21 type Service = SanitizePath<S>;
22
23 fn layer(&self, inner: S) -> Self::Service {
24 SanitizePath::sanitize_paths(inner)
25 }
26}
27
28#[derive(Clone, Copy, Debug)]
32pub struct SanitizePath<S> {
33 inner: S,
34}
35
36impl<S> SanitizePath<S> {
37 pub fn sanitize_paths(inner: S) -> Self {
41 Self { inner }
42 }
43
44 pub fn inner(&self) -> &S {
46 &self.inner
47 }
48}
49
50impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SanitizePath<S>
51where
52 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
53{
54 type Response = S::Response;
55 type Error = S::Error;
56 type Future = S::Future;
57
58 #[inline]
59 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
60 self.inner.poll_ready(cx)
61 }
62
63 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
64 sanitize_path(req.uri_mut());
65
66 self.inner.call(req)
67 }
68}
69
70fn sanitize_path(uri: &mut Uri) {
71 let path = uri.path();
72 let path_decoded = decode(path);
73
74 let trailing_slash = path_decoded.len() > 1
77 && path_decoded
78 .chars()
79 .last()
80 .filter(|char| *char == '/')
81 .is_some();
82
83 let path_buf = PathBuf::from_str(&path_decoded).expect("infallible");
84
85 let mut new_path = path_buf
86 .components()
87 .filter(|c| matches!(c, Component::RootDir | Component::Normal(_)))
88 .collect::<PathBuf>()
89 .display()
90 .to_string();
91
92 if trailing_slash {
95 new_path += "/";
96 }
97
98 if path == new_path {
99 return;
100 }
101
102 let mut parts = uri.clone().into_parts();
103
104 let new_path_and_query = if let Some(path_and_query) = parts.path_and_query {
105 let new_path_and_query = if let Some(query) = path_and_query.query() {
106 Cow::Owned(format!("{new_path}?{query}"))
107 } else {
108 new_path.into()
109 }
110 .parse()
111 .expect("url to still be valid");
112
113 Some(new_path_and_query)
114 } else {
115 None
116 };
117
118 parts.path_and_query = new_path_and_query;
119 if let Ok(new_uri) = Uri::from_parts(parts) {
120 *uri = new_uri;
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::convert::Infallible;
127
128 use tower::{ServiceBuilder, ServiceExt};
129
130 use super::*;
131
132 #[tokio::test]
133 async fn layer() {
134 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
135 Ok(Response::new(request.uri().to_string()))
136 }
137
138 let mut svc = ServiceBuilder::new()
139 .layer(SanitizePathLayer)
140 .service_fn(handle);
141
142 let body = svc
143 .ready()
144 .await
145 .unwrap()
146 .call(Request::builder().uri("/../../secret").body(()).unwrap())
147 .await
148 .unwrap()
149 .into_body();
150
151 assert_eq!(body, "/secret");
152 }
153
154 #[test]
155 fn no_path() {
156 let mut uri = "/".parse().unwrap();
157 sanitize_path(&mut uri);
158
159 assert_eq!(uri, "/");
160 }
161
162 #[test]
163 fn maintain_query() {
164 let mut uri = "/?test".parse().unwrap();
165 sanitize_path(&mut uri);
166
167 assert_eq!(uri, "/?test");
168 }
169
170 #[test]
171 fn path_maintain_query() {
172 let mut uri = "/path?test=true".parse().unwrap();
173 sanitize_path(&mut uri);
174
175 assert_eq!(uri, "/path?test=true");
176 }
177
178 #[test]
179 fn remove_path_parent_traversal() {
180 let mut uri = "/../../path".parse().unwrap();
181 sanitize_path(&mut uri);
182
183 assert_eq!(uri, "/path");
184 }
185
186 #[test]
187 fn remove_path_parent_traversal_maintain_query() {
188 let mut uri = "/../../path?name=John".parse().unwrap();
189 sanitize_path(&mut uri);
190
191 assert_eq!(uri, "/path?name=John");
192 }
193
194 #[test]
195 fn remove_path_current_traversal() {
196 let mut uri = "/.././path".parse().unwrap();
197 sanitize_path(&mut uri);
198
199 assert_eq!(uri, "/path");
200 }
201
202 #[test]
203 fn remove_path_encoded_traversal() {
204 let mut uri = "/..%2f..%2fpath".parse().unwrap();
205 sanitize_path(&mut uri);
206
207 assert_eq!(uri, "/path");
208 }
209
210 #[test]
211 fn keep_trailing_slash() {
212 let mut uri = "/path/".parse().unwrap();
213 sanitize_path(&mut uri);
214
215 assert_eq!(uri, "/path/");
216 }
217
218 #[test]
219 fn keep_only_one_trailing_slash() {
220 let mut uri = "/path//".parse().unwrap();
221 sanitize_path(&mut uri);
222
223 assert_eq!(uri, "/path/");
224 }
225}