tower_sanitize_path/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    borrow::Cow,
5    path::{Component, PathBuf},
6    str::FromStr,
7    task::{Context, Poll},
8};
9
10use http::{Request, Response, Uri};
11use tower_layer::Layer;
12use tower_service::Service;
13use url_escape::decode;
14
15/// Layer that applies [`SanitizePath`] which sanitizes paths.
16///
17/// See the [module docs](self) for more details.
18pub struct SanitizePathLayer;
19
20impl<S> Layer<S> for SanitizePathLayer {
21    type Service = SanitizePath<S>;
22
23    fn layer(&self, inner: S) -> Self::Service {
24        SanitizePath::sanitize_paths(inner)
25    }
26}
27
28/// Middleware to remove filesystem path traversals attempts from URL paths.
29///
30/// See the [module docs](self) for more details.
31#[derive(Clone, Copy, Debug)]
32pub struct SanitizePath<S> {
33    inner: S,
34}
35
36impl<S> SanitizePath<S> {
37    /// Sanitize all paths for the given service.
38    ///
39    /// This will make all paths on the URL safe for the service to consume.
40    pub fn sanitize_paths(inner: S) -> Self {
41        Self { inner }
42    }
43
44    /// Access the wrapped service.
45    pub fn inner(&self) -> &S {
46        &self.inner
47    }
48}
49
50impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SanitizePath<S>
51where
52    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
53{
54    type Response = S::Response;
55    type Error = S::Error;
56    type Future = S::Future;
57
58    #[inline]
59    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
60        self.inner.poll_ready(cx)
61    }
62
63    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
64        sanitize_path(req.uri_mut());
65
66        self.inner.call(req)
67    }
68}
69
70fn sanitize_path(uri: &mut Uri) {
71    let path = uri.path();
72    let path_decoded = decode(path);
73
74    // Check if the path contains a trailing slash and that it is not the only
75    // character.
76    let trailing_slash = path_decoded.len() > 1
77        && path_decoded
78            .chars()
79            .last()
80            .filter(|char| *char == '/')
81            .is_some();
82
83    let path_buf = PathBuf::from_str(&path_decoded).expect("infallible");
84
85    let mut new_path = path_buf
86        .components()
87        .filter(|c| matches!(c, Component::RootDir | Component::Normal(_)))
88        .collect::<PathBuf>()
89        .display()
90        .to_string();
91
92    // Path::components above will normalize away the trailing slash if there is one,
93    // so we add it back.
94    if trailing_slash {
95        new_path += "/";
96    }
97
98    if path == new_path {
99        return;
100    }
101
102    let mut parts = uri.clone().into_parts();
103
104    let new_path_and_query = if let Some(path_and_query) = parts.path_and_query {
105        let new_path_and_query = if let Some(query) = path_and_query.query() {
106            Cow::Owned(format!("{new_path}?{query}"))
107        } else {
108            new_path.into()
109        }
110        .parse()
111        .expect("url to still be valid");
112
113        Some(new_path_and_query)
114    } else {
115        None
116    };
117
118    parts.path_and_query = new_path_and_query;
119    if let Ok(new_uri) = Uri::from_parts(parts) {
120        *uri = new_uri;
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use std::convert::Infallible;
127
128    use tower::{ServiceBuilder, ServiceExt};
129
130    use super::*;
131
132    #[tokio::test]
133    async fn layer() {
134        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
135            Ok(Response::new(request.uri().to_string()))
136        }
137
138        let mut svc = ServiceBuilder::new()
139            .layer(SanitizePathLayer)
140            .service_fn(handle);
141
142        let body = svc
143            .ready()
144            .await
145            .unwrap()
146            .call(Request::builder().uri("/../../secret").body(()).unwrap())
147            .await
148            .unwrap()
149            .into_body();
150
151        assert_eq!(body, "/secret");
152    }
153
154    #[test]
155    fn no_path() {
156        let mut uri = "/".parse().unwrap();
157        sanitize_path(&mut uri);
158
159        assert_eq!(uri, "/");
160    }
161
162    #[test]
163    fn maintain_query() {
164        let mut uri = "/?test".parse().unwrap();
165        sanitize_path(&mut uri);
166
167        assert_eq!(uri, "/?test");
168    }
169
170    #[test]
171    fn path_maintain_query() {
172        let mut uri = "/path?test=true".parse().unwrap();
173        sanitize_path(&mut uri);
174
175        assert_eq!(uri, "/path?test=true");
176    }
177
178    #[test]
179    fn remove_path_parent_traversal() {
180        let mut uri = "/../../path".parse().unwrap();
181        sanitize_path(&mut uri);
182
183        assert_eq!(uri, "/path");
184    }
185
186    #[test]
187    fn remove_path_parent_traversal_maintain_query() {
188        let mut uri = "/../../path?name=John".parse().unwrap();
189        sanitize_path(&mut uri);
190
191        assert_eq!(uri, "/path?name=John");
192    }
193
194    #[test]
195    fn remove_path_current_traversal() {
196        let mut uri = "/.././path".parse().unwrap();
197        sanitize_path(&mut uri);
198
199        assert_eq!(uri, "/path");
200    }
201
202    #[test]
203    fn remove_path_encoded_traversal() {
204        let mut uri = "/..%2f..%2fpath".parse().unwrap();
205        sanitize_path(&mut uri);
206
207        assert_eq!(uri, "/path");
208    }
209
210    #[test]
211    fn keep_trailing_slash() {
212        let mut uri = "/path/".parse().unwrap();
213        sanitize_path(&mut uri);
214
215        assert_eq!(uri, "/path/");
216    }
217
218    #[test]
219    fn keep_only_one_trailing_slash() {
220        let mut uri = "/path//".parse().unwrap();
221        sanitize_path(&mut uri);
222
223        assert_eq!(uri, "/path/");
224    }
225}