1use http::Uri;
2
3use super::ConfigError;
4
5pub(crate) trait UriExt: Sized {
8 fn parse_origin(input: &str) -> Result<Self, ConfigError>;
23}
24
25impl UriExt for Uri {
26 fn parse_origin(input: &str) -> Result<Self, ConfigError> {
27 if input.contains('#') {
28 return Err(ConfigError::InvalidOriginUrlComponents {
29 origin: input.to_owned(),
30 });
31 }
32
33 if !input.is_ascii() {
35 return Err(ConfigError::NonAsciiHostname {
36 origin: input.to_owned(),
37 });
38 }
39
40 let uri: Uri =
41 input
42 .parse()
43 .map_err(|e: http::uri::InvalidUri| ConfigError::InvalidOriginUrl {
44 origin: input.to_owned(),
45 message: e.to_string(),
46 })?;
47
48 if !matches!(uri.scheme_str(), Some("http" | "https"))
49 || uri.host().map_or(true, |h| h.is_empty())
50 {
51 return Err(ConfigError::OpaqueOrigin {
52 origin: input.to_owned(),
53 });
54 }
55
56 let after_scheme = input.split_once("://").map_or("", |(_, rest)| rest);
61
62 if after_scheme.contains('/') || uri.query().is_some() {
63 return Err(ConfigError::InvalidOriginUrlComponents {
64 origin: input.to_owned(),
65 });
66 }
67
68 Ok(uri)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_parse_origin_accepts() {
78 for input in [
79 "https://example.com",
80 "http://example.com",
81 "https://example.com:8443",
82 "HTTPS://Example.COM",
83 ] {
84 assert!(
85 Uri::parse_origin(input).is_ok(),
86 "expected Ok for {input:?}, got {:?}",
87 Uri::parse_origin(input)
88 );
89 }
90 }
91
92 #[test]
93 fn test_parse_origin_rejects() {
94 type Check = fn(&ConfigError) -> bool;
98 let cases: &[(&str, Check)] = &[
99 ("not a valid url", |e| {
101 matches!(e, ConfigError::InvalidOriginUrl { .. })
102 }),
103 ("https://", |e| {
104 matches!(e, ConfigError::InvalidOriginUrl { .. })
105 }),
106 ("file:///", |e| {
107 matches!(e, ConfigError::InvalidOriginUrl { .. })
108 }),
109 ("example.com", |e| {
111 matches!(e, ConfigError::OpaqueOrigin { .. })
112 }),
113 ("file://host/path", |e| {
114 matches!(e, ConfigError::OpaqueOrigin { .. })
115 }),
116 ("mailto:x@y.z", |e| {
117 matches!(e, ConfigError::OpaqueOrigin { .. })
118 }),
119 ("javascript:alert(1)", |e| {
120 matches!(e, ConfigError::OpaqueOrigin { .. })
121 }),
122 ("https://example.com/", |e| {
125 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
126 }),
127 ("https://example.com/path", |e| {
128 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
129 }),
130 ("https://example.com/path?query=value", |e| {
131 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
132 }),
133 ("https://example.com/path#fragment", |e| {
134 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
135 }),
136 ("https://example.com#fragment", |e| {
139 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
140 }),
141 ("https://example.com/#fragment", |e| {
142 matches!(e, ConfigError::InvalidOriginUrlComponents { .. })
143 }),
144 ("https://ümlaut.de", |e| {
146 matches!(e, ConfigError::NonAsciiHostname { .. })
147 }),
148 ("https://日本.jp", |e| {
149 matches!(e, ConfigError::NonAsciiHostname { .. })
150 }),
151 ];
152
153 for (input, predicate) in cases {
154 match Uri::parse_origin(input) {
155 Err(e) if predicate(&e) => {}
156 other => panic!("unexpected result for {:?}: {:?}", input, other),
157 }
158 }
159 }
160}