Skip to main content

rustauth_core/auth/
trusted_origins.rs

1//! Trusted origin matching.
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct OriginMatchSettings {
5    pub allow_relative_paths: bool,
6}
7
8/// Match a URL against an origin or origin pattern.
9pub fn matches_origin_pattern(
10    url: &str,
11    pattern: &str,
12    settings: Option<OriginMatchSettings>,
13) -> bool {
14    if url.starts_with('/') {
15        return settings
16            .is_some_and(|settings| settings.allow_relative_paths && is_safe_relative_path(url));
17    }
18
19    let has_wildcard = pattern.contains('*') || pattern.contains('?');
20    if has_wildcard {
21        if pattern.contains("://") {
22            return wildcard_match(pattern, url)
23                || origin_from_url(url).is_some_and(|origin| wildcard_match(pattern, &origin));
24        }
25        return host_from_url(url).is_some_and(|host| wildcard_match(pattern, &host));
26    }
27
28    match protocol_from_url(url).as_deref() {
29        Some("http") | Some("https") | None => {
30            origin_from_url(url).is_some_and(|origin| origin == pattern)
31        }
32        Some(_) => url.starts_with(pattern),
33    }
34}
35
36fn is_safe_relative_path(path: &str) -> bool {
37    if !path.starts_with('/') || path.starts_with("//") || path.starts_with("/\\") {
38        return false;
39    }
40    let lowercase = path.to_ascii_lowercase();
41    if lowercase.contains("%2f")
42        || lowercase.contains("%5c")
43        || lowercase.starts_with("javascript:")
44        || lowercase.starts_with("data:")
45    {
46        return false;
47    }
48    path.bytes().all(|byte| {
49        byte.is_ascii_alphanumeric()
50            || matches!(
51                byte,
52                b'/' | b'?' | b'&' | b'=' | b'%' | b'@' | b'.' | b'-' | b'_' | b'+'
53            )
54    })
55}
56
57fn protocol_from_url(url: &str) -> Option<String> {
58    url.split_once("://")
59        .map(|(protocol, _)| protocol.to_owned())
60}
61
62fn host_from_url(url: &str) -> Option<String> {
63    let (_, rest) = url.split_once("://")?;
64    let host = rest.split('/').next().unwrap_or(rest);
65    let host = host.split('?').next().unwrap_or(host);
66    (!host.is_empty()).then(|| host.to_owned())
67}
68
69fn origin_from_url(url: &str) -> Option<String> {
70    let (protocol, rest) = url.split_once("://")?;
71    let host = rest.split('/').next().unwrap_or(rest);
72    let host = host.split('?').next().unwrap_or(host);
73    (!host.is_empty()).then(|| format!("{protocol}://{host}"))
74}
75
76fn wildcard_match(pattern: &str, value: &str) -> bool {
77    wildcard_match_bytes(pattern.as_bytes(), value.as_bytes())
78}
79
80fn wildcard_match_bytes(pattern: &[u8], value: &[u8]) -> bool {
81    let (mut pattern_index, mut value_index) = (0, 0);
82    let mut star_index = None;
83    let mut match_index = 0;
84
85    while value_index < value.len() {
86        if pattern_index < pattern.len()
87            && (pattern[pattern_index] == b'?' || pattern[pattern_index] == value[value_index])
88        {
89            pattern_index += 1;
90            value_index += 1;
91        } else if pattern_index < pattern.len() && pattern[pattern_index] == b'*' {
92            star_index = Some(pattern_index);
93            match_index = value_index;
94            pattern_index += 1;
95        } else if let Some(star) = star_index {
96            pattern_index = star + 1;
97            match_index += 1;
98            value_index = match_index;
99        } else {
100            return false;
101        }
102    }
103
104    while pattern_index < pattern.len() && pattern[pattern_index] == b'*' {
105        pattern_index += 1;
106    }
107
108    pattern_index == pattern.len()
109}