Skip to main content

tower_http/csrf/
url.rs

1use http::Uri;
2
3use super::ConfigError;
4
5/// Internal extension methods on [`http::Uri`] used by the CSRF middleware to
6/// validate trusted-origin strings.
7pub(crate) trait UriExt: Sized {
8    /// Parses a trusted-origin string of the form `scheme://host[:port]`.
9    ///
10    /// Rejects inputs that can't represent a browser `Origin`:
11    ///
12    /// - unparseable URIs ([`ConfigError::InvalidOriginUrl`]);
13    /// - non-`http`/`https` schemes or missing host ([`ConfigError::OpaqueOrigin`]);
14    /// - any path, query, or fragment component
15    ///   ([`ConfigError::InvalidOriginUrlComponents`] — including a bare trailing
16    ///   `/` and fragments that `http::Uri` would otherwise silently strip);
17    /// - non-ASCII hostnames ([`ConfigError::NonAsciiHostname`] — IDN hosts
18    ///   must be supplied in punycode, since that's what browsers send).
19    ///
20    /// The returned [`Uri`] is parsed but not normalized; the origin is matched
21    /// against the request's `Origin` header byte-for-byte.
22    fn parse_origin(input: &str) -> Result<Self, ConfigError>;
23}
24
25impl UriExt for Uri {
26    fn parse_origin(input: &str) -> Result<Self, ConfigError> {
27        if input.contains('#') {
28            return Err(ConfigError::InvalidOriginUrlComponents {
29                origin: input.to_owned(),
30            });
31        }
32
33        // browsers will send punycode anyways
34        if !input.is_ascii() {
35            return Err(ConfigError::NonAsciiHostname {
36                origin: input.to_owned(),
37            });
38        }
39
40        let uri: Uri =
41            input
42                .parse()
43                .map_err(|e: http::uri::InvalidUri| ConfigError::InvalidOriginUrl {
44                    origin: input.to_owned(),
45                    message: e.to_string(),
46                })?;
47
48        if !matches!(uri.scheme_str(), Some("http" | "https"))
49            || uri.host().map_or(true, |h| h.is_empty())
50        {
51            return Err(ConfigError::OpaqueOrigin {
52                origin: input.to_owned(),
53            });
54        }
55
56        // Reject any path/query (fragments are rejected above). `http::Uri`
57        // reports `path()` as "/" for both `scheme://host` and `scheme://host/`,
58        // so detect a path from the raw input (everything after "://") to reach
59        // parity with Go, which rejects a non-empty path — including a bare "/".
60        let after_scheme = input.split_once("://").map_or("", |(_, rest)| rest);
61
62        if after_scheme.contains('/') || uri.query().is_some() {
63            return Err(ConfigError::InvalidOriginUrlComponents {
64                origin: input.to_owned(),
65            });
66        }
67
68        Ok(uri)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_parse_origin_accepts() {
78        for input in [
79            "https://example.com",
80            "http://example.com",
81            "https://example.com:8443",
82            "HTTPS://Example.COM",
83        ] {
84            assert!(
85                Uri::parse_origin(input).is_ok(),
86                "expected Ok for {input:?}, got {:?}",
87                Uri::parse_origin(input)
88            );
89        }
90    }
91
92    #[test]
93    fn test_parse_origin_rejects() {
94        // Each row maps an input to the expected ConfigError variant.
95        // Marker functions over closures because PartialEq on the enum already
96        // makes equality the easy assertion shape.
97        type Check = fn(&ConfigError) -> bool;
98        let cases: &[(&str, Check)] = &[
99            // http::Uri rejects these outright at parse time.
100            ("not a valid url", |e| {
101                matches!(e, ConfigError::InvalidOriginUrl { .. })
102            }),
103            ("https://", |e| {
104                matches!(e, ConfigError::InvalidOriginUrl { .. })
105            }),
106            ("file:///", |e| {
107                matches!(e, ConfigError::InvalidOriginUrl { .. })
108            }),
109            // Parse OK but scheme is not http/https (or absent).
110            ("example.com", |e| {
111                matches!(e, ConfigError::OpaqueOrigin { .. })
112            }),
113            ("file://host/path", |e| {
114                matches!(e, ConfigError::OpaqueOrigin { .. })
115            }),
116            ("mailto:x@y.z", |e| {
117                matches!(e, ConfigError::OpaqueOrigin { .. })
118            }),
119            ("javascript:alert(1)", |e| {
120                matches!(e, ConfigError::OpaqueOrigin { .. })
121            }),
122            // Path/query/fragment not allowed on a trusted origin. A bare
123            // trailing slash is a (non-empty) path too — rejected, matching Go.
124            ("https://example.com/", |e| {
125                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
126            }),
127            ("https://example.com/path", |e| {
128                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
129            }),
130            ("https://example.com/path?query=value", |e| {
131                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
132            }),
133            ("https://example.com/path#fragment", |e| {
134                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
135            }),
136            // http::Uri silently strips fragments; the `contains('#')` pre-check
137            // surfaces these as component errors instead of letting them slip in.
138            ("https://example.com#fragment", |e| {
139                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
140            }),
141            ("https://example.com/#fragment", |e| {
142                matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
143            }),
144            // IDN hosts must be supplied in punycode.
145            ("https://ümlaut.de", |e| {
146                matches!(e, ConfigError::NonAsciiHostname { .. })
147            }),
148            ("https://日本.jp", |e| {
149                matches!(e, ConfigError::NonAsciiHostname { .. })
150            }),
151        ];
152
153        for (input, predicate) in cases {
154            match Uri::parse_origin(input) {
155                Err(e) if predicate(&e) => {}
156                other => panic!("unexpected result for {:?}: {:?}", input, other),
157            }
158        }
159    }
160}