rustauth_core/auth/
trusted_origins.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct OriginMatchSettings {
5 pub allow_relative_paths: bool,
6}
7
8pub fn matches_origin_pattern(
10 url: &str,
11 pattern: &str,
12 settings: Option<OriginMatchSettings>,
13) -> bool {
14 if url.starts_with('/') {
15 return settings
16 .is_some_and(|settings| settings.allow_relative_paths && is_safe_relative_path(url));
17 }
18
19 let has_wildcard = pattern.contains('*') || pattern.contains('?');
20 if has_wildcard {
21 if pattern.contains("://") {
22 return wildcard_match(pattern, url)
23 || origin_from_url(url).is_some_and(|origin| wildcard_match(pattern, &origin));
24 }
25 return host_from_url(url).is_some_and(|host| wildcard_match(pattern, &host));
26 }
27
28 match protocol_from_url(url).as_deref() {
29 Some("http") | Some("https") | None => {
30 origin_from_url(url).is_some_and(|origin| origin == pattern)
31 }
32 Some(_) => url.starts_with(pattern),
33 }
34}
35
36fn is_safe_relative_path(path: &str) -> bool {
37 if !path.starts_with('/') || path.starts_with("//") || path.starts_with("/\\") {
38 return false;
39 }
40 let lowercase = path.to_ascii_lowercase();
41 if lowercase.contains("%2f")
42 || lowercase.contains("%5c")
43 || lowercase.starts_with("javascript:")
44 || lowercase.starts_with("data:")
45 {
46 return false;
47 }
48 path.bytes().all(|byte| {
49 byte.is_ascii_alphanumeric()
50 || matches!(
51 byte,
52 b'/' | b'?' | b'&' | b'=' | b'%' | b'@' | b'.' | b'-' | b'_' | b'+'
53 )
54 })
55}
56
57fn protocol_from_url(url: &str) -> Option<String> {
58 url.split_once("://")
59 .map(|(protocol, _)| protocol.to_owned())
60}
61
62fn host_from_url(url: &str) -> Option<String> {
63 let (_, rest) = url.split_once("://")?;
64 let host = rest.split('/').next().unwrap_or(rest);
65 let host = host.split('?').next().unwrap_or(host);
66 (!host.is_empty()).then(|| host.to_owned())
67}
68
69fn origin_from_url(url: &str) -> Option<String> {
70 let (protocol, rest) = url.split_once("://")?;
71 let host = rest.split('/').next().unwrap_or(rest);
72 let host = host.split('?').next().unwrap_or(host);
73 (!host.is_empty()).then(|| format!("{protocol}://{host}"))
74}
75
76fn wildcard_match(pattern: &str, value: &str) -> bool {
77 wildcard_match_bytes(pattern.as_bytes(), value.as_bytes())
78}
79
80fn wildcard_match_bytes(pattern: &[u8], value: &[u8]) -> bool {
81 let (mut pattern_index, mut value_index) = (0, 0);
82 let mut star_index = None;
83 let mut match_index = 0;
84
85 while value_index < value.len() {
86 if pattern_index < pattern.len()
87 && (pattern[pattern_index] == b'?' || pattern[pattern_index] == value[value_index])
88 {
89 pattern_index += 1;
90 value_index += 1;
91 } else if pattern_index < pattern.len() && pattern[pattern_index] == b'*' {
92 star_index = Some(pattern_index);
93 match_index = value_index;
94 pattern_index += 1;
95 } else if let Some(star) = star_index {
96 pattern_index = star + 1;
97 match_index += 1;
98 value_index = match_index;
99 } else {
100 return false;
101 }
102 }
103
104 while pattern_index < pattern.len() && pattern[pattern_index] == b'*' {
105 pattern_index += 1;
106 }
107
108 pattern_index == pattern.len()
109}