Skip to main content

rustauth_core/utils/
url.rs

1/// Normalize a request URL pathname by removing the auth base path and trailing slashes.
2pub fn normalize_pathname(request_url: &str, base_path: &str) -> String {
3    let Some(pathname) = pathname_from_url(request_url) else {
4        return "/".to_owned();
5    };
6
7    let pathname = trim_trailing_slashes(&pathname);
8    let base_path = trim_trailing_slashes(base_path);
9
10    if base_path == "/" {
11        return pathname;
12    }
13
14    if pathname == base_path {
15        return "/".to_owned();
16    }
17
18    let base_prefix = format!("{base_path}/");
19    if let Some(without_base_path) = pathname.strip_prefix(&base_prefix) {
20        trim_trailing_slashes(&format!("/{without_base_path}"))
21    } else {
22        pathname
23    }
24}
25
26/// Reject unsafe `x-forwarded-proto` values (Better Auth `validateProxyHeader` parity).
27pub fn is_valid_forwarded_proto(proto: &str) -> bool {
28    matches!(proto.trim().to_ascii_lowercase().as_str(), "http" | "https")
29}
30
31/// Reject unsafe `x-forwarded-host` / authority values used for base URL inference.
32pub fn is_valid_forwarded_host(host: &str) -> bool {
33    let host = host.trim();
34    if host.is_empty() || host.contains("..") || host.starts_with('.') {
35        return false;
36    }
37    if host
38        .bytes()
39        .any(|byte| byte == 0 || byte.is_ascii_whitespace() || byte == b'<' || byte == b'>')
40    {
41        return false;
42    }
43
44    let (name, port) = match split_host_and_port(host) {
45        Some(parts) => parts,
46        None => return false,
47    };
48
49    if let Some(port) = port {
50        if port.parse::<u16>().is_err() {
51            return false;
52        }
53    }
54
55    if name.starts_with('[') && name.ends_with(']') {
56        is_valid_ipv6_literal(&name[1..name.len() - 1])
57    } else {
58        is_valid_dns_or_ipv4_literal(name)
59    }
60}
61
62fn pathname_from_url(request_url: &str) -> Option<String> {
63    if request_url.starts_with('/') {
64        let path = request_url
65            .split_once('?')
66            .map_or(request_url, |(path, _)| path);
67        let path = path.split_once('#').map_or(path, |(path, _)| path);
68        return Some(path.to_owned());
69    }
70    let (_, after_scheme) = request_url.split_once("://")?;
71    let path_start = after_scheme.find('/')?;
72    let path_with_query = &after_scheme[path_start..];
73    let path = path_with_query
74        .split_once('?')
75        .map_or(path_with_query, |(path, _)| path);
76    let path = path.split_once('#').map_or(path, |(path, _)| path);
77
78    Some(path.to_owned())
79}
80
81fn trim_trailing_slashes(path: &str) -> String {
82    let trimmed = path.trim_end_matches('/');
83    if trimmed.is_empty() {
84        "/".to_owned()
85    } else if trimmed.starts_with('/') {
86        trimmed.to_owned()
87    } else {
88        format!("/{trimmed}")
89    }
90}
91
92fn split_host_and_port(host: &str) -> Option<(&str, Option<&str>)> {
93    if host.starts_with('[') {
94        let end = host.find(']')?;
95        let name = &host[..=end];
96        let rest = &host[end + 1..];
97        let port = if rest.is_empty() {
98            None
99        } else if let Some(port) = rest.strip_prefix(':') {
100            Some(port)
101        } else {
102            return None;
103        };
104        return Some((name, port));
105    }
106
107    if let Some((name, port)) = host.rsplit_once(':') {
108        if port.chars().all(|char| char.is_ascii_digit()) {
109            return Some((name, Some(port)));
110        }
111    }
112
113    Some((host, None))
114}
115
116fn is_valid_dns_or_ipv4_literal(host: &str) -> bool {
117    !host.is_empty()
118        && host
119            .chars()
120            .all(|char| char.is_ascii_alphanumeric() || matches!(char, '.' | '-' | '_'))
121}
122
123fn is_valid_ipv6_literal(host: &str) -> bool {
124    !host.is_empty()
125        && host
126            .chars()
127            .all(|char| char.is_ascii_hexdigit() || matches!(char, ':' | '.' | '-' | '%'))
128}