tower_web/extract/
pathbuf.rs1use extract::{Context, Error, Extract, Immediate};
2use std::ffi::{OsStr, OsString};
3use std::path::{self, Path, PathBuf};
4use util::buf_stream::BufStream;
5
6fn check_for_path_traversal(path: &Path) -> Result<(), Error> {
8 use self::path::Component::*;
9
10 let path_traversal_error = || Error::invalid_argument(&"Path traversal detected");
11
12 let mut depth = 0u32;
13 for c in path.components() {
14 match c {
15 Prefix(_) | RootDir => {
16 Err(path_traversal_error())?
18 }
19 CurDir => {
20 }
22 ParentDir => {
23 depth = match depth.checked_sub(1) {
24 Some(v) => v,
25 None => Err(path_traversal_error())?,
26 }
27 }
28 Normal(_) => {
29 depth += 1;
30 }
31 }
32 }
33
34 Ok(())
35}
36
37fn decode(s: &OsStr) -> Result<PathBuf, Error> {
38 let path = PathBuf::from(s);
39 check_for_path_traversal(&path)?;
40 Ok(path)
41}
42
43impl<B: BufStream> Extract<B> for PathBuf {
44 type Future = Immediate<Self>;
45
46 fn extract(ctx: &Context) -> Self::Future {
47 use extract::ExtractFuture;
48
49 let s = <OsString as Extract<B>>::extract(ctx).extract();
50 Immediate::result(decode(&s))
51 }
52}
53
54#[cfg(test)]
55mod test {
56 use super::*;
57 use std::path::Path;
58
59 #[test]
60 fn extract() {
61 assert_eq!(
62 decode(OsStr::new("hello, world")).unwrap(),
63 Path::new("hello, world")
64 );
65 }
66
67 #[test]
68 fn disallows_path_traversal() {
69 assert!(decode(OsStr::new("/")).unwrap_err().is_invalid_argument());
70 assert!(decode(OsStr::new("..")).unwrap_err().is_invalid_argument());
71 assert_eq!(decode(OsStr::new("a/..")).unwrap(), Path::new("a/.."));
72 assert!(
73 decode(OsStr::new("../a"))
74 .unwrap_err()
75 .is_invalid_argument()
76 );
77 assert!(
78 decode(OsStr::new("../a/b"))
79 .unwrap_err()
80 .is_invalid_argument()
81 );
82 assert_eq!(decode(OsStr::new("a/../b")).unwrap(), Path::new("a/../b"));
83 assert_eq!(decode(OsStr::new("a/b/..")).unwrap(), Path::new("a/b/.."));
84 }
85}