tower_web/extract/
osstring.rs

1use extract::{Context, Error, Extract, Immediate};
2use percent_encoding;
3use std::borrow::Cow;
4use std::ffi::{OsStr, OsString};
5use std::str;
6use util::buf_stream::BufStream;
7
8fn osstr_from_bytes(bytes: &[u8]) -> Result<&OsStr, Error> {
9    // NOTE: this is too conservative, as we are rejecting valid paths on Unix
10    str::from_utf8(bytes)
11        .map_err(|e| Error::invalid_argument(&e))
12        .map(|s| OsStr::new(s))
13}
14
15fn decode(s: &str) -> Result<OsString, Error> {
16    let percent_decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
17    Ok(osstr_from_bytes(percent_decoded.as_ref())?.to_os_string())
18}
19
20impl<B: BufStream> Extract<B> for OsString {
21    type Future = Immediate<Self>;
22
23    fn extract(ctx: &Context) -> Self::Future {
24        use codegen::Source::*;
25
26        match ctx.callsite().source() {
27            Capture(idx) => {
28                let path = ctx.request().uri().path();
29                let value = ctx.captures().get(*idx, path);
30
31                Immediate::result(decode(value))
32            }
33            Header(header_name) => {
34                let value = match ctx.request().headers().get(header_name) {
35                    Some(value) => value,
36                    None => {
37                        return Immediate::err(Error::missing_argument());
38                    }
39                };
40
41                let r = value
42                    .to_str()
43                    .map(OsString::from)
44                    .map_err(|e| Error::invalid_argument(&e));
45                Immediate::result(r)
46            }
47            QueryString => {
48                let query = ctx.request().uri().query().unwrap_or("");
49
50                Immediate::result(decode(query))
51            }
52            Body => {
53                unimplemented!();
54            }
55            Unknown => {
56                unimplemented!();
57            }
58        }
59    }
60}
61
62#[cfg(test)]
63mod test {
64    use super::*;
65    use std::path::Path;
66
67    #[test]
68    fn extract() {
69        assert_eq!(Path::new("hello, world"), decode("hello,%20world").unwrap());
70    }
71
72    #[test]
73    fn disallows_path_traversal() {
74        assert_eq!(decode("foo").unwrap(), OsString::from("foo"));
75        assert_eq!(decode("foo%20bar").unwrap(), OsString::from("foo bar"));
76    }
77}